feat(threatscore): implement ThreatScoreSnapshot model, filter, serializer, and view for ThreatScore metrics retrieval (#9148)

This commit is contained in:
Adrián Jesús Peña Rodríguez
2025-11-11 10:19:48 +01:00
committed by GitHub
parent 73a277f27b
commit beec37b0da
11 changed files with 1721 additions and 140 deletions

View File

@@ -14,6 +14,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Support muting findings based on simple rules with custom reason [(#9051)](https://github.com/prowler-cloud/prowler/pull/9051)
- Support C5 compliance framework for the GCP provider [(#9097)](https://github.com/prowler-cloud/prowler/pull/9097)
- Support for Amazon Bedrock and OpenAI compatible providers in Lighthouse AI [(#8957)](https://github.com/prowler-cloud/prowler/pull/8957)
- Tenant-wide ThreatScore overview aggregation and snapshot persistence with backfill support [(#9148)](https://github.com/prowler-cloud/prowler/pull/9148)
- Support for MongoDB Atlas provider [(#9167)](https://github.com/prowler-cloud/prowler/pull/9167)
---

View File

@@ -47,6 +47,7 @@ from api.models import (
StatusChoices,
Task,
TenantAPIKey,
ThreatScoreSnapshot,
User,
)
from api.rls import Tenant
@@ -998,3 +999,36 @@ class MuteRuleFilter(FilterSet):
"inserted_at": ["gte", "lte"],
"updated_at": ["gte", "lte"],
}
class ThreatScoreSnapshotFilter(FilterSet):
"""
Filter for ThreatScore snapshots.
Allows filtering by scan, provider, compliance_id, and date ranges.
"""
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
scan_id = UUIDFilter(field_name="scan__id", lookup_expr="exact")
scan_id__in = UUIDInFilter(field_name="scan__id", lookup_expr="in")
provider_id = UUIDFilter(field_name="provider__id", lookup_expr="exact")
provider_id__in = UUIDInFilter(field_name="provider__id", lookup_expr="in")
provider_type = ChoiceFilter(
field_name="provider__provider", choices=Provider.ProviderChoices.choices
)
provider_type__in = ChoiceInFilter(
field_name="provider__provider",
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
compliance_id = CharFilter(field_name="compliance_id", lookup_expr="exact")
compliance_id__in = CharInFilter(field_name="compliance_id", lookup_expr="in")
class Meta:
model = ThreatScoreSnapshot
fields = {
"scan": ["exact", "in"],
"provider": ["exact", "in"],
"compliance_id": ["exact", "in"],
"inserted_at": ["date", "gte", "lte"],
"overall_score": ["exact", "gte", "lte"],
}

View File

@@ -0,0 +1,170 @@
# Generated by Django 5.1.13 on 2025-10-31 09:04
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0056_remove_provider_unique_provider_uids_and_more"),
]
operations = [
migrations.CreateModel(
name="ThreatScoreSnapshot",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"compliance_id",
models.CharField(
help_text="Compliance framework ID (e.g., 'prowler_threatscore_aws')",
max_length=100,
),
),
(
"overall_score",
models.DecimalField(
decimal_places=2,
help_text="Overall ThreatScore percentage (0-100)",
max_digits=5,
),
),
(
"score_delta",
models.DecimalField(
blank=True,
decimal_places=2,
help_text="Score change compared to previous snapshot (positive = improvement)",
max_digits=5,
null=True,
),
),
(
"section_scores",
models.JSONField(
blank=True,
default=dict,
help_text="ThreatScore breakdown by section",
),
),
(
"critical_requirements",
models.JSONField(
blank=True,
default=list,
help_text="List of critical failed requirements (risk >= 4)",
),
),
(
"total_requirements",
models.IntegerField(
default=0, help_text="Total number of requirements evaluated"
),
),
(
"passed_requirements",
models.IntegerField(
default=0, help_text="Number of requirements with PASS status"
),
),
(
"failed_requirements",
models.IntegerField(
default=0, help_text="Number of requirements with FAIL status"
),
),
(
"manual_requirements",
models.IntegerField(
default=0, help_text="Number of requirements with MANUAL status"
),
),
(
"total_findings",
models.IntegerField(
default=0,
help_text="Total number of findings across all requirements",
),
),
(
"passed_findings",
models.IntegerField(
default=0, help_text="Number of findings with PASS status"
),
),
(
"failed_findings",
models.IntegerField(
default=0, help_text="Number of findings with FAIL status"
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="threatscore_snapshots",
related_query_name="threatscore_snapshot",
to="api.provider",
),
),
(
"scan",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="threatscore_snapshots",
related_query_name="threatscore_snapshot",
to="api.scan",
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "threatscore_snapshots",
"abstract": False,
},
),
migrations.AddIndex(
model_name="threatscoresnapshot",
index=models.Index(
fields=["tenant_id", "scan_id"], name="threatscore_snap_t_scan_idx"
),
),
migrations.AddIndex(
model_name="threatscoresnapshot",
index=models.Index(
fields=["tenant_id", "provider_id"], name="threatscore_snap_t_prov_idx"
),
),
migrations.AddIndex(
model_name="threatscoresnapshot",
index=models.Index(
fields=["tenant_id", "inserted_at"], name="threatscore_snap_t_time_idx"
),
),
migrations.AddConstraint(
model_name="threatscoresnapshot",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_threatscoresnapshot",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]

View File

@@ -2239,3 +2239,137 @@ class LighthouseProviderModels(RowLevelSecurityProtectedModel):
class JSONAPIMeta:
resource_name = "lighthouse-models"
class ThreatScoreSnapshot(RowLevelSecurityProtectedModel):
"""
Stores historical ThreatScore metrics for a given scan.
Snapshots are created automatically after each ThreatScore report generation.
"""
objects = models.Manager()
all_objects = models.Manager()
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
scan = models.ForeignKey(
Scan,
on_delete=models.CASCADE,
related_name="threatscore_snapshots",
related_query_name="threatscore_snapshot",
)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
related_name="threatscore_snapshots",
related_query_name="threatscore_snapshot",
)
compliance_id = models.CharField(
max_length=100,
blank=False,
null=False,
help_text="Compliance framework ID (e.g., 'prowler_threatscore_aws')",
)
# Overall ThreatScore metrics
overall_score = models.DecimalField(
max_digits=5,
decimal_places=2,
help_text="Overall ThreatScore percentage (0-100)",
)
# Score improvement/degradation compared to previous snapshot
score_delta = models.DecimalField(
max_digits=5,
decimal_places=2,
null=True,
blank=True,
help_text="Score change compared to previous snapshot (positive = improvement)",
)
# Section breakdown stored as JSON
# Format: {"1. IAM": 85.5, "2. Attack Surface": 92.3, ...}
section_scores = models.JSONField(
default=dict,
blank=True,
help_text="ThreatScore breakdown by section",
)
# Critical requirements metadata stored as JSON
# Format: [{"requirement_id": "...", "risk_level": 5, "weight": 150, ...}, ...]
critical_requirements = models.JSONField(
default=list,
blank=True,
help_text="List of critical failed requirements (risk >= 4)",
)
# Summary statistics
total_requirements = models.IntegerField(
default=0,
help_text="Total number of requirements evaluated",
)
passed_requirements = models.IntegerField(
default=0,
help_text="Number of requirements with PASS status",
)
failed_requirements = models.IntegerField(
default=0,
help_text="Number of requirements with FAIL status",
)
manual_requirements = models.IntegerField(
default=0,
help_text="Number of requirements with MANUAL status",
)
total_findings = models.IntegerField(
default=0,
help_text="Total number of findings across all requirements",
)
passed_findings = models.IntegerField(
default=0,
help_text="Number of findings with PASS status",
)
failed_findings = models.IntegerField(
default=0,
help_text="Number of findings with FAIL status",
)
def __str__(self):
return f"ThreatScore {self.overall_score}% for scan {self.scan_id} ({self.inserted_at})"
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "threatscore_snapshots"
constraints = [
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
indexes = [
models.Index(
fields=["tenant_id", "scan_id"],
name="threatscore_snap_t_scan_idx",
),
models.Index(
fields=["tenant_id", "provider_id"],
name="threatscore_snap_t_prov_idx",
),
models.Index(
fields=["tenant_id", "inserted_at"],
name="threatscore_snap_t_time_idx",
),
]
class JSONAPIMeta:
resource_name = "threatscore-snapshots"

View File

@@ -4,6 +4,7 @@ import json
import os
import tempfile
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock, patch
@@ -56,6 +57,7 @@ from api.models import (
StateChoices,
Task,
TenantAPIKey,
ThreatScoreSnapshot,
User,
UserRoleRelationship,
)
@@ -6221,6 +6223,407 @@ class TestOverviewViewSet:
for entry in grouped_data:
assert "findings" not in entry["attributes"]
def _create_scan(self, tenant, provider, name, started_at=None):
scan_started = started_at or datetime.now(timezone.utc) - timedelta(hours=1)
return Scan.objects.create(
tenant=tenant,
provider=provider,
name=name,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
started_at=scan_started,
completed_at=scan_started + timedelta(minutes=30),
)
def _create_threatscore_snapshot(
self,
tenant,
scan,
provider,
*,
compliance_id,
overall_score,
score_delta,
section_scores,
critical_requirements,
total_requirements,
passed_requirements,
failed_requirements,
manual_requirements,
total_findings,
passed_findings,
failed_findings,
):
return ThreatScoreSnapshot.objects.create(
tenant=tenant,
scan=scan,
provider=provider,
compliance_id=compliance_id,
overall_score=Decimal(overall_score),
score_delta=Decimal(score_delta) if score_delta is not None else None,
section_scores=section_scores,
critical_requirements=critical_requirements,
total_requirements=total_requirements,
passed_requirements=passed_requirements,
failed_requirements=failed_requirements,
manual_requirements=manual_requirements,
total_findings=total_findings,
passed_findings=passed_findings,
failed_findings=failed_findings,
)
def test_overview_threatscore_returns_weighted_aggregate_snapshot(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
scan1 = self._create_scan(tenant, provider1, "agg-scan-one")
scan2 = self._create_scan(tenant, provider2, "agg-scan-two")
snapshot1 = self._create_threatscore_snapshot(
tenant,
scan1,
provider1,
compliance_id="prowler_threatscore_aws",
overall_score="80.00",
score_delta="5.00",
section_scores={"1. IAM": "70.00", "2. Attack Surface": "60.00"},
critical_requirements=[
{
"requirement_id": "req_shared",
"title": "Shared requirement (preferred)",
"section": "1. IAM",
"subsection": "Sub IAM",
"risk_level": 5,
"weight": 150,
"passed_findings": 14,
"total_findings": 20,
"description": "Higher risk duplicate",
},
{
"requirement_id": "req_unique_one",
"title": "Unique provider one",
"section": "2. Attack Surface",
"subsection": "Sub Attack",
"risk_level": 4,
"weight": 90,
"passed_findings": 20,
"total_findings": 30,
"description": "Lower risk",
},
],
total_requirements=120,
passed_requirements=90,
failed_requirements=30,
manual_requirements=0,
total_findings=100,
passed_findings=70,
failed_findings=30,
)
snapshot2 = self._create_threatscore_snapshot(
tenant,
scan2,
provider2,
compliance_id="prowler_threatscore_aws",
overall_score="20.00",
score_delta="-2.00",
section_scores={
"1. IAM": "10.00",
"2. Attack Surface": "40.00",
"3. Logging": "30.00",
},
critical_requirements=[
{
"requirement_id": "req_shared",
"title": "Shared requirement (secondary)",
"section": "1. IAM",
"subsection": "Sub IAM",
"risk_level": 4,
"weight": 120,
"passed_findings": 8,
"total_findings": 12,
"description": "Lower risk duplicate",
},
{
"requirement_id": "req_unique_two",
"title": "Unique provider two",
"section": "3. Logging",
"subsection": "Sub Logging",
"risk_level": 5,
"weight": 110,
"passed_findings": 6,
"total_findings": 10,
"description": "Another critical requirement",
},
],
total_requirements=80,
passed_requirements=30,
failed_requirements=50,
manual_requirements=0,
total_findings=50,
passed_findings=15,
failed_findings=35,
)
older_inserted = datetime(2025, 1, 1, 12, 0, tzinfo=timezone.utc)
newer_inserted = datetime(2025, 1, 2, 12, 0, tzinfo=timezone.utc)
ThreatScoreSnapshot.objects.filter(id=snapshot1.id).update(
inserted_at=older_inserted
)
ThreatScoreSnapshot.objects.filter(id=snapshot2.id).update(
inserted_at=newer_inserted
)
snapshot2.refresh_from_db()
response = authenticated_client.get(reverse("overview-threatscore"))
assert response.status_code == status.HTTP_200_OK
body = response.json()
assert len(body["data"]) == 1
aggregated = body["data"][0]
assert aggregated["id"] == "n/a"
assert aggregated["relationships"]["scan"]["data"] is None
assert aggregated["relationships"]["provider"]["data"] is None
attrs = aggregated["attributes"]
assert Decimal(attrs["overall_score"]) == Decimal("60.00")
assert Decimal(attrs["score_delta"]) == Decimal("2.67")
assert attrs["inserted_at"] == snapshot2.inserted_at.isoformat().replace(
"+00:00", "Z"
)
assert attrs["total_findings"] == 150
assert attrs["passed_findings"] == 85
assert attrs["failed_findings"] == 65
assert attrs["total_requirements"] == 200
assert attrs["passed_requirements"] == 120
assert attrs["failed_requirements"] == 80
assert attrs["manual_requirements"] == 0
assert attrs["section_scores"] == {
"1. IAM": "50.00",
"2. Attack Surface": "53.33",
"3. Logging": "30.00",
}
expected_critical = [
{
"requirement_id": "req_shared",
"title": "Shared requirement (preferred)",
"section": "1. IAM",
"subsection": "Sub IAM",
"risk_level": 5,
"weight": 150,
"passed_findings": 14,
"total_findings": 20,
"description": "Higher risk duplicate",
},
{
"requirement_id": "req_unique_two",
"title": "Unique provider two",
"section": "3. Logging",
"subsection": "Sub Logging",
"risk_level": 5,
"weight": 110,
"passed_findings": 6,
"total_findings": 10,
"description": "Another critical requirement",
},
{
"requirement_id": "req_unique_one",
"title": "Unique provider one",
"section": "2. Attack Surface",
"subsection": "Sub Attack",
"risk_level": 4,
"weight": 90,
"passed_findings": 20,
"total_findings": 30,
"description": "Lower risk",
},
]
assert attrs["critical_requirements"] == expected_critical
def test_overview_threatscore_weight_fallback_to_requirements(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
scan1 = self._create_scan(tenant, provider1, "fallback-scan-1")
scan2 = self._create_scan(tenant, provider2, "fallback-scan-2")
self._create_threatscore_snapshot(
tenant,
scan1,
provider1,
compliance_id="prowler_threatscore_aws",
overall_score="90.00",
score_delta="4.00",
section_scores={"1. IAM": "90.00"},
critical_requirements=[],
total_requirements=10,
passed_requirements=8,
failed_requirements=0,
manual_requirements=2,
total_findings=0,
passed_findings=0,
failed_findings=0,
)
self._create_threatscore_snapshot(
tenant,
scan2,
provider2,
compliance_id="prowler_threatscore_aws",
overall_score="50.00",
score_delta="1.00",
section_scores={"1. IAM": "40.00"},
critical_requirements=[],
total_requirements=12,
passed_requirements=5,
failed_requirements=7,
manual_requirements=0,
total_findings=10,
passed_findings=4,
failed_findings=6,
)
response = authenticated_client.get(reverse("overview-threatscore"))
assert response.status_code == status.HTTP_200_OK
aggregate = response.json()["data"][0]["attributes"]
assert Decimal(aggregate["overall_score"]) == Decimal("67.78")
assert Decimal(aggregate["score_delta"]) == Decimal("2.33")
assert aggregate["total_findings"] == 10
assert aggregate["total_requirements"] == 22
assert aggregate["manual_requirements"] == 2
assert aggregate["section_scores"] == {"1. IAM": "62.22"}
def test_overview_threatscore_filter_by_scan_id_returns_snapshot(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
provider1, *_ = providers_fixture
scan = self._create_scan(tenant, provider1, "filter-scan")
snapshot = self._create_threatscore_snapshot(
tenant,
scan,
provider1,
compliance_id="prowler_threatscore_aws",
overall_score="75.00",
score_delta="3.00",
section_scores={"1. IAM": "70.00"},
critical_requirements=[],
total_requirements=50,
passed_requirements=30,
failed_requirements=20,
manual_requirements=0,
total_findings=25,
passed_findings=15,
failed_findings=10,
)
response = authenticated_client.get(
reverse("overview-threatscore"), {"filter[scan_id]": str(scan.id)}
)
assert response.status_code == status.HTTP_200_OK
body = response.json()
assert len(body["data"]) == 1
assert body["data"][0]["id"] == str(snapshot.id)
assert body["data"][0]["attributes"]["overall_score"] == "75.00"
def test_overview_threatscore_snapshot_id_returns_specific_snapshot(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
provider1, *_ = providers_fixture
scan = self._create_scan(tenant, provider1, "snapshot-id-scan")
snapshot = self._create_threatscore_snapshot(
tenant,
scan,
provider1,
compliance_id="prowler_threatscore_aws",
overall_score="88.50",
score_delta=None,
section_scores={"1. IAM": "80.00"},
critical_requirements=[],
total_requirements=60,
passed_requirements=45,
failed_requirements=15,
manual_requirements=0,
total_findings=30,
passed_findings=25,
failed_findings=5,
)
response = authenticated_client.get(
reverse("overview-threatscore"), {"snapshot_id": str(snapshot.id)}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["data"]["id"] == str(snapshot.id)
assert data["data"]["attributes"]["score_delta"] is None
def test_overview_threatscore_provider_filter_returns_unaggregated_snapshot(
self, authenticated_client, tenants_fixture, providers_fixture
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
scan1 = self._create_scan(tenant, provider1, "provider-filter-scan-1")
scan2 = self._create_scan(tenant, provider2, "provider-filter-scan-2")
snapshot1 = self._create_threatscore_snapshot(
tenant,
scan1,
provider1,
compliance_id="prowler_threatscore_aws",
overall_score="55.55",
score_delta="1.10",
section_scores={"1. IAM": "50.00"},
critical_requirements=[],
total_requirements=40,
passed_requirements=25,
failed_requirements=15,
manual_requirements=0,
total_findings=12,
passed_findings=7,
failed_findings=5,
)
self._create_threatscore_snapshot(
tenant,
scan2,
provider2,
compliance_id="prowler_threatscore_aws",
overall_score="44.44",
score_delta="0.80",
section_scores={"1. IAM": "40.00"},
critical_requirements=[],
total_requirements=30,
passed_requirements=18,
failed_requirements=12,
manual_requirements=0,
total_findings=10,
passed_findings=6,
failed_findings=4,
)
response = authenticated_client.get(
reverse("overview-threatscore"),
{"filter[provider_id__in]": str(provider1.id)},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
assert data[0]["id"] == str(snapshot1.id)
assert data[0]["attributes"]["overall_score"] == "55.55"
def test_overview_services_list_no_required_filters(
self, authenticated_client, scan_summaries_fixture
):

View File

@@ -47,6 +47,7 @@ from api.models import (
StatusChoices,
Task,
TenantAPIKey,
ThreatScoreSnapshot,
User,
UserRoleRelationship,
)
@@ -3626,3 +3627,64 @@ class MuteRuleUpdateSerializer(BaseWriteSerializer):
):
raise ValidationError("A mute rule with this name already exists.")
return value
# ThreatScore Snapshots
class ThreatScoreSnapshotSerializer(RLSSerializer):
"""
Serializer for ThreatScore snapshots.
Read-only serializer for retrieving historical ThreatScore metrics.
"""
id = serializers.SerializerMethodField()
class Meta:
model = ThreatScoreSnapshot
fields = [
"id",
"inserted_at",
"scan",
"provider",
"compliance_id",
"overall_score",
"score_delta",
"section_scores",
"critical_requirements",
"total_requirements",
"passed_requirements",
"failed_requirements",
"manual_requirements",
"total_findings",
"passed_findings",
"failed_findings",
]
extra_kwargs = {
"id": {"read_only": True},
"inserted_at": {"read_only": True},
"scan": {"read_only": True},
"provider": {"read_only": True},
"compliance_id": {"read_only": True},
"overall_score": {"read_only": True},
"score_delta": {"read_only": True},
"section_scores": {"read_only": True},
"critical_requirements": {"read_only": True},
"total_requirements": {"read_only": True},
"passed_requirements": {"read_only": True},
"failed_requirements": {"read_only": True},
"manual_requirements": {"read_only": True},
"total_findings": {"read_only": True},
"passed_findings": {"read_only": True},
"failed_findings": {"read_only": True},
}
included_serializers = {
"scan": "api.v1.serializers.ScanIncludeSerializer",
"provider": "api.v1.serializers.ProviderIncludeSerializer",
}
def get_id(self, obj):
if getattr(obj, "_aggregated", False):
return "n/a"
return str(obj.id)

View File

@@ -3,7 +3,10 @@ import glob
import json
import logging
import os
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
from urllib.parse import urljoin
import sentry_sdk
@@ -24,9 +27,23 @@ from django.conf import settings as django_settings
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.search import SearchQuery
from django.db import transaction
from django.db.models import Count, F, Prefetch, Q, Subquery, Sum
from django.db.models import (
Case,
Count,
DecimalField,
ExpressionWrapper,
F,
IntegerField,
Max,
Prefetch,
Q,
Subquery,
Sum,
Value,
When,
)
from django.db.models.functions import Coalesce
from django.http import HttpResponse
from django.http import HttpResponse, QueryDict
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.dateparse import parse_date
@@ -105,6 +122,7 @@ from api.filters import (
TaskFilter,
TenantApiKeyFilter,
TenantFilter,
ThreatScoreSnapshotFilter,
UserFilter,
)
from api.models import (
@@ -138,6 +156,7 @@ from api.models import (
StateChoices,
Task,
TenantAPIKey,
ThreatScoreSnapshot,
User,
UserRoleRelationship,
)
@@ -218,6 +237,7 @@ from api.v1.serializers import (
TenantApiKeySerializer,
TenantApiKeyUpdateSerializer,
TenantSerializer,
ThreatScoreSnapshotSerializer,
TokenRefreshSerializer,
TokenSerializer,
TokenSocialLoginSerializer,
@@ -3770,6 +3790,8 @@ class OverviewViewSet(BaseRLSViewSet):
return OverviewSeveritySerializer
elif self.action == "services":
return OverviewServiceSerializer
elif self.action == "threatscore":
return ThreatScoreSnapshotSerializer
return super().get_serializer_class()
def get_filterset_class(self):
@@ -4011,6 +4033,332 @@ class OverviewViewSet(BaseRLSViewSet):
return Response(serializer.data, status=status.HTTP_200_OK)
@extend_schema(
summary="Get ThreatScore snapshots",
description=(
"Retrieve ThreatScore metrics. By default, returns the latest snapshot for each provider. "
"Use snapshot_id to retrieve a specific historical snapshot."
),
tags=["Overviews"],
parameters=[
OpenApiParameter(
name="snapshot_id",
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Retrieve a specific snapshot by ID. If not provided, returns latest snapshots.",
),
OpenApiParameter(
name="provider_id",
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Filter by specific provider ID",
),
OpenApiParameter(
name="provider_id__in",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by multiple provider IDs (comma-separated UUIDs)",
),
OpenApiParameter(
name="provider_type",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by provider type (aws, azure, gcp, etc.)",
),
OpenApiParameter(
name="provider_type__in",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by multiple provider types (comma-separated)",
),
],
)
@action(detail=False, methods=["get"], url_name="threatscore")
def threatscore(self, request):
"""
Get ThreatScore snapshots.
Default behavior: Returns the latest snapshot for each provider.
With snapshot_id: Returns the specific snapshot requested.
"""
tenant_id = self.request.tenant_id
snapshot_id = request.query_params.get("snapshot_id")
# Base queryset with RLS
base_queryset = ThreatScoreSnapshot.objects.filter(tenant_id=tenant_id)
# Apply RBAC filtering
if hasattr(self, "allowed_providers"):
base_queryset = base_queryset.filter(provider__in=self.allowed_providers)
# Case 1: Specific snapshot requested
if snapshot_id:
try:
snapshot = base_queryset.get(id=snapshot_id)
serializer = ThreatScoreSnapshotSerializer(
snapshot, context={"request": request}
)
return Response(serializer.data, status=status.HTTP_200_OK)
except ThreatScoreSnapshot.DoesNotExist:
raise NotFound(detail="ThreatScore snapshot not found")
# Case 2: Latest snapshot per provider (default)
# Apply filters manually: this @action is outside the standard list endpoint flow,
# so DRF's filter backends don't execute and we must flatten JSON:API params ourselves.
normalized_params = QueryDict(mutable=True)
for param_key, values in request.query_params.lists():
normalized_key = param_key
if param_key.startswith("filter[") and param_key.endswith("]"):
normalized_key = param_key[7:-1]
if normalized_key == "snapshot_id":
continue
normalized_params.setlist(normalized_key, values)
filterset = ThreatScoreSnapshotFilter(normalized_params, queryset=base_queryset)
filtered_queryset = filterset.qs
# Get distinct provider IDs from filtered queryset
# Pick the latest snapshot per provider using Postgres DISTINCT ON pattern.
# This avoids issuing one query per provider (N+1) when the filtered dataset is large.
latest_snapshot_ids = list(
filtered_queryset.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values_list("id", flat=True)
)
latest_snapshot_map = {
snapshot.id: snapshot
for snapshot in filtered_queryset.filter(id__in=latest_snapshot_ids)
}
latest_snapshots = [
latest_snapshot_map[snapshot_id]
for snapshot_id in latest_snapshot_ids
if snapshot_id in latest_snapshot_map
]
if len(latest_snapshots) <= 1:
serializer = ThreatScoreSnapshotSerializer(
latest_snapshots, many=True, context={"request": request}
)
return Response(serializer.data, status=status.HTTP_200_OK)
snapshot_ids = [
snapshot.id for snapshot in latest_snapshots if snapshot and snapshot.id
]
aggregated_snapshot = self._build_threatscore_overview_snapshot(
snapshot_ids, tenant_id
)
serializer = ThreatScoreSnapshotSerializer(
[aggregated_snapshot], many=True, context={"request": request}
)
return Response(serializer.data, status=status.HTTP_200_OK)
def _build_threatscore_overview_snapshot(self, snapshot_ids, tenant_id):
"""
Aggregate the latest snapshots into a single overview snapshot for the tenant.
"""
if not snapshot_ids:
raise ValueError(
"Snapshot id list cannot be empty when aggregating threatscore overview"
)
base_queryset = ThreatScoreSnapshot.objects.filter(
tenant_id=tenant_id, id__in=snapshot_ids
)
annotated_queryset = (
base_queryset.annotate(
active_requirements=ExpressionWrapper(
F("total_requirements") - F("manual_requirements"),
output_field=IntegerField(),
)
)
.annotate(
weight=Case(
When(total_findings__gt=0, then=F("total_findings")),
When(
active_requirements__gt=0,
then=F("active_requirements"),
),
default=Value(1, output_field=IntegerField()),
output_field=IntegerField(),
)
)
.order_by()
)
aggregated_metrics = annotated_queryset.aggregate(
total_requirements=Sum("total_requirements"),
passed_requirements=Sum("passed_requirements"),
failed_requirements=Sum("failed_requirements"),
manual_requirements=Sum("manual_requirements"),
total_findings=Sum("total_findings"),
passed_findings=Sum("passed_findings"),
failed_findings=Sum("failed_findings"),
weighted_overall_sum=Sum(
ExpressionWrapper(
F("overall_score") * F("weight"),
output_field=DecimalField(max_digits=14, decimal_places=4),
)
),
overall_weight=Sum("weight"),
unweighted_overall_sum=Sum("overall_score"),
weighted_delta_sum=Sum(
Case(
When(
score_delta__isnull=False,
then=ExpressionWrapper(
F("score_delta") * F("weight"),
output_field=DecimalField(max_digits=14, decimal_places=4),
),
),
default=Value(
Decimal("0"),
output_field=DecimalField(max_digits=14, decimal_places=4),
),
output_field=DecimalField(max_digits=14, decimal_places=4),
)
),
delta_weight=Sum(
Case(
When(score_delta__isnull=False, then=F("weight")),
default=Value(0, output_field=IntegerField()),
output_field=IntegerField(),
)
),
provider_count=Count("id"),
latest_inserted_at=Max("inserted_at"),
)
total_requirements = aggregated_metrics["total_requirements"] or 0
passed_requirements = aggregated_metrics["passed_requirements"] or 0
failed_requirements = aggregated_metrics["failed_requirements"] or 0
manual_requirements = aggregated_metrics["manual_requirements"] or 0
total_findings = aggregated_metrics["total_findings"] or 0
passed_findings = aggregated_metrics["passed_findings"] or 0
failed_findings = aggregated_metrics["failed_findings"] or 0
weighted_overall_sum = aggregated_metrics["weighted_overall_sum"]
if weighted_overall_sum is None:
weighted_overall_sum = Decimal("0")
unweighted_overall_sum = aggregated_metrics["unweighted_overall_sum"]
if unweighted_overall_sum is None:
unweighted_overall_sum = Decimal("0")
overall_weight = aggregated_metrics["overall_weight"] or 0
provider_count = aggregated_metrics["provider_count"] or 0
weighted_delta_sum = aggregated_metrics["weighted_delta_sum"]
if weighted_delta_sum is None:
weighted_delta_sum = Decimal("0")
delta_weight = aggregated_metrics["delta_weight"] or 0
if overall_weight > 0:
overall_score = (weighted_overall_sum / Decimal(overall_weight)).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
elif provider_count > 0:
overall_score = (unweighted_overall_sum / Decimal(provider_count)).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
else:
overall_score = Decimal("0.00")
if delta_weight > 0:
score_delta = (weighted_delta_sum / Decimal(delta_weight)).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
else:
score_delta = None
section_weighted_sums = defaultdict(lambda: Decimal("0"))
section_weights = defaultdict(lambda: Decimal("0"))
combined_critical_requirements = {}
snapshots_with_weight = list(annotated_queryset)
for snapshot in snapshots_with_weight:
weight_value = getattr(snapshot, "weight", None)
try:
weight_decimal = Decimal(weight_value)
except (InvalidOperation, TypeError):
weight_decimal = Decimal("1")
if weight_decimal <= 0:
weight_decimal = Decimal("1")
section_scores = snapshot.section_scores or {}
for section, score in section_scores.items():
try:
score_decimal = Decimal(str(score))
except (InvalidOperation, TypeError):
continue
section_weighted_sums[section] += score_decimal * weight_decimal
section_weights[section] += weight_decimal
for requirement in snapshot.critical_requirements or []:
key = requirement.get("requirement_id") or requirement.get("title")
if not key:
continue
existing = combined_critical_requirements.get(key)
def requirement_sort_key(item):
return (
item.get("risk_level") or 0,
item.get("weight") or 0,
)
if existing is None or requirement_sort_key(
requirement
) > requirement_sort_key(existing):
combined_critical_requirements[key] = deepcopy(requirement)
aggregated_section_scores = {}
for section, total in section_weighted_sums.items():
weight_total = section_weights[section]
if weight_total > 0:
aggregated_section_scores[section] = str(
(total / weight_total).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
)
aggregated_section_scores = dict(sorted(aggregated_section_scores.items()))
aggregated_critical_requirements = sorted(
combined_critical_requirements.values(),
key=lambda item: (
item.get("risk_level") or 0,
item.get("weight") or 0,
),
reverse=True,
)
aggregated_snapshot = ThreatScoreSnapshot(
tenant_id=tenant_id,
scan=None,
provider=None,
compliance_id="prowler_threatscore_overview",
overall_score=overall_score,
score_delta=score_delta,
section_scores=aggregated_section_scores,
critical_requirements=aggregated_critical_requirements,
total_requirements=total_requirements,
passed_requirements=passed_requirements,
failed_requirements=failed_requirements,
manual_requirements=manual_requirements,
total_findings=total_findings,
passed_findings=passed_findings,
failed_findings=failed_findings,
)
latest_inserted_at = aggregated_metrics["latest_inserted_at"]
if latest_inserted_at is not None:
aggregated_snapshot.inserted_at = latest_inserted_at
aggregated_snapshot._aggregated = True
return aggregated_snapshot
@extend_schema(tags=["Schedule"])
@extend_schema_view(

View File

@@ -7,7 +7,6 @@ from shutil import rmtree
import matplotlib.pyplot as plt
from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django.db.models import Count, Q
from reportlab.lib import colors
from reportlab.lib.enums import TA_CENTER
from reportlab.lib.pagesizes import letter
@@ -26,11 +25,22 @@ from reportlab.platypus import (
TableStyle,
)
from tasks.jobs.export import _generate_output_directory, _upload_to_s3
from tasks.jobs.threatscore import compute_threatscore_metrics
from tasks.jobs.threatscore_utils import (
_aggregate_requirement_statistics_from_database,
_calculate_requirements_data_from_statistics,
)
from tasks.utils import batched
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import Finding, Provider, ScanSummary, StatusChoices
from api.models import (
Finding,
Provider,
ScanSummary,
StatusChoices,
ThreatScoreSnapshot,
)
from api.utils import initialize_prowler_provider
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.finding import Finding as FindingOutput
@@ -434,56 +444,6 @@ def _add_pdf_footer(canvas_obj: canvas.Canvas, doc: SimpleDocTemplate) -> None:
canvas_obj.drawString(width - text_width - 30, 20, powered_text)
def _aggregate_requirement_statistics_from_database(
tenant_id: str, scan_id: str
) -> dict[str, dict[str, int]]:
"""
Aggregate finding statistics by check_id using database aggregation.
This function uses Django ORM aggregation to calculate pass/fail statistics
entirely in the database, avoiding the need to load findings into memory.
Args:
tenant_id (str): The tenant ID for Row-Level Security context.
scan_id (str): The ID of the scan to retrieve findings for.
Returns:
dict[str, dict[str, int]]: Dictionary mapping check_id to statistics:
- 'passed' (int): Number of passed findings for this check
- 'total' (int): Total number of findings for this check
Example:
{
'aws_iam_user_mfa_enabled': {'passed': 10, 'total': 15},
'aws_s3_bucket_public_access': {'passed': 0, 'total': 5}
}
"""
requirement_statistics_by_check_id = {}
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
# Use database aggregation to calculate stats without loading findings into memory
aggregated_statistics_queryset = (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.values("check_id")
.annotate(
total_findings=Count("id"),
passed_findings=Count("id", filter=Q(status=StatusChoices.PASS)),
)
)
for aggregated_stat in aggregated_statistics_queryset:
check_id = aggregated_stat["check_id"]
requirement_statistics_by_check_id[check_id] = {
"passed": aggregated_stat["passed_findings"],
"total": aggregated_stat["total_findings"],
}
logger.info(
f"Aggregated statistics for {len(requirement_statistics_by_check_id)} unique checks"
)
return requirement_statistics_by_check_id
def _load_findings_for_requirement_checks(
tenant_id: str, scan_id: str, check_ids: list[str], prowler_provider
) -> dict[str, list[FindingOutput]]:
@@ -544,84 +504,6 @@ def _load_findings_for_requirement_checks(
return dict(findings_by_check_id)
def _calculate_requirements_data_from_statistics(
compliance_obj, requirement_statistics_by_check_id: dict[str, dict[str, int]]
) -> tuple[dict[str, dict], list[dict]]:
"""
Calculate requirement status and statistics using pre-aggregated database statistics.
This function uses O(n) lookups with pre-aggregated statistics from the database,
avoiding the need to iterate over all findings for each requirement.
Args:
compliance_obj: The compliance framework object containing requirements.
requirement_statistics_by_check_id (dict[str, dict[str, int]]): Pre-aggregated statistics
mapping check_id to {'passed': int, 'total': int} counts.
Returns:
tuple[dict[str, dict], list[dict]]: A tuple containing:
- attributes_by_requirement_id: Dictionary mapping requirement IDs to their attributes.
- requirements_list: List of requirement dictionaries with status and statistics.
"""
attributes_by_requirement_id = {}
requirements_list = []
compliance_framework = getattr(compliance_obj, "Framework", "N/A")
compliance_version = getattr(compliance_obj, "Version", "N/A")
for requirement in compliance_obj.Requirements:
requirement_id = requirement.Id
requirement_description = getattr(requirement, "Description", "")
requirement_checks = getattr(requirement, "Checks", [])
requirement_attributes = getattr(requirement, "Attributes", [])
# Store requirement metadata for later use
attributes_by_requirement_id[requirement_id] = {
"attributes": {
"req_attributes": requirement_attributes,
"checks": requirement_checks,
},
"description": requirement_description,
}
# Calculate aggregated passed and total findings for this requirement
total_passed_findings = 0
total_findings_count = 0
for check_id in requirement_checks:
if check_id in requirement_statistics_by_check_id:
check_statistics = requirement_statistics_by_check_id[check_id]
total_findings_count += check_statistics["total"]
total_passed_findings += check_statistics["passed"]
# Determine overall requirement status based on findings
if total_findings_count > 0:
if total_passed_findings == total_findings_count:
requirement_status = StatusChoices.PASS
else:
# Partial pass or complete fail both count as FAIL
requirement_status = StatusChoices.FAIL
else:
# No findings means manual review required
requirement_status = StatusChoices.MANUAL
requirements_list.append(
{
"id": requirement_id,
"attributes": {
"framework": compliance_framework,
"version": compliance_version,
"status": requirement_status,
"description": requirement_description,
"passed_findings": total_passed_findings,
"total_findings": total_findings_count,
},
}
)
return attributes_by_requirement_id, requirements_list
def generate_threatscore_report(
tenant_id: str,
scan_id: str,
@@ -1262,8 +1144,9 @@ def generate_threatscore_report_job(
2. Checks provider type compatibility
3. Generates the output directory
4. Calls generate_threatscore_report to create the PDF
5. Uploads the PDF to S3
6. Cleans up temporary files
5. Computes and stores ThreatScore metrics snapshot
6. Uploads the PDF to S3
7. Cleans up temporary files
Args:
tenant_id (str): The tenant ID for Row-Level Security context.
@@ -1317,6 +1200,66 @@ def generate_threatscore_report_job(
min_risk_level=4,
)
# Compute and store ThreatScore metrics snapshot
logger.info(f"Computing ThreatScore metrics for scan {scan_id}")
try:
metrics = compute_threatscore_metrics(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
compliance_id=compliance_id,
min_risk_level=4,
)
# Create snapshot in database
with rls_transaction(tenant_id):
# Get previous snapshot for the same provider to calculate delta
previous_snapshot = (
ThreatScoreSnapshot.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
compliance_id=compliance_id,
)
.order_by("-inserted_at")
.first()
)
# Calculate score delta (improvement)
score_delta = None
if previous_snapshot:
score_delta = metrics["overall_score"] - float(
previous_snapshot.overall_score
)
snapshot = ThreatScoreSnapshot.objects.create(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
compliance_id=compliance_id,
overall_score=metrics["overall_score"],
score_delta=score_delta,
section_scores=metrics["section_scores"],
critical_requirements=metrics["critical_requirements"],
total_requirements=metrics["total_requirements"],
passed_requirements=metrics["passed_requirements"],
failed_requirements=metrics["failed_requirements"],
manual_requirements=metrics["manual_requirements"],
total_findings=metrics["total_findings"],
passed_findings=metrics["passed_findings"],
failed_findings=metrics["failed_findings"],
)
delta_msg = (
f" (delta: {score_delta:+.2f}%)" if score_delta is not None else ""
)
logger.info(
f"ThreatScore snapshot created with ID {snapshot.id} "
f"(score: {snapshot.overall_score}%{delta_msg})"
)
except Exception as e:
# Log error but don't fail the job if snapshot creation fails
logger.error(f"Error creating ThreatScore snapshot: {e}")
upload_uri = _upload_to_s3(
tenant_id,
scan_id,

View File

@@ -0,0 +1,214 @@
from celery.utils.log import get_task_logger
from tasks.jobs.threatscore_utils import (
_aggregate_requirement_statistics_from_database,
_calculate_requirements_data_from_statistics,
)
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import Provider, StatusChoices
from prowler.lib.check.compliance_models import Compliance
logger = get_task_logger(__name__)
def compute_threatscore_metrics(
tenant_id: str,
scan_id: str,
provider_id: str,
compliance_id: str,
min_risk_level: int = 4,
) -> dict:
"""
Compute ThreatScore metrics for a given scan.
This function calculates all the metrics needed for a ThreatScore snapshot:
- Overall ThreatScore percentage
- Section-by-section scores
- Critical failed requirements (risk >= min_risk_level)
- Summary statistics (requirements and findings counts)
Args:
tenant_id (str): The tenant ID for Row-Level Security context.
scan_id (str): The ID of the scan to analyze.
provider_id (str): The ID of the provider used in the scan.
compliance_id (str): Compliance framework ID (e.g., "prowler_threatscore_aws").
min_risk_level (int): Minimum risk level for critical requirements. Defaults to 4.
Returns:
dict: A dictionary containing:
- overall_score (float): Overall ThreatScore percentage (0-100)
- section_scores (dict): Section name -> score percentage mapping
- critical_requirements (list): List of critical failed requirement dicts
- total_requirements (int): Total number of requirements
- passed_requirements (int): Number of PASS requirements
- failed_requirements (int): Number of FAIL requirements
- manual_requirements (int): Number of MANUAL requirements
- total_findings (int): Total findings count
- passed_findings (int): Passed findings count
- failed_findings (int): Failed findings count
Example:
>>> metrics = compute_threatscore_metrics(
... tenant_id="tenant-123",
... scan_id="scan-456",
... provider_id="provider-789",
... compliance_id="prowler_threatscore_aws"
... )
>>> print(f"Overall ThreatScore: {metrics['overall_score']:.2f}%")
"""
# Get provider and compliance information
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
provider_obj = Provider.objects.get(id=provider_id)
provider_type = provider_obj.provider
frameworks_bulk = Compliance.get_bulk(provider_type)
compliance_obj = frameworks_bulk[compliance_id]
# Aggregate requirement statistics from database
requirement_statistics_by_check_id = (
_aggregate_requirement_statistics_from_database(tenant_id, scan_id)
)
# Calculate requirements data using aggregated statistics
attributes_by_requirement_id, requirements_list = (
_calculate_requirements_data_from_statistics(
compliance_obj, requirement_statistics_by_check_id
)
)
# Initialize metrics
overall_numerator = 0
overall_denominator = 0
overall_has_findings = False
sections_data = {}
total_requirements = len(requirements_list)
passed_requirements = 0
failed_requirements = 0
manual_requirements = 0
total_findings = 0
passed_findings = 0
failed_findings = 0
critical_requirements_list = []
# Process each requirement
for requirement in requirements_list:
requirement_id = requirement["id"]
requirement_status = requirement["attributes"]["status"]
requirement_attributes = attributes_by_requirement_id.get(requirement_id, {})
# Count requirements by status
if requirement_status == StatusChoices.PASS:
passed_requirements += 1
elif requirement_status == StatusChoices.FAIL:
failed_requirements += 1
elif requirement_status == StatusChoices.MANUAL:
manual_requirements += 1
# Get findings data
req_passed_findings = requirement["attributes"].get("passed_findings", 0)
req_total_findings = requirement["attributes"].get("total_findings", 0)
# Accumulate findings counts
total_findings += req_total_findings
passed_findings += req_passed_findings
failed_findings += req_total_findings - req_passed_findings
# Skip requirements with no findings
if req_total_findings == 0:
continue
overall_has_findings = True
# Get requirement metadata
metadata = requirement_attributes.get("attributes", {}).get(
"req_attributes", []
)
if not metadata or len(metadata) == 0:
continue
m = metadata[0]
risk_level = getattr(m, "LevelOfRisk", 0)
weight = getattr(m, "Weight", 0)
section = getattr(m, "Section", "Unknown")
# Calculate ThreatScore components using formula from UI
rate_i = req_passed_findings / req_total_findings
rfac_i = 1 + 0.25 * risk_level
# Update overall score
overall_numerator += rate_i * req_total_findings * weight * rfac_i
overall_denominator += req_total_findings * weight * rfac_i
# Update section scores
if section not in sections_data:
sections_data[section] = {
"numerator": 0,
"denominator": 0,
"has_findings": False,
}
sections_data[section]["has_findings"] = True
sections_data[section]["numerator"] += (
rate_i * req_total_findings * weight * rfac_i
)
sections_data[section]["denominator"] += req_total_findings * weight * rfac_i
# Identify critical failed requirements
if requirement_status == StatusChoices.FAIL and risk_level >= min_risk_level:
critical_requirements_list.append(
{
"requirement_id": requirement_id,
"title": getattr(m, "Title", "N/A"),
"section": section,
"subsection": getattr(m, "SubSection", "N/A"),
"risk_level": risk_level,
"weight": weight,
"passed_findings": req_passed_findings,
"total_findings": req_total_findings,
"description": getattr(m, "AttributeDescription", "N/A"),
}
)
# Calculate overall ThreatScore
if not overall_has_findings:
overall_score = 100.0
elif overall_denominator > 0:
overall_score = (overall_numerator / overall_denominator) * 100
else:
overall_score = 0.0
# Calculate section scores
section_scores = {}
for section, data in sections_data.items():
if data["has_findings"] and data["denominator"] > 0:
section_scores[section] = (data["numerator"] / data["denominator"]) * 100
else:
section_scores[section] = 100.0
# Sort critical requirements by risk level (desc) and weight (desc)
critical_requirements_list.sort(
key=lambda x: (x["risk_level"], x["weight"]), reverse=True
)
logger.info(
f"ThreatScore computed: {overall_score:.2f}% "
f"({passed_requirements}/{total_requirements} requirements passed, "
f"{len(critical_requirements_list)} critical failures)"
)
return {
"overall_score": round(overall_score, 2),
"section_scores": {k: round(v, 2) for k, v in section_scores.items()},
"critical_requirements": critical_requirements_list,
"total_requirements": total_requirements,
"passed_requirements": passed_requirements,
"failed_requirements": failed_requirements,
"manual_requirements": manual_requirements,
"total_findings": total_findings,
"passed_findings": passed_findings,
"failed_findings": failed_findings,
}

View File

@@ -0,0 +1,127 @@
from celery.utils.log import get_task_logger
from django.db.models import Count, Q
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import Finding, StatusChoices
logger = get_task_logger(__name__)
def _aggregate_requirement_statistics_from_database(
tenant_id: str, scan_id: str
) -> dict[str, dict[str, int]]:
"""
Aggregate finding statistics by check_id using database aggregation.
This function uses Django ORM aggregation to calculate pass/fail statistics
entirely in the database, avoiding the need to load findings into memory.
Args:
tenant_id (str): The tenant ID for Row-Level Security context.
scan_id (str): The ID of the scan to retrieve findings for.
Returns:
dict[str, dict[str, int]]: Dictionary mapping check_id to statistics:
- 'passed' (int): Number of passed findings for this check
- 'total' (int): Total number of findings for this check
Example:
{
'aws_iam_user_mfa_enabled': {'passed': 10, 'total': 15},
'aws_s3_bucket_public_access': {'passed': 0, 'total': 5}
}
"""
requirement_statistics_by_check_id = {}
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
aggregated_statistics_queryset = (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.values("check_id")
.annotate(
total_findings=Count("id"),
passed_findings=Count("id", filter=Q(status=StatusChoices.PASS)),
)
)
for aggregated_stat in aggregated_statistics_queryset:
check_id = aggregated_stat["check_id"]
requirement_statistics_by_check_id[check_id] = {
"passed": aggregated_stat["passed_findings"],
"total": aggregated_stat["total_findings"],
}
logger.info(
f"Aggregated statistics for {len(requirement_statistics_by_check_id)} unique checks"
)
return requirement_statistics_by_check_id
def _calculate_requirements_data_from_statistics(
compliance_obj, requirement_statistics_by_check_id: dict[str, dict[str, int]]
) -> tuple[dict[str, dict], list[dict]]:
"""
Calculate requirement status and statistics using pre-aggregated database statistics.
Args:
compliance_obj: The compliance framework object containing requirements.
requirement_statistics_by_check_id (dict[str, dict[str, int]]): Pre-aggregated statistics
mapping check_id to {'passed': int, 'total': int} counts.
Returns:
tuple[dict[str, dict], list[dict]]: A tuple containing:
- attributes_by_requirement_id: Dictionary mapping requirement IDs to their attributes.
- requirements_list: List of requirement dictionaries with status and statistics.
"""
attributes_by_requirement_id = {}
requirements_list = []
compliance_framework = getattr(compliance_obj, "Framework", "N/A")
compliance_version = getattr(compliance_obj, "Version", "N/A")
for requirement in compliance_obj.Requirements:
requirement_id = requirement.Id
requirement_description = getattr(requirement, "Description", "")
requirement_checks = getattr(requirement, "Checks", [])
requirement_attributes = getattr(requirement, "Attributes", [])
attributes_by_requirement_id[requirement_id] = {
"attributes": {
"req_attributes": requirement_attributes,
"checks": requirement_checks,
},
"description": requirement_description,
}
total_passed_findings = 0
total_findings_count = 0
for check_id in requirement_checks:
if check_id in requirement_statistics_by_check_id:
check_statistics = requirement_statistics_by_check_id[check_id]
total_findings_count += check_statistics["total"]
total_passed_findings += check_statistics["passed"]
if total_findings_count > 0:
if total_passed_findings == total_findings_count:
requirement_status = StatusChoices.PASS
else:
requirement_status = StatusChoices.FAIL
else:
requirement_status = StatusChoices.MANUAL
requirements_list.append(
{
"id": requirement_id,
"attributes": {
"framework": compliance_framework,
"version": compliance_version,
"status": requirement_status,
"description": requirement_description,
"passed_findings": total_passed_findings,
"total_findings": total_findings_count,
},
}
)
return attributes_by_requirement_id, requirements_list

View File

@@ -1,19 +1,25 @@
import uuid
from datetime import timedelta
from decimal import Decimal
from pathlib import Path
from unittest.mock import MagicMock, patch
import matplotlib
import pytest
from django.utils import timezone
from freezegun import freeze_time
from tasks.jobs.report import (
_aggregate_requirement_statistics_from_database,
_calculate_requirements_data_from_statistics,
_load_findings_for_requirement_checks,
generate_threatscore_report,
generate_threatscore_report_job,
)
from tasks.jobs.threatscore_utils import (
_aggregate_requirement_statistics_from_database,
_calculate_requirements_data_from_statistics,
)
from tasks.tasks import generate_threatscore_report_task
from api.models import Finding, StatusChoices
from api.models import Finding, Scan, StateChoices, StatusChoices, ThreatScoreSnapshot
from prowler.lib.check.models import Severity
matplotlib.use("Agg") # Use non-interactive backend for tests
@@ -39,6 +45,7 @@ class TestGenerateThreatscoreReport:
assert result == {"upload": False}
mock_filter.assert_called_once_with(scan_id=self.scan_id)
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
@patch("tasks.jobs.report.rmtree")
@patch("tasks.jobs.report._upload_to_s3")
@patch("tasks.jobs.report.generate_threatscore_report")
@@ -53,6 +60,7 @@ class TestGenerateThreatscoreReport:
mock_generate_report,
mock_upload,
mock_rmtree,
mock_snapshot_create,
):
mock_scan_summary_filter.return_value.exists.return_value = True
@@ -95,8 +103,10 @@ class TestGenerateThreatscoreReport:
Path("/tmp/threatscore_path_threatscore_report.pdf").parent,
ignore_errors=True,
)
mock_snapshot_create.assert_called_once()
def test_generate_threatscore_report_fails_upload(self):
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
def test_generate_threatscore_report_fails_upload(self, mock_snapshot_create):
with (
patch("tasks.jobs.report.ScanSummary.objects.filter") as mock_filter,
patch("tasks.jobs.report.Provider.objects.get") as mock_provider_get,
@@ -125,8 +135,12 @@ class TestGenerateThreatscoreReport:
)
assert result == {"upload": False}
mock_snapshot_create.assert_called_once()
def test_generate_threatscore_report_logs_rmtree_exception(self, caplog):
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
def test_generate_threatscore_report_logs_rmtree_exception(
self, mock_snapshot_create, caplog
):
with (
patch("tasks.jobs.report.ScanSummary.objects.filter") as mock_filter,
patch("tasks.jobs.report.Provider.objects.get") as mock_provider_get,
@@ -160,8 +174,10 @@ class TestGenerateThreatscoreReport:
provider_id=self.provider_id,
)
assert "Error deleting output files" in caplog.text
mock_snapshot_create.assert_called_once()
def test_generate_threatscore_report_azure_provider(self):
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
def test_generate_threatscore_report_azure_provider(self, mock_snapshot_create):
with (
patch("tasks.jobs.report.ScanSummary.objects.filter") as mock_filter,
patch("tasks.jobs.report.Provider.objects.get") as mock_provider_get,
@@ -200,6 +216,135 @@ class TestGenerateThreatscoreReport:
only_failed=True,
min_risk_level=4,
)
mock_snapshot_create.assert_called_once()
@patch("tasks.jobs.report.rmtree")
@patch(
"tasks.jobs.report._upload_to_s3",
return_value="s3://bucket/threatscore/threatscore_report.pdf",
)
@patch("tasks.jobs.report.generate_threatscore_report")
@patch("tasks.jobs.report._generate_output_directory")
@patch("tasks.jobs.report.ScanSummary.objects.filter")
@patch("tasks.jobs.report.compute_threatscore_metrics")
@pytest.mark.django_db
@freeze_time("2025-01-10T12:00:00Z")
def test_generate_threatscore_report_persists_snapshot_and_delta(
self,
mock_compute_metrics,
mock_scan_summary_filter,
mock_generate_output_directory,
mock_generate_report,
mock_upload,
mock_rmtree,
tenants_fixture,
providers_fixture,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan_previous = Scan.objects.create(
tenant=tenant,
provider=provider,
name="previous-threatscore-scan",
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
started_at=timezone.now() - timedelta(hours=4),
completed_at=timezone.now() - timedelta(hours=3),
)
ThreatScoreSnapshot.objects.create(
tenant=tenant,
scan=scan_previous,
provider=provider,
compliance_id="prowler_threatscore_aws",
overall_score=Decimal("70.00"),
score_delta=None,
section_scores={"1. IAM": "65.00"},
critical_requirements=[],
total_requirements=50,
passed_requirements=35,
failed_requirements=15,
manual_requirements=0,
total_findings=40,
passed_findings=25,
failed_findings=15,
)
scan_current = Scan.objects.create(
tenant=tenant,
provider=provider,
name="current-threatscore-scan",
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
started_at=timezone.now() - timedelta(hours=2),
completed_at=timezone.now() - timedelta(hours=1),
)
mock_scan_summary_filter.return_value.exists.return_value = True
mock_generate_output_directory.return_value = (
"/tmp/output",
"/tmp/compressed",
"/tmp/threatscore_path",
)
metrics = {
"overall_score": 85.5,
"score_delta": 10.0,
"section_scores": {"1. IAM": 82.3, "2. Attack Surface": 60.0},
"critical_requirements": [
{
"requirement_id": "req_new",
"title": "New high-risk requirement",
"section": "1. IAM",
"subsection": "Root Account",
"risk_level": 5,
"weight": 150,
"passed_findings": 7,
"total_findings": 10,
"description": "Critical requirement description",
}
],
"total_requirements": 140,
"passed_requirements": 100,
"failed_requirements": 40,
"manual_requirements": 0,
"total_findings": 200,
"passed_findings": 150,
"failed_findings": 50,
}
mock_compute_metrics.return_value = metrics
result = generate_threatscore_report_job(
tenant_id=str(tenant.id),
scan_id=str(scan_current.id),
provider_id=str(provider.id),
)
assert result == {"upload": True}
mock_compute_metrics.assert_called_once_with(
tenant_id=str(tenant.id),
scan_id=str(scan_current.id),
provider_id=str(provider.id),
compliance_id="prowler_threatscore_aws",
min_risk_level=4,
)
mock_generate_report.assert_called_once()
mock_upload.assert_called_once()
mock_rmtree.assert_called_once()
snapshots = ThreatScoreSnapshot.objects.filter(
tenant=tenant, provider=provider
).order_by("inserted_at")
assert snapshots.count() == 2
new_snapshot = ThreatScoreSnapshot.objects.get(scan=scan_current)
assert new_snapshot.compliance_id == "prowler_threatscore_aws"
assert Decimal(new_snapshot.overall_score) == Decimal("85.50")
assert Decimal(new_snapshot.score_delta) == Decimal("15.50")
assert new_snapshot.section_scores == metrics["section_scores"]
assert new_snapshot.critical_requirements == metrics["critical_requirements"]
assert new_snapshot.total_requirements == metrics["total_requirements"]
assert new_snapshot.total_findings == metrics["total_findings"]
@pytest.mark.django_db