feat(db): optimize write queries for scan related tasks (#9190)

Co-authored-by: Josema Camacho <josema@prowler.com>
This commit is contained in:
Víctor Fernández Poyatos
2025-11-13 12:27:57 +01:00
committed by GitHub
parent ce5f2cc5ed
commit 46666d29d3
11 changed files with 2347 additions and 502 deletions

View File

@@ -17,12 +17,15 @@ All notable changes to the **Prowler API** are documented in this file.
- Tenant-wide ThreatScore overview aggregation and snapshot persistence with backfill support [(#9148)](https://github.com/prowler-cloud/prowler/pull/9148)
- Support for MongoDB Atlas provider [(#9167)](https://github.com/prowler-cloud/prowler/pull/9167)
### Changed
- Optimized database write queries for scan related tasks [(#9190)](https://github.com/prowler-cloud/prowler/pull/9190)
### Security
- Django updated to the latest 5.1 security release, 5.1.14, due to problems with potential [SQL injection](https://github.com/prowler-cloud/prowler/security/dependabot/113) and [denial-of-service vulnerability](https://github.com/prowler-cloud/prowler/security/dependabot/114) [(#9176)](https://github.com/prowler-cloud/prowler/pull/9176)
---
## [1.14.2] (Prowler UNRELEASED)
## [1.14.2] (Prowler 5.13.2)
### Fixed
- Update unique constraint for `Provider` model to exclude soft-deleted entries, resolving duplicate errors when re-deleting providers.[(#9054)](https://github.com/prowler-cloud/prowler/pull/9054)

View File

@@ -0,0 +1,25 @@
from django.contrib.postgres.operations import RemoveIndexConcurrently
from django.db import migrations
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0057_threatscoresnapshot"),
]
operations = [
RemoveIndexConcurrently(
model_name="compliancerequirementoverview",
name="cro_tenant_scan_idx",
),
RemoveIndexConcurrently(
model_name="compliancerequirementoverview",
name="cro_scan_comp_idx",
),
RemoveIndexConcurrently(
model_name="compliancerequirementoverview",
name="cro_scan_comp_req_idx",
),
]

View File

@@ -0,0 +1,75 @@
# Generated by Django 5.1.13 on 2025-10-30 15:23
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0058_drop_redundant_compliance_requirement_indexes"),
]
operations = [
migrations.CreateModel(
name="ComplianceOverviewSummary",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("compliance_id", models.TextField()),
("requirements_passed", models.IntegerField(default=0)),
("requirements_failed", models.IntegerField(default=0)),
("requirements_manual", models.IntegerField(default=0)),
("total_requirements", models.IntegerField(default=0)),
(
"scan",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="compliance_summaries",
related_query_name="compliance_summary",
to="api.scan",
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "compliance_overview_summaries",
"abstract": False,
"indexes": [
models.Index(
fields=["tenant_id", "scan_id"], name="cos_tenant_scan_idx"
)
],
"constraints": [
models.UniqueConstraint(
fields=("tenant_id", "scan_id", "compliance_id"),
name="unique_compliance_summary_per_scan",
)
],
},
),
migrations.AddConstraint(
model_name="complianceoverviewsummary",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_complianceoverviewsummary",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]

View File

@@ -1371,35 +1371,70 @@ class ComplianceRequirementOverview(RowLevelSecurityProtectedModel):
),
]
indexes = [
models.Index(fields=["tenant_id", "scan_id"], name="cro_tenant_scan_idx"),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id"],
name="cro_scan_comp_idx",
),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id", "region"],
name="cro_scan_comp_reg_idx",
),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id", "requirement_id"],
name="cro_scan_comp_req_idx",
),
models.Index(
fields=[
"tenant_id",
"scan_id",
"compliance_id",
"requirement_id",
"region",
],
name="cro_scan_comp_req_reg_idx",
),
]
class JSONAPIMeta:
resource_name = "compliance-requirements-overviews"
class ComplianceOverviewSummary(RowLevelSecurityProtectedModel):
"""
Pre-aggregated compliance overview aggregated across ALL regions.
One row per (scan_id, compliance_id) combination.
This table optimizes the common case where users view overall compliance
without filtering by region. For region-specific views, the detailed
ComplianceRequirementOverview table is used instead.
"""
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
scan = models.ForeignKey(
Scan,
on_delete=models.CASCADE,
related_name="compliance_summaries",
related_query_name="compliance_summary",
)
compliance_id = models.TextField(blank=False)
# Pre-aggregated scores (computed across ALL regions)
requirements_passed = models.IntegerField(default=0)
requirements_failed = models.IntegerField(default=0)
requirements_manual = models.IntegerField(default=0)
total_requirements = models.IntegerField(default=0)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "compliance_overview_summaries"
constraints = [
models.UniqueConstraint(
fields=("tenant_id", "scan_id", "compliance_id"),
name="unique_compliance_summary_per_scan",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "DELETE"],
),
]
indexes = [
models.Index(
fields=["tenant_id", "scan_id"],
name="cos_tenant_scan_idx",
),
]
class JSONAPIMeta:
resource_name = "compliance-overview-summaries"
class ScanSummary(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

View File

@@ -35,6 +35,8 @@ from rest_framework.response import Response
from api.compliance import get_compliance_frameworks
from api.db_router import MainRouter
from api.models import (
ComplianceOverviewSummary,
ComplianceRequirementOverview,
Finding,
Integration,
Invitation,
@@ -55,6 +57,7 @@ from api.models import (
Scan,
ScanSummary,
StateChoices,
StatusChoices,
Task,
TenantAPIKey,
ThreatScoreSnapshot,
@@ -5814,16 +5817,44 @@ class TestProviderGroupMembershipViewSet:
@pytest.mark.django_db
class TestComplianceOverviewViewSet:
def test_compliance_overview_list_none(self, authenticated_client):
@pytest.fixture(autouse=True)
def mock_backfill_task(self):
with patch("api.v1.views.backfill_compliance_summaries_task.delay") as mock:
yield mock
def test_compliance_overview_list_none(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
mock_backfill_task,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan = Scan.objects.create(
name="empty-compliance-scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": "8d20ac7d-4cbc-435e-85f4-359be37af821"},
{"filter[scan_id]": str(scan.id)},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
mock_backfill_task.assert_called_once()
_, kwargs = mock_backfill_task.call_args
assert kwargs["scan_id"] == str(scan.id)
assert str(kwargs["tenant_id"]) == str(tenant.id)
def test_compliance_overview_list(
self, authenticated_client, compliance_requirements_overviews_fixture
self,
authenticated_client,
compliance_requirements_overviews_fixture,
mock_backfill_task,
):
# List compliance overviews with existing data
requirement_overview1 = compliance_requirements_overviews_fixture[0]
@@ -5853,6 +5884,90 @@ class TestComplianceOverviewViewSet:
assert "requirements_failed" in attributes
assert "requirements_manual" in attributes
assert "total_requirements" in attributes
mock_backfill_task.assert_called_once()
_, kwargs = mock_backfill_task.call_args
assert kwargs["scan_id"] == scan_id
def test_compliance_overview_list_uses_preaggregated_summaries(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
mock_backfill_task,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan = Scan.objects.create(
name="preaggregated-scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan,
compliance_id="cis_1.4_aws",
framework="CIS-1.4-AWS",
version="1.4",
description="CIS AWS Foundations Benchmark v1.4.0",
region="eu-west-1",
requirement_id="framework-metadata",
requirement_status=StatusChoices.PASS,
passed_checks=1,
failed_checks=0,
total_checks=1,
)
ComplianceOverviewSummary.objects.create(
tenant=tenant,
scan=scan,
compliance_id="cis_1.4_aws",
requirements_passed=5,
requirements_failed=1,
requirements_manual=2,
total_requirements=8,
)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": str(scan.id)},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
overview = data[0]
assert overview["id"] == "cis_1.4_aws"
assert overview["attributes"]["requirements_passed"] == 5
assert overview["attributes"]["requirements_failed"] == 1
assert overview["attributes"]["requirements_manual"] == 2
assert overview["attributes"]["total_requirements"] == 8
assert "framework" in overview["attributes"]
assert "version" in overview["attributes"]
mock_backfill_task.assert_not_called()
def test_compliance_overview_region_filter_skips_backfill(
self,
authenticated_client,
compliance_requirements_overviews_fixture,
mock_backfill_task,
):
requirement_overview = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[scan_id]": scan_id,
"filter[region]": requirement_overview.region,
},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) >= 1
mock_backfill_task.assert_not_called()
def test_compliance_overview_metadata(
self, authenticated_client, compliance_requirements_overviews_fixture
@@ -6006,6 +6121,11 @@ class TestComplianceOverviewViewSet:
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
# Remove existing compliance data so the view falls back to task checks
scan = requirement_overview1.scan
ComplianceOverviewSummary.objects.filter(scan=scan).delete()
ComplianceRequirementOverview.objects.filter(scan=scan).delete()
# Mock a running task
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
@@ -6033,6 +6153,11 @@ class TestComplianceOverviewViewSet:
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
# Remove existing compliance data so the view falls back to task checks
scan = requirement_overview1.scan
ComplianceOverviewSummary.objects.filter(scan=scan).delete()
ComplianceRequirementOverview.objects.filter(scan=scan).delete()
# Mock a failed task
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
@@ -6056,6 +6181,8 @@ class TestComplianceOverviewViewSet:
("framework", "framework", 1),
("version", "version", 1),
("region", "region", 1),
("region__in", "region", 1),
("region.in", "region", 1),
],
)
def test_compliance_overview_filters(

View File

@@ -75,6 +75,7 @@ from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
backfill_compliance_summaries_task,
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
@@ -126,6 +127,7 @@ from api.filters import (
UserFilter,
)
from api.models import (
ComplianceOverviewSummary,
ComplianceRequirementOverview,
Finding,
Integration,
@@ -3398,6 +3400,168 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
def retrieve(self, request, *args, **kwargs):
raise MethodNotAllowed(method="GET")
def _compliance_summaries_queryset(self, scan_id):
"""Return pre-aggregated summaries constrained by RBAC visibility."""
role = get_role(self.request.user)
unlimited_visibility = getattr(
role, Permissions.UNLIMITED_VISIBILITY.value, False
)
summaries = ComplianceOverviewSummary.objects.filter(
tenant_id=self.request.tenant_id,
scan_id=scan_id,
)
if not unlimited_visibility:
providers = Provider.all_objects.filter(
provider_groups__in=role.provider_groups.all()
).distinct()
summaries = summaries.filter(scan__provider__in=providers)
return summaries
def _get_compliance_template(self, *, provider=None, scan_id=None):
"""Return the compliance template for the given provider or scan."""
if provider is None and scan_id is not None:
scan = Scan.all_objects.select_related("provider").get(pk=scan_id)
provider = scan.provider
if not provider:
return {}
return PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE.get(provider.provider, {})
def _aggregate_compliance_overview(self, queryset, template_metadata=None):
"""
Aggregate requirement rows into compliance overview dictionaries.
Args:
queryset: ComplianceRequirementOverview queryset already filtered.
template_metadata: Optional dict mapping compliance_id -> metadata.
"""
template_metadata = template_metadata or {}
requirement_status_subquery = queryset.values(
"compliance_id", "requirement_id"
).annotate(
fail_count=Count("id", filter=Q(requirement_status="FAIL")),
pass_count=Count("id", filter=Q(requirement_status="PASS")),
total_count=Count("id"),
)
compliance_data = {}
fallback_metadata = {
item["compliance_id"]: {
"framework": item["framework"],
"version": item["version"],
}
for item in queryset.values(
"compliance_id", "framework", "version"
).distinct()
}
for item in requirement_status_subquery:
compliance_id = item["compliance_id"]
if item["fail_count"] > 0:
req_status = "FAIL"
elif item["pass_count"] == item["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
compliance_status = compliance_data.setdefault(
compliance_id,
{
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
},
)
compliance_status["total_requirements"] += 1
if req_status == "PASS":
compliance_status["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_status["requirements_failed"] += 1
else:
compliance_status["requirements_manual"] += 1
response_data = []
for compliance_id, data in compliance_data.items():
template = template_metadata.get(compliance_id, {})
fallback = fallback_metadata.get(compliance_id, {})
response_data.append(
{
"id": compliance_id,
"compliance_id": compliance_id,
"framework": template.get("framework")
or fallback.get("framework", ""),
"version": template.get("version") or fallback.get("version", ""),
"requirements_passed": data["requirements_passed"],
"requirements_failed": data["requirements_failed"],
"requirements_manual": data["requirements_manual"],
"total_requirements": data["total_requirements"],
}
)
serializer = self.get_serializer(response_data, many=True)
return serializer.data
def _task_response_if_running(self, scan_id):
"""Check for an in-progress task only when no compliance data exists."""
try:
return self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
)
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _list_with_region_filter(self, scan_id, region_filter):
"""
Fall back to detailed ComplianceRequirementOverview query when region filter is applied.
This uses the original aggregation logic across filtered regions.
"""
regions = region_filter.split(",") if "," in region_filter else [region_filter]
queryset = self.filter_queryset(self.get_queryset()).filter(
scan_id=scan_id,
region__in=regions,
)
data = self._aggregate_compliance_overview(queryset)
if data:
return Response(data)
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
return Response(data)
def _list_without_region_aggregation(self, scan_id):
"""
Fall back aggregation when compliance summaries don't exist yet.
Aggregates ComplianceRequirementOverview data across ALL regions.
"""
queryset = self.filter_queryset(self.get_queryset()).filter(scan_id=scan_id)
compliance_template = self._get_compliance_template(scan_id=scan_id)
data = self._aggregate_compliance_overview(
queryset, template_metadata=compliance_template
)
if data:
return Response(data)
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
return Response(data)
def list(self, request, *args, **kwargs):
scan_id = request.query_params.get("filter[scan_id]")
if not scan_id:
@@ -3411,77 +3575,41 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
}
]
)
try:
if task := self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
):
return task
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
region_filter = request.query_params.get(
"filter[region]"
) or request.query_params.get("filter[region__in]")
if region_filter:
# Fall back to detailed query with region filtering
return self._list_with_region_filter(scan_id, region_filter)
summaries = list(self._compliance_summaries_queryset(scan_id))
if not summaries:
# Trigger async backfill for next time
backfill_compliance_summaries_task.delay(
tenant_id=self.request.tenant_id, scan_id=scan_id
)
queryset = self.filter_queryset(self.filter_queryset(self.get_queryset()))
# Use fallback aggregation for this request
return self._list_without_region_aggregation(scan_id)
requirement_status_subquery = queryset.values(
"compliance_id", "requirement_id"
).annotate(
fail_count=Count("id", filter=Q(requirement_status="FAIL")),
pass_count=Count("id", filter=Q(requirement_status="PASS")),
total_count=Count("id"),
)
compliance_data = {}
framework_info = {}
for item in queryset.values("compliance_id", "framework", "version").distinct():
framework_info[item["compliance_id"]] = {
"framework": item["framework"],
"version": item["version"],
}
for item in requirement_status_subquery:
compliance_id = item["compliance_id"]
if item["fail_count"] > 0:
req_status = "FAIL"
elif item["pass_count"] == item["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
if compliance_id not in compliance_data:
compliance_data[compliance_id] = {
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
}
compliance_data[compliance_id]["total_requirements"] += 1
if req_status == "PASS":
compliance_data[compliance_id]["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_data[compliance_id]["requirements_failed"] += 1
else:
compliance_data[compliance_id]["requirements_manual"] += 1
# Get compliance template for provider to enrich with framework/version
compliance_template = self._get_compliance_template(scan_id=scan_id)
# Convert to response format with framework/version enrichment
response_data = []
for compliance_id, data in compliance_data.items():
framework = framework_info.get(compliance_id, {})
for summary in summaries:
compliance_metadata = compliance_template.get(summary.compliance_id, {})
response_data.append(
{
"id": compliance_id,
"compliance_id": compliance_id,
"framework": framework.get("framework", ""),
"version": framework.get("version", ""),
"requirements_passed": data["requirements_passed"],
"requirements_failed": data["requirements_failed"],
"requirements_manual": data["requirements_manual"],
"total_requirements": data["total_requirements"],
"id": summary.compliance_id,
"compliance_id": summary.compliance_id,
"framework": compliance_metadata.get("framework", ""),
"version": compliance_metadata.get("version", ""),
"requirements_passed": summary.requirements_passed,
"requirements_failed": summary.requirements_failed,
"requirements_manual": summary.requirements_manual,
"total_requirements": summary.total_requirements,
}
)
@@ -3502,18 +3630,6 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
}
]
)
try:
if task := self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
):
return task
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
regions = list(
self.get_queryset()
.filter(scan_id=scan_id)
@@ -3523,6 +3639,15 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
)
result = {"regions": regions}
if regions:
serializer = self.get_serializer(data=result)
serializer.is_valid(raise_exception=True)
return Response(serializer.data, status=status.HTTP_200_OK)
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
serializer = self.get_serializer(data=result)
serializer.is_valid(raise_exception=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@@ -3555,18 +3680,6 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
}
]
)
try:
if task := self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
):
return task
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
filtered_queryset = self.filter_queryset(self.get_queryset())
all_requirements = filtered_queryset.values(
@@ -3626,6 +3739,13 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
requirements_summary, many=True
)
if requirements_summary:
return Response(serializer.data, status=status.HTTP_200_OK)
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
return Response(serializer.data, status=status.HTTP_200_OK)
@action(detail=False, methods=["get"], url_name="attributes")
@@ -3644,15 +3764,6 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
)
provider_type = None
try:
sample_requirement = (
self.get_queryset().filter(compliance_id=compliance_id).first()
)
if sample_requirement:
provider_type = sample_requirement.scan.provider.provider
except Exception:
pass
# If we couldn't determine from database, try each provider type
if not provider_type:

