mirror of
https://github.com/prowler-cloud/prowler.git
synced 2025-12-19 05:17:47 +00:00
feat(db): optimize write queries for scan related tasks (#9190)
Co-authored-by: Josema Camacho <josema@prowler.com>
This commit is contained in:
committed by
GitHub
parent
ce5f2cc5ed
commit
46666d29d3
@@ -17,12 +17,15 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- Tenant-wide ThreatScore overview aggregation and snapshot persistence with backfill support [(#9148)](https://github.com/prowler-cloud/prowler/pull/9148)
|
||||
- Support for MongoDB Atlas provider [(#9167)](https://github.com/prowler-cloud/prowler/pull/9167)
|
||||
|
||||
### Changed
|
||||
- Optimized database write queries for scan related tasks [(#9190)](https://github.com/prowler-cloud/prowler/pull/9190)
|
||||
|
||||
### Security
|
||||
- Django updated to the latest 5.1 security release, 5.1.14, due to problems with potential [SQL injection](https://github.com/prowler-cloud/prowler/security/dependabot/113) and [denial-of-service vulnerability](https://github.com/prowler-cloud/prowler/security/dependabot/114) [(#9176)](https://github.com/prowler-cloud/prowler/pull/9176)
|
||||
|
||||
---
|
||||
|
||||
## [1.14.2] (Prowler UNRELEASED)
|
||||
## [1.14.2] (Prowler 5.13.2)
|
||||
|
||||
### Fixed
|
||||
- Update unique constraint for `Provider` model to exclude soft-deleted entries, resolving duplicate errors when re-deleting providers.[(#9054)](https://github.com/prowler-cloud/prowler/pull/9054)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from django.contrib.postgres.operations import RemoveIndexConcurrently
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
atomic = False
|
||||
|
||||
dependencies = [
|
||||
("api", "0057_threatscoresnapshot"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
RemoveIndexConcurrently(
|
||||
model_name="compliancerequirementoverview",
|
||||
name="cro_tenant_scan_idx",
|
||||
),
|
||||
RemoveIndexConcurrently(
|
||||
model_name="compliancerequirementoverview",
|
||||
name="cro_scan_comp_idx",
|
||||
),
|
||||
RemoveIndexConcurrently(
|
||||
model_name="compliancerequirementoverview",
|
||||
name="cro_scan_comp_req_idx",
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,75 @@
|
||||
# Generated by Django 5.1.13 on 2025-10-30 15:23
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0058_drop_redundant_compliance_requirement_indexes"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ComplianceOverviewSummary",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
("compliance_id", models.TextField()),
|
||||
("requirements_passed", models.IntegerField(default=0)),
|
||||
("requirements_failed", models.IntegerField(default=0)),
|
||||
("requirements_manual", models.IntegerField(default=0)),
|
||||
("total_requirements", models.IntegerField(default=0)),
|
||||
(
|
||||
"scan",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="compliance_summaries",
|
||||
related_query_name="compliance_summary",
|
||||
to="api.scan",
|
||||
),
|
||||
),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "compliance_overview_summaries",
|
||||
"abstract": False,
|
||||
"indexes": [
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id"], name="cos_tenant_scan_idx"
|
||||
)
|
||||
],
|
||||
"constraints": [
|
||||
models.UniqueConstraint(
|
||||
fields=("tenant_id", "scan_id", "compliance_id"),
|
||||
name="unique_compliance_summary_per_scan",
|
||||
)
|
||||
],
|
||||
},
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="complianceoverviewsummary",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_complianceoverviewsummary",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -1371,35 +1371,70 @@ class ComplianceRequirementOverview(RowLevelSecurityProtectedModel):
|
||||
),
|
||||
]
|
||||
indexes = [
|
||||
models.Index(fields=["tenant_id", "scan_id"], name="cro_tenant_scan_idx"),
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id", "compliance_id"],
|
||||
name="cro_scan_comp_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id", "compliance_id", "region"],
|
||||
name="cro_scan_comp_reg_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id", "compliance_id", "requirement_id"],
|
||||
name="cro_scan_comp_req_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=[
|
||||
"tenant_id",
|
||||
"scan_id",
|
||||
"compliance_id",
|
||||
"requirement_id",
|
||||
"region",
|
||||
],
|
||||
name="cro_scan_comp_req_reg_idx",
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "compliance-requirements-overviews"
|
||||
|
||||
|
||||
class ComplianceOverviewSummary(RowLevelSecurityProtectedModel):
|
||||
"""
|
||||
Pre-aggregated compliance overview aggregated across ALL regions.
|
||||
One row per (scan_id, compliance_id) combination.
|
||||
|
||||
This table optimizes the common case where users view overall compliance
|
||||
without filtering by region. For region-specific views, the detailed
|
||||
ComplianceRequirementOverview table is used instead.
|
||||
"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
|
||||
scan = models.ForeignKey(
|
||||
Scan,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="compliance_summaries",
|
||||
related_query_name="compliance_summary",
|
||||
)
|
||||
|
||||
compliance_id = models.TextField(blank=False)
|
||||
|
||||
# Pre-aggregated scores (computed across ALL regions)
|
||||
requirements_passed = models.IntegerField(default=0)
|
||||
requirements_failed = models.IntegerField(default=0)
|
||||
requirements_manual = models.IntegerField(default=0)
|
||||
total_requirements = models.IntegerField(default=0)
|
||||
|
||||
class Meta(RowLevelSecurityProtectedModel.Meta):
|
||||
db_table = "compliance_overview_summaries"
|
||||
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=("tenant_id", "scan_id", "compliance_id"),
|
||||
name="unique_compliance_summary_per_scan",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
indexes = [
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id"],
|
||||
name="cos_tenant_scan_idx",
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "compliance-overview-summaries"
|
||||
|
||||
|
||||
class ScanSummary(RowLevelSecurityProtectedModel):
|
||||
objects = ActiveProviderManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
@@ -35,6 +35,8 @@ from rest_framework.response import Response
|
||||
from api.compliance import get_compliance_frameworks
|
||||
from api.db_router import MainRouter
|
||||
from api.models import (
|
||||
ComplianceOverviewSummary,
|
||||
ComplianceRequirementOverview,
|
||||
Finding,
|
||||
Integration,
|
||||
Invitation,
|
||||
@@ -55,6 +57,7 @@ from api.models import (
|
||||
Scan,
|
||||
ScanSummary,
|
||||
StateChoices,
|
||||
StatusChoices,
|
||||
Task,
|
||||
TenantAPIKey,
|
||||
ThreatScoreSnapshot,
|
||||
@@ -5814,16 +5817,44 @@ class TestProviderGroupMembershipViewSet:
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestComplianceOverviewViewSet:
|
||||
def test_compliance_overview_list_none(self, authenticated_client):
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_backfill_task(self):
|
||||
with patch("api.v1.views.backfill_compliance_summaries_task.delay") as mock:
|
||||
yield mock
|
||||
|
||||
def test_compliance_overview_list_none(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
mock_backfill_task,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
scan = Scan.objects.create(
|
||||
name="empty-compliance-scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{"filter[scan_id]": "8d20ac7d-4cbc-435e-85f4-359be37af821"},
|
||||
{"filter[scan_id]": str(scan.id)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 0
|
||||
mock_backfill_task.assert_called_once()
|
||||
_, kwargs = mock_backfill_task.call_args
|
||||
assert kwargs["scan_id"] == str(scan.id)
|
||||
assert str(kwargs["tenant_id"]) == str(tenant.id)
|
||||
|
||||
def test_compliance_overview_list(
|
||||
self, authenticated_client, compliance_requirements_overviews_fixture
|
||||
self,
|
||||
authenticated_client,
|
||||
compliance_requirements_overviews_fixture,
|
||||
mock_backfill_task,
|
||||
):
|
||||
# List compliance overviews with existing data
|
||||
requirement_overview1 = compliance_requirements_overviews_fixture[0]
|
||||
@@ -5853,6 +5884,90 @@ class TestComplianceOverviewViewSet:
|
||||
assert "requirements_failed" in attributes
|
||||
assert "requirements_manual" in attributes
|
||||
assert "total_requirements" in attributes
|
||||
mock_backfill_task.assert_called_once()
|
||||
_, kwargs = mock_backfill_task.call_args
|
||||
assert kwargs["scan_id"] == scan_id
|
||||
|
||||
def test_compliance_overview_list_uses_preaggregated_summaries(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
mock_backfill_task,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
scan = Scan.objects.create(
|
||||
name="preaggregated-scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
ComplianceRequirementOverview.objects.create(
|
||||
tenant=tenant,
|
||||
scan=scan,
|
||||
compliance_id="cis_1.4_aws",
|
||||
framework="CIS-1.4-AWS",
|
||||
version="1.4",
|
||||
description="CIS AWS Foundations Benchmark v1.4.0",
|
||||
region="eu-west-1",
|
||||
requirement_id="framework-metadata",
|
||||
requirement_status=StatusChoices.PASS,
|
||||
passed_checks=1,
|
||||
failed_checks=0,
|
||||
total_checks=1,
|
||||
)
|
||||
|
||||
ComplianceOverviewSummary.objects.create(
|
||||
tenant=tenant,
|
||||
scan=scan,
|
||||
compliance_id="cis_1.4_aws",
|
||||
requirements_passed=5,
|
||||
requirements_failed=1,
|
||||
requirements_manual=2,
|
||||
total_requirements=8,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{"filter[scan_id]": str(scan.id)},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 1
|
||||
overview = data[0]
|
||||
assert overview["id"] == "cis_1.4_aws"
|
||||
assert overview["attributes"]["requirements_passed"] == 5
|
||||
assert overview["attributes"]["requirements_failed"] == 1
|
||||
assert overview["attributes"]["requirements_manual"] == 2
|
||||
assert overview["attributes"]["total_requirements"] == 8
|
||||
assert "framework" in overview["attributes"]
|
||||
assert "version" in overview["attributes"]
|
||||
mock_backfill_task.assert_not_called()
|
||||
|
||||
def test_compliance_overview_region_filter_skips_backfill(
|
||||
self,
|
||||
authenticated_client,
|
||||
compliance_requirements_overviews_fixture,
|
||||
mock_backfill_task,
|
||||
):
|
||||
requirement_overview = compliance_requirements_overviews_fixture[0]
|
||||
scan_id = str(requirement_overview.scan.id)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{
|
||||
"filter[scan_id]": scan_id,
|
||||
"filter[region]": requirement_overview.region,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) >= 1
|
||||
mock_backfill_task.assert_not_called()
|
||||
|
||||
def test_compliance_overview_metadata(
|
||||
self, authenticated_client, compliance_requirements_overviews_fixture
|
||||
@@ -6006,6 +6121,11 @@ class TestComplianceOverviewViewSet:
|
||||
requirement_overview1 = compliance_requirements_overviews_fixture[0]
|
||||
scan_id = str(requirement_overview1.scan.id)
|
||||
|
||||
# Remove existing compliance data so the view falls back to task checks
|
||||
scan = requirement_overview1.scan
|
||||
ComplianceOverviewSummary.objects.filter(scan=scan).delete()
|
||||
ComplianceRequirementOverview.objects.filter(scan=scan).delete()
|
||||
|
||||
# Mock a running task
|
||||
with patch.object(
|
||||
ComplianceOverviewViewSet, "get_task_response_if_running"
|
||||
@@ -6033,6 +6153,11 @@ class TestComplianceOverviewViewSet:
|
||||
requirement_overview1 = compliance_requirements_overviews_fixture[0]
|
||||
scan_id = str(requirement_overview1.scan.id)
|
||||
|
||||
# Remove existing compliance data so the view falls back to task checks
|
||||
scan = requirement_overview1.scan
|
||||
ComplianceOverviewSummary.objects.filter(scan=scan).delete()
|
||||
ComplianceRequirementOverview.objects.filter(scan=scan).delete()
|
||||
|
||||
# Mock a failed task
|
||||
with patch.object(
|
||||
ComplianceOverviewViewSet, "get_task_response_if_running"
|
||||
@@ -6056,6 +6181,8 @@ class TestComplianceOverviewViewSet:
|
||||
("framework", "framework", 1),
|
||||
("version", "version", 1),
|
||||
("region", "region", 1),
|
||||
("region__in", "region", 1),
|
||||
("region.in", "region", 1),
|
||||
],
|
||||
)
|
||||
def test_compliance_overview_filters(
|
||||
|
||||
@@ -75,6 +75,7 @@ from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
|
||||
from tasks.beat import schedule_provider_scan
|
||||
from tasks.jobs.export import get_s3_client
|
||||
from tasks.tasks import (
|
||||
backfill_compliance_summaries_task,
|
||||
backfill_scan_resource_summaries_task,
|
||||
check_integration_connection_task,
|
||||
check_lighthouse_connection_task,
|
||||
@@ -126,6 +127,7 @@ from api.filters import (
|
||||
UserFilter,
|
||||
)
|
||||
from api.models import (
|
||||
ComplianceOverviewSummary,
|
||||
ComplianceRequirementOverview,
|
||||
Finding,
|
||||
Integration,
|
||||
@@ -3398,6 +3400,168 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
raise MethodNotAllowed(method="GET")
|
||||
|
||||
def _compliance_summaries_queryset(self, scan_id):
|
||||
"""Return pre-aggregated summaries constrained by RBAC visibility."""
|
||||
role = get_role(self.request.user)
|
||||
unlimited_visibility = getattr(
|
||||
role, Permissions.UNLIMITED_VISIBILITY.value, False
|
||||
)
|
||||
summaries = ComplianceOverviewSummary.objects.filter(
|
||||
tenant_id=self.request.tenant_id,
|
||||
scan_id=scan_id,
|
||||
)
|
||||
|
||||
if not unlimited_visibility:
|
||||
providers = Provider.all_objects.filter(
|
||||
provider_groups__in=role.provider_groups.all()
|
||||
).distinct()
|
||||
summaries = summaries.filter(scan__provider__in=providers)
|
||||
|
||||
return summaries
|
||||
|
||||
def _get_compliance_template(self, *, provider=None, scan_id=None):
|
||||
"""Return the compliance template for the given provider or scan."""
|
||||
if provider is None and scan_id is not None:
|
||||
scan = Scan.all_objects.select_related("provider").get(pk=scan_id)
|
||||
provider = scan.provider
|
||||
|
||||
if not provider:
|
||||
return {}
|
||||
|
||||
return PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE.get(provider.provider, {})
|
||||
|
||||
def _aggregate_compliance_overview(self, queryset, template_metadata=None):
|
||||
"""
|
||||
Aggregate requirement rows into compliance overview dictionaries.
|
||||
|
||||
Args:
|
||||
queryset: ComplianceRequirementOverview queryset already filtered.
|
||||
template_metadata: Optional dict mapping compliance_id -> metadata.
|
||||
"""
|
||||
template_metadata = template_metadata or {}
|
||||
requirement_status_subquery = queryset.values(
|
||||
"compliance_id", "requirement_id"
|
||||
).annotate(
|
||||
fail_count=Count("id", filter=Q(requirement_status="FAIL")),
|
||||
pass_count=Count("id", filter=Q(requirement_status="PASS")),
|
||||
total_count=Count("id"),
|
||||
)
|
||||
|
||||
compliance_data = {}
|
||||
fallback_metadata = {
|
||||
item["compliance_id"]: {
|
||||
"framework": item["framework"],
|
||||
"version": item["version"],
|
||||
}
|
||||
for item in queryset.values(
|
||||
"compliance_id", "framework", "version"
|
||||
).distinct()
|
||||
}
|
||||
|
||||
for item in requirement_status_subquery:
|
||||
compliance_id = item["compliance_id"]
|
||||
|
||||
if item["fail_count"] > 0:
|
||||
req_status = "FAIL"
|
||||
elif item["pass_count"] == item["total_count"]:
|
||||
req_status = "PASS"
|
||||
else:
|
||||
req_status = "MANUAL"
|
||||
|
||||
compliance_status = compliance_data.setdefault(
|
||||
compliance_id,
|
||||
{
|
||||
"total_requirements": 0,
|
||||
"requirements_passed": 0,
|
||||
"requirements_failed": 0,
|
||||
"requirements_manual": 0,
|
||||
},
|
||||
)
|
||||
|
||||
compliance_status["total_requirements"] += 1
|
||||
if req_status == "PASS":
|
||||
compliance_status["requirements_passed"] += 1
|
||||
elif req_status == "FAIL":
|
||||
compliance_status["requirements_failed"] += 1
|
||||
else:
|
||||
compliance_status["requirements_manual"] += 1
|
||||
|
||||
response_data = []
|
||||
for compliance_id, data in compliance_data.items():
|
||||
template = template_metadata.get(compliance_id, {})
|
||||
fallback = fallback_metadata.get(compliance_id, {})
|
||||
|
||||
response_data.append(
|
||||
{
|
||||
"id": compliance_id,
|
||||
"compliance_id": compliance_id,
|
||||
"framework": template.get("framework")
|
||||
or fallback.get("framework", ""),
|
||||
"version": template.get("version") or fallback.get("version", ""),
|
||||
"requirements_passed": data["requirements_passed"],
|
||||
"requirements_failed": data["requirements_failed"],
|
||||
"requirements_manual": data["requirements_manual"],
|
||||
"total_requirements": data["total_requirements"],
|
||||
}
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(response_data, many=True)
|
||||
return serializer.data
|
||||
|
||||
def _task_response_if_running(self, scan_id):
|
||||
"""Check for an in-progress task only when no compliance data exists."""
|
||||
try:
|
||||
return self.get_task_response_if_running(
|
||||
task_name="scan-compliance-overviews",
|
||||
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
|
||||
raise_on_not_found=False,
|
||||
)
|
||||
except TaskFailedException:
|
||||
return Response(
|
||||
{"detail": "Task failed to generate compliance overview data."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def _list_with_region_filter(self, scan_id, region_filter):
|
||||
"""
|
||||
Fall back to detailed ComplianceRequirementOverview query when region filter is applied.
|
||||
This uses the original aggregation logic across filtered regions.
|
||||
"""
|
||||
regions = region_filter.split(",") if "," in region_filter else [region_filter]
|
||||
queryset = self.filter_queryset(self.get_queryset()).filter(
|
||||
scan_id=scan_id,
|
||||
region__in=regions,
|
||||
)
|
||||
|
||||
data = self._aggregate_compliance_overview(queryset)
|
||||
if data:
|
||||
return Response(data)
|
||||
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
return Response(data)
|
||||
|
||||
def _list_without_region_aggregation(self, scan_id):
|
||||
"""
|
||||
Fall back aggregation when compliance summaries don't exist yet.
|
||||
Aggregates ComplianceRequirementOverview data across ALL regions.
|
||||
"""
|
||||
queryset = self.filter_queryset(self.get_queryset()).filter(scan_id=scan_id)
|
||||
compliance_template = self._get_compliance_template(scan_id=scan_id)
|
||||
data = self._aggregate_compliance_overview(
|
||||
queryset, template_metadata=compliance_template
|
||||
)
|
||||
if data:
|
||||
return Response(data)
|
||||
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
return Response(data)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
scan_id = request.query_params.get("filter[scan_id]")
|
||||
if not scan_id:
|
||||
@@ -3411,77 +3575,41 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
}
|
||||
]
|
||||
)
|
||||
try:
|
||||
if task := self.get_task_response_if_running(
|
||||
task_name="scan-compliance-overviews",
|
||||
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
|
||||
raise_on_not_found=False,
|
||||
):
|
||||
return task
|
||||
except TaskFailedException:
|
||||
return Response(
|
||||
{"detail": "Task failed to generate compliance overview data."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
||||
region_filter = request.query_params.get(
|
||||
"filter[region]"
|
||||
) or request.query_params.get("filter[region__in]")
|
||||
|
||||
if region_filter:
|
||||
# Fall back to detailed query with region filtering
|
||||
return self._list_with_region_filter(scan_id, region_filter)
|
||||
|
||||
summaries = list(self._compliance_summaries_queryset(scan_id))
|
||||
if not summaries:
|
||||
# Trigger async backfill for next time
|
||||
backfill_compliance_summaries_task.delay(
|
||||
tenant_id=self.request.tenant_id, scan_id=scan_id
|
||||
)
|
||||
queryset = self.filter_queryset(self.filter_queryset(self.get_queryset()))
|
||||
# Use fallback aggregation for this request
|
||||
return self._list_without_region_aggregation(scan_id)
|
||||
|
||||
requirement_status_subquery = queryset.values(
|
||||
"compliance_id", "requirement_id"
|
||||
).annotate(
|
||||
fail_count=Count("id", filter=Q(requirement_status="FAIL")),
|
||||
pass_count=Count("id", filter=Q(requirement_status="PASS")),
|
||||
total_count=Count("id"),
|
||||
)
|
||||
|
||||
compliance_data = {}
|
||||
framework_info = {}
|
||||
|
||||
for item in queryset.values("compliance_id", "framework", "version").distinct():
|
||||
framework_info[item["compliance_id"]] = {
|
||||
"framework": item["framework"],
|
||||
"version": item["version"],
|
||||
}
|
||||
|
||||
for item in requirement_status_subquery:
|
||||
compliance_id = item["compliance_id"]
|
||||
|
||||
if item["fail_count"] > 0:
|
||||
req_status = "FAIL"
|
||||
elif item["pass_count"] == item["total_count"]:
|
||||
req_status = "PASS"
|
||||
else:
|
||||
req_status = "MANUAL"
|
||||
|
||||
if compliance_id not in compliance_data:
|
||||
compliance_data[compliance_id] = {
|
||||
"total_requirements": 0,
|
||||
"requirements_passed": 0,
|
||||
"requirements_failed": 0,
|
||||
"requirements_manual": 0,
|
||||
}
|
||||
|
||||
compliance_data[compliance_id]["total_requirements"] += 1
|
||||
if req_status == "PASS":
|
||||
compliance_data[compliance_id]["requirements_passed"] += 1
|
||||
elif req_status == "FAIL":
|
||||
compliance_data[compliance_id]["requirements_failed"] += 1
|
||||
else:
|
||||
compliance_data[compliance_id]["requirements_manual"] += 1
|
||||
# Get compliance template for provider to enrich with framework/version
|
||||
compliance_template = self._get_compliance_template(scan_id=scan_id)
|
||||
|
||||
# Convert to response format with framework/version enrichment
|
||||
response_data = []
|
||||
for compliance_id, data in compliance_data.items():
|
||||
framework = framework_info.get(compliance_id, {})
|
||||
|
||||
for summary in summaries:
|
||||
compliance_metadata = compliance_template.get(summary.compliance_id, {})
|
||||
response_data.append(
|
||||
{
|
||||
"id": compliance_id,
|
||||
"compliance_id": compliance_id,
|
||||
"framework": framework.get("framework", ""),
|
||||
"version": framework.get("version", ""),
|
||||
"requirements_passed": data["requirements_passed"],
|
||||
"requirements_failed": data["requirements_failed"],
|
||||
"requirements_manual": data["requirements_manual"],
|
||||
"total_requirements": data["total_requirements"],
|
||||
"id": summary.compliance_id,
|
||||
"compliance_id": summary.compliance_id,
|
||||
"framework": compliance_metadata.get("framework", ""),
|
||||
"version": compliance_metadata.get("version", ""),
|
||||
"requirements_passed": summary.requirements_passed,
|
||||
"requirements_failed": summary.requirements_failed,
|
||||
"requirements_manual": summary.requirements_manual,
|
||||
"total_requirements": summary.total_requirements,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3502,18 +3630,6 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
}
|
||||
]
|
||||
)
|
||||
try:
|
||||
if task := self.get_task_response_if_running(
|
||||
task_name="scan-compliance-overviews",
|
||||
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
|
||||
raise_on_not_found=False,
|
||||
):
|
||||
return task
|
||||
except TaskFailedException:
|
||||
return Response(
|
||||
{"detail": "Task failed to generate compliance overview data."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
regions = list(
|
||||
self.get_queryset()
|
||||
.filter(scan_id=scan_id)
|
||||
@@ -3523,6 +3639,15 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
)
|
||||
result = {"regions": regions}
|
||||
|
||||
if regions:
|
||||
serializer = self.get_serializer(data=result)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
serializer = self.get_serializer(data=result)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
@@ -3555,18 +3680,6 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
}
|
||||
]
|
||||
)
|
||||
try:
|
||||
if task := self.get_task_response_if_running(
|
||||
task_name="scan-compliance-overviews",
|
||||
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
|
||||
raise_on_not_found=False,
|
||||
):
|
||||
return task
|
||||
except TaskFailedException:
|
||||
return Response(
|
||||
{"detail": "Task failed to generate compliance overview data."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
filtered_queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
all_requirements = filtered_queryset.values(
|
||||
@@ -3626,6 +3739,13 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
requirements_summary, many=True
|
||||
)
|
||||
|
||||
if requirements_summary:
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="attributes")
|
||||
@@ -3644,15 +3764,6 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
)
|
||||
|
||||
provider_type = None
|
||||
try:
|
||||
sample_requirement = (
|
||||
self.get_queryset().filter(compliance_id=compliance_id).first()
|
||||
)
|
||||
|
||||
if sample_requirement:
|
||||
provider_type = sample_requirement.scan.provider.provider
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we couldn't determine from database, try each provider type
|
||||
if not provider_type:
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
ComplianceOverviewSummary,
|
||||
ComplianceRequirementOverview,
|
||||
Resource,
|
||||
ResourceFindingMapping,
|
||||
ResourceScanSummary,
|
||||
@@ -9,7 +14,7 @@ from api.models import (
|
||||
|
||||
|
||||
def backfill_resource_scan_summaries(tenant_id: str, scan_id: str):
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
if ResourceScanSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
@@ -59,3 +64,114 @@ def backfill_resource_scan_summaries(tenant_id: str, scan_id: str):
|
||||
)
|
||||
|
||||
return {"status": "backfilled", "inserted": len(summaries)}
|
||||
|
||||
|
||||
def backfill_compliance_summaries(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Backfill ComplianceOverviewSummary records for a completed scan.
|
||||
|
||||
This function checks if summary records already exist for the scan.
|
||||
If not, it aggregates compliance requirement data and creates the summaries.
|
||||
|
||||
Args:
|
||||
tenant_id: Target tenant UUID
|
||||
scan_id: Scan UUID to backfill
|
||||
|
||||
Returns:
|
||||
dict: Status indicating whether backfill was performed
|
||||
"""
|
||||
with rls_transaction(tenant_id):
|
||||
if ComplianceOverviewSummary.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).exists():
|
||||
return {"status": "already backfilled"}
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
if not Scan.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
id=scan_id,
|
||||
state__in=(StateChoices.COMPLETED, StateChoices.FAILED),
|
||||
).exists():
|
||||
return {"status": "scan is not completed"}
|
||||
|
||||
# Fetch all compliance requirement overview rows for this scan
|
||||
requirement_rows = ComplianceRequirementOverview.objects.filter(
|
||||
tenant_id=tenant_id, scan_id=scan_id
|
||||
).values(
|
||||
"compliance_id",
|
||||
"requirement_id",
|
||||
"requirement_status",
|
||||
)
|
||||
|
||||
if not requirement_rows:
|
||||
return {"status": "no compliance data to backfill"}
|
||||
|
||||
# Group by (compliance_id, requirement_id) across regions
|
||||
requirement_statuses = defaultdict(
|
||||
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
|
||||
)
|
||||
|
||||
for row in requirement_rows:
|
||||
compliance_id = row["compliance_id"]
|
||||
requirement_id = row["requirement_id"]
|
||||
requirement_status = row["requirement_status"]
|
||||
|
||||
# Aggregate requirement status across regions
|
||||
key = (compliance_id, requirement_id)
|
||||
requirement_statuses[key]["total_count"] += 1
|
||||
|
||||
if requirement_status == "FAIL":
|
||||
requirement_statuses[key]["fail_count"] += 1
|
||||
elif requirement_status == "PASS":
|
||||
requirement_statuses[key]["pass_count"] += 1
|
||||
|
||||
# Determine per-requirement status and aggregate to compliance level
|
||||
compliance_summaries = defaultdict(
|
||||
lambda: {
|
||||
"total_requirements": 0,
|
||||
"requirements_passed": 0,
|
||||
"requirements_failed": 0,
|
||||
"requirements_manual": 0,
|
||||
}
|
||||
)
|
||||
|
||||
for (compliance_id, requirement_id), counts in requirement_statuses.items():
|
||||
# Apply business rule: any FAIL → requirement fails
|
||||
if counts["fail_count"] > 0:
|
||||
req_status = "FAIL"
|
||||
elif counts["pass_count"] == counts["total_count"]:
|
||||
req_status = "PASS"
|
||||
else:
|
||||
req_status = "MANUAL"
|
||||
|
||||
# Aggregate to compliance level
|
||||
compliance_summaries[compliance_id]["total_requirements"] += 1
|
||||
if req_status == "PASS":
|
||||
compliance_summaries[compliance_id]["requirements_passed"] += 1
|
||||
elif req_status == "FAIL":
|
||||
compliance_summaries[compliance_id]["requirements_failed"] += 1
|
||||
else:
|
||||
compliance_summaries[compliance_id]["requirements_manual"] += 1
|
||||
|
||||
# Create summary objects
|
||||
summary_objects = []
|
||||
for compliance_id, data in compliance_summaries.items():
|
||||
summary_objects.append(
|
||||
ComplianceOverviewSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
compliance_id=compliance_id,
|
||||
requirements_passed=data["requirements_passed"],
|
||||
requirements_failed=data["requirements_failed"],
|
||||
requirements_manual=data["requirements_manual"],
|
||||
total_requirements=data["total_requirements"],
|
||||
)
|
||||
)
|
||||
|
||||
# Bulk insert summaries
|
||||
if summary_objects:
|
||||
ComplianceOverviewSummary.objects.bulk_create(
|
||||
summary_objects, batch_size=500, ignore_conflicts=True
|
||||
)
|
||||
|
||||
return {"status": "backfilled", "inserted": len(summary_objects)}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,10 @@ from celery.utils.log import get_task_logger
|
||||
from config.celery import RLSTask
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
from tasks.jobs.backfill import backfill_resource_scan_summaries
|
||||
from tasks.jobs.backfill import (
|
||||
backfill_compliance_summaries,
|
||||
backfill_resource_scan_summaries,
|
||||
)
|
||||
from tasks.jobs.connection import (
|
||||
check_integration_connection,
|
||||
check_lighthouse_connection,
|
||||
@@ -494,6 +497,21 @@ def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
|
||||
return backfill_resource_scan_summaries(tenant_id=tenant_id, scan_id=scan_id)
|
||||
|
||||
|
||||
@shared_task(name="backfill-compliance-summaries", queue="backfill")
|
||||
def backfill_compliance_summaries_task(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
Tries to backfill compliance overview summaries for a completed scan.
|
||||
|
||||
This task aggregates compliance requirement data across regions
|
||||
to create pre-computed summary records for fast compliance overview queries.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant identifier.
|
||||
scan_id (str): The scan identifier.
|
||||
"""
|
||||
return backfill_compliance_summaries(tenant_id=tenant_id, scan_id=scan_id)
|
||||
|
||||
|
||||
@shared_task(base=RLSTask, name="scan-compliance-overviews", queue="compliance")
|
||||
def create_compliance_requirements_task(tenant_id: str, scan_id: str):
|
||||
"""
|
||||
|
||||
@@ -1,43 +1,53 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from tasks.jobs.backfill import backfill_resource_scan_summaries
|
||||
from tasks.jobs.backfill import (
|
||||
backfill_compliance_summaries,
|
||||
backfill_resource_scan_summaries,
|
||||
)
|
||||
|
||||
from api.models import ResourceScanSummary, Scan, StateChoices
|
||||
from api.models import (
|
||||
ComplianceOverviewSummary,
|
||||
ResourceScanSummary,
|
||||
Scan,
|
||||
StateChoices,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def resource_scan_summary_data(scans_fixture):
|
||||
scan = scans_fixture[0]
|
||||
return ResourceScanSummary.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
scan_id=scan.id,
|
||||
resource_id=str(uuid4()),
|
||||
service="aws",
|
||||
region="us-east-1",
|
||||
resource_type="instance",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def get_not_completed_scans(providers_fixture):
|
||||
provider_id = providers_fixture[0].id
|
||||
tenant_id = providers_fixture[0].tenant_id
|
||||
scan_1 = Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.EXECUTING,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
scan_2 = Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.AVAILABLE,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return scan_1, scan_2
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestBackfillResourceScanSummaries:
|
||||
@pytest.fixture(scope="function")
|
||||
def resource_scan_summary_data(self, scans_fixture):
|
||||
scan = scans_fixture[0]
|
||||
return ResourceScanSummary.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
scan_id=scan.id,
|
||||
resource_id=str(uuid4()),
|
||||
service="aws",
|
||||
region="us-east-1",
|
||||
resource_type="instance",
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def get_not_completed_scans(self, providers_fixture):
|
||||
provider_id = providers_fixture[0].id
|
||||
tenant_id = providers_fixture[0].tenant_id
|
||||
scan_1 = Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.EXECUTING,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
scan_2 = Scan.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.AVAILABLE,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return scan_1, scan_2
|
||||
|
||||
def test_already_backfilled(self, resource_scan_summary_data):
|
||||
tenant_id = resource_scan_summary_data.tenant_id
|
||||
scan_id = resource_scan_summary_data.scan_id
|
||||
@@ -77,3 +87,88 @@ class TestBackfillResourceScanSummaries:
|
||||
assert summary.service == resource.service
|
||||
assert summary.region == resource.region
|
||||
assert summary.resource_type == resource.type
|
||||
|
||||
def test_no_resources_to_backfill(self, scans_fixture):
|
||||
scan = scans_fixture[1] # Failed scan with no findings/resources
|
||||
tenant_id = str(scan.tenant_id)
|
||||
scan_id = str(scan.id)
|
||||
|
||||
result = backfill_resource_scan_summaries(tenant_id, scan_id)
|
||||
|
||||
assert result == {"status": "no resources to backfill"}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestBackfillComplianceSummaries:
|
||||
def test_already_backfilled(self, scans_fixture):
|
||||
scan = scans_fixture[0]
|
||||
tenant_id = str(scan.tenant_id)
|
||||
ComplianceOverviewSummary.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
scan=scan,
|
||||
compliance_id="aws_account_security_onboarding_aws",
|
||||
requirements_passed=1,
|
||||
requirements_failed=0,
|
||||
requirements_manual=0,
|
||||
total_requirements=1,
|
||||
)
|
||||
|
||||
result = backfill_compliance_summaries(tenant_id, str(scan.id))
|
||||
|
||||
assert result == {"status": "already backfilled"}
|
||||
|
||||
def test_not_completed_scan(self, get_not_completed_scans):
|
||||
for scan in get_not_completed_scans:
|
||||
result = backfill_compliance_summaries(str(scan.tenant_id), str(scan.id))
|
||||
assert result == {"status": "scan is not completed"}
|
||||
|
||||
def test_no_compliance_data(self, scans_fixture):
|
||||
scan = scans_fixture[1] # Failed scan with no compliance rows
|
||||
|
||||
result = backfill_compliance_summaries(str(scan.tenant_id), str(scan.id))
|
||||
|
||||
assert result == {"status": "no compliance data to backfill"}
|
||||
|
||||
def test_backfill_creates_compliance_summaries(
|
||||
self, tenants_fixture, scans_fixture, compliance_requirements_overviews_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
|
||||
result = backfill_compliance_summaries(str(tenant.id), str(scan.id))
|
||||
|
||||
expected = {
|
||||
"aws_account_security_onboarding_aws": {
|
||||
"requirements_passed": 1,
|
||||
"requirements_failed": 1,
|
||||
"requirements_manual": 1,
|
||||
"total_requirements": 3,
|
||||
},
|
||||
"cis_1.4_aws": {
|
||||
"requirements_passed": 0,
|
||||
"requirements_failed": 1,
|
||||
"requirements_manual": 0,
|
||||
"total_requirements": 1,
|
||||
},
|
||||
"mitre_attack_aws": {
|
||||
"requirements_passed": 0,
|
||||
"requirements_failed": 1,
|
||||
"requirements_manual": 0,
|
||||
"total_requirements": 1,
|
||||
},
|
||||
}
|
||||
|
||||
assert result == {"status": "backfilled", "inserted": len(expected)}
|
||||
|
||||
summaries = ComplianceOverviewSummary.objects.filter(
|
||||
tenant_id=str(tenant.id), scan_id=str(scan.id)
|
||||
)
|
||||
assert summaries.count() == len(expected)
|
||||
|
||||
for summary in summaries:
|
||||
assert summary.compliance_id in expected
|
||||
expected_counts = expected[summary.compliance_id]
|
||||
assert summary.requirements_passed == expected_counts["requirements_passed"]
|
||||
assert summary.requirements_failed == expected_counts["requirements_failed"]
|
||||
assert summary.requirements_manual == expected_counts["requirements_manual"]
|
||||
assert summary.total_requirements == expected_counts["total_requirements"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user