From e47f2b4033d0df59d48de02bdab5000979c4d021 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Pe=C3=B1a?= Date: Wed, 25 Feb 2026 11:34:41 +0100 Subject: [PATCH] fix(api): harden security hub retries (#10144) --- api/CHANGELOG.md | 1 + api/src/backend/api/db_utils.py | 7 +- api/src/backend/api/tests/test_db_utils.py | 62 ++++++ api/src/backend/tasks/jobs/integrations.py | 194 +++++++++++------- .../backend/tasks/tests/test_integrations.py | 79 +++++++ 5 files changed, 263 insertions(+), 80 deletions(-) diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 96a1d210a8..9da5688c4d 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -35,6 +35,7 @@ All notable changes to the **Prowler API** are documented in this file. - Attack Paths: Orphaned temporary Neo4j databases are now cleaned up on scan failure and provider deletion [(#10101)](https://github.com/prowler-cloud/prowler/pull/10101) - Attack Paths: scan no longer raises `DatabaseError` when provider is deleted mid-scan [(#10116)](https://github.com/prowler-cloud/prowler/pull/10116) +- Security Hub export retries transient replica conflicts without failing integrations [(#10144)](https://github.com/prowler-cloud/prowler/pull/10144) ### 🔐 Security diff --git a/api/src/backend/api/db_utils.py b/api/src/backend/api/db_utils.py index b719d4b736..7a71084ccd 100644 --- a/api/src/backend/api/db_utils.py +++ b/api/src/backend/api/db_utils.py @@ -74,6 +74,7 @@ def rls_transaction( value: str, parameter: str = POSTGRES_TENANT_VAR, using: str | None = None, + retry_on_replica: bool = True, ): """ Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the @@ -92,10 +93,11 @@ def rls_transaction( alias = db_alias is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS - max_attempts = REPLICA_MAX_ATTEMPTS if is_replica else 1 + max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1 for attempt in range(1, max_attempts + 1): router_token = None + yielded_cursor = False # On final attempt, fallback to primary if attempt == max_attempts and is_replica: @@ -118,9 +120,12 @@ def rls_transaction( except ValueError: raise ValidationError("Must be a valid UUID") cursor.execute(SET_CONFIG_QUERY, [parameter, value]) + yielded_cursor = True yield cursor return except OperationalError as e: + if yielded_cursor: + raise # If on primary or max attempts reached, raise if not is_replica or attempt == max_attempts: raise diff --git a/api/src/backend/api/tests/test_db_utils.py b/api/src/backend/api/tests/test_db_utils.py index f706ab8c71..f52bb349aa 100644 --- a/api/src/backend/api/tests/test_db_utils.py +++ b/api/src/backend/api/tests/test_db_utils.py @@ -550,6 +550,36 @@ class TestRlsTransaction: mock_sleep.assert_any_call(1.0) assert mock_logger.info.call_count == 2 + def test_rls_transaction_operational_error_inside_context_no_retry( + self, tenants_fixture, enable_read_replica + ): + """Test OperationalError raised inside context does not retry.""" + tenant = tenants_fixture[0] + tenant_id = str(tenant.id) + + with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica): + with patch("api.db_utils.connections") as mock_connections: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + mock_connections.__getitem__.return_value = mock_conn + mock_connections.__contains__.return_value = True + + with patch("api.db_utils.transaction.atomic") as mock_atomic: + mock_atomic.return_value.__enter__.return_value = None + mock_atomic.return_value.__exit__.return_value = False + + with patch("api.db_utils.time.sleep") as mock_sleep: + with patch( + "api.db_utils.set_read_db_alias", return_value="token" + ): + with patch("api.db_utils.reset_read_db_alias"): + with pytest.raises(OperationalError): + with rls_transaction(tenant_id): + raise OperationalError("Conflict with recovery") + + mock_sleep.assert_not_called() + def test_rls_transaction_max_three_attempts_for_replica( self, tenants_fixture, enable_read_replica ): @@ -579,6 +609,38 @@ class TestRlsTransaction: assert mock_atomic.call_count == 3 + def test_rls_transaction_replica_no_retry_when_disabled( + self, tenants_fixture, enable_read_replica + ): + """Test replica retry is disabled when retry_on_replica=False.""" + tenant = tenants_fixture[0] + tenant_id = str(tenant.id) + + with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica): + with patch("api.db_utils.connections") as mock_connections: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + mock_connections.__getitem__.return_value = mock_conn + mock_connections.__contains__.return_value = True + + with patch("api.db_utils.transaction.atomic") as mock_atomic: + mock_atomic.side_effect = OperationalError("Replica error") + + with patch("api.db_utils.time.sleep") as mock_sleep: + with patch( + "api.db_utils.set_read_db_alias", return_value="token" + ): + with patch("api.db_utils.reset_read_db_alias"): + with pytest.raises(OperationalError): + with rls_transaction( + tenant_id, retry_on_replica=False + ): + pass + + assert mock_atomic.call_count == 1 + mock_sleep.assert_not_called() + def test_rls_transaction_only_one_attempt_for_primary(self, tenants_fixture): """Test only 1 attempt for primary database.""" tenant = tenants_fixture[0] diff --git a/api/src/backend/tasks/jobs/integrations.py b/api/src/backend/tasks/jobs/integrations.py index cd76762a40..5ca94057da 100644 --- a/api/src/backend/tasks/jobs/integrations.py +++ b/api/src/backend/tasks/jobs/integrations.py @@ -1,12 +1,14 @@ import os +import time from glob import glob from celery.utils.log import get_task_logger from config.django.base import DJANGO_FINDINGS_BATCH_SIZE +from django.db import OperationalError from tasks.utils import batched from api.db_router import READ_REPLICA_ALIAS, MainRouter -from api.db_utils import rls_transaction +from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction from api.models import Finding, Integration, Provider from api.utils import initialize_prowler_integration, initialize_prowler_provider from prowler.lib.outputs.asff.asff import ASFF @@ -17,11 +19,11 @@ from prowler.lib.outputs.html.html import HTML from prowler.lib.outputs.ocsf.ocsf import OCSF from prowler.providers.aws.aws_provider import AwsProvider from prowler.providers.aws.lib.s3.s3 import S3 -from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub -from prowler.providers.common.models import Connection from prowler.providers.aws.lib.security_hub.exceptions.exceptions import ( SecurityHubNoEnabledRegionsError, ) +from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub +from prowler.providers.common.models import Connection logger = get_task_logger(__name__) @@ -291,96 +293,130 @@ def upload_security_hub_integration( total_findings_sent[integration.id] = 0 # Process findings in batches to avoid memory issues + max_attempts = REPLICA_MAX_ATTEMPTS if READ_REPLICA_ALIAS else 1 has_findings = False batch_number = 0 - with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): - qs = ( - Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id) - .order_by("uid") - .iterator() - ) - - for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE): - batch_number += 1 - has_findings = True - - # Transform findings for this batch - transformed_findings = [ - FindingOutput.transform_api_finding( - finding, prowler_provider - ) - for finding in batch - ] - - # Convert to ASFF format - asff_transformer = ASFF( - findings=transformed_findings, - file_path="", - file_extension="json", + for attempt in range(1, max_attempts + 1): + read_alias = None + if READ_REPLICA_ALIAS: + read_alias = ( + READ_REPLICA_ALIAS + if attempt < max_attempts + else MainRouter.default_db ) - asff_transformer.transform(transformed_findings) - # Get the batch of ASFF findings - batch_asff_findings = asff_transformer.data - - if batch_asff_findings: - # Create Security Hub client for first batch or reuse existing - if not security_hub_client: - connected, security_hub = ( - get_security_hub_client_from_integration( - integration, tenant_id, batch_asff_findings - ) + try: + batch_number = 0 + has_findings = False + with rls_transaction( + tenant_id, + using=read_alias, + retry_on_replica=False, + ): + qs = ( + Finding.all_objects.filter( + tenant_id=tenant_id, scan_id=scan_id ) + .order_by("uid") + .iterator() + ) - if not connected: - if isinstance( - security_hub.error, - SecurityHubNoEnabledRegionsError, - ): - logger.warning( - f"Security Hub integration {integration.id} has no enabled regions" + for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE): + batch_number += 1 + has_findings = True + + # Transform findings for this batch + transformed_findings = [ + FindingOutput.transform_api_finding( + finding, prowler_provider + ) + for finding in batch + ] + + # Convert to ASFF format + asff_transformer = ASFF( + findings=transformed_findings, + file_path="", + file_extension="json", + ) + asff_transformer.transform(transformed_findings) + + # Get the batch of ASFF findings + batch_asff_findings = asff_transformer.data + + if batch_asff_findings: + # Create Security Hub client for first batch or reuse existing + if not security_hub_client: + connected, security_hub = ( + get_security_hub_client_from_integration( + integration, + tenant_id, + batch_asff_findings, + ) + ) + + if not connected: + if isinstance( + security_hub.error, + SecurityHubNoEnabledRegionsError, + ): + logger.warning( + f"Security Hub integration {integration.id} has no enabled regions" + ) + else: + logger.error( + f"Security Hub connection failed for integration {integration.id}: " + f"{security_hub.error}" + ) + break # Skip this integration + + security_hub_client = security_hub + logger.info( + f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via " + f"integration {integration.id}" ) else: - logger.error( - f"Security Hub connection failed for integration {integration.id}: " - f"{security_hub.error}" + # Update findings in existing client for this batch + security_hub_client._findings_per_region = ( + security_hub_client.filter( + batch_asff_findings, + send_only_fails, + ) ) - break # Skip this integration - security_hub_client = security_hub - logger.info( - f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via " - f"integration {integration.id}" - ) - else: - # Update findings in existing client for this batch - security_hub_client._findings_per_region = ( - security_hub_client.filter( - batch_asff_findings, send_only_fails - ) - ) + # Send this batch to Security Hub + try: + findings_sent = security_hub_client.batch_send_to_security_hub() + total_findings_sent[integration.id] += ( + findings_sent + ) - # Send this batch to Security Hub - try: - findings_sent = ( - security_hub_client.batch_send_to_security_hub() - ) - total_findings_sent[integration.id] += findings_sent + if findings_sent > 0: + logger.debug( + f"Sent batch {batch_number} with {findings_sent} findings to Security Hub" + ) + except Exception as batch_error: + logger.error( + f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}" + ) - if findings_sent > 0: - logger.debug( - f"Sent batch {batch_number} with {findings_sent} findings to Security Hub" - ) - except Exception as batch_error: - logger.error( - f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}" - ) + # Clear memory after processing each batch + asff_transformer._data.clear() + del batch_asff_findings + del transformed_findings - # Clear memory after processing each batch - asff_transformer._data.clear() - del batch_asff_findings - del transformed_findings + break + except OperationalError as e: + if attempt == max_attempts: + raise + + delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1)) + logger.info( + "RLS query failed during Security Hub integration " + f"(attempt {attempt}/{max_attempts}), retrying in {delay}s. Error: {e}" + ) + time.sleep(delay) if not has_findings: logger.info( diff --git a/api/src/backend/tasks/tests/test_integrations.py b/api/src/backend/tasks/tests/test_integrations.py index 954c645998..e246405cdd 100644 --- a/api/src/backend/tasks/tests/test_integrations.py +++ b/api/src/backend/tasks/tests/test_integrations.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from django.db import OperationalError from tasks.jobs.integrations import ( get_s3_client_from_integration, get_security_hub_client_from_integration, @@ -1056,6 +1057,84 @@ class TestSecurityHubIntegrationUploads: mock_security_hub.batch_send_to_security_hub.assert_called_once() mock_security_hub.archive_previous_findings.assert_called_once() + @patch("tasks.jobs.integrations.time.sleep") + @patch("tasks.jobs.integrations.batched") + @patch("tasks.jobs.integrations.get_security_hub_client_from_integration") + @patch("tasks.jobs.integrations.initialize_prowler_provider") + @patch("tasks.jobs.integrations.rls_transaction") + @patch("tasks.jobs.integrations.Integration") + @patch("tasks.jobs.integrations.Provider") + @patch("tasks.jobs.integrations.Finding") + def test_upload_security_hub_integration_retries_on_operational_error( + self, + mock_finding_model, + mock_provider_model, + mock_integration_model, + mock_rls, + mock_initialize_provider, + mock_get_security_hub, + mock_batched, + mock_sleep, + ): + """Test SecurityHub upload retries on transient OperationalError.""" + tenant_id = "tenant-id" + provider_id = "provider-id" + scan_id = "scan-123" + + integration = MagicMock() + integration.id = "integration-1" + integration.configuration = { + "send_only_fails": True, + "archive_previous_findings": False, + } + mock_integration_model.objects.filter.return_value = [integration] + + provider = MagicMock() + mock_provider_model.objects.get.return_value = provider + + mock_prowler_provider = MagicMock() + mock_initialize_provider.return_value = mock_prowler_provider + + mock_findings = [MagicMock(), MagicMock()] + mock_finding_model.all_objects.filter.return_value.order_by.return_value.iterator.return_value = iter( + mock_findings + ) + + transformed_findings = [MagicMock(), MagicMock()] + with patch("tasks.jobs.integrations.FindingOutput") as mock_finding_output: + mock_finding_output.transform_api_finding.side_effect = transformed_findings + + with patch("tasks.jobs.integrations.ASFF") as mock_asff: + mock_asff_instance = MagicMock() + finding1 = MagicMock() + finding1.Compliance.Status = "FAILED" + finding2 = MagicMock() + finding2.Compliance.Status = "FAILED" + mock_asff_instance.data = [finding1, finding2] + mock_asff_instance._data = MagicMock() + mock_asff.return_value = mock_asff_instance + + mock_security_hub = MagicMock() + mock_security_hub.batch_send_to_security_hub.return_value = 2 + mock_get_security_hub.return_value = (True, mock_security_hub) + + mock_rls.return_value.__enter__.return_value = None + mock_rls.return_value.__exit__.return_value = False + + mock_batched.side_effect = [ + OperationalError("Conflict with recovery"), + [(mock_findings, None)], + ] + + with patch("tasks.jobs.integrations.REPLICA_MAX_ATTEMPTS", 2): + with patch("tasks.jobs.integrations.READ_REPLICA_ALIAS", "replica"): + result = upload_security_hub_integration( + tenant_id, provider_id, scan_id + ) + + assert result is True + mock_sleep.assert_called_once() + @patch("tasks.jobs.integrations.get_security_hub_client_from_integration") @patch("tasks.jobs.integrations.initialize_prowler_provider") @patch("tasks.jobs.integrations.rls_transaction")