mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
perf(api): optimize scan-compliance-overviews task (#11591)
This commit is contained in:
@@ -16,6 +16,14 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
---
|
||||
|
||||
## [1.31.2] (Prowler UNRELEASED)
|
||||
|
||||
### 🔄 Changed
|
||||
|
||||
- `scan-compliance-overviews` task now streams the findings aggregation and the requirement-row writes (reading the denormalized `resource_regions` instead of prefetching resources, and batching rows into COPY instead of building the full list first), so it runs faster and its peak memory no longer grows with the number of regions and frameworks — a previous worker OOM risk on large scans — with no change to the compliance overview output [(#11591)](https://github.com/prowler-cloud/prowler/pull/11591)
|
||||
|
||||
---
|
||||
|
||||
## [1.31.1] (Prowler v5.30.1)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
+161
-137
@@ -5,6 +5,7 @@ import re
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -22,7 +23,6 @@ from django.db.models import (
|
||||
Max,
|
||||
Min,
|
||||
OuterRef,
|
||||
Prefetch,
|
||||
Q,
|
||||
Sum,
|
||||
When,
|
||||
@@ -357,68 +357,71 @@ def _copy_compliance_requirement_rows(
|
||||
|
||||
|
||||
def _persist_compliance_requirement_rows(
|
||||
tenant_id: str, rows: list[dict[str, Any]], batch_size: int = 10000
|
||||
) -> None:
|
||||
tenant_id: str, rows: Iterable[dict[str, Any]], batch_size: int = 10000
|
||||
) -> int:
|
||||
"""Persist compliance requirement rows using batched COPY with ORM fallback.
|
||||
|
||||
Splits large row sets into batches to reduce lock duration and improve concurrency.
|
||||
``rows`` is consumed lazily in batches, so peak memory stays at ~``batch_size``
|
||||
rows instead of the full set. A batch that fails COPY falls back to an ORM
|
||||
``bulk_create`` of just that batch.
|
||||
|
||||
Args:
|
||||
tenant_id: Target tenant UUID.
|
||||
rows: Precomputed row dictionaries that reflect the compliance
|
||||
overview state for a scan.
|
||||
rows: Iterable of row dictionaries reflecting the compliance overview
|
||||
state for a scan.
|
||||
batch_size: Number of rows per COPY batch (default: 10000).
|
||||
|
||||
Returns:
|
||||
int: total number of rows persisted.
|
||||
"""
|
||||
if not rows:
|
||||
return
|
||||
|
||||
total_rows = len(rows)
|
||||
total_batches = (total_rows + batch_size - 1) // batch_size
|
||||
|
||||
try:
|
||||
# Process rows in batches to reduce lock duration
|
||||
for batch_num in range(total_batches):
|
||||
start_idx = batch_num * batch_size
|
||||
end_idx = min(start_idx + batch_size, total_rows)
|
||||
batch = rows[start_idx:end_idx]
|
||||
total_rows = 0
|
||||
batch_num = 0
|
||||
|
||||
for batch, _is_last in batched(rows, batch_size):
|
||||
if not batch:
|
||||
continue
|
||||
batch_num += 1
|
||||
try:
|
||||
_copy_compliance_requirement_rows(tenant_id, batch)
|
||||
except Exception as error:
|
||||
logger.exception(
|
||||
f"COPY bulk insert for compliance requirements batch {batch_num} "
|
||||
"failed; falling back to ORM bulk_create for this batch",
|
||||
exc_info=error,
|
||||
)
|
||||
fallback_objects = [
|
||||
ComplianceRequirementOverview(
|
||||
id=row["id"],
|
||||
tenant_id=row["tenant_id"],
|
||||
inserted_at=row["inserted_at"],
|
||||
compliance_id=row["compliance_id"],
|
||||
framework=row["framework"],
|
||||
version=row["version"],
|
||||
description=row["description"],
|
||||
region=row["region"],
|
||||
requirement_id=row["requirement_id"],
|
||||
requirement_status=row["requirement_status"],
|
||||
passed_checks=row["passed_checks"],
|
||||
failed_checks=row["failed_checks"],
|
||||
total_checks=row["total_checks"],
|
||||
passed_findings=row.get("passed_findings", 0),
|
||||
total_findings=row.get("total_findings", 0),
|
||||
scan_id=row["scan_id"],
|
||||
)
|
||||
for row in batch
|
||||
]
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceRequirementOverview.objects.bulk_create(
|
||||
fallback_objects, batch_size=500
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compliance COPY batch {batch_num + 1}/{total_batches}: "
|
||||
f"inserted {len(batch)} rows ({start_idx + len(batch)}/{total_rows} total)"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(
|
||||
"COPY bulk insert for compliance requirements failed; falling back to ORM bulk_create",
|
||||
exc_info=error,
|
||||
total_rows += len(batch)
|
||||
logger.info(
|
||||
f"Compliance COPY batch {batch_num}: inserted {len(batch)} rows "
|
||||
f"({total_rows} total)"
|
||||
)
|
||||
# Fallback: use ORM bulk_create for all remaining rows
|
||||
fallback_objects = [
|
||||
ComplianceRequirementOverview(
|
||||
id=row["id"],
|
||||
tenant_id=row["tenant_id"],
|
||||
inserted_at=row["inserted_at"],
|
||||
compliance_id=row["compliance_id"],
|
||||
framework=row["framework"],
|
||||
version=row["version"],
|
||||
description=row["description"],
|
||||
region=row["region"],
|
||||
requirement_id=row["requirement_id"],
|
||||
requirement_status=row["requirement_status"],
|
||||
passed_checks=row["passed_checks"],
|
||||
failed_checks=row["failed_checks"],
|
||||
total_checks=row["total_checks"],
|
||||
passed_findings=row.get("passed_findings", 0),
|
||||
total_findings=row.get("total_findings", 0),
|
||||
scan_id=row["scan_id"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceRequirementOverview.objects.bulk_create(
|
||||
fallback_objects, batch_size=500
|
||||
)
|
||||
|
||||
return total_rows
|
||||
|
||||
|
||||
def _create_compliance_summaries(
|
||||
@@ -1445,9 +1448,13 @@ def _aggregate_findings_by_region(
|
||||
tenant_id: str, scan_id: str, modeled_threatscore_compliance_id: str
|
||||
) -> tuple[dict, dict]:
|
||||
"""
|
||||
Aggregate findings by region using optimized ORM queries.
|
||||
Aggregate findings by region using streaming, column-scoped ORM reads.
|
||||
|
||||
Replaces nested Python loops with efficient queries and aggregation.
|
||||
Reads only the consumed columns as tuples via ``values_list`` and streams
|
||||
them with ``.iterator()``, using the denormalized ``resource_regions`` array
|
||||
instead of ``prefetch_related("resources")``. ``resource_regions`` mirrors the
|
||||
regions of a finding's related resources, so it yields the same per-region
|
||||
tally without joining the resource table.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
@@ -1459,12 +1466,12 @@ def _aggregate_findings_by_region(
|
||||
- check_status_by_region: {region: {check_id: status}}
|
||||
- findings_count_by_compliance: {region: {normalized_id: {requirement_id: {total, pass}}}}
|
||||
"""
|
||||
check_status_by_region = {}
|
||||
findings_count_by_compliance = {}
|
||||
check_status_by_region: dict = {}
|
||||
findings_count_by_compliance: dict = {}
|
||||
|
||||
normalized_id = re.sub(r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower())
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Fetch only PASS/FAIL findings (optimized query reduces data transfer)
|
||||
# Other statuses are not needed for check_status or ThreatScore calculation
|
||||
findings = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
@@ -1472,42 +1479,28 @@ def _aggregate_findings_by_region(
|
||||
muted=False,
|
||||
status__in=["PASS", "FAIL"],
|
||||
)
|
||||
.only("id", "check_id", "status", "compliance")
|
||||
.prefetch_related(
|
||||
Prefetch(
|
||||
"resources",
|
||||
queryset=Resource.objects.only("id", "region"),
|
||||
to_attr="small_resources",
|
||||
)
|
||||
.values_list("check_id", "status", "resource_regions", "compliance")
|
||||
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
|
||||
)
|
||||
|
||||
for check_id, status, resource_regions, compliance in findings:
|
||||
threatscore_requirements = (compliance or {}).get(
|
||||
modeled_threatscore_compliance_id
|
||||
)
|
||||
)
|
||||
|
||||
# Process findings in a single pass (more efficient than original nested loops)
|
||||
normalized_id = re.sub(
|
||||
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
|
||||
)
|
||||
|
||||
for finding in findings:
|
||||
status = finding.status
|
||||
|
||||
for resource in finding.small_resources:
|
||||
region = resource.region
|
||||
|
||||
# Aggregate check status by region
|
||||
current_status = check_status_by_region.setdefault(region, {})
|
||||
for region in resource_regions or ():
|
||||
# Priority: FAIL > any other status
|
||||
if current_status.get(finding.check_id) != "FAIL":
|
||||
current_status[finding.check_id] = status
|
||||
current_status = check_status_by_region.setdefault(region, {})
|
||||
if current_status.get(check_id) != "FAIL":
|
||||
current_status[check_id] = status
|
||||
|
||||
# Aggregate ThreatScore compliance counts
|
||||
if modeled_threatscore_compliance_id in (finding.compliance or {}):
|
||||
if threatscore_requirements:
|
||||
compliance_key = findings_count_by_compliance.setdefault(
|
||||
region, {}
|
||||
).setdefault(normalized_id, {})
|
||||
|
||||
for requirement_id in finding.compliance[
|
||||
modeled_threatscore_compliance_id
|
||||
]:
|
||||
for requirement_id in threatscore_requirements:
|
||||
requirement_stats = compliance_key.setdefault(
|
||||
requirement_id, {"total": 0, "pass": 0}
|
||||
)
|
||||
@@ -1554,8 +1547,8 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
(compliance_id, requirement_id)
|
||||
)
|
||||
|
||||
compliance_requirement_rows: list[dict[str, Any]] = []
|
||||
regions = []
|
||||
requirements_created = 0
|
||||
requirement_statuses = defaultdict(
|
||||
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
|
||||
)
|
||||
@@ -1595,44 +1588,93 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
else:
|
||||
requirement_stats["failed_checks"] += 1
|
||||
|
||||
# Prepare compliance requirement rows and compute summaries in single pass
|
||||
utc_datetime_now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Pre-compute shared strings (optimization: reduces string conversions)
|
||||
tenant_id_str = str(tenant_id)
|
||||
scan_id_str = str(scan_instance.id)
|
||||
|
||||
for region in regions:
|
||||
region_stats = region_requirement_stats.get(region, {})
|
||||
for compliance_id, compliance in compliance_template.items():
|
||||
modeled_compliance_id = _normalized_compliance_key(
|
||||
compliance["framework"], compliance["version"]
|
||||
# Per-framework constants that don't depend on the region.
|
||||
compliance_plan = []
|
||||
for compliance_id, compliance in compliance_template.items():
|
||||
modeled_compliance_id = _normalized_compliance_key(
|
||||
compliance["framework"], compliance["version"]
|
||||
)
|
||||
framework = compliance["framework"]
|
||||
version = compliance["version"] or ""
|
||||
requirements = [
|
||||
(
|
||||
requirement_id,
|
||||
requirement.get("description") or "",
|
||||
len(requirement["checks"]),
|
||||
)
|
||||
compliance_stats = region_stats.get(compliance_id, {})
|
||||
# Create an overview record for each requirement within each compliance framework
|
||||
for requirement_id, requirement in compliance[
|
||||
"requirements"
|
||||
].items():
|
||||
stats = compliance_stats.get(requirement_id)
|
||||
passed_checks = stats["passed_checks"] if stats else 0
|
||||
failed_checks = stats["failed_checks"] if stats else 0
|
||||
total_checks = len(requirement["checks"])
|
||||
if total_checks == 0:
|
||||
requirement_status = "MANUAL"
|
||||
elif failed_checks > 0:
|
||||
requirement_status = "FAIL"
|
||||
else:
|
||||
requirement_status = "PASS"
|
||||
].items()
|
||||
]
|
||||
compliance_plan.append(
|
||||
(
|
||||
compliance_id,
|
||||
framework,
|
||||
version,
|
||||
modeled_compliance_id,
|
||||
requirements,
|
||||
)
|
||||
)
|
||||
|
||||
compliance_requirement_rows.append(
|
||||
{
|
||||
# Yield rows lazily (consumed batch-by-batch by COPY) so peak memory
|
||||
# stays bounded; tally requirement_statuses in the same pass.
|
||||
def _iter_compliance_requirement_rows():
|
||||
for region in regions:
|
||||
region_stats = region_requirement_stats.get(region, {})
|
||||
region_findings = findings_count_by_compliance.get(region, {})
|
||||
for (
|
||||
compliance_id,
|
||||
framework,
|
||||
version,
|
||||
modeled_compliance_id,
|
||||
requirements,
|
||||
) in compliance_plan:
|
||||
compliance_stats = region_stats.get(compliance_id, {})
|
||||
compliance_findings = region_findings.get(
|
||||
modeled_compliance_id, {}
|
||||
)
|
||||
for requirement_id, description, total_checks in requirements:
|
||||
stats = compliance_stats.get(requirement_id)
|
||||
if stats:
|
||||
passed_checks = stats["passed_checks"]
|
||||
failed_checks = stats["failed_checks"]
|
||||
else:
|
||||
passed_checks = 0
|
||||
failed_checks = 0
|
||||
if total_checks == 0:
|
||||
requirement_status = "MANUAL"
|
||||
elif failed_checks > 0:
|
||||
requirement_status = "FAIL"
|
||||
else:
|
||||
requirement_status = "PASS"
|
||||
|
||||
finding_counts = compliance_findings.get(requirement_id)
|
||||
if finding_counts:
|
||||
passed_findings = finding_counts.get("pass", 0)
|
||||
total_findings = finding_counts.get("total", 0)
|
||||
else:
|
||||
passed_findings = 0
|
||||
total_findings = 0
|
||||
|
||||
key = (compliance_id, requirement_id)
|
||||
requirement_statuses[key]["total_count"] += 1
|
||||
if requirement_status == "FAIL":
|
||||
requirement_statuses[key]["fail_count"] += 1
|
||||
elif requirement_status == "PASS":
|
||||
requirement_statuses[key]["pass_count"] += 1
|
||||
|
||||
yield {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id_str,
|
||||
"inserted_at": utc_datetime_now,
|
||||
"compliance_id": compliance_id,
|
||||
"framework": compliance["framework"],
|
||||
"version": compliance["version"] or "",
|
||||
"description": requirement.get("description") or "",
|
||||
"framework": framework,
|
||||
"version": version,
|
||||
"description": description,
|
||||
"region": region,
|
||||
"requirement_id": requirement_id,
|
||||
"requirement_status": requirement_status,
|
||||
@@ -1640,41 +1682,23 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
"failed_checks": failed_checks,
|
||||
"total_checks": total_checks,
|
||||
"scan_id": scan_id_str,
|
||||
"passed_findings": findings_count_by_compliance.get(
|
||||
region, {}
|
||||
)
|
||||
.get(modeled_compliance_id, {})
|
||||
.get(requirement_id, {})
|
||||
.get("pass", 0),
|
||||
"total_findings": findings_count_by_compliance.get(
|
||||
region, {}
|
||||
)
|
||||
.get(modeled_compliance_id, {})
|
||||
.get(requirement_id, {})
|
||||
.get("total", 0),
|
||||
"passed_findings": passed_findings,
|
||||
"total_findings": total_findings,
|
||||
}
|
||||
)
|
||||
|
||||
# Update summary tracking (single-pass optimization)
|
||||
key = (compliance_id, requirement_id)
|
||||
requirement_statuses[key]["total_count"] += 1
|
||||
if requirement_status == "FAIL":
|
||||
requirement_statuses[key]["fail_count"] += 1
|
||||
elif requirement_status == "PASS":
|
||||
requirement_statuses[key]["pass_count"] += 1
|
||||
|
||||
# Idempotent re-run: COPY can't ON CONFLICT, so clear this scan's rows first.
|
||||
# Idempotent re-run: clear this scan's rows before re-inserting.
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceRequirementOverview.objects.filter(scan_id=scan_id).delete()
|
||||
|
||||
# Bulk create requirement records using PostgreSQL COPY
|
||||
_persist_compliance_requirement_rows(tenant_id, compliance_requirement_rows)
|
||||
requirements_created = _persist_compliance_requirement_rows(
|
||||
tenant_id, _iter_compliance_requirement_rows()
|
||||
)
|
||||
|
||||
# Create pre-aggregated summaries for fast compliance overview lookups
|
||||
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
|
||||
|
||||
return {
|
||||
"requirements_created": len(compliance_requirement_rows),
|
||||
"requirements_created": requirements_created,
|
||||
"regions_processed": list(regions),
|
||||
"compliance_frameworks": (
|
||||
list(compliance_template.keys()) if regions else []
|
||||
|
||||
@@ -3674,19 +3674,19 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# Mock findings with resources
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "FAIL"
|
||||
mock_finding1.compliance = {modeled_threatscore_compliance_id: ["req1", "req2"]}
|
||||
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
# (check_id, status, resource_regions, compliance) tuples
|
||||
finding_rows = [
|
||||
(
|
||||
"check1",
|
||||
"FAIL",
|
||||
["us-east-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1", "req2"]},
|
||||
)
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3700,6 +3700,12 @@ class TestAggregateFindingsByRegion:
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify structure of check_status_by_region
|
||||
assert isinstance(check_status_by_region, dict)
|
||||
assert "us-east-1" in check_status_by_region
|
||||
@@ -3719,27 +3725,15 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# First finding with PASS status
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "PASS"
|
||||
mock_finding1.compliance = {}
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
|
||||
# Second finding with FAIL status for same check/region
|
||||
mock_finding2 = MagicMock()
|
||||
mock_finding2.check_id = "check1"
|
||||
mock_finding2.status = "FAIL"
|
||||
mock_finding2.compliance = {}
|
||||
mock_resource2 = MagicMock()
|
||||
mock_resource2.region = "us-east-1"
|
||||
mock_finding2.small_resources = [mock_resource2]
|
||||
# Same check/region: PASS first, then FAIL — FAIL must win
|
||||
finding_rows = [
|
||||
("check1", "PASS", ["us-east-1"], {}),
|
||||
("check1", "FAIL", ["us-east-1"], {}),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3751,6 +3745,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# FAIL should override PASS
|
||||
assert check_status_by_region["us-east-1"]["check1"] == "FAIL"
|
||||
|
||||
@@ -3765,8 +3765,8 @@ class TestAggregateFindingsByRegion:
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = []
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = []
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3778,6 +3778,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify filter was called with muted=False
|
||||
mock_findings_filter.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
@@ -3796,27 +3802,25 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# Finding with PASS status
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "PASS"
|
||||
mock_finding1.compliance = {modeled_threatscore_compliance_id: ["req1"]}
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
|
||||
# Finding with FAIL status
|
||||
mock_finding2 = MagicMock()
|
||||
mock_finding2.check_id = "check2"
|
||||
mock_finding2.status = "FAIL"
|
||||
mock_finding2.compliance = {modeled_threatscore_compliance_id: ["req1"]}
|
||||
mock_resource2 = MagicMock()
|
||||
mock_resource2.region = "us-east-1"
|
||||
mock_finding2.small_resources = [mock_resource2]
|
||||
# PASS and FAIL findings mapped to the same ThreatScore requirement
|
||||
finding_rows = [
|
||||
(
|
||||
"check1",
|
||||
"PASS",
|
||||
["us-east-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1"]},
|
||||
),
|
||||
(
|
||||
"check2",
|
||||
"FAIL",
|
||||
["us-east-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1"]},
|
||||
),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3828,6 +3832,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify compliance counts
|
||||
normalized_id = re.sub(
|
||||
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
|
||||
@@ -3850,27 +3860,15 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# Finding in us-east-1
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "FAIL"
|
||||
mock_finding1.compliance = {}
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
|
||||
# Finding in us-west-2
|
||||
mock_finding2 = MagicMock()
|
||||
mock_finding2.check_id = "check1"
|
||||
mock_finding2.status = "PASS"
|
||||
mock_finding2.compliance = {}
|
||||
mock_resource2 = MagicMock()
|
||||
mock_resource2.region = "us-west-2"
|
||||
mock_finding2.small_resources = [mock_resource2]
|
||||
# One finding per region
|
||||
finding_rows = [
|
||||
("check1", "FAIL", ["us-east-1"], {}),
|
||||
("check1", "PASS", ["us-west-2"], {}),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3882,6 +3880,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify both regions are present with correct statuses
|
||||
assert "us-east-1" in check_status_by_region
|
||||
assert "us-west-2" in check_status_by_region
|
||||
@@ -3890,17 +3894,26 @@ class TestAggregateFindingsByRegion:
|
||||
|
||||
@patch("tasks.jobs.scan.Finding.all_objects.filter")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
def test_aggregate_findings_by_region_empty_findings(
|
||||
def test_aggregate_findings_by_region_multi_region_finding(
|
||||
self, mock_rls_transaction, mock_findings_filter
|
||||
):
|
||||
"""Test with no findings - should return empty dicts."""
|
||||
"""A finding with multiple resource_regions is tallied in every region."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
finding_rows = [
|
||||
(
|
||||
"check1",
|
||||
"FAIL",
|
||||
["us-east-1", "eu-west-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1"]},
|
||||
)
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = []
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3914,6 +3927,92 @@ class TestAggregateFindingsByRegion:
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
normalized_id = re.sub(
|
||||
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
|
||||
)
|
||||
for region in ("us-east-1", "eu-west-1"):
|
||||
assert check_status_by_region[region]["check1"] == "FAIL"
|
||||
req_stats = findings_count_by_compliance[region][normalized_id]["req1"]
|
||||
assert req_stats == {"total": 1, "pass": 0}
|
||||
|
||||
@patch("tasks.jobs.scan.Finding.all_objects.filter")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
def test_aggregate_findings_by_region_skips_empty_regions(
|
||||
self, mock_rls_transaction, mock_findings_filter
|
||||
):
|
||||
"""A finding with no denormalized regions contributes nothing."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
finding_rows = [
|
||||
("check1", "FAIL", [], {modeled_threatscore_compliance_id: ["req1"]}),
|
||||
("check2", "PASS", None, {}),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, findings_count_by_compliance = (
|
||||
_aggregate_findings_by_region(
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
assert check_status_by_region == {}
|
||||
assert findings_count_by_compliance == {}
|
||||
|
||||
@patch("tasks.jobs.scan.Finding.all_objects.filter")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
def test_aggregate_findings_by_region_empty_findings(
|
||||
self, mock_rls_transaction, mock_findings_filter
|
||||
):
|
||||
"""Test with no findings - should return empty dicts."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = []
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, findings_count_by_compliance = (
|
||||
_aggregate_findings_by_region(
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
assert check_status_by_region == {}
|
||||
assert findings_count_by_compliance == {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user