Compare commits

...

3 Commits

36 changed files with 2362 additions and 2092 deletions

View File

@@ -17,6 +17,7 @@ All notable changes to the **Prowler API** are documented in this file.
### 🐞 Fixed
- Attack Paths: Recover `graph_data_ready` flag when scan fails during graph swap, preventing query endpoints from staying blocked until the next successful scan [(#10354)](https://github.com/prowler-cloud/prowler/pull/10354)
- Rewrite `rls_transaction` to retry mid-query read replica failures with primary DB fallback, fixing scan crashes during RDS replica recovery [(#10374)](https://github.com/prowler-cloud/prowler/pull/10374)
### 🔐 Security

View File

@@ -45,26 +45,27 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
tenant = Tenant.objects.using(MainRouter.admin_db).create(
name=f"{user.email.split('@')[0]} default tenant"
)
with rls_transaction(str(tenant.id)):
Membership.objects.using(MainRouter.admin_db).create(
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
)
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
for attempt in rls_transaction(str(tenant.id)):
with attempt:
Membership.objects.using(MainRouter.admin_db).create(
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
)
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
else:
request.session["saml_user_created"] = str(user.id)

View File

@@ -90,11 +90,12 @@ class BaseRLSViewSet(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(
for attempt in rls_transaction(
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
with attempt:
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
def get_serializer_context(self):
context = super().get_serializer_context()
@@ -163,12 +164,13 @@ class BaseTenantViewset(BaseViewSet):
raise NotAuthenticated("Tenant ID is not present in token")
user_id = str(request.user.id)
with rls_transaction(
for attempt in rls_transaction(
value=user_id,
parameter=POSTGRES_USER_VAR,
using=getattr(self, "db_alias", MainRouter.default_db),
):
return super().initial(request, *args, **kwargs)
with attempt:
return super().initial(request, *args, **kwargs)
class BaseUserViewset(BaseViewSet):
@@ -200,8 +202,9 @@ class BaseUserViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(
for attempt in rls_transaction(
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
with attempt:
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

View File

@@ -71,21 +71,14 @@ def psycopg_connection(database_alias: str):
@contextmanager
def rls_transaction(
def _rls_transaction_context_manager(
value: str,
parameter: str = POSTGRES_TENANT_VAR,
using: str | None = None,
retry_on_replica: bool = True,
):
"""
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
if the value is a valid UUID.
"""Internal context manager that opens a single RLS transaction.
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
using (str | None): Optional database alias to run the transaction against. Defaults to the
active read alias (if any) or Django's default connection.
Callers should use ``rls_transaction`` (the public class) instead.
"""
requested_alias = using or get_read_db_alias()
db_alias = requested_alias or DEFAULT_DB_ALIAS
@@ -93,54 +86,164 @@ def rls_transaction(
db_alias = DEFAULT_DB_ALIAS
alias = db_alias
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
router_token = None
conn = connections[alias]
try:
if alias != DEFAULT_DB_ALIAS:
router_token = set_read_db_alias(alias)
for attempt in range(1, max_attempts + 1):
router_token = None
yielded_cursor = False
with transaction.atomic(using=alias):
with conn.cursor() as cursor:
try:
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor
finally:
if router_token is not None:
reset_read_db_alias(router_token)
# On final attempt, fallback to primary
if attempt == max_attempts and is_replica:
logger.warning(
f"RLS transaction failed after {attempt - 1} attempts on replica, "
f"falling back to primary DB"
)
class rls_transaction:
"""RLS transaction with retry and replica-to-primary fallback.
Usage::
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
result = Model.objects.filter(...)
When ``using`` points to a read replica and ``retry_on_replica`` is True,
the iterator yields up to ``REPLICA_MAX_ATTEMPTS`` attempts on the replica
followed by one final attempt on the primary DB. For primary-only calls
the iterator yields a single attempt with no retry.
"""
def __init__(
self,
value: str,
parameter: str = POSTGRES_TENANT_VAR,
using: str | None = None,
retry_on_replica: bool = True,
):
self.value = value
self.parameter = parameter
self.using = using
self.retry_on_replica = retry_on_replica
def __iter__(self):
return _RLSRetryIterator(self)
class _RLSRetryIterator:
def __init__(self, parent: rls_transaction):
self._parent = parent
self._attempt = 0
self._done = False
requested = parent.using or get_read_db_alias()
self._db_alias = requested or DEFAULT_DB_ALIAS
if self._db_alias not in connections:
self._db_alias = DEFAULT_DB_ALIAS
self._is_replica = bool(
READ_REPLICA_ALIAS and self._db_alias == READ_REPLICA_ALIAS
)
if self._is_replica and parent.retry_on_replica:
# N replica attempts + 1 primary fallback
self._max_attempts = REPLICA_MAX_ATTEMPTS + 1
else:
self._max_attempts = 1
def __iter__(self):
return self
def __next__(self):
if self._done or self._attempt >= self._max_attempts:
raise StopIteration
self._attempt += 1
is_final = self._attempt == self._max_attempts
if is_final and self._is_replica and self._parent.retry_on_replica:
alias = DEFAULT_DB_ALIAS
if self._attempt > 1:
logger.warning(
f"RLS transaction failed after {self._attempt - 1} attempts on replica, "
f"falling back to primary DB"
)
else:
alias = self._db_alias
conn = connections[alias]
return _RLSAttempt(self, alias, is_final)
class _RLSAttempt:
def __init__(self, iterator: _RLSRetryIterator, alias: str, is_final: bool):
self._iterator = iterator
self._alias = alias
self._is_final = is_final
self._context_manager = None
def __enter__(self):
# Retry loop for connection-setup errors (pre-yield).
# Python does NOT call __exit__ when __enter__ raises, so we
# must handle retries here for errors like "connection refused".
while True:
try:
self._context_manager = _rls_transaction_context_manager(
self._iterator._parent.value,
self._iterator._parent.parameter,
using=self._alias,
)
return self._context_manager.__enter__()
except OperationalError as exc:
if self._is_final or not self._iterator._is_replica:
raise
self._handle_retry(exc)
# Consume the next attempt from the iterator
next_att = next(self._iterator)
self._alias = next_att._alias
self._is_final = next_att._is_final
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if alias != DEFAULT_DB_ALIAS:
router_token = set_read_db_alias(alias)
with transaction.atomic(using=alias):
with conn.cursor() as cursor:
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yielded_cursor = True
yield cursor
return
except OperationalError as e:
if yielded_cursor:
raise
# If on primary or max attempts reached, raise
if not is_replica or attempt == max_attempts:
self._context_manager.__exit__(exc_type, exc_val, exc_tb)
except OperationalError as exc:
if self._is_final or not self._iterator._is_replica:
raise
self._handle_retry(exc)
return True
# Retry with exponential backoff
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.info(
f"RLS transaction failed on replica (attempt {attempt}/{max_attempts}), "
f"retrying in {delay}s. Error: {e}"
)
time.sleep(delay)
finally:
if router_token is not None:
reset_read_db_alias(router_token)
if exc_type is None:
self._iterator._done = True
return False
if (
issubclass(exc_type, OperationalError)
and not self._is_final
and self._iterator._is_replica
):
self._handle_retry(exc_val)
return True
return False
def _handle_retry(self, error):
try:
connections[self._alias].close()
except Exception:
pass
attempt = self._iterator._attempt
max_att = self._iterator._max_attempts
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.info(
f"RLS transaction failed on replica (attempt {attempt}/{max_att}), "
f"retrying in {delay:.1f}s. Error: {error}"
)
time.sleep(delay)
class CustomUserManager(BaseUserManager):
@@ -197,23 +300,21 @@ def batch_delete(tenant_id, queryset, batch_size=settings.DJANGO_DELETION_BATCH_
deletion_summary = {}
while True:
with rls_transaction(tenant_id, POSTGRES_TENANT_VAR):
# Get a batch of IDs to delete
batch_ids = set(
queryset.values_list("id", flat=True).order_by("id")[:batch_size]
)
if not batch_ids:
# No more objects to delete
break
for attempt in rls_transaction(tenant_id, POSTGRES_TENANT_VAR):
with attempt:
# Get a batch of IDs to delete
batch_ids = set(
queryset.values_list("id", flat=True).order_by("id")[:batch_size]
)
if not batch_ids:
return total_deleted, deletion_summary
deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()
deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()
total_deleted += deleted_count
for model_label, count in deleted_info.items():
deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count
return total_deleted, deletion_summary
def delete_related_daily_task(provider_id: str):
"""
@@ -245,8 +346,9 @@ def create_objects_in_batches(
total = len(objects)
for i in range(0, total, batch_size):
chunk = objects[i : i + batch_size]
with rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
model.objects.bulk_create(chunk, batch_size)
for attempt in rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
with attempt:
model.objects.bulk_create(chunk, batch_size)
def update_objects_in_batches(
@@ -268,8 +370,9 @@ def update_objects_in_batches(
total = len(objects)
for start in range(0, total, batch_size):
chunk = objects[start : start + batch_size]
with rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
model.objects.bulk_update(chunk, fields, batch_size)
for attempt in rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
with attempt:
model.objects.bulk_update(chunk, fields, batch_size)
# Postgres Enums

View File

@@ -97,23 +97,24 @@ def handle_provider_deletion(func):
tenant_id = kwargs.get("tenant_id")
provider_id = kwargs.get("provider_id")
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
if provider_id is None:
scan_id = kwargs.get("scan_id")
if scan_id is None:
raise AssertionError(
"This task does not have provider or scan in the kwargs"
)
scan = Scan.objects.filter(pk=scan_id).first()
if scan is None:
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
if provider_id is None:
scan_id = kwargs.get("scan_id")
if scan_id is None:
raise AssertionError(
"This task does not have provider or scan in the kwargs"
)
scan = Scan.objects.filter(pk=scan_id).first()
if scan is None:
raise ProviderDeletedException(
f"Provider for scan '{scan_id}' was deleted during the scan"
) from None
provider_id = str(scan.provider_id)
if not Provider.objects.filter(pk=provider_id).exists():
raise ProviderDeletedException(
f"Provider for scan '{scan_id}' was deleted during the scan"
f"Provider '{provider_id}' was deleted during the scan"
) from None
provider_id = str(scan.provider_id)
if not Provider.objects.filter(pk=provider_id).exists():
raise ProviderDeletedException(
f"Provider '{provider_id}' was deleted during the scan"
) from None
raise
return wrapper

View File

@@ -97,27 +97,29 @@ class Command(BaseCommand):
):
possible_types.append(check_metadata.ResourceType)
with rls_transaction(tenant_id):
provider, _ = Provider.all_objects.get_or_create(
tenant_id=tenant_id,
provider="aws",
connected=True,
uid=str(random.randint(100000000000, 999999999999)),
defaults={
"alias": alias,
},
)
for attempt in rls_transaction(tenant_id):
with attempt:
provider, _ = Provider.all_objects.get_or_create(
tenant_id=tenant_id,
provider="aws",
connected=True,
uid=str(random.randint(100000000000, 999999999999)),
defaults={
"alias": alias,
},
)
with rls_transaction(tenant_id):
scan = Scan.all_objects.create(
tenant_id=tenant_id,
provider=provider,
name=alias,
trigger="manual",
state="executing",
progress=0,
started_at=datetime.now(timezone.utc),
)
for attempt in rls_transaction(tenant_id):
with attempt:
scan = Scan.all_objects.create(
tenant_id=tenant_id,
provider=provider,
name=alias,
trigger="manual",
state="executing",
progress=0,
started_at=datetime.now(timezone.utc),
)
scan_state = "completed"
try:
@@ -141,13 +143,15 @@ class Command(BaseCommand):
num_batches = ceil(len(resources) / batch_size)
self.stdout.write(self.style.WARNING("Creating resources..."))
for i in tqdm(range(0, len(resources), batch_size), total=num_batches):
with rls_transaction(tenant_id):
Resource.all_objects.bulk_create(resources[i : i + batch_size])
for attempt in rls_transaction(tenant_id):
with attempt:
Resource.all_objects.bulk_create(resources[i : i + batch_size])
self.stdout.write(self.style.SUCCESS("Resources created successfully.\n\n"))
with rls_transaction(tenant_id):
scan.progress = 33
scan.save()
for attempt in rls_transaction(tenant_id):
with attempt:
scan.progress = 33
scan.save()
# Create Findings
findings = []
@@ -193,13 +197,15 @@ class Command(BaseCommand):
num_batches = ceil(len(findings) / batch_size)
self.stdout.write(self.style.WARNING("Creating findings..."))
for i in tqdm(range(0, len(findings), batch_size), total=num_batches):
with rls_transaction(tenant_id):
Finding.all_objects.bulk_create(findings[i : i + batch_size])
for attempt in rls_transaction(tenant_id):
with attempt:
Finding.all_objects.bulk_create(findings[i : i + batch_size])
self.stdout.write(self.style.SUCCESS("Findings created successfully.\n\n"))
with rls_transaction(tenant_id):
scan.progress = 66
scan.save()
for attempt in rls_transaction(tenant_id):
with attempt:
scan.progress = 66
scan.save()
# Create ResourceFindingMapping
mappings = []
@@ -227,19 +233,21 @@ class Command(BaseCommand):
self.style.WARNING("Creating resource-finding mappings...")
)
for i in tqdm(range(0, len(mappings), batch_size), total=num_batches):
with rls_transaction(tenant_id):
ResourceFindingMapping.objects.bulk_create(
mappings[i : i + batch_size]
)
for attempt in rls_transaction(tenant_id):
with attempt:
ResourceFindingMapping.objects.bulk_create(
mappings[i : i + batch_size]
)
self.stdout.write(
self.style.SUCCESS(
"Resource-finding mappings created successfully.\n\n"
)
)
with rls_transaction(tenant_id):
scan.progress = 99
scan.save()
for attempt in rls_transaction(tenant_id):
with attempt:
scan.progress = 99
scan.save()
self.stdout.write(self.style.WARNING("Creating finding filter values..."))
resource_scan_summaries = [
@@ -254,16 +262,18 @@ class Command(BaseCommand):
for resource_id, service, region, resource_type in scan_resource_cache
]
num_batches = ceil(len(resource_scan_summaries) / batch_size)
with rls_transaction(tenant_id):
for i in tqdm(
range(0, len(resource_scan_summaries), batch_size),
total=num_batches,
):
with rls_transaction(tenant_id):
ResourceScanSummary.objects.bulk_create(
resource_scan_summaries[i : i + batch_size],
ignore_conflicts=True,
)
for attempt in rls_transaction(tenant_id):
with attempt:
for i in tqdm(
range(0, len(resource_scan_summaries), batch_size),
total=num_batches,
):
for inner_attempt in rls_transaction(tenant_id):
with inner_attempt:
ResourceScanSummary.objects.bulk_create(
resource_scan_summaries[i : i + batch_size],
ignore_conflicts=True,
)
self.stdout.write(
self.style.SUCCESS("Finding filter values created successfully.\n\n")
@@ -279,7 +289,8 @@ class Command(BaseCommand):
scan.progress = 100
scan.state = scan_state
scan.unique_resource_count = num_resources
with rls_transaction(tenant_id):
scan.save()
for attempt in rls_transaction(tenant_id):
with attempt:
scan.save()
self.stdout.write(self.style.NOTICE("Successfully populated test data."))

View File

@@ -29,16 +29,17 @@ def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
else:
next_scan_date = scheduled_time_today + timedelta(days=1)
with rls_transaction(tenant_id):
Scan.objects.create(
tenant_id=tenant_id,
name="Daily scheduled scan",
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=next_scan_date,
scheduler_task_id=daily_scheduled_scan_task.id,
)
for attempt in rls_transaction(tenant_id):
with attempt:
Scan.objects.create(
tenant_id=tenant_id,
name="Daily scheduled scan",
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=next_scan_date,
scheduler_task_id=daily_scheduled_scan_task.id,
)
class Migration(migrations.Migration):

View File

@@ -1,5 +1,3 @@
from contextlib import nullcontext
from rest_framework.renderers import BaseRenderer
from rest_framework_json_api.renderers import JSONRenderer
@@ -28,11 +26,8 @@ class APIJSONRenderer(JSONRenderer):
db_alias = getattr(request, "db_alias", None) if request else None
include_param_present = "include" in request.query_params if request else False
# Use rls_transaction if needed for included resources, otherwise do nothing
context_manager = (
rls_transaction(tenant_id, using=db_alias)
if tenant_id and include_param_present
else nullcontext()
)
with context_manager:
return super().render(data, accepted_media_type, renderer_context)
if tenant_id and include_param_present:
for attempt in rls_transaction(tenant_id, using=db_alias):
with attempt:
return super().render(data, accepted_media_type, renderer_context)
return super().render(data, accepted_media_type, renderer_context)

View File

@@ -17,23 +17,26 @@ class TestRLSTransaction:
def test_success_on_primary(self, tenant):
"""Basic: transaction succeeds on primary database."""
with rls_transaction(str(tenant.id), using=DEFAULT_DB_ALIAS) as cursor:
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result == (1,)
for attempt in rls_transaction(str(tenant.id), using=DEFAULT_DB_ALIAS):
with attempt as cursor:
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result == (1,)
def test_invalid_uuid_raises_validation_error(self):
"""Invalid UUID raises ValidationError before DB operations."""
with pytest.raises(ValidationError, match="Must be a valid UUID"):
with rls_transaction("not-a-uuid", using=DEFAULT_DB_ALIAS):
pass
for attempt in rls_transaction("not-a-uuid", using=DEFAULT_DB_ALIAS):
with attempt:
pass
def test_custom_parameter_name(self, tenant):
"""Test custom RLS parameter name."""
custom_param = "api.custom_id"
with rls_transaction(
for attempt in rls_transaction(
str(tenant.id), parameter=custom_param, using=DEFAULT_DB_ALIAS
) as cursor:
cursor.execute("SELECT current_setting(%s, true)", [custom_param])
result = cursor.fetchone()
assert result == (str(tenant.id),)
):
with attempt as cursor:
cursor.execute("SELECT current_setting(%s, true)", [custom_param])
result = cursor.fetchone()
assert result == (str(tenant.id),)

View File

@@ -364,29 +364,32 @@ class TestRlsTransaction:
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with rls_transaction(tenant_id) as cursor:
assert cursor is not None
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == tenant_id
for attempt in rls_transaction(tenant_id):
with attempt as cursor:
assert cursor is not None
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == tenant_id
def test_rls_transaction_valid_uuid_object(self, tenants_fixture):
"""Test rls_transaction with UUID object."""
tenant = tenants_fixture[0]
with rls_transaction(tenant.id) as cursor:
assert cursor is not None
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == str(tenant.id)
for attempt in rls_transaction(tenant.id):
with attempt as cursor:
assert cursor is not None
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == str(tenant.id)
def test_rls_transaction_invalid_uuid_raises_validation_error(self):
"""Test rls_transaction raises ValidationError for invalid UUID."""
invalid_uuid = "not-a-valid-uuid"
with pytest.raises(ValidationError, match="Must be a valid UUID"):
with rls_transaction(invalid_uuid):
pass
for attempt in rls_transaction(invalid_uuid):
with attempt:
pass
def test_rls_transaction_uses_default_database_when_no_alias(self, tenants_fixture):
"""Test rls_transaction uses DEFAULT_DB_ALIAS when no alias specified."""
@@ -402,8 +405,9 @@ class TestRlsTransaction:
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
mock_connections.__getitem__.assert_called_with(DEFAULT_DB_ALIAS)
@@ -424,8 +428,9 @@ class TestRlsTransaction:
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
mock_set_alias.return_value = "test_token"
with rls_transaction(tenant_id, using=custom_alias):
pass
for attempt in rls_transaction(tenant_id, using=custom_alias):
with attempt:
pass
mock_connections.__getitem__.assert_called_with(custom_alias)
mock_set_alias.assert_called_once_with(custom_alias)
@@ -452,8 +457,9 @@ class TestRlsTransaction:
"api.db_utils.reset_read_db_alias"
) as mock_reset_alias:
mock_set_alias.return_value = "test_token"
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
mock_connections.__getitem__.assert_called()
mock_set_alias.assert_called_once()
@@ -480,8 +486,9 @@ class TestRlsTransaction:
mock_connections.__getitem__.return_value = mock_conn
with patch("api.db_utils.transaction.atomic"):
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
mock_connections.__getitem__.assert_called_with(DEFAULT_DB_ALIAS)
@@ -503,15 +510,20 @@ class TestRlsTransaction:
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.set_read_db_alias", return_value="token"):
with patch("api.db_utils.reset_read_db_alias"):
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
assert mock_cursor.execute.call_count == 1
def test_rls_transaction_retry_with_exponential_backoff_on_operational_error(
self, tenants_fixture, enable_read_replica
):
"""Test retry with exponential backoff on OperationalError on replica."""
"""Test retry with exponential backoff on OperationalError on replica.
REPLICA_MAX_ATTEMPTS=3 means 3 replica tries + 1 primary fallback = 4 total.
First 3 attempts fail, 4th succeeds on primary.
"""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
@@ -528,7 +540,7 @@ class TestRlsTransaction:
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
if call_count <= 3:
raise OperationalError("Connection error")
return MagicMock(
__enter__=MagicMock(return_value=None),
@@ -544,48 +556,24 @@ class TestRlsTransaction:
):
with patch("api.db_utils.reset_read_db_alias"):
with patch("api.db_utils.logger") as mock_logger:
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
assert mock_sleep.call_count == 2
assert mock_sleep.call_count == 3
mock_sleep.assert_any_call(0.5)
mock_sleep.assert_any_call(1.0)
assert mock_logger.info.call_count == 2
def test_rls_transaction_operational_error_inside_context_no_retry(
self, tenants_fixture, enable_read_replica
):
"""Test OperationalError raised inside context does not retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.return_value.__enter__.return_value = None
mock_atomic.return_value.__exit__.return_value = False
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
raise OperationalError("Conflict with recovery")
mock_sleep.assert_not_called()
mock_sleep.assert_any_call(2.0)
assert mock_logger.info.call_count == 3
assert mock_logger.warning.call_count == 1
def test_rls_transaction_max_three_attempts_for_replica(
self, tenants_fixture, enable_read_replica
):
"""Test maximum 3 attempts for replica database."""
"""Test maximum attempts for replica database.
REPLICA_MAX_ATTEMPTS=3 means 3 replica + 1 primary = 4 total attempts.
"""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
@@ -606,10 +594,11 @@ class TestRlsTransaction:
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
assert mock_atomic.call_count == 3
assert mock_atomic.call_count == 4
def test_rls_transaction_replica_no_retry_when_disabled(
self, tenants_fixture, enable_read_replica
@@ -635,10 +624,11 @@ class TestRlsTransaction:
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(
for attempt in rls_transaction(
tenant_id, retry_on_replica=False
):
pass
with attempt:
pass
assert mock_atomic.call_count == 1
mock_sleep.assert_not_called()
@@ -660,15 +650,19 @@ class TestRlsTransaction:
mock_atomic.side_effect = OperationalError("Primary error")
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
assert mock_atomic.call_count == 1
def test_rls_transaction_fallback_to_primary_after_max_attempts(
self, tenants_fixture, enable_read_replica
):
"""Test fallback to primary DB after max attempts on replica."""
"""Test fallback to primary DB after max attempts on replica.
First 3 attempts fail on replica, 4th succeeds on primary.
"""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
@@ -685,7 +679,7 @@ class TestRlsTransaction:
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
if call_count <= 3:
raise OperationalError("Replica error")
return MagicMock(
__enter__=MagicMock(return_value=None),
@@ -701,8 +695,9 @@ class TestRlsTransaction:
):
with patch("api.db_utils.reset_read_db_alias"):
with patch("api.db_utils.logger") as mock_logger:
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
mock_logger.warning.assert_called_once()
warning_msg = mock_logger.warning.call_args[0][0]
@@ -711,7 +706,10 @@ class TestRlsTransaction:
def test_rls_transaction_logger_warning_on_fallback(
self, tenants_fixture, enable_read_replica
):
"""Test logger warnings are emitted on fallback to primary."""
"""Test logger warnings are emitted on fallback to primary.
3 replica failures produce 3 info logs, then 1 warning on fallback.
"""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
@@ -728,7 +726,7 @@ class TestRlsTransaction:
def atomic_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3:
if call_count <= 3:
raise OperationalError("Replica error")
return MagicMock(
__enter__=MagicMock(return_value=None),
@@ -744,10 +742,11 @@ class TestRlsTransaction:
):
with patch("api.db_utils.reset_read_db_alias"):
with patch("api.db_utils.logger") as mock_logger:
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
assert mock_logger.info.call_count == 2
assert mock_logger.info.call_count == 3
assert mock_logger.warning.call_count == 1
def test_rls_transaction_operational_error_raised_immediately_on_primary(
@@ -770,15 +769,16 @@ class TestRlsTransaction:
with patch("api.db_utils.time.sleep") as mock_sleep:
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
mock_sleep.assert_not_called()
def test_rls_transaction_operational_error_raised_after_max_attempts(
self, tenants_fixture, enable_read_replica
):
"""Test OperationalError raised after max attempts on replica."""
"""Test OperationalError raised after all 4 attempts (3 replica + 1 primary)."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
@@ -801,8 +801,9 @@ class TestRlsTransaction:
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
def test_rls_transaction_router_token_set_for_non_default_alias(
self, tenants_fixture
@@ -823,8 +824,9 @@ class TestRlsTransaction:
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
mock_set_alias.return_value = "test_token"
with rls_transaction(tenant_id, using=custom_alias):
pass
for attempt in rls_transaction(tenant_id, using=custom_alias):
with attempt:
pass
mock_set_alias.assert_called_once_with(custom_alias)
mock_reset_alias.assert_called_once_with("test_token")
@@ -848,8 +850,11 @@ class TestRlsTransaction:
with patch("api.db_utils.set_read_db_alias", return_value="test_token"):
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
with pytest.raises(Exception):
with rls_transaction(tenant_id, using=custom_alias):
pass
for attempt in rls_transaction(
tenant_id, using=custom_alias
):
with attempt:
pass
mock_reset_alias.assert_called_once_with("test_token")
@@ -873,8 +878,9 @@ class TestRlsTransaction:
with patch(
"api.db_utils.reset_read_db_alias"
) as mock_reset_alias:
with rls_transaction(tenant_id):
pass
for attempt in rls_transaction(tenant_id):
with attempt:
pass
mock_set_alias.assert_not_called()
mock_reset_alias.assert_not_called()
@@ -886,10 +892,11 @@ class TestRlsTransaction:
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with rls_transaction(tenant_id) as cursor:
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == tenant_id
for attempt in rls_transaction(tenant_id):
with attempt as cursor:
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
result = cursor.fetchone()
assert result[0] == tenant_id
def test_rls_transaction_custom_parameter(self, tenants_fixture):
"""Test rls_transaction with custom parameter name."""
@@ -897,21 +904,205 @@ class TestRlsTransaction:
tenant_id = str(tenant.id)
custom_param = "api.user_id"
with rls_transaction(tenant_id, parameter=custom_param) as cursor:
cursor.execute("SELECT current_setting(%s)", [custom_param])
result = cursor.fetchone()
assert result[0] == tenant_id
for attempt in rls_transaction(tenant_id, parameter=custom_param):
with attempt as cursor:
cursor.execute("SELECT current_setting(%s)", [custom_param])
result = cursor.fetchone()
assert result[0] == tenant_id
def test_rls_transaction_cursor_yielded_correctly(self, tenants_fixture):
"""Test cursor is yielded correctly."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with rls_transaction(tenant_id) as cursor:
assert cursor is not None
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result[0] == 1
for attempt in rls_transaction(tenant_id):
with attempt as cursor:
assert cursor is not None
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result[0] == 1
def test_rls_transaction_for_with_retries_mid_body_error(
self, tenants_fixture, enable_read_replica
):
"""Test that OperationalError raised inside the body triggers retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
body_call_count = 0
for attempt in rls_transaction(tenant_id):
with attempt:
body_call_count += 1
if body_call_count == 1:
raise OperationalError(
"Conflict with recovery"
)
assert body_call_count == 2
mock_connections.__getitem__.return_value.close.assert_called()
assert mock_sleep.call_count == 1
def test_rls_transaction_for_with_success_first_attempt(
self, tenants_fixture, enable_read_replica
):
"""Test happy path on replica: body succeeds on first attempt, no retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
iterations = 0
for attempt in rls_transaction(tenant_id):
with attempt:
iterations += 1
assert iterations == 1
mock_sleep.assert_not_called()
def test_rls_transaction_for_with_closes_stale_connection(
self, tenants_fixture, enable_read_replica
):
"""Test that stale connection is closed on OperationalError retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
body_call_count = 0
for attempt in rls_transaction(tenant_id):
with attempt:
body_call_count += 1
if body_call_count == 1:
raise OperationalError("stale connection")
mock_conn.close.assert_called()
def test_rls_transaction_for_with_non_operational_error_propagates(
self, tenants_fixture, enable_read_replica
):
"""Test that non-OperationalError propagates immediately without retry."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.time.sleep") as mock_sleep:
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(ValueError, match="bad value"):
for attempt in rls_transaction(tenant_id):
with attempt:
raise ValueError("bad value")
mock_sleep.assert_not_called()
def test_rls_transaction_for_with_primary_no_retry(self, tenants_fixture):
"""Test that OperationalError on primary propagates immediately."""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=None):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic"):
with patch("api.db_utils.time.sleep") as mock_sleep:
with pytest.raises(OperationalError):
for attempt in rls_transaction(tenant_id):
with attempt:
raise OperationalError("primary failure")
mock_sleep.assert_not_called()
def test_rls_transaction_for_with_replica_max_attempts_semantics(
self, tenants_fixture, enable_read_replica
):
"""Test that REPLICA_MAX_ATTEMPTS=3 produces 4 atomic calls (3 replica + 1 primary).
When transaction.atomic always fails, _RLSAttempt.__enter__ retries
internally and consumes all attempts. The for-loop body never
executes, so we assert on mock_atomic.call_count instead.
"""
tenant = tenants_fixture[0]
tenant_id = str(tenant.id)
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
with patch("api.db_utils.connections") as mock_connections:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_connections.__getitem__.return_value = mock_conn
mock_connections.__contains__.return_value = True
with patch("api.db_utils.transaction.atomic") as mock_atomic:
mock_atomic.side_effect = OperationalError("always fails")
with patch("api.db_utils.time.sleep"):
with patch(
"api.db_utils.set_read_db_alias", return_value="token"
):
with patch("api.db_utils.reset_read_db_alias"):
with pytest.raises(OperationalError):
for attempt in rls_transaction(tenant_id):
with attempt:
pass
# 3 replica attempts + 1 primary fallback
assert mock_atomic.call_count == 4
# Last call should use the default (primary) alias
last_call = mock_atomic.call_args_list[-1]
assert last_call.kwargs.get("using") == DEFAULT_DB_ALIAS
class TestPostgresEnumMigration:

View File

@@ -1,5 +1,5 @@
import uuid
from unittest.mock import call, patch
from unittest.mock import MagicMock, call, patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
@@ -65,8 +65,7 @@ class TestHandleProviderDeletionDecorator:
tenant = tenants_fixture[0]
deleted_provider_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_rls.return_value = [MagicMock()]
mock_filter.return_value.exists.return_value = False
@handle_provider_deletion
@@ -89,8 +88,7 @@ class TestHandleProviderDeletionDecorator:
scan_id = str(uuid.uuid4())
provider_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_rls.return_value = [MagicMock()]
mock_scan = type("MockScan", (), {"provider_id": provider_id})()
mock_scan_filter.return_value.first.return_value = mock_scan
@@ -112,8 +110,7 @@ class TestHandleProviderDeletionDecorator:
tenant = tenants_fixture[0]
scan_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_rls.return_value = [MagicMock()]
mock_scan_filter.return_value.first.return_value = None
@handle_provider_deletion
@@ -134,8 +131,7 @@ class TestHandleProviderDeletionDecorator:
tenant = tenants_fixture[0]
provider = providers_fixture[0]
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_rls.return_value = [MagicMock()]
mock_filter.return_value.exists.return_value = True
@handle_provider_deletion
@@ -154,8 +150,7 @@ class TestHandleProviderDeletionDecorator:
tenant = tenants_fixture[0]
deleted_provider_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_rls.return_value = [MagicMock()]
mock_filter.return_value.exists.return_value = False
@handle_provider_deletion
@@ -174,8 +169,7 @@ class TestHandleProviderDeletionDecorator:
tenant = tenants_fixture[0]
deleted_provider_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_rls.return_value = [MagicMock()]
mock_filter.return_value.exists.return_value = False
@handle_provider_deletion
@@ -194,8 +188,7 @@ class TestHandleProviderDeletionDecorator:
tenant = tenants_fixture[0]
provider = providers_fixture[0]
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_rls.return_value = [MagicMock()]
mock_filter.return_value.exists.return_value = True
@handle_provider_deletion

View File

@@ -789,9 +789,8 @@ class TestProwlerIntegrationConnectionTest:
mock_connection.projects = {"PROJ1": "Project 1", "PROJ2": "Project 2"}
mock_jira_class.test_connection.return_value = mock_connection
# Mock rls_transaction context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
# Mock rls_transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
result = prowler_integration_connection_test(integration)
@@ -840,9 +839,8 @@ class TestProwlerIntegrationConnectionTest:
mock_connection.projects = {} # Empty projects when connection fails
mock_jira_class.test_connection.return_value = mock_connection
# Mock rls_transaction context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
# Mock rls_transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
result = prowler_integration_connection_test(integration)
@@ -895,9 +893,8 @@ class TestProwlerIntegrationConnectionTest:
}
mock_jira_class.test_connection.return_value = mock_connection
# Mock rls_transaction context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
# Mock rls_transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
result = prowler_integration_connection_test(integration)

View File

@@ -415,9 +415,10 @@ def prowler_integration_connection_test(integration: Integration) -> Connection:
raise_on_exception=False,
)
project_keys = jira_connection.projects if jira_connection.is_connected else {}
with rls_transaction(str(integration.tenant_id)):
integration.configuration["projects"] = project_keys
integration.save()
for attempt in rls_transaction(str(integration.tenant_id)):
with attempt:
integration.configuration["projects"] = project_keys
integration.save()
return jira_connection
elif integration.integration_type == Integration.IntegrationChoices.SLACK:
pass
@@ -544,9 +545,12 @@ def initialize_prowler_integration(integration: Integration) -> Jira:
try:
return Jira(**integration.credentials)
except JiraBasicAuthError as jira_auth_error:
with rls_transaction(str(integration.tenant_id)):
integration.configuration["projects"] = {}
integration.connected = False
integration.connection_last_checked_at = datetime.now(tz=timezone.utc)
integration.save()
for attempt in rls_transaction(str(integration.tenant_id)):
with attempt:
integration.configuration["projects"] = {}
integration.connected = False
integration.connection_last_checked_at = datetime.now(
tz=timezone.utc
)
integration.save()
raise jira_auth_error

View File

@@ -630,8 +630,11 @@ class SAMLInitiateAPIView(GenericAPIView):
# Retrieve the SAML configuration for the given email domain
try:
check = SAMLDomainIndex.objects.get(email_domain=domain)
with rls_transaction(str(check.tenant_id)):
config = SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
for attempt in rls_transaction(str(check.tenant_id)):
with attempt:
config = SAMLConfiguration.objects.get(
tenant_id=str(check.tenant_id)
)
except (SAMLDomainIndex.DoesNotExist, SAMLConfiguration.DoesNotExist):
return Response(
{"detail": "Unauthorized domain."}, status=status.HTTP_403_FORBIDDEN
@@ -738,8 +741,9 @@ class TenantFinishACSView(FinishACSView):
# This handles scenarios like partially deleted or missing related objects
try:
check = SAMLDomainIndex.objects.get(email_domain=organization_slug)
with rls_transaction(str(check.tenant_id)):
SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
for attempt in rls_transaction(str(check.tenant_id)):
with attempt:
SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
social_app = SocialApp.objects.get(
provider="saml", client_id=organization_slug
)

View File

@@ -57,10 +57,11 @@ class RLSTask(Task):
from api.db_utils import rls_transaction
tenant_id = kwargs.get("tenant_id")
with rls_transaction(tenant_id):
APITask.objects.update_or_create(
id=task_result_instance.task_id,
tenant_id=tenant_id,
defaults={"task_runner_task": task_result_instance},
)
for attempt in rls_transaction(tenant_id):
with attempt:
APITask.objects.update_or_create(
id=task_result_instance.task_id,
tenant_id=tenant_id,
defaults={"task_runner_task": task_result_instance},
)
return result

View File

@@ -418,23 +418,24 @@ def tenants_fixture(create_test_user):
def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
user = create_test_user
for tenant in tenants_fixture[:2]:
with rls_transaction(str(tenant.id)):
role = Role.objects.create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.create(
user=user,
role=role,
tenant_id=tenant.id,
)
for attempt in rls_transaction(str(tenant.id)):
with attempt:
role = Role.objects.create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.create(
user=user,
role=role,
tenant_id=tenant.id,
)
@pytest.fixture

View File

@@ -30,15 +30,16 @@ def schedule_provider_scan(provider_instance: Provider):
pointer="/data/attributes/provider_id",
)
with rls_transaction(tenant_id):
scheduled_scan = Scan.objects.create(
tenant_id=tenant_id,
name="Daily scheduled scan",
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.AVAILABLE,
scheduled_at=datetime.now(timezone.utc),
)
for attempt in rls_transaction(tenant_id):
with attempt:
scheduled_scan = Scan.objects.create(
tenant_id=tenant_id,
name="Daily scheduled scan",
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.AVAILABLE,
scheduled_at=datetime.now(timezone.utc),
)
attack_paths_db_utils.create_attack_paths_scan(
tenant_id=tenant_id,

View File

@@ -15,8 +15,9 @@ logger = get_task_logger(__name__)
def can_provider_run_attack_paths_scan(tenant_id: str, provider_id: int) -> bool:
with rls_transaction(tenant_id):
prowler_api_provider = ProwlerAPIProvider.objects.get(id=provider_id)
for attempt in rls_transaction(tenant_id):
with attempt:
prowler_api_provider = ProwlerAPIProvider.objects.get(id=provider_id)
return is_provider_available(prowler_api_provider.provider)
@@ -29,24 +30,25 @@ def create_attack_paths_scan(
if not can_provider_run_attack_paths_scan(tenant_id, provider_id):
return None
with rls_transaction(tenant_id):
# Inherit graph_data_ready from the previous scan for this provider,
# so queries remain available while the new scan runs.
previous_data_ready = ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
graph_data_ready=True,
).exists()
for attempt in rls_transaction(tenant_id):
with attempt:
# Inherit graph_data_ready from the previous scan for this provider,
# so queries remain available while the new scan runs.
previous_data_ready = ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
graph_data_ready=True,
).exists()
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
state=StateChoices.SCHEDULED,
started_at=datetime.now(tz=timezone.utc),
graph_data_ready=previous_data_ready,
)
attack_paths_scan.save()
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
state=StateChoices.SCHEDULED,
started_at=datetime.now(tz=timezone.utc),
graph_data_ready=previous_data_ready,
)
attack_paths_scan.save()
return attack_paths_scan
@@ -56,10 +58,11 @@ def retrieve_attack_paths_scan(
scan_id: str,
) -> ProwlerAPIAttackPathsScan | None:
try:
with rls_transaction(tenant_id):
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.get(
scan_id=scan_id,
)
for attempt in rls_transaction(tenant_id):
with attempt:
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.get(
scan_id=scan_id,
)
return attack_paths_scan
@@ -72,20 +75,21 @@ def starting_attack_paths_scan(
task_id: str,
cartography_config: CartographyConfig,
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.task_id = task_id
attack_paths_scan.state = StateChoices.EXECUTING
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
attack_paths_scan.update_tag = cartography_config.update_tag
for attempt in rls_transaction(attack_paths_scan.tenant_id):
with attempt:
attack_paths_scan.task_id = task_id
attack_paths_scan.state = StateChoices.EXECUTING
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
attack_paths_scan.update_tag = cartography_config.update_tag
attack_paths_scan.save(
update_fields=[
"task_id",
"state",
"started_at",
"update_tag",
]
)
attack_paths_scan.save(
update_fields=[
"task_id",
"state",
"started_at",
"update_tag",
]
)
def finish_attack_paths_scan(
@@ -93,47 +97,50 @@ def finish_attack_paths_scan(
state: StateChoices,
ingestion_exceptions: dict[str, Any],
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
now = datetime.now(tz=timezone.utc)
duration = (
int((now - attack_paths_scan.started_at).total_seconds())
if attack_paths_scan.started_at
else 0
)
for attempt in rls_transaction(attack_paths_scan.tenant_id):
with attempt:
now = datetime.now(tz=timezone.utc)
duration = (
int((now - attack_paths_scan.started_at).total_seconds())
if attack_paths_scan.started_at
else 0
)
attack_paths_scan.state = state
attack_paths_scan.progress = 100
attack_paths_scan.completed_at = now
attack_paths_scan.duration = duration
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
attack_paths_scan.state = state
attack_paths_scan.progress = 100
attack_paths_scan.completed_at = now
attack_paths_scan.duration = duration
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
attack_paths_scan.save(
update_fields=[
"state",
"progress",
"completed_at",
"duration",
"ingestion_exceptions",
]
)
attack_paths_scan.save(
update_fields=[
"state",
"progress",
"completed_at",
"duration",
"ingestion_exceptions",
]
)
def update_attack_paths_scan_progress(
attack_paths_scan: ProwlerAPIAttackPathsScan,
progress: int,
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.progress = progress
attack_paths_scan.save(update_fields=["progress"])
for attempt in rls_transaction(attack_paths_scan.tenant_id):
with attempt:
attack_paths_scan.progress = progress
attack_paths_scan.save(update_fields=["progress"])
def set_graph_data_ready(
attack_paths_scan: ProwlerAPIAttackPathsScan,
ready: bool,
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.graph_data_ready = ready
attack_paths_scan.save(update_fields=["graph_data_ready"])
for attempt in rls_transaction(attack_paths_scan.tenant_id):
with attempt:
attack_paths_scan.graph_data_ready = ready
attack_paths_scan.save(update_fields=["graph_data_ready"])
def set_provider_graph_data_ready(
@@ -145,12 +152,13 @@ def set_provider_graph_data_ready(
Used before drop/sync so that older scan IDs cannot bypass the query gate while the graph is being replaced.
"""
with rls_transaction(attack_paths_scan.tenant_id):
ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=attack_paths_scan.tenant_id,
provider_id=attack_paths_scan.provider_id,
).update(graph_data_ready=ready)
attack_paths_scan.refresh_from_db(fields=["graph_data_ready"])
for attempt in rls_transaction(attack_paths_scan.tenant_id):
with attempt:
ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=attack_paths_scan.tenant_id,
provider_id=attack_paths_scan.provider_id,
).update(graph_data_ready=ready)
attack_paths_scan.refresh_from_db(fields=["graph_data_ready"])
def recover_graph_data_ready(

View File

@@ -277,15 +277,16 @@ def _fetch_findings_batch(
Uses read replica and RLS-scoped transaction.
"""
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
# Use `all_objects` to get `Findings` even on soft-deleted `Providers`
# But even the provider is already validated as active in this context
qs = FindingModel.all_objects.filter(scan_id=scan_id).order_by("id")
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
# Use `all_objects` to get `Findings` even on soft-deleted `Providers`
# But even the provider is already validated as active in this context
qs = FindingModel.all_objects.filter(scan_id=scan_id).order_by("id")
if after_id is not None:
qs = qs.filter(id__gt=after_id)
if after_id is not None:
qs = qs.filter(id__gt=after_id)
return list(qs.values(*_DB_QUERY_FIELDS)[:FINDINGS_BATCH_SIZE])
return list(qs.values(*_DB_QUERY_FIELDS)[:FINDINGS_BATCH_SIZE])
# Batch Enrichment
@@ -316,12 +317,13 @@ def _build_finding_resource_map(
finding_ids: list[UUID], tenant_id: str
) -> dict[UUID, list[str]]:
"""Build mapping from finding_id to list of resource UIDs."""
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
resource_mappings = ResourceFindingMapping.objects.filter(
finding_id__in=finding_ids
).values_list("finding_id", "resource__uid")
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
resource_mappings = ResourceFindingMapping.objects.filter(
finding_id__in=finding_ids
).values_list("finding_id", "resource__uid")
result = defaultdict(list)
for finding_id, resource_uid in resource_mappings:
result[finding_id].append(resource_uid)
return result
result = defaultdict(list)
for finding_id, resource_uid in resource_mappings:
result[finding_id].append(resource_uid)
return result

View File

@@ -86,9 +86,10 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
ingestion_exceptions = {} # This will hold any exceptions raised during ingestion
# Prowler necessary objects
with rls_transaction(tenant_id):
prowler_api_provider = ProwlerAPIProvider.objects.get(scan__pk=scan_id)
prowler_sdk_provider = initialize_prowler_provider(prowler_api_provider)
for attempt in rls_transaction(tenant_id):
with attempt:
prowler_api_provider = ProwlerAPIProvider.objects.get(scan__pk=scan_id)
prowler_sdk_provider = initialize_prowler_provider(prowler_api_provider)
# Attack Paths Scan necessary objects
cartography_ingestion_function = get_cartography_ingestion_function(

View File

@@ -41,54 +41,56 @@ logger = get_task_logger(__name__)
def backfill_resource_scan_summaries(tenant_id: str, scan_id: str):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
if ResourceScanSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
if ResourceScanSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
with rls_transaction(tenant_id):
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
for attempt in rls_transaction(tenant_id):
with attempt:
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
resource_ids_qs = (
ResourceFindingMapping.objects.filter(
tenant_id=tenant_id, finding__scan_id=scan_id
)
.values_list("resource_id", flat=True)
.distinct()
)
resource_ids = list(resource_ids_qs)
if not resource_ids:
return {"status": "no resources to backfill"}
resources_qs = Resource.objects.filter(
tenant_id=tenant_id, id__in=resource_ids
).only("id", "service", "region", "type")
summaries = []
for resource in resources_qs.iterator():
summaries.append(
ResourceScanSummary(
tenant_id=tenant_id,
scan_id=scan_id,
resource_id=str(resource.id),
service=resource.service,
region=resource.region,
resource_type=resource.type,
resource_ids_qs = (
ResourceFindingMapping.objects.filter(
tenant_id=tenant_id, finding__scan_id=scan_id
)
.values_list("resource_id", flat=True)
.distinct()
)
for i in range(0, len(summaries), 500):
ResourceScanSummary.objects.bulk_create(
summaries[i : i + 500], ignore_conflicts=True
)
resource_ids = list(resource_ids_qs)
if not resource_ids:
return {"status": "no resources to backfill"}
resources_qs = Resource.objects.filter(
tenant_id=tenant_id, id__in=resource_ids
).only("id", "service", "region", "type")
summaries = []
for resource in resources_qs.iterator():
summaries.append(
ResourceScanSummary(
tenant_id=tenant_id,
scan_id=scan_id,
resource_id=str(resource.id),
service=resource.service,
region=resource.region,
resource_type=resource.type,
)
)
for i in range(0, len(summaries), 500):
ResourceScanSummary.objects.bulk_create(
summaries[i : i + 500], ignore_conflicts=True
)
return {"status": "backfilled", "inserted": len(summaries)}
@@ -107,99 +109,101 @@ def backfill_compliance_summaries(tenant_id: str, scan_id: str):
Returns:
dict: Status indicating whether backfill was performed
"""
with rls_transaction(tenant_id):
if ComplianceOverviewSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
for attempt in rls_transaction(tenant_id):
with attempt:
if ComplianceOverviewSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
with rls_transaction(tenant_id):
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
for attempt in rls_transaction(tenant_id):
with attempt:
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
# Fetch all compliance requirement overview rows for this scan
requirement_rows = ComplianceRequirementOverview.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).values(
"compliance_id",
"requirement_id",
"requirement_status",
)
# Fetch all compliance requirement overview rows for this scan
requirement_rows = ComplianceRequirementOverview.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).values(
"compliance_id",
"requirement_id",
"requirement_status",
)
if not requirement_rows:
return {"status": "no compliance data to backfill"}
if not requirement_rows:
return {"status": "no compliance data to backfill"}
# Group by (compliance_id, requirement_id) across regions
requirement_statuses = defaultdict(
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
)
# Group by (compliance_id, requirement_id) across regions
requirement_statuses = defaultdict(
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
)
for row in requirement_rows:
compliance_id = row["compliance_id"]
requirement_id = row["requirement_id"]
requirement_status = row["requirement_status"]
for row in requirement_rows:
compliance_id = row["compliance_id"]
requirement_id = row["requirement_id"]
requirement_status = row["requirement_status"]
# Aggregate requirement status across regions
key = (compliance_id, requirement_id)
requirement_statuses[key]["total_count"] += 1
# Aggregate requirement status across regions
key = (compliance_id, requirement_id)
requirement_statuses[key]["total_count"] += 1
if requirement_status == "FAIL":
requirement_statuses[key]["fail_count"] += 1
elif requirement_status == "PASS":
requirement_statuses[key]["pass_count"] += 1
if requirement_status == "FAIL":
requirement_statuses[key]["fail_count"] += 1
elif requirement_status == "PASS":
requirement_statuses[key]["pass_count"] += 1
# Determine per-requirement status and aggregate to compliance level
compliance_summaries = defaultdict(
lambda: {
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
}
)
# Determine per-requirement status and aggregate to compliance level
compliance_summaries = defaultdict(
lambda: {
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
}
)
for (compliance_id, requirement_id), counts in requirement_statuses.items():
# Apply business rule: any FAIL → requirement fails
if counts["fail_count"] > 0:
req_status = "FAIL"
elif counts["pass_count"] == counts["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
for (compliance_id, requirement_id), counts in requirement_statuses.items():
# Apply business rule: any FAIL → requirement fails
if counts["fail_count"] > 0:
req_status = "FAIL"
elif counts["pass_count"] == counts["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
# Aggregate to compliance level
compliance_summaries[compliance_id]["total_requirements"] += 1
if req_status == "PASS":
compliance_summaries[compliance_id]["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_summaries[compliance_id]["requirements_failed"] += 1
else:
compliance_summaries[compliance_id]["requirements_manual"] += 1
# Aggregate to compliance level
compliance_summaries[compliance_id]["total_requirements"] += 1
if req_status == "PASS":
compliance_summaries[compliance_id]["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_summaries[compliance_id]["requirements_failed"] += 1
else:
compliance_summaries[compliance_id]["requirements_manual"] += 1
# Create summary objects
summary_objects = []
for compliance_id, data in compliance_summaries.items():
summary_objects.append(
ComplianceOverviewSummary(
tenant_id=tenant_id,
scan_id=scan_id,
compliance_id=compliance_id,
requirements_passed=data["requirements_passed"],
requirements_failed=data["requirements_failed"],
requirements_manual=data["requirements_manual"],
total_requirements=data["total_requirements"],
# Create summary objects
summary_objects = []
for compliance_id, data in compliance_summaries.items():
summary_objects.append(
ComplianceOverviewSummary(
tenant_id=tenant_id,
scan_id=scan_id,
compliance_id=compliance_id,
requirements_passed=data["requirements_passed"],
requirements_failed=data["requirements_failed"],
requirements_manual=data["requirements_manual"],
total_requirements=data["total_requirements"],
)
)
)
# Bulk insert summaries
if summary_objects:
ComplianceOverviewSummary.objects.bulk_create(
summary_objects, batch_size=500, ignore_conflicts=True
)
# Bulk insert summaries
if summary_objects:
ComplianceOverviewSummary.objects.bulk_create(
summary_objects, batch_size=500, ignore_conflicts=True
)
return {"status": "backfilled", "inserted": len(summary_objects)}
@@ -212,82 +216,85 @@ def backfill_daily_severity_summaries(tenant_id: str, days: int = None):
created_count = 0
updated_count = 0
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
scan_filter = {
"tenant_id": tenant_id,
"state": StateChoices.COMPLETED,
"completed_at__isnull": False,
}
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
scan_filter = {
"tenant_id": tenant_id,
"state": StateChoices.COMPLETED,
"completed_at__isnull": False,
}
if days is not None:
cutoff_date = timezone.now() - timedelta(days=days)
scan_filter["completed_at__gte"] = cutoff_date
if days is not None:
cutoff_date = timezone.now() - timedelta(days=days)
scan_filter["completed_at__gte"] = cutoff_date
completed_scans = (
Scan.objects.filter(**scan_filter)
.order_by("provider_id", "-completed_at")
.values("id", "provider_id", "completed_at")
)
completed_scans = (
Scan.objects.filter(**scan_filter)
.order_by("provider_id", "-completed_at")
.values("id", "provider_id", "completed_at")
)
if not completed_scans:
return {"status": "no scans to backfill"}
if not completed_scans:
return {"status": "no scans to backfill"}
# Keep only latest scan per provider/day
latest_scans_by_day = {}
for scan in completed_scans:
key = (scan["provider_id"], scan["completed_at"].date())
if key not in latest_scans_by_day:
latest_scans_by_day[key] = scan
# Keep only latest scan per provider/day
latest_scans_by_day = {}
for scan in completed_scans:
key = (scan["provider_id"], scan["completed_at"].date())
if key not in latest_scans_by_day:
latest_scans_by_day[key] = scan
# Process each provider/day
for (provider_id, scan_date), scan in latest_scans_by_day.items():
scan_id = scan["id"]
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
severity_totals = (
ScanSummary.objects.filter(
tenant_id=tenant_id,
scan_id=scan_id,
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
severity_totals = (
ScanSummary.objects.filter(
tenant_id=tenant_id,
scan_id=scan_id,
)
.values("severity")
.annotate(total_fail=Sum("fail"), total_muted=Sum("muted"))
)
.values("severity")
.annotate(total_fail=Sum("fail"), total_muted=Sum("muted"))
)
severity_data = {
"critical": 0,
"high": 0,
"medium": 0,
"low": 0,
"informational": 0,
"muted": 0,
}
severity_data = {
"critical": 0,
"high": 0,
"medium": 0,
"low": 0,
"informational": 0,
"muted": 0,
}
for row in severity_totals:
severity = row["severity"]
if severity in severity_data:
severity_data[severity] = row["total_fail"] or 0
severity_data["muted"] += row["total_muted"] or 0
for row in severity_totals:
severity = row["severity"]
if severity in severity_data:
severity_data[severity] = row["total_fail"] or 0
severity_data["muted"] += row["total_muted"] or 0
with rls_transaction(tenant_id):
_, created = DailySeveritySummary.objects.update_or_create(
tenant_id=tenant_id,
provider_id=provider_id,
date=scan_date,
defaults={
"scan_id": scan_id,
"critical": severity_data["critical"],
"high": severity_data["high"],
"medium": severity_data["medium"],
"low": severity_data["low"],
"informational": severity_data["informational"],
"muted": severity_data["muted"],
},
)
for attempt in rls_transaction(tenant_id):
with attempt:
_, created = DailySeveritySummary.objects.update_or_create(
tenant_id=tenant_id,
provider_id=provider_id,
date=scan_date,
defaults={
"scan_id": scan_id,
"critical": severity_data["critical"],
"high": severity_data["high"],
"medium": severity_data["medium"],
"low": severity_data["low"],
"informational": severity_data["informational"],
"muted": severity_data["muted"],
},
)
if created:
created_count += 1
else:
updated_count += 1
if created:
created_count += 1
else:
updated_count += 1
return {
"status": "backfilled",
@@ -311,34 +318,35 @@ def backfill_scan_category_summaries(tenant_id: str, scan_id: str):
Returns:
dict: Status indicating whether backfill was performed
"""
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
if ScanCategorySummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
if ScanCategorySummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
category_counts: dict[tuple[str, str], dict[str, int]] = {}
for finding in Finding.all_objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).values("categories", "severity", "status", "delta", "muted"):
aggregate_category_counts(
categories=finding.get("categories") or [],
severity=finding.get("severity"),
status=finding.get("status"),
delta=finding.get("delta"),
muted=finding.get("muted", False),
cache=category_counts,
)
category_counts: dict[tuple[str, str], dict[str, int]] = {}
for finding in Finding.all_objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).values("categories", "severity", "status", "delta", "muted"):
aggregate_category_counts(
categories=finding.get("categories") or [],
severity=finding.get("severity"),
status=finding.get("status"),
delta=finding.get("delta"),
muted=finding.get("muted", False),
cache=category_counts,
)
if not category_counts:
return {"status": "no categories to backfill"}
if not category_counts:
return {"status": "no categories to backfill"}
category_summaries = [
ScanCategorySummary(
@@ -353,10 +361,11 @@ def backfill_scan_category_summaries(tenant_id: str, scan_id: str):
for (category, severity), counts in category_counts.items()
]
with rls_transaction(tenant_id):
ScanCategorySummary.objects.bulk_create(
category_summaries, batch_size=500, ignore_conflicts=True
)
for attempt in rls_transaction(tenant_id):
with attempt:
ScanCategorySummary.objects.bulk_create(
category_summaries, batch_size=500, ignore_conflicts=True
)
return {"status": "backfilled", "categories_count": len(category_counts)}
@@ -375,51 +384,52 @@ def backfill_scan_resource_group_summaries(tenant_id: str, scan_id: str):
Returns:
dict: Status indicating whether backfill was performed
"""
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
if ScanGroupSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
if ScanGroupSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
resource_group_counts: dict[tuple[str, str], dict[str, int]] = {}
group_resources_cache: dict[str, set] = {}
# Get findings with their first resource UID via annotation
resource_uid_subquery = ResourceFindingMapping.objects.filter(
finding_id=OuterRef("id"), tenant_id=tenant_id
).values("resource__uid")[:1]
resource_group_counts: dict[tuple[str, str], dict[str, int]] = {}
group_resources_cache: dict[str, set] = {}
# Get findings with their first resource UID via annotation
resource_uid_subquery = ResourceFindingMapping.objects.filter(
finding_id=OuterRef("id"), tenant_id=tenant_id
).values("resource__uid")[:1]
for finding in (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.annotate(resource_uid=Subquery(resource_uid_subquery))
.values(
"resource_groups",
"severity",
"status",
"delta",
"muted",
"resource_uid",
)
):
aggregate_resource_group_counts(
resource_group=finding.get("resource_groups"),
severity=finding.get("severity"),
status=finding.get("status"),
delta=finding.get("delta"),
muted=finding.get("muted", False),
resource_uid=finding.get("resource_uid") or "",
cache=resource_group_counts,
group_resources_cache=group_resources_cache,
)
for finding in (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.annotate(resource_uid=Subquery(resource_uid_subquery))
.values(
"resource_groups",
"severity",
"status",
"delta",
"muted",
"resource_uid",
)
):
aggregate_resource_group_counts(
resource_group=finding.get("resource_groups"),
severity=finding.get("severity"),
status=finding.get("status"),
delta=finding.get("delta"),
muted=finding.get("muted", False),
resource_uid=finding.get("resource_uid") or "",
cache=resource_group_counts,
group_resources_cache=group_resources_cache,
)
if not resource_group_counts:
return {"status": "no resource groups to backfill"}
if not resource_group_counts:
return {"status": "no resource groups to backfill"}
# Compute group-level resource counts (same value for all severity rows in a group)
group_resource_counts = {
@@ -439,10 +449,11 @@ def backfill_scan_resource_group_summaries(tenant_id: str, scan_id: str):
for (grp, severity), counts in resource_group_counts.items()
]
with rls_transaction(tenant_id):
ScanGroupSummary.objects.bulk_create(
resource_group_summaries, batch_size=500, ignore_conflicts=True
)
for attempt in rls_transaction(tenant_id):
with attempt:
ScanGroupSummary.objects.bulk_create(
resource_group_summaries, batch_size=500, ignore_conflicts=True
)
return {"status": "backfilled", "resource_groups_count": len(resource_group_counts)}
@@ -460,34 +471,35 @@ def backfill_provider_compliance_scores(tenant_id: str) -> dict:
Returns:
dict: Statistics about the backfill operation
"""
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
completed_scans = Scan.all_objects.filter(
tenant_id=tenant_id,
state=StateChoices.COMPLETED,
completed_at__isnull=False,
)
if not completed_scans.exists():
return {"status": "no completed scans"}
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
completed_scans = Scan.all_objects.filter(
tenant_id=tenant_id,
state=StateChoices.COMPLETED,
completed_at__isnull=False,
)
if not completed_scans.exists():
return {"status": "no completed scans"}
existing_providers = set(
ProviderComplianceScore.objects.filter(tenant_id=tenant_id)
.values_list("provider_id", flat=True)
.distinct()
)
if existing_providers:
completed_scans = completed_scans.exclude(
provider_id__in=existing_providers
existing_providers = set(
ProviderComplianceScore.objects.filter(tenant_id=tenant_id)
.values_list("provider_id", flat=True)
.distinct()
)
scan_info = list(
completed_scans.order_by("provider_id", "-completed_at")
.distinct("provider_id")
.values("id", "provider_id", "completed_at")
)
if existing_providers:
completed_scans = completed_scans.exclude(
provider_id__in=existing_providers
)
if not scan_info:
return {"status": "no scans to process"}
scan_info = list(
completed_scans.order_by("provider_id", "-completed_at")
.distinct("provider_id")
.values("id", "provider_id", "completed_at")
)
if not scan_info:
return {"status": "no scans to process"}
total_upserted = 0
providers_processed = 0
@@ -577,32 +589,33 @@ def backfill_finding_group_summaries(tenant_id: str, days: int = None):
total_created = 0
total_updated = 0
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
scan_filter = {
"tenant_id": tenant_id,
"state": StateChoices.COMPLETED,
"completed_at__isnull": False,
}
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
scan_filter = {
"tenant_id": tenant_id,
"state": StateChoices.COMPLETED,
"completed_at__isnull": False,
}
if days is not None:
cutoff_date = timezone.now() - timedelta(days=days)
scan_filter["completed_at__gte"] = cutoff_date
if days is not None:
cutoff_date = timezone.now() - timedelta(days=days)
scan_filter["completed_at__gte"] = cutoff_date
completed_scans = (
Scan.objects.filter(**scan_filter)
.order_by("-completed_at")
.values("id", "completed_at")
)
completed_scans = (
Scan.objects.filter(**scan_filter)
.order_by("-completed_at")
.values("id", "completed_at")
)
if not completed_scans:
return {"status": "no scans to backfill"}
if not completed_scans:
return {"status": "no scans to backfill"}
# Keep only latest scan per day
latest_scans_by_day = {}
for scan in completed_scans:
key = scan["completed_at"].date()
if key not in latest_scans_by_day:
latest_scans_by_day[key] = scan
# Keep only latest scan per day
latest_scans_by_day = {}
for scan in completed_scans:
key = scan["completed_at"].date()
if key not in latest_scans_by_day:
latest_scans_by_day[key] = scan
# Process each day's scan
for scan_date, scan in latest_scans_by_day.items():

View File

@@ -28,20 +28,21 @@ def _recalculate_tenant_compliance_summary(tenant_id: str, compliance_ids: list[
compliance_ids = sorted(set(compliance_ids))
with rls_transaction(tenant_id, using=MainRouter.default_db) as cursor:
# Serialize tenant-level summary updates to avoid concurrent recomputes
cursor.execute(
"SELECT pg_advisory_xact_lock(hashtext(%s))",
[tenant_id],
)
cursor.execute(
COMPLIANCE_UPSERT_TENANT_SUMMARY_SQL,
[tenant_id, tenant_id, compliance_ids],
)
cursor.execute(
COMPLIANCE_DELETE_EMPTY_TENANT_SUMMARY_SQL,
[tenant_id, compliance_ids],
)
for attempt in rls_transaction(tenant_id, using=MainRouter.default_db):
with attempt as cursor:
# Serialize tenant-level summary updates to avoid concurrent recomputes
cursor.execute(
"SELECT pg_advisory_xact_lock(hashtext(%s))",
[tenant_id],
)
cursor.execute(
COMPLIANCE_UPSERT_TENANT_SUMMARY_SQL,
[tenant_id, tenant_id, compliance_ids],
)
cursor.execute(
COMPLIANCE_DELETE_EMPTY_TENANT_SUMMARY_SQL,
[tenant_id, compliance_ids],
)
def delete_provider(tenant_id: str, pk: str):
@@ -59,32 +60,39 @@ def delete_provider(tenant_id: str, pk: str):
"""
# Get all provider related data to delete them in batches
with rls_transaction(tenant_id):
try:
instance = Provider.all_objects.get(pk=pk)
except Provider.DoesNotExist:
logger.info(f"Provider `{pk}` already deleted, skipping")
return {}
for attempt in rls_transaction(tenant_id):
with attempt:
try:
instance = Provider.all_objects.get(pk=pk)
except Provider.DoesNotExist:
logger.info(f"Provider `{pk}` already deleted, skipping")
return {}
compliance_ids = list(
ProviderComplianceScore.objects.filter(provider=instance)
.values_list("compliance_id", flat=True)
.distinct()
)
attack_paths_scan_ids = list(
AttackPathsScan.all_objects.filter(provider=instance).values_list(
"id", flat=True
compliance_ids = list(
ProviderComplianceScore.objects.filter(provider=instance)
.values_list("compliance_id", flat=True)
.distinct()
)
)
deletion_steps = [
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
("Findings", Finding.all_objects.filter(scan__provider=instance)),
("Resources", Resource.all_objects.filter(provider=instance)),
("Scans", Scan.all_objects.filter(provider=instance)),
("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)),
]
attack_paths_scan_ids = list(
AttackPathsScan.all_objects.filter(provider=instance).values_list(
"id", flat=True
)
)
deletion_steps = [
(
"Scan Summaries",
ScanSummary.all_objects.filter(scan__provider=instance),
),
("Findings", Finding.all_objects.filter(scan__provider=instance)),
("Resources", Resource.all_objects.filter(provider=instance)),
("Scans", Scan.all_objects.filter(provider=instance)),
(
"AttackPathsScans",
AttackPathsScan.all_objects.filter(provider=instance),
),
]
# Drop orphaned temporary Neo4j databases
for aps_id in attack_paths_scan_ids:
@@ -116,8 +124,9 @@ def delete_provider(tenant_id: str, pk: str):
# Delete the provider instance itself
try:
with rls_transaction(tenant_id):
_, provider_summary = instance.delete()
for attempt in rls_transaction(tenant_id):
with attempt:
_, provider_summary = instance.delete()
deletion_summary.update(provider_summary)
except DatabaseError as db_error:
logger.error(f"Error deleting Provider: {db_error}")

View File

@@ -295,8 +295,9 @@ def _build_output_path(
# Sanitize the prowler provider name to ensure it is a valid directory name
prowler_provider_sanitized = re.sub(r"[^\w\-]", "-", prowler_provider)
with rls_transaction(tenant_id):
started_at = Scan.objects.get(id=scan_id).started_at
for attempt in rls_transaction(tenant_id):
with attempt:
started_at = Scan.objects.get(id=scan_id).started_at
set_output_timestamp(started_at)

View File

@@ -1,14 +1,12 @@
import os
import time
from glob import glob
from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
from django.db import OperationalError
from tasks.utils import batched
from api.db_router import READ_REPLICA_ALIAS, MainRouter
from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction
from api.db_utils import rls_transaction
from api.models import Finding, Integration, Provider
from api.utils import initialize_prowler_integration, initialize_prowler_provider
from prowler.lib.outputs.asff.asff import ASFF
@@ -78,14 +76,15 @@ def upload_s3_integration(
logger.info(f"Processing S3 integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
for attempt in rls_transaction(tenant_id):
with attempt:
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
)
)
if not integrations:
logger.error(f"No S3 integrations found for provider {provider_id}")
@@ -184,14 +183,18 @@ def get_security_hub_client_from_integration(
if the connection was successful and the SecurityHub client or connection object.
"""
# Get the provider associated with this integration
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
provider_relationship = integration.integrationproviderrelationship_set.first()
if not provider_relationship:
return Connection(
is_connected=False, error="No provider associated with this integration"
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
provider_relationship = (
integration.integrationproviderrelationship_set.first()
)
provider_uid = provider_relationship.provider.uid
provider_secret = provider_relationship.provider.secret.secret
if not provider_relationship:
return Connection(
is_connected=False,
error="No provider associated with this integration",
)
provider_uid = provider_relationship.provider.uid
provider_secret = provider_relationship.provider.secret.secret
credentials = (
integration.credentials if integration.credentials else provider_secret
@@ -213,9 +216,10 @@ def get_security_hub_client_from_integration(
regions_status[region] = region in connection.enabled_regions
# Save regions information in the integration configuration
with rls_transaction(tenant_id, using=MainRouter.default_db):
integration.configuration["regions"] = regions_status
integration.save()
for attempt in rls_transaction(tenant_id, using=MainRouter.default_db):
with attempt:
integration.configuration["regions"] = regions_status
integration.save()
# Create SecurityHub client with all necessary parameters
security_hub = SecurityHub(
@@ -228,10 +232,11 @@ def get_security_hub_client_from_integration(
return True, security_hub
else:
# Reset regions information if connection fails and integration is not connected
with rls_transaction(tenant_id, using=MainRouter.default_db):
integration.connected = False
integration.configuration["regions"] = {}
integration.save()
for attempt in rls_transaction(tenant_id, using=MainRouter.default_db):
with attempt:
integration.connected = False
integration.configuration["regions"] = {}
integration.save()
return False, connection
@@ -256,27 +261,28 @@ def upload_security_hub_integration(
logger.info(f"Processing Security Hub integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
# Get Security Hub integrations for this provider
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
enabled=True,
for attempt in rls_transaction(tenant_id):
with attempt:
# Get Security Hub integrations for this provider
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
enabled=True,
)
)
)
if not integrations:
logger.error(
f"No Security Hub integrations found for provider {provider_id}"
)
return False
if not integrations:
logger.error(
f"No Security Hub integrations found for provider {provider_id}"
)
return False
# Get the provider object
provider = Provider.objects.get(id=provider_id)
# Get the provider object
provider = Provider.objects.get(id=provider_id)
# Initialize prowler provider for finding transformation
prowler_provider = initialize_prowler_provider(provider)
# Initialize prowler provider for finding transformation
prowler_provider = initialize_prowler_provider(provider)
# Process each Security Hub integration
integration_executions = 0
@@ -293,130 +299,104 @@ def upload_security_hub_integration(
total_findings_sent[integration.id] = 0
# Process findings in batches to avoid memory issues
max_attempts = REPLICA_MAX_ATTEMPTS if READ_REPLICA_ALIAS else 1
has_findings = False
batch_number = 0
for attempt in range(1, max_attempts + 1):
read_alias = None
if READ_REPLICA_ALIAS:
read_alias = (
READ_REPLICA_ALIAS
if attempt < max_attempts
else MainRouter.default_db
)
try:
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
batch_number = 0
has_findings = False
with rls_transaction(
tenant_id,
using=read_alias,
retry_on_replica=False,
):
qs = (
Finding.all_objects.filter(
tenant_id=tenant_id, scan_id=scan_id
)
.order_by("uid")
.iterator()
qs = (
Finding.all_objects.filter(
tenant_id=tenant_id, scan_id=scan_id
)
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
batch_number += 1
has_findings = True
# Transform findings for this batch
transformed_findings = [
FindingOutput.transform_api_finding(
finding, prowler_provider
)
for finding in batch
]
# Convert to ASFF format
asff_transformer = ASFF(
findings=transformed_findings,
file_path="",
file_extension="json",
)
asff_transformer.transform(transformed_findings)
# Get the batch of ASFF findings
batch_asff_findings = asff_transformer.data
if batch_asff_findings:
# Create Security Hub client for first batch or reuse existing
if not security_hub_client:
connected, security_hub = (
get_security_hub_client_from_integration(
integration,
tenant_id,
batch_asff_findings,
)
)
if not connected:
if isinstance(
security_hub.error,
SecurityHubNoEnabledRegionsError,
):
logger.warning(
f"Security Hub integration {integration.id} has no enabled regions"
)
else:
logger.error(
f"Security Hub connection failed for integration {integration.id}: "
f"{security_hub.error}"
)
break # Skip this integration
security_hub_client = security_hub
logger.info(
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
f"integration {integration.id}"
)
else:
# Update findings in existing client for this batch
security_hub_client._findings_per_region = (
security_hub_client.filter(
batch_asff_findings,
send_only_fails,
)
)
# Send this batch to Security Hub
try:
findings_sent = security_hub_client.batch_send_to_security_hub()
total_findings_sent[integration.id] += (
findings_sent
)
if findings_sent > 0:
logger.debug(
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
)
except Exception as batch_error:
logger.error(
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
)
# Clear memory after processing each batch
asff_transformer._data.clear()
del batch_asff_findings
del transformed_findings
break
except OperationalError as e:
if attempt == max_attempts:
raise
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.info(
"RLS query failed during Security Hub integration "
f"(attempt {attempt}/{max_attempts}), retrying in {delay}s. Error: {e}"
.order_by("uid")
.iterator()
)
time.sleep(delay)
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
batch_number += 1
has_findings = True
# Transform findings for this batch
transformed_findings = [
FindingOutput.transform_api_finding(
finding, prowler_provider
)
for finding in batch
]
# Convert to ASFF format
asff_transformer = ASFF(
findings=transformed_findings,
file_path="",
file_extension="json",
)
asff_transformer.transform(transformed_findings)
# Get the batch of ASFF findings
batch_asff_findings = asff_transformer.data
if batch_asff_findings:
# Create Security Hub client for first batch or reuse existing
if not security_hub_client:
connected, security_hub = (
get_security_hub_client_from_integration(
integration,
tenant_id,
batch_asff_findings,
)
)
if not connected:
if isinstance(
security_hub.error,
SecurityHubNoEnabledRegionsError,
):
logger.warning(
f"Security Hub integration {integration.id} has no enabled regions"
)
else:
logger.error(
f"Security Hub connection failed for integration {integration.id}: "
f"{security_hub.error}"
)
break # Skip this integration
security_hub_client = security_hub
logger.info(
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
f"integration {integration.id}"
)
else:
# Update findings in existing client for this batch
security_hub_client._findings_per_region = (
security_hub_client.filter(
batch_asff_findings,
send_only_fails,
)
)
# Send this batch to Security Hub
try:
findings_sent = (
security_hub_client.batch_send_to_security_hub()
)
total_findings_sent[integration.id] += findings_sent
if findings_sent > 0:
logger.debug(
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
)
except Exception as batch_error:
logger.error(
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
)
# Clear memory after processing each batch
asff_transformer._data.clear()
del batch_asff_findings
del transformed_findings
if not has_findings:
logger.info(
@@ -479,67 +459,69 @@ def send_findings_to_jira(
issue_type: str,
finding_ids: list[str],
):
with rls_transaction(tenant_id):
integration = Integration.objects.get(id=integration_id)
jira_integration = initialize_prowler_integration(integration)
for attempt in rls_transaction(tenant_id):
with attempt:
integration = Integration.objects.get(id=integration_id)
jira_integration = initialize_prowler_integration(integration)
num_tickets_created = 0
for finding_id in finding_ids:
with rls_transaction(tenant_id):
finding_instance = (
Finding.all_objects.select_related("scan__provider")
.prefetch_related("resources")
.get(id=finding_id)
)
for attempt in rls_transaction(tenant_id):
with attempt:
finding_instance = (
Finding.all_objects.select_related("scan__provider")
.prefetch_related("resources")
.get(id=finding_id)
)
# Extract resource information
resource = (
finding_instance.resources.first()
if finding_instance.resources.exists()
else None
)
resource_uid = resource.uid if resource else ""
resource_name = resource.name if resource else ""
resource_tags = {}
if resource and hasattr(resource, "tags"):
resource_tags = resource.get_tags(tenant_id)
# Extract resource information
resource = (
finding_instance.resources.first()
if finding_instance.resources.exists()
else None
)
resource_uid = resource.uid if resource else ""
resource_name = resource.name if resource else ""
resource_tags = {}
if resource and hasattr(resource, "tags"):
resource_tags = resource.get_tags(tenant_id)
# Get region
region = resource.region if resource and resource.region else ""
# Get region
region = resource.region if resource and resource.region else ""
# Extract remediation information from check_metadata
check_metadata = finding_instance.check_metadata
remediation = check_metadata.get("remediation", {})
recommendation = remediation.get("recommendation", {})
remediation_code = remediation.get("code", {})
# Extract remediation information from check_metadata
check_metadata = finding_instance.check_metadata
remediation = check_metadata.get("remediation", {})
recommendation = remediation.get("recommendation", {})
remediation_code = remediation.get("code", {})
# Send the individual finding to Jira
result = jira_integration.send_finding(
check_id=finding_instance.check_id,
check_title=check_metadata.get("checktitle", ""),
severity=finding_instance.severity,
status=finding_instance.status,
status_extended=finding_instance.status_extended or "",
provider=finding_instance.scan.provider.provider,
region=region,
resource_uid=resource_uid,
resource_name=resource_name,
risk=check_metadata.get("risk", ""),
recommendation_text=recommendation.get("text", ""),
recommendation_url=recommendation.get("url", ""),
remediation_code_native_iac=remediation_code.get("nativeiac", ""),
remediation_code_terraform=remediation_code.get("terraform", ""),
remediation_code_cli=remediation_code.get("cli", ""),
remediation_code_other=remediation_code.get("other", ""),
resource_tags=resource_tags,
compliance=finding_instance.compliance or {},
project_key=project_key,
issue_type=issue_type,
)
if result:
num_tickets_created += 1
else:
logger.error(f"Failed to send finding {finding_id} to Jira")
# Send the individual finding to Jira
result = jira_integration.send_finding(
check_id=finding_instance.check_id,
check_title=check_metadata.get("checktitle", ""),
severity=finding_instance.severity,
status=finding_instance.status,
status_extended=finding_instance.status_extended or "",
provider=finding_instance.scan.provider.provider,
region=region,
resource_uid=resource_uid,
resource_name=resource_name,
risk=check_metadata.get("risk", ""),
recommendation_text=recommendation.get("text", ""),
recommendation_url=recommendation.get("url", ""),
remediation_code_native_iac=remediation_code.get("nativeiac", ""),
remediation_code_terraform=remediation_code.get("terraform", ""),
remediation_code_cli=remediation_code.get("cli", ""),
remediation_code_other=remediation_code.get("other", ""),
resource_tags=resource_tags,
compliance=finding_instance.compliance or {},
project_key=project_key,
issue_type=issue_type,
)
if result:
num_tickets_created += 1
else:
logger.error(f"Failed to send finding {finding_id} to Jira")
return {
"created_count": num_tickets_created,

View File

@@ -25,36 +25,38 @@ def mute_historical_findings(tenant_id: str, mute_rule_id: str):
findings_muted_count = 0
# Get the list of UIDs to mute and the reason
with rls_transaction(tenant_id):
mute_rule = MuteRule.objects.get(id=mute_rule_id, tenant_id=tenant_id)
finding_uids = mute_rule.finding_uids
mute_reason = mute_rule.reason
muted_at = mute_rule.inserted_at
for attempt in rls_transaction(tenant_id):
with attempt:
mute_rule = MuteRule.objects.get(id=mute_rule_id, tenant_id=tenant_id)
finding_uids = mute_rule.finding_uids
mute_reason = mute_rule.reason
muted_at = mute_rule.inserted_at
# Query findings that match the UIDs and are not already muted
with rls_transaction(tenant_id):
findings_to_mute = Finding.objects.filter(
tenant_id=tenant_id, uid__in=finding_uids, muted=False
)
total_findings = findings_to_mute.count()
for attempt in rls_transaction(tenant_id):
with attempt:
findings_to_mute = Finding.objects.filter(
tenant_id=tenant_id, uid__in=finding_uids, muted=False
)
total_findings = findings_to_mute.count()
logger.info(
f"Processing {total_findings} findings for mute rule {mute_rule_id}"
)
logger.info(
f"Processing {total_findings} findings for mute rule {mute_rule_id}"
)
if total_findings > 0:
for batch, is_last in batched(
findings_to_mute.iterator(), DJANGO_FINDINGS_BATCH_SIZE
):
batch_ids = [f.id for f in batch]
updated_count = Finding.all_objects.filter(
id__in=batch_ids, tenant_id=tenant_id
).update(
muted=True,
muted_at=muted_at,
muted_reason=mute_reason,
)
findings_muted_count += updated_count
if total_findings > 0:
for batch, is_last in batched(
findings_to_mute.iterator(), DJANGO_FINDINGS_BATCH_SIZE
):
batch_ids = [f.id for f in batch]
updated_count = Finding.all_objects.filter(
id__in=batch_ids, tenant_id=tenant_id
).update(
muted=True,
muted_at=muted_at,
muted_reason=mute_reason,
)
findings_muted_count += updated_count
logger.info(f"Muted {findings_muted_count} findings for rule {mute_rule_id}")

View File

@@ -248,22 +248,23 @@ def generate_compliance_reports(
results = {}
# Validate that the scan has findings and get provider info
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
if not ScanSummary.objects.filter(scan_id=scan_id).exists():
logger.info("No findings found for scan %s", scan_id)
if generate_threatscore:
results["threatscore"] = {"upload": False, "path": ""}
if generate_ens:
results["ens"] = {"upload": False, "path": ""}
if generate_nis2:
results["nis2"] = {"upload": False, "path": ""}
if generate_csa:
results["csa"] = {"upload": False, "path": ""}
return results
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
if not ScanSummary.objects.filter(scan_id=scan_id).exists():
logger.info("No findings found for scan %s", scan_id)
if generate_threatscore:
results["threatscore"] = {"upload": False, "path": ""}
if generate_ens:
results["ens"] = {"upload": False, "path": ""}
if generate_nis2:
results["nis2"] = {"upload": False, "path": ""}
if generate_csa:
results["csa"] = {"upload": False, "path": ""}
return results
provider_obj = Provider.objects.get(id=provider_id)
provider_uid = provider_obj.uid
provider_type = provider_obj.provider
provider_obj = Provider.objects.get(id=provider_id)
provider_uid = provider_obj.uid
provider_type = provider_obj.provider
# Check provider compatibility
if generate_threatscore and provider_type not in [
@@ -398,49 +399,50 @@ def generate_compliance_reports(
min_risk_level=min_risk_level_threatscore,
)
with rls_transaction(tenant_id):
previous_snapshot = (
ThreatScoreSnapshot.objects.filter(
for attempt in rls_transaction(tenant_id):
with attempt:
previous_snapshot = (
ThreatScoreSnapshot.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
compliance_id=compliance_id_threatscore,
)
.order_by("-inserted_at")
.first()
)
score_delta = None
if previous_snapshot:
score_delta = metrics["overall_score"] - float(
previous_snapshot.overall_score
)
snapshot = ThreatScoreSnapshot.objects.create(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
compliance_id=compliance_id_threatscore,
)
.order_by("-inserted_at")
.first()
)
score_delta = None
if previous_snapshot:
score_delta = metrics["overall_score"] - float(
previous_snapshot.overall_score
overall_score=metrics["overall_score"],
score_delta=score_delta,
section_scores=metrics["section_scores"],
critical_requirements=metrics["critical_requirements"],
total_requirements=metrics["total_requirements"],
passed_requirements=metrics["passed_requirements"],
failed_requirements=metrics["failed_requirements"],
manual_requirements=metrics["manual_requirements"],
total_findings=metrics["total_findings"],
passed_findings=metrics["passed_findings"],
failed_findings=metrics["failed_findings"],
)
snapshot = ThreatScoreSnapshot.objects.create(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
compliance_id=compliance_id_threatscore,
overall_score=metrics["overall_score"],
score_delta=score_delta,
section_scores=metrics["section_scores"],
critical_requirements=metrics["critical_requirements"],
total_requirements=metrics["total_requirements"],
passed_requirements=metrics["passed_requirements"],
failed_requirements=metrics["failed_requirements"],
manual_requirements=metrics["manual_requirements"],
total_findings=metrics["total_findings"],
passed_findings=metrics["passed_findings"],
failed_findings=metrics["failed_findings"],
)
delta_msg = (
f" (delta: {score_delta:+.2f}%)"
if score_delta is not None
else ""
)
logger.info(
f"ThreatScore snapshot created with ID {snapshot.id} (score: {snapshot.overall_score}%{delta_msg})",
)
delta_msg = (
f" (delta: {score_delta:+.2f}%)"
if score_delta is not None
else ""
)
logger.info(
f"ThreatScore snapshot created with ID {snapshot.id} (score: {snapshot.overall_score}%{delta_msg})",
)
except Exception as e:
logger.error("Error creating ThreatScore snapshot: %s", e)

View File

@@ -750,25 +750,26 @@ class BaseComplianceReportGenerator(ABC):
Returns:
Aggregated ComplianceData object
"""
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
# Load provider
if provider_obj is None:
provider_obj = Provider.objects.get(id=provider_id)
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
# Load provider
if provider_obj is None:
provider_obj = Provider.objects.get(id=provider_id)
prowler_provider = initialize_prowler_provider(provider_obj)
provider_type = provider_obj.provider
prowler_provider = initialize_prowler_provider(provider_obj)
provider_type = provider_obj.provider
# Load compliance framework
frameworks_bulk = Compliance.get_bulk(provider_type)
compliance_obj = frameworks_bulk.get(compliance_id)
# Load compliance framework
frameworks_bulk = Compliance.get_bulk(provider_type)
compliance_obj = frameworks_bulk.get(compliance_id)
if not compliance_obj:
raise ValueError(f"Compliance framework not found: {compliance_id}")
if not compliance_obj:
raise ValueError(f"Compliance framework not found: {compliance_id}")
framework = getattr(compliance_obj, "Framework", "N/A")
name = getattr(compliance_obj, "Name", "N/A")
version = getattr(compliance_obj, "Version", "N/A")
description = getattr(compliance_obj, "Description", "")
framework = getattr(compliance_obj, "Framework", "N/A")
name = getattr(compliance_obj, "Name", "N/A")
version = getattr(compliance_obj, "Version", "N/A")
description = getattr(compliance_obj, "Description", "")
# Aggregate requirement statistics
if requirement_statistics is None:

File diff suppressed because it is too large Load Diff

View File

@@ -58,12 +58,13 @@ def compute_threatscore_metrics(
>>> print(f"Overall ThreatScore: {metrics['overall_score']:.2f}%")
"""
# Get provider and compliance information
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
provider_obj = Provider.objects.get(id=provider_id)
provider_type = provider_obj.provider
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
provider_obj = Provider.objects.get(id=provider_id)
provider_type = provider_obj.provider
frameworks_bulk = Compliance.get_bulk(provider_type)
compliance_obj = frameworks_bulk[compliance_id]
frameworks_bulk = Compliance.get_bulk(provider_type)
compliance_obj = frameworks_bulk[compliance_id]
# Aggregate requirement statistics from database
requirement_statistics_by_check_id = (

View File

@@ -36,35 +36,36 @@ def _aggregate_requirement_statistics_from_database(
"""
requirement_statistics_by_check_id = {}
# TODO: take into account that now the relation is 1 finding == 1 resource, review this when the logic changes
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
aggregated_statistics_queryset = (
Finding.all_objects.filter(
tenant_id=tenant_id,
scan_id=scan_id,
muted=False,
resources__provider__is_deleted=False,
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
aggregated_statistics_queryset = (
Finding.all_objects.filter(
tenant_id=tenant_id,
scan_id=scan_id,
muted=False,
resources__provider__is_deleted=False,
)
.values("check_id")
.annotate(
total_findings=Count(
"id",
distinct=True,
filter=Q(status__in=[StatusChoices.PASS, StatusChoices.FAIL]),
),
passed_findings=Count(
"id",
distinct=True,
filter=Q(status=StatusChoices.PASS),
),
)
)
.values("check_id")
.annotate(
total_findings=Count(
"id",
distinct=True,
filter=Q(status__in=[StatusChoices.PASS, StatusChoices.FAIL]),
),
passed_findings=Count(
"id",
distinct=True,
filter=Q(status=StatusChoices.PASS),
),
)
)
for aggregated_stat in aggregated_statistics_queryset:
check_id = aggregated_stat["check_id"]
requirement_statistics_by_check_id[check_id] = {
"passed": aggregated_stat["passed_findings"],
"total": aggregated_stat["total_findings"],
}
for aggregated_stat in aggregated_statistics_queryset:
check_id = aggregated_stat["check_id"]
requirement_statistics_by_check_id[check_id] = {
"passed": aggregated_stat["passed_findings"],
"total": aggregated_stat["total_findings"],
}
logger.info(
f"Aggregated statistics for {len(requirement_statistics_by_check_id)} unique checks"
@@ -220,35 +221,36 @@ def _load_findings_for_requirement_checks(
f"Loading findings for {len(check_ids_to_load)} checks from database"
)
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
# Use iterator with chunk_size for memory-efficient streaming
# chunk_size controls how many rows Django fetches from DB at once
findings_queryset = (
Finding.all_objects.filter(
tenant_id=tenant_id,
scan_id=scan_id,
check_id__in=check_ids_to_load,
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
# Use iterator with chunk_size for memory-efficient streaming
# chunk_size controls how many rows Django fetches from DB at once
findings_queryset = (
Finding.all_objects.filter(
tenant_id=tenant_id,
scan_id=scan_id,
check_id__in=check_ids_to_load,
)
.order_by("check_id", "uid")
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
)
.order_by("check_id", "uid")
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
)
# Pre-initialize empty lists for all check_ids to load
# This avoids repeated dict lookups and 'if not in' checks
for check_id in check_ids_to_load:
findings_cache[check_id] = []
# Pre-initialize empty lists for all check_ids to load
# This avoids repeated dict lookups and 'if not in' checks
for check_id in check_ids_to_load:
findings_cache[check_id] = []
findings_count = 0
for finding_model in findings_queryset:
finding_output = FindingOutput.transform_api_finding(
finding_model, prowler_provider
findings_count = 0
for finding_model in findings_queryset:
finding_output = FindingOutput.transform_api_finding(
finding_model, prowler_provider
)
findings_cache[finding_output.check_id].append(finding_output)
findings_count += 1
logger.info(
f"Loaded {findings_count} findings for {len(check_ids_to_load)} checks"
)
findings_cache[finding_output.check_id].append(finding_output)
findings_count += 1
logger.info(
f"Loaded {findings_count} findings for {len(check_ids_to_load)} checks"
)
# Build result dict using cache references (no data duplication)
# This shares the same list objects between cache and result

View File

@@ -280,58 +280,59 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
"""
task_id = self.request.id
with rls_transaction(tenant_id):
periodic_task_instance = PeriodicTask.objects.get(
name=f"scan-perform-scheduled-{provider_id}"
)
executing_scan = (
Scan.objects.filter(
for attempt in rls_transaction(tenant_id):
with attempt:
periodic_task_instance = PeriodicTask.objects.get(
name=f"scan-perform-scheduled-{provider_id}"
)
executing_scan = (
Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING,
)
.order_by("-started_at")
.first()
)
if executing_scan:
logger.warning(
f"Scheduled scan already executing for provider {provider_id}. Skipping."
)
return ScanTaskSerializer(instance=executing_scan).data
executed_scan = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING,
task__task_runner_task__task_id=task_id,
).first()
if executed_scan:
# Duplicated task execution due to visibility timeout
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
return ScanTaskSerializer(instance=executed_scan).data
interval = periodic_task_instance.interval
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
current_scan_datetime = next_scan_datetime - timedelta(
**{interval.period: interval.every}
)
.order_by("-started_at")
.first()
)
if executing_scan:
logger.warning(
f"Scheduled scan already executing for provider {provider_id}. Skipping."
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
_cleanup_orphan_scheduled_scans(
tenant_id=tenant_id,
provider_id=provider_id,
scheduler_task_id=periodic_task_instance.id,
)
return ScanTaskSerializer(instance=executing_scan).data
executed_scan = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
task__task_runner_task__task_id=task_id,
).first()
if executed_scan:
# Duplicated task execution due to visibility timeout
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
return ScanTaskSerializer(instance=executed_scan).data
interval = periodic_task_instance.interval
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
current_scan_datetime = next_scan_datetime - timedelta(
**{interval.period: interval.every}
)
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
_cleanup_orphan_scheduled_scans(
tenant_id=tenant_id,
provider_id=provider_id,
scheduler_task_id=periodic_task_instance.id,
)
scan_instance = _get_or_create_scheduled_scan(
tenant_id=tenant_id,
provider_id=provider_id,
scheduler_task_id=periodic_task_instance.id,
scheduled_at=current_scan_datetime,
)
scan_instance.task_id = task_id
scan_instance.save()
scan_instance = _get_or_create_scheduled_scan(
tenant_id=tenant_id,
provider_id=provider_id,
scheduler_task_id=periodic_task_instance.id,
scheduled_at=current_scan_datetime,
)
scan_instance.task_id = task_id
scan_instance.save()
try:
result = perform_prowler_scan(
@@ -340,19 +341,20 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
provider_id=provider_id,
)
finally:
with rls_transaction(tenant_id):
now = datetime.now(timezone.utc)
if next_scan_datetime <= now:
interval_delta = timedelta(**{interval.period: interval.every})
while next_scan_datetime <= now:
next_scan_datetime += interval_delta
_get_or_create_scheduled_scan(
tenant_id=tenant_id,
provider_id=provider_id,
scheduler_task_id=periodic_task_instance.id,
scheduled_at=next_scan_datetime,
update_state=True,
)
for attempt in rls_transaction(tenant_id):
with attempt:
now = datetime.now(timezone.utc)
if next_scan_datetime <= now:
interval_delta = timedelta(**{interval.period: interval.every})
while next_scan_datetime <= now:
next_scan_datetime += interval_delta
_get_or_create_scheduled_scan(
tenant_id=tenant_id,
provider_id=provider_id,
scheduler_task_id=periodic_task_instance.id,
scheduled_at=next_scan_datetime,
update_state=True,
)
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
@@ -486,67 +488,69 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
.order_by("uid")
.iterator()
)
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
fos = [
FindingOutput.transform_api_finding(f, prowler_provider) for f in batch
]
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
fos = [
FindingOutput.transform_api_finding(f, prowler_provider)
for f in batch
]
# Outputs
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
# Skip ASFF generation if not needed
if mode == "json-asff" and not generate_asff:
continue
# Outputs
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
# Skip ASFF generation if not needed
if mode == "json-asff" and not generate_asff:
continue
cls = cfg["class"]
suffix = cfg["suffix"]
extra = cfg.get("kwargs", {}).copy()
if mode == "html":
extra.update(provider=prowler_provider, stats=scan_summary)
cls = cfg["class"]
suffix = cfg["suffix"]
extra = cfg.get("kwargs", {}).copy()
if mode == "html":
extra.update(provider=prowler_provider, stats=scan_summary)
writer, initialization = get_writer(
output_writers,
cls,
lambda cls=cls, fos=fos, suffix=suffix: cls(
findings=fos,
file_path=out_dir,
file_extension=suffix,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos)
writer.batch_write_data_to_file(**extra)
writer._data.clear()
writer, initialization = get_writer(
output_writers,
cls,
lambda cls=cls, fos=fos, suffix=suffix: cls(
findings=fos,
file_path=out_dir,
file_extension=suffix,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos)
writer.batch_write_data_to_file(**extra)
writer._data.clear()
# Compliance CSVs
for name in frameworks_avail:
compliance_obj = frameworks_bulk[name]
# Compliance CSVs
for name in frameworks_avail:
compliance_obj = frameworks_bulk[name]
klass = GenericCompliance
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
if condition(name):
klass = cls
break
klass = GenericCompliance
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
if condition(name):
klass = cls
break
filename = f"{comp_dir}_{name}.csv"
filename = f"{comp_dir}_{name}.csv"
writer, initialization = get_writer(
compliance_writers,
name,
lambda klass=klass, fos=fos: klass(
findings=fos,
compliance=compliance_obj,
file_path=filename,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos, compliance_obj, name)
writer.batch_write_data_to_file()
writer._data.clear()
writer, initialization = get_writer(
compliance_writers,
name,
lambda klass=klass, fos=fos: klass(
findings=fos,
compliance=compliance_obj,
file_path=filename,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos, compliance_obj, name)
writer.batch_write_data_to_file()
writer._data.clear()
compressed = _compress_output_files(out_dir)
@@ -569,12 +573,13 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
)
# S3 integrations (need output_directory)
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
s3_integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
with attempt:
s3_integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
enabled=True,
)
if s3_integrations:
# Pass the output directory path to S3 integration task to reconstruct objects from files
@@ -812,27 +817,32 @@ def check_integrations_task(tenant_id: str, provider_id: str, scan_id: str = Non
try:
integration_tasks = []
with rls_transaction(tenant_id):
integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
enabled=True,
)
if not integrations.exists():
logger.info(f"No integrations configured for provider {provider_id}")
return {"integrations_processed": 0}
# Security Hub integration
security_hub_integrations = integrations.filter(
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB
)
if security_hub_integrations.exists():
integration_tasks.append(
security_hub_integration_task.s(
tenant_id=tenant_id, provider_id=provider_id, scan_id=scan_id
)
for attempt in rls_transaction(tenant_id):
with attempt:
integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
enabled=True,
)
if not integrations.exists():
logger.info(
f"No integrations configured for provider {provider_id}"
)
return {"integrations_processed": 0}
# Security Hub integration
security_hub_integrations = integrations.filter(
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB
)
if security_hub_integrations.exists():
integration_tasks.append(
security_hub_integration_task.s(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
)
)
# TODO: Add other integration types here
# slack_integrations = integrations.filter(
# integration_type=Integration.IntegrationChoices.SLACK

View File

@@ -55,7 +55,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_run_success_flow(
self,
@@ -201,7 +201,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_run_failure_marks_scan_failed(
self,
@@ -300,7 +300,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_failure_before_gate_does_not_flip_graph_data_ready_true(
self,
@@ -403,7 +403,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_run_failure_marks_scan_failed_even_when_drop_database_fails(
self,
@@ -509,7 +509,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_failure_after_gate_before_drop_restores_graph_data_ready(
self,
@@ -622,7 +622,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_failure_after_drop_before_sync_leaves_graph_data_ready_false(
self,
@@ -735,7 +735,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_failure_after_sync_restores_graph_data_ready(
self,
@@ -853,7 +853,7 @@ class TestAttackPathsRun:
)
@patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
)
def test_recovery_failure_does_not_suppress_original_exception(
self,
@@ -949,7 +949,7 @@ class TestAttackPathsRun:
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
@@ -1411,7 +1411,7 @@ class TestAttackPathsFindingsHelpers:
with (
patch(
"tasks.jobs.attack_paths.findings.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
),
patch(
"tasks.jobs.attack_paths.findings.READ_REPLICA_ALIAS",
@@ -2024,7 +2024,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
attack_paths_scan = create_attack_paths_scan(
str(tenant.id), str(scan.id), provider.id
@@ -2064,7 +2064,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
attack_paths_scan = create_attack_paths_scan(
str(tenant.id), str(new_scan.id), provider.id
@@ -2104,7 +2104,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
attack_paths_scan = create_attack_paths_scan(
str(tenant.id), str(new_scan.id), provider.id
@@ -2136,7 +2136,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
set_graph_data_ready(attack_paths_scan, False)
@@ -2145,7 +2145,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
set_graph_data_ready(attack_paths_scan, True)
@@ -2175,7 +2175,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
finish_attack_paths_scan(attack_paths_scan, StateChoices.COMPLETED, {})
@@ -2206,7 +2206,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
finish_attack_paths_scan(
attack_paths_scan,
@@ -2257,7 +2257,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
set_provider_graph_data_ready(new_ap_scan, False)
@@ -2309,7 +2309,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
new=lambda *args, **kwargs: [nullcontext()],
):
set_provider_graph_data_ready(ap_scan_a, False)

View File

@@ -159,9 +159,8 @@ class TestOutputs:
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
mock_scan.objects.get.return_value = mock_scan_instance
# Mock rls_transaction as a context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
# Mock rls_transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
base_dir = str(base_tmp)
@@ -210,9 +209,8 @@ class TestOutputs:
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
mock_scan.objects.get.return_value = mock_scan_instance
# Mock rls_transaction as a context manager
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
# Mock rls_transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
base_dir = str(base_tmp)

View File

@@ -1,7 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest
from django.db import OperationalError
from tasks.jobs.integrations import (
get_s3_client_from_integration,
get_security_hub_client_from_integration,
@@ -99,6 +98,7 @@ class TestS3IntegrationUploads:
):
tenant_id = "tenant-id"
provider_id = "provider-id"
mock_rls.return_value = [MagicMock()]
integration = MagicMock()
integration.id = "i-1"
@@ -150,6 +150,7 @@ class TestS3IntegrationUploads:
):
tenant_id = "tenant-id"
provider_id = "provider-id"
mock_rls.return_value = [MagicMock()]
integration = MagicMock()
integration.id = "i-1"
@@ -176,6 +177,7 @@ class TestS3IntegrationUploads:
def test_upload_s3_integration_logs_if_no_integrations(
self, mock_logger, mock_integration_model, mock_rls
):
mock_rls.return_value = [MagicMock()]
mock_integration_model.objects.filter.return_value = []
output_directory = "/tmp/prowler_output/scan123"
result = upload_s3_integration("tenant", "provider", output_directory)
@@ -197,6 +199,7 @@ class TestS3IntegrationUploads:
):
tenant_id = "tenant-id"
provider_id = "provider-id"
mock_rls.return_value = [MagicMock()]
integration = MagicMock()
integration.id = "i-1"
@@ -226,7 +229,7 @@ class TestS3IntegrationUploads:
# Mock that no enabled integrations are found
mock_integration_filter.return_value = []
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value = [MagicMock()]
result = upload_s3_integration(tenant_id, provider_id, output_directory)
@@ -784,7 +787,7 @@ class TestSecurityHubIntegrationUploads:
mock_findings = [{"finding": "test"}]
# Mock RLS context manager
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value = [MagicMock()]
connected, connection = get_security_hub_client_from_integration(
mock_integration, tenant_id, mock_findings
@@ -857,7 +860,7 @@ class TestSecurityHubIntegrationUploads:
mock_findings = [{"finding": "test1"}, {"finding": "test2"}]
# Mock RLS context manager
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value = [MagicMock()]
# Call the function
connected, connection = get_security_hub_client_from_integration(
@@ -999,6 +1002,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()
@@ -1057,84 +1061,6 @@ class TestSecurityHubIntegrationUploads:
mock_security_hub.batch_send_to_security_hub.assert_called_once()
mock_security_hub.archive_previous_findings.assert_called_once()
@patch("tasks.jobs.integrations.time.sleep")
@patch("tasks.jobs.integrations.batched")
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
@patch("tasks.jobs.integrations.initialize_prowler_provider")
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.Provider")
@patch("tasks.jobs.integrations.Finding")
def test_upload_security_hub_integration_retries_on_operational_error(
self,
mock_finding_model,
mock_provider_model,
mock_integration_model,
mock_rls,
mock_initialize_provider,
mock_get_security_hub,
mock_batched,
mock_sleep,
):
"""Test SecurityHub upload retries on transient OperationalError."""
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
integration = MagicMock()
integration.id = "integration-1"
integration.configuration = {
"send_only_fails": True,
"archive_previous_findings": False,
}
mock_integration_model.objects.filter.return_value = [integration]
provider = MagicMock()
mock_provider_model.objects.get.return_value = provider
mock_prowler_provider = MagicMock()
mock_initialize_provider.return_value = mock_prowler_provider
mock_findings = [MagicMock(), MagicMock()]
mock_finding_model.all_objects.filter.return_value.order_by.return_value.iterator.return_value = iter(
mock_findings
)
transformed_findings = [MagicMock(), MagicMock()]
with patch("tasks.jobs.integrations.FindingOutput") as mock_finding_output:
mock_finding_output.transform_api_finding.side_effect = transformed_findings
with patch("tasks.jobs.integrations.ASFF") as mock_asff:
mock_asff_instance = MagicMock()
finding1 = MagicMock()
finding1.Compliance.Status = "FAILED"
finding2 = MagicMock()
finding2.Compliance.Status = "FAILED"
mock_asff_instance.data = [finding1, finding2]
mock_asff_instance._data = MagicMock()
mock_asff.return_value = mock_asff_instance
mock_security_hub = MagicMock()
mock_security_hub.batch_send_to_security_hub.return_value = 2
mock_get_security_hub.return_value = (True, mock_security_hub)
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value.__exit__.return_value = False
mock_batched.side_effect = [
OperationalError("Conflict with recovery"),
[(mock_findings, None)],
]
with patch("tasks.jobs.integrations.REPLICA_MAX_ATTEMPTS", 2):
with patch("tasks.jobs.integrations.READ_REPLICA_ALIAS", "replica"):
result = upload_security_hub_integration(
tenant_id, provider_id, scan_id
)
assert result is True
mock_sleep.assert_called_once()
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
@patch("tasks.jobs.integrations.initialize_prowler_provider")
@patch("tasks.jobs.integrations.rls_transaction")
@@ -1154,6 +1080,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock no integrations found
mock_integration_model.objects.filter.return_value = []
@@ -1183,6 +1110,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()
@@ -1226,6 +1154,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()
@@ -1300,6 +1229,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock integration with archive_previous_findings disabled
integration = MagicMock()
@@ -1375,6 +1305,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()
@@ -1443,6 +1374,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock exception during integration retrieval
mock_integration_model.objects.filter.side_effect = Exception("Database error")
@@ -1476,6 +1408,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock integration with send_only_fails=True
integration = MagicMock()
@@ -1570,6 +1503,7 @@ class TestSecurityHubIntegrationUploads:
tenant_id = "tenant-id"
provider_id = "provider-id"
scan_id = "scan-123"
mock_rls.return_value = [MagicMock()]
# Mock integration with send_only_fails=False
integration = MagicMock()
@@ -1654,9 +1588,8 @@ class TestJiraIntegration:
issue_type = "Task"
finding_ids = ["finding-1", "finding-2"]
# Mock RLS transaction
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
# Mock RLS transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()
@@ -1786,9 +1719,8 @@ class TestJiraIntegration:
issue_type = "Task"
finding_ids = ["finding-1", "finding-2", "finding-3"]
# Mock RLS transaction
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
# Mock RLS transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()
@@ -1856,9 +1788,8 @@ class TestJiraIntegration:
issue_type = "Task"
finding_ids = ["finding-1"]
# Mock RLS transaction
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
# Mock RLS transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()
@@ -1934,9 +1865,8 @@ class TestJiraIntegration:
issue_type = "Task"
finding_ids = ["finding-1"]
# Mock RLS transaction
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
# Mock RLS transaction as an iterable yielding a context manager
mock_rls_transaction.return_value = [MagicMock()]
# Mock integration
integration = MagicMock()

View File

@@ -43,9 +43,12 @@ from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
@contextmanager
def noop_rls_transaction(*args, **kwargs):
yield
@contextmanager
def _ctx():
yield
return [_ctx()]
class FakeFinding:
@@ -75,7 +78,7 @@ class TestPerformScan:
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -241,6 +244,8 @@ class TestPerformScan:
scans_fixture,
providers_fixture,
):
mock_rls_transaction.return_value = [MagicMock()]
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
@@ -282,6 +287,8 @@ class TestPerformScan:
mock_get_or_create_resource,
mock_get_or_create_tag,
):
mock_rls_transaction.return_value = [MagicMock()]
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider123"
@@ -331,6 +338,8 @@ class TestPerformScan:
mock_get_or_create_resource,
mock_get_or_create_tag,
):
mock_rls_transaction.return_value = [MagicMock()]
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -388,6 +397,8 @@ class TestPerformScan:
mock_get_or_create_resource,
mock_get_or_create_tag,
):
mock_rls_transaction.return_value = [MagicMock()]
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -438,7 +449,7 @@ class TestPerformScan:
):
"""Test that failed findings increment the failed_findings_count"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -516,7 +527,7 @@ class TestPerformScan:
):
"""Test that multiple FAIL findings on the same resource increment the counter correctly"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -633,7 +644,7 @@ class TestPerformScan:
):
"""Test that muted FAIL findings do not increment the failed_findings_count"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -723,7 +734,7 @@ class TestPerformScan:
)
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -791,7 +802,7 @@ class TestPerformScan:
):
"""Test active MuteRule mutes findings with correct reason"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -908,7 +919,7 @@ class TestPerformScan:
):
"""Test inactive MuteRule does not mute findings"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -994,7 +1005,7 @@ class TestPerformScan:
):
"""Test mutelist processor takes precedence over MuteRule"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -1080,7 +1091,7 @@ class TestPerformScan:
):
"""Test MuteRule with multiple finding UIDs mutes all findings"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -1179,7 +1190,7 @@ class TestPerformScan:
):
"""Test scan continues when MuteRule loading fails"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -1262,7 +1273,7 @@ class TestPerformScan:
):
"""Test muted_at timestamp is set correctly for muted findings"""
with (
patch("api.db_utils.rls_transaction"),
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -2200,9 +2211,7 @@ class TestComplianceRequirementCopy:
tenant_id = row["tenant_id"]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_persist_compliance_requirement_rows(tenant_id, [row])
@@ -2727,9 +2736,7 @@ class TestComplianceRequirementCopy:
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_persist_compliance_requirement_rows(tenant_id, [row])
@@ -2790,9 +2797,7 @@ class TestComplianceRequirementCopy:
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_persist_compliance_requirement_rows(tenant_id, rows)
@@ -2854,9 +2859,7 @@ class TestComplianceRequirementCopy:
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_persist_compliance_requirement_rows(tenant_id, [row])
@@ -2920,9 +2923,7 @@ class TestCreateComplianceSummaries:
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
@@ -2961,9 +2962,7 @@ class TestCreateComplianceSummaries:
requirement_statuses = {}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
@@ -2985,9 +2984,7 @@ class TestCreateComplianceSummaries:
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
@@ -3017,9 +3014,7 @@ class TestCreateComplianceSummaries:
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
@@ -3067,9 +3062,7 @@ class TestCreateComplianceSummaries:
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
@@ -3157,9 +3150,7 @@ class TestAggregateFindings:
scan = scans_fixture[0]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
aggregate_findings(str(tenant.id), str(scan.id))
@@ -3208,9 +3199,7 @@ class TestAggregateFindings:
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
aggregate_findings(tenant_id, scan_id)
@@ -3260,9 +3249,7 @@ class TestAggregateFindings:
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
aggregate_findings(tenant_id, scan_id)
@@ -3335,9 +3322,7 @@ class TestAggregateFindings:
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
aggregate_findings(tenant_id, scan_id)
@@ -3386,9 +3371,7 @@ class TestAggregateFindingsByRegion:
mock_queryset.prefetch_related.return_value = [mock_finding1]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
check_status_by_region, findings_count_by_compliance = (
@@ -3439,9 +3422,7 @@ class TestAggregateFindingsByRegion:
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
check_status_by_region, _ = _aggregate_findings_by_region(
@@ -3466,9 +3447,7 @@ class TestAggregateFindingsByRegion:
mock_queryset.prefetch_related.return_value = []
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
_aggregate_findings_by_region(
@@ -3516,9 +3495,7 @@ class TestAggregateFindingsByRegion:
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
_, findings_count_by_compliance = _aggregate_findings_by_region(
@@ -3570,9 +3547,7 @@ class TestAggregateFindingsByRegion:
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
check_status_by_region, _ = _aggregate_findings_by_region(
@@ -3600,9 +3575,7 @@ class TestAggregateFindingsByRegion:
mock_queryset.prefetch_related.return_value = []
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
check_status_by_region, findings_count_by_compliance = (
@@ -3715,9 +3688,7 @@ class TestAggregateAttackSurface:
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
@@ -3764,9 +3735,7 @@ class TestAggregateAttackSurface:
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
@@ -3805,9 +3774,7 @@ class TestAggregateAttackSurface:
mock_queryset.annotate.return_value = [] # No findings
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
@@ -3848,9 +3815,7 @@ class TestAggregateAttackSurface:
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
mock_findings_filter.return_value = mock_queryset
aggregate_attack_surface(str(tenant.id), str(scan.id))
@@ -3880,9 +3845,7 @@ class TestAggregateAttackSurface:
mock_select_related.return_value.get.return_value = mock_scan
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_rls_transaction.return_value = [ctx]
with patch(
"tasks.jobs.scan._get_attack_surface_mapping_from_provider"

View File

@@ -713,7 +713,7 @@ class TestGenerateOutputs:
True,
]
mock_integration_filter.return_value = [MagicMock()]
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value = [MagicMock()]
with (
patch("tasks.tasks.OUTPUT_FORMATS_MAPPING", {}),
@@ -880,7 +880,7 @@ class TestCheckIntegrationsTask:
):
mock_integration_filter.return_value.exists.return_value = False
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value = [MagicMock()]
result = check_integrations_task(
tenant_id=self.tenant_id,
@@ -922,7 +922,7 @@ class TestCheckIntegrationsTask:
mock_group.return_value = mock_job
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value = [MagicMock()]
# Execute the function
result = check_integrations_task(
@@ -963,7 +963,7 @@ class TestCheckIntegrationsTask:
):
"""Test that disabled integrations are not processed."""
mock_integration_filter.return_value.exists.return_value = False
mock_rls.return_value.__enter__.return_value = None
mock_rls.return_value = [MagicMock()]
result = check_integrations_task(
tenant_id=self.tenant_id,