feat(rls-transaction): add retry for read replica connections (#9064)

This commit is contained in:
Víctor Fernández Poyatos
2025-10-31 11:09:05 +01:00
committed by GitHub
parent 5d4415d090
commit c5e88f4a74
5 changed files with 622 additions and 20 deletions

2
.env
View File

@@ -35,6 +35,8 @@ POSTGRES_DB=prowler_db
# POSTGRES_REPLICA_USER=prowler
# POSTGRES_REPLICA_PASSWORD=postgres
# POSTGRES_REPLICA_DB=prowler_db
# POSTGRES_REPLICA_MAX_ATTEMPTS=3
# POSTGRES_REPLICA_RETRY_BASE_DELAY=0.5
# Celery-Prowler task settings
TASK_RETRY_DELAY_SECONDS=0.1

View File

@@ -13,12 +13,17 @@ All notable changes to the **Prowler API** are documented in this file.
- Support muting findings based on simple rules with custom reason [(#9051)](https://github.com/prowler-cloud/prowler/pull/9051)
- Support C5 compliance framework for the GCP provider [(#9097)](https://github.com/prowler-cloud/prowler/pull/9097)
---
## [1.14.1] (Prowler 5.13.1)
### Fixed
- `/api/v1/overviews/providers` collapses data by provider type so the UI receives a single aggregated record per cloud family even when multiple accounts exist [(#9053)](https://github.com/prowler-cloud/prowler/pull/9053)
- Added retry logic to database transactions to handle Aurora read replica connection failures during scale-down events [(#9064)](https://github.com/prowler-cloud/prowler/pull/9064)
- Security Hub integrations stop failing when they read relationships via the replica by allowing replica relations and saving updates through the primary [(#9080)](https://github.com/prowler-cloud/prowler/pull/9080)
---
## [1.14.0] (Prowler 5.13.0)
### Added

View File

@@ -1,18 +1,35 @@
import re
import secrets
import time
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from celery.utils.log import get_task_logger
from config.env import env
from django.conf import settings
from django.contrib.auth.models import BaseUserManager
from django.db import DEFAULT_DB_ALIAS, connection, connections, models, transaction
from django.db import (
DEFAULT_DB_ALIAS,
OperationalError,
connection,
connections,
models,
transaction,
)
from django_celery_beat.models import PeriodicTask
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError
from api.db_router import get_read_db_alias, reset_read_db_alias, set_read_db_alias
from api.db_router import (
READ_REPLICA_ALIAS,
get_read_db_alias,
reset_read_db_alias,
set_read_db_alias,
)
logger = get_task_logger(__name__)
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
@@ -28,6 +45,9 @@ TASK_RUNNER_DB_TABLE = "django_celery_results_taskresult"
POSTGRES_TENANT_VAR = "api.tenant_id"
POSTGRES_USER_VAR = "api.user_id"
REPLICA_MAX_ATTEMPTS = env.int("POSTGRES_REPLICA_MAX_ATTEMPTS", default=3)
REPLICA_RETRY_BASE_DELAY = env.float("POSTGRES_REPLICA_RETRY_BASE_DELAY", default=0.5)
SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"
@@ -71,24 +91,51 @@ def rls_transaction(
if db_alias not in connections:
db_alias = DEFAULT_DB_ALIAS
router_token = None
try:
if db_alias != DEFAULT_DB_ALIAS:
router_token = set_read_db_alias(db_alias)
alias = db_alias
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica else 1
with transaction.atomic(using=db_alias):
conn = connections[db_alias]
with conn.cursor() as cursor:
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor
finally:
if router_token is not None:
reset_read_db_alias(router_token)
for attempt in range(1, max_attempts + 1):
router_token = None
# On final attempt, fallback to primary
if attempt == max_attempts and is_replica:
logger.warning(
f"RLS transaction failed after {attempt - 1} attempts on replica, "
f"falling back to primary DB"
)
alias = DEFAULT_DB_ALIAS
conn = connections[alias]
try:
if alias != DEFAULT_DB_ALIAS:
router_token = set_read_db_alias(alias)
with transaction.atomic(using=alias):
with conn.cursor() as cursor:
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor
return
except OperationalError as e:
# If on primary or max attempts reached, raise
if not is_replica or attempt == max_attempts:
raise
# Retry with exponential backoff
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.info(
f"RLS transaction failed on replica (attempt {attempt}/{max_attempts}), "
f"retrying in {delay}s. Error: {e}"
)
time.sleep(delay)
finally:
if router_token is not None:
reset_read_db_alias(router_token)
class CustomUserManager(BaseUserManager):

View File

@@ -0,0 +1,39 @@
"""Tests for rls_transaction retry and fallback logic."""
import pytest
from django.db import DEFAULT_DB_ALIAS
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import rls_transaction
@pytest.mark.django_db
class TestRLSTransaction:
"""Simple integration tests for rls_transaction using real DB."""
@pytest.fixture
def tenant(self, tenants_fixture):
return tenants_fixture[0]
def test_success_on_primary(self, tenant):
"""Basic: transaction succeeds on primary database."""
with rls_transaction(str(tenant.id), using=DEFAULT_DB_ALIAS) as cursor:
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result == (1,)
def test_invalid_uuid_raises_validation_error(self):
"""Invalid UUID raises ValidationError before DB operations."""
with pytest.raises(ValidationError, match="Must be a valid UUID"):
with rls_transaction("not-a-uuid", using=DEFAULT_DB_ALIAS):
pass
def test_custom_parameter_name(self, tenant):
"""Test custom RLS parameter name."""
custom_param = "api.custom_id"
with rls_transaction(
str(tenant.id), parameter=custom_param, using=DEFAULT_DB_ALIAS
) as cursor:
cursor.execute("SELECT current_setting(%s, true)", [custom_param])
result = cursor.fetchone()
assert result == (str(tenant.id),)

View File

@@ -1,12 +1,15 @@
from datetime import datetime, timezone
from enum import Enum
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, OperationalError
from freezegun import freeze_time
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import (
POSTGRES_TENANT_VAR,
_should_create_index_on_partition,
batch_delete,
create_objects_in_batches,
@@ -14,11 +17,22 @@ from api.db_utils import (
generate_api_key_prefix,
generate_random_token,
one_week_from_now,
rls_transaction,
update_objects_in_batches,
)
from api.models import Provider
@pytest.fixture
def enable_read_replica():
"""
Fixture to enable READ_REPLICA_ALIAS for tests that need replica functionality.
This avoids polluting the global test configuration.
"""
with patch("api.db_utils.READ_REPLICA_ALIAS", "replica"):
yield "replica"
class TestEnumToChoices:
def test_enum_to_choices_simple(self):
class Color(Enum):
@@ -339,3 +353,498 @@ class TestGenerateApiKeyPrefix:
prefix = generate_api_key_prefix()
random_part = prefix[3:] # Strip 'pk_'
assert all(char in allowed_chars for char in random_part)
@pytest.mark.django_db
class TestRlsTransaction:
def test_rls_transaction_valid_uuid_string(self, tenants_fixture):
"""Test rls_transaction with valid UUID string."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with rls_transaction(tenant_id) as cursor:
assert cursor is not None
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == tenant_id
def test_rls_transaction_valid_uuid_object(self, tenants_fixture):
"""Test rls_transaction with UUID object."""
tenant = tenants_fixture[0]
with rls_transaction(tenant.id) as cursor:
assert cursor is not None
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == str(tenant.id)
def test_rls_transaction_invalid_uuid_raises_validation_error(self):
"""Test rls_transaction raises ValidationError for invalid UUID."""
invalid_uuid = "not-a-valid-uuid"
with pytest.raises(ValidationError, match="Must be a valid UUID"):
with rls_transaction(invalid_uuid):
pass
def test_rls_transaction_uses_default_database_when_no_alias(self, tenants_fixture):
"""Test rls_transaction uses DEFAULT_DB_ALIAS when no alias specified."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=None):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with rls_transaction(tenant_id):
pass
mock_connections.__getitem__.assert_called_with(DEFAULT_DB_ALIAS)
def test_rls_transaction_uses_specified_alias(self, tenants_fixture):
"""Test rls_transaction uses specified database alias via using parameter."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
custom_alias = "custom_db"
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
mock_set_alias.return_value = "test_token"
with rls_transaction(tenant_id, using=custom_alias):
pass
mock_connections.__getitem__.assert_called_with(custom_alias)
mock_set_alias.assert_called_once_with(custom_alias)
mock_reset_alias.assert_called_once_with("test_token")
def test_rls_transaction_uses_read_replica_from_router(
self, tenants_fixture, enable_read_replica
):
"""Test rls_transaction uses read replica alias from router."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
with patch(
"api.db_utils.reset_read_db_alias"
) as mock_reset_alias:
mock_set_alias.return_value = "test_token"
with rls_transaction(tenant_id):
pass
mock_connections.__getitem__.assert_called()
mock_set_alias.assert_called_once()
mock_reset_alias.assert_called_once()
def test_rls_transaction_fallback_to_default_when_alias_not_in_connections(
self, tenants_fixture
):
"""Test rls_transaction falls back to DEFAULT_DB_ALIAS when alias not in connections."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
invalid_alias = "nonexistent_db"
with patch("api.db_utils.get_read_db_alias", return_value=invalid_alias):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
def contains_check(alias):
return alias == DEFAULT_DB_ALIAS
mock_connections.__contains__.side_effect = contains_check
mock_connections.__getitem__.return_value = mock_conn
with patch("api.db_utils.transaction.atomic"):
with rls_transaction(tenant_id):
pass
mock_connections.__getitem__.assert_called_with(DEFAULT_DB_ALIAS)
def test_rls_transaction_successful_execution_on_replica_no_retries(
self, tenants_fixture, enable_read_replica
):
"""Test successful execution on replica without retries."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.set_read_db_alias", return_value="token"):
with patch("api.db_utils.reset_read_db_alias"):
with rls_transaction(tenant_id):
pass
assert mock_cursor.execute.call_count == 1
def test_rls_transaction_retry_with_exponential_backoff_on_operational_error(
self, tenants_fixture, enable_read_replica
):
"""Test retry with exponential backoff on OperationalError on replica."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
call_count = 0
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
raise OperationalError("Connection error")
return MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
with patch(
"api.db_utils.transaction.atomic", side_effect=atomic_side_effect
):
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with patch("api.db_utils.logger") as mock_logger:
with rls_transaction(tenant_id):
pass
assert mock_sleep.call_count == 2
mock_sleep.assert_any_call(0.5)
mock_sleep.assert_any_call(1.0)
assert mock_logger.info.call_count == 2
def test_rls_transaction_max_three_attempts_for_replica(
self, tenants_fixture, enable_read_replica
):
"""Test maximum 3 attempts for replica database."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.side_effect = OperationalError("Persistent error")
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
assert mock_atomic.call_count == 3
def test_rls_transaction_only_one_attempt_for_primary(self, tenants_fixture):
"""Test only 1 attempt for primary database."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=None):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.side_effect = OperationalError("Primary error")
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
assert mock_atomic.call_count == 1
def test_rls_transaction_fallback_to_primary_after_max_attempts(
self, tenants_fixture, enable_read_replica
):
"""Test fallback to primary DB after max attempts on replica."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
call_count = 0
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
raise OperationalError("Replica error")
return MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
with patch(
"api.db_utils.transaction.atomic", side_effect=atomic_side_effect
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with patch("api.db_utils.logger") as mock_logger:
with rls_transaction(tenant_id):
pass
mock_logger.warning.assert_called_once()
warning_msg = mock_logger.warning.call_args[0][0]
assert "falling back to primary DB" in warning_msg
def test_rls_transaction_logger_warning_on_fallback(
self, tenants_fixture, enable_read_replica
):
"""Test logger warnings are emitted on fallback to primary."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
call_count = 0
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
raise OperationalError("Replica error")
return MagicMock(
__enter__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=False),
)
with patch(
"api.db_utils.transaction.atomic", side_effect=atomic_side_effect
):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with patch("api.db_utils.logger") as mock_logger:
with rls_transaction(tenant_id):
pass
assert mock_logger.info.call_count == 2
assert mock_logger.warning.call_count == 1
def test_rls_transaction_operational_error_raised_immediately_on_primary(
self, tenants_fixture
):
"""Test OperationalError raised immediately on primary without retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=None):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.side_effect = OperationalError("Primary error")
with patch("api.db_utils.time.sleep") as mock_sleep:
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
mock_sleep.assert_not_called()
def test_rls_transaction_operational_error_raised_after_max_attempts(
self, tenants_fixture, enable_read_replica
):
"""Test OperationalError raised after max attempts on replica."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.side_effect = OperationalError(
"Persistent replica error"
)
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
def test_rls_transaction_router_token_set_for_non_default_alias(
self, tenants_fixture
):
"""Test router token is set when using non-default alias."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
custom_alias = "custom_db"
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
mock_set_alias.return_value = "test_token"
with rls_transaction(tenant_id, using=custom_alias):
pass
mock_set_alias.assert_called_once_with(custom_alias)
mock_reset_alias.assert_called_once_with("test_token")
def test_rls_transaction_router_token_reset_in_finally_block(self, tenants_fixture):
"""Test router token is reset in finally block even on error."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
custom_alias = "custom_db"
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.side_effect = Exception("Unexpected error")
with patch("api.db_utils.set_read_db_alias", return_value="test_token"):
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
with pytest.raises(Exception):
with rls_transaction(tenant_id, using=custom_alias):
pass
mock_reset_alias.assert_called_once_with("test_token")
def test_rls_transaction_router_token_not_set_for_default_alias(
self, tenants_fixture
):
"""Test router token is not set when using default alias."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=None):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
with patch(
"api.db_utils.reset_read_db_alias"
) as mock_reset_alias:
with rls_transaction(tenant_id):
pass
mock_set_alias.assert_not_called()
mock_reset_alias.assert_not_called()
def test_rls_transaction_set_config_query_executed_with_correct_params(
self, tenants_fixture
):
"""Test SET_CONFIG_QUERY executed with correct parameters."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with rls_transaction(tenant_id) as cursor:
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == tenant_id
def test_rls_transaction_custom_parameter(self, tenants_fixture):
"""Test rls_transaction with custom parameter name."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
custom_param = "api.user_id"
with rls_transaction(tenant_id, parameter=custom_param) as cursor:
cursor.execute("SELECT current_setting(%s)", [custom_param])
result = cursor.fetchone()
assert result[0] == tenant_id
def test_rls_transaction_cursor_yielded_correctly(self, tenants_fixture):
"""Test cursor is yielded correctly."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with rls_transaction(tenant_id) as cursor:
assert cursor is not None
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result[0] == 1