mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-04-09 11:17:08 +00:00
Compare commits
3 Commits
PROWLER-12
...
PROWLER-12
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f12c72ff91 | ||
|
|
27bc88e8f3 | ||
|
|
864559c508 |
@@ -17,6 +17,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
### 🐞 Fixed
|
||||
|
||||
- Attack Paths: Recover `graph_data_ready` flag when scan fails during graph swap, preventing query endpoints from staying blocked until the next successful scan [(#10354)](https://github.com/prowler-cloud/prowler/pull/10354)
|
||||
- Rewrite `rls_transaction` to retry mid-query read replica failures with primary DB fallback, fixing scan crashes during RDS replica recovery [(#10374)](https://github.com/prowler-cloud/prowler/pull/10374)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
|
||||
@@ -45,26 +45,27 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
tenant = Tenant.objects.using(MainRouter.admin_db).create(
|
||||
name=f"{user.email.split('@')[0]} default tenant"
|
||||
)
|
||||
with rls_transaction(str(tenant.id)):
|
||||
Membership.objects.using(MainRouter.admin_db).create(
|
||||
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
|
||||
)
|
||||
role = Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
for attempt in rls_transaction(str(tenant.id)):
|
||||
with attempt:
|
||||
Membership.objects.using(MainRouter.admin_db).create(
|
||||
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
|
||||
)
|
||||
role = Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
else:
|
||||
request.session["saml_user_created"] = str(user.id)
|
||||
|
||||
|
||||
@@ -90,11 +90,12 @@ class BaseRLSViewSet(BaseViewSet):
|
||||
if tenant_id is None:
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
with rls_transaction(
|
||||
for attempt in rls_transaction(
|
||||
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
|
||||
):
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
with attempt:
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
def get_serializer_context(self):
|
||||
context = super().get_serializer_context()
|
||||
@@ -163,12 +164,13 @@ class BaseTenantViewset(BaseViewSet):
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
user_id = str(request.user.id)
|
||||
with rls_transaction(
|
||||
for attempt in rls_transaction(
|
||||
value=user_id,
|
||||
parameter=POSTGRES_USER_VAR,
|
||||
using=getattr(self, "db_alias", MainRouter.default_db),
|
||||
):
|
||||
return super().initial(request, *args, **kwargs)
|
||||
with attempt:
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
|
||||
class BaseUserViewset(BaseViewSet):
|
||||
@@ -200,8 +202,9 @@ class BaseUserViewset(BaseViewSet):
|
||||
if tenant_id is None:
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
with rls_transaction(
|
||||
for attempt in rls_transaction(
|
||||
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
|
||||
):
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
with attempt:
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
@@ -71,21 +71,14 @@ def psycopg_connection(database_alias: str):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def rls_transaction(
|
||||
def _rls_transaction_context_manager(
|
||||
value: str,
|
||||
parameter: str = POSTGRES_TENANT_VAR,
|
||||
using: str | None = None,
|
||||
retry_on_replica: bool = True,
|
||||
):
|
||||
"""
|
||||
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
|
||||
if the value is a valid UUID.
|
||||
"""Internal context manager that opens a single RLS transaction.
|
||||
|
||||
Args:
|
||||
value (str): Database configuration parameter value.
|
||||
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
|
||||
using (str | None): Optional database alias to run the transaction against. Defaults to the
|
||||
active read alias (if any) or Django's default connection.
|
||||
Callers should use ``rls_transaction`` (the public class) instead.
|
||||
"""
|
||||
requested_alias = using or get_read_db_alias()
|
||||
db_alias = requested_alias or DEFAULT_DB_ALIAS
|
||||
@@ -93,54 +86,164 @@ def rls_transaction(
|
||||
db_alias = DEFAULT_DB_ALIAS
|
||||
|
||||
alias = db_alias
|
||||
is_replica = READ_REPLICA_ALIAS and alias == READ_REPLICA_ALIAS
|
||||
max_attempts = REPLICA_MAX_ATTEMPTS if is_replica and retry_on_replica else 1
|
||||
router_token = None
|
||||
conn = connections[alias]
|
||||
try:
|
||||
if alias != DEFAULT_DB_ALIAS:
|
||||
router_token = set_read_db_alias(alias)
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
router_token = None
|
||||
yielded_cursor = False
|
||||
with transaction.atomic(using=alias):
|
||||
with conn.cursor() as cursor:
|
||||
try:
|
||||
uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
yield cursor
|
||||
finally:
|
||||
if router_token is not None:
|
||||
reset_read_db_alias(router_token)
|
||||
|
||||
# On final attempt, fallback to primary
|
||||
if attempt == max_attempts and is_replica:
|
||||
logger.warning(
|
||||
f"RLS transaction failed after {attempt - 1} attempts on replica, "
|
||||
f"falling back to primary DB"
|
||||
)
|
||||
|
||||
class rls_transaction:
|
||||
"""RLS transaction with retry and replica-to-primary fallback.
|
||||
|
||||
Usage::
|
||||
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
result = Model.objects.filter(...)
|
||||
|
||||
When ``using`` points to a read replica and ``retry_on_replica`` is True,
|
||||
the iterator yields up to ``REPLICA_MAX_ATTEMPTS`` attempts on the replica
|
||||
followed by one final attempt on the primary DB. For primary-only calls
|
||||
the iterator yields a single attempt with no retry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: str,
|
||||
parameter: str = POSTGRES_TENANT_VAR,
|
||||
using: str | None = None,
|
||||
retry_on_replica: bool = True,
|
||||
):
|
||||
self.value = value
|
||||
self.parameter = parameter
|
||||
self.using = using
|
||||
self.retry_on_replica = retry_on_replica
|
||||
|
||||
def __iter__(self):
|
||||
return _RLSRetryIterator(self)
|
||||
|
||||
|
||||
class _RLSRetryIterator:
|
||||
def __init__(self, parent: rls_transaction):
|
||||
self._parent = parent
|
||||
self._attempt = 0
|
||||
self._done = False
|
||||
|
||||
requested = parent.using or get_read_db_alias()
|
||||
self._db_alias = requested or DEFAULT_DB_ALIAS
|
||||
if self._db_alias not in connections:
|
||||
self._db_alias = DEFAULT_DB_ALIAS
|
||||
|
||||
self._is_replica = bool(
|
||||
READ_REPLICA_ALIAS and self._db_alias == READ_REPLICA_ALIAS
|
||||
)
|
||||
|
||||
if self._is_replica and parent.retry_on_replica:
|
||||
# N replica attempts + 1 primary fallback
|
||||
self._max_attempts = REPLICA_MAX_ATTEMPTS + 1
|
||||
else:
|
||||
self._max_attempts = 1
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._done or self._attempt >= self._max_attempts:
|
||||
raise StopIteration
|
||||
|
||||
self._attempt += 1
|
||||
is_final = self._attempt == self._max_attempts
|
||||
|
||||
if is_final and self._is_replica and self._parent.retry_on_replica:
|
||||
alias = DEFAULT_DB_ALIAS
|
||||
if self._attempt > 1:
|
||||
logger.warning(
|
||||
f"RLS transaction failed after {self._attempt - 1} attempts on replica, "
|
||||
f"falling back to primary DB"
|
||||
)
|
||||
else:
|
||||
alias = self._db_alias
|
||||
|
||||
conn = connections[alias]
|
||||
return _RLSAttempt(self, alias, is_final)
|
||||
|
||||
|
||||
class _RLSAttempt:
|
||||
def __init__(self, iterator: _RLSRetryIterator, alias: str, is_final: bool):
|
||||
self._iterator = iterator
|
||||
self._alias = alias
|
||||
self._is_final = is_final
|
||||
self._context_manager = None
|
||||
|
||||
def __enter__(self):
|
||||
# Retry loop for connection-setup errors (pre-yield).
|
||||
# Python does NOT call __exit__ when __enter__ raises, so we
|
||||
# must handle retries here for errors like "connection refused".
|
||||
while True:
|
||||
try:
|
||||
self._context_manager = _rls_transaction_context_manager(
|
||||
self._iterator._parent.value,
|
||||
self._iterator._parent.parameter,
|
||||
using=self._alias,
|
||||
)
|
||||
return self._context_manager.__enter__()
|
||||
except OperationalError as exc:
|
||||
if self._is_final or not self._iterator._is_replica:
|
||||
raise
|
||||
self._handle_retry(exc)
|
||||
# Consume the next attempt from the iterator
|
||||
next_att = next(self._iterator)
|
||||
self._alias = next_att._alias
|
||||
self._is_final = next_att._is_final
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
try:
|
||||
if alias != DEFAULT_DB_ALIAS:
|
||||
router_token = set_read_db_alias(alias)
|
||||
|
||||
with transaction.atomic(using=alias):
|
||||
with conn.cursor() as cursor:
|
||||
try:
|
||||
# just in case the value is a UUID object
|
||||
uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
yielded_cursor = True
|
||||
yield cursor
|
||||
return
|
||||
except OperationalError as e:
|
||||
if yielded_cursor:
|
||||
raise
|
||||
# If on primary or max attempts reached, raise
|
||||
if not is_replica or attempt == max_attempts:
|
||||
self._context_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
except OperationalError as exc:
|
||||
if self._is_final or not self._iterator._is_replica:
|
||||
raise
|
||||
self._handle_retry(exc)
|
||||
return True
|
||||
|
||||
# Retry with exponential backoff
|
||||
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
|
||||
logger.info(
|
||||
f"RLS transaction failed on replica (attempt {attempt}/{max_attempts}), "
|
||||
f"retrying in {delay}s. Error: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if router_token is not None:
|
||||
reset_read_db_alias(router_token)
|
||||
if exc_type is None:
|
||||
self._iterator._done = True
|
||||
return False
|
||||
|
||||
if (
|
||||
issubclass(exc_type, OperationalError)
|
||||
and not self._is_final
|
||||
and self._iterator._is_replica
|
||||
):
|
||||
self._handle_retry(exc_val)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _handle_retry(self, error):
|
||||
try:
|
||||
connections[self._alias].close()
|
||||
except Exception:
|
||||
pass
|
||||
attempt = self._iterator._attempt
|
||||
max_att = self._iterator._max_attempts
|
||||
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
|
||||
logger.info(
|
||||
f"RLS transaction failed on replica (attempt {attempt}/{max_att}), "
|
||||
f"retrying in {delay:.1f}s. Error: {error}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
class CustomUserManager(BaseUserManager):
|
||||
@@ -197,23 +300,21 @@ def batch_delete(tenant_id, queryset, batch_size=settings.DJANGO_DELETION_BATCH_
|
||||
deletion_summary = {}
|
||||
|
||||
while True:
|
||||
with rls_transaction(tenant_id, POSTGRES_TENANT_VAR):
|
||||
# Get a batch of IDs to delete
|
||||
batch_ids = set(
|
||||
queryset.values_list("id", flat=True).order_by("id")[:batch_size]
|
||||
)
|
||||
if not batch_ids:
|
||||
# No more objects to delete
|
||||
break
|
||||
for attempt in rls_transaction(tenant_id, POSTGRES_TENANT_VAR):
|
||||
with attempt:
|
||||
# Get a batch of IDs to delete
|
||||
batch_ids = set(
|
||||
queryset.values_list("id", flat=True).order_by("id")[:batch_size]
|
||||
)
|
||||
if not batch_ids:
|
||||
return total_deleted, deletion_summary
|
||||
|
||||
deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()
|
||||
deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()
|
||||
|
||||
total_deleted += deleted_count
|
||||
for model_label, count in deleted_info.items():
|
||||
deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count
|
||||
|
||||
return total_deleted, deletion_summary
|
||||
|
||||
|
||||
def delete_related_daily_task(provider_id: str):
|
||||
"""
|
||||
@@ -245,8 +346,9 @@ def create_objects_in_batches(
|
||||
total = len(objects)
|
||||
for i in range(0, total, batch_size):
|
||||
chunk = objects[i : i + batch_size]
|
||||
with rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
|
||||
model.objects.bulk_create(chunk, batch_size)
|
||||
for attempt in rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
|
||||
with attempt:
|
||||
model.objects.bulk_create(chunk, batch_size)
|
||||
|
||||
|
||||
def update_objects_in_batches(
|
||||
@@ -268,8 +370,9 @@ def update_objects_in_batches(
|
||||
total = len(objects)
|
||||
for start in range(0, total, batch_size):
|
||||
chunk = objects[start : start + batch_size]
|
||||
with rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
|
||||
model.objects.bulk_update(chunk, fields, batch_size)
|
||||
for attempt in rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
|
||||
with attempt:
|
||||
model.objects.bulk_update(chunk, fields, batch_size)
|
||||
|
||||
|
||||
# Postgres Enums
|
||||
|
||||
@@ -97,23 +97,24 @@ def handle_provider_deletion(func):
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
provider_id = kwargs.get("provider_id")
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
if provider_id is None:
|
||||
scan_id = kwargs.get("scan_id")
|
||||
if scan_id is None:
|
||||
raise AssertionError(
|
||||
"This task does not have provider or scan in the kwargs"
|
||||
)
|
||||
scan = Scan.objects.filter(pk=scan_id).first()
|
||||
if scan is None:
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
if provider_id is None:
|
||||
scan_id = kwargs.get("scan_id")
|
||||
if scan_id is None:
|
||||
raise AssertionError(
|
||||
"This task does not have provider or scan in the kwargs"
|
||||
)
|
||||
scan = Scan.objects.filter(pk=scan_id).first()
|
||||
if scan is None:
|
||||
raise ProviderDeletedException(
|
||||
f"Provider for scan '{scan_id}' was deleted during the scan"
|
||||
) from None
|
||||
provider_id = str(scan.provider_id)
|
||||
if not Provider.objects.filter(pk=provider_id).exists():
|
||||
raise ProviderDeletedException(
|
||||
f"Provider for scan '{scan_id}' was deleted during the scan"
|
||||
f"Provider '{provider_id}' was deleted during the scan"
|
||||
) from None
|
||||
provider_id = str(scan.provider_id)
|
||||
if not Provider.objects.filter(pk=provider_id).exists():
|
||||
raise ProviderDeletedException(
|
||||
f"Provider '{provider_id}' was deleted during the scan"
|
||||
) from None
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -97,27 +97,29 @@ class Command(BaseCommand):
|
||||
):
|
||||
possible_types.append(check_metadata.ResourceType)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
provider, _ = Provider.all_objects.get_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider="aws",
|
||||
connected=True,
|
||||
uid=str(random.randint(100000000000, 999999999999)),
|
||||
defaults={
|
||||
"alias": alias,
|
||||
},
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
provider, _ = Provider.all_objects.get_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider="aws",
|
||||
connected=True,
|
||||
uid=str(random.randint(100000000000, 999999999999)),
|
||||
defaults={
|
||||
"alias": alias,
|
||||
},
|
||||
)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
scan = Scan.all_objects.create(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
name=alias,
|
||||
trigger="manual",
|
||||
state="executing",
|
||||
progress=0,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
scan = Scan.all_objects.create(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
name=alias,
|
||||
trigger="manual",
|
||||
state="executing",
|
||||
progress=0,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
scan_state = "completed"
|
||||
|
||||
try:
|
||||
@@ -141,13 +143,15 @@ class Command(BaseCommand):
|
||||
num_batches = ceil(len(resources) / batch_size)
|
||||
self.stdout.write(self.style.WARNING("Creating resources..."))
|
||||
for i in tqdm(range(0, len(resources), batch_size), total=num_batches):
|
||||
with rls_transaction(tenant_id):
|
||||
Resource.all_objects.bulk_create(resources[i : i + batch_size])
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
Resource.all_objects.bulk_create(resources[i : i + batch_size])
|
||||
self.stdout.write(self.style.SUCCESS("Resources created successfully.\n\n"))
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
scan.progress = 33
|
||||
scan.save()
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
scan.progress = 33
|
||||
scan.save()
|
||||
|
||||
# Create Findings
|
||||
findings = []
|
||||
@@ -193,13 +197,15 @@ class Command(BaseCommand):
|
||||
num_batches = ceil(len(findings) / batch_size)
|
||||
self.stdout.write(self.style.WARNING("Creating findings..."))
|
||||
for i in tqdm(range(0, len(findings), batch_size), total=num_batches):
|
||||
with rls_transaction(tenant_id):
|
||||
Finding.all_objects.bulk_create(findings[i : i + batch_size])
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
Finding.all_objects.bulk_create(findings[i : i + batch_size])
|
||||
self.stdout.write(self.style.SUCCESS("Findings created successfully.\n\n"))
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
scan.progress = 66
|
||||
scan.save()
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
scan.progress = 66
|
||||
scan.save()
|
||||
|
||||
# Create ResourceFindingMapping
|
||||
mappings = []
|
||||
@@ -227,19 +233,21 @@ class Command(BaseCommand):
|
||||
self.style.WARNING("Creating resource-finding mappings...")
|
||||
)
|
||||
for i in tqdm(range(0, len(mappings), batch_size), total=num_batches):
|
||||
with rls_transaction(tenant_id):
|
||||
ResourceFindingMapping.objects.bulk_create(
|
||||
mappings[i : i + batch_size]
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
ResourceFindingMapping.objects.bulk_create(
|
||||
mappings[i : i + batch_size]
|
||||
)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
"Resource-finding mappings created successfully.\n\n"
|
||||
)
|
||||
)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
scan.progress = 99
|
||||
scan.save()
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
scan.progress = 99
|
||||
scan.save()
|
||||
|
||||
self.stdout.write(self.style.WARNING("Creating finding filter values..."))
|
||||
resource_scan_summaries = [
|
||||
@@ -254,16 +262,18 @@ class Command(BaseCommand):
|
||||
for resource_id, service, region, resource_type in scan_resource_cache
|
||||
]
|
||||
num_batches = ceil(len(resource_scan_summaries) / batch_size)
|
||||
with rls_transaction(tenant_id):
|
||||
for i in tqdm(
|
||||
range(0, len(resource_scan_summaries), batch_size),
|
||||
total=num_batches,
|
||||
):
|
||||
with rls_transaction(tenant_id):
|
||||
ResourceScanSummary.objects.bulk_create(
|
||||
resource_scan_summaries[i : i + batch_size],
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
for i in tqdm(
|
||||
range(0, len(resource_scan_summaries), batch_size),
|
||||
total=num_batches,
|
||||
):
|
||||
for inner_attempt in rls_transaction(tenant_id):
|
||||
with inner_attempt:
|
||||
ResourceScanSummary.objects.bulk_create(
|
||||
resource_scan_summaries[i : i + batch_size],
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("Finding filter values created successfully.\n\n")
|
||||
@@ -279,7 +289,8 @@ class Command(BaseCommand):
|
||||
scan.progress = 100
|
||||
scan.state = scan_state
|
||||
scan.unique_resource_count = num_resources
|
||||
with rls_transaction(tenant_id):
|
||||
scan.save()
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
scan.save()
|
||||
|
||||
self.stdout.write(self.style.NOTICE("Successfully populated test data."))
|
||||
|
||||
@@ -29,16 +29,17 @@ def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
|
||||
else:
|
||||
next_scan_date = scheduled_time_today + timedelta(days=1)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
name="Daily scheduled scan",
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
scheduled_at=next_scan_date,
|
||||
scheduler_task_id=daily_scheduled_scan_task.id,
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
name="Daily scheduled scan",
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.SCHEDULED,
|
||||
scheduled_at=next_scan_date,
|
||||
scheduler_task_id=daily_scheduled_scan_task.id,
|
||||
)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
from rest_framework_json_api.renderers import JSONRenderer
|
||||
|
||||
@@ -28,11 +26,8 @@ class APIJSONRenderer(JSONRenderer):
|
||||
db_alias = getattr(request, "db_alias", None) if request else None
|
||||
include_param_present = "include" in request.query_params if request else False
|
||||
|
||||
# Use rls_transaction if needed for included resources, otherwise do nothing
|
||||
context_manager = (
|
||||
rls_transaction(tenant_id, using=db_alias)
|
||||
if tenant_id and include_param_present
|
||||
else nullcontext()
|
||||
)
|
||||
with context_manager:
|
||||
return super().render(data, accepted_media_type, renderer_context)
|
||||
if tenant_id and include_param_present:
|
||||
for attempt in rls_transaction(tenant_id, using=db_alias):
|
||||
with attempt:
|
||||
return super().render(data, accepted_media_type, renderer_context)
|
||||
return super().render(data, accepted_media_type, renderer_context)
|
||||
|
||||
@@ -17,23 +17,26 @@ class TestRLSTransaction:
|
||||
|
||||
def test_success_on_primary(self, tenant):
|
||||
"""Basic: transaction succeeds on primary database."""
|
||||
with rls_transaction(str(tenant.id), using=DEFAULT_DB_ALIAS) as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
assert result == (1,)
|
||||
for attempt in rls_transaction(str(tenant.id), using=DEFAULT_DB_ALIAS):
|
||||
with attempt as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
assert result == (1,)
|
||||
|
||||
def test_invalid_uuid_raises_validation_error(self):
|
||||
"""Invalid UUID raises ValidationError before DB operations."""
|
||||
with pytest.raises(ValidationError, match="Must be a valid UUID"):
|
||||
with rls_transaction("not-a-uuid", using=DEFAULT_DB_ALIAS):
|
||||
pass
|
||||
for attempt in rls_transaction("not-a-uuid", using=DEFAULT_DB_ALIAS):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
def test_custom_parameter_name(self, tenant):
|
||||
"""Test custom RLS parameter name."""
|
||||
custom_param = "api.custom_id"
|
||||
with rls_transaction(
|
||||
for attempt in rls_transaction(
|
||||
str(tenant.id), parameter=custom_param, using=DEFAULT_DB_ALIAS
|
||||
) as cursor:
|
||||
cursor.execute("SELECT current_setting(%s, true)", [custom_param])
|
||||
result = cursor.fetchone()
|
||||
assert result == (str(tenant.id),)
|
||||
):
|
||||
with attempt as cursor:
|
||||
cursor.execute("SELECT current_setting(%s, true)", [custom_param])
|
||||
result = cursor.fetchone()
|
||||
assert result == (str(tenant.id),)
|
||||
|
||||
@@ -364,29 +364,32 @@ class TestRlsTransaction:
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with rls_transaction(tenant_id) as cursor:
|
||||
assert cursor is not None
|
||||
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == tenant_id
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt as cursor:
|
||||
assert cursor is not None
|
||||
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == tenant_id
|
||||
|
||||
def test_rls_transaction_valid_uuid_object(self, tenants_fixture):
|
||||
"""Test rls_transaction with UUID object."""
|
||||
tenant = tenants_fixture[0]
|
||||
|
||||
with rls_transaction(tenant.id) as cursor:
|
||||
assert cursor is not None
|
||||
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == str(tenant.id)
|
||||
for attempt in rls_transaction(tenant.id):
|
||||
with attempt as cursor:
|
||||
assert cursor is not None
|
||||
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == str(tenant.id)
|
||||
|
||||
def test_rls_transaction_invalid_uuid_raises_validation_error(self):
|
||||
"""Test rls_transaction raises ValidationError for invalid UUID."""
|
||||
invalid_uuid = "not-a-valid-uuid"
|
||||
|
||||
with pytest.raises(ValidationError, match="Must be a valid UUID"):
|
||||
with rls_transaction(invalid_uuid):
|
||||
pass
|
||||
for attempt in rls_transaction(invalid_uuid):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
def test_rls_transaction_uses_default_database_when_no_alias(self, tenants_fixture):
|
||||
"""Test rls_transaction uses DEFAULT_DB_ALIAS when no alias specified."""
|
||||
@@ -402,8 +405,9 @@ class TestRlsTransaction:
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_connections.__getitem__.assert_called_with(DEFAULT_DB_ALIAS)
|
||||
|
||||
@@ -424,8 +428,9 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
|
||||
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
|
||||
mock_set_alias.return_value = "test_token"
|
||||
with rls_transaction(tenant_id, using=custom_alias):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id, using=custom_alias):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_connections.__getitem__.assert_called_with(custom_alias)
|
||||
mock_set_alias.assert_called_once_with(custom_alias)
|
||||
@@ -452,8 +457,9 @@ class TestRlsTransaction:
|
||||
"api.db_utils.reset_read_db_alias"
|
||||
) as mock_reset_alias:
|
||||
mock_set_alias.return_value = "test_token"
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_connections.__getitem__.assert_called()
|
||||
mock_set_alias.assert_called_once()
|
||||
@@ -480,8 +486,9 @@ class TestRlsTransaction:
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_connections.__getitem__.assert_called_with(DEFAULT_DB_ALIAS)
|
||||
|
||||
@@ -503,15 +510,20 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with patch("api.db_utils.set_read_db_alias", return_value="token"):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
assert mock_cursor.execute.call_count == 1
|
||||
|
||||
def test_rls_transaction_retry_with_exponential_backoff_on_operational_error(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test retry with exponential backoff on OperationalError on replica."""
|
||||
"""Test retry with exponential backoff on OperationalError on replica.
|
||||
|
||||
REPLICA_MAX_ATTEMPTS=3 means 3 replica tries + 1 primary fallback = 4 total.
|
||||
First 3 attempts fail, 4th succeeds on primary.
|
||||
"""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
@@ -528,7 +540,7 @@ class TestRlsTransaction:
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
if call_count <= 3:
|
||||
raise OperationalError("Connection error")
|
||||
return MagicMock(
|
||||
__enter__=MagicMock(return_value=None),
|
||||
@@ -544,48 +556,24 @@ class TestRlsTransaction:
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with patch("api.db_utils.logger") as mock_logger:
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
assert mock_sleep.call_count == 2
|
||||
assert mock_sleep.call_count == 3
|
||||
mock_sleep.assert_any_call(0.5)
|
||||
mock_sleep.assert_any_call(1.0)
|
||||
assert mock_logger.info.call_count == 2
|
||||
|
||||
def test_rls_transaction_operational_error_inside_context_no_retry(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test OperationalError raised inside context does not retry."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.return_value.__enter__.return_value = None
|
||||
mock_atomic.return_value.__exit__.return_value = False
|
||||
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
raise OperationalError("Conflict with recovery")
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
mock_sleep.assert_any_call(2.0)
|
||||
assert mock_logger.info.call_count == 3
|
||||
assert mock_logger.warning.call_count == 1
|
||||
|
||||
def test_rls_transaction_max_three_attempts_for_replica(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test maximum 3 attempts for replica database."""
|
||||
"""Test maximum attempts for replica database.
|
||||
|
||||
REPLICA_MAX_ATTEMPTS=3 means 3 replica + 1 primary = 4 total attempts.
|
||||
"""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
@@ -606,10 +594,11 @@ class TestRlsTransaction:
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
assert mock_atomic.call_count == 3
|
||||
assert mock_atomic.call_count == 4
|
||||
|
||||
def test_rls_transaction_replica_no_retry_when_disabled(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
@@ -635,10 +624,11 @@ class TestRlsTransaction:
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(
|
||||
for attempt in rls_transaction(
|
||||
tenant_id, retry_on_replica=False
|
||||
):
|
||||
pass
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
assert mock_atomic.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
@@ -660,15 +650,19 @@ class TestRlsTransaction:
|
||||
mock_atomic.side_effect = OperationalError("Primary error")
|
||||
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
assert mock_atomic.call_count == 1
|
||||
|
||||
def test_rls_transaction_fallback_to_primary_after_max_attempts(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test fallback to primary DB after max attempts on replica."""
|
||||
"""Test fallback to primary DB after max attempts on replica.
|
||||
|
||||
First 3 attempts fail on replica, 4th succeeds on primary.
|
||||
"""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
@@ -685,7 +679,7 @@ class TestRlsTransaction:
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
if call_count <= 3:
|
||||
raise OperationalError("Replica error")
|
||||
return MagicMock(
|
||||
__enter__=MagicMock(return_value=None),
|
||||
@@ -701,8 +695,9 @@ class TestRlsTransaction:
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with patch("api.db_utils.logger") as mock_logger:
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_logger.warning.assert_called_once()
|
||||
warning_msg = mock_logger.warning.call_args[0][0]
|
||||
@@ -711,7 +706,10 @@ class TestRlsTransaction:
|
||||
def test_rls_transaction_logger_warning_on_fallback(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test logger warnings are emitted on fallback to primary."""
|
||||
"""Test logger warnings are emitted on fallback to primary.
|
||||
|
||||
3 replica failures produce 3 info logs, then 1 warning on fallback.
|
||||
"""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
@@ -728,7 +726,7 @@ class TestRlsTransaction:
|
||||
def atomic_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
if call_count <= 3:
|
||||
raise OperationalError("Replica error")
|
||||
return MagicMock(
|
||||
__enter__=MagicMock(return_value=None),
|
||||
@@ -744,10 +742,11 @@ class TestRlsTransaction:
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with patch("api.db_utils.logger") as mock_logger:
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
assert mock_logger.info.call_count == 2
|
||||
assert mock_logger.info.call_count == 3
|
||||
assert mock_logger.warning.call_count == 1
|
||||
|
||||
def test_rls_transaction_operational_error_raised_immediately_on_primary(
|
||||
@@ -770,15 +769,16 @@ class TestRlsTransaction:
|
||||
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_rls_transaction_operational_error_raised_after_max_attempts(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test OperationalError raised after max attempts on replica."""
|
||||
"""Test OperationalError raised after all 4 attempts (3 replica + 1 primary)."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
@@ -801,8 +801,9 @@ class TestRlsTransaction:
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
def test_rls_transaction_router_token_set_for_non_default_alias(
|
||||
self, tenants_fixture
|
||||
@@ -823,8 +824,9 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.set_read_db_alias") as mock_set_alias:
|
||||
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
|
||||
mock_set_alias.return_value = "test_token"
|
||||
with rls_transaction(tenant_id, using=custom_alias):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id, using=custom_alias):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_set_alias.assert_called_once_with(custom_alias)
|
||||
mock_reset_alias.assert_called_once_with("test_token")
|
||||
@@ -848,8 +850,11 @@ class TestRlsTransaction:
|
||||
with patch("api.db_utils.set_read_db_alias", return_value="test_token"):
|
||||
with patch("api.db_utils.reset_read_db_alias") as mock_reset_alias:
|
||||
with pytest.raises(Exception):
|
||||
with rls_transaction(tenant_id, using=custom_alias):
|
||||
pass
|
||||
for attempt in rls_transaction(
|
||||
tenant_id, using=custom_alias
|
||||
):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_reset_alias.assert_called_once_with("test_token")
|
||||
|
||||
@@ -873,8 +878,9 @@ class TestRlsTransaction:
|
||||
with patch(
|
||||
"api.db_utils.reset_read_db_alias"
|
||||
) as mock_reset_alias:
|
||||
with rls_transaction(tenant_id):
|
||||
pass
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
mock_set_alias.assert_not_called()
|
||||
mock_reset_alias.assert_not_called()
|
||||
@@ -886,10 +892,11 @@ class TestRlsTransaction:
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with rls_transaction(tenant_id) as cursor:
|
||||
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == tenant_id
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt as cursor:
|
||||
cursor.execute("SELECT current_setting(%s)", [POSTGRES_TENANT_VAR])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == tenant_id
|
||||
|
||||
def test_rls_transaction_custom_parameter(self, tenants_fixture):
|
||||
"""Test rls_transaction with custom parameter name."""
|
||||
@@ -897,21 +904,205 @@ class TestRlsTransaction:
|
||||
tenant_id = str(tenant.id)
|
||||
custom_param = "api.user_id"
|
||||
|
||||
with rls_transaction(tenant_id, parameter=custom_param) as cursor:
|
||||
cursor.execute("SELECT current_setting(%s)", [custom_param])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == tenant_id
|
||||
for attempt in rls_transaction(tenant_id, parameter=custom_param):
|
||||
with attempt as cursor:
|
||||
cursor.execute("SELECT current_setting(%s)", [custom_param])
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == tenant_id
|
||||
|
||||
def test_rls_transaction_cursor_yielded_correctly(self, tenants_fixture):
|
||||
"""Test cursor is yielded correctly."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with rls_transaction(tenant_id) as cursor:
|
||||
assert cursor is not None
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == 1
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt as cursor:
|
||||
assert cursor is not None
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == 1
|
||||
|
||||
def test_rls_transaction_for_with_retries_mid_body_error(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test that OperationalError raised inside the body triggers retry."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
body_call_count = 0
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
body_call_count += 1
|
||||
if body_call_count == 1:
|
||||
raise OperationalError(
|
||||
"Conflict with recovery"
|
||||
)
|
||||
|
||||
assert body_call_count == 2
|
||||
mock_connections.__getitem__.return_value.close.assert_called()
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
def test_rls_transaction_for_with_success_first_attempt(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test happy path on replica: body succeeds on first attempt, no retry."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
iterations = 0
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
iterations += 1
|
||||
|
||||
assert iterations == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_rls_transaction_for_with_closes_stale_connection(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test that stale connection is closed on OperationalError retry."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
body_call_count = 0
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
body_call_count += 1
|
||||
if body_call_count == 1:
|
||||
raise OperationalError("stale connection")
|
||||
|
||||
mock_conn.close.assert_called()
|
||||
|
||||
def test_rls_transaction_for_with_non_operational_error_propagates(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test that non-OperationalError propagates immediately without retry."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(ValueError, match="bad value"):
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
raise ValueError("bad value")
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_rls_transaction_for_with_primary_no_retry(self, tenants_fixture):
|
||||
"""Test that OperationalError on primary propagates immediately."""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=None):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic"):
|
||||
with patch("api.db_utils.time.sleep") as mock_sleep:
|
||||
with pytest.raises(OperationalError):
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
raise OperationalError("primary failure")
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_rls_transaction_for_with_replica_max_attempts_semantics(
|
||||
self, tenants_fixture, enable_read_replica
|
||||
):
|
||||
"""Test that REPLICA_MAX_ATTEMPTS=3 produces 4 atomic calls (3 replica + 1 primary).
|
||||
|
||||
When transaction.atomic always fails, _RLSAttempt.__enter__ retries
|
||||
internally and consumes all attempts. The for-loop body never
|
||||
executes, so we assert on mock_atomic.call_count instead.
|
||||
"""
|
||||
tenant = tenants_fixture[0]
|
||||
tenant_id = str(tenant.id)
|
||||
|
||||
with patch("api.db_utils.get_read_db_alias", return_value=enable_read_replica):
|
||||
with patch("api.db_utils.connections") as mock_connections:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mock_connections.__getitem__.return_value = mock_conn
|
||||
mock_connections.__contains__.return_value = True
|
||||
|
||||
with patch("api.db_utils.transaction.atomic") as mock_atomic:
|
||||
mock_atomic.side_effect = OperationalError("always fails")
|
||||
|
||||
with patch("api.db_utils.time.sleep"):
|
||||
with patch(
|
||||
"api.db_utils.set_read_db_alias", return_value="token"
|
||||
):
|
||||
with patch("api.db_utils.reset_read_db_alias"):
|
||||
with pytest.raises(OperationalError):
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
pass
|
||||
|
||||
# 3 replica attempts + 1 primary fallback
|
||||
assert mock_atomic.call_count == 4
|
||||
# Last call should use the default (primary) alias
|
||||
last_call = mock_atomic.call_args_list[-1]
|
||||
assert last_call.kwargs.get("using") == DEFAULT_DB_ALIAS
|
||||
|
||||
|
||||
class TestPostgresEnumMigration:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import call, patch
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
@@ -65,8 +65,7 @@ class TestHandleProviderDeletionDecorator:
|
||||
tenant = tenants_fixture[0]
|
||||
deleted_provider_id = str(uuid.uuid4())
|
||||
|
||||
mock_rls.return_value.__enter__ = lambda s: None
|
||||
mock_rls.return_value.__exit__ = lambda s, *args: None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
mock_filter.return_value.exists.return_value = False
|
||||
|
||||
@handle_provider_deletion
|
||||
@@ -89,8 +88,7 @@ class TestHandleProviderDeletionDecorator:
|
||||
scan_id = str(uuid.uuid4())
|
||||
provider_id = str(uuid.uuid4())
|
||||
|
||||
mock_rls.return_value.__enter__ = lambda s: None
|
||||
mock_rls.return_value.__exit__ = lambda s, *args: None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
mock_scan = type("MockScan", (), {"provider_id": provider_id})()
|
||||
mock_scan_filter.return_value.first.return_value = mock_scan
|
||||
@@ -112,8 +110,7 @@ class TestHandleProviderDeletionDecorator:
|
||||
tenant = tenants_fixture[0]
|
||||
scan_id = str(uuid.uuid4())
|
||||
|
||||
mock_rls.return_value.__enter__ = lambda s: None
|
||||
mock_rls.return_value.__exit__ = lambda s, *args: None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
mock_scan_filter.return_value.first.return_value = None
|
||||
|
||||
@handle_provider_deletion
|
||||
@@ -134,8 +131,7 @@ class TestHandleProviderDeletionDecorator:
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
mock_rls.return_value.__enter__ = lambda s: None
|
||||
mock_rls.return_value.__exit__ = lambda s, *args: None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
mock_filter.return_value.exists.return_value = True
|
||||
|
||||
@handle_provider_deletion
|
||||
@@ -154,8 +150,7 @@ class TestHandleProviderDeletionDecorator:
|
||||
tenant = tenants_fixture[0]
|
||||
deleted_provider_id = str(uuid.uuid4())
|
||||
|
||||
mock_rls.return_value.__enter__ = lambda s: None
|
||||
mock_rls.return_value.__exit__ = lambda s, *args: None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
mock_filter.return_value.exists.return_value = False
|
||||
|
||||
@handle_provider_deletion
|
||||
@@ -174,8 +169,7 @@ class TestHandleProviderDeletionDecorator:
|
||||
tenant = tenants_fixture[0]
|
||||
deleted_provider_id = str(uuid.uuid4())
|
||||
|
||||
mock_rls.return_value.__enter__ = lambda s: None
|
||||
mock_rls.return_value.__exit__ = lambda s, *args: None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
mock_filter.return_value.exists.return_value = False
|
||||
|
||||
@handle_provider_deletion
|
||||
@@ -194,8 +188,7 @@ class TestHandleProviderDeletionDecorator:
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
mock_rls.return_value.__enter__ = lambda s: None
|
||||
mock_rls.return_value.__exit__ = lambda s, *args: None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
mock_filter.return_value.exists.return_value = True
|
||||
|
||||
@handle_provider_deletion
|
||||
|
||||
@@ -789,9 +789,8 @@ class TestProwlerIntegrationConnectionTest:
|
||||
mock_connection.projects = {"PROJ1": "Project 1", "PROJ2": "Project 2"}
|
||||
mock_jira_class.test_connection.return_value = mock_connection
|
||||
|
||||
# Mock rls_transaction context manager
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock()
|
||||
# Mock rls_transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
result = prowler_integration_connection_test(integration)
|
||||
|
||||
@@ -840,9 +839,8 @@ class TestProwlerIntegrationConnectionTest:
|
||||
mock_connection.projects = {} # Empty projects when connection fails
|
||||
mock_jira_class.test_connection.return_value = mock_connection
|
||||
|
||||
# Mock rls_transaction context manager
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock()
|
||||
# Mock rls_transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
result = prowler_integration_connection_test(integration)
|
||||
|
||||
@@ -895,9 +893,8 @@ class TestProwlerIntegrationConnectionTest:
|
||||
}
|
||||
mock_jira_class.test_connection.return_value = mock_connection
|
||||
|
||||
# Mock rls_transaction context manager
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock()
|
||||
# Mock rls_transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
result = prowler_integration_connection_test(integration)
|
||||
|
||||
|
||||
@@ -415,9 +415,10 @@ def prowler_integration_connection_test(integration: Integration) -> Connection:
|
||||
raise_on_exception=False,
|
||||
)
|
||||
project_keys = jira_connection.projects if jira_connection.is_connected else {}
|
||||
with rls_transaction(str(integration.tenant_id)):
|
||||
integration.configuration["projects"] = project_keys
|
||||
integration.save()
|
||||
for attempt in rls_transaction(str(integration.tenant_id)):
|
||||
with attempt:
|
||||
integration.configuration["projects"] = project_keys
|
||||
integration.save()
|
||||
return jira_connection
|
||||
elif integration.integration_type == Integration.IntegrationChoices.SLACK:
|
||||
pass
|
||||
@@ -544,9 +545,12 @@ def initialize_prowler_integration(integration: Integration) -> Jira:
|
||||
try:
|
||||
return Jira(**integration.credentials)
|
||||
except JiraBasicAuthError as jira_auth_error:
|
||||
with rls_transaction(str(integration.tenant_id)):
|
||||
integration.configuration["projects"] = {}
|
||||
integration.connected = False
|
||||
integration.connection_last_checked_at = datetime.now(tz=timezone.utc)
|
||||
integration.save()
|
||||
for attempt in rls_transaction(str(integration.tenant_id)):
|
||||
with attempt:
|
||||
integration.configuration["projects"] = {}
|
||||
integration.connected = False
|
||||
integration.connection_last_checked_at = datetime.now(
|
||||
tz=timezone.utc
|
||||
)
|
||||
integration.save()
|
||||
raise jira_auth_error
|
||||
|
||||
@@ -630,8 +630,11 @@ class SAMLInitiateAPIView(GenericAPIView):
|
||||
# Retrieve the SAML configuration for the given email domain
|
||||
try:
|
||||
check = SAMLDomainIndex.objects.get(email_domain=domain)
|
||||
with rls_transaction(str(check.tenant_id)):
|
||||
config = SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
|
||||
for attempt in rls_transaction(str(check.tenant_id)):
|
||||
with attempt:
|
||||
config = SAMLConfiguration.objects.get(
|
||||
tenant_id=str(check.tenant_id)
|
||||
)
|
||||
except (SAMLDomainIndex.DoesNotExist, SAMLConfiguration.DoesNotExist):
|
||||
return Response(
|
||||
{"detail": "Unauthorized domain."}, status=status.HTTP_403_FORBIDDEN
|
||||
@@ -738,8 +741,9 @@ class TenantFinishACSView(FinishACSView):
|
||||
# This handles scenarios like partially deleted or missing related objects
|
||||
try:
|
||||
check = SAMLDomainIndex.objects.get(email_domain=organization_slug)
|
||||
with rls_transaction(str(check.tenant_id)):
|
||||
SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
|
||||
for attempt in rls_transaction(str(check.tenant_id)):
|
||||
with attempt:
|
||||
SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
|
||||
social_app = SocialApp.objects.get(
|
||||
provider="saml", client_id=organization_slug
|
||||
)
|
||||
|
||||
@@ -57,10 +57,11 @@ class RLSTask(Task):
|
||||
from api.db_utils import rls_transaction
|
||||
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
with rls_transaction(tenant_id):
|
||||
APITask.objects.update_or_create(
|
||||
id=task_result_instance.task_id,
|
||||
tenant_id=tenant_id,
|
||||
defaults={"task_runner_task": task_result_instance},
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
APITask.objects.update_or_create(
|
||||
id=task_result_instance.task_id,
|
||||
tenant_id=tenant_id,
|
||||
defaults={"task_runner_task": task_result_instance},
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -418,23 +418,24 @@ def tenants_fixture(create_test_user):
|
||||
def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
|
||||
user = create_test_user
|
||||
for tenant in tenants_fixture[:2]:
|
||||
with rls_transaction(str(tenant.id)):
|
||||
role = Role.objects.create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
for attempt in rls_transaction(str(tenant.id)):
|
||||
with attempt:
|
||||
role = Role.objects.create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -30,15 +30,16 @@ def schedule_provider_scan(provider_instance: Provider):
|
||||
pointer="/data/attributes/provider_id",
|
||||
)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
scheduled_scan = Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
name="Daily scheduled scan",
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.AVAILABLE,
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
scheduled_scan = Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
name="Daily scheduled scan",
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.AVAILABLE,
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
attack_paths_db_utils.create_attack_paths_scan(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -15,8 +15,9 @@ logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
def can_provider_run_attack_paths_scan(tenant_id: str, provider_id: int) -> bool:
|
||||
with rls_transaction(tenant_id):
|
||||
prowler_api_provider = ProwlerAPIProvider.objects.get(id=provider_id)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
prowler_api_provider = ProwlerAPIProvider.objects.get(id=provider_id)
|
||||
|
||||
return is_provider_available(prowler_api_provider.provider)
|
||||
|
||||
@@ -29,24 +30,25 @@ def create_attack_paths_scan(
|
||||
if not can_provider_run_attack_paths_scan(tenant_id, provider_id):
|
||||
return None
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
# Inherit graph_data_ready from the previous scan for this provider,
|
||||
# so queries remain available while the new scan runs.
|
||||
previous_data_ready = ProwlerAPIAttackPathsScan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
graph_data_ready=True,
|
||||
).exists()
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
# Inherit graph_data_ready from the previous scan for this provider,
|
||||
# so queries remain available while the new scan runs.
|
||||
previous_data_ready = ProwlerAPIAttackPathsScan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
graph_data_ready=True,
|
||||
).exists()
|
||||
|
||||
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scan_id=scan_id,
|
||||
state=StateChoices.SCHEDULED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
graph_data_ready=previous_data_ready,
|
||||
)
|
||||
attack_paths_scan.save()
|
||||
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scan_id=scan_id,
|
||||
state=StateChoices.SCHEDULED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
graph_data_ready=previous_data_ready,
|
||||
)
|
||||
attack_paths_scan.save()
|
||||
|
||||
return attack_paths_scan
|
||||
|
||||
@@ -56,10 +58,11 @@ def retrieve_attack_paths_scan(
|
||||
scan_id: str,
|
||||
) -> ProwlerAPIAttackPathsScan | None:
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.get(
|
||||
scan_id=scan_id,
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.get(
|
||||
scan_id=scan_id,
|
||||
)
|
||||
|
||||
return attack_paths_scan
|
||||
|
||||
@@ -72,20 +75,21 @@ def starting_attack_paths_scan(
|
||||
task_id: str,
|
||||
cartography_config: CartographyConfig,
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
attack_paths_scan.task_id = task_id
|
||||
attack_paths_scan.state = StateChoices.EXECUTING
|
||||
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
|
||||
attack_paths_scan.update_tag = cartography_config.update_tag
|
||||
for attempt in rls_transaction(attack_paths_scan.tenant_id):
|
||||
with attempt:
|
||||
attack_paths_scan.task_id = task_id
|
||||
attack_paths_scan.state = StateChoices.EXECUTING
|
||||
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
|
||||
attack_paths_scan.update_tag = cartography_config.update_tag
|
||||
|
||||
attack_paths_scan.save(
|
||||
update_fields=[
|
||||
"task_id",
|
||||
"state",
|
||||
"started_at",
|
||||
"update_tag",
|
||||
]
|
||||
)
|
||||
attack_paths_scan.save(
|
||||
update_fields=[
|
||||
"task_id",
|
||||
"state",
|
||||
"started_at",
|
||||
"update_tag",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def finish_attack_paths_scan(
|
||||
@@ -93,47 +97,50 @@ def finish_attack_paths_scan(
|
||||
state: StateChoices,
|
||||
ingestion_exceptions: dict[str, Any],
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
duration = (
|
||||
int((now - attack_paths_scan.started_at).total_seconds())
|
||||
if attack_paths_scan.started_at
|
||||
else 0
|
||||
)
|
||||
for attempt in rls_transaction(attack_paths_scan.tenant_id):
|
||||
with attempt:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
duration = (
|
||||
int((now - attack_paths_scan.started_at).total_seconds())
|
||||
if attack_paths_scan.started_at
|
||||
else 0
|
||||
)
|
||||
|
||||
attack_paths_scan.state = state
|
||||
attack_paths_scan.progress = 100
|
||||
attack_paths_scan.completed_at = now
|
||||
attack_paths_scan.duration = duration
|
||||
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
|
||||
attack_paths_scan.state = state
|
||||
attack_paths_scan.progress = 100
|
||||
attack_paths_scan.completed_at = now
|
||||
attack_paths_scan.duration = duration
|
||||
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
|
||||
|
||||
attack_paths_scan.save(
|
||||
update_fields=[
|
||||
"state",
|
||||
"progress",
|
||||
"completed_at",
|
||||
"duration",
|
||||
"ingestion_exceptions",
|
||||
]
|
||||
)
|
||||
attack_paths_scan.save(
|
||||
update_fields=[
|
||||
"state",
|
||||
"progress",
|
||||
"completed_at",
|
||||
"duration",
|
||||
"ingestion_exceptions",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def update_attack_paths_scan_progress(
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
progress: int,
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
attack_paths_scan.progress = progress
|
||||
attack_paths_scan.save(update_fields=["progress"])
|
||||
for attempt in rls_transaction(attack_paths_scan.tenant_id):
|
||||
with attempt:
|
||||
attack_paths_scan.progress = progress
|
||||
attack_paths_scan.save(update_fields=["progress"])
|
||||
|
||||
|
||||
def set_graph_data_ready(
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
ready: bool,
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
attack_paths_scan.graph_data_ready = ready
|
||||
attack_paths_scan.save(update_fields=["graph_data_ready"])
|
||||
for attempt in rls_transaction(attack_paths_scan.tenant_id):
|
||||
with attempt:
|
||||
attack_paths_scan.graph_data_ready = ready
|
||||
attack_paths_scan.save(update_fields=["graph_data_ready"])
|
||||
|
||||
|
||||
def set_provider_graph_data_ready(
|
||||
@@ -145,12 +152,13 @@ def set_provider_graph_data_ready(
|
||||
|
||||
Used before drop/sync so that older scan IDs cannot bypass the query gate while the graph is being replaced.
|
||||
"""
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
ProwlerAPIAttackPathsScan.objects.filter(
|
||||
tenant_id=attack_paths_scan.tenant_id,
|
||||
provider_id=attack_paths_scan.provider_id,
|
||||
).update(graph_data_ready=ready)
|
||||
attack_paths_scan.refresh_from_db(fields=["graph_data_ready"])
|
||||
for attempt in rls_transaction(attack_paths_scan.tenant_id):
|
||||
with attempt:
|
||||
ProwlerAPIAttackPathsScan.objects.filter(
|
||||
tenant_id=attack_paths_scan.tenant_id,
|
||||
provider_id=attack_paths_scan.provider_id,
|
||||
).update(graph_data_ready=ready)
|
||||
attack_paths_scan.refresh_from_db(fields=["graph_data_ready"])
|
||||
|
||||
|
||||
def recover_graph_data_ready(
|
||||
|
||||
@@ -277,15 +277,16 @@ def _fetch_findings_batch(
|
||||
|
||||
Uses read replica and RLS-scoped transaction.
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Use `all_objects` to get `Findings` even on soft-deleted `Providers`
|
||||
# But even the provider is already validated as active in this context
|
||||
qs = FindingModel.all_objects.filter(scan_id=scan_id).order_by("id")
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
# Use `all_objects` to get `Findings` even on soft-deleted `Providers`
|
||||
# But even the provider is already validated as active in this context
|
||||
qs = FindingModel.all_objects.filter(scan_id=scan_id).order_by("id")
|
||||
|
||||
if after_id is not None:
|
||||
qs = qs.filter(id__gt=after_id)
|
||||
if after_id is not None:
|
||||
qs = qs.filter(id__gt=after_id)
|
||||
|
||||
return list(qs.values(*_DB_QUERY_FIELDS)[:FINDINGS_BATCH_SIZE])
|
||||
return list(qs.values(*_DB_QUERY_FIELDS)[:FINDINGS_BATCH_SIZE])
|
||||
|
||||
|
||||
# Batch Enrichment
|
||||
@@ -316,12 +317,13 @@ def _build_finding_resource_map(
|
||||
finding_ids: list[UUID], tenant_id: str
|
||||
) -> dict[UUID, list[str]]:
|
||||
"""Build mapping from finding_id to list of resource UIDs."""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
resource_mappings = ResourceFindingMapping.objects.filter(
|
||||
finding_id__in=finding_ids
|
||||
).values_list("finding_id", "resource__uid")
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
resource_mappings = ResourceFindingMapping.objects.filter(
|
||||
finding_id__in=finding_ids
|
||||
).values_list("finding_id", "resource__uid")
|
||||
|
||||
result = defaultdict(list)
|
||||
for finding_id, resource_uid in resource_mappings:
|
||||
result[finding_id].append(resource_uid)
|
||||
return result
|
||||
result = defaultdict(list)
|
||||
for finding_id, resource_uid in resource_mappings:
|
||||
result[finding_id].append(resource_uid)
|
||||
return result
|
||||
|
||||
@@ -86,9 +86,10 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
|
||||
ingestion_exceptions = {} # This will hold any exceptions raised during ingestion
|
||||
|
||||
# Prowler necessary objects
|
||||
with rls_transaction(tenant_id):
|
||||
prowler_api_provider = ProwlerAPIProvider.objects.get(scan__pk=scan_id)
|
||||
prowler_sdk_provider = initialize_prowler_provider(prowler_api_provider)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
prowler_api_provider = ProwlerAPIProvider.objects.get(scan__pk=scan_id)
|
||||
prowler_sdk_provider = initialize_prowler_provider(prowler_api_provider)
|
||||
|
||||
# Attack Paths Scan necessary objects
|
||||
cartography_ingestion_function = get_cartography_ingestion_function(
|
||||
|
||||
@@ -41,54 +41,56 @@ logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
def backfill_resource_scan_summaries(tenant_id: str, scan_id: str):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
if ResourceScanSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
if ResourceScanSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
|
||||
resource_ids_qs = (
|
||||
ResourceFindingMapping.objects.filter(
|
||||
tenant_id=tenant_id, finding__scan_id=scan_id
|
||||
)
|
||||
.values_list("resource_id", flat=True)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
resource_ids = list(resource_ids_qs)
|
||||
|
||||
if not resource_ids:
|
||||
return {"status": "no resources to backfill"}
|
||||
|
||||
resources_qs = Resource.objects.filter(
|
||||
tenant_id=tenant_id, id__in=resource_ids
|
||||
).only("id", "service", "region", "type")
|
||||
|
||||
summaries = []
|
||||
for resource in resources_qs.iterator():
|
||||
summaries.append(
|
||||
ResourceScanSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
resource_id=str(resource.id),
|
||||
service=resource.service,
|
||||
region=resource.region,
|
||||
resource_type=resource.type,
|
||||
resource_ids_qs = (
|
||||
ResourceFindingMapping.objects.filter(
|
||||
tenant_id=tenant_id, finding__scan_id=scan_id
|
||||
)
|
||||
.values_list("resource_id", flat=True)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
for i in range(0, len(summaries), 500):
|
||||
ResourceScanSummary.objects.bulk_create(
|
||||
summaries[i : i + 500], ignore_conflicts=True
|
||||
)
|
||||
resource_ids = list(resource_ids_qs)
|
||||
|
||||
if not resource_ids:
|
||||
return {"status": "no resources to backfill"}
|
||||
|
||||
resources_qs = Resource.objects.filter(
|
||||
tenant_id=tenant_id, id__in=resource_ids
|
||||
).only("id", "service", "region", "type")
|
||||
|
||||
summaries = []
|
||||
for resource in resources_qs.iterator():
|
||||
summaries.append(
|
||||
ResourceScanSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
resource_id=str(resource.id),
|
||||
service=resource.service,
|
||||
region=resource.region,
|
||||
resource_type=resource.type,
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(0, len(summaries), 500):
|
||||
ResourceScanSummary.objects.bulk_create(
|
||||
summaries[i : i + 500], ignore_conflicts=True
|
||||
)
|
||||
|
||||
return {"status": "backfilled", "inserted": len(summaries)}
|
||||
|
||||
@@ -107,99 +109,101 @@ def backfill_compliance_summaries(tenant_id: str, scan_id: str):
|
||||
Returns:
|
||||
dict: Status indicating whether backfill was performed
|
||||
"""
|
||||
with rls_transaction(tenant_id):
|
||||
if ComplianceOverviewSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
if ComplianceOverviewSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
|
||||
# Fetch all compliance requirement overview rows for this scan
|
||||
requirement_rows = ComplianceRequirementOverview.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).values(
|
||||
"compliance_id",
|
||||
"requirement_id",
|
||||
"requirement_status",
|
||||
)
|
||||
# Fetch all compliance requirement overview rows for this scan
|
||||
requirement_rows = ComplianceRequirementOverview.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).values(
|
||||
"compliance_id",
|
||||
"requirement_id",
|
||||
"requirement_status",
|
||||
)
|
||||
|
||||
if not requirement_rows:
|
||||
return {"status": "no compliance data to backfill"}
|
||||
if not requirement_rows:
|
||||
return {"status": "no compliance data to backfill"}
|
||||
|
||||
# Group by (compliance_id, requirement_id) across regions
|
||||
requirement_statuses = defaultdict(
|
||||
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
|
||||
)
|
||||
# Group by (compliance_id, requirement_id) across regions
|
||||
requirement_statuses = defaultdict(
|
||||
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
|
||||
)
|
||||
|
||||
for row in requirement_rows:
|
||||
compliance_id = row["compliance_id"]
|
||||
requirement_id = row["requirement_id"]
|
||||
requirement_status = row["requirement_status"]
|
||||
for row in requirement_rows:
|
||||
compliance_id = row["compliance_id"]
|
||||
requirement_id = row["requirement_id"]
|
||||
requirement_status = row["requirement_status"]
|
||||
|
||||
# Aggregate requirement status across regions
|
||||
key = (compliance_id, requirement_id)
|
||||
requirement_statuses[key]["total_count"] += 1
|
||||
# Aggregate requirement status across regions
|
||||
key = (compliance_id, requirement_id)
|
||||
requirement_statuses[key]["total_count"] += 1
|
||||
|
||||
if requirement_status == "FAIL":
|
||||
requirement_statuses[key]["fail_count"] += 1
|
||||
elif requirement_status == "PASS":
|
||||
requirement_statuses[key]["pass_count"] += 1
|
||||
if requirement_status == "FAIL":
|
||||
requirement_statuses[key]["fail_count"] += 1
|
||||
elif requirement_status == "PASS":
|
||||
requirement_statuses[key]["pass_count"] += 1
|
||||
|
||||
# Determine per-requirement status and aggregate to compliance level
|
||||
compliance_summaries = defaultdict(
|
||||
lambda: {
|
||||
"total_requirements": 0,
|
||||
"requirements_passed": 0,
|
||||
"requirements_failed": 0,
|
||||
"requirements_manual": 0,
|
||||
}
|
||||
)
|
||||
# Determine per-requirement status and aggregate to compliance level
|
||||
compliance_summaries = defaultdict(
|
||||
lambda: {
|
||||
"total_requirements": 0,
|
||||
"requirements_passed": 0,
|
||||
"requirements_failed": 0,
|
||||
"requirements_manual": 0,
|
||||
}
|
||||
)
|
||||
|
||||
for (compliance_id, requirement_id), counts in requirement_statuses.items():
|
||||
# Apply business rule: any FAIL → requirement fails
|
||||
if counts["fail_count"] > 0:
|
||||
req_status = "FAIL"
|
||||
elif counts["pass_count"] == counts["total_count"]:
|
||||
req_status = "PASS"
|
||||
else:
|
||||
req_status = "MANUAL"
|
||||
for (compliance_id, requirement_id), counts in requirement_statuses.items():
|
||||
# Apply business rule: any FAIL → requirement fails
|
||||
if counts["fail_count"] > 0:
|
||||
req_status = "FAIL"
|
||||
elif counts["pass_count"] == counts["total_count"]:
|
||||
req_status = "PASS"
|
||||
else:
|
||||
req_status = "MANUAL"
|
||||
|
||||
# Aggregate to compliance level
|
||||
compliance_summaries[compliance_id]["total_requirements"] += 1
|
||||
if req_status == "PASS":
|
||||
compliance_summaries[compliance_id]["requirements_passed"] += 1
|
||||
elif req_status == "FAIL":
|
||||
compliance_summaries[compliance_id]["requirements_failed"] += 1
|
||||
else:
|
||||
compliance_summaries[compliance_id]["requirements_manual"] += 1
|
||||
# Aggregate to compliance level
|
||||
compliance_summaries[compliance_id]["total_requirements"] += 1
|
||||
if req_status == "PASS":
|
||||
compliance_summaries[compliance_id]["requirements_passed"] += 1
|
||||
elif req_status == "FAIL":
|
||||
compliance_summaries[compliance_id]["requirements_failed"] += 1
|
||||
else:
|
||||
compliance_summaries[compliance_id]["requirements_manual"] += 1
|
||||
|
||||
# Create summary objects
|
||||
summary_objects = []
|
||||
for compliance_id, data in compliance_summaries.items():
|
||||
summary_objects.append(
|
||||
ComplianceOverviewSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
compliance_id=compliance_id,
|
||||
requirements_passed=data["requirements_passed"],
|
||||
requirements_failed=data["requirements_failed"],
|
||||
requirements_manual=data["requirements_manual"],
|
||||
total_requirements=data["total_requirements"],
|
||||
# Create summary objects
|
||||
summary_objects = []
|
||||
for compliance_id, data in compliance_summaries.items():
|
||||
summary_objects.append(
|
||||
ComplianceOverviewSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
compliance_id=compliance_id,
|
||||
requirements_passed=data["requirements_passed"],
|
||||
requirements_failed=data["requirements_failed"],
|
||||
requirements_manual=data["requirements_manual"],
|
||||
total_requirements=data["total_requirements"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Bulk insert summaries
|
||||
if summary_objects:
|
||||
ComplianceOverviewSummary.objects.bulk_create(
|
||||
summary_objects, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
# Bulk insert summaries
|
||||
if summary_objects:
|
||||
ComplianceOverviewSummary.objects.bulk_create(
|
||||
summary_objects, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
|
||||
return {"status": "backfilled", "inserted": len(summary_objects)}
|
||||
|
||||
@@ -212,82 +216,85 @@ def backfill_daily_severity_summaries(tenant_id: str, days: int = None):
|
||||
created_count = 0
|
||||
updated_count = 0
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
scan_filter = {
|
||||
"tenant_id": tenant_id,
|
||||
"state": StateChoices.COMPLETED,
|
||||
"completed_at__isnull": False,
|
||||
}
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
scan_filter = {
|
||||
"tenant_id": tenant_id,
|
||||
"state": StateChoices.COMPLETED,
|
||||
"completed_at__isnull": False,
|
||||
}
|
||||
|
||||
if days is not None:
|
||||
cutoff_date = timezone.now() - timedelta(days=days)
|
||||
scan_filter["completed_at__gte"] = cutoff_date
|
||||
if days is not None:
|
||||
cutoff_date = timezone.now() - timedelta(days=days)
|
||||
scan_filter["completed_at__gte"] = cutoff_date
|
||||
|
||||
completed_scans = (
|
||||
Scan.objects.filter(**scan_filter)
|
||||
.order_by("provider_id", "-completed_at")
|
||||
.values("id", "provider_id", "completed_at")
|
||||
)
|
||||
completed_scans = (
|
||||
Scan.objects.filter(**scan_filter)
|
||||
.order_by("provider_id", "-completed_at")
|
||||
.values("id", "provider_id", "completed_at")
|
||||
)
|
||||
|
||||
if not completed_scans:
|
||||
return {"status": "no scans to backfill"}
|
||||
if not completed_scans:
|
||||
return {"status": "no scans to backfill"}
|
||||
|
||||
# Keep only latest scan per provider/day
|
||||
latest_scans_by_day = {}
|
||||
for scan in completed_scans:
|
||||
key = (scan["provider_id"], scan["completed_at"].date())
|
||||
if key not in latest_scans_by_day:
|
||||
latest_scans_by_day[key] = scan
|
||||
# Keep only latest scan per provider/day
|
||||
latest_scans_by_day = {}
|
||||
for scan in completed_scans:
|
||||
key = (scan["provider_id"], scan["completed_at"].date())
|
||||
if key not in latest_scans_by_day:
|
||||
latest_scans_by_day[key] = scan
|
||||
|
||||
# Process each provider/day
|
||||
for (provider_id, scan_date), scan in latest_scans_by_day.items():
|
||||
scan_id = scan["id"]
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
severity_totals = (
|
||||
ScanSummary.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
severity_totals = (
|
||||
ScanSummary.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
)
|
||||
.values("severity")
|
||||
.annotate(total_fail=Sum("fail"), total_muted=Sum("muted"))
|
||||
)
|
||||
.values("severity")
|
||||
.annotate(total_fail=Sum("fail"), total_muted=Sum("muted"))
|
||||
)
|
||||
|
||||
severity_data = {
|
||||
"critical": 0,
|
||||
"high": 0,
|
||||
"medium": 0,
|
||||
"low": 0,
|
||||
"informational": 0,
|
||||
"muted": 0,
|
||||
}
|
||||
severity_data = {
|
||||
"critical": 0,
|
||||
"high": 0,
|
||||
"medium": 0,
|
||||
"low": 0,
|
||||
"informational": 0,
|
||||
"muted": 0,
|
||||
}
|
||||
|
||||
for row in severity_totals:
|
||||
severity = row["severity"]
|
||||
if severity in severity_data:
|
||||
severity_data[severity] = row["total_fail"] or 0
|
||||
severity_data["muted"] += row["total_muted"] or 0
|
||||
for row in severity_totals:
|
||||
severity = row["severity"]
|
||||
if severity in severity_data:
|
||||
severity_data[severity] = row["total_fail"] or 0
|
||||
severity_data["muted"] += row["total_muted"] or 0
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
_, created = DailySeveritySummary.objects.update_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
date=scan_date,
|
||||
defaults={
|
||||
"scan_id": scan_id,
|
||||
"critical": severity_data["critical"],
|
||||
"high": severity_data["high"],
|
||||
"medium": severity_data["medium"],
|
||||
"low": severity_data["low"],
|
||||
"informational": severity_data["informational"],
|
||||
"muted": severity_data["muted"],
|
||||
},
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
_, created = DailySeveritySummary.objects.update_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
date=scan_date,
|
||||
defaults={
|
||||
"scan_id": scan_id,
|
||||
"critical": severity_data["critical"],
|
||||
"high": severity_data["high"],
|
||||
"medium": severity_data["medium"],
|
||||
"low": severity_data["low"],
|
||||
"informational": severity_data["informational"],
|
||||
"muted": severity_data["muted"],
|
||||
},
|
||||
)
|
||||
|
||||
if created:
|
||||
created_count += 1
|
||||
else:
|
||||
updated_count += 1
|
||||
if created:
|
||||
created_count += 1
|
||||
else:
|
||||
updated_count += 1
|
||||
|
||||
return {
|
||||
"status": "backfilled",
|
||||
@@ -311,34 +318,35 @@ def backfill_scan_category_summaries(tenant_id: str, scan_id: str):
|
||||
Returns:
|
||||
dict: Status indicating whether backfill was performed
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
if ScanCategorySummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
if ScanCategorySummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
|
||||
category_counts: dict[tuple[str, str], dict[str, int]] = {}
|
||||
for finding in Finding.all_objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).values("categories", "severity", "status", "delta", "muted"):
|
||||
aggregate_category_counts(
|
||||
categories=finding.get("categories") or [],
|
||||
severity=finding.get("severity"),
|
||||
status=finding.get("status"),
|
||||
delta=finding.get("delta"),
|
||||
muted=finding.get("muted", False),
|
||||
cache=category_counts,
|
||||
)
|
||||
category_counts: dict[tuple[str, str], dict[str, int]] = {}
|
||||
for finding in Finding.all_objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).values("categories", "severity", "status", "delta", "muted"):
|
||||
aggregate_category_counts(
|
||||
categories=finding.get("categories") or [],
|
||||
severity=finding.get("severity"),
|
||||
status=finding.get("status"),
|
||||
delta=finding.get("delta"),
|
||||
muted=finding.get("muted", False),
|
||||
cache=category_counts,
|
||||
)
|
||||
|
||||
if not category_counts:
|
||||
return {"status": "no categories to backfill"}
|
||||
if not category_counts:
|
||||
return {"status": "no categories to backfill"}
|
||||
|
||||
category_summaries = [
|
||||
ScanCategorySummary(
|
||||
@@ -353,10 +361,11 @@ def backfill_scan_category_summaries(tenant_id: str, scan_id: str):
|
||||
for (category, severity), counts in category_counts.items()
|
||||
]
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
ScanCategorySummary.objects.bulk_create(
|
||||
category_summaries, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
ScanCategorySummary.objects.bulk_create(
|
||||
category_summaries, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
|
||||
return {"status": "backfilled", "categories_count": len(category_counts)}
|
||||
|
||||
@@ -375,51 +384,52 @@ def backfill_scan_resource_group_summaries(tenant_id: str, scan_id: str):
|
||||
Returns:
|
||||
dict: Status indicating whether backfill was performed
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
if ScanGroupSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
if ScanGroupSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
|
||||
resource_group_counts: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
# Get findings with their first resource UID via annotation
|
||||
resource_uid_subquery = ResourceFindingMapping.objects.filter(
|
||||
finding_id=OuterRef("id"), tenant_id=tenant_id
|
||||
).values("resource__uid")[:1]
|
||||
resource_group_counts: dict[tuple[str, str], dict[str, int]] = {}
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
# Get findings with their first resource UID via annotation
|
||||
resource_uid_subquery = ResourceFindingMapping.objects.filter(
|
||||
finding_id=OuterRef("id"), tenant_id=tenant_id
|
||||
).values("resource__uid")[:1]
|
||||
|
||||
for finding in (
|
||||
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
.annotate(resource_uid=Subquery(resource_uid_subquery))
|
||||
.values(
|
||||
"resource_groups",
|
||||
"severity",
|
||||
"status",
|
||||
"delta",
|
||||
"muted",
|
||||
"resource_uid",
|
||||
)
|
||||
):
|
||||
aggregate_resource_group_counts(
|
||||
resource_group=finding.get("resource_groups"),
|
||||
severity=finding.get("severity"),
|
||||
status=finding.get("status"),
|
||||
delta=finding.get("delta"),
|
||||
muted=finding.get("muted", False),
|
||||
resource_uid=finding.get("resource_uid") or "",
|
||||
cache=resource_group_counts,
|
||||
group_resources_cache=group_resources_cache,
|
||||
)
|
||||
for finding in (
|
||||
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
.annotate(resource_uid=Subquery(resource_uid_subquery))
|
||||
.values(
|
||||
"resource_groups",
|
||||
"severity",
|
||||
"status",
|
||||
"delta",
|
||||
"muted",
|
||||
"resource_uid",
|
||||
)
|
||||
):
|
||||
aggregate_resource_group_counts(
|
||||
resource_group=finding.get("resource_groups"),
|
||||
severity=finding.get("severity"),
|
||||
status=finding.get("status"),
|
||||
delta=finding.get("delta"),
|
||||
muted=finding.get("muted", False),
|
||||
resource_uid=finding.get("resource_uid") or "",
|
||||
cache=resource_group_counts,
|
||||
group_resources_cache=group_resources_cache,
|
||||
)
|
||||
|
||||
if not resource_group_counts:
|
||||
return {"status": "no resource groups to backfill"}
|
||||
if not resource_group_counts:
|
||||
return {"status": "no resource groups to backfill"}
|
||||
|
||||
# Compute group-level resource counts (same value for all severity rows in a group)
|
||||
group_resource_counts = {
|
||||
@@ -439,10 +449,11 @@ def backfill_scan_resource_group_summaries(tenant_id: str, scan_id: str):
|
||||
for (grp, severity), counts in resource_group_counts.items()
|
||||
]
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
ScanGroupSummary.objects.bulk_create(
|
||||
resource_group_summaries, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
ScanGroupSummary.objects.bulk_create(
|
||||
resource_group_summaries, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
|
||||
return {"status": "backfilled", "resource_groups_count": len(resource_group_counts)}
|
||||
|
||||
@@ -460,34 +471,35 @@ def backfill_provider_compliance_scores(tenant_id: str) -> dict:
|
||||
Returns:
|
||||
dict: Statistics about the backfill operation
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
completed_scans = Scan.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
completed_at__isnull=False,
|
||||
)
|
||||
if not completed_scans.exists():
|
||||
return {"status": "no completed scans"}
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
completed_scans = Scan.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
completed_at__isnull=False,
|
||||
)
|
||||
if not completed_scans.exists():
|
||||
return {"status": "no completed scans"}
|
||||
|
||||
existing_providers = set(
|
||||
ProviderComplianceScore.objects.filter(tenant_id=tenant_id)
|
||||
.values_list("provider_id", flat=True)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
if existing_providers:
|
||||
completed_scans = completed_scans.exclude(
|
||||
provider_id__in=existing_providers
|
||||
existing_providers = set(
|
||||
ProviderComplianceScore.objects.filter(tenant_id=tenant_id)
|
||||
.values_list("provider_id", flat=True)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
scan_info = list(
|
||||
completed_scans.order_by("provider_id", "-completed_at")
|
||||
.distinct("provider_id")
|
||||
.values("id", "provider_id", "completed_at")
|
||||
)
|
||||
if existing_providers:
|
||||
completed_scans = completed_scans.exclude(
|
||||
provider_id__in=existing_providers
|
||||
)
|
||||
|
||||
if not scan_info:
|
||||
return {"status": "no scans to process"}
|
||||
scan_info = list(
|
||||
completed_scans.order_by("provider_id", "-completed_at")
|
||||
.distinct("provider_id")
|
||||
.values("id", "provider_id", "completed_at")
|
||||
)
|
||||
|
||||
if not scan_info:
|
||||
return {"status": "no scans to process"}
|
||||
|
||||
total_upserted = 0
|
||||
providers_processed = 0
|
||||
@@ -577,32 +589,33 @@ def backfill_finding_group_summaries(tenant_id: str, days: int = None):
|
||||
total_created = 0
|
||||
total_updated = 0
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
scan_filter = {
|
||||
"tenant_id": tenant_id,
|
||||
"state": StateChoices.COMPLETED,
|
||||
"completed_at__isnull": False,
|
||||
}
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
scan_filter = {
|
||||
"tenant_id": tenant_id,
|
||||
"state": StateChoices.COMPLETED,
|
||||
"completed_at__isnull": False,
|
||||
}
|
||||
|
||||
if days is not None:
|
||||
cutoff_date = timezone.now() - timedelta(days=days)
|
||||
scan_filter["completed_at__gte"] = cutoff_date
|
||||
if days is not None:
|
||||
cutoff_date = timezone.now() - timedelta(days=days)
|
||||
scan_filter["completed_at__gte"] = cutoff_date
|
||||
|
||||
completed_scans = (
|
||||
Scan.objects.filter(**scan_filter)
|
||||
.order_by("-completed_at")
|
||||
.values("id", "completed_at")
|
||||
)
|
||||
completed_scans = (
|
||||
Scan.objects.filter(**scan_filter)
|
||||
.order_by("-completed_at")
|
||||
.values("id", "completed_at")
|
||||
)
|
||||
|
||||
if not completed_scans:
|
||||
return {"status": "no scans to backfill"}
|
||||
if not completed_scans:
|
||||
return {"status": "no scans to backfill"}
|
||||
|
||||
# Keep only latest scan per day
|
||||
latest_scans_by_day = {}
|
||||
for scan in completed_scans:
|
||||
key = scan["completed_at"].date()
|
||||
if key not in latest_scans_by_day:
|
||||
latest_scans_by_day[key] = scan
|
||||
# Keep only latest scan per day
|
||||
latest_scans_by_day = {}
|
||||
for scan in completed_scans:
|
||||
key = scan["completed_at"].date()
|
||||
if key not in latest_scans_by_day:
|
||||
latest_scans_by_day[key] = scan
|
||||
|
||||
# Process each day's scan
|
||||
for scan_date, scan in latest_scans_by_day.items():
|
||||
|
||||
@@ -28,20 +28,21 @@ def _recalculate_tenant_compliance_summary(tenant_id: str, compliance_ids: list[
|
||||
|
||||
compliance_ids = sorted(set(compliance_ids))
|
||||
|
||||
with rls_transaction(tenant_id, using=MainRouter.default_db) as cursor:
|
||||
# Serialize tenant-level summary updates to avoid concurrent recomputes
|
||||
cursor.execute(
|
||||
"SELECT pg_advisory_xact_lock(hashtext(%s))",
|
||||
[tenant_id],
|
||||
)
|
||||
cursor.execute(
|
||||
COMPLIANCE_UPSERT_TENANT_SUMMARY_SQL,
|
||||
[tenant_id, tenant_id, compliance_ids],
|
||||
)
|
||||
cursor.execute(
|
||||
COMPLIANCE_DELETE_EMPTY_TENANT_SUMMARY_SQL,
|
||||
[tenant_id, compliance_ids],
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id, using=MainRouter.default_db):
|
||||
with attempt as cursor:
|
||||
# Serialize tenant-level summary updates to avoid concurrent recomputes
|
||||
cursor.execute(
|
||||
"SELECT pg_advisory_xact_lock(hashtext(%s))",
|
||||
[tenant_id],
|
||||
)
|
||||
cursor.execute(
|
||||
COMPLIANCE_UPSERT_TENANT_SUMMARY_SQL,
|
||||
[tenant_id, tenant_id, compliance_ids],
|
||||
)
|
||||
cursor.execute(
|
||||
COMPLIANCE_DELETE_EMPTY_TENANT_SUMMARY_SQL,
|
||||
[tenant_id, compliance_ids],
|
||||
)
|
||||
|
||||
|
||||
def delete_provider(tenant_id: str, pk: str):
|
||||
@@ -59,32 +60,39 @@ def delete_provider(tenant_id: str, pk: str):
|
||||
"""
|
||||
|
||||
# Get all provider related data to delete them in batches
|
||||
with rls_transaction(tenant_id):
|
||||
try:
|
||||
instance = Provider.all_objects.get(pk=pk)
|
||||
except Provider.DoesNotExist:
|
||||
logger.info(f"Provider `{pk}` already deleted, skipping")
|
||||
return {}
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
try:
|
||||
instance = Provider.all_objects.get(pk=pk)
|
||||
except Provider.DoesNotExist:
|
||||
logger.info(f"Provider `{pk}` already deleted, skipping")
|
||||
return {}
|
||||
|
||||
compliance_ids = list(
|
||||
ProviderComplianceScore.objects.filter(provider=instance)
|
||||
.values_list("compliance_id", flat=True)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
attack_paths_scan_ids = list(
|
||||
AttackPathsScan.all_objects.filter(provider=instance).values_list(
|
||||
"id", flat=True
|
||||
compliance_ids = list(
|
||||
ProviderComplianceScore.objects.filter(provider=instance)
|
||||
.values_list("compliance_id", flat=True)
|
||||
.distinct()
|
||||
)
|
||||
)
|
||||
|
||||
deletion_steps = [
|
||||
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
|
||||
("Findings", Finding.all_objects.filter(scan__provider=instance)),
|
||||
("Resources", Resource.all_objects.filter(provider=instance)),
|
||||
("Scans", Scan.all_objects.filter(provider=instance)),
|
||||
("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)),
|
||||
]
|
||||
attack_paths_scan_ids = list(
|
||||
AttackPathsScan.all_objects.filter(provider=instance).values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
)
|
||||
|
||||
deletion_steps = [
|
||||
(
|
||||
"Scan Summaries",
|
||||
ScanSummary.all_objects.filter(scan__provider=instance),
|
||||
),
|
||||
("Findings", Finding.all_objects.filter(scan__provider=instance)),
|
||||
("Resources", Resource.all_objects.filter(provider=instance)),
|
||||
("Scans", Scan.all_objects.filter(provider=instance)),
|
||||
(
|
||||
"AttackPathsScans",
|
||||
AttackPathsScan.all_objects.filter(provider=instance),
|
||||
),
|
||||
]
|
||||
|
||||
# Drop orphaned temporary Neo4j databases
|
||||
for aps_id in attack_paths_scan_ids:
|
||||
@@ -116,8 +124,9 @@ def delete_provider(tenant_id: str, pk: str):
|
||||
|
||||
# Delete the provider instance itself
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
_, provider_summary = instance.delete()
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
_, provider_summary = instance.delete()
|
||||
deletion_summary.update(provider_summary)
|
||||
except DatabaseError as db_error:
|
||||
logger.error(f"Error deleting Provider: {db_error}")
|
||||
|
||||
@@ -295,8 +295,9 @@ def _build_output_path(
|
||||
# Sanitize the prowler provider name to ensure it is a valid directory name
|
||||
prowler_provider_sanitized = re.sub(r"[^\w\-]", "-", prowler_provider)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
started_at = Scan.objects.get(id=scan_id).started_at
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
started_at = Scan.objects.get(id=scan_id).started_at
|
||||
|
||||
set_output_timestamp(started_at)
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import os
|
||||
import time
|
||||
from glob import glob
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
|
||||
from django.db import OperationalError
|
||||
from tasks.utils import batched
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS, MainRouter
|
||||
from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Finding, Integration, Provider
|
||||
from api.utils import initialize_prowler_integration, initialize_prowler_provider
|
||||
from prowler.lib.outputs.asff.asff import ASFF
|
||||
@@ -78,14 +76,15 @@ def upload_s3_integration(
|
||||
logger.info(f"Processing S3 integrations for provider {provider_id}")
|
||||
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
integrations = list(
|
||||
Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
integration_type=Integration.IntegrationChoices.AMAZON_S3,
|
||||
enabled=True,
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
integrations = list(
|
||||
Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
integration_type=Integration.IntegrationChoices.AMAZON_S3,
|
||||
enabled=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not integrations:
|
||||
logger.error(f"No S3 integrations found for provider {provider_id}")
|
||||
@@ -184,14 +183,18 @@ def get_security_hub_client_from_integration(
|
||||
if the connection was successful and the SecurityHub client or connection object.
|
||||
"""
|
||||
# Get the provider associated with this integration
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
provider_relationship = integration.integrationproviderrelationship_set.first()
|
||||
if not provider_relationship:
|
||||
return Connection(
|
||||
is_connected=False, error="No provider associated with this integration"
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
provider_relationship = (
|
||||
integration.integrationproviderrelationship_set.first()
|
||||
)
|
||||
provider_uid = provider_relationship.provider.uid
|
||||
provider_secret = provider_relationship.provider.secret.secret
|
||||
if not provider_relationship:
|
||||
return Connection(
|
||||
is_connected=False,
|
||||
error="No provider associated with this integration",
|
||||
)
|
||||
provider_uid = provider_relationship.provider.uid
|
||||
provider_secret = provider_relationship.provider.secret.secret
|
||||
|
||||
credentials = (
|
||||
integration.credentials if integration.credentials else provider_secret
|
||||
@@ -213,9 +216,10 @@ def get_security_hub_client_from_integration(
|
||||
regions_status[region] = region in connection.enabled_regions
|
||||
|
||||
# Save regions information in the integration configuration
|
||||
with rls_transaction(tenant_id, using=MainRouter.default_db):
|
||||
integration.configuration["regions"] = regions_status
|
||||
integration.save()
|
||||
for attempt in rls_transaction(tenant_id, using=MainRouter.default_db):
|
||||
with attempt:
|
||||
integration.configuration["regions"] = regions_status
|
||||
integration.save()
|
||||
|
||||
# Create SecurityHub client with all necessary parameters
|
||||
security_hub = SecurityHub(
|
||||
@@ -228,10 +232,11 @@ def get_security_hub_client_from_integration(
|
||||
return True, security_hub
|
||||
else:
|
||||
# Reset regions information if connection fails and integration is not connected
|
||||
with rls_transaction(tenant_id, using=MainRouter.default_db):
|
||||
integration.connected = False
|
||||
integration.configuration["regions"] = {}
|
||||
integration.save()
|
||||
for attempt in rls_transaction(tenant_id, using=MainRouter.default_db):
|
||||
with attempt:
|
||||
integration.connected = False
|
||||
integration.configuration["regions"] = {}
|
||||
integration.save()
|
||||
|
||||
return False, connection
|
||||
|
||||
@@ -256,27 +261,28 @@ def upload_security_hub_integration(
|
||||
logger.info(f"Processing Security Hub integrations for provider {provider_id}")
|
||||
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
# Get Security Hub integrations for this provider
|
||||
integrations = list(
|
||||
Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
|
||||
enabled=True,
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
# Get Security Hub integrations for this provider
|
||||
integrations = list(
|
||||
Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB,
|
||||
enabled=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not integrations:
|
||||
logger.error(
|
||||
f"No Security Hub integrations found for provider {provider_id}"
|
||||
)
|
||||
return False
|
||||
if not integrations:
|
||||
logger.error(
|
||||
f"No Security Hub integrations found for provider {provider_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Get the provider object
|
||||
provider = Provider.objects.get(id=provider_id)
|
||||
# Get the provider object
|
||||
provider = Provider.objects.get(id=provider_id)
|
||||
|
||||
# Initialize prowler provider for finding transformation
|
||||
prowler_provider = initialize_prowler_provider(provider)
|
||||
# Initialize prowler provider for finding transformation
|
||||
prowler_provider = initialize_prowler_provider(provider)
|
||||
|
||||
# Process each Security Hub integration
|
||||
integration_executions = 0
|
||||
@@ -293,130 +299,104 @@ def upload_security_hub_integration(
|
||||
total_findings_sent[integration.id] = 0
|
||||
|
||||
# Process findings in batches to avoid memory issues
|
||||
max_attempts = REPLICA_MAX_ATTEMPTS if READ_REPLICA_ALIAS else 1
|
||||
has_findings = False
|
||||
batch_number = 0
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
read_alias = None
|
||||
if READ_REPLICA_ALIAS:
|
||||
read_alias = (
|
||||
READ_REPLICA_ALIAS
|
||||
if attempt < max_attempts
|
||||
else MainRouter.default_db
|
||||
)
|
||||
|
||||
try:
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
batch_number = 0
|
||||
has_findings = False
|
||||
with rls_transaction(
|
||||
tenant_id,
|
||||
using=read_alias,
|
||||
retry_on_replica=False,
|
||||
):
|
||||
qs = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
)
|
||||
.order_by("uid")
|
||||
.iterator()
|
||||
qs = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
)
|
||||
|
||||
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
batch_number += 1
|
||||
has_findings = True
|
||||
|
||||
# Transform findings for this batch
|
||||
transformed_findings = [
|
||||
FindingOutput.transform_api_finding(
|
||||
finding, prowler_provider
|
||||
)
|
||||
for finding in batch
|
||||
]
|
||||
|
||||
# Convert to ASFF format
|
||||
asff_transformer = ASFF(
|
||||
findings=transformed_findings,
|
||||
file_path="",
|
||||
file_extension="json",
|
||||
)
|
||||
asff_transformer.transform(transformed_findings)
|
||||
|
||||
# Get the batch of ASFF findings
|
||||
batch_asff_findings = asff_transformer.data
|
||||
|
||||
if batch_asff_findings:
|
||||
# Create Security Hub client for first batch or reuse existing
|
||||
if not security_hub_client:
|
||||
connected, security_hub = (
|
||||
get_security_hub_client_from_integration(
|
||||
integration,
|
||||
tenant_id,
|
||||
batch_asff_findings,
|
||||
)
|
||||
)
|
||||
|
||||
if not connected:
|
||||
if isinstance(
|
||||
security_hub.error,
|
||||
SecurityHubNoEnabledRegionsError,
|
||||
):
|
||||
logger.warning(
|
||||
f"Security Hub integration {integration.id} has no enabled regions"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Security Hub connection failed for integration {integration.id}: "
|
||||
f"{security_hub.error}"
|
||||
)
|
||||
break # Skip this integration
|
||||
|
||||
security_hub_client = security_hub
|
||||
logger.info(
|
||||
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
|
||||
f"integration {integration.id}"
|
||||
)
|
||||
else:
|
||||
# Update findings in existing client for this batch
|
||||
security_hub_client._findings_per_region = (
|
||||
security_hub_client.filter(
|
||||
batch_asff_findings,
|
||||
send_only_fails,
|
||||
)
|
||||
)
|
||||
|
||||
# Send this batch to Security Hub
|
||||
try:
|
||||
findings_sent = security_hub_client.batch_send_to_security_hub()
|
||||
total_findings_sent[integration.id] += (
|
||||
findings_sent
|
||||
)
|
||||
|
||||
if findings_sent > 0:
|
||||
logger.debug(
|
||||
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
|
||||
)
|
||||
except Exception as batch_error:
|
||||
logger.error(
|
||||
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
|
||||
)
|
||||
|
||||
# Clear memory after processing each batch
|
||||
asff_transformer._data.clear()
|
||||
del batch_asff_findings
|
||||
del transformed_findings
|
||||
|
||||
break
|
||||
except OperationalError as e:
|
||||
if attempt == max_attempts:
|
||||
raise
|
||||
|
||||
delay = REPLICA_RETRY_BASE_DELAY * (2 ** (attempt - 1))
|
||||
logger.info(
|
||||
"RLS query failed during Security Hub integration "
|
||||
f"(attempt {attempt}/{max_attempts}), retrying in {delay}s. Error: {e}"
|
||||
.order_by("uid")
|
||||
.iterator()
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
for batch, _ in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
batch_number += 1
|
||||
has_findings = True
|
||||
|
||||
# Transform findings for this batch
|
||||
transformed_findings = [
|
||||
FindingOutput.transform_api_finding(
|
||||
finding, prowler_provider
|
||||
)
|
||||
for finding in batch
|
||||
]
|
||||
|
||||
# Convert to ASFF format
|
||||
asff_transformer = ASFF(
|
||||
findings=transformed_findings,
|
||||
file_path="",
|
||||
file_extension="json",
|
||||
)
|
||||
asff_transformer.transform(transformed_findings)
|
||||
|
||||
# Get the batch of ASFF findings
|
||||
batch_asff_findings = asff_transformer.data
|
||||
|
||||
if batch_asff_findings:
|
||||
# Create Security Hub client for first batch or reuse existing
|
||||
if not security_hub_client:
|
||||
connected, security_hub = (
|
||||
get_security_hub_client_from_integration(
|
||||
integration,
|
||||
tenant_id,
|
||||
batch_asff_findings,
|
||||
)
|
||||
)
|
||||
|
||||
if not connected:
|
||||
if isinstance(
|
||||
security_hub.error,
|
||||
SecurityHubNoEnabledRegionsError,
|
||||
):
|
||||
logger.warning(
|
||||
f"Security Hub integration {integration.id} has no enabled regions"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Security Hub connection failed for integration {integration.id}: "
|
||||
f"{security_hub.error}"
|
||||
)
|
||||
break # Skip this integration
|
||||
|
||||
security_hub_client = security_hub
|
||||
logger.info(
|
||||
f"Sending {'fail' if send_only_fails else 'all'} findings to Security Hub via "
|
||||
f"integration {integration.id}"
|
||||
)
|
||||
else:
|
||||
# Update findings in existing client for this batch
|
||||
security_hub_client._findings_per_region = (
|
||||
security_hub_client.filter(
|
||||
batch_asff_findings,
|
||||
send_only_fails,
|
||||
)
|
||||
)
|
||||
|
||||
# Send this batch to Security Hub
|
||||
try:
|
||||
findings_sent = (
|
||||
security_hub_client.batch_send_to_security_hub()
|
||||
)
|
||||
total_findings_sent[integration.id] += findings_sent
|
||||
|
||||
if findings_sent > 0:
|
||||
logger.debug(
|
||||
f"Sent batch {batch_number} with {findings_sent} findings to Security Hub"
|
||||
)
|
||||
except Exception as batch_error:
|
||||
logger.error(
|
||||
f"Failed to send batch {batch_number} to Security Hub: {str(batch_error)}"
|
||||
)
|
||||
|
||||
# Clear memory after processing each batch
|
||||
asff_transformer._data.clear()
|
||||
del batch_asff_findings
|
||||
del transformed_findings
|
||||
|
||||
if not has_findings:
|
||||
logger.info(
|
||||
@@ -479,67 +459,69 @@ def send_findings_to_jira(
|
||||
issue_type: str,
|
||||
finding_ids: list[str],
|
||||
):
|
||||
with rls_transaction(tenant_id):
|
||||
integration = Integration.objects.get(id=integration_id)
|
||||
jira_integration = initialize_prowler_integration(integration)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
integration = Integration.objects.get(id=integration_id)
|
||||
jira_integration = initialize_prowler_integration(integration)
|
||||
|
||||
num_tickets_created = 0
|
||||
for finding_id in finding_ids:
|
||||
with rls_transaction(tenant_id):
|
||||
finding_instance = (
|
||||
Finding.all_objects.select_related("scan__provider")
|
||||
.prefetch_related("resources")
|
||||
.get(id=finding_id)
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
finding_instance = (
|
||||
Finding.all_objects.select_related("scan__provider")
|
||||
.prefetch_related("resources")
|
||||
.get(id=finding_id)
|
||||
)
|
||||
|
||||
# Extract resource information
|
||||
resource = (
|
||||
finding_instance.resources.first()
|
||||
if finding_instance.resources.exists()
|
||||
else None
|
||||
)
|
||||
resource_uid = resource.uid if resource else ""
|
||||
resource_name = resource.name if resource else ""
|
||||
resource_tags = {}
|
||||
if resource and hasattr(resource, "tags"):
|
||||
resource_tags = resource.get_tags(tenant_id)
|
||||
# Extract resource information
|
||||
resource = (
|
||||
finding_instance.resources.first()
|
||||
if finding_instance.resources.exists()
|
||||
else None
|
||||
)
|
||||
resource_uid = resource.uid if resource else ""
|
||||
resource_name = resource.name if resource else ""
|
||||
resource_tags = {}
|
||||
if resource and hasattr(resource, "tags"):
|
||||
resource_tags = resource.get_tags(tenant_id)
|
||||
|
||||
# Get region
|
||||
region = resource.region if resource and resource.region else ""
|
||||
# Get region
|
||||
region = resource.region if resource and resource.region else ""
|
||||
|
||||
# Extract remediation information from check_metadata
|
||||
check_metadata = finding_instance.check_metadata
|
||||
remediation = check_metadata.get("remediation", {})
|
||||
recommendation = remediation.get("recommendation", {})
|
||||
remediation_code = remediation.get("code", {})
|
||||
# Extract remediation information from check_metadata
|
||||
check_metadata = finding_instance.check_metadata
|
||||
remediation = check_metadata.get("remediation", {})
|
||||
recommendation = remediation.get("recommendation", {})
|
||||
remediation_code = remediation.get("code", {})
|
||||
|
||||
# Send the individual finding to Jira
|
||||
result = jira_integration.send_finding(
|
||||
check_id=finding_instance.check_id,
|
||||
check_title=check_metadata.get("checktitle", ""),
|
||||
severity=finding_instance.severity,
|
||||
status=finding_instance.status,
|
||||
status_extended=finding_instance.status_extended or "",
|
||||
provider=finding_instance.scan.provider.provider,
|
||||
region=region,
|
||||
resource_uid=resource_uid,
|
||||
resource_name=resource_name,
|
||||
risk=check_metadata.get("risk", ""),
|
||||
recommendation_text=recommendation.get("text", ""),
|
||||
recommendation_url=recommendation.get("url", ""),
|
||||
remediation_code_native_iac=remediation_code.get("nativeiac", ""),
|
||||
remediation_code_terraform=remediation_code.get("terraform", ""),
|
||||
remediation_code_cli=remediation_code.get("cli", ""),
|
||||
remediation_code_other=remediation_code.get("other", ""),
|
||||
resource_tags=resource_tags,
|
||||
compliance=finding_instance.compliance or {},
|
||||
project_key=project_key,
|
||||
issue_type=issue_type,
|
||||
)
|
||||
if result:
|
||||
num_tickets_created += 1
|
||||
else:
|
||||
logger.error(f"Failed to send finding {finding_id} to Jira")
|
||||
# Send the individual finding to Jira
|
||||
result = jira_integration.send_finding(
|
||||
check_id=finding_instance.check_id,
|
||||
check_title=check_metadata.get("checktitle", ""),
|
||||
severity=finding_instance.severity,
|
||||
status=finding_instance.status,
|
||||
status_extended=finding_instance.status_extended or "",
|
||||
provider=finding_instance.scan.provider.provider,
|
||||
region=region,
|
||||
resource_uid=resource_uid,
|
||||
resource_name=resource_name,
|
||||
risk=check_metadata.get("risk", ""),
|
||||
recommendation_text=recommendation.get("text", ""),
|
||||
recommendation_url=recommendation.get("url", ""),
|
||||
remediation_code_native_iac=remediation_code.get("nativeiac", ""),
|
||||
remediation_code_terraform=remediation_code.get("terraform", ""),
|
||||
remediation_code_cli=remediation_code.get("cli", ""),
|
||||
remediation_code_other=remediation_code.get("other", ""),
|
||||
resource_tags=resource_tags,
|
||||
compliance=finding_instance.compliance or {},
|
||||
project_key=project_key,
|
||||
issue_type=issue_type,
|
||||
)
|
||||
if result:
|
||||
num_tickets_created += 1
|
||||
else:
|
||||
logger.error(f"Failed to send finding {finding_id} to Jira")
|
||||
|
||||
return {
|
||||
"created_count": num_tickets_created,
|
||||
|
||||
@@ -25,36 +25,38 @@ def mute_historical_findings(tenant_id: str, mute_rule_id: str):
|
||||
findings_muted_count = 0
|
||||
|
||||
# Get the list of UIDs to mute and the reason
|
||||
with rls_transaction(tenant_id):
|
||||
mute_rule = MuteRule.objects.get(id=mute_rule_id, tenant_id=tenant_id)
|
||||
finding_uids = mute_rule.finding_uids
|
||||
mute_reason = mute_rule.reason
|
||||
muted_at = mute_rule.inserted_at
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
mute_rule = MuteRule.objects.get(id=mute_rule_id, tenant_id=tenant_id)
|
||||
finding_uids = mute_rule.finding_uids
|
||||
mute_reason = mute_rule.reason
|
||||
muted_at = mute_rule.inserted_at
|
||||
|
||||
# Query findings that match the UIDs and are not already muted
|
||||
with rls_transaction(tenant_id):
|
||||
findings_to_mute = Finding.objects.filter(
|
||||
tenant_id=tenant_id, uid__in=finding_uids, muted=False
|
||||
)
|
||||
total_findings = findings_to_mute.count()
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
findings_to_mute = Finding.objects.filter(
|
||||
tenant_id=tenant_id, uid__in=finding_uids, muted=False
|
||||
)
|
||||
total_findings = findings_to_mute.count()
|
||||
|
||||
logger.info(
|
||||
f"Processing {total_findings} findings for mute rule {mute_rule_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"Processing {total_findings} findings for mute rule {mute_rule_id}"
|
||||
)
|
||||
|
||||
if total_findings > 0:
|
||||
for batch, is_last in batched(
|
||||
findings_to_mute.iterator(), DJANGO_FINDINGS_BATCH_SIZE
|
||||
):
|
||||
batch_ids = [f.id for f in batch]
|
||||
updated_count = Finding.all_objects.filter(
|
||||
id__in=batch_ids, tenant_id=tenant_id
|
||||
).update(
|
||||
muted=True,
|
||||
muted_at=muted_at,
|
||||
muted_reason=mute_reason,
|
||||
)
|
||||
findings_muted_count += updated_count
|
||||
if total_findings > 0:
|
||||
for batch, is_last in batched(
|
||||
findings_to_mute.iterator(), DJANGO_FINDINGS_BATCH_SIZE
|
||||
):
|
||||
batch_ids = [f.id for f in batch]
|
||||
updated_count = Finding.all_objects.filter(
|
||||
id__in=batch_ids, tenant_id=tenant_id
|
||||
).update(
|
||||
muted=True,
|
||||
muted_at=muted_at,
|
||||
muted_reason=mute_reason,
|
||||
)
|
||||
findings_muted_count += updated_count
|
||||
|
||||
logger.info(f"Muted {findings_muted_count} findings for rule {mute_rule_id}")
|
||||
|
||||
|
||||
@@ -248,22 +248,23 @@ def generate_compliance_reports(
|
||||
results = {}
|
||||
|
||||
# Validate that the scan has findings and get provider info
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
if not ScanSummary.objects.filter(scan_id=scan_id).exists():
|
||||
logger.info("No findings found for scan %s", scan_id)
|
||||
if generate_threatscore:
|
||||
results["threatscore"] = {"upload": False, "path": ""}
|
||||
if generate_ens:
|
||||
results["ens"] = {"upload": False, "path": ""}
|
||||
if generate_nis2:
|
||||
results["nis2"] = {"upload": False, "path": ""}
|
||||
if generate_csa:
|
||||
results["csa"] = {"upload": False, "path": ""}
|
||||
return results
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
if not ScanSummary.objects.filter(scan_id=scan_id).exists():
|
||||
logger.info("No findings found for scan %s", scan_id)
|
||||
if generate_threatscore:
|
||||
results["threatscore"] = {"upload": False, "path": ""}
|
||||
if generate_ens:
|
||||
results["ens"] = {"upload": False, "path": ""}
|
||||
if generate_nis2:
|
||||
results["nis2"] = {"upload": False, "path": ""}
|
||||
if generate_csa:
|
||||
results["csa"] = {"upload": False, "path": ""}
|
||||
return results
|
||||
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
provider_uid = provider_obj.uid
|
||||
provider_type = provider_obj.provider
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
provider_uid = provider_obj.uid
|
||||
provider_type = provider_obj.provider
|
||||
|
||||
# Check provider compatibility
|
||||
if generate_threatscore and provider_type not in [
|
||||
@@ -398,49 +399,50 @@ def generate_compliance_reports(
|
||||
min_risk_level=min_risk_level_threatscore,
|
||||
)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
previous_snapshot = (
|
||||
ThreatScoreSnapshot.objects.filter(
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
previous_snapshot = (
|
||||
ThreatScoreSnapshot.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
compliance_id=compliance_id_threatscore,
|
||||
)
|
||||
.order_by("-inserted_at")
|
||||
.first()
|
||||
)
|
||||
|
||||
score_delta = None
|
||||
if previous_snapshot:
|
||||
score_delta = metrics["overall_score"] - float(
|
||||
previous_snapshot.overall_score
|
||||
)
|
||||
|
||||
snapshot = ThreatScoreSnapshot.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
provider_id=provider_id,
|
||||
compliance_id=compliance_id_threatscore,
|
||||
)
|
||||
.order_by("-inserted_at")
|
||||
.first()
|
||||
)
|
||||
|
||||
score_delta = None
|
||||
if previous_snapshot:
|
||||
score_delta = metrics["overall_score"] - float(
|
||||
previous_snapshot.overall_score
|
||||
overall_score=metrics["overall_score"],
|
||||
score_delta=score_delta,
|
||||
section_scores=metrics["section_scores"],
|
||||
critical_requirements=metrics["critical_requirements"],
|
||||
total_requirements=metrics["total_requirements"],
|
||||
passed_requirements=metrics["passed_requirements"],
|
||||
failed_requirements=metrics["failed_requirements"],
|
||||
manual_requirements=metrics["manual_requirements"],
|
||||
total_findings=metrics["total_findings"],
|
||||
passed_findings=metrics["passed_findings"],
|
||||
failed_findings=metrics["failed_findings"],
|
||||
)
|
||||
|
||||
snapshot = ThreatScoreSnapshot.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
provider_id=provider_id,
|
||||
compliance_id=compliance_id_threatscore,
|
||||
overall_score=metrics["overall_score"],
|
||||
score_delta=score_delta,
|
||||
section_scores=metrics["section_scores"],
|
||||
critical_requirements=metrics["critical_requirements"],
|
||||
total_requirements=metrics["total_requirements"],
|
||||
passed_requirements=metrics["passed_requirements"],
|
||||
failed_requirements=metrics["failed_requirements"],
|
||||
manual_requirements=metrics["manual_requirements"],
|
||||
total_findings=metrics["total_findings"],
|
||||
passed_findings=metrics["passed_findings"],
|
||||
failed_findings=metrics["failed_findings"],
|
||||
)
|
||||
|
||||
delta_msg = (
|
||||
f" (delta: {score_delta:+.2f}%)"
|
||||
if score_delta is not None
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
f"ThreatScore snapshot created with ID {snapshot.id} (score: {snapshot.overall_score}%{delta_msg})",
|
||||
)
|
||||
delta_msg = (
|
||||
f" (delta: {score_delta:+.2f}%)"
|
||||
if score_delta is not None
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
f"ThreatScore snapshot created with ID {snapshot.id} (score: {snapshot.overall_score}%{delta_msg})",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error creating ThreatScore snapshot: %s", e)
|
||||
|
||||
|
||||
@@ -750,25 +750,26 @@ class BaseComplianceReportGenerator(ABC):
|
||||
Returns:
|
||||
Aggregated ComplianceData object
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Load provider
|
||||
if provider_obj is None:
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
# Load provider
|
||||
if provider_obj is None:
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
|
||||
prowler_provider = initialize_prowler_provider(provider_obj)
|
||||
provider_type = provider_obj.provider
|
||||
prowler_provider = initialize_prowler_provider(provider_obj)
|
||||
provider_type = provider_obj.provider
|
||||
|
||||
# Load compliance framework
|
||||
frameworks_bulk = Compliance.get_bulk(provider_type)
|
||||
compliance_obj = frameworks_bulk.get(compliance_id)
|
||||
# Load compliance framework
|
||||
frameworks_bulk = Compliance.get_bulk(provider_type)
|
||||
compliance_obj = frameworks_bulk.get(compliance_id)
|
||||
|
||||
if not compliance_obj:
|
||||
raise ValueError(f"Compliance framework not found: {compliance_id}")
|
||||
if not compliance_obj:
|
||||
raise ValueError(f"Compliance framework not found: {compliance_id}")
|
||||
|
||||
framework = getattr(compliance_obj, "Framework", "N/A")
|
||||
name = getattr(compliance_obj, "Name", "N/A")
|
||||
version = getattr(compliance_obj, "Version", "N/A")
|
||||
description = getattr(compliance_obj, "Description", "")
|
||||
framework = getattr(compliance_obj, "Framework", "N/A")
|
||||
name = getattr(compliance_obj, "Name", "N/A")
|
||||
version = getattr(compliance_obj, "Version", "N/A")
|
||||
description = getattr(compliance_obj, "Description", "")
|
||||
|
||||
# Aggregate requirement statistics
|
||||
if requirement_statistics is None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -58,12 +58,13 @@ def compute_threatscore_metrics(
|
||||
>>> print(f"Overall ThreatScore: {metrics['overall_score']:.2f}%")
|
||||
"""
|
||||
# Get provider and compliance information
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
provider_type = provider_obj.provider
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
provider_type = provider_obj.provider
|
||||
|
||||
frameworks_bulk = Compliance.get_bulk(provider_type)
|
||||
compliance_obj = frameworks_bulk[compliance_id]
|
||||
frameworks_bulk = Compliance.get_bulk(provider_type)
|
||||
compliance_obj = frameworks_bulk[compliance_id]
|
||||
|
||||
# Aggregate requirement statistics from database
|
||||
requirement_statistics_by_check_id = (
|
||||
|
||||
@@ -36,35 +36,36 @@ def _aggregate_requirement_statistics_from_database(
|
||||
"""
|
||||
requirement_statistics_by_check_id = {}
|
||||
# TODO: take into account that now the relation is 1 finding == 1 resource, review this when the logic changes
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
aggregated_statistics_queryset = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
muted=False,
|
||||
resources__provider__is_deleted=False,
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
aggregated_statistics_queryset = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
muted=False,
|
||||
resources__provider__is_deleted=False,
|
||||
)
|
||||
.values("check_id")
|
||||
.annotate(
|
||||
total_findings=Count(
|
||||
"id",
|
||||
distinct=True,
|
||||
filter=Q(status__in=[StatusChoices.PASS, StatusChoices.FAIL]),
|
||||
),
|
||||
passed_findings=Count(
|
||||
"id",
|
||||
distinct=True,
|
||||
filter=Q(status=StatusChoices.PASS),
|
||||
),
|
||||
)
|
||||
)
|
||||
.values("check_id")
|
||||
.annotate(
|
||||
total_findings=Count(
|
||||
"id",
|
||||
distinct=True,
|
||||
filter=Q(status__in=[StatusChoices.PASS, StatusChoices.FAIL]),
|
||||
),
|
||||
passed_findings=Count(
|
||||
"id",
|
||||
distinct=True,
|
||||
filter=Q(status=StatusChoices.PASS),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for aggregated_stat in aggregated_statistics_queryset:
|
||||
check_id = aggregated_stat["check_id"]
|
||||
requirement_statistics_by_check_id[check_id] = {
|
||||
"passed": aggregated_stat["passed_findings"],
|
||||
"total": aggregated_stat["total_findings"],
|
||||
}
|
||||
for aggregated_stat in aggregated_statistics_queryset:
|
||||
check_id = aggregated_stat["check_id"]
|
||||
requirement_statistics_by_check_id[check_id] = {
|
||||
"passed": aggregated_stat["passed_findings"],
|
||||
"total": aggregated_stat["total_findings"],
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Aggregated statistics for {len(requirement_statistics_by_check_id)} unique checks"
|
||||
@@ -220,35 +221,36 @@ def _load_findings_for_requirement_checks(
|
||||
f"Loading findings for {len(check_ids_to_load)} checks from database"
|
||||
)
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Use iterator with chunk_size for memory-efficient streaming
|
||||
# chunk_size controls how many rows Django fetches from DB at once
|
||||
findings_queryset = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
check_id__in=check_ids_to_load,
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
# Use iterator with chunk_size for memory-efficient streaming
|
||||
# chunk_size controls how many rows Django fetches from DB at once
|
||||
findings_queryset = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
check_id__in=check_ids_to_load,
|
||||
)
|
||||
.order_by("check_id", "uid")
|
||||
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
|
||||
)
|
||||
.order_by("check_id", "uid")
|
||||
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
|
||||
)
|
||||
|
||||
# Pre-initialize empty lists for all check_ids to load
|
||||
# This avoids repeated dict lookups and 'if not in' checks
|
||||
for check_id in check_ids_to_load:
|
||||
findings_cache[check_id] = []
|
||||
# Pre-initialize empty lists for all check_ids to load
|
||||
# This avoids repeated dict lookups and 'if not in' checks
|
||||
for check_id in check_ids_to_load:
|
||||
findings_cache[check_id] = []
|
||||
|
||||
findings_count = 0
|
||||
for finding_model in findings_queryset:
|
||||
finding_output = FindingOutput.transform_api_finding(
|
||||
finding_model, prowler_provider
|
||||
findings_count = 0
|
||||
for finding_model in findings_queryset:
|
||||
finding_output = FindingOutput.transform_api_finding(
|
||||
finding_model, prowler_provider
|
||||
)
|
||||
findings_cache[finding_output.check_id].append(finding_output)
|
||||
findings_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Loaded {findings_count} findings for {len(check_ids_to_load)} checks"
|
||||
)
|
||||
findings_cache[finding_output.check_id].append(finding_output)
|
||||
findings_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Loaded {findings_count} findings for {len(check_ids_to_load)} checks"
|
||||
)
|
||||
|
||||
# Build result dict using cache references (no data duplication)
|
||||
# This shares the same list objects between cache and result
|
||||
|
||||
@@ -280,58 +280,59 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
"""
|
||||
task_id = self.request.id
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
periodic_task_instance = PeriodicTask.objects.get(
|
||||
name=f"scan-perform-scheduled-{provider_id}"
|
||||
)
|
||||
executing_scan = (
|
||||
Scan.objects.filter(
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
periodic_task_instance = PeriodicTask.objects.get(
|
||||
name=f"scan-perform-scheduled-{provider_id}"
|
||||
)
|
||||
executing_scan = (
|
||||
Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.EXECUTING,
|
||||
)
|
||||
.order_by("-started_at")
|
||||
.first()
|
||||
)
|
||||
if executing_scan:
|
||||
logger.warning(
|
||||
f"Scheduled scan already executing for provider {provider_id}. Skipping."
|
||||
)
|
||||
return ScanTaskSerializer(instance=executing_scan).data
|
||||
|
||||
executed_scan = Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||
state=StateChoices.EXECUTING,
|
||||
task__task_runner_task__task_id=task_id,
|
||||
).first()
|
||||
|
||||
if executed_scan:
|
||||
# Duplicated task execution due to visibility timeout
|
||||
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
|
||||
return ScanTaskSerializer(instance=executed_scan).data
|
||||
|
||||
interval = periodic_task_instance.interval
|
||||
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
|
||||
current_scan_datetime = next_scan_datetime - timedelta(
|
||||
**{interval.period: interval.every}
|
||||
)
|
||||
.order_by("-started_at")
|
||||
.first()
|
||||
)
|
||||
if executing_scan:
|
||||
logger.warning(
|
||||
f"Scheduled scan already executing for provider {provider_id}. Skipping."
|
||||
|
||||
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
|
||||
_cleanup_orphan_scheduled_scans(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
)
|
||||
return ScanTaskSerializer(instance=executing_scan).data
|
||||
|
||||
executed_scan = Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
task__task_runner_task__task_id=task_id,
|
||||
).first()
|
||||
|
||||
if executed_scan:
|
||||
# Duplicated task execution due to visibility timeout
|
||||
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
|
||||
return ScanTaskSerializer(instance=executed_scan).data
|
||||
|
||||
interval = periodic_task_instance.interval
|
||||
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
|
||||
current_scan_datetime = next_scan_datetime - timedelta(
|
||||
**{interval.period: interval.every}
|
||||
)
|
||||
|
||||
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
|
||||
_cleanup_orphan_scheduled_scans(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
)
|
||||
|
||||
scan_instance = _get_or_create_scheduled_scan(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
scheduled_at=current_scan_datetime,
|
||||
)
|
||||
scan_instance.task_id = task_id
|
||||
scan_instance.save()
|
||||
scan_instance = _get_or_create_scheduled_scan(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
scheduled_at=current_scan_datetime,
|
||||
)
|
||||
scan_instance.task_id = task_id
|
||||
scan_instance.save()
|
||||
|
||||
try:
|
||||
result = perform_prowler_scan(
|
||||
@@ -340,19 +341,20 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
provider_id=provider_id,
|
||||
)
|
||||
finally:
|
||||
with rls_transaction(tenant_id):
|
||||
now = datetime.now(timezone.utc)
|
||||
if next_scan_datetime <= now:
|
||||
interval_delta = timedelta(**{interval.period: interval.every})
|
||||
while next_scan_datetime <= now:
|
||||
next_scan_datetime += interval_delta
|
||||
_get_or_create_scheduled_scan(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
scheduled_at=next_scan_datetime,
|
||||
update_state=True,
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
now = datetime.now(timezone.utc)
|
||||
if next_scan_datetime <= now:
|
||||
interval_delta = timedelta(**{interval.period: interval.every})
|
||||
while next_scan_datetime <= now:
|
||||
next_scan_datetime += interval_delta
|
||||
_get_or_create_scheduled_scan(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
scheduled_at=next_scan_datetime,
|
||||
update_state=True,
|
||||
)
|
||||
|
||||
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
|
||||
|
||||
@@ -486,67 +488,69 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
|
||||
.order_by("uid")
|
||||
.iterator()
|
||||
)
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
fos = [
|
||||
FindingOutput.transform_api_finding(f, prowler_provider) for f in batch
|
||||
]
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
fos = [
|
||||
FindingOutput.transform_api_finding(f, prowler_provider)
|
||||
for f in batch
|
||||
]
|
||||
|
||||
# Outputs
|
||||
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
|
||||
# Skip ASFF generation if not needed
|
||||
if mode == "json-asff" and not generate_asff:
|
||||
continue
|
||||
# Outputs
|
||||
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
|
||||
# Skip ASFF generation if not needed
|
||||
if mode == "json-asff" and not generate_asff:
|
||||
continue
|
||||
|
||||
cls = cfg["class"]
|
||||
suffix = cfg["suffix"]
|
||||
extra = cfg.get("kwargs", {}).copy()
|
||||
if mode == "html":
|
||||
extra.update(provider=prowler_provider, stats=scan_summary)
|
||||
cls = cfg["class"]
|
||||
suffix = cfg["suffix"]
|
||||
extra = cfg.get("kwargs", {}).copy()
|
||||
if mode == "html":
|
||||
extra.update(provider=prowler_provider, stats=scan_summary)
|
||||
|
||||
writer, initialization = get_writer(
|
||||
output_writers,
|
||||
cls,
|
||||
lambda cls=cls, fos=fos, suffix=suffix: cls(
|
||||
findings=fos,
|
||||
file_path=out_dir,
|
||||
file_extension=suffix,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos)
|
||||
writer.batch_write_data_to_file(**extra)
|
||||
writer._data.clear()
|
||||
writer, initialization = get_writer(
|
||||
output_writers,
|
||||
cls,
|
||||
lambda cls=cls, fos=fos, suffix=suffix: cls(
|
||||
findings=fos,
|
||||
file_path=out_dir,
|
||||
file_extension=suffix,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos)
|
||||
writer.batch_write_data_to_file(**extra)
|
||||
writer._data.clear()
|
||||
|
||||
# Compliance CSVs
|
||||
for name in frameworks_avail:
|
||||
compliance_obj = frameworks_bulk[name]
|
||||
# Compliance CSVs
|
||||
for name in frameworks_avail:
|
||||
compliance_obj = frameworks_bulk[name]
|
||||
|
||||
klass = GenericCompliance
|
||||
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
|
||||
if condition(name):
|
||||
klass = cls
|
||||
break
|
||||
klass = GenericCompliance
|
||||
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
|
||||
if condition(name):
|
||||
klass = cls
|
||||
break
|
||||
|
||||
filename = f"{comp_dir}_{name}.csv"
|
||||
filename = f"{comp_dir}_{name}.csv"
|
||||
|
||||
writer, initialization = get_writer(
|
||||
compliance_writers,
|
||||
name,
|
||||
lambda klass=klass, fos=fos: klass(
|
||||
findings=fos,
|
||||
compliance=compliance_obj,
|
||||
file_path=filename,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos, compliance_obj, name)
|
||||
writer.batch_write_data_to_file()
|
||||
writer._data.clear()
|
||||
writer, initialization = get_writer(
|
||||
compliance_writers,
|
||||
name,
|
||||
lambda klass=klass, fos=fos: klass(
|
||||
findings=fos,
|
||||
compliance=compliance_obj,
|
||||
file_path=filename,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos, compliance_obj, name)
|
||||
writer.batch_write_data_to_file()
|
||||
writer._data.clear()
|
||||
|
||||
compressed = _compress_output_files(out_dir)
|
||||
|
||||
@@ -569,12 +573,13 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
|
||||
)
|
||||
|
||||
# S3 integrations (need output_directory)
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
s3_integrations = Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
integration_type=Integration.IntegrationChoices.AMAZON_S3,
|
||||
enabled=True,
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
with attempt:
|
||||
s3_integrations = Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
integration_type=Integration.IntegrationChoices.AMAZON_S3,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
if s3_integrations:
|
||||
# Pass the output directory path to S3 integration task to reconstruct objects from files
|
||||
@@ -812,27 +817,32 @@ def check_integrations_task(tenant_id: str, provider_id: str, scan_id: str = Non
|
||||
|
||||
try:
|
||||
integration_tasks = []
|
||||
with rls_transaction(tenant_id):
|
||||
integrations = Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
if not integrations.exists():
|
||||
logger.info(f"No integrations configured for provider {provider_id}")
|
||||
return {"integrations_processed": 0}
|
||||
|
||||
# Security Hub integration
|
||||
security_hub_integrations = integrations.filter(
|
||||
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB
|
||||
)
|
||||
if security_hub_integrations.exists():
|
||||
integration_tasks.append(
|
||||
security_hub_integration_task.s(
|
||||
tenant_id=tenant_id, provider_id=provider_id, scan_id=scan_id
|
||||
)
|
||||
for attempt in rls_transaction(tenant_id):
|
||||
with attempt:
|
||||
integrations = Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
if not integrations.exists():
|
||||
logger.info(
|
||||
f"No integrations configured for provider {provider_id}"
|
||||
)
|
||||
return {"integrations_processed": 0}
|
||||
|
||||
# Security Hub integration
|
||||
security_hub_integrations = integrations.filter(
|
||||
integration_type=Integration.IntegrationChoices.AWS_SECURITY_HUB
|
||||
)
|
||||
if security_hub_integrations.exists():
|
||||
integration_tasks.append(
|
||||
security_hub_integration_task.s(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
scan_id=scan_id,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Add other integration types here
|
||||
# slack_integrations = integrations.filter(
|
||||
# integration_type=Integration.IntegrationChoices.SLACK
|
||||
|
||||
@@ -55,7 +55,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_run_success_flow(
|
||||
self,
|
||||
@@ -201,7 +201,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_run_failure_marks_scan_failed(
|
||||
self,
|
||||
@@ -300,7 +300,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_failure_before_gate_does_not_flip_graph_data_ready_true(
|
||||
self,
|
||||
@@ -403,7 +403,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_run_failure_marks_scan_failed_even_when_drop_database_fails(
|
||||
self,
|
||||
@@ -509,7 +509,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_failure_after_gate_before_drop_restores_graph_data_ready(
|
||||
self,
|
||||
@@ -622,7 +622,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_failure_after_drop_before_sync_leaves_graph_data_ready_false(
|
||||
self,
|
||||
@@ -735,7 +735,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_failure_after_sync_restores_graph_data_ready(
|
||||
self,
|
||||
@@ -853,7 +853,7 @@ class TestAttackPathsRun:
|
||||
)
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
)
|
||||
def test_recovery_failure_does_not_suppress_original_exception(
|
||||
self,
|
||||
@@ -949,7 +949,7 @@ class TestAttackPathsRun:
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
|
||||
@@ -1411,7 +1411,7 @@ class TestAttackPathsFindingsHelpers:
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.findings.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.findings.READ_REPLICA_ALIAS",
|
||||
@@ -2024,7 +2024,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
str(tenant.id), str(scan.id), provider.id
|
||||
@@ -2064,7 +2064,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
str(tenant.id), str(new_scan.id), provider.id
|
||||
@@ -2104,7 +2104,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
str(tenant.id), str(new_scan.id), provider.id
|
||||
@@ -2136,7 +2136,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
set_graph_data_ready(attack_paths_scan, False)
|
||||
|
||||
@@ -2145,7 +2145,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
set_graph_data_ready(attack_paths_scan, True)
|
||||
|
||||
@@ -2175,7 +2175,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
finish_attack_paths_scan(attack_paths_scan, StateChoices.COMPLETED, {})
|
||||
|
||||
@@ -2206,7 +2206,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
finish_attack_paths_scan(
|
||||
attack_paths_scan,
|
||||
@@ -2257,7 +2257,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
set_provider_graph_data_ready(new_ap_scan, False)
|
||||
|
||||
@@ -2309,7 +2309,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.db_utils.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
new=lambda *args, **kwargs: [nullcontext()],
|
||||
):
|
||||
set_provider_graph_data_ready(ap_scan_a, False)
|
||||
|
||||
|
||||
@@ -159,9 +159,8 @@ class TestOutputs:
|
||||
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
|
||||
mock_scan.objects.get.return_value = mock_scan_instance
|
||||
|
||||
# Mock rls_transaction as a context manager
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
|
||||
# Mock rls_transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
|
||||
base_dir = str(base_tmp)
|
||||
@@ -210,9 +209,8 @@ class TestOutputs:
|
||||
mock_scan_instance.started_at = datetime(2023, 6, 15, 10, 30, 45)
|
||||
mock_scan.objects.get.return_value = mock_scan_instance
|
||||
|
||||
# Mock rls_transaction as a context manager
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock(return_value=False)
|
||||
# Mock rls_transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
base_tmp = Path(str(tmpdir.mkdir("generate_output")))
|
||||
base_dir = str(base_tmp)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from django.db import OperationalError
|
||||
from tasks.jobs.integrations import (
|
||||
get_s3_client_from_integration,
|
||||
get_security_hub_client_from_integration,
|
||||
@@ -99,6 +98,7 @@ class TestS3IntegrationUploads:
|
||||
):
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
integration = MagicMock()
|
||||
integration.id = "i-1"
|
||||
@@ -150,6 +150,7 @@ class TestS3IntegrationUploads:
|
||||
):
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
integration = MagicMock()
|
||||
integration.id = "i-1"
|
||||
@@ -176,6 +177,7 @@ class TestS3IntegrationUploads:
|
||||
def test_upload_s3_integration_logs_if_no_integrations(
|
||||
self, mock_logger, mock_integration_model, mock_rls
|
||||
):
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
mock_integration_model.objects.filter.return_value = []
|
||||
output_directory = "/tmp/prowler_output/scan123"
|
||||
result = upload_s3_integration("tenant", "provider", output_directory)
|
||||
@@ -197,6 +199,7 @@ class TestS3IntegrationUploads:
|
||||
):
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
integration = MagicMock()
|
||||
integration.id = "i-1"
|
||||
@@ -226,7 +229,7 @@ class TestS3IntegrationUploads:
|
||||
|
||||
# Mock that no enabled integrations are found
|
||||
mock_integration_filter.return_value = []
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
result = upload_s3_integration(tenant_id, provider_id, output_directory)
|
||||
|
||||
@@ -784,7 +787,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
mock_findings = [{"finding": "test"}]
|
||||
|
||||
# Mock RLS context manager
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
connected, connection = get_security_hub_client_from_integration(
|
||||
mock_integration, tenant_id, mock_findings
|
||||
@@ -857,7 +860,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
mock_findings = [{"finding": "test1"}, {"finding": "test2"}]
|
||||
|
||||
# Mock RLS context manager
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Call the function
|
||||
connected, connection = get_security_hub_client_from_integration(
|
||||
@@ -999,6 +1002,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
@@ -1057,84 +1061,6 @@ class TestSecurityHubIntegrationUploads:
|
||||
mock_security_hub.batch_send_to_security_hub.assert_called_once()
|
||||
mock_security_hub.archive_previous_findings.assert_called_once()
|
||||
|
||||
@patch("tasks.jobs.integrations.time.sleep")
|
||||
@patch("tasks.jobs.integrations.batched")
|
||||
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
|
||||
@patch("tasks.jobs.integrations.initialize_prowler_provider")
|
||||
@patch("tasks.jobs.integrations.rls_transaction")
|
||||
@patch("tasks.jobs.integrations.Integration")
|
||||
@patch("tasks.jobs.integrations.Provider")
|
||||
@patch("tasks.jobs.integrations.Finding")
|
||||
def test_upload_security_hub_integration_retries_on_operational_error(
|
||||
self,
|
||||
mock_finding_model,
|
||||
mock_provider_model,
|
||||
mock_integration_model,
|
||||
mock_rls,
|
||||
mock_initialize_provider,
|
||||
mock_get_security_hub,
|
||||
mock_batched,
|
||||
mock_sleep,
|
||||
):
|
||||
"""Test SecurityHub upload retries on transient OperationalError."""
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
|
||||
integration = MagicMock()
|
||||
integration.id = "integration-1"
|
||||
integration.configuration = {
|
||||
"send_only_fails": True,
|
||||
"archive_previous_findings": False,
|
||||
}
|
||||
mock_integration_model.objects.filter.return_value = [integration]
|
||||
|
||||
provider = MagicMock()
|
||||
mock_provider_model.objects.get.return_value = provider
|
||||
|
||||
mock_prowler_provider = MagicMock()
|
||||
mock_initialize_provider.return_value = mock_prowler_provider
|
||||
|
||||
mock_findings = [MagicMock(), MagicMock()]
|
||||
mock_finding_model.all_objects.filter.return_value.order_by.return_value.iterator.return_value = iter(
|
||||
mock_findings
|
||||
)
|
||||
|
||||
transformed_findings = [MagicMock(), MagicMock()]
|
||||
with patch("tasks.jobs.integrations.FindingOutput") as mock_finding_output:
|
||||
mock_finding_output.transform_api_finding.side_effect = transformed_findings
|
||||
|
||||
with patch("tasks.jobs.integrations.ASFF") as mock_asff:
|
||||
mock_asff_instance = MagicMock()
|
||||
finding1 = MagicMock()
|
||||
finding1.Compliance.Status = "FAILED"
|
||||
finding2 = MagicMock()
|
||||
finding2.Compliance.Status = "FAILED"
|
||||
mock_asff_instance.data = [finding1, finding2]
|
||||
mock_asff_instance._data = MagicMock()
|
||||
mock_asff.return_value = mock_asff_instance
|
||||
|
||||
mock_security_hub = MagicMock()
|
||||
mock_security_hub.batch_send_to_security_hub.return_value = 2
|
||||
mock_get_security_hub.return_value = (True, mock_security_hub)
|
||||
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value.__exit__.return_value = False
|
||||
|
||||
mock_batched.side_effect = [
|
||||
OperationalError("Conflict with recovery"),
|
||||
[(mock_findings, None)],
|
||||
]
|
||||
|
||||
with patch("tasks.jobs.integrations.REPLICA_MAX_ATTEMPTS", 2):
|
||||
with patch("tasks.jobs.integrations.READ_REPLICA_ALIAS", "replica"):
|
||||
result = upload_security_hub_integration(
|
||||
tenant_id, provider_id, scan_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
@patch("tasks.jobs.integrations.get_security_hub_client_from_integration")
|
||||
@patch("tasks.jobs.integrations.initialize_prowler_provider")
|
||||
@patch("tasks.jobs.integrations.rls_transaction")
|
||||
@@ -1154,6 +1080,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock no integrations found
|
||||
mock_integration_model.objects.filter.return_value = []
|
||||
@@ -1183,6 +1110,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
@@ -1226,6 +1154,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
@@ -1300,6 +1229,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration with archive_previous_findings disabled
|
||||
integration = MagicMock()
|
||||
@@ -1375,6 +1305,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
@@ -1443,6 +1374,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock exception during integration retrieval
|
||||
mock_integration_model.objects.filter.side_effect = Exception("Database error")
|
||||
@@ -1476,6 +1408,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration with send_only_fails=True
|
||||
integration = MagicMock()
|
||||
@@ -1570,6 +1503,7 @@ class TestSecurityHubIntegrationUploads:
|
||||
tenant_id = "tenant-id"
|
||||
provider_id = "provider-id"
|
||||
scan_id = "scan-123"
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration with send_only_fails=False
|
||||
integration = MagicMock()
|
||||
@@ -1654,9 +1588,8 @@ class TestJiraIntegration:
|
||||
issue_type = "Task"
|
||||
finding_ids = ["finding-1", "finding-2"]
|
||||
|
||||
# Mock RLS transaction
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock()
|
||||
# Mock RLS transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
@@ -1786,9 +1719,8 @@ class TestJiraIntegration:
|
||||
issue_type = "Task"
|
||||
finding_ids = ["finding-1", "finding-2", "finding-3"]
|
||||
|
||||
# Mock RLS transaction
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock()
|
||||
# Mock RLS transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
@@ -1856,9 +1788,8 @@ class TestJiraIntegration:
|
||||
issue_type = "Task"
|
||||
finding_ids = ["finding-1"]
|
||||
|
||||
# Mock RLS transaction
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock()
|
||||
# Mock RLS transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
@@ -1934,9 +1865,8 @@ class TestJiraIntegration:
|
||||
issue_type = "Task"
|
||||
finding_ids = ["finding-1"]
|
||||
|
||||
# Mock RLS transaction
|
||||
mock_rls_transaction.return_value.__enter__ = MagicMock()
|
||||
mock_rls_transaction.return_value.__exit__ = MagicMock()
|
||||
# Mock RLS transaction as an iterable yielding a context manager
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
# Mock integration
|
||||
integration = MagicMock()
|
||||
|
||||
@@ -43,9 +43,12 @@ from prowler.lib.check.models import Severity
|
||||
from prowler.lib.outputs.finding import Status
|
||||
|
||||
|
||||
@contextmanager
|
||||
def noop_rls_transaction(*args, **kwargs):
|
||||
yield
|
||||
@contextmanager
|
||||
def _ctx():
|
||||
yield
|
||||
|
||||
return [_ctx()]
|
||||
|
||||
|
||||
class FakeFinding:
|
||||
@@ -75,7 +78,7 @@ class TestPerformScan:
|
||||
providers_fixture,
|
||||
):
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -241,6 +244,8 @@ class TestPerformScan:
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
@@ -282,6 +287,8 @@ class TestPerformScan:
|
||||
mock_get_or_create_resource,
|
||||
mock_get_or_create_tag,
|
||||
):
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.id = "provider123"
|
||||
@@ -331,6 +338,8 @@ class TestPerformScan:
|
||||
mock_get_or_create_resource,
|
||||
mock_get_or_create_tag,
|
||||
):
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.id = "provider456"
|
||||
@@ -388,6 +397,8 @@ class TestPerformScan:
|
||||
mock_get_or_create_resource,
|
||||
mock_get_or_create_tag,
|
||||
):
|
||||
mock_rls_transaction.return_value = [MagicMock()]
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.id = "provider456"
|
||||
@@ -438,7 +449,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test that failed findings increment the failed_findings_count"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -516,7 +527,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test that multiple FAIL findings on the same resource increment the counter correctly"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -633,7 +644,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test that muted FAIL findings do not increment the failed_findings_count"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -723,7 +734,7 @@ class TestPerformScan:
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -791,7 +802,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test active MuteRule mutes findings with correct reason"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -908,7 +919,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test inactive MuteRule does not mute findings"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -994,7 +1005,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test mutelist processor takes precedence over MuteRule"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -1080,7 +1091,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test MuteRule with multiple finding UIDs mutes all findings"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -1179,7 +1190,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test scan continues when MuteRule loading fails"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -1262,7 +1273,7 @@ class TestPerformScan:
|
||||
):
|
||||
"""Test muted_at timestamp is set correctly for muted findings"""
|
||||
with (
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch("api.db_utils.rls_transaction", new=lambda *a, **kw: [MagicMock()]),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -2200,9 +2211,7 @@ class TestComplianceRequirementCopy:
|
||||
tenant_id = row["tenant_id"]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, [row])
|
||||
|
||||
@@ -2727,9 +2736,7 @@ class TestComplianceRequirementCopy:
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, [row])
|
||||
|
||||
@@ -2790,9 +2797,7 @@ class TestComplianceRequirementCopy:
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, rows)
|
||||
|
||||
@@ -2854,9 +2859,7 @@ class TestComplianceRequirementCopy:
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, [row])
|
||||
|
||||
@@ -2920,9 +2923,7 @@ class TestCreateComplianceSummaries:
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
|
||||
|
||||
@@ -2961,9 +2962,7 @@ class TestCreateComplianceSummaries:
|
||||
requirement_statuses = {}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
|
||||
|
||||
@@ -2985,9 +2984,7 @@ class TestCreateComplianceSummaries:
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
|
||||
|
||||
@@ -3017,9 +3014,7 @@ class TestCreateComplianceSummaries:
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
|
||||
|
||||
@@ -3067,9 +3062,7 @@ class TestCreateComplianceSummaries:
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
|
||||
|
||||
@@ -3157,9 +3150,7 @@ class TestAggregateFindings:
|
||||
scan = scans_fixture[0]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
aggregate_findings(str(tenant.id), str(scan.id))
|
||||
|
||||
@@ -3208,9 +3199,7 @@ class TestAggregateFindings:
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
aggregate_findings(tenant_id, scan_id)
|
||||
@@ -3260,9 +3249,7 @@ class TestAggregateFindings:
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
aggregate_findings(tenant_id, scan_id)
|
||||
@@ -3335,9 +3322,7 @@ class TestAggregateFindings:
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
aggregate_findings(tenant_id, scan_id)
|
||||
@@ -3386,9 +3371,7 @@ class TestAggregateFindingsByRegion:
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, findings_count_by_compliance = (
|
||||
@@ -3439,9 +3422,7 @@ class TestAggregateFindingsByRegion:
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, _ = _aggregate_findings_by_region(
|
||||
@@ -3466,9 +3447,7 @@ class TestAggregateFindingsByRegion:
|
||||
mock_queryset.prefetch_related.return_value = []
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
_aggregate_findings_by_region(
|
||||
@@ -3516,9 +3495,7 @@ class TestAggregateFindingsByRegion:
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
_, findings_count_by_compliance = _aggregate_findings_by_region(
|
||||
@@ -3570,9 +3547,7 @@ class TestAggregateFindingsByRegion:
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, _ = _aggregate_findings_by_region(
|
||||
@@ -3600,9 +3575,7 @@ class TestAggregateFindingsByRegion:
|
||||
mock_queryset.prefetch_related.return_value = []
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, findings_count_by_compliance = (
|
||||
@@ -3715,9 +3688,7 @@ class TestAggregateAttackSurface:
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
aggregate_attack_surface(str(tenant.id), str(scan.id))
|
||||
@@ -3764,9 +3735,7 @@ class TestAggregateAttackSurface:
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
aggregate_attack_surface(str(tenant.id), str(scan.id))
|
||||
@@ -3805,9 +3774,7 @@ class TestAggregateAttackSurface:
|
||||
mock_queryset.annotate.return_value = [] # No findings
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
aggregate_attack_surface(str(tenant.id), str(scan.id))
|
||||
@@ -3848,9 +3815,7 @@ class TestAggregateAttackSurface:
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
aggregate_attack_surface(str(tenant.id), str(scan.id))
|
||||
@@ -3880,9 +3845,7 @@ class TestAggregateAttackSurface:
|
||||
mock_select_related.return_value.get.return_value = mock_scan
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_rls_transaction.return_value = [ctx]
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.scan._get_attack_surface_mapping_from_provider"
|
||||
|
||||
@@ -713,7 +713,7 @@ class TestGenerateOutputs:
|
||||
True,
|
||||
]
|
||||
mock_integration_filter.return_value = [MagicMock()]
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
with (
|
||||
patch("tasks.tasks.OUTPUT_FORMATS_MAPPING", {}),
|
||||
@@ -880,7 +880,7 @@ class TestCheckIntegrationsTask:
|
||||
):
|
||||
mock_integration_filter.return_value.exists.return_value = False
|
||||
# Ensure rls_transaction is mocked
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
result = check_integrations_task(
|
||||
tenant_id=self.tenant_id,
|
||||
@@ -922,7 +922,7 @@ class TestCheckIntegrationsTask:
|
||||
mock_group.return_value = mock_job
|
||||
|
||||
# Ensure rls_transaction is mocked
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
# Execute the function
|
||||
result = check_integrations_task(
|
||||
@@ -963,7 +963,7 @@ class TestCheckIntegrationsTask:
|
||||
):
|
||||
"""Test that disabled integrations are not processed."""
|
||||
mock_integration_filter.return_value.exists.return_value = False
|
||||
mock_rls.return_value.__enter__.return_value = None
|
||||
mock_rls.return_value = [MagicMock()]
|
||||
|
||||
result = check_integrations_task(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
Reference in New Issue
Block a user