Compare commits

...

9 Commits

Author SHA1 Message Date
Josema Camacho
d8a654a1a1 Merge branch 'master' of github.com:prowler-cloud/prowler into PROWLER-1225-fix-read-queries-on-the-read-replica-that-doesnt-use-the-write-replica-on-retries--second-attempt 2026-03-24 09:56:33 +01:00
Josema Camacho
f58bc362df Merge branch 'master' of github.com:prowler-cloud/prowler into PROWLER-1225-fix-read-queries-on-the-read-replica-that-doesnt-use-the-write-replica-on-retries--second-attempt 2026-03-23 17:04:55 +01:00
Josema Camacho
fb245d8136 Update api/CHANGELOG.md
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2026-03-23 13:36:07 +01:00
Josema Camacho
e43d0aa806 Merge branch 'master' into PROWLER-1225-fix-read-queries-on-the-read-replica-that-doesnt-use-the-write-replica-on-retries--second-attempt 2026-03-23 10:23:11 +01:00
Josema Camacho
dd3de0bd7a Merge branch 'master' into PROWLER-1225-fix-read-queries-on-the-read-replica-that-doesnt-use-the-write-replica-on-retries--second-attempt 2026-03-19 13:19:32 +01:00
Josema Camacho
32d2127e89 docs(api): move rls_transaction failover entry to its own v1.23.0 changelog section - Fix no new line between sections 2026-03-19 13:11:31 +01:00
Josema Camacho
33af438dc5 docs(api): move rls_transaction failover entry to its own v1.23.0 changelog section 2026-03-19 12:22:30 +01:00
Josema Camacho
e86e895e96 fix(api): address review findings in rls_transaction failover logic 2026-03-19 10:49:48 +01:00
Josema Camacho
e31829ebcd fix(api): add query-level retry with primary fallback to rls_transaction via execute_wrapper 2026-03-18 16:22:17 +01:00
3 changed files with 671 additions and 29 deletions

View File

