From e86e895e96c95fe63e7c1c4483a1b29b3fce14e5 Mon Sep 17 00:00:00 2001 From: Josema Camacho Date: Thu, 19 Mar 2026 10:49:48 +0100 Subject: [PATCH] fix(api): address review findings in rls_transaction failover logic --- api/src/backend/api/db_utils.py | 38 ++-- api/src/backend/api/tests/test_db_utils.py | 215 +++++++++++++++++++++ 2 files changed, 233 insertions(+), 20 deletions(-) diff --git a/api/src/backend/api/db_utils.py b/api/src/backend/api/db_utils.py index 8e89bc4867..129bab57a9 100644 --- a/api/src/backend/api/db_utils.py +++ b/api/src/backend/api/db_utils.py @@ -1,8 +1,9 @@ import re import secrets +import sys import time import uuid -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from datetime import datetime, timedelta, timezone from celery.utils.log import get_task_logger @@ -131,7 +132,7 @@ def rls_transaction( try: connections[replica_alias].close() except Exception: - pass + pass # Best-effort; connection may already be dead delay = REPLICA_RETRY_BASE_DELAY * (2 ** (retry - 1)) logger.info( @@ -160,7 +161,7 @@ def rls_transaction( try: connections[replica_alias].close() except Exception: - pass + pass # Best-effort; connection may already be dead logger.warning( "Mid-query replica retries exhausted, falling back to primary DB" @@ -185,7 +186,7 @@ def rls_transaction( for attempt in range(1, max_attempts + 1): router_token = None yielded_cursor = False - wrapper_installed = False + _caller_exited_cleanly = False # On final attempt, fall back to primary if attempt == max_attempts and can_failover: @@ -209,21 +210,24 @@ def rls_transaction( raise ValidationError("Must be a valid UUID") cursor.execute(SET_CONFIG_QUERY, [parameter, value]) - if can_failover and alias == replica_alias: - conn.execute_wrappers.append(_query_failover) - wrapper_installed = True - - yielded_cursor = True - yield cursor + wrapper_cm = ( + conn.execute_wrapper(_query_failover) + if can_failover and alias == replica_alias + else nullcontext() + ) + with wrapper_cm: + yielded_cursor = True + yield cursor + _caller_exited_cleanly = True return except OperationalError as e: try: connections[alias].close() except Exception: - pass + pass # Best-effort; connection may already be dead if yielded_cursor: - if _fallback["succeeded"]: + if _fallback["succeeded"] and _caller_exited_cleanly: # Caller's queries succeeded on primary via failover. # This error is transaction.atomic() cleanup on the # dead replica connection — suppress it. @@ -241,17 +245,11 @@ def rls_transaction( ) time.sleep(delay) finally: - if wrapper_installed: - try: - conn.execute_wrappers.remove(_query_failover) - except ValueError: - pass - if _fallback["atomic"] is not None: try: - _fallback["atomic"].__exit__(None, None, None) + _fallback["atomic"].__exit__(*sys.exc_info()) except Exception: - pass + pass # Best-effort; primary connection may be dead _fallback["atomic"] = None if _fallback["token"] is not None: reset_read_db_alias(_fallback["token"]) diff --git a/api/src/backend/api/tests/test_db_utils.py b/api/src/backend/api/tests/test_db_utils.py index 7a98c6a2d7..3be72638f5 100644 --- a/api/src/backend/api/tests/test_db_utils.py +++ b/api/src/backend/api/tests/test_db_utils.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from datetime import datetime, timezone from enum import Enum from unittest.mock import MagicMock, patch @@ -928,6 +929,16 @@ class TestRlsTransaction: with patch("api.db_utils.connections") as mock_connections: mock_conn = MagicMock() mock_conn.execute_wrappers = [] + + @contextmanager + def _execute_wrapper(fn): + mock_conn.execute_wrappers.append(fn) + try: + yield + finally: + mock_conn.execute_wrappers.remove(fn) + + mock_conn.execute_wrapper = _execute_wrapper mock_cursor = MagicMock() mock_conn.cursor.return_value.__enter__.return_value = mock_cursor mock_connections.__getitem__.return_value = mock_conn @@ -987,6 +998,16 @@ class TestRlsTransaction: with patch("api.db_utils.connections") as mock_connections: mock_replica_conn = MagicMock() mock_replica_conn.execute_wrappers = [] + + @contextmanager + def _execute_wrapper_replica(fn): + mock_replica_conn.execute_wrappers.append(fn) + try: + yield + finally: + mock_replica_conn.execute_wrappers.remove(fn) + + mock_replica_conn.execute_wrapper = _execute_wrapper_replica mock_cursor = MagicMock() mock_replica_conn.cursor.return_value.__enter__.return_value = ( mock_cursor @@ -1064,6 +1085,16 @@ class TestRlsTransaction: with patch("api.db_utils.connections") as mock_connections: mock_replica_conn = MagicMock() mock_replica_conn.execute_wrappers = [] + + @contextmanager + def _execute_wrapper_replica(fn): + mock_replica_conn.execute_wrappers.append(fn) + try: + yield + finally: + mock_replica_conn.execute_wrappers.remove(fn) + + mock_replica_conn.execute_wrapper = _execute_wrapper_replica mock_cursor = MagicMock() mock_replica_conn.cursor.return_value.__enter__.return_value = ( mock_cursor @@ -1125,6 +1156,16 @@ class TestRlsTransaction: with patch("api.db_utils.connections") as mock_connections: mock_conn = MagicMock() mock_conn.execute_wrappers = [] + + @contextmanager + def _execute_wrapper(fn): + mock_conn.execute_wrappers.append(fn) + try: + yield + finally: + mock_conn.execute_wrappers.remove(fn) + + mock_conn.execute_wrapper = _execute_wrapper mock_cursor = MagicMock() mock_conn.cursor.return_value.__enter__.return_value = mock_cursor # Both replica retry and primary fail @@ -1223,6 +1264,180 @@ class TestRlsTransaction: # close() called for each failed pre-yield attempt assert mock_conn.close.call_count == 2 + def test_caller_error_propagates_after_successful_failover( + self, tenants_fixture, enable_read_replica + ): + """OperationalError raised by caller after failover is NOT suppressed.""" + tenant = tenants_fixture[0] + tenant_id = str(tenant.id) + + with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica): + with patch("api.db_utils.connections") as mock_connections: + mock_replica_conn = MagicMock() + mock_replica_conn.execute_wrappers = [] + + @contextmanager + def _execute_wrapper_replica(fn): + mock_replica_conn.execute_wrappers.append(fn) + try: + yield + finally: + mock_replica_conn.execute_wrappers.remove(fn) + + mock_replica_conn.execute_wrapper = _execute_wrapper_replica + mock_cursor = MagicMock() + mock_replica_conn.cursor.return_value.__enter__.return_value = ( + mock_cursor + ) + mock_replica_conn.ensure_connection.side_effect = OperationalError( + "replica down" + ) + + mock_primary_conn = MagicMock() + mock_primary_raw_cursor = MagicMock() + mock_primary_conn.connection.cursor.return_value = ( + mock_primary_raw_cursor + ) + + def connections_getitem(alias): + if alias == "replica": + return mock_replica_conn + return mock_primary_conn + + mock_connections.__getitem__.side_effect = connections_getitem + mock_connections.__contains__.return_value = True + + # transaction.atomic().__exit__ raises on dead replica cleanup + mock_atomic_cm = MagicMock() + mock_atomic_cm.__enter__ = MagicMock(return_value=None) + mock_atomic_cm.__exit__ = MagicMock( + side_effect=OperationalError("cleanup on dead replica") + ) + + with patch( + "api.db_utils.transaction.atomic", return_value=mock_atomic_cm + ): + with patch("api.db_utils.time.sleep"): + with patch( + "api.db_utils.set_read_db_alias", return_value="token" + ): + with patch("api.db_utils.reset_read_db_alias"): + with pytest.raises(OperationalError): + with rls_transaction(tenant_id): + # Trigger failover (succeeds on primary) + wrapper = mock_replica_conn.execute_wrappers[0] + mock_execute = MagicMock( + side_effect=OperationalError("EOF") + ) + mock_context = {"cursor": MagicMock()} + wrapper( + mock_execute, + "SELECT 1", + None, + False, + mock_context, + ) + # Caller raises after successful failover — + # must NOT be suppressed + raise OperationalError("caller error") + + def test_fallback_atomic_receives_exception_info( + self, tenants_fixture, enable_read_replica + ): + """Primary transaction rollback uses sys.exc_info(), not (None, None, None).""" + tenant = tenants_fixture[0] + tenant_id = str(tenant.id) + + with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica): + with patch("api.db_utils.connections") as mock_connections: + mock_replica_conn = MagicMock() + mock_replica_conn.execute_wrappers = [] + + @contextmanager + def _execute_wrapper_replica(fn): + mock_replica_conn.execute_wrappers.append(fn) + try: + yield + finally: + mock_replica_conn.execute_wrappers.remove(fn) + + mock_replica_conn.execute_wrapper = _execute_wrapper_replica + mock_cursor = MagicMock() + mock_replica_conn.cursor.return_value.__enter__.return_value = ( + mock_cursor + ) + mock_replica_conn.ensure_connection.side_effect = OperationalError( + "replica down" + ) + + mock_primary_conn = MagicMock() + mock_primary_raw_cursor = MagicMock() + mock_primary_conn.connection.cursor.return_value = ( + mock_primary_raw_cursor + ) + + def connections_getitem(alias): + if alias == "replica": + return mock_replica_conn + return mock_primary_conn + + mock_connections.__getitem__.side_effect = connections_getitem + mock_connections.__contains__.return_value = True + + mock_fallback_atomic = MagicMock() + mock_fallback_atomic.__enter__ = MagicMock(return_value=None) + mock_fallback_atomic.__exit__ = MagicMock(return_value=False) + + # Outer atomic raises on dead replica cleanup + mock_outer_atomic = MagicMock() + mock_outer_atomic.__enter__ = MagicMock(return_value=None) + mock_outer_atomic.__exit__ = MagicMock( + side_effect=OperationalError("cleanup on dead replica") + ) + + atomic_call_count = 0 + + def atomic_side_effect(*args, **kwargs): + nonlocal atomic_call_count + atomic_call_count += 1 + # First call is the outer transaction.atomic(using=replica) + if atomic_call_count == 1: + return mock_outer_atomic + # Second call is _fallback["atomic"] inside the wrapper + return mock_fallback_atomic + + with patch( + "api.db_utils.transaction.atomic", + side_effect=atomic_side_effect, + ): + with patch("api.db_utils.time.sleep"): + with patch( + "api.db_utils.set_read_db_alias", return_value="token" + ): + with patch("api.db_utils.reset_read_db_alias"): + with pytest.raises(OperationalError): + with rls_transaction(tenant_id): + wrapper = mock_replica_conn.execute_wrappers[0] + mock_execute = MagicMock( + side_effect=OperationalError("EOF") + ) + mock_context = {"cursor": MagicMock()} + wrapper( + mock_execute, + "SELECT 1", + None, + False, + mock_context, + ) + raise OperationalError("caller error") + + # Verify __exit__ was called with exception info, + # not (None, None, None) + exit_args = mock_fallback_atomic.__exit__.call_args + exc_type = exit_args[0][0] + assert exc_type is not None + assert issubclass(exc_type, OperationalError) + class TestPostgresEnumMigration: """