fix(api): harden security hub retries (#10144)

This commit is contained in:
Adrián Peña
2026-02-25 11:34:41 +01:00
committed by GitHub
parent 7077a56331
commit e47f2b4033
5 changed files with 263 additions and 80 deletions

View File

@@ -35,6 +35,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Attack Paths: Orphaned temporary Neo4j databases are now cleaned up on scan failure and provider deletion [(#10101)](https://github.com/prowler-cloud/prowler/pull/10101)
- Attack Paths: scan no longer raises `DatabaseError` when provider is deleted mid-scan [(#10116)](https://github.com/prowler-cloud/prowler/pull/10116)
- Security Hub export retries transient replica conflicts without failing integrations [(#10144)](https://github.com/prowler-cloud/prowler/pull/10144)
### 🔐 Security

View File

@@ -74,6 +74,7 @@ def rls_transaction(
value: str,
parameter: str = POSTGRES_TENANT_VAR,
using: str | None = None,
retry_on_replica: bool = True,
):
"""
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
@@ -92,10 +93,11 @@ def rls_transaction(
alias = db_alias
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica else 1
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
for attempt in range(1, max_attempts + 1):
router_token = None
yielded_cursor = False
# On final attempt, fallback to primary
if attempt == max_attempts and is_replica:
@@ -118,9 +120,12 @@ def rls_transaction(
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yielded_cursor = True
yield cursor
return
except OperationalError as e:
if yielded_cursor:
raise
# If on primary or max attempts reached, raise
if not is_replica or attempt == max_attempts:
raise

View File

@@ -550,6 +550,36 @@ class TestRlsTransaction:
mock_sleep.assert_any_call(1.0)
assert mock_logger.info.call_count == 2
def test_rls_transaction_operational_error_inside_context_no_retry(
self, tenants_fixture, enable_read_replica
):
"""Test OperationalError raised inside context does not retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.return_value.__enter__.return_value = None
mock_atomic.return_value.__exit__.return_value = False
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
raise OperationalError("Conflict with recovery")
mock_sleep.assert_not_called()
def test_rls_transaction_max_three_attempts_for_replica(
self, tenants_fixture, enable_read_replica
):
@@ -579,6 +609,38 @@ class TestRlsTransaction:
assert mock_atomic.call_count == 3
def test_rls_transaction_replica_no_retry_when_disabled(
self, tenants_fixture, enable_read_replica
):
"""Test replica retry is disabled when retry_on_replica=False."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.side_effect = OperationalError("Replica error")
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(
tenant_id, retry_on_replica=False
):
pass
assert mock_atomic.call_count == 1
mock_sleep.assert_not_called()
def test_rls_transaction_only_one_attempt_for_primary(self, tenants_fixture):
"""Test only 1 attempt for primary database."""
tenant = tenants_fixture[0]

View File

@@ -1,12 +1,14 @@
import os
import time
from glob import glob
from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
from django.db import OperationalError
from tasks.utils import batched
from api.db_router import READ_REPLICA_ALIAS, MainRouter
from api.db_utils import rls_transaction
from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction
from api.models import Finding, Integration, Provider
from api.utils import initialize_prowler_integration, initialize_prowler_provider
from prowler.lib.outputs.asff.asff import ASFF
@@ -17,11 +19,11 @@ from prowler.lib.outputs.html.html import HTML
from prowler.lib.outputs.ocsf.ocsf import OCSF
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.lib.s3.s3 import S3
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
from prowler.providers.common.models import Connection
from prowler.providers.aws.lib.security_hub.exceptions.exceptions import (
SecurityHubNoEnabledRegionsError,
)
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
from prowler.providers.common.models import Connection
logger = get_task_logger(__name__)
@@ -291,96 +293,130 @@ def upload_security_hub_integration(
total_findings_sent[integration.id] = 0
# Process findings in batches to avoid memory issues
max_attempts = REPLICA_MAX_ATTEMPTS if READ_REPLICA_ALIAS else 1
has_findings = False
batch_number = 0
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
qs = (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.order_by("uid")
.iterator()
)
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
batch_number += 1
has_findings = True
# Transform findings for this batch
transformed_findings = [
FindingOutput.transform_api_finding(
finding, prowler_provider
)
for finding in batch
]
# Convert to ASFF format
asff_transformer = ASFF(
findings=transformed_findings,
file_path="",
file_extension="json",
for attempt in range(1, max_attempts + 1):
read_alias = None
if READ_REPLICA_ALIAS:
read_alias = (
READ_REPLICA_ALIAS
if attempt < max_attempts
else MainRouter.default_db
)
asff_transformer.transform(transformed_findings)
# Get the batch of ASFF findings
batch_asff_findings = asff_transformer.data
if batch_asff_findings:
# Create Security Hub client for first batch or reuse existing
if not security_hub_client:
connected, security_hub = (
get_security_hub_client_from_integration(
integration, tenant_id, batch_asff_findings
)
try:
batch_number = 0
has_findings = False
with rls_transaction(
tenant_id,
using=read_alias,
retry_on_replica=False,
):
qs = (
Finding.all_objects.filter(
tenant_id=tenant_id, scan_id=scan_id
)
.order_by("uid")
.iterator()
)
if not connected:
if isinstance(
security_hub.error,
SecurityHubNoEnabledRegionsError,
):
logger.warning(
f"Security Hub integration {integration.id} has no enabled regions"
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
batch_number += 1
has_findings = True
# Transform findings for this batch
transformed_findings = [
FindingOutput.transform_api_finding(
finding, prowler_provider
)
for finding in batch
]
# Convert to ASFF format
asff_transformer = ASFF(
findings=transformed_findings,
file_path="",
file_extension="json",
)
asff_transformer.transform(transformed_findings)
# Get the batch of ASFF findings
batch_asff_findings = asff_transformer.data
if batch_asff_findings:
# Create Security Hub client for first batch or reuse existing
if not security_hub_client:
connected, security_hub = (
get_security_hub_client_from_integration(
integration,
tenant_id,
batch_asff_findings,
)
)
if not connected:
if isinstance(
security_hub.error,
SecurityHubNoEnabledRegionsError,
):
logger.warning(
f"Security Hub integration {integration.id} has no enabled regions"
)
else:
logger.error(
f"Security Hub connection failed for integration {integration.id}: "
f"{security_hub.error}"
)
break # Skip this integration
security_hub_client = security_hub
logger.info(
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
f"integration {integration.id}"
)
else:
logger.error(
f"Security Hub connection failed for integration {integration.id}: "
f"{security_hub.error}"
# Update findings in existing client for this batch
security_hub_client._findings_per_region = (
security_hub_client.filter(
batch_asff_findings,
send_only_fails,
)
)
break # Skip this integration
security_hub_client = security_hub
logger.info(
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
f"integration {integration.id}"
)
else:
# Update findings in existing client for this batch
security_hub_client._findings_per_region = (
security_hub_client.filter(
batch_asff_findings, send_only_fails
)
)
# Send this batch to Security Hub
try:
findings_sent = security_hub_client.batch_send_to_security_hub()
total_findings_sent[integration.id] += (
findings_sent
)
# Send this batch to Security Hub
try:
findings_sent = (
security_hub_client.batch_send_to_security_hub()
)
total_findings_sent[integration.id] += findings_sent
if findings_sent > 0:
logger.debug(
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
)
except Exception as batch_error:
logger.error(
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
)
if findings_sent > 0:
logger.debug(
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
)
except Exception as batch_error:
logger.error(
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
)
# Clear memory after processing each batch
asff_transformer._data.clear()
del batch_asff_findings
del transformed_findings
# Clear memory after processing each batch
asff_transformer._data.clear()
del batch_asff_findings
del transformed_findings
break
except OperationalError as e:
if attempt == max_attempts:
raise
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.info(
"RLS query failed during Security Hub integration "
f"(attempt {attempt}/{max_attempts}), retrying in {delay}s. Error: {e}"
)
time.sleep(delay)
if not has_findings:
logger.info(

View File

@@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from django.db import OperationalError
from tasks.jobs.integrations import (
get_s3_client_from_integration,
get_security_hub_client_from_integration,
@@ -1056,6 +1057,84 @@ class TestSecurityHubIntegrationUploads:
mock_security_hub.batch_send_to_security_hub.assert_called_once()
mock_security_hub.archive_previous_findings.assert_called_once()
@patch("tasks.jobs.integrations.time.sleep")
@patch("tasks.jobs.integrations.batched")
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
@patch("tasks.jobs.integrations.initialize_prowler_provider")
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.Provider")
@patch("tasks.jobs.integrations.Finding")
def test_upload_security_hub_integration_retries_on_operational_error(
self,
mock_finding_model,
mock_provider_model,
mock_integration_model,
mock_rls,
mock_initialize_provider,
mock_get_security_hub,
mock_batched,
mock_sleep,
):
"""Test SecurityHub upload retries on transient OperationalError."""
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
integration = MagicMock()
integration.id = "integration-1"
integration.configuration = {
"send_only_fails": True,
"archive_previous_findings": False,
}
mock_integration_model.objects.filter.return_value = [integration]
provider = MagicMock()
mock_provider_model.objects.get.return_value = provider
mock_prowler_provider = MagicMock()
mock_initialize_provider.return_value = mock_prowler_provider
mock_findings = [MagicMock(), MagicMock()]
mock_finding_model.all_objects.filter.return_value.order_by.return_value.iterator.return_value = iter(
mock_findings
)
transformed_findings = [MagicMock(), MagicMock()]
with patch("tasks.jobs.integrations.FindingOutput") as mock_finding_output:
mock_finding_output.transform_api_finding.side_effect = transformed_findings
with patch("tasks.jobs.integrations.ASFF") as mock_asff:
mock_asff_instance = MagicMock()
finding1 = MagicMock()
finding1.Compliance.Status = "FAILED"
finding2 = MagicMock()
finding2.Compliance.Status = "FAILED"
mock_asff_instance.data = [finding1, finding2]
mock_asff_instance._data = MagicMock()
mock_asff.return_value = mock_asff_instance
mock_security_hub = MagicMock()
mock_security_hub.batch_send_to_security_hub.return_value = 2
mock_get_security_hub.return_value = (True, mock_security_hub)
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value.__exit__.return_value = False
mock_batched.side_effect = [
OperationalError("Conflict with recovery"),
[(mock_findings, None)],
]
with patch("tasks.jobs.integrations.REPLICA_MAX_ATTEMPTS", 2):
with patch("tasks.jobs.integrations.READ_REPLICA_ALIAS", "replica"):
result = upload_security_hub_integration(
tenant_id, provider_id, scan_id
)
assert result is True
mock_sleep.assert_called_once()
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
@patch("tasks.jobs.integrations.initialize_prowler_provider")
@patch("tasks.jobs.integrations.rls_transaction")