refactor(api): update compliance report endpoints and enhance query parameters (#9338)

This commit is contained in:
Adrián Peña
2025-12-03 11:41:07 +01:00
committed by GitHub
parent 7b1915e489
commit a4e12a94f9
3 changed files with 320 additions and 77 deletions

View File

@@ -9,7 +9,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Exception handler for provider deletions during scans [(#9414)](https://github.com/prowler-cloud/prowler/pull/9414)
### Changed
- Restore the compliance overview endpoint's mandatory filters [(#9330)](https://github.com/prowler-cloud/prowler/pull/9330)
- Restore the compliance overview endpoint's mandatory filters [(#9338)](https://github.com/prowler-cloud/prowler/pull/9338)
---

View File

@@ -36,6 +36,8 @@ from api.compliance import get_compliance_frameworks
from api.db_router import MainRouter
from api.models import (
AttackSurfaceOverview,
ComplianceOverviewSummary,
ComplianceRequirementOverview,
Finding,
Integration,
Invitation,
@@ -56,6 +58,7 @@ from api.models import (
Scan,
ScanSummary,
StateChoices,
StatusChoices,
Task,
TenantAPIKey,
ThreatScoreSnapshot,
@@ -5820,16 +5823,44 @@ class TestProviderGroupMembershipViewSet:
@pytest.mark.django_db
class TestComplianceOverviewViewSet:
def test_compliance_overview_list_none(self, authenticated_client):
@pytest.fixture(autouse=True)
def mock_backfill_task(self):
with patch("api.v1.views.backfill_compliance_summaries_task.delay") as mock:
yield mock
def test_compliance_overview_list_none(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
mock_backfill_task,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan = Scan.objects.create(
name="empty-compliance-scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": "8d20ac7d-4cbc-435e-85f4-359be37af821"},
{"filter[scan_id]": str(scan.id)},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
mock_backfill_task.assert_called_once()
_, kwargs = mock_backfill_task.call_args
assert kwargs["scan_id"] == str(scan.id)
assert str(kwargs["tenant_id"]) == str(tenant.id)
def test_compliance_overview_list(
self, authenticated_client, compliance_requirements_overviews_fixture
self,
authenticated_client,
compliance_requirements_overviews_fixture,
mock_backfill_task,
):
# List compliance overviews with existing data
requirement_overview1 = compliance_requirements_overviews_fixture[0]
@@ -5859,6 +5890,90 @@ class TestComplianceOverviewViewSet:
assert "requirements_failed" in attributes
assert "requirements_manual" in attributes
assert "total_requirements" in attributes
mock_backfill_task.assert_called_once()
_, kwargs = mock_backfill_task.call_args
assert kwargs["scan_id"] == scan_id
def test_compliance_overview_list_uses_preaggregated_summaries(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
mock_backfill_task,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan = Scan.objects.create(
name="preaggregated-scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan,
compliance_id="cis_1.4_aws",
framework="CIS-1.4-AWS",
version="1.4",
description="CIS AWS Foundations Benchmark v1.4.0",
region="eu-west-1",
requirement_id="framework-metadata",
requirement_status=StatusChoices.PASS,
passed_checks=1,
failed_checks=0,
total_checks=1,
)
ComplianceOverviewSummary.objects.create(
tenant=tenant,
scan=scan,
compliance_id="cis_1.4_aws",
requirements_passed=5,
requirements_failed=1,
requirements_manual=2,
total_requirements=8,
)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": str(scan.id)},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
overview = data[0]
assert overview["id"] == "cis_1.4_aws"
assert overview["attributes"]["requirements_passed"] == 5
assert overview["attributes"]["requirements_failed"] == 1
assert overview["attributes"]["requirements_manual"] == 2
assert overview["attributes"]["total_requirements"] == 8
assert "framework" in overview["attributes"]
assert "version" in overview["attributes"]
mock_backfill_task.assert_not_called()
def test_compliance_overview_region_filter_skips_backfill(
self,
authenticated_client,
compliance_requirements_overviews_fixture,
mock_backfill_task,
):
requirement_overview = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[scan_id]": scan_id,
"filter[region]": requirement_overview.region,
},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) >= 1
mock_backfill_task.assert_not_called()
def test_compliance_overview_metadata(
self, authenticated_client, compliance_requirements_overviews_fixture
@@ -6012,6 +6127,11 @@ class TestComplianceOverviewViewSet:
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
# Remove existing compliance data so the view falls back to task checks
scan = requirement_overview1.scan
ComplianceOverviewSummary.objects.filter(scan=scan).delete()
ComplianceRequirementOverview.objects.filter(scan=scan).delete()
# Mock a running task
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
@@ -6039,6 +6159,11 @@ class TestComplianceOverviewViewSet:
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
# Remove existing compliance data so the view falls back to task checks
scan = requirement_overview1.scan
ComplianceOverviewSummary.objects.filter(scan=scan).delete()
ComplianceRequirementOverview.objects.filter(scan=scan).delete()
# Mock a failed task
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
@@ -6062,6 +6187,8 @@ class TestComplianceOverviewViewSet:
("framework", "framework", 1),
("version", "version", 1),
("region", "region", 1),
("region__in", "region", 1),
("region.in", "region", 1),
],
)
def test_compliance_overview_filters(

View File

@@ -76,6 +76,7 @@ from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.jobs.scan import _get_attack_surface_mapping_from_provider
from tasks.tasks import (
backfill_compliance_summaries_task,
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
@@ -128,6 +129,7 @@ from api.filters import (
)
from api.models import (
AttackSurfaceOverview,
ComplianceOverviewSummary,
ComplianceRequirementOverview,
Finding,
Integration,
@@ -3526,6 +3528,126 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
def retrieve(self, request, *args, **kwargs):
raise MethodNotAllowed(method="GET")
def _compliance_summaries_queryset(self, scan_id):
"""Return pre-aggregated summaries constrained by RBAC visibility."""
role = get_role(self.request.user)
unlimited_visibility = getattr(
role, Permissions.UNLIMITED_VISIBILITY.value, False
)
summaries = ComplianceOverviewSummary.objects.filter(
tenant_id=self.request.tenant_id,
scan_id=scan_id,
)
if not unlimited_visibility:
providers = Provider.all_objects.filter(
provider_groups__in=role.provider_groups.all()
).distinct()
summaries = summaries.filter(scan__provider__in=providers)
return summaries
def _get_compliance_template(self, *, provider=None, scan_id=None):
"""Return the compliance template for the given provider or scan."""
if provider is None and scan_id is not None:
try:
scan = Scan.all_objects.select_related("provider").get(pk=scan_id)
except Scan.DoesNotExist:
raise ValidationError(
[
{
"detail": "Scan not found",
"status": 404,
"source": {"pointer": "filter[scan_id]"},
"code": "not_found",
}
]
)
provider = scan.provider
if not provider:
return {}
return PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE.get(provider.provider, {})
def _aggregate_compliance_overview(self, queryset, template_metadata=None):
"""
Aggregate requirement rows into compliance overview dictionaries.
Args:
queryset: ComplianceRequirementOverview queryset already filtered.
template_metadata: Optional dict mapping compliance_id -> metadata.
"""
template_metadata = template_metadata or {}
requirement_status_subquery = queryset.values(
"compliance_id", "requirement_id"
).annotate(
fail_count=Count("id", filter=Q(requirement_status="FAIL")),
pass_count=Count("id", filter=Q(requirement_status="PASS")),
total_count=Count("id"),
)
compliance_data = {}
fallback_metadata = {
item["compliance_id"]: {
"framework": item["framework"],
"version": item["version"],
}
for item in queryset.values(
"compliance_id", "framework", "version"
).distinct()
}
for item in requirement_status_subquery:
compliance_id = item["compliance_id"]
if item["fail_count"] > 0:
req_status = "FAIL"
elif item["pass_count"] == item["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
compliance_status = compliance_data.setdefault(
compliance_id,
{
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
},
)
compliance_status["total_requirements"] += 1
if req_status == "PASS":
compliance_status["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_status["requirements_failed"] += 1
else:
compliance_status["requirements_manual"] += 1
response_data = []
for compliance_id, data in compliance_data.items():
template = template_metadata.get(compliance_id, {})
fallback = fallback_metadata.get(compliance_id, {})
response_data.append(
{
"id": compliance_id,
"compliance_id": compliance_id,
"framework": template.get("framework")
or fallback.get("framework", ""),
"version": template.get("version") or fallback.get("version", ""),
"requirements_passed": data["requirements_passed"],
"requirements_failed": data["requirements_failed"],
"requirements_manual": data["requirements_manual"],
"total_requirements": data["total_requirements"],
}
)
serializer = self.get_serializer(response_data, many=True)
return serializer.data
def _task_response_if_running(self, scan_id):
"""Check for an in-progress task only when no compliance data exists."""
try:
@@ -3540,90 +3662,84 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def list(self, request, *args, **kwargs):
scan_id = request.query_params.get("filter[scan_id]")
if not scan_id:
raise ValidationError(
[
{
"detail": "This query parameter is required.",
"status": 400,
"source": {"pointer": "filter[scan_id]"},
"code": "required",
}
]
)
try:
if task := self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
):
return task
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
queryset = self.filter_queryset(self.filter_queryset(self.get_queryset()))
requirement_status_subquery = queryset.values(
"compliance_id", "requirement_id"
).annotate(
fail_count=Count("id", filter=Q(requirement_status="FAIL")),
pass_count=Count("id", filter=Q(requirement_status="PASS")),
total_count=Count("id"),
def _list_with_region_filter(self, scan_id, region_filter):
"""
Fall back to detailed ComplianceRequirementOverview query when region filter is applied.
This uses the original aggregation logic across filtered regions.
"""
regions = region_filter.split(",") if "," in region_filter else [region_filter]
queryset = self.filter_queryset(self.get_queryset()).filter(
scan_id=scan_id,
region__in=regions,
)
compliance_data = {}
framework_info = {}
data = self._aggregate_compliance_overview(queryset)
if data:
return Response(data)
for item in queryset.values("compliance_id", "framework", "version").distinct():
framework_info[item["compliance_id"]] = {
"framework": item["framework"],
"version": item["version"],
}
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
for item in requirement_status_subquery:
compliance_id = item["compliance_id"]
return Response(data)
if item["fail_count"] > 0:
req_status = "FAIL"
elif item["pass_count"] == item["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
def _list_without_region_aggregation(self, scan_id):
"""
Fall back aggregation when compliance summaries don't exist yet.
Aggregates ComplianceRequirementOverview data across ALL regions.
"""
queryset = self.filter_queryset(self.get_queryset()).filter(scan_id=scan_id)
compliance_template = self._get_compliance_template(scan_id=scan_id)
data = self._aggregate_compliance_overview(
queryset, template_metadata=compliance_template
)
if data:
return Response(data)
if compliance_id not in compliance_data:
compliance_data[compliance_id] = {
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
}
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
compliance_data[compliance_id]["total_requirements"] += 1
if req_status == "PASS":
compliance_data[compliance_id]["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_data[compliance_id]["requirements_failed"] += 1
else:
compliance_data[compliance_id]["requirements_manual"] += 1
return Response(data)
def list(self, request, *args, **kwargs):
scan_id = request.query_params.get("filter[scan_id]")
# Specific scan requested - use optimized summaries with region support
region_filter = request.query_params.get(
"filter[region]"
) or request.query_params.get("filter[region__in]")
if region_filter:
# Fall back to detailed query with region filtering
return self._list_with_region_filter(scan_id, region_filter)
summaries = list(self._compliance_summaries_queryset(scan_id))
if not summaries:
# Trigger async backfill for next time
backfill_compliance_summaries_task.delay(
tenant_id=self.request.tenant_id, scan_id=scan_id
)
# Use fallback aggregation for this request
return self._list_without_region_aggregation(scan_id)
# Get compliance template for provider to enrich with framework/version
compliance_template = self._get_compliance_template(scan_id=scan_id)
# Convert to response format with framework/version enrichment
response_data = []
for compliance_id, data in compliance_data.items():
framework = framework_info.get(compliance_id, {})
for summary in summaries:
compliance_metadata = compliance_template.get(summary.compliance_id, {})
response_data.append(
{
"id": compliance_id,
"compliance_id": compliance_id,
"framework": framework.get("framework", ""),
"version": framework.get("version", ""),
"requirements_passed": data["requirements_passed"],
"requirements_failed": data["requirements_failed"],
"requirements_manual": data["requirements_manual"],
"total_requirements": data["total_requirements"],
"id": summary.compliance_id,
"compliance_id": summary.compliance_id,
"framework": compliance_metadata.get("framework", ""),
"version": compliance_metadata.get("version", ""),
"requirements_passed": summary.requirements_passed,
"requirements_failed": summary.requirements_failed,
"requirements_manual": summary.requirements_manual,
"total_requirements": summary.total_requirements,
}
)