fix(api): address review findings in rls_transaction failover logic

This commit is contained in:
Josema Camacho
2026-03-19 10:49:48 +01:00
parent e31829ebcd
commit e86e895e96
2 changed files with 233 additions and 20 deletions

View File

@@ -1,8 +1,9 @@
import re
import secrets
import sys
import time
import uuid
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import datetime, timedelta, timezone
from celery.utils.log import get_task_logger
@@ -131,7 +132,7 @@ def rls_transaction(
try:
connections[replica_alias].close()
except Exception:
pass
pass # Best-effort; connection may already be dead
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (retry - 1))
logger.info(
@@ -160,7 +161,7 @@ def rls_transaction(
try:
connections[replica_alias].close()
except Exception:
pass
pass # Best-effort; connection may already be dead
logger.warning(
"Mid-query replica retries exhausted, falling back to primary DB"
@@ -185,7 +186,7 @@ def rls_transaction(
for attempt in range(1, max_attempts + 1):
router_token = None
yielded_cursor = False
wrapper_installed = False
_caller_exited_cleanly = False
# On final attempt, fall back to primary
if attempt == max_attempts and can_failover:
@@ -209,21 +210,24 @@ def rls_transaction(
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
if can_failover and alias == replica_alias:
conn.execute_wrappers.append(_query_failover)
wrapper_installed = True
yielded_cursor = True
yield cursor
wrapper_cm = (
conn.execute_wrapper(_query_failover)
if can_failover and alias == replica_alias
else nullcontext()
)
with wrapper_cm:
yielded_cursor = True
yield cursor
_caller_exited_cleanly = True
return
except OperationalError as e:
try:
connections[alias].close()
except Exception:
pass
pass # Best-effort; connection may already be dead
if yielded_cursor:
if _fallback["succeeded"]:
if _fallback["succeeded"] and _caller_exited_cleanly:
# Caller's queries succeeded on primary via failover.
# This error is transaction.atomic() cleanup on the
# dead replica connection — suppress it.
@@ -241,17 +245,11 @@ def rls_transaction(
)
time.sleep(delay)
finally:
if wrapper_installed:
try:
conn.execute_wrappers.remove(_query_failover)
except ValueError:
pass
if _fallback["atomic"] is not None:
try:
_fallback["atomic"].__exit__(None, None, None)
_fallback["atomic"].__exit__(*sys.exc_info())
except Exception:
pass
pass # Best-effort; primary connection may be dead
_fallback["atomic"] = None
if _fallback["token"] is not None:
reset_read_db_alias(_fallback["token"])

View File

@@ -1,3 +1,4 @@
from contextlib import contextmanager
from datetime import datetime, timezone
from enum import Enum
from unittest.mock import MagicMock, patch
@@ -928,6 +929,16 @@ class TestRlsTransaction:
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper(fn):
mock_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_conn.execute_wrappers.remove(fn)
mock_conn.execute_wrapper = _execute_wrapper
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
@@ -987,6 +998,16 @@ class TestRlsTransaction:
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
@@ -1064,6 +1085,16 @@ class TestRlsTransaction:
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
@@ -1125,6 +1156,16 @@ class TestRlsTransaction:
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper(fn):
mock_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_conn.execute_wrappers.remove(fn)
mock_conn.execute_wrapper = _execute_wrapper
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
# Both replica retry and primary fail
@@ -1223,6 +1264,180 @@ class TestRlsTransaction:
# close() called for each failed pre-yield attempt
assert mock_conn.close.call_count == 2
def test_caller_error_propagates_after_successful_failover(
self, tenants_fixture, enable_read_replica
):
"""OperationalError raised by caller after failover is NOT suppressed."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
)
mock_replica_conn.ensure_connection.side_effect = OperationalError(
"replica down"
)
mock_primary_conn = MagicMock()
mock_primary_raw_cursor = MagicMock()
mock_primary_conn.connection.cursor.return_value = (
mock_primary_raw_cursor
)
def connections_getitem(alias):
if alias == "replica":
return mock_replica_conn
return mock_primary_conn
mock_connections.__getitem__.side_effect = connections_getitem
mock_connections.__contains__.return_value = True
# transaction.atomic().__exit__ raises on dead replica cleanup
mock_atomic_cm = MagicMock()
mock_atomic_cm.__enter__ = MagicMock(return_value=None)
mock_atomic_cm.__exit__ = MagicMock(
side_effect=OperationalError("cleanup on dead replica")
)
with patch(
"api.db_utils.transaction.atomic", return_value=mock_atomic_cm
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
# Trigger failover (succeeds on primary)
wrapper = mock_replica_conn.execute_wrappers[0]
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
# Caller raises after successful failover —
# must NOT be suppressed
raise OperationalError("caller error")
def test_fallback_atomic_receives_exception_info(
self, tenants_fixture, enable_read_replica
):
"""Primary transaction rollback uses sys.exc_info(), not (None, None, None)."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
)
mock_replica_conn.ensure_connection.side_effect = OperationalError(
"replica down"
)
mock_primary_conn = MagicMock()
mock_primary_raw_cursor = MagicMock()
mock_primary_conn.connection.cursor.return_value = (
mock_primary_raw_cursor
)
def connections_getitem(alias):
if alias == "replica":
return mock_replica_conn
return mock_primary_conn
mock_connections.__getitem__.side_effect = connections_getitem
mock_connections.__contains__.return_value = True
mock_fallback_atomic = MagicMock()
mock_fallback_atomic.__enter__ = MagicMock(return_value=None)
mock_fallback_atomic.__exit__ = MagicMock(return_value=False)
# Outer atomic raises on dead replica cleanup
mock_outer_atomic = MagicMock()
mock_outer_atomic.__enter__ = MagicMock(return_value=None)
mock_outer_atomic.__exit__ = MagicMock(
side_effect=OperationalError("cleanup on dead replica")
)
atomic_call_count = 0
def atomic_side_effect(*args, **kwargs):
nonlocal atomic_call_count
atomic_call_count += 1
# First call is the outer transaction.atomic(using=replica)
if atomic_call_count == 1:
return mock_outer_atomic
# Second call is _fallback["atomic"] inside the wrapper
return mock_fallback_atomic
with patch(
"api.db_utils.transaction.atomic",
side_effect=atomic_side_effect,
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
wrapper = mock_replica_conn.execute_wrappers[0]
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
raise OperationalError("caller error")
# Verify __exit__ was called with exception info,
# not (None, None, None)
exit_args = mock_fallback_atomic.__exit__.call_args
exc_type = exit_args[0][0]
assert exc_type is not None
assert issubclass(exc_type, OperationalError)
class TestPostgresEnumMigration:
"""