mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-22 03:08:23 +00:00
fix(api): harden security hub retries (#10144)
This commit is contained in:
@@ -35,6 +35,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
- Attack Paths: Orphaned temporary Neo4j databases are now cleaned up on scan failure and provider deletion [(#10101)](https://github.com/prowler-cloud/prowler/pull/10101)
|
||||
- Attack Paths: scan no longer raises `DatabaseError` when provider is deleted mid-scan [(#10116)](https://github.com/prowler-cloud/prowler/pull/10116)
|
||||
- Security Hub export retries transient replica conflicts without failing integrations [(#10144)](https://github.com/prowler-cloud/prowler/pull/10144)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ def rls_transaction(
|
||||
value: str,
|
||||
parameter: str = POSTGRES_TENANT_VAR,
|
||||
using: str | None = None,
|
||||
retry_on_replica: bool = True,
|
||||
):
|
||||
"""
|
||||
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
|
||||
@@ -92,10 +93,11 @@ def rls_transaction(
|
||||
|
||||
alias = db_alias
|
||||
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
|
||||
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica else 1
|
||||
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
router_token = None
|
||||
yielded_cursor = False
|
||||
|
||||
# On final attempt, fallback to primary
|
||||
if attempt == max_attempts and is_replica:
|
||||
@@ -118,9 +120,12 @@ def rls_transaction(
|
||||
except ValueError:
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
yielded_cursor = True
|
||||
yield cursor
|
||||
return
|
||||
except OperationalError as e:
|
||||
if yielded_cursor:
|
||||
raise
|
||||
# If on primary or max attempts reached, raise
|
||||
if not is_replica or attempt == max_attempts:
|
||||
raise
|
||||
|
||||
@@ -550,6 +550,36 @@ class TestRlsTransaction:
|
||||
mock_sleep.assert_any_call(1.0)
|
||||
assert mock_logger.info.call_count == 2
|
||||
|
||||
def test_rls_transaction_operational_error_inside_context_no_retry(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test OperationalError raised inside context does not retry."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.return_value.__enter__.return_value = None
|
||||
mock_atomic.return_value.__exit__.return_value = False
|
||||
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
raise OperationalError("Conflict with recovery")
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_rls_transaction_max_three_attempts_for_replica(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
@@ -579,6 +609,38 @@ class TestRlsTransaction:
|
||||
|
||||
assert mock_atomic.call_count == 3
|
||||
|
||||
def test_rls_transaction_replica_no_retry_when_disabled(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test replica retry is disabled when retry_on_replica=False."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.side_effect = OperationalError("Replica error")
|
||||
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(
|
||||
tenant_id, retry_on_replica=False
|
||||
):
|
||||
pass
|
||||
|
||||
assert mock_atomic.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_rls_transaction_only_one_attempt_for_primary(self, tenants_fixture):
|
||||
"""Test only 1 attempt for primary database."""
|
||||
tenant = tenants_fixture[0]
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import os
|
||||
import time
|
||||
from glob import glob
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
|
||||
from django.db import OperationalError
|
||||
from tasks.utils import batched
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS, MainRouter
|
||||
from api.db_utils import rls_transaction
|
||||
from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction
|
||||
from api.models import Finding, Integration, Provider
|
||||
from api.utils import initialize_prowler_integration, initialize_prowler_provider
|
||||
from prowler.lib.outputs.asff.asff import ASFF
|
||||
@@ -17,11 +19,11 @@ from prowler.lib.outputs.html.html import HTML
|
||||
from prowler.lib.outputs.ocsf.ocsf import OCSF
|
||||
from prowler.providers.aws.aws_provider import AwsProvider
|
||||
from prowler.providers.aws.lib.s3.s3 import S3
|
||||
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
|
||||
from prowler.providers.common.models import Connection
|
||||
from prowler.providers.aws.lib.security_hub.exceptions.exceptions import (
|
||||
SecurityHubNoEnabledRegionsError,
|
||||
)
|
||||
from prowler.providers.aws.lib.security_hub.security_hub import SecurityHub
|
||||
from prowler.providers.common.models import Connection
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
@@ -291,96 +293,130 @@ def upload_security_hub_integration(
|
||||
total_findings_sent[integration.id] = 0
|
||||
|
||||
# Process findings in batches to avoid memory issues
|
||||
max_attempts = REPLICA_MAX_ATTEMPTS if READ_REPLICA_ALIAS else 1
|
||||
has_findings = False
|
||||
batch_number = 0
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
qs = (
|
||||
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
.order_by("uid")
|
||||
.iterator()
|
||||
)
|
||||
|
||||
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
batch_number += 1
|
||||
has_findings = True
|
||||
|
||||
# Transform findings for this batch
|
||||
transformed_findings = [
|
||||
FindingOutput.transform_api_finding(
|
||||
finding, prowler_provider
|
||||
)
|
||||
for finding in batch
|
||||
]
|
||||
|
||||
# Convert to ASFF format
|
||||
asff_transformer = ASFF(
|
||||
findings=transformed_findings,
|
||||
file_path="",
|
||||
file_extension="json",
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
read_alias = None
|
||||
if READ_REPLICA_ALIAS:
|
||||
read_alias = (
|
||||
READ_REPLICA_ALIAS
|
||||
if attempt < max_attempts
|
||||
else MainRouter.default_db
|
||||
)
|
||||
asff_transformer.transform(transformed_findings)
|
||||
|
||||
# Get the batch of ASFF findings
|
||||
batch_asff_findings = asff_transformer.data
|
||||
|
||||
if batch_asff_findings:
|
||||
# Create Security Hub client for first batch or reuse existing
|
||||
if not security_hub_client:
|
||||
connected, security_hub = (
|
||||
get_security_hub_client_from_integration(
|
||||
integration, tenant_id, batch_asff_findings
|
||||
)
|
||||
try:
|
||||
batch_number = 0
|
||||
has_findings = False
|
||||
with rls_transaction(
|
||||
tenant_id,
|
||||
using=read_alias,
|
||||
retry_on_replica=False,
|
||||
):
|
||||
qs = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
)
|
||||
.order_by("uid")
|
||||
.iterator()
|
||||
)
|
||||
|
||||
if not connected:
|
||||
if isinstance(
|
||||
security_hub.error,
|
||||
SecurityHubNoEnabledRegionsError,
|
||||
):
|
||||
logger.warning(
|
||||
f"Security Hub integration {integration.id} has no enabled regions"
|
||||
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
batch_number += 1
|
||||
has_findings = True
|
||||
|
||||
# Transform findings for this batch
|
||||
transformed_findings = [
|
||||
FindingOutput.transform_api_finding(
|
||||
finding, prowler_provider
|
||||
)
|
||||
for finding in batch
|
||||
]
|
||||
|
||||
# Convert to ASFF format
|
||||
asff_transformer = ASFF(
|
||||
findings=transformed_findings,
|
||||
file_path="",
|
||||
file_extension="json",
|
||||
)
|
||||
asff_transformer.transform(transformed_findings)
|
||||
|
||||
# Get the batch of ASFF findings
|
||||
batch_asff_findings = asff_transformer.data
|
||||
|
||||
if batch_asff_findings:
|
||||
# Create Security Hub client for first batch or reuse existing
|
||||
if not security_hub_client:
|
||||
connected, security_hub = (
|
||||
get_security_hub_client_from_integration(
|
||||
integration,
|
||||
tenant_id,
|
||||
batch_asff_findings,
|
||||
)
|
||||
)
|
||||
|
||||
if not connected:
|
||||
if isinstance(
|
||||
security_hub.error,
|
||||
SecurityHubNoEnabledRegionsError,
|
||||
):
|
||||
logger.warning(
|
||||
f"Security Hub integration {integration.id} has no enabled regions"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Security Hub connection failed for integration {integration.id}: "
|
||||
f"{security_hub.error}"
|
||||
)
|
||||
break # Skip this integration
|
||||
|
||||
security_hub_client = security_hub
|
||||
logger.info(
|
||||
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
|
||||
f"integration {integration.id}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Security Hub connection failed for integration {integration.id}: "
|
||||
f"{security_hub.error}"
|
||||
# Update findings in existing client for this batch
|
||||
security_hub_client._findings_per_region = (
|
||||
security_hub_client.filter(
|
||||
batch_asff_findings,
|
||||
send_only_fails,
|
||||
)
|
||||
)
|
||||
break # Skip this integration
|
||||
|
||||
security_hub_client = security_hub
|
||||
logger.info(
|
||||
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
|
||||
f"integration {integration.id}"
|
||||
)
|
||||
else:
|
||||
# Update findings in existing client for this batch
|
||||
security_hub_client._findings_per_region = (
|
||||
security_hub_client.filter(
|
||||
batch_asff_findings, send_only_fails
|
||||
)
|
||||
)
|
||||
# Send this batch to Security Hub
|
||||
try:
|
||||
findings_sent = security_hub_client.batch_send_to_security_hub()
|
||||
total_findings_sent[integration.id] += (
|
||||
findings_sent
|
||||
)
|
||||
|
||||
# Send this batch to Security Hub
|
||||
try:
|
||||
findings_sent = (
|
||||
security_hub_client.batch_send_to_security_hub()
|
||||
)
|
||||
total_findings_sent[integration.id] += findings_sent
|
||||
if findings_sent > 0:
|
||||
logger.debug(
|
||||
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
|
||||
)
|
||||
except Exception as batch_error:
|
||||
logger.error(
|
||||
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
|
||||
)
|
||||
|
||||
if findings_sent > 0:
|
||||
logger.debug(
|
||||
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
|
||||
)
|
||||
except Exception as batch_error:
|
||||
logger.error(
|
||||
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
|
||||
)
|
||||
# Clear memory after processing each batch
|
||||
asff_transformer._data.clear()
|
||||
del batch_asff_findings
|
||||
del transformed_findings
|
||||
|
||||
# Clear memory after processing each batch
|
||||
asff_transformer._data.clear()
|
||||
del batch_asff_findings
|
||||
del transformed_findings
|
||||
break
|
||||
except OperationalError as e:
|
||||
if attempt == max_attempts:
|
||||
raise
|
||||
|
||||
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
|
||||
logger.info(
|
||||
"RLS query failed during Security Hub integration "
|
||||
f"(attempt {attempt}/{max_attempts}), retrying in {delay}s. Error: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
if not has_findings:
|
||||
logger.info(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from django.db import OperationalError
|
||||
from tasks.jobs.integrations import (
|
||||
get_s3_client_from_integration,
|
||||
get_security_hub_client_from_integration,
|
||||
@@ -1056,6 +1057,84 @@ class TestSecurityHubIntegrationUploads:
|
||||
mock_security_hub.batch_send_to_security_hub.assert_called_once()
|
||||
mock_security_hub.archive_previous_findings.assert_called_once()
|
||||
|
||||
@patch("tasks.jobs.integrations.time.sleep")
|
||||
@patch("tasks.jobs.integrations.batched")
|
||||
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
|
||||
@patch("tasks.jobs.integrations.initialize_prowler_provider")
|
||||
@patch("tasks.jobs.integrations.rls_transaction")
|
||||
@patch("tasks.jobs.integrations.Integration")
|
||||
@patch("tasks.jobs.integrations.Provider")
|
||||
@patch("tasks.jobs.integrations.Finding")
|
||||
def test_upload_security_hub_integration_retries_on_operational_error(
|
||||
self,
|
||||
mock_finding_model,
|
||||
mock_provider_model,
|
||||
mock_integration_model,
|
||||
mock_rls,
|
||||
mock_initialize_provider,
|
||||
mock_get_security_hub,
|
||||
mock_batched,
|
||||
mock_sleep,
|
||||
):
|
||||
"""Test SecurityHub upload retries on transient OperationalError."""
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
|
||||
integration = MagicMock()
|
||||
integration.id = "integration-1"
|
||||
integration.configuration = {
|
||||
"send_only_fails": True,
|
||||
"archive_previous_findings": False,
|
||||
}
|
||||
mock_integration_model.objects.filter.return_value = [integration]
|
||||
|
||||
provider = MagicMock()
|
||||
mock_provider_model.objects.get.return_value = provider
|
||||
|
||||
mock_prowler_provider = MagicMock()
|
||||
mock_initialize_provider.return_value = mock_prowler_provider
|
||||
|
||||
mock_findings = [MagicMock(), MagicMock()]
|
||||
mock_finding_model.all_objects.filter.return_value.order_by.return_value.iterator.return_value = iter(
|
||||
mock_findings
|
||||
)
|
||||
|
||||
transformed_findings = [MagicMock(), MagicMock()]
|
||||
with patch("tasks.jobs.integrations.FindingOutput") as mock_finding_output:
|
||||
mock_finding_output.transform_api_finding.side_effect = transformed_findings
|
||||
|
||||
with patch("tasks.jobs.integrations.ASFF") as mock_asff:
|
||||
mock_asff_instance = MagicMock()
|
||||
finding1 = MagicMock()
|
||||
finding1.Compliance.Status = "FAILED"
|
||||
finding2 = MagicMock()
|
||||
finding2.Compliance.Status = "FAILED"
|
||||
mock_asff_instance.data = [finding1, finding2]
|
||||
mock_asff_instance._data = MagicMock()
|
||||
mock_asff.return_value = mock_asff_instance
|
||||
|
||||
mock_security_hub = MagicMock()
|
||||
mock_security_hub.batch_send_to_security_hub.return_value = 2
|
||||
mock_get_security_hub.return_value = (True, mock_security_hub)
|
||||
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value.__exit__.return_value = False
|
||||
|
||||
mock_batched.side_effect = [
|
||||
OperationalError("Conflict with recovery"),
|
||||
[(mock_findings, None)],
|
||||
]
|
||||
|
||||
with patch("tasks.jobs.integrations.REPLICA_MAX_ATTEMPTS", 2):
|
||||
with patch("tasks.jobs.integrations.READ_REPLICA_ALIAS", "replica"):
|
||||
result = upload_security_hub_integration(
|
||||
tenant_id, provider_id, scan_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
|
||||
@patch("tasks.jobs.integrations.initialize_prowler_provider")
|
||||
@patch("tasks.jobs.integrations.rls_transaction")
|
||||
|
||||
Reference in New Issue
Block a user