perf(api): optimize scan-compliance-overviews task (#11591)

This commit is contained in:
Pedro Martín
2026-06-16 10:48:55 +02:00
committed by GitHub
parent 94ce76d679
commit e419771b04
3 changed files with 342 additions and 211 deletions
+8
View File
@@ -16,6 +16,14 @@ All notable changes to the **Prowler API** are documented in this file.
---
## [1.31.2] (Prowler UNRELEASED)
### 🔄 Changed
- `scan-compliance-overviews` task now streams the findings aggregation and the requirement-row writes (reading the denormalized `resource_regions` instead of prefetching resources, and batching rows into COPY instead of building the full list first), so it runs faster and its peak memory no longer grows with the number of regions and frameworks — a previous worker OOM risk on large scans — with no change to the compliance overview output [(#11591)](https://github.com/prowler-cloud/prowler/pull/11591)
---
## [1.31.1] (Prowler v5.30.1)
### 🐞 Fixed
+161 -137
View File
@@ -5,6 +5,7 @@ import re
import time
import uuid
from collections import defaultdict
from collections.abc import Iterable
from datetime import datetime, timezone
from typing import Any
@@ -22,7 +23,6 @@ from django.db.models import (
Max,
Min,
OuterRef,
Prefetch,
Q,
Sum,
When,
@@ -357,68 +357,71 @@ def _copy_compliance_requirement_rows(
def _persist_compliance_requirement_rows(
tenant_id: str, rows: list[dict[str, Any]], batch_size: int = 10000
) -> None:
tenant_id: str, rows: Iterable[dict[str, Any]], batch_size: int = 10000
) -> int:
"""Persist compliance requirement rows using batched COPY with ORM fallback.
Splits large row sets into batches to reduce lock duration and improve concurrency.
``rows`` is consumed lazily in batches, so peak memory stays at ~``batch_size``
rows instead of the full set. A batch that fails COPY falls back to an ORM
``bulk_create`` of just that batch.
Args:
tenant_id: Target tenant UUID.
rows: Precomputed row dictionaries that reflect the compliance
overview state for a scan.
rows: Iterable of row dictionaries reflecting the compliance overview
state for a scan.
batch_size: Number of rows per COPY batch (default: 10000).
Returns:
int: total number of rows persisted.
"""
if not rows:
return
total_rows = len(rows)
total_batches = (total_rows + batch_size - 1) // batch_size
try:
# Process rows in batches to reduce lock duration
for batch_num in range(total_batches):
start_idx = batch_num * batch_size
end_idx = min(start_idx + batch_size, total_rows)
batch = rows[start_idx:end_idx]
total_rows = 0
batch_num = 0
for batch, _is_last in batched(rows, batch_size):
if not batch:
continue
batch_num += 1
try:
_copy_compliance_requirement_rows(tenant_id, batch)
except Exception as error:
logger.exception(
f"COPY bulk insert for compliance requirements batch {batch_num} "
"failed; falling back to ORM bulk_create for this batch",
exc_info=error,
)
fallback_objects = [
ComplianceRequirementOverview(
id=row["id"],
tenant_id=row["tenant_id"],
inserted_at=row["inserted_at"],
compliance_id=row["compliance_id"],
framework=row["framework"],
version=row["version"],
description=row["description"],
region=row["region"],
requirement_id=row["requirement_id"],
requirement_status=row["requirement_status"],
passed_checks=row["passed_checks"],
failed_checks=row["failed_checks"],
total_checks=row["total_checks"],
passed_findings=row.get("passed_findings", 0),
total_findings=row.get("total_findings", 0),
scan_id=row["scan_id"],
)
for row in batch
]
with rls_transaction(tenant_id):
ComplianceRequirementOverview.objects.bulk_create(
fallback_objects, batch_size=500
)
logger.info(
f"Compliance COPY batch {batch_num + 1}/{total_batches}: "
f"inserted {len(batch)} rows ({start_idx + len(batch)}/{total_rows} total)"
)
except Exception as error:
logger.exception(
"COPY bulk insert for compliance requirements failed; falling back to ORM bulk_create",
exc_info=error,
total_rows += len(batch)
logger.info(
f"Compliance COPY batch {batch_num}: inserted {len(batch)} rows "
f"({total_rows} total)"
)
# Fallback: use ORM bulk_create for all remaining rows
fallback_objects = [
ComplianceRequirementOverview(
id=row["id"],
tenant_id=row["tenant_id"],
inserted_at=row["inserted_at"],
compliance_id=row["compliance_id"],
framework=row["framework"],
version=row["version"],
description=row["description"],
region=row["region"],
requirement_id=row["requirement_id"],
requirement_status=row["requirement_status"],
passed_checks=row["passed_checks"],
failed_checks=row["failed_checks"],
total_checks=row["total_checks"],
passed_findings=row.get("passed_findings", 0),
total_findings=row.get("total_findings", 0),
scan_id=row["scan_id"],
)
for row in rows
]
with rls_transaction(tenant_id):
ComplianceRequirementOverview.objects.bulk_create(
fallback_objects, batch_size=500
)
return total_rows
def _create_compliance_summaries(
@@ -1445,9 +1448,13 @@ def _aggregate_findings_by_region(
tenant_id: str, scan_id: str, modeled_threatscore_compliance_id: str
) -> tuple[dict, dict]:
"""
Aggregate findings by region using optimized ORM queries.
Aggregate findings by region using streaming, column-scoped ORM reads.
Replaces nested Python loops with efficient queries and aggregation.
Reads only the consumed columns as tuples via ``values_list`` and streams
them with ``.iterator()``, using the denormalized ``resource_regions`` array
instead of ``prefetch_related("resources")``. ``resource_regions`` mirrors the
regions of a finding's related resources, so it yields the same per-region
tally without joining the resource table.
Args:
tenant_id: Tenant UUID
@@ -1459,12 +1466,12 @@ def _aggregate_findings_by_region(
- check_status_by_region: {region: {check_id: status}}
- findings_count_by_compliance: {region: {normalized_id: {requirement_id: {total, pass}}}}
"""
check_status_by_region = {}
findings_count_by_compliance = {}
check_status_by_region: dict = {}
findings_count_by_compliance: dict = {}
normalized_id = re.sub(r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower())
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
# Fetch only PASS/FAIL findings (optimized query reduces data transfer)
# Other statuses are not needed for check_status or ThreatScore calculation
findings = (
Finding.all_objects.filter(
tenant_id=tenant_id,
@@ -1472,42 +1479,28 @@ def _aggregate_findings_by_region(
muted=False,
status__in=["PASS", "FAIL"],
)
.only("id", "check_id", "status", "compliance")
.prefetch_related(
Prefetch(
"resources",
queryset=Resource.objects.only("id", "region"),
to_attr="small_resources",
)
.values_list("check_id", "status", "resource_regions", "compliance")
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
)
for check_id, status, resource_regions, compliance in findings:
threatscore_requirements = (compliance or {}).get(
modeled_threatscore_compliance_id
)
)
# Process findings in a single pass (more efficient than original nested loops)
normalized_id = re.sub(
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
)
for finding in findings:
status = finding.status
for resource in finding.small_resources:
region = resource.region
# Aggregate check status by region
current_status = check_status_by_region.setdefault(region, {})
for region in resource_regions or ():
# Priority: FAIL > any other status
if current_status.get(finding.check_id) != "FAIL":
current_status[finding.check_id] = status
current_status = check_status_by_region.setdefault(region, {})
if current_status.get(check_id) != "FAIL":
current_status[check_id] = status
# Aggregate ThreatScore compliance counts
if modeled_threatscore_compliance_id in (finding.compliance or {}):
if threatscore_requirements:
compliance_key = findings_count_by_compliance.setdefault(
region, {}
).setdefault(normalized_id, {})
for requirement_id in finding.compliance[
modeled_threatscore_compliance_id
]:
for requirement_id in threatscore_requirements:
requirement_stats = compliance_key.setdefault(
requirement_id, {"total": 0, "pass": 0}
)
@@ -1554,8 +1547,8 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
(compliance_id, requirement_id)
)
compliance_requirement_rows: list[dict[str, Any]] = []
regions = []
requirements_created = 0
requirement_statuses = defaultdict(
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
)
@@ -1595,44 +1588,93 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
else:
requirement_stats["failed_checks"] += 1
# Prepare compliance requirement rows and compute summaries in single pass
utc_datetime_now = datetime.now(tz=timezone.utc)
# Pre-compute shared strings (optimization: reduces string conversions)
tenant_id_str = str(tenant_id)
scan_id_str = str(scan_instance.id)
for region in regions:
region_stats = region_requirement_stats.get(region, {})
for compliance_id, compliance in compliance_template.items():
modeled_compliance_id = _normalized_compliance_key(
compliance["framework"], compliance["version"]
# Per-framework constants that don't depend on the region.
compliance_plan = []
for compliance_id, compliance in compliance_template.items():
modeled_compliance_id = _normalized_compliance_key(
compliance["framework"], compliance["version"]
)
framework = compliance["framework"]
version = compliance["version"] or ""
requirements = [
(
requirement_id,
requirement.get("description") or "",
len(requirement["checks"]),
)
compliance_stats = region_stats.get(compliance_id, {})
# Create an overview record for each requirement within each compliance framework
for requirement_id, requirement in compliance[
"requirements"
].items():
stats = compliance_stats.get(requirement_id)
passed_checks = stats["passed_checks"] if stats else 0
failed_checks = stats["failed_checks"] if stats else 0
total_checks = len(requirement["checks"])
if total_checks == 0:
requirement_status = "MANUAL"
elif failed_checks > 0:
requirement_status = "FAIL"
else:
requirement_status = "PASS"
].items()
]
compliance_plan.append(
(
compliance_id,
framework,
version,
modeled_compliance_id,
requirements,
)
)
compliance_requirement_rows.append(
{
# Yield rows lazily (consumed batch-by-batch by COPY) so peak memory
# stays bounded; tally requirement_statuses in the same pass.
def _iter_compliance_requirement_rows():
for region in regions:
region_stats = region_requirement_stats.get(region, {})
region_findings = findings_count_by_compliance.get(region, {})
for (
compliance_id,
framework,
version,
modeled_compliance_id,
requirements,
) in compliance_plan:
compliance_stats = region_stats.get(compliance_id, {})
compliance_findings = region_findings.get(
modeled_compliance_id, {}
)
for requirement_id, description, total_checks in requirements:
stats = compliance_stats.get(requirement_id)
if stats:
passed_checks = stats["passed_checks"]
failed_checks = stats["failed_checks"]
else:
passed_checks = 0
failed_checks = 0
if total_checks == 0:
requirement_status = "MANUAL"
elif failed_checks > 0:
requirement_status = "FAIL"
else:
requirement_status = "PASS"
finding_counts = compliance_findings.get(requirement_id)
if finding_counts:
passed_findings = finding_counts.get("pass", 0)
total_findings = finding_counts.get("total", 0)
else:
passed_findings = 0
total_findings = 0
key = (compliance_id, requirement_id)
requirement_statuses[key]["total_count"] += 1
if requirement_status == "FAIL":
requirement_statuses[key]["fail_count"] += 1
elif requirement_status == "PASS":
requirement_statuses[key]["pass_count"] += 1
yield {
"id": uuid.uuid4(),
"tenant_id": tenant_id_str,
"inserted_at": utc_datetime_now,
"compliance_id": compliance_id,
"framework": compliance["framework"],
"version": compliance["version"] or "",
"description": requirement.get("description") or "",
"framework": framework,
"version": version,
"description": description,
"region": region,
"requirement_id": requirement_id,
"requirement_status": requirement_status,
@@ -1640,41 +1682,23 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
"failed_checks": failed_checks,
"total_checks": total_checks,
"scan_id": scan_id_str,
"passed_findings": findings_count_by_compliance.get(
region, {}
)
.get(modeled_compliance_id, {})
.get(requirement_id, {})
.get("pass", 0),
"total_findings": findings_count_by_compliance.get(
region, {}
)
.get(modeled_compliance_id, {})
.get(requirement_id, {})
.get("total", 0),
"passed_findings": passed_findings,
"total_findings": total_findings,
}
)
# Update summary tracking (single-pass optimization)
key = (compliance_id, requirement_id)
requirement_statuses[key]["total_count"] += 1
if requirement_status == "FAIL":
requirement_statuses[key]["fail_count"] += 1
elif requirement_status == "PASS":
requirement_statuses[key]["pass_count"] += 1
# Idempotent re-run: COPY can't ON CONFLICT, so clear this scan's rows first.
# Idempotent re-run: clear this scan's rows before re-inserting.
with rls_transaction(tenant_id):
ComplianceRequirementOverview.objects.filter(scan_id=scan_id).delete()
# Bulk create requirement records using PostgreSQL COPY
_persist_compliance_requirement_rows(tenant_id, compliance_requirement_rows)
requirements_created = _persist_compliance_requirement_rows(
tenant_id, _iter_compliance_requirement_rows()
)
# Create pre-aggregated summaries for fast compliance overview lookups
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
return {
"requirements_created": len(compliance_requirement_rows),
"requirements_created": requirements_created,
"regions_processed": list(regions),
"compliance_frameworks": (
list(compliance_template.keys()) if regions else []
+173 -74
View File
@@ -3674,19 +3674,19 @@ class TestAggregateFindingsByRegion:
scan_id = str(uuid.uuid4())
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
# Mock findings with resources
mock_finding1 = MagicMock()
mock_finding1.check_id = "check1"
mock_finding1.status = "FAIL"
mock_finding1.compliance = {modeled_threatscore_compliance_id: ["req1", "req2"]}
mock_resource1 = MagicMock()
mock_resource1.region = "us-east-1"
mock_finding1.small_resources = [mock_resource1]
# (check_id, status, resource_regions, compliance) tuples
finding_rows = [
(
"check1",
"FAIL",
["us-east-1"],
{modeled_threatscore_compliance_id: ["req1", "req2"]},
)
]
mock_queryset = MagicMock()
mock_queryset.only.return_value = mock_queryset
mock_queryset.prefetch_related.return_value = [mock_finding1]
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = finding_rows
ctx = MagicMock()
ctx.__enter__.return_value = None
@@ -3700,6 +3700,12 @@ class TestAggregateFindingsByRegion:
)
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
# Verify structure of check_status_by_region
assert isinstance(check_status_by_region, dict)
assert "us-east-1" in check_status_by_region
@@ -3719,27 +3725,15 @@ class TestAggregateFindingsByRegion:
scan_id = str(uuid.uuid4())
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
# First finding with PASS status
mock_finding1 = MagicMock()
mock_finding1.check_id = "check1"
mock_finding1.status = "PASS"
mock_finding1.compliance = {}
mock_resource1 = MagicMock()
mock_resource1.region = "us-east-1"
mock_finding1.small_resources = [mock_resource1]
# Second finding with FAIL status for same check/region
mock_finding2 = MagicMock()
mock_finding2.check_id = "check1"
mock_finding2.status = "FAIL"
mock_finding2.compliance = {}
mock_resource2 = MagicMock()
mock_resource2.region = "us-east-1"
mock_finding2.small_resources = [mock_resource2]
# Same check/region: PASS first, then FAIL — FAIL must win
finding_rows = [
("check1", "PASS", ["us-east-1"], {}),
("check1", "FAIL", ["us-east-1"], {}),
]
mock_queryset = MagicMock()
mock_queryset.only.return_value = mock_queryset
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = finding_rows
ctx = MagicMock()
ctx.__enter__.return_value = None
@@ -3751,6 +3745,12 @@ class TestAggregateFindingsByRegion:
tenant_id, scan_id, modeled_threatscore_compliance_id
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
# FAIL should override PASS
assert check_status_by_region["us-east-1"]["check1"] == "FAIL"
@@ -3765,8 +3765,8 @@ class TestAggregateFindingsByRegion:
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
mock_queryset = MagicMock()
mock_queryset.only.return_value = mock_queryset
mock_queryset.prefetch_related.return_value = []
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = []
ctx = MagicMock()
ctx.__enter__.return_value = None
@@ -3778,6 +3778,12 @@ class TestAggregateFindingsByRegion:
tenant_id, scan_id, modeled_threatscore_compliance_id
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
# Verify filter was called with muted=False
mock_findings_filter.assert_called_once_with(
tenant_id=tenant_id,
@@ -3796,27 +3802,25 @@ class TestAggregateFindingsByRegion:
scan_id = str(uuid.uuid4())
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
# Finding with PASS status
mock_finding1 = MagicMock()
mock_finding1.check_id = "check1"
mock_finding1.status = "PASS"
mock_finding1.compliance = {modeled_threatscore_compliance_id: ["req1"]}
mock_resource1 = MagicMock()
mock_resource1.region = "us-east-1"
mock_finding1.small_resources = [mock_resource1]
# Finding with FAIL status
mock_finding2 = MagicMock()
mock_finding2.check_id = "check2"
mock_finding2.status = "FAIL"
mock_finding2.compliance = {modeled_threatscore_compliance_id: ["req1"]}
mock_resource2 = MagicMock()
mock_resource2.region = "us-east-1"
mock_finding2.small_resources = [mock_resource2]
# PASS and FAIL findings mapped to the same ThreatScore requirement
finding_rows = [
(
"check1",
"PASS",
["us-east-1"],
{modeled_threatscore_compliance_id: ["req1"]},
),
(
"check2",
"FAIL",
["us-east-1"],
{modeled_threatscore_compliance_id: ["req1"]},
),
]
mock_queryset = MagicMock()
mock_queryset.only.return_value = mock_queryset
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = finding_rows
ctx = MagicMock()
ctx.__enter__.return_value = None
@@ -3828,6 +3832,12 @@ class TestAggregateFindingsByRegion:
tenant_id, scan_id, modeled_threatscore_compliance_id
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
# Verify compliance counts
normalized_id = re.sub(
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
@@ -3850,27 +3860,15 @@ class TestAggregateFindingsByRegion:
scan_id = str(uuid.uuid4())
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
# Finding in us-east-1
mock_finding1 = MagicMock()
mock_finding1.check_id = "check1"
mock_finding1.status = "FAIL"
mock_finding1.compliance = {}
mock_resource1 = MagicMock()
mock_resource1.region = "us-east-1"
mock_finding1.small_resources = [mock_resource1]
# Finding in us-west-2
mock_finding2 = MagicMock()
mock_finding2.check_id = "check1"
mock_finding2.status = "PASS"
mock_finding2.compliance = {}
mock_resource2 = MagicMock()
mock_resource2.region = "us-west-2"
mock_finding2.small_resources = [mock_resource2]
# One finding per region
finding_rows = [
("check1", "FAIL", ["us-east-1"], {}),
("check1", "PASS", ["us-west-2"], {}),
]
mock_queryset = MagicMock()
mock_queryset.only.return_value = mock_queryset
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = finding_rows
ctx = MagicMock()
ctx.__enter__.return_value = None
@@ -3882,6 +3880,12 @@ class TestAggregateFindingsByRegion:
tenant_id, scan_id, modeled_threatscore_compliance_id
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
# Verify both regions are present with correct statuses
assert "us-east-1" in check_status_by_region
assert "us-west-2" in check_status_by_region
@@ -3890,17 +3894,26 @@ class TestAggregateFindingsByRegion:
@patch("tasks.jobs.scan.Finding.all_objects.filter")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_findings_by_region_empty_findings(
def test_aggregate_findings_by_region_multi_region_finding(
self, mock_rls_transaction, mock_findings_filter
):
"""Test with no findings - should return empty dicts."""
"""A finding with multiple resource_regions is tallied in every region."""
tenant_id = str(uuid.uuid4())
scan_id = str(uuid.uuid4())
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
finding_rows = [
(
"check1",
"FAIL",
["us-east-1", "eu-west-1"],
{modeled_threatscore_compliance_id: ["req1"]},
)
]
mock_queryset = MagicMock()
mock_queryset.only.return_value = mock_queryset
mock_queryset.prefetch_related.return_value = []
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = finding_rows
ctx = MagicMock()
ctx.__enter__.return_value = None
@@ -3914,6 +3927,92 @@ class TestAggregateFindingsByRegion:
)
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
normalized_id = re.sub(
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
)
for region in ("us-east-1", "eu-west-1"):
assert check_status_by_region[region]["check1"] == "FAIL"
req_stats = findings_count_by_compliance[region][normalized_id]["req1"]
assert req_stats == {"total": 1, "pass": 0}
@patch("tasks.jobs.scan.Finding.all_objects.filter")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_findings_by_region_skips_empty_regions(
self, mock_rls_transaction, mock_findings_filter
):
"""A finding with no denormalized regions contributes nothing."""
tenant_id = str(uuid.uuid4())
scan_id = str(uuid.uuid4())
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
finding_rows = [
("check1", "FAIL", [], {modeled_threatscore_compliance_id: ["req1"]}),
("check2", "PASS", None, {}),
]
mock_queryset = MagicMock()
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = finding_rows
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_findings_filter.return_value = mock_queryset
check_status_by_region, findings_count_by_compliance = (
_aggregate_findings_by_region(
tenant_id, scan_id, modeled_threatscore_compliance_id
)
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
assert check_status_by_region == {}
assert findings_count_by_compliance == {}
@patch("tasks.jobs.scan.Finding.all_objects.filter")
@patch("tasks.jobs.scan.rls_transaction")
def test_aggregate_findings_by_region_empty_findings(
self, mock_rls_transaction, mock_findings_filter
):
"""Test with no findings - should return empty dicts."""
tenant_id = str(uuid.uuid4())
scan_id = str(uuid.uuid4())
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
mock_queryset = MagicMock()
mock_queryset.values_list.return_value = mock_queryset
mock_queryset.iterator.return_value = []
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
mock_findings_filter.return_value = mock_queryset
check_status_by_region, findings_count_by_compliance = (
_aggregate_findings_by_region(
tenant_id, scan_id, modeled_threatscore_compliance_id
)
)
# Streaming query contract: column-scoped values_list + iterator
mock_queryset.values_list.assert_called_once_with(
"check_id", "status", "resource_regions", "compliance"
)
mock_queryset.iterator.assert_called_once()
assert check_status_by_region == {}
assert findings_count_by_compliance == {}