mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-04-03 05:55:54 +00:00
Compare commits
9 Commits
fix/ui-fin
...
PROWLER-12
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8a654a1a1 | ||
|
|
f58bc362df | ||
|
|
fb245d8136 | ||
|
|
e43d0aa806 | ||
|
|
dd3de0bd7a | ||
|
|
32d2127e89 | ||
|
|
33af438dc5 | ||
|
|
e86e895e96 | ||
|
|
e31829ebcd |
@@ -6,6 +6,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- `rls_transaction` to retry mid-query read replica failures with primary DB fallback via `execute_wrapper`, preventing scan crashes during replica recovery [(#10379)](https://github.com/prowler-cloud/prowler/pull/10379)
|
||||
- Finding groups latest endpoint now aggregates the latest snapshot per provider before check-level totals, keeping impacted resources aligned across providers [(#10419)](https://github.com/prowler-cloud/prowler/pull/10419)
|
||||
- Mute rule creation now triggers finding-group summary re-aggregation after historical muting, keeping stats in sync after mute operations [(#10419)](https://github.com/prowler-cloud/prowler/pull/10419)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import re
|
||||
import secrets
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
@@ -78,14 +79,34 @@ def rls_transaction(
|
||||
retry_on_replica: bool = True,
|
||||
):
|
||||
"""
|
||||
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
|
||||
if the value is a valid UUID.
|
||||
Context manager that opens an RLS-scoped database transaction.
|
||||
|
||||
Sets a Postgres configuration variable (``set_config``) so that
|
||||
Row-Level Security policies can filter by tenant. When *using*
|
||||
points to a read replica and *retry_on_replica* is True, two
|
||||
layers of retry protect against replica failures:
|
||||
|
||||
1. **Pre-yield** (connection-setup failures): the function retries
|
||||
up to ``REPLICA_MAX_ATTEMPTS`` times on the replica, then falls
|
||||
back to the primary DB.
|
||||
2. **Post-yield** (mid-query failures): an ``execute_wrapper``
|
||||
intercepts ``OperationalError`` during ``cursor.execute()``
|
||||
calls, retries on the replica with backoff, and falls back to
|
||||
the primary if the replica stays down. The wrapper swaps the
|
||||
inner psycopg2 cursor so ``fetchall()`` / ``fetchone()`` read
|
||||
from the new connection transparently.
|
||||
|
||||
Limitation: server-side cursors (``.iterator()``) fetch rows via
|
||||
``fetchmany()``, which the wrapper does not intercept. Call sites
|
||||
that iterate large result sets with ``.iterator()`` on the replica
|
||||
should add their own retry logic.
|
||||
|
||||
Args:
|
||||
value (str): Database configuration parameter value.
|
||||
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
|
||||
using (str | None): Optional database alias to run the transaction against. Defaults to the
|
||||
active read alias (if any) or Django's default connection.
|
||||
value: Database configuration parameter value (must be a valid UUID).
|
||||
parameter: Database configuration parameter name.
|
||||
using: Optional database alias. Defaults to the active read
|
||||
alias or Django's default connection.
|
||||
retry_on_replica: Whether to retry on replica failures.
|
||||
"""
|
||||
requested_alias = using or get_read_db_alias()
|
||||
db_alias = requested_alias or DEFAULT_DB_ALIAS
|
||||
@@ -93,19 +114,87 @@ def rls_transaction(
|
||||
db_alias = DEFAULT_DB_ALIAS
|
||||
|
||||
alias = db_alias
|
||||
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
|
||||
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
|
||||
is_replica = bool(READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS)
|
||||
can_failover = is_replica and retry_on_replica
|
||||
replica_alias = alias # captured before the loop mutates alias
|
||||
max_attempts = (REPLICA_MAX_ATTEMPTS + 1) if can_failover else 1
|
||||
|
||||
# State shared between the generator and the _query_failover closure
|
||||
_fallback = {"succeeded": False, "atomic": None, "token": None}
|
||||
|
||||
def _query_failover(execute, sql, params, many, context):
|
||||
"""execute_wrapper: retry failed replica queries, then fall back to primary."""
|
||||
try:
|
||||
return execute(sql, params, many, context)
|
||||
except OperationalError as err:
|
||||
# Phase 1 — retry on replica with exponential backoff
|
||||
for retry in range(1, REPLICA_MAX_ATTEMPTS + 1):
|
||||
try:
|
||||
connections[replica_alias].close()
|
||||
except Exception:
|
||||
pass # Best-effort; connection may already be dead
|
||||
|
||||
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (retry - 1))
|
||||
logger.info(
|
||||
f"Mid-query failure on replica (retry {retry}/{REPLICA_MAX_ATTEMPTS}), "
|
||||
f"retrying in {delay:.1f}s. Error: {err}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
try:
|
||||
replica_conn = connections[replica_alias]
|
||||
replica_conn.ensure_connection()
|
||||
replica_conn.connection.autocommit = False
|
||||
raw = replica_conn.connection.cursor()
|
||||
raw.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
if many:
|
||||
raw.executemany(sql, params)
|
||||
else:
|
||||
raw.execute(sql, params)
|
||||
context["cursor"].cursor = raw
|
||||
return None
|
||||
except OperationalError as retry_err:
|
||||
err = retry_err
|
||||
continue
|
||||
|
||||
# Phase 2 — fall back to primary
|
||||
try:
|
||||
connections[replica_alias].close()
|
||||
except Exception:
|
||||
pass # Best-effort; connection may already be dead
|
||||
|
||||
logger.warning(
|
||||
"Mid-query replica retries exhausted, falling back to primary DB"
|
||||
)
|
||||
primary = connections[DEFAULT_DB_ALIAS]
|
||||
primary.ensure_connection()
|
||||
_fallback["atomic"] = transaction.atomic(using=DEFAULT_DB_ALIAS)
|
||||
_fallback["atomic"].__enter__()
|
||||
with primary.cursor() as setup_cursor:
|
||||
setup_cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
_fallback["token"] = set_read_db_alias(DEFAULT_DB_ALIAS)
|
||||
|
||||
raw = primary.connection.cursor()
|
||||
if many:
|
||||
raw.executemany(sql, params)
|
||||
else:
|
||||
raw.execute(sql, params)
|
||||
context["cursor"].cursor = raw
|
||||
_fallback["succeeded"] = True
|
||||
return None
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
router_token = None
|
||||
yielded_cursor = False
|
||||
_caller_exited_cleanly = False
|
||||
|
||||
# On final attempt, fallback to primary
|
||||
if attempt == max_attempts and is_replica:
|
||||
logger.warning(
|
||||
f"RLS transaction failed after {attempt - 1} attempts on replica, "
|
||||
f"falling back to primary DB"
|
||||
)
|
||||
# On final attempt, fall back to primary
|
||||
if attempt == max_attempts and can_failover:
|
||||
if attempt > 1:
|
||||
logger.warning(
|
||||
f"RLS transaction failed after {attempt - 1} attempts on replica, "
|
||||
f"falling back to primary DB"
|
||||
)
|
||||
alias = DEFAULT_DB_ALIAS
|
||||
|
||||
conn = connections[alias]
|
||||
@@ -116,19 +205,36 @@ def rls_transaction(
|
||||
with transaction.atomic(using=alias):
|
||||
with conn.cursor() as cursor:
|
||||
try:
|
||||
# just in case the value is a UUID object
|
||||
uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
yielded_cursor = True
|
||||
yield cursor
|
||||
|
||||
wrapper_cm = (
|
||||
conn.execute_wrapper(_query_failover)
|
||||
if can_failover and alias == replica_alias
|
||||
else nullcontext()
|
||||
)
|
||||
with wrapper_cm:
|
||||
yielded_cursor = True
|
||||
yield cursor
|
||||
_caller_exited_cleanly = True
|
||||
return
|
||||
except OperationalError as e:
|
||||
try:
|
||||
connections[alias].close()
|
||||
except Exception:
|
||||
pass # Best-effort; connection may already be dead
|
||||
|
||||
if yielded_cursor:
|
||||
if _fallback["succeeded"] and _caller_exited_cleanly:
|
||||
# Caller's queries succeeded on primary via failover.
|
||||
# This error is transaction.atomic() cleanup on the
|
||||
# dead replica connection — suppress it.
|
||||
return
|
||||
raise
|
||||
# If on primary or max attempts reached, raise
|
||||
if not is_replica or attempt == max_attempts:
|
||||
|
||||
if not can_failover or attempt == max_attempts:
|
||||
raise
|
||||
|
||||
# Retry with exponential backoff
|
||||
@@ -139,6 +245,16 @@ def rls_transaction(
|
||||
)
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if _fallback["atomic"] is not None:
|
||||
try:
|
||||
_fallback["atomic"].__exit__(*sys.exc_info())
|
||||
except Exception:
|
||||
pass # Best-effort; primary connection may be dead
|
||||
_fallback["atomic"] = None
|
||||
if _fallback["token"] is not None:
|
||||
reset_read_db_alias(_fallback["token"])
|
||||
_fallback["token"] = None
|
||||
|
||||
if router_token is not None:
|
||||
reset_read_db_alias(router_token)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -528,7 +529,7 @@ class TestRlsTransaction:
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
if call_count < 4:
|
||||
raise OperationalError("Connection error")
|
||||
return MagicMock(
|
||||
__enter__=MagicMock(return_value=None),
|
||||
@@ -547,10 +548,11 @@ class TestRlsTransaction:
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
|
||||
assert mock_sleep.call_count == 2
|
||||
assert mock_sleep.call_count == 3
|
||||
mock_sleep.assert_any_call(0.5)
|
||||
mock_sleep.assert_any_call(1.0)
|
||||
assert mock_logger.info.call_count == 2
|
||||
mock_sleep.assert_any_call(2.0)
|
||||
assert mock_logger.info.call_count == 3
|
||||
|
||||
def test_rls_transaction_operational_error_inside_context_no_retry(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
@@ -582,10 +584,10 @@ class TestRlsTransaction:
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_rls_transaction_max_three_attempts_for_replica(
|
||||
def test_rls_transaction_max_attempts_for_replica(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test maximum 3 attempts for replica database."""
|
||||
"""Test REPLICA_MAX_ATTEMPTS replica tries + 1 primary fallback."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
@@ -609,7 +611,8 @@ class TestRlsTransaction:
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
|
||||
assert mock_atomic.call_count == 3
|
||||
# 3 replica + 1 primary = 4 total
|
||||
assert mock_atomic.call_count == 4
|
||||
|
||||
def test_rls_transaction_replica_no_retry_when_disabled(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
@@ -685,7 +688,7 @@ class TestRlsTransaction:
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
if call_count < 4:
|
||||
raise OperationalError("Replica error")
|
||||
return MagicMock(
|
||||
__enter__=MagicMock(return_value=None),
|
||||
@@ -728,7 +731,7 @@ class TestRlsTransaction:
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
if call_count < 4:
|
||||
raise OperationalError("Replica error")
|
||||
return MagicMock(
|
||||
__enter__=MagicMock(return_value=None),
|
||||
@@ -747,7 +750,7 @@ class TestRlsTransaction:
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
|
||||
assert mock_logger.info.call_count == 2
|
||||
assert mock_logger.info.call_count == 3
|
||||
assert mock_logger.warning.call_count == 1
|
||||
|
||||
def test_rls_transaction_operational_error_raised_immediately_on_primary(
|
||||
@@ -913,6 +916,528 @@ class TestRlsTransaction:
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == 1
|
||||
|
||||
# --- Mid-query failover tests ---
|
||||
|
||||
def test_mid_query_failure_retries_on_replica(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Mid-query OperationalError retries on replica via execute_wrapper."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper(fn):
|
||||
mock_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_conn.execute_wrapper = _execute_wrapper
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
# The raw replica cursor that succeeds on retry
|
||||
mock_raw_cursor = MagicMock()
|
||||
mock_conn.connection.cursor.return_value = mock_raw_cursor
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with rls_transaction(tenant_id):
|
||||
# Wrapper is installed
|
||||
assert len(mock_conn.execute_wrappers) == 1
|
||||
wrapper = mock_conn.execute_wrappers[0]
|
||||
|
||||
# Simulate a mid-query failure that succeeds on retry
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
|
||||
# Verify retry happened with backoff
|
||||
mock_sleep.assert_called_once_with(0.5)
|
||||
mock_conn.ensure_connection.assert_called_once()
|
||||
# The raw cursor was swapped
|
||||
assert (
|
||||
mock_context["cursor"].cursor == mock_raw_cursor
|
||||
)
|
||||
|
||||
# Wrapper removed after exiting
|
||||
assert len(mock_conn.execute_wrappers) == 0
|
||||
|
||||
def test_mid_query_failure_falls_back_to_primary(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Mid-query failure falls back to primary after replica retries exhausted."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
# Replica always fails on retry
|
||||
mock_replica_conn.ensure_connection.side_effect = OperationalError(
|
||||
"replica down"
|
||||
)
|
||||
|
||||
mock_primary_conn = MagicMock()
|
||||
mock_primary_raw_cursor = MagicMock()
|
||||
mock_primary_conn.connection.cursor.return_value = (
|
||||
mock_primary_raw_cursor
|
||||
)
|
||||
|
||||
def connections_getitem(alias):
|
||||
if alias == "replica":
|
||||
return mock_replica_conn
|
||||
return mock_primary_conn
|
||||
|
||||
mock_connections.__getitem__.side_effect = connections_getitem
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with patch("api.db_utils.logger") as mock_logger:
|
||||
with rls_transaction(tenant_id):
|
||||
wrapper = mock_replica_conn.execute_wrappers[0]
|
||||
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
|
||||
# Verify primary was used
|
||||
mock_primary_conn.ensure_connection.assert_called_once()
|
||||
assert (
|
||||
mock_context["cursor"].cursor
|
||||
== mock_primary_raw_cursor
|
||||
)
|
||||
|
||||
# Verify warning logged
|
||||
warning_msgs = [
|
||||
c[0][0]
|
||||
for c in mock_logger.warning.call_args_list
|
||||
]
|
||||
assert any(
|
||||
"falling back to primary" in m
|
||||
for m in warning_msgs
|
||||
)
|
||||
|
||||
def test_mid_query_fallback_suppresses_cleanup_error(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""After successful primary fallback, replica cleanup error is suppressed."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_replica_conn.ensure_connection.side_effect = OperationalError(
|
||||
"replica down"
|
||||
)
|
||||
|
||||
mock_primary_conn = MagicMock()
|
||||
mock_primary_raw_cursor = MagicMock()
|
||||
mock_primary_conn.connection.cursor.return_value = (
|
||||
mock_primary_raw_cursor
|
||||
)
|
||||
|
||||
def connections_getitem(alias):
|
||||
if alias == "replica":
|
||||
return mock_replica_conn
|
||||
return mock_primary_conn
|
||||
|
||||
mock_connections.__getitem__.side_effect = connections_getitem
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
# transaction.atomic().__exit__ raises on dead replica
|
||||
mock_atomic_cm = MagicMock()
|
||||
mock_atomic_cm.__enter__ = MagicMock(return_value=None)
|
||||
mock_atomic_cm.__exit__ = MagicMock(
|
||||
side_effect=OperationalError("cleanup failed on dead replica")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.db_utils.transaction.atomic", return_value=mock_atomic_cm
|
||||
):
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
# Should NOT raise — the cleanup error is suppressed
|
||||
with rls_transaction(tenant_id):
|
||||
wrapper = mock_replica_conn.execute_wrappers[0]
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
|
||||
def test_mid_query_primary_also_fails(self, tenants_fixture, enable_read_replica):
|
||||
"""When both replica and primary fail, OperationalError propagates."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper(fn):
|
||||
mock_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_conn.execute_wrapper = _execute_wrapper
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
# Both replica retry and primary fail
|
||||
mock_conn.ensure_connection.side_effect = OperationalError("down")
|
||||
mock_conn.connection.cursor.side_effect = OperationalError("down")
|
||||
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with rls_transaction(tenant_id):
|
||||
wrapper = mock_conn.execute_wrappers[0]
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
|
||||
with pytest.raises(OperationalError):
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
|
||||
def test_wrapper_not_installed_on_primary(self, tenants_fixture):
|
||||
"""execute_wrapper is not installed when targeting primary DB."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=None):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute_wrappers = []
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.return_value.__enter__ = MagicMock(return_value=None)
|
||||
mock_atomic.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
# No wrapper installed on primary
|
||||
assert len(mock_conn.execute_wrappers) == 0
|
||||
|
||||
def test_stale_connection_closed_on_pre_yield_retry(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Stale connection is closed before each pre-yield retry."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute_wrappers = []
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
call_count = 0
|
||||
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise OperationalError("Connection error")
|
||||
return MagicMock(
|
||||
__enter__=MagicMock(return_value=None),
|
||||
__exit__=MagicMock(return_value=False),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.db_utils.transaction.atomic", side_effect=atomic_side_effect
|
||||
):
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
|
||||
# close() called for each failed pre-yield attempt
|
||||
assert mock_conn.close.call_count == 2
|
||||
|
||||
def test_caller_error_propagates_after_successful_failover(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""OperationalError raised by caller after failover is NOT suppressed."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_replica_conn.ensure_connection.side_effect = OperationalError(
|
||||
"replica down"
|
||||
)
|
||||
|
||||
mock_primary_conn = MagicMock()
|
||||
mock_primary_raw_cursor = MagicMock()
|
||||
mock_primary_conn.connection.cursor.return_value = (
|
||||
mock_primary_raw_cursor
|
||||
)
|
||||
|
||||
def connections_getitem(alias):
|
||||
if alias == "replica":
|
||||
return mock_replica_conn
|
||||
return mock_primary_conn
|
||||
|
||||
mock_connections.__getitem__.side_effect = connections_getitem
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
# transaction.atomic().__exit__ raises on dead replica cleanup
|
||||
mock_atomic_cm = MagicMock()
|
||||
mock_atomic_cm.__enter__ = MagicMock(return_value=None)
|
||||
mock_atomic_cm.__exit__ = MagicMock(
|
||||
side_effect=OperationalError("cleanup on dead replica")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.db_utils.transaction.atomic", return_value=mock_atomic_cm
|
||||
):
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
# Trigger failover (succeeds on primary)
|
||||
wrapper = mock_replica_conn.execute_wrappers[0]
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
# Caller raises after successful failover —
|
||||
# must NOT be suppressed
|
||||
raise OperationalError("caller error")
|
||||
|
||||
def test_fallback_atomic_receives_exception_info(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Primary transaction rollback uses sys.exc_info(), not (None, None, None)."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_replica_conn.ensure_connection.side_effect = OperationalError(
|
||||
"replica down"
|
||||
)
|
||||
|
||||
mock_primary_conn = MagicMock()
|
||||
mock_primary_raw_cursor = MagicMock()
|
||||
mock_primary_conn.connection.cursor.return_value = (
|
||||
mock_primary_raw_cursor
|
||||
)
|
||||
|
||||
def connections_getitem(alias):
|
||||
if alias == "replica":
|
||||
return mock_replica_conn
|
||||
return mock_primary_conn
|
||||
|
||||
mock_connections.__getitem__.side_effect = connections_getitem
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
mock_fallback_atomic = MagicMock()
|
||||
mock_fallback_atomic.__enter__ = MagicMock(return_value=None)
|
||||
mock_fallback_atomic.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# Outer atomic raises on dead replica cleanup
|
||||
mock_outer_atomic = MagicMock()
|
||||
mock_outer_atomic.__enter__ = MagicMock(return_value=None)
|
||||
mock_outer_atomic.__exit__ = MagicMock(
|
||||
side_effect=OperationalError("cleanup on dead replica")
|
||||
)
|
||||
|
||||
atomic_call_count = 0
|
||||
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal atomic_call_count
|
||||
atomic_call_count += 1
|
||||
# First call is the outer transaction.atomic(using=replica)
|
||||
if atomic_call_count == 1:
|
||||
return mock_outer_atomic
|
||||
# Second call is _fallback["atomic"] inside the wrapper
|
||||
return mock_fallback_atomic
|
||||
|
||||
with patch(
|
||||
"api.db_utils.transaction.atomic",
|
||||
side_effect=atomic_side_effect,
|
||||
):
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
wrapper = mock_replica_conn.execute_wrappers[0]
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
raise OperationalError("caller error")
|
||||
|
||||
# Verify __exit__ was called with exception info,
|
||||
# not (None, None, None)
|
||||
exit_args = mock_fallback_atomic.__exit__.call_args
|
||||
exc_type = exit_args[0][0]
|
||||
assert exc_type is not None
|
||||
assert issubclass(exc_type, OperationalError)
|
||||
|
||||
|
||||
class TestPostgresEnumMigration:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user