mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-22 03:08:23 +00:00
fix(api): address review findings in rls_transaction failover logic
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import re
|
||||
import secrets
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
@@ -131,7 +132,7 @@ def rls_transaction(
|
||||
try:
|
||||
connections[replica_alias].close()
|
||||
except Exception:
|
||||
pass
|
||||
pass # Best-effort; connection may already be dead
|
||||
|
||||
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (retry - 1))
|
||||
logger.info(
|
||||
@@ -160,7 +161,7 @@ def rls_transaction(
|
||||
try:
|
||||
connections[replica_alias].close()
|
||||
except Exception:
|
||||
pass
|
||||
pass # Best-effort; connection may already be dead
|
||||
|
||||
logger.warning(
|
||||
"Mid-query replica retries exhausted, falling back to primary DB"
|
||||
@@ -185,7 +186,7 @@ def rls_transaction(
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
router_token = None
|
||||
yielded_cursor = False
|
||||
wrapper_installed = False
|
||||
_caller_exited_cleanly = False
|
||||
|
||||
# On final attempt, fall back to primary
|
||||
if attempt == max_attempts and can_failover:
|
||||
@@ -209,21 +210,24 @@ def rls_transaction(
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
|
||||
if can_failover and alias == replica_alias:
|
||||
conn.execute_wrappers.append(_query_failover)
|
||||
wrapper_installed = True
|
||||
|
||||
yielded_cursor = True
|
||||
yield cursor
|
||||
wrapper_cm = (
|
||||
conn.execute_wrapper(_query_failover)
|
||||
if can_failover and alias == replica_alias
|
||||
else nullcontext()
|
||||
)
|
||||
with wrapper_cm:
|
||||
yielded_cursor = True
|
||||
yield cursor
|
||||
_caller_exited_cleanly = True
|
||||
return
|
||||
except OperationalError as e:
|
||||
try:
|
||||
connections[alias].close()
|
||||
except Exception:
|
||||
pass
|
||||
pass # Best-effort; connection may already be dead
|
||||
|
||||
if yielded_cursor:
|
||||
if _fallback["succeeded"]:
|
||||
if _fallback["succeeded"] and _caller_exited_cleanly:
|
||||
# Caller's queries succeeded on primary via failover.
|
||||
# This error is transaction.atomic() cleanup on the
|
||||
# dead replica connection — suppress it.
|
||||
@@ -241,17 +245,11 @@ def rls_transaction(
|
||||
)
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if wrapper_installed:
|
||||
try:
|
||||
conn.execute_wrappers.remove(_query_failover)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if _fallback["atomic"] is not None:
|
||||
try:
|
||||
_fallback["atomic"].__exit__(None, None, None)
|
||||
_fallback["atomic"].__exit__(*sys.exc_info())
|
||||
except Exception:
|
||||
pass
|
||||
pass # Best-effort; primary connection may be dead
|
||||
_fallback["atomic"] = None
|
||||
if _fallback["token"] is not None:
|
||||
reset_read_db_alias(_fallback["token"])
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -928,6 +929,16 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper(fn):
|
||||
mock_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_conn.execute_wrapper = _execute_wrapper
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
@@ -987,6 +998,16 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
@@ -1064,6 +1085,16 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
@@ -1125,6 +1156,16 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper(fn):
|
||||
mock_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_conn.execute_wrapper = _execute_wrapper
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
# Both replica retry and primary fail
|
||||
@@ -1223,6 +1264,180 @@ class TestRlsTransaction:
|
||||
# close() called for each failed pre-yield attempt
|
||||
assert mock_conn.close.call_count == 2
|
||||
|
||||
def test_caller_error_propagates_after_successful_failover(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""OperationalError raised by caller after failover is NOT suppressed."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_replica_conn.ensure_connection.side_effect = OperationalError(
|
||||
"replica down"
|
||||
)
|
||||
|
||||
mock_primary_conn = MagicMock()
|
||||
mock_primary_raw_cursor = MagicMock()
|
||||
mock_primary_conn.connection.cursor.return_value = (
|
||||
mock_primary_raw_cursor
|
||||
)
|
||||
|
||||
def connections_getitem(alias):
|
||||
if alias == "replica":
|
||||
return mock_replica_conn
|
||||
return mock_primary_conn
|
||||
|
||||
mock_connections.__getitem__.side_effect = connections_getitem
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
# transaction.atomic().__exit__ raises on dead replica cleanup
|
||||
mock_atomic_cm = MagicMock()
|
||||
mock_atomic_cm.__enter__ = MagicMock(return_value=None)
|
||||
mock_atomic_cm.__exit__ = MagicMock(
|
||||
side_effect=OperationalError("cleanup on dead replica")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.db_utils.transaction.atomic", return_value=mock_atomic_cm
|
||||
):
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
# Trigger failover (succeeds on primary)
|
||||
wrapper = mock_replica_conn.execute_wrappers[0]
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
# Caller raises after successful failover —
|
||||
# must NOT be suppressed
|
||||
raise OperationalError("caller error")
|
||||
|
||||
def test_fallback_atomic_receives_exception_info(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Primary transaction rollback uses sys.exc_info(), not (None, None, None)."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_replica_conn = MagicMock()
|
||||
mock_replica_conn.execute_wrappers = []
|
||||
|
||||
@contextmanager
|
||||
def _execute_wrapper_replica(fn):
|
||||
mock_replica_conn.execute_wrappers.append(fn)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
mock_replica_conn.execute_wrappers.remove(fn)
|
||||
|
||||
mock_replica_conn.execute_wrapper = _execute_wrapper_replica
|
||||
mock_cursor = MagicMock()
|
||||
mock_replica_conn.cursor.return_value.__enter__.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_replica_conn.ensure_connection.side_effect = OperationalError(
|
||||
"replica down"
|
||||
)
|
||||
|
||||
mock_primary_conn = MagicMock()
|
||||
mock_primary_raw_cursor = MagicMock()
|
||||
mock_primary_conn.connection.cursor.return_value = (
|
||||
mock_primary_raw_cursor
|
||||
)
|
||||
|
||||
def connections_getitem(alias):
|
||||
if alias == "replica":
|
||||
return mock_replica_conn
|
||||
return mock_primary_conn
|
||||
|
||||
mock_connections.__getitem__.side_effect = connections_getitem
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
mock_fallback_atomic = MagicMock()
|
||||
mock_fallback_atomic.__enter__ = MagicMock(return_value=None)
|
||||
mock_fallback_atomic.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# Outer atomic raises on dead replica cleanup
|
||||
mock_outer_atomic = MagicMock()
|
||||
mock_outer_atomic.__enter__ = MagicMock(return_value=None)
|
||||
mock_outer_atomic.__exit__ = MagicMock(
|
||||
side_effect=OperationalError("cleanup on dead replica")
|
||||
)
|
||||
|
||||
atomic_call_count = 0
|
||||
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal atomic_call_count
|
||||
atomic_call_count += 1
|
||||
# First call is the outer transaction.atomic(using=replica)
|
||||
if atomic_call_count == 1:
|
||||
return mock_outer_atomic
|
||||
# Second call is _fallback["atomic"] inside the wrapper
|
||||
return mock_fallback_atomic
|
||||
|
||||
with patch(
|
||||
"api.db_utils.transaction.atomic",
|
||||
side_effect=atomic_side_effect,
|
||||
):
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
wrapper = mock_replica_conn.execute_wrappers[0]
|
||||
mock_execute = MagicMock(
|
||||
side_effect=OperationalError("EOF")
|
||||
)
|
||||
mock_context = {"cursor": MagicMock()}
|
||||
wrapper(
|
||||
mock_execute,
|
||||
"SELECT 1",
|
||||
None,
|
||||
False,
|
||||
mock_context,
|
||||
)
|
||||
raise OperationalError("caller error")
|
||||
|
||||
# Verify __exit__ was called with exception info,
|
||||
# not (None, None, None)
|
||||
exit_args = mock_fallback_atomic.__exit__.call_args
|
||||
exc_type = exit_args[0][0]
|
||||
assert exc_type is not None
|
||||
assert issubclass(exc_type, OperationalError)
|
||||
|
||||
|
||||
class TestPostgresEnumMigration:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user