@@ -6,6 +6,7 @@ All notable changes to the **Prowler API** are documented in this file.
### 🐞 Fixed
- `rls_transaction` to retry mid-query read replica failures with primary DB fallback via `execute_wrapper`, preventing scan crashes during replica recovery [(#10379)](https://github.com/prowler-cloud/prowler/pull/10379)
- Finding groups latest endpoint now aggregates the latest snapshot per provider before check-level totals, keeping impacted resources aligned across providers [(#10419)](https://github.com/prowler-cloud/prowler/pull/10419)
- Mute rule creation now triggers finding-group summary re-aggregation after historical muting, keeping stats in sync after mute operations [(#10419)](https://github.com/prowler-cloud/prowler/pull/10419)

View File

@@ -1,8 +1,9 @@
import re
import secrets
import sys
import time
import uuid
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import datetime, timedelta, timezone
from celery.utils.log import get_task_logger
@@ -78,14 +79,34 @@ def rls_transaction(
retry_on_replica: bool = True,
):
"""
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
if the value is a valid UUID.
Context manager that opens an RLS-scoped database transaction.
Sets a Postgres configuration variable (``set_config``) so that
Row-Level Security policies can filter by tenant. When *using*
points to a read replica and *retry_on_replica* is True, two
layers of retry protect against replica failures:
1. **Pre-yield** (connection-setup failures): the function retries
up to ``REPLICA_MAX_ATTEMPTS`` times on the replica, then falls
back to the primary DB.
2. **Post-yield** (mid-query failures): an ``execute_wrapper``
intercepts ``OperationalError`` during ``cursor.execute()``
calls, retries on the replica with backoff, and falls back to
the primary if the replica stays down. The wrapper swaps the
inner psycopg2 cursor so ``fetchall()`` / ``fetchone()`` read
from the new connection transparently.
Limitation: server-side cursors (``.iterator()``) fetch rows via
``fetchmany()``, which the wrapper does not intercept. Call sites
that iterate large result sets with ``.iterator()`` on the replica
should add their own retry logic.
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
using (str | None): Optional database alias to run the transaction against. Defaults to the
active read alias (if any) or Django's default connection.
value: Database configuration parameter value (must be a valid UUID).
parameter: Database configuration parameter name.
using: Optional database alias. Defaults to the active read
alias or Django's default connection.
retry_on_replica: Whether to retry on replica failures.
"""
requested_alias = using or get_read_db_alias()
db_alias = requested_alias or DEFAULT_DB_ALIAS
@@ -93,19 +114,87 @@ def rls_transaction(
db_alias = DEFAULT_DB_ALIAS
alias = db_alias
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
is_replica = bool(READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS)
can_failover = is_replica and retry_on_replica
replica_alias = alias # captured before the loop mutates alias
max_attempts = (REPLICA_MAX_ATTEMPTS + 1) if can_failover else 1
# State shared between the generator and the _query_failover closure
_fallback = {"succeeded": False, "atomic": None, "token": None}
def _query_failover(execute, sql, params, many, context):
"""execute_wrapper: retry failed replica queries, then fall back to primary."""
try:
return execute(sql, params, many, context)
except OperationalError as err:
# Phase 1 — retry on replica with exponential backoff
for retry in range(1, REPLICA_MAX_ATTEMPTS + 1):
try:
connections[replica_alias].close()
except Exception:
pass # Best-effort; connection may already be dead
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (retry - 1))
logger.info(
f"Mid-query failure on replica (retry {retry}/{REPLICA_MAX_ATTEMPTS}), "
f"retrying in {delay:.1f}s. Error: {err}"
)
time.sleep(delay)
try:
replica_conn = connections[replica_alias]
replica_conn.ensure_connection()
replica_conn.connection.autocommit = False
raw = replica_conn.connection.cursor()
raw.execute(SET_CONFIG_QUERY, [parameter, value])
if many:
raw.executemany(sql, params)
else:
raw.execute(sql, params)
context["cursor"].cursor = raw
return None
except OperationalError as retry_err:
err = retry_err
continue
# Phase 2 — fall back to primary
try:
connections[replica_alias].close()
except Exception:
pass # Best-effort; connection may already be dead
logger.warning(
"Mid-query replica retries exhausted, falling back to primary DB"
)
primary = connections[DEFAULT_DB_ALIAS]
primary.ensure_connection()
_fallback["atomic"] = transaction.atomic(using=DEFAULT_DB_ALIAS)
_fallback["atomic"].__enter__()
with primary.cursor() as setup_cursor:
setup_cursor.execute(SET_CONFIG_QUERY, [parameter, value])
_fallback["token"] = set_read_db_alias(DEFAULT_DB_ALIAS)
raw = primary.connection.cursor()
if many:
raw.executemany(sql, params)
else:
raw.execute(sql, params)
context["cursor"].cursor = raw
_fallback["succeeded"] = True
return None
for attempt in range(1, max_attempts + 1):
router_token = None
yielded_cursor = False
_caller_exited_cleanly = False
# On final attempt, fallback to primary
if attempt == max_attempts and is_replica:
logger.warning(
f"RLS transaction failed after {attempt - 1} attempts on replica, "
f"falling back to primary DB"
)
# On final attempt, fall back to primary
if attempt == max_attempts and can_failover:
if attempt > 1:
logger.warning(
f"RLS transaction failed after {attempt - 1} attempts on replica, "
f"falling back to primary DB"
)
alias = DEFAULT_DB_ALIAS
conn = connections[alias]
@@ -116,19 +205,36 @@ def rls_transaction(
with transaction.atomic(using=alias):
with conn.cursor() as cursor:
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yielded_cursor = True
yield cursor
wrapper_cm = (
conn.execute_wrapper(_query_failover)
if can_failover and alias == replica_alias
else nullcontext()
)
with wrapper_cm:
yielded_cursor = True
yield cursor
_caller_exited_cleanly = True
return
except OperationalError as e:
try:
connections[alias].close()
except Exception:
pass # Best-effort; connection may already be dead
if yielded_cursor:
if _fallback["succeeded"] and _caller_exited_cleanly:
# Caller's queries succeeded on primary via failover.
# This error is transaction.atomic() cleanup on the
# dead replica connection — suppress it.
return
raise
# If on primary or max attempts reached, raise
if not is_replica or attempt == max_attempts:
if not can_failover or attempt == max_attempts:
raise
# Retry with exponential backoff
@@ -139,6 +245,16 @@ def rls_transaction(
)
time.sleep(delay)
finally:
if _fallback["atomic"] is not None:
try:
_fallback["atomic"].__exit__(*sys.exc_info())
except Exception:
pass # Best-effort; primary connection may be dead
_fallback["atomic"] = None
if _fallback["token"] is not None:
reset_read_db_alias(_fallback["token"])
_fallback["token"] = None
if router_token is not None:
reset_read_db_alias(router_token)

View File

@@ -1,3 +1,4 @@
from contextlib import contextmanager
from datetime import datetime, timezone
from enum import Enum
from unittest.mock import MagicMock, patch
@@ -528,7 +529,7 @@ class TestRlsTransaction:
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
if call_count < 4:
raise OperationalError("Connection error")
return MagicMock(
__enter__=MagicMock(return_value=None),
@@ -547,10 +548,11 @@ class TestRlsTransaction:
with rls_transaction(tenant_id):
pass
assert mock_sleep.call_count == 2
assert mock_sleep.call_count == 3
mock_sleep.assert_any_call(0.5)
mock_sleep.assert_any_call(1.0)
assert mock_logger.info.call_count == 2
mock_sleep.assert_any_call(2.0)
assert mock_logger.info.call_count == 3
def test_rls_transaction_operational_error_inside_context_no_retry(
self, tenants_fixture, enable_read_replica
@@ -582,10 +584,10 @@ class TestRlsTransaction:
mock_sleep.assert_not_called()
def test_rls_transaction_max_three_attempts_for_replica(
def test_rls_transaction_max_attempts_for_replica(
self, tenants_fixture, enable_read_replica
):
"""Test maximum 3 attempts for replica database."""
"""Test REPLICA_MAX_ATTEMPTS replica tries + 1 primary fallback."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
@@ -609,7 +611,8 @@ class TestRlsTransaction:
with rls_transaction(tenant_id):
pass
assert mock_atomic.call_count == 3
# 3 replica + 1 primary = 4 total
assert mock_atomic.call_count == 4
def test_rls_transaction_replica_no_retry_when_disabled(
self, tenants_fixture, enable_read_replica
@@ -685,7 +688,7 @@ class TestRlsTransaction:
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
if call_count < 4:
raise OperationalError("Replica error")
return MagicMock(
__enter__=MagicMock(return_value=None),
@@ -728,7 +731,7 @@ class TestRlsTransaction:
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
if call_count < 4:
raise OperationalError("Replica error")
return MagicMock(
__enter__=MagicMock(return_value=None),
@@ -747,7 +750,7 @@ class TestRlsTransaction:
with rls_transaction(tenant_id):
pass
assert mock_logger.info.call_count == 2
assert mock_logger.info.call_count == 3
assert mock_logger.warning.call_count == 1
def test_rls_transaction_operational_error_raised_immediately_on_primary(
@@ -913,6 +916,528 @@ class TestRlsTransaction:
result = cursor.fetchone()
assert result[0] == 1
# --- Mid-query failover tests ---
def test_mid_query_failure_retries_on_replica(
self, tenants_fixture, enable_read_replica
):
"""Mid-query OperationalError retries on replica via execute_wrapper."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper(fn):
mock_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_conn.execute_wrappers.remove(fn)
mock_conn.execute_wrapper = _execute_wrapper
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
# The raw replica cursor that succeeds on retry
mock_raw_cursor = MagicMock()
mock_conn.connection.cursor.return_value = mock_raw_cursor
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with rls_transaction(tenant_id):
# Wrapper is installed
assert len(mock_conn.execute_wrappers) == 1
wrapper = mock_conn.execute_wrappers[0]
# Simulate a mid-query failure that succeeds on retry
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
# Verify retry happened with backoff
mock_sleep.assert_called_once_with(0.5)
mock_conn.ensure_connection.assert_called_once()
# The raw cursor was swapped
assert (
mock_context["cursor"].cursor == mock_raw_cursor
)
# Wrapper removed after exiting
assert len(mock_conn.execute_wrappers) == 0
def test_mid_query_failure_falls_back_to_primary(
self, tenants_fixture, enable_read_replica
):
"""Mid-query failure falls back to primary after replica retries exhausted."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
)
# Replica always fails on retry
mock_replica_conn.ensure_connection.side_effect = OperationalError(
"replica down"
)
mock_primary_conn = MagicMock()
mock_primary_raw_cursor = MagicMock()
mock_primary_conn.connection.cursor.return_value = (
mock_primary_raw_cursor
)
def connections_getitem(alias):
if alias == "replica":
return mock_replica_conn
return mock_primary_conn
mock_connections.__getitem__.side_effect = connections_getitem
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with patch("api.db_utils.logger") as mock_logger:
with rls_transaction(tenant_id):
wrapper = mock_replica_conn.execute_wrappers[0]
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
# Verify primary was used
mock_primary_conn.ensure_connection.assert_called_once()
assert (
mock_context["cursor"].cursor
== mock_primary_raw_cursor
)
# Verify warning logged
warning_msgs = [
c[0][0]
for c in mock_logger.warning.call_args_list
]
assert any(
"falling back to primary" in m
for m in warning_msgs
)
def test_mid_query_fallback_suppresses_cleanup_error(
self, tenants_fixture, enable_read_replica
):
"""After successful primary fallback, replica cleanup error is suppressed."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
)
mock_replica_conn.ensure_connection.side_effect = OperationalError(
"replica down"
)
mock_primary_conn = MagicMock()
mock_primary_raw_cursor = MagicMock()
mock_primary_conn.connection.cursor.return_value = (
mock_primary_raw_cursor
)
def connections_getitem(alias):
if alias == "replica":
return mock_replica_conn
return mock_primary_conn
mock_connections.__getitem__.side_effect = connections_getitem
mock_connections.__contains__.return_value = True
# transaction.atomic().__exit__ raises on dead replica
mock_atomic_cm = MagicMock()
mock_atomic_cm.__enter__ = MagicMock(return_value=None)
mock_atomic_cm.__exit__ = MagicMock(
side_effect=OperationalError("cleanup failed on dead replica")
)
with patch(
"api.db_utils.transaction.atomic", return_value=mock_atomic_cm
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
# Should NOT raise — the cleanup error is suppressed
with rls_transaction(tenant_id):
wrapper = mock_replica_conn.execute_wrappers[0]
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
def test_mid_query_primary_also_fails(self, tenants_fixture, enable_read_replica):
"""When both replica and primary fail, OperationalError propagates."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper(fn):
mock_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_conn.execute_wrappers.remove(fn)
mock_conn.execute_wrapper = _execute_wrapper
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
# Both replica retry and primary fail
mock_conn.ensure_connection.side_effect = OperationalError("down")
mock_conn.connection.cursor.side_effect = OperationalError("down")
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with rls_transaction(tenant_id):
wrapper = mock_conn.execute_wrappers[0]
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
with pytest.raises(OperationalError):
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
def test_wrapper_not_installed_on_primary(self, tenants_fixture):
"""execute_wrapper is not installed when targeting primary DB."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=None):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_conn.execute_wrappers = []
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
with rls_transaction(tenant_id):
# No wrapper installed on primary
assert len(mock_conn.execute_wrappers) == 0
def test_stale_connection_closed_on_pre_yield_retry(
self, tenants_fixture, enable_read_replica
):
"""Stale connection is closed before each pre-yield retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_conn.execute_wrappers = []
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
call_count = 0
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
raise OperationalError("Connection error")
return MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
with patch(
"api.db_utils.transaction.atomic", side_effect=atomic_side_effect
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with rls_transaction(tenant_id):
pass
# close() called for each failed pre-yield attempt
assert mock_conn.close.call_count == 2
def test_caller_error_propagates_after_successful_failover(
self, tenants_fixture, enable_read_replica
):
"""OperationalError raised by caller after failover is NOT suppressed."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
)
mock_replica_conn.ensure_connection.side_effect = OperationalError(
"replica down"
)
mock_primary_conn = MagicMock()
mock_primary_raw_cursor = MagicMock()
mock_primary_conn.connection.cursor.return_value = (
mock_primary_raw_cursor
)
def connections_getitem(alias):
if alias == "replica":
return mock_replica_conn
return mock_primary_conn
mock_connections.__getitem__.side_effect = connections_getitem
mock_connections.__contains__.return_value = True
# transaction.atomic().__exit__ raises on dead replica cleanup
mock_atomic_cm = MagicMock()
mock_atomic_cm.__enter__ = MagicMock(return_value=None)
mock_atomic_cm.__exit__ = MagicMock(
side_effect=OperationalError("cleanup on dead replica")
)
with patch(
"api.db_utils.transaction.atomic", return_value=mock_atomic_cm
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
# Trigger failover (succeeds on primary)
wrapper = mock_replica_conn.execute_wrappers[0]
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
# Caller raises after successful failover —
# must NOT be suppressed
raise OperationalError("caller error")
def test_fallback_atomic_receives_exception_info(
self, tenants_fixture, enable_read_replica
):
"""Primary transaction rollback uses sys.exc_info(), not (None, None, None)."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_replica_conn = MagicMock()
mock_replica_conn.execute_wrappers = []
@contextmanager
def _execute_wrapper_replica(fn):
mock_replica_conn.execute_wrappers.append(fn)
try:
yield
finally:
mock_replica_conn.execute_wrappers.remove(fn)
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
mock_cursor = MagicMock()
mock_replica_conn.cursor.return_value.__enter__.return_value = (
mock_cursor
)
mock_replica_conn.ensure_connection.side_effect = OperationalError(
"replica down"
)
mock_primary_conn = MagicMock()
mock_primary_raw_cursor = MagicMock()
mock_primary_conn.connection.cursor.return_value = (
mock_primary_raw_cursor
)
def connections_getitem(alias):
if alias == "replica":
return mock_replica_conn
return mock_primary_conn
mock_connections.__getitem__.side_effect = connections_getitem
mock_connections.__contains__.return_value = True
mock_fallback_atomic = MagicMock()
mock_fallback_atomic.__enter__ = MagicMock(return_value=None)
mock_fallback_atomic.__exit__ = MagicMock(return_value=False)
# Outer atomic raises on dead replica cleanup
mock_outer_atomic = MagicMock()
mock_outer_atomic.__enter__ = MagicMock(return_value=None)
mock_outer_atomic.__exit__ = MagicMock(
side_effect=OperationalError("cleanup on dead replica")
)
atomic_call_count = 0
def atomic_side_effect(*args, **kwargs):
nonlocal atomic_call_count
atomic_call_count += 1
# First call is the outer transaction.atomic(using=replica)
if atomic_call_count == 1:
return mock_outer_atomic
# Second call is _fallback["atomic"] inside the wrapper
return mock_fallback_atomic
with patch(
"api.db_utils.transaction.atomic",
side_effect=atomic_side_effect,
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
wrapper = mock_replica_conn.execute_wrappers[0]
mock_execute = MagicMock(
side_effect=OperationalError("EOF")
)
mock_context = {"cursor": MagicMock()}
wrapper(
mock_execute,
"SELECT 1",
None,
False,
mock_context,
)
raise OperationalError("caller error")
# Verify __exit__ was called with exception info,
# not (None, None, None)
exit_args = mock_fallback_atomic.__exit__.call_args
exc_type = exit_args[0][0]
assert exc_type is not None
assert issubclass(exc_type, OperationalError)
class TestPostgresEnumMigration:
"""