View File

@@ -1,5 +1,10 @@
from collections import defaultdict
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import (
ComplianceOverviewSummary,
ComplianceRequirementOverview,
Resource,
ResourceFindingMapping,
ResourceScanSummary,
@@ -9,7 +14,7 @@ from api.models import (
def backfill_resource_scan_summaries(tenant_id: str, scan_id: str):
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
if ResourceScanSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
@@ -59,3 +64,114 @@ def backfill_resource_scan_summaries(tenant_id: str, scan_id: str):
)
return {"status": "backfilled", "inserted": len(summaries)}
def backfill_compliance_summaries(tenant_id: str, scan_id: str):
"""
Backfill ComplianceOverviewSummary records for a completed scan.
This function checks if summary records already exist for the scan.
If not, it aggregates compliance requirement data and creates the summaries.
Args:
tenant_id: Target tenant UUID
scan_id: Scan UUID to backfill
Returns:
dict: Status indicating whether backfill was performed
"""
with rls_transaction(tenant_id):
if ComplianceOverviewSummary.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).exists():
return {"status": "already backfilled"}
with rls_transaction(tenant_id):
if not Scan.objects.filter(
tenant_id=tenant_id,
id=scan_id,
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
).exists():
return {"status": "scan is not completed"}
# Fetch all compliance requirement overview rows for this scan
requirement_rows = ComplianceRequirementOverview.objects.filter(
tenant_id=tenant_id, scan_id=scan_id
).values(
"compliance_id",
"requirement_id",
"requirement_status",
)
if not requirement_rows:
return {"status": "no compliance data to backfill"}
# Group by (compliance_id, requirement_id) across regions
requirement_statuses = defaultdict(
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
)
for row in requirement_rows:
compliance_id = row["compliance_id"]
requirement_id = row["requirement_id"]
requirement_status = row["requirement_status"]
# Aggregate requirement status across regions
key = (compliance_id, requirement_id)
requirement_statuses[key]["total_count"] += 1
if requirement_status == "FAIL":
requirement_statuses[key]["fail_count"] += 1
elif requirement_status == "PASS":
requirement_statuses[key]["pass_count"] += 1
# Determine per-requirement status and aggregate to compliance level
compliance_summaries = defaultdict(
lambda: {
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
}
)
for (compliance_id, requirement_id), counts in requirement_statuses.items():
# Apply business rule: any FAIL → requirement fails
if counts["fail_count"] > 0:
req_status = "FAIL"
elif counts["pass_count"] == counts["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
# Aggregate to compliance level
compliance_summaries[compliance_id]["total_requirements"] += 1
if req_status == "PASS":
compliance_summaries[compliance_id]["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_summaries[compliance_id]["requirements_failed"] += 1
else:
compliance_summaries[compliance_id]["requirements_manual"] += 1
# Create summary objects
summary_objects = []
for compliance_id, data in compliance_summaries.items():
summary_objects.append(
ComplianceOverviewSummary(
tenant_id=tenant_id,
scan_id=scan_id,
compliance_id=compliance_id,
requirements_passed=data["requirements_passed"],
requirements_failed=data["requirements_failed"],
requirements_manual=data["requirements_manual"],
total_requirements=data["total_requirements"],
)
)
# Bulk insert summaries
if summary_objects:
ComplianceOverviewSummary.objects.bulk_create(
summary_objects, batch_size=500, ignore_conflicts=True
)
return {"status": "backfilled", "inserted": len(summary_objects)}

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,10 @@ from celery.utils.log import get_task_logger
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from tasks.jobs.backfill import backfill_resource_scan_summaries
from tasks.jobs.backfill import (
backfill_compliance_summaries,
backfill_resource_scan_summaries,
)
from tasks.jobs.connection import (
check_integration_connection,
check_lighthouse_connection,
@@ -494,6 +497,21 @@ def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
return backfill_resource_scan_summaries(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(name="backfill-compliance-summaries", queue="backfill")
def backfill_compliance_summaries_task(tenant_id: str, scan_id: str):
"""
Tries to backfill compliance overview summaries for a completed scan.
This task aggregates compliance requirement data across regions
to create pre-computed summary records for fast compliance overview queries.
Args:
tenant_id (str): The tenant identifier.
scan_id (str): The scan identifier.
"""
return backfill_compliance_summaries(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(base=RLSTask, name="scan-compliance-overviews", queue="compliance")
def create_compliance_requirements_task(tenant_id: str, scan_id: str):
"""

View File

@@ -1,43 +1,53 @@
from uuid import uuid4
import pytest
from tasks.jobs.backfill import backfill_resource_scan_summaries
from tasks.jobs.backfill import (
backfill_compliance_summaries,
backfill_resource_scan_summaries,
)
from api.models import ResourceScanSummary, Scan, StateChoices
from api.models import (
ComplianceOverviewSummary,
ResourceScanSummary,
Scan,
StateChoices,
)
@pytest.fixture(scope="function")
def resource_scan_summary_data(scans_fixture):
scan = scans_fixture[0]
return ResourceScanSummary.objects.create(
tenant_id=scan.tenant_id,
scan_id=scan.id,
resource_id=str(uuid4()),
service="aws",
region="us-east-1",
resource_type="instance",
)
@pytest.fixture(scope="function")
def get_not_completed_scans(providers_fixture):
provider_id = providers_fixture[0].id
tenant_id = providers_fixture[0].tenant_id
scan_1 = Scan.objects.create(
tenant_id=tenant_id,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.EXECUTING,
provider_id=provider_id,
)
scan_2 = Scan.objects.create(
tenant_id=tenant_id,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
provider_id=provider_id,
)
return scan_1, scan_2
@pytest.mark.django_db
class TestBackfillResourceScanSummaries:
@pytest.fixture(scope="function")
def resource_scan_summary_data(self, scans_fixture):
scan = scans_fixture[0]
return ResourceScanSummary.objects.create(
tenant_id=scan.tenant_id,
scan_id=scan.id,
resource_id=str(uuid4()),
service="aws",
region="us-east-1",
resource_type="instance",
)
@pytest.fixture(scope="function")
def get_not_completed_scans(self, providers_fixture):
provider_id = providers_fixture[0].id
tenant_id = providers_fixture[0].tenant_id
scan_1 = Scan.objects.create(
tenant_id=tenant_id,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.EXECUTING,
provider_id=provider_id,
)
scan_2 = Scan.objects.create(
tenant_id=tenant_id,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
provider_id=provider_id,
)
return scan_1, scan_2
def test_already_backfilled(self, resource_scan_summary_data):
tenant_id = resource_scan_summary_data.tenant_id
scan_id = resource_scan_summary_data.scan_id
@@ -77,3 +87,88 @@ class TestBackfillResourceScanSummaries:
assert summary.service == resource.service
assert summary.region == resource.region
assert summary.resource_type == resource.type
def test_no_resources_to_backfill(self, scans_fixture):
scan = scans_fixture[1] # Failed scan with no findings/resources
tenant_id = str(scan.tenant_id)
scan_id = str(scan.id)
result = backfill_resource_scan_summaries(tenant_id, scan_id)
assert result == {"status": "no resources to backfill"}
@pytest.mark.django_db
class TestBackfillComplianceSummaries:
def test_already_backfilled(self, scans_fixture):
scan = scans_fixture[0]
tenant_id = str(scan.tenant_id)
ComplianceOverviewSummary.objects.create(
tenant_id=scan.tenant_id,
scan=scan,
compliance_id="aws_account_security_onboarding_aws",
requirements_passed=1,
requirements_failed=0,
requirements_manual=0,
total_requirements=1,
)
result = backfill_compliance_summaries(tenant_id, str(scan.id))
assert result == {"status": "already backfilled"}
def test_not_completed_scan(self, get_not_completed_scans):
for scan in get_not_completed_scans:
result = backfill_compliance_summaries(str(scan.tenant_id), str(scan.id))
assert result == {"status": "scan is not completed"}
def test_no_compliance_data(self, scans_fixture):
scan = scans_fixture[1] # Failed scan with no compliance rows
result = backfill_compliance_summaries(str(scan.tenant_id), str(scan.id))
assert result == {"status": "no compliance data to backfill"}
def test_backfill_creates_compliance_summaries(
self, tenants_fixture, scans_fixture, compliance_requirements_overviews_fixture
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
result = backfill_compliance_summaries(str(tenant.id), str(scan.id))
expected = {
"aws_account_security_onboarding_aws": {
"requirements_passed": 1,
"requirements_failed": 1,
"requirements_manual": 1,
"total_requirements": 3,
},
"cis_1.4_aws": {
"requirements_passed": 0,
"requirements_failed": 1,
"requirements_manual": 0,
"total_requirements": 1,
},
"mitre_attack_aws": {
"requirements_passed": 0,
"requirements_failed": 1,
"requirements_manual": 0,
"total_requirements": 1,
},
}
assert result == {"status": "backfilled", "inserted": len(expected)}
summaries = ComplianceOverviewSummary.objects.filter(
tenant_id=str(tenant.id), scan_id=str(scan.id)
)
assert summaries.count() == len(expected)
for summary in summaries:
assert summary.compliance_id in expected
expected_counts = expected[summary.compliance_id]
assert summary.requirements_passed == expected_counts["requirements_passed"]
assert summary.requirements_failed == expected_counts["requirements_failed"]
assert summary.requirements_manual == expected_counts["requirements_manual"]
assert summary.total_requirements == expected_counts["total_requirements"]

File diff suppressed because it is too large Load Diff