mirror of
https://github.com/prowler-cloud/prowler.git
synced 2025-12-19 05:17:47 +00:00
feat(threatscore): implement ThreatScoreSnapshot model, filter, serializer, and view for ThreatScore metrics retrieval (#9148)
This commit is contained in:
committed by
GitHub
parent
73a277f27b
commit
beec37b0da
@@ -14,6 +14,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- Support muting findings based on simple rules with custom reason [(#9051)](https://github.com/prowler-cloud/prowler/pull/9051)
|
||||
- Support C5 compliance framework for the GCP provider [(#9097)](https://github.com/prowler-cloud/prowler/pull/9097)
|
||||
- Support for Amazon Bedrock and OpenAI compatible providers in Lighthouse AI [(#8957)](https://github.com/prowler-cloud/prowler/pull/8957)
|
||||
- Tenant-wide ThreatScore overview aggregation and snapshot persistence with backfill support [(#9148)](https://github.com/prowler-cloud/prowler/pull/9148)
|
||||
- Support for MongoDB Atlas provider [(#9167)](https://github.com/prowler-cloud/prowler/pull/9167)
|
||||
|
||||
---
|
||||
|
||||
@@ -47,6 +47,7 @@ from api.models import (
|
||||
StatusChoices,
|
||||
Task,
|
||||
TenantAPIKey,
|
||||
ThreatScoreSnapshot,
|
||||
User,
|
||||
)
|
||||
from api.rls import Tenant
|
||||
@@ -998,3 +999,36 @@ class MuteRuleFilter(FilterSet):
|
||||
"inserted_at": ["gte", "lte"],
|
||||
"updated_at": ["gte", "lte"],
|
||||
}
|
||||
|
||||
|
||||
class ThreatScoreSnapshotFilter(FilterSet):
|
||||
"""
|
||||
Filter for ThreatScore snapshots.
|
||||
Allows filtering by scan, provider, compliance_id, and date ranges.
|
||||
"""
|
||||
|
||||
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
|
||||
scan_id = UUIDFilter(field_name="scan__id", lookup_expr="exact")
|
||||
scan_id__in = UUIDInFilter(field_name="scan__id", lookup_expr="in")
|
||||
provider_id = UUIDFilter(field_name="provider__id", lookup_expr="exact")
|
||||
provider_id__in = UUIDInFilter(field_name="provider__id", lookup_expr="in")
|
||||
provider_type = ChoiceFilter(
|
||||
field_name="provider__provider", choices=Provider.ProviderChoices.choices
|
||||
)
|
||||
provider_type__in = ChoiceInFilter(
|
||||
field_name="provider__provider",
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
compliance_id = CharFilter(field_name="compliance_id", lookup_expr="exact")
|
||||
compliance_id__in = CharInFilter(field_name="compliance_id", lookup_expr="in")
|
||||
|
||||
class Meta:
|
||||
model = ThreatScoreSnapshot
|
||||
fields = {
|
||||
"scan": ["exact", "in"],
|
||||
"provider": ["exact", "in"],
|
||||
"compliance_id": ["exact", "in"],
|
||||
"inserted_at": ["date", "gte", "lte"],
|
||||
"overall_score": ["exact", "gte", "lte"],
|
||||
}
|
||||
|
||||
170
api/src/backend/api/migrations/0057_threatscoresnapshot.py
Normal file
170
api/src/backend/api/migrations/0057_threatscoresnapshot.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Generated by Django 5.1.13 on 2025-10-31 09:04
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0056_remove_provider_unique_provider_uids_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ThreatScoreSnapshot",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"compliance_id",
|
||||
models.CharField(
|
||||
help_text="Compliance framework ID (e.g., 'prowler_threatscore_aws')",
|
||||
max_length=100,
|
||||
),
|
||||
),
|
||||
(
|
||||
"overall_score",
|
||||
models.DecimalField(
|
||||
decimal_places=2,
|
||||
help_text="Overall ThreatScore percentage (0-100)",
|
||||
max_digits=5,
|
||||
),
|
||||
),
|
||||
(
|
||||
"score_delta",
|
||||
models.DecimalField(
|
||||
blank=True,
|
||||
decimal_places=2,
|
||||
help_text="Score change compared to previous snapshot (positive = improvement)",
|
||||
max_digits=5,
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"section_scores",
|
||||
models.JSONField(
|
||||
blank=True,
|
||||
default=dict,
|
||||
help_text="ThreatScore breakdown by section",
|
||||
),
|
||||
),
|
||||
(
|
||||
"critical_requirements",
|
||||
models.JSONField(
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text="List of critical failed requirements (risk >= 4)",
|
||||
),
|
||||
),
|
||||
(
|
||||
"total_requirements",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Total number of requirements evaluated"
|
||||
),
|
||||
),
|
||||
(
|
||||
"passed_requirements",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Number of requirements with PASS status"
|
||||
),
|
||||
),
|
||||
(
|
||||
"failed_requirements",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Number of requirements with FAIL status"
|
||||
),
|
||||
),
|
||||
(
|
||||
"manual_requirements",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Number of requirements with MANUAL status"
|
||||
),
|
||||
),
|
||||
(
|
||||
"total_findings",
|
||||
models.IntegerField(
|
||||
default=0,
|
||||
help_text="Total number of findings across all requirements",
|
||||
),
|
||||
),
|
||||
(
|
||||
"passed_findings",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Number of findings with PASS status"
|
||||
),
|
||||
),
|
||||
(
|
||||
"failed_findings",
|
||||
models.IntegerField(
|
||||
default=0, help_text="Number of findings with FAIL status"
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="threatscore_snapshots",
|
||||
related_query_name="threatscore_snapshot",
|
||||
to="api.provider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"scan",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="threatscore_snapshots",
|
||||
related_query_name="threatscore_snapshot",
|
||||
to="api.scan",
|
||||
),
|
||||
),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "threatscore_snapshots",
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="threatscoresnapshot",
|
||||
index=models.Index(
|
||||
fields=["tenant_id", "scan_id"], name="threatscore_snap_t_scan_idx"
|
||||
),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="threatscoresnapshot",
|
||||
index=models.Index(
|
||||
fields=["tenant_id", "provider_id"], name="threatscore_snap_t_prov_idx"
|
||||
),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="threatscoresnapshot",
|
||||
index=models.Index(
|
||||
fields=["tenant_id", "inserted_at"], name="threatscore_snap_t_time_idx"
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="threatscoresnapshot",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_threatscoresnapshot",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -2239,3 +2239,137 @@ class LighthouseProviderModels(RowLevelSecurityProtectedModel):
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "lighthouse-models"
|
||||
|
||||
|
||||
class ThreatScoreSnapshot(RowLevelSecurityProtectedModel):
|
||||
"""
|
||||
Stores historical ThreatScore metrics for a given scan.
|
||||
Snapshots are created automatically after each ThreatScore report generation.
|
||||
"""
|
||||
|
||||
objects = models.Manager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
|
||||
scan = models.ForeignKey(
|
||||
Scan,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="threatscore_snapshots",
|
||||
related_query_name="threatscore_snapshot",
|
||||
)
|
||||
|
||||
provider = models.ForeignKey(
|
||||
Provider,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="threatscore_snapshots",
|
||||
related_query_name="threatscore_snapshot",
|
||||
)
|
||||
|
||||
compliance_id = models.CharField(
|
||||
max_length=100,
|
||||
blank=False,
|
||||
null=False,
|
||||
help_text="Compliance framework ID (e.g., 'prowler_threatscore_aws')",
|
||||
)
|
||||
|
||||
# Overall ThreatScore metrics
|
||||
overall_score = models.DecimalField(
|
||||
max_digits=5,
|
||||
decimal_places=2,
|
||||
help_text="Overall ThreatScore percentage (0-100)",
|
||||
)
|
||||
|
||||
# Score improvement/degradation compared to previous snapshot
|
||||
score_delta = models.DecimalField(
|
||||
max_digits=5,
|
||||
decimal_places=2,
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text="Score change compared to previous snapshot (positive = improvement)",
|
||||
)
|
||||
|
||||
# Section breakdown stored as JSON
|
||||
# Format: {"1. IAM": 85.5, "2. Attack Surface": 92.3, ...}
|
||||
section_scores = models.JSONField(
|
||||
default=dict,
|
||||
blank=True,
|
||||
help_text="ThreatScore breakdown by section",
|
||||
)
|
||||
|
||||
# Critical requirements metadata stored as JSON
|
||||
# Format: [{"requirement_id": "...", "risk_level": 5, "weight": 150, ...}, ...]
|
||||
critical_requirements = models.JSONField(
|
||||
default=list,
|
||||
blank=True,
|
||||
help_text="List of critical failed requirements (risk >= 4)",
|
||||
)
|
||||
|
||||
# Summary statistics
|
||||
total_requirements = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Total number of requirements evaluated",
|
||||
)
|
||||
|
||||
passed_requirements = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Number of requirements with PASS status",
|
||||
)
|
||||
|
||||
failed_requirements = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Number of requirements with FAIL status",
|
||||
)
|
||||
|
||||
manual_requirements = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Number of requirements with MANUAL status",
|
||||
)
|
||||
|
||||
total_findings = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Total number of findings across all requirements",
|
||||
)
|
||||
|
||||
passed_findings = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Number of findings with PASS status",
|
||||
)
|
||||
|
||||
failed_findings = models.IntegerField(
|
||||
default=0,
|
||||
help_text="Number of findings with FAIL status",
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"ThreatScore {self.overall_score}% for scan {self.scan_id} ({self.inserted_at})"
|
||||
|
||||
class Meta(RowLevelSecurityProtectedModel.Meta):
|
||||
db_table = "threatscore_snapshots"
|
||||
|
||||
constraints = [
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
indexes = [
|
||||
models.Index(
|
||||
fields=["tenant_id", "scan_id"],
|
||||
name="threatscore_snap_t_scan_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "provider_id"],
|
||||
name="threatscore_snap_t_prov_idx",
|
||||
),
|
||||
models.Index(
|
||||
fields=["tenant_id", "inserted_at"],
|
||||
name="threatscore_snap_t_time_idx",
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "threatscore-snapshots"
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||
@@ -56,6 +57,7 @@ from api.models import (
|
||||
StateChoices,
|
||||
Task,
|
||||
TenantAPIKey,
|
||||
ThreatScoreSnapshot,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
@@ -6221,6 +6223,407 @@ class TestOverviewViewSet:
|
||||
for entry in grouped_data:
|
||||
assert "findings" not in entry["attributes"]
|
||||
|
||||
def _create_scan(self, tenant, provider, name, started_at=None):
|
||||
scan_started = started_at or datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
return Scan.objects.create(
|
||||
tenant=tenant,
|
||||
provider=provider,
|
||||
name=name,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
started_at=scan_started,
|
||||
completed_at=scan_started + timedelta(minutes=30),
|
||||
)
|
||||
|
||||
def _create_threatscore_snapshot(
|
||||
self,
|
||||
tenant,
|
||||
scan,
|
||||
provider,
|
||||
*,
|
||||
compliance_id,
|
||||
overall_score,
|
||||
score_delta,
|
||||
section_scores,
|
||||
critical_requirements,
|
||||
total_requirements,
|
||||
passed_requirements,
|
||||
failed_requirements,
|
||||
manual_requirements,
|
||||
total_findings,
|
||||
passed_findings,
|
||||
failed_findings,
|
||||
):
|
||||
return ThreatScoreSnapshot.objects.create(
|
||||
tenant=tenant,
|
||||
scan=scan,
|
||||
provider=provider,
|
||||
compliance_id=compliance_id,
|
||||
overall_score=Decimal(overall_score),
|
||||
score_delta=Decimal(score_delta) if score_delta is not None else None,
|
||||
section_scores=section_scores,
|
||||
critical_requirements=critical_requirements,
|
||||
total_requirements=total_requirements,
|
||||
passed_requirements=passed_requirements,
|
||||
failed_requirements=failed_requirements,
|
||||
manual_requirements=manual_requirements,
|
||||
total_findings=total_findings,
|
||||
passed_findings=passed_findings,
|
||||
failed_findings=failed_findings,
|
||||
)
|
||||
|
||||
def test_overview_threatscore_returns_weighted_aggregate_snapshot(
|
||||
self, authenticated_client, tenants_fixture, providers_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
|
||||
scan1 = self._create_scan(tenant, provider1, "agg-scan-one")
|
||||
scan2 = self._create_scan(tenant, provider2, "agg-scan-two")
|
||||
|
||||
snapshot1 = self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan1,
|
||||
provider1,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="80.00",
|
||||
score_delta="5.00",
|
||||
section_scores={"1. IAM": "70.00", "2. Attack Surface": "60.00"},
|
||||
critical_requirements=[
|
||||
{
|
||||
"requirement_id": "req_shared",
|
||||
"title": "Shared requirement (preferred)",
|
||||
"section": "1. IAM",
|
||||
"subsection": "Sub IAM",
|
||||
"risk_level": 5,
|
||||
"weight": 150,
|
||||
"passed_findings": 14,
|
||||
"total_findings": 20,
|
||||
"description": "Higher risk duplicate",
|
||||
},
|
||||
{
|
||||
"requirement_id": "req_unique_one",
|
||||
"title": "Unique provider one",
|
||||
"section": "2. Attack Surface",
|
||||
"subsection": "Sub Attack",
|
||||
"risk_level": 4,
|
||||
"weight": 90,
|
||||
"passed_findings": 20,
|
||||
"total_findings": 30,
|
||||
"description": "Lower risk",
|
||||
},
|
||||
],
|
||||
total_requirements=120,
|
||||
passed_requirements=90,
|
||||
failed_requirements=30,
|
||||
manual_requirements=0,
|
||||
total_findings=100,
|
||||
passed_findings=70,
|
||||
failed_findings=30,
|
||||
)
|
||||
|
||||
snapshot2 = self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan2,
|
||||
provider2,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="20.00",
|
||||
score_delta="-2.00",
|
||||
section_scores={
|
||||
"1. IAM": "10.00",
|
||||
"2. Attack Surface": "40.00",
|
||||
"3. Logging": "30.00",
|
||||
},
|
||||
critical_requirements=[
|
||||
{
|
||||
"requirement_id": "req_shared",
|
||||
"title": "Shared requirement (secondary)",
|
||||
"section": "1. IAM",
|
||||
"subsection": "Sub IAM",
|
||||
"risk_level": 4,
|
||||
"weight": 120,
|
||||
"passed_findings": 8,
|
||||
"total_findings": 12,
|
||||
"description": "Lower risk duplicate",
|
||||
},
|
||||
{
|
||||
"requirement_id": "req_unique_two",
|
||||
"title": "Unique provider two",
|
||||
"section": "3. Logging",
|
||||
"subsection": "Sub Logging",
|
||||
"risk_level": 5,
|
||||
"weight": 110,
|
||||
"passed_findings": 6,
|
||||
"total_findings": 10,
|
||||
"description": "Another critical requirement",
|
||||
},
|
||||
],
|
||||
total_requirements=80,
|
||||
passed_requirements=30,
|
||||
failed_requirements=50,
|
||||
manual_requirements=0,
|
||||
total_findings=50,
|
||||
passed_findings=15,
|
||||
failed_findings=35,
|
||||
)
|
||||
|
||||
older_inserted = datetime(2025, 1, 1, 12, 0, tzinfo=timezone.utc)
|
||||
newer_inserted = datetime(2025, 1, 2, 12, 0, tzinfo=timezone.utc)
|
||||
ThreatScoreSnapshot.objects.filter(id=snapshot1.id).update(
|
||||
inserted_at=older_inserted
|
||||
)
|
||||
ThreatScoreSnapshot.objects.filter(id=snapshot2.id).update(
|
||||
inserted_at=newer_inserted
|
||||
)
|
||||
snapshot2.refresh_from_db()
|
||||
|
||||
response = authenticated_client.get(reverse("overview-threatscore"))
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
body = response.json()
|
||||
assert len(body["data"]) == 1
|
||||
aggregated = body["data"][0]
|
||||
|
||||
assert aggregated["id"] == "n/a"
|
||||
assert aggregated["relationships"]["scan"]["data"] is None
|
||||
assert aggregated["relationships"]["provider"]["data"] is None
|
||||
|
||||
attrs = aggregated["attributes"]
|
||||
assert Decimal(attrs["overall_score"]) == Decimal("60.00")
|
||||
assert Decimal(attrs["score_delta"]) == Decimal("2.67")
|
||||
assert attrs["inserted_at"] == snapshot2.inserted_at.isoformat().replace(
|
||||
"+00:00", "Z"
|
||||
)
|
||||
assert attrs["total_findings"] == 150
|
||||
assert attrs["passed_findings"] == 85
|
||||
assert attrs["failed_findings"] == 65
|
||||
assert attrs["total_requirements"] == 200
|
||||
assert attrs["passed_requirements"] == 120
|
||||
assert attrs["failed_requirements"] == 80
|
||||
assert attrs["manual_requirements"] == 0
|
||||
|
||||
assert attrs["section_scores"] == {
|
||||
"1. IAM": "50.00",
|
||||
"2. Attack Surface": "53.33",
|
||||
"3. Logging": "30.00",
|
||||
}
|
||||
|
||||
expected_critical = [
|
||||
{
|
||||
"requirement_id": "req_shared",
|
||||
"title": "Shared requirement (preferred)",
|
||||
"section": "1. IAM",
|
||||
"subsection": "Sub IAM",
|
||||
"risk_level": 5,
|
||||
"weight": 150,
|
||||
"passed_findings": 14,
|
||||
"total_findings": 20,
|
||||
"description": "Higher risk duplicate",
|
||||
},
|
||||
{
|
||||
"requirement_id": "req_unique_two",
|
||||
"title": "Unique provider two",
|
||||
"section": "3. Logging",
|
||||
"subsection": "Sub Logging",
|
||||
"risk_level": 5,
|
||||
"weight": 110,
|
||||
"passed_findings": 6,
|
||||
"total_findings": 10,
|
||||
"description": "Another critical requirement",
|
||||
},
|
||||
{
|
||||
"requirement_id": "req_unique_one",
|
||||
"title": "Unique provider one",
|
||||
"section": "2. Attack Surface",
|
||||
"subsection": "Sub Attack",
|
||||
"risk_level": 4,
|
||||
"weight": 90,
|
||||
"passed_findings": 20,
|
||||
"total_findings": 30,
|
||||
"description": "Lower risk",
|
||||
},
|
||||
]
|
||||
assert attrs["critical_requirements"] == expected_critical
|
||||
|
||||
def test_overview_threatscore_weight_fallback_to_requirements(
|
||||
self, authenticated_client, tenants_fixture, providers_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
|
||||
scan1 = self._create_scan(tenant, provider1, "fallback-scan-1")
|
||||
scan2 = self._create_scan(tenant, provider2, "fallback-scan-2")
|
||||
|
||||
self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan1,
|
||||
provider1,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="90.00",
|
||||
score_delta="4.00",
|
||||
section_scores={"1. IAM": "90.00"},
|
||||
critical_requirements=[],
|
||||
total_requirements=10,
|
||||
passed_requirements=8,
|
||||
failed_requirements=0,
|
||||
manual_requirements=2,
|
||||
total_findings=0,
|
||||
passed_findings=0,
|
||||
failed_findings=0,
|
||||
)
|
||||
self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan2,
|
||||
provider2,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="50.00",
|
||||
score_delta="1.00",
|
||||
section_scores={"1. IAM": "40.00"},
|
||||
critical_requirements=[],
|
||||
total_requirements=12,
|
||||
passed_requirements=5,
|
||||
failed_requirements=7,
|
||||
manual_requirements=0,
|
||||
total_findings=10,
|
||||
passed_findings=4,
|
||||
failed_findings=6,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(reverse("overview-threatscore"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
aggregate = response.json()["data"][0]["attributes"]
|
||||
|
||||
assert Decimal(aggregate["overall_score"]) == Decimal("67.78")
|
||||
assert Decimal(aggregate["score_delta"]) == Decimal("2.33")
|
||||
assert aggregate["total_findings"] == 10
|
||||
assert aggregate["total_requirements"] == 22
|
||||
assert aggregate["manual_requirements"] == 2
|
||||
assert aggregate["section_scores"] == {"1. IAM": "62.22"}
|
||||
|
||||
def test_overview_threatscore_filter_by_scan_id_returns_snapshot(
|
||||
self, authenticated_client, tenants_fixture, providers_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, *_ = providers_fixture
|
||||
scan = self._create_scan(tenant, provider1, "filter-scan")
|
||||
|
||||
snapshot = self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan,
|
||||
provider1,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="75.00",
|
||||
score_delta="3.00",
|
||||
section_scores={"1. IAM": "70.00"},
|
||||
critical_requirements=[],
|
||||
total_requirements=50,
|
||||
passed_requirements=30,
|
||||
failed_requirements=20,
|
||||
manual_requirements=0,
|
||||
total_findings=25,
|
||||
passed_findings=15,
|
||||
failed_findings=10,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-threatscore"), {"filter[scan_id]": str(scan.id)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
body = response.json()
|
||||
assert len(body["data"]) == 1
|
||||
assert body["data"][0]["id"] == str(snapshot.id)
|
||||
assert body["data"][0]["attributes"]["overall_score"] == "75.00"
|
||||
|
||||
def test_overview_threatscore_snapshot_id_returns_specific_snapshot(
|
||||
self, authenticated_client, tenants_fixture, providers_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, *_ = providers_fixture
|
||||
scan = self._create_scan(tenant, provider1, "snapshot-id-scan")
|
||||
|
||||
snapshot = self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan,
|
||||
provider1,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="88.50",
|
||||
score_delta=None,
|
||||
section_scores={"1. IAM": "80.00"},
|
||||
critical_requirements=[],
|
||||
total_requirements=60,
|
||||
passed_requirements=45,
|
||||
failed_requirements=15,
|
||||
manual_requirements=0,
|
||||
total_findings=30,
|
||||
passed_findings=25,
|
||||
failed_findings=5,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-threatscore"), {"snapshot_id": str(snapshot.id)}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["data"]["id"] == str(snapshot.id)
|
||||
assert data["data"]["attributes"]["score_delta"] is None
|
||||
|
||||
def test_overview_threatscore_provider_filter_returns_unaggregated_snapshot(
|
||||
self, authenticated_client, tenants_fixture, providers_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
|
||||
scan1 = self._create_scan(tenant, provider1, "provider-filter-scan-1")
|
||||
scan2 = self._create_scan(tenant, provider2, "provider-filter-scan-2")
|
||||
|
||||
snapshot1 = self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan1,
|
||||
provider1,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="55.55",
|
||||
score_delta="1.10",
|
||||
section_scores={"1. IAM": "50.00"},
|
||||
critical_requirements=[],
|
||||
total_requirements=40,
|
||||
passed_requirements=25,
|
||||
failed_requirements=15,
|
||||
manual_requirements=0,
|
||||
total_findings=12,
|
||||
passed_findings=7,
|
||||
failed_findings=5,
|
||||
)
|
||||
self._create_threatscore_snapshot(
|
||||
tenant,
|
||||
scan2,
|
||||
provider2,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score="44.44",
|
||||
score_delta="0.80",
|
||||
section_scores={"1. IAM": "40.00"},
|
||||
critical_requirements=[],
|
||||
total_requirements=30,
|
||||
passed_requirements=18,
|
||||
failed_requirements=12,
|
||||
manual_requirements=0,
|
||||
total_findings=10,
|
||||
passed_findings=6,
|
||||
failed_findings=4,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-threatscore"),
|
||||
{"filter[provider_id__in]": str(provider1.id)},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == str(snapshot1.id)
|
||||
assert data[0]["attributes"]["overall_score"] == "55.55"
|
||||
|
||||
def test_overview_services_list_no_required_filters(
|
||||
self, authenticated_client, scan_summaries_fixture
|
||||
):
|
||||
|
||||
@@ -47,6 +47,7 @@ from api.models import (
|
||||
StatusChoices,
|
||||
Task,
|
||||
TenantAPIKey,
|
||||
ThreatScoreSnapshot,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
@@ -3626,3 +3627,64 @@ class MuteRuleUpdateSerializer(BaseWriteSerializer):
|
||||
):
|
||||
raise ValidationError("A mute rule with this name already exists.")
|
||||
return value
|
||||
|
||||
|
||||
# ThreatScore Snapshots
|
||||
|
||||
|
||||
class ThreatScoreSnapshotSerializer(RLSSerializer):
|
||||
"""
|
||||
Serializer for ThreatScore snapshots.
|
||||
Read-only serializer for retrieving historical ThreatScore metrics.
|
||||
"""
|
||||
|
||||
id = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = ThreatScoreSnapshot
|
||||
fields = [
|
||||
"id",
|
||||
"inserted_at",
|
||||
"scan",
|
||||
"provider",
|
||||
"compliance_id",
|
||||
"overall_score",
|
||||
"score_delta",
|
||||
"section_scores",
|
||||
"critical_requirements",
|
||||
"total_requirements",
|
||||
"passed_requirements",
|
||||
"failed_requirements",
|
||||
"manual_requirements",
|
||||
"total_findings",
|
||||
"passed_findings",
|
||||
"failed_findings",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"id": {"read_only": True},
|
||||
"inserted_at": {"read_only": True},
|
||||
"scan": {"read_only": True},
|
||||
"provider": {"read_only": True},
|
||||
"compliance_id": {"read_only": True},
|
||||
"overall_score": {"read_only": True},
|
||||
"score_delta": {"read_only": True},
|
||||
"section_scores": {"read_only": True},
|
||||
"critical_requirements": {"read_only": True},
|
||||
"total_requirements": {"read_only": True},
|
||||
"passed_requirements": {"read_only": True},
|
||||
"failed_requirements": {"read_only": True},
|
||||
"manual_requirements": {"read_only": True},
|
||||
"total_findings": {"read_only": True},
|
||||
"passed_findings": {"read_only": True},
|
||||
"failed_findings": {"read_only": True},
|
||||
}
|
||||
|
||||
included_serializers = {
|
||||
"scan": "api.v1.serializers.ScanIncludeSerializer",
|
||||
"provider": "api.v1.serializers.ProviderIncludeSerializer",
|
||||
}
|
||||
|
||||
def get_id(self, obj):
|
||||
if getattr(obj, "_aggregated", False):
|
||||
return "n/a"
|
||||
return str(obj.id)
|
||||
|
||||
@@ -3,7 +3,10 @@ import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import sentry_sdk
|
||||
@@ -24,9 +27,23 @@ from django.conf import settings as django_settings
|
||||
from django.contrib.postgres.aggregates import ArrayAgg
|
||||
from django.contrib.postgres.search import SearchQuery
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, F, Prefetch, Q, Subquery, Sum
|
||||
from django.db.models import (
|
||||
Case,
|
||||
Count,
|
||||
DecimalField,
|
||||
ExpressionWrapper,
|
||||
F,
|
||||
IntegerField,
|
||||
Max,
|
||||
Prefetch,
|
||||
Q,
|
||||
Subquery,
|
||||
Sum,
|
||||
Value,
|
||||
When,
|
||||
)
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.http import HttpResponse
|
||||
from django.http import HttpResponse, QueryDict
|
||||
from django.shortcuts import redirect
|
||||
from django.urls import reverse
|
||||
from django.utils.dateparse import parse_date
|
||||
@@ -105,6 +122,7 @@ from api.filters import (
|
||||
TaskFilter,
|
||||
TenantApiKeyFilter,
|
||||
TenantFilter,
|
||||
ThreatScoreSnapshotFilter,
|
||||
UserFilter,
|
||||
)
|
||||
from api.models import (
|
||||
@@ -138,6 +156,7 @@ from api.models import (
|
||||
StateChoices,
|
||||
Task,
|
||||
TenantAPIKey,
|
||||
ThreatScoreSnapshot,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
@@ -218,6 +237,7 @@ from api.v1.serializers import (
|
||||
TenantApiKeySerializer,
|
||||
TenantApiKeyUpdateSerializer,
|
||||
TenantSerializer,
|
||||
ThreatScoreSnapshotSerializer,
|
||||
TokenRefreshSerializer,
|
||||
TokenSerializer,
|
||||
TokenSocialLoginSerializer,
|
||||
@@ -3770,6 +3790,8 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
return OverviewSeveritySerializer
|
||||
elif self.action == "services":
|
||||
return OverviewServiceSerializer
|
||||
elif self.action == "threatscore":
|
||||
return ThreatScoreSnapshotSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
def get_filterset_class(self):
|
||||
@@ -4011,6 +4033,332 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@extend_schema(
|
||||
summary="Get ThreatScore snapshots",
|
||||
description=(
|
||||
"Retrieve ThreatScore metrics. By default, returns the latest snapshot for each provider. "
|
||||
"Use snapshot_id to retrieve a specific historical snapshot."
|
||||
),
|
||||
tags=["Overviews"],
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="snapshot_id",
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Retrieve a specific snapshot by ID. If not provided, returns latest snapshots.",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_id",
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by specific provider ID",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_id__in",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by multiple provider IDs (comma-separated UUIDs)",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_type",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by provider type (aws, azure, gcp, etc.)",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_type__in",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by multiple provider types (comma-separated)",
|
||||
),
|
||||
],
|
||||
)
|
||||
@action(detail=False, methods=["get"], url_name="threatscore")
|
||||
def threatscore(self, request):
|
||||
"""
|
||||
Get ThreatScore snapshots.
|
||||
|
||||
Default behavior: Returns the latest snapshot for each provider.
|
||||
With snapshot_id: Returns the specific snapshot requested.
|
||||
"""
|
||||
tenant_id = self.request.tenant_id
|
||||
snapshot_id = request.query_params.get("snapshot_id")
|
||||
|
||||
# Base queryset with RLS
|
||||
base_queryset = ThreatScoreSnapshot.objects.filter(tenant_id=tenant_id)
|
||||
|
||||
# Apply RBAC filtering
|
||||
if hasattr(self, "allowed_providers"):
|
||||
base_queryset = base_queryset.filter(provider__in=self.allowed_providers)
|
||||
|
||||
# Case 1: Specific snapshot requested
|
||||
if snapshot_id:
|
||||
try:
|
||||
snapshot = base_queryset.get(id=snapshot_id)
|
||||
serializer = ThreatScoreSnapshotSerializer(
|
||||
snapshot, context={"request": request}
|
||||
)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
except ThreatScoreSnapshot.DoesNotExist:
|
||||
raise NotFound(detail="ThreatScore snapshot not found")
|
||||
|
||||
# Case 2: Latest snapshot per provider (default)
|
||||
# Apply filters manually: this @action is outside the standard list endpoint flow,
|
||||
# so DRF's filter backends don't execute and we must flatten JSON:API params ourselves.
|
||||
normalized_params = QueryDict(mutable=True)
|
||||
for param_key, values in request.query_params.lists():
|
||||
normalized_key = param_key
|
||||
if param_key.startswith("filter[") and param_key.endswith("]"):
|
||||
normalized_key = param_key[7:-1]
|
||||
if normalized_key == "snapshot_id":
|
||||
continue
|
||||
normalized_params.setlist(normalized_key, values)
|
||||
|
||||
filterset = ThreatScoreSnapshotFilter(normalized_params, queryset=base_queryset)
|
||||
filtered_queryset = filterset.qs
|
||||
|
||||
# Get distinct provider IDs from filtered queryset
|
||||
# Pick the latest snapshot per provider using Postgres DISTINCT ON pattern.
|
||||
# This avoids issuing one query per provider (N+1) when the filtered dataset is large.
|
||||
latest_snapshot_ids = list(
|
||||
filtered_queryset.order_by("provider_id", "-inserted_at")
|
||||
.distinct("provider_id")
|
||||
.values_list("id", flat=True)
|
||||
)
|
||||
latest_snapshot_map = {
|
||||
snapshot.id: snapshot
|
||||
for snapshot in filtered_queryset.filter(id__in=latest_snapshot_ids)
|
||||
}
|
||||
latest_snapshots = [
|
||||
latest_snapshot_map[snapshot_id]
|
||||
for snapshot_id in latest_snapshot_ids
|
||||
if snapshot_id in latest_snapshot_map
|
||||
]
|
||||
|
||||
if len(latest_snapshots) <= 1:
|
||||
serializer = ThreatScoreSnapshotSerializer(
|
||||
latest_snapshots, many=True, context={"request": request}
|
||||
)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
snapshot_ids = [
|
||||
snapshot.id for snapshot in latest_snapshots if snapshot and snapshot.id
|
||||
]
|
||||
aggregated_snapshot = self._build_threatscore_overview_snapshot(
|
||||
snapshot_ids, tenant_id
|
||||
)
|
||||
serializer = ThreatScoreSnapshotSerializer(
|
||||
[aggregated_snapshot], many=True, context={"request": request}
|
||||
)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
def _build_threatscore_overview_snapshot(self, snapshot_ids, tenant_id):
|
||||
"""
|
||||
Aggregate the latest snapshots into a single overview snapshot for the tenant.
|
||||
"""
|
||||
if not snapshot_ids:
|
||||
raise ValueError(
|
||||
"Snapshot id list cannot be empty when aggregating threatscore overview"
|
||||
)
|
||||
|
||||
base_queryset = ThreatScoreSnapshot.objects.filter(
|
||||
tenant_id=tenant_id, id__in=snapshot_ids
|
||||
)
|
||||
|
||||
annotated_queryset = (
|
||||
base_queryset.annotate(
|
||||
active_requirements=ExpressionWrapper(
|
||||
F("total_requirements") - F("manual_requirements"),
|
||||
output_field=IntegerField(),
|
||||
)
|
||||
)
|
||||
.annotate(
|
||||
weight=Case(
|
||||
When(total_findings__gt=0, then=F("total_findings")),
|
||||
When(
|
||||
active_requirements__gt=0,
|
||||
then=F("active_requirements"),
|
||||
),
|
||||
default=Value(1, output_field=IntegerField()),
|
||||
output_field=IntegerField(),
|
||||
)
|
||||
)
|
||||
.order_by()
|
||||
)
|
||||
|
||||
aggregated_metrics = annotated_queryset.aggregate(
|
||||
total_requirements=Sum("total_requirements"),
|
||||
passed_requirements=Sum("passed_requirements"),
|
||||
failed_requirements=Sum("failed_requirements"),
|
||||
manual_requirements=Sum("manual_requirements"),
|
||||
total_findings=Sum("total_findings"),
|
||||
passed_findings=Sum("passed_findings"),
|
||||
failed_findings=Sum("failed_findings"),
|
||||
weighted_overall_sum=Sum(
|
||||
ExpressionWrapper(
|
||||
F("overall_score") * F("weight"),
|
||||
output_field=DecimalField(max_digits=14, decimal_places=4),
|
||||
)
|
||||
),
|
||||
overall_weight=Sum("weight"),
|
||||
unweighted_overall_sum=Sum("overall_score"),
|
||||
weighted_delta_sum=Sum(
|
||||
Case(
|
||||
When(
|
||||
score_delta__isnull=False,
|
||||
then=ExpressionWrapper(
|
||||
F("score_delta") * F("weight"),
|
||||
output_field=DecimalField(max_digits=14, decimal_places=4),
|
||||
),
|
||||
),
|
||||
default=Value(
|
||||
Decimal("0"),
|
||||
output_field=DecimalField(max_digits=14, decimal_places=4),
|
||||
),
|
||||
output_field=DecimalField(max_digits=14, decimal_places=4),
|
||||
)
|
||||
),
|
||||
delta_weight=Sum(
|
||||
Case(
|
||||
When(score_delta__isnull=False, then=F("weight")),
|
||||
default=Value(0, output_field=IntegerField()),
|
||||
output_field=IntegerField(),
|
||||
)
|
||||
),
|
||||
provider_count=Count("id"),
|
||||
latest_inserted_at=Max("inserted_at"),
|
||||
)
|
||||
|
||||
total_requirements = aggregated_metrics["total_requirements"] or 0
|
||||
passed_requirements = aggregated_metrics["passed_requirements"] or 0
|
||||
failed_requirements = aggregated_metrics["failed_requirements"] or 0
|
||||
manual_requirements = aggregated_metrics["manual_requirements"] or 0
|
||||
total_findings = aggregated_metrics["total_findings"] or 0
|
||||
passed_findings = aggregated_metrics["passed_findings"] or 0
|
||||
failed_findings = aggregated_metrics["failed_findings"] or 0
|
||||
|
||||
weighted_overall_sum = aggregated_metrics["weighted_overall_sum"]
|
||||
if weighted_overall_sum is None:
|
||||
weighted_overall_sum = Decimal("0")
|
||||
unweighted_overall_sum = aggregated_metrics["unweighted_overall_sum"]
|
||||
if unweighted_overall_sum is None:
|
||||
unweighted_overall_sum = Decimal("0")
|
||||
|
||||
overall_weight = aggregated_metrics["overall_weight"] or 0
|
||||
provider_count = aggregated_metrics["provider_count"] or 0
|
||||
|
||||
weighted_delta_sum = aggregated_metrics["weighted_delta_sum"]
|
||||
if weighted_delta_sum is None:
|
||||
weighted_delta_sum = Decimal("0")
|
||||
delta_weight = aggregated_metrics["delta_weight"] or 0
|
||||
|
||||
if overall_weight > 0:
|
||||
overall_score = (weighted_overall_sum / Decimal(overall_weight)).quantize(
|
||||
Decimal("0.01"), rounding=ROUND_HALF_UP
|
||||
)
|
||||
elif provider_count > 0:
|
||||
overall_score = (unweighted_overall_sum / Decimal(provider_count)).quantize(
|
||||
Decimal("0.01"), rounding=ROUND_HALF_UP
|
||||
)
|
||||
else:
|
||||
overall_score = Decimal("0.00")
|
||||
|
||||
if delta_weight > 0:
|
||||
score_delta = (weighted_delta_sum / Decimal(delta_weight)).quantize(
|
||||
Decimal("0.01"), rounding=ROUND_HALF_UP
|
||||
)
|
||||
else:
|
||||
score_delta = None
|
||||
|
||||
section_weighted_sums = defaultdict(lambda: Decimal("0"))
|
||||
section_weights = defaultdict(lambda: Decimal("0"))
|
||||
|
||||
combined_critical_requirements = {}
|
||||
|
||||
snapshots_with_weight = list(annotated_queryset)
|
||||
|
||||
for snapshot in snapshots_with_weight:
|
||||
weight_value = getattr(snapshot, "weight", None)
|
||||
try:
|
||||
weight_decimal = Decimal(weight_value)
|
||||
except (InvalidOperation, TypeError):
|
||||
weight_decimal = Decimal("1")
|
||||
if weight_decimal <= 0:
|
||||
weight_decimal = Decimal("1")
|
||||
|
||||
section_scores = snapshot.section_scores or {}
|
||||
for section, score in section_scores.items():
|
||||
try:
|
||||
score_decimal = Decimal(str(score))
|
||||
except (InvalidOperation, TypeError):
|
||||
continue
|
||||
section_weighted_sums[section] += score_decimal * weight_decimal
|
||||
section_weights[section] += weight_decimal
|
||||
|
||||
for requirement in snapshot.critical_requirements or []:
|
||||
key = requirement.get("requirement_id") or requirement.get("title")
|
||||
if not key:
|
||||
continue
|
||||
existing = combined_critical_requirements.get(key)
|
||||
|
||||
def requirement_sort_key(item):
|
||||
return (
|
||||
item.get("risk_level") or 0,
|
||||
item.get("weight") or 0,
|
||||
)
|
||||
|
||||
if existing is None or requirement_sort_key(
|
||||
requirement
|
||||
) > requirement_sort_key(existing):
|
||||
combined_critical_requirements[key] = deepcopy(requirement)
|
||||
|
||||
aggregated_section_scores = {}
|
||||
for section, total in section_weighted_sums.items():
|
||||
weight_total = section_weights[section]
|
||||
if weight_total > 0:
|
||||
aggregated_section_scores[section] = str(
|
||||
(total / weight_total).quantize(
|
||||
Decimal("0.01"), rounding=ROUND_HALF_UP
|
||||
)
|
||||
)
|
||||
|
||||
aggregated_section_scores = dict(sorted(aggregated_section_scores.items()))
|
||||
|
||||
aggregated_critical_requirements = sorted(
|
||||
combined_critical_requirements.values(),
|
||||
key=lambda item: (
|
||||
item.get("risk_level") or 0,
|
||||
item.get("weight") or 0,
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
aggregated_snapshot = ThreatScoreSnapshot(
|
||||
tenant_id=tenant_id,
|
||||
scan=None,
|
||||
provider=None,
|
||||
compliance_id="prowler_threatscore_overview",
|
||||
overall_score=overall_score,
|
||||
score_delta=score_delta,
|
||||
section_scores=aggregated_section_scores,
|
||||
critical_requirements=aggregated_critical_requirements,
|
||||
total_requirements=total_requirements,
|
||||
passed_requirements=passed_requirements,
|
||||
failed_requirements=failed_requirements,
|
||||
manual_requirements=manual_requirements,
|
||||
total_findings=total_findings,
|
||||
passed_findings=passed_findings,
|
||||
failed_findings=failed_findings,
|
||||
)
|
||||
|
||||
latest_inserted_at = aggregated_metrics["latest_inserted_at"]
|
||||
if latest_inserted_at is not None:
|
||||
aggregated_snapshot.inserted_at = latest_inserted_at
|
||||
|
||||
aggregated_snapshot._aggregated = True
|
||||
|
||||
return aggregated_snapshot
|
||||
|
||||
|
||||
@extend_schema(tags=["Schedule"])
|
||||
@extend_schema_view(
|
||||
|
||||
@@ -7,7 +7,6 @@ from shutil import rmtree
|
||||
import matplotlib.pyplot as plt
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
|
||||
from django.db.models import Count, Q
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.enums import TA_CENTER
|
||||
from reportlab.lib.pagesizes import letter
|
||||
@@ -26,11 +25,22 @@ from reportlab.platypus import (
|
||||
TableStyle,
|
||||
)
|
||||
from tasks.jobs.export import _generate_output_directory, _upload_to_s3
|
||||
from tasks.jobs.threatscore import compute_threatscore_metrics
|
||||
from tasks.jobs.threatscore_utils import (
|
||||
_aggregate_requirement_statistics_from_database,
|
||||
_calculate_requirements_data_from_statistics,
|
||||
)
|
||||
from tasks.utils import batched
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Finding, Provider, ScanSummary, StatusChoices
|
||||
from api.models import (
|
||||
Finding,
|
||||
Provider,
|
||||
ScanSummary,
|
||||
StatusChoices,
|
||||
ThreatScoreSnapshot,
|
||||
)
|
||||
from api.utils import initialize_prowler_provider
|
||||
from prowler.lib.check.compliance_models import Compliance
|
||||
from prowler.lib.outputs.finding import Finding as FindingOutput
|
||||
@@ -434,56 +444,6 @@ def _add_pdf_footer(canvas_obj: canvas.Canvas, doc: SimpleDocTemplate) -> None:
|
||||
canvas_obj.drawString(width - text_width - 30, 20, powered_text)
|
||||
|
||||
|
||||
def _aggregate_requirement_statistics_from_database(
|
||||
tenant_id: str, scan_id: str
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""
|
||||
Aggregate finding statistics by check_id using database aggregation.
|
||||
|
||||
This function uses Django ORM aggregation to calculate pass/fail statistics
|
||||
entirely in the database, avoiding the need to load findings into memory.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant ID for Row-Level Security context.
|
||||
scan_id (str): The ID of the scan to retrieve findings for.
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, int]]: Dictionary mapping check_id to statistics:
|
||||
- 'passed' (int): Number of passed findings for this check
|
||||
- 'total' (int): Total number of findings for this check
|
||||
|
||||
Example:
|
||||
{
|
||||
'aws_iam_user_mfa_enabled': {'passed': 10, 'total': 15},
|
||||
'aws_s3_bucket_public_access': {'passed': 0, 'total': 5}
|
||||
}
|
||||
"""
|
||||
requirement_statistics_by_check_id = {}
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Use database aggregation to calculate stats without loading findings into memory
|
||||
aggregated_statistics_queryset = (
|
||||
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
.values("check_id")
|
||||
.annotate(
|
||||
total_findings=Count("id"),
|
||||
passed_findings=Count("id", filter=Q(status=StatusChoices.PASS)),
|
||||
)
|
||||
)
|
||||
|
||||
for aggregated_stat in aggregated_statistics_queryset:
|
||||
check_id = aggregated_stat["check_id"]
|
||||
requirement_statistics_by_check_id[check_id] = {
|
||||
"passed": aggregated_stat["passed_findings"],
|
||||
"total": aggregated_stat["total_findings"],
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Aggregated statistics for {len(requirement_statistics_by_check_id)} unique checks"
|
||||
)
|
||||
return requirement_statistics_by_check_id
|
||||
|
||||
|
||||
def _load_findings_for_requirement_checks(
|
||||
tenant_id: str, scan_id: str, check_ids: list[str], prowler_provider
|
||||
) -> dict[str, list[FindingOutput]]:
|
||||
@@ -544,84 +504,6 @@ def _load_findings_for_requirement_checks(
|
||||
return dict(findings_by_check_id)
|
||||
|
||||
|
||||
def _calculate_requirements_data_from_statistics(
|
||||
compliance_obj, requirement_statistics_by_check_id: dict[str, dict[str, int]]
|
||||
) -> tuple[dict[str, dict], list[dict]]:
|
||||
"""
|
||||
Calculate requirement status and statistics using pre-aggregated database statistics.
|
||||
|
||||
This function uses O(n) lookups with pre-aggregated statistics from the database,
|
||||
avoiding the need to iterate over all findings for each requirement.
|
||||
|
||||
Args:
|
||||
compliance_obj: The compliance framework object containing requirements.
|
||||
requirement_statistics_by_check_id (dict[str, dict[str, int]]): Pre-aggregated statistics
|
||||
mapping check_id to {'passed': int, 'total': int} counts.
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, dict], list[dict]]: A tuple containing:
|
||||
- attributes_by_requirement_id: Dictionary mapping requirement IDs to their attributes.
|
||||
- requirements_list: List of requirement dictionaries with status and statistics.
|
||||
"""
|
||||
attributes_by_requirement_id = {}
|
||||
requirements_list = []
|
||||
|
||||
compliance_framework = getattr(compliance_obj, "Framework", "N/A")
|
||||
compliance_version = getattr(compliance_obj, "Version", "N/A")
|
||||
|
||||
for requirement in compliance_obj.Requirements:
|
||||
requirement_id = requirement.Id
|
||||
requirement_description = getattr(requirement, "Description", "")
|
||||
requirement_checks = getattr(requirement, "Checks", [])
|
||||
requirement_attributes = getattr(requirement, "Attributes", [])
|
||||
|
||||
# Store requirement metadata for later use
|
||||
attributes_by_requirement_id[requirement_id] = {
|
||||
"attributes": {
|
||||
"req_attributes": requirement_attributes,
|
||||
"checks": requirement_checks,
|
||||
},
|
||||
"description": requirement_description,
|
||||
}
|
||||
|
||||
# Calculate aggregated passed and total findings for this requirement
|
||||
total_passed_findings = 0
|
||||
total_findings_count = 0
|
||||
|
||||
for check_id in requirement_checks:
|
||||
if check_id in requirement_statistics_by_check_id:
|
||||
check_statistics = requirement_statistics_by_check_id[check_id]
|
||||
total_findings_count += check_statistics["total"]
|
||||
total_passed_findings += check_statistics["passed"]
|
||||
|
||||
# Determine overall requirement status based on findings
|
||||
if total_findings_count > 0:
|
||||
if total_passed_findings == total_findings_count:
|
||||
requirement_status = StatusChoices.PASS
|
||||
else:
|
||||
# Partial pass or complete fail both count as FAIL
|
||||
requirement_status = StatusChoices.FAIL
|
||||
else:
|
||||
# No findings means manual review required
|
||||
requirement_status = StatusChoices.MANUAL
|
||||
|
||||
requirements_list.append(
|
||||
{
|
||||
"id": requirement_id,
|
||||
"attributes": {
|
||||
"framework": compliance_framework,
|
||||
"version": compliance_version,
|
||||
"status": requirement_status,
|
||||
"description": requirement_description,
|
||||
"passed_findings": total_passed_findings,
|
||||
"total_findings": total_findings_count,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return attributes_by_requirement_id, requirements_list
|
||||
|
||||
|
||||
def generate_threatscore_report(
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
@@ -1262,8 +1144,9 @@ def generate_threatscore_report_job(
|
||||
2. Checks provider type compatibility
|
||||
3. Generates the output directory
|
||||
4. Calls generate_threatscore_report to create the PDF
|
||||
5. Uploads the PDF to S3
|
||||
6. Cleans up temporary files
|
||||
5. Computes and stores ThreatScore metrics snapshot
|
||||
6. Uploads the PDF to S3
|
||||
7. Cleans up temporary files
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant ID for Row-Level Security context.
|
||||
@@ -1317,6 +1200,66 @@ def generate_threatscore_report_job(
|
||||
min_risk_level=4,
|
||||
)
|
||||
|
||||
# Compute and store ThreatScore metrics snapshot
|
||||
logger.info(f"Computing ThreatScore metrics for scan {scan_id}")
|
||||
try:
|
||||
metrics = compute_threatscore_metrics(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
provider_id=provider_id,
|
||||
compliance_id=compliance_id,
|
||||
min_risk_level=4,
|
||||
)
|
||||
|
||||
# Create snapshot in database
|
||||
with rls_transaction(tenant_id):
|
||||
# Get previous snapshot for the same provider to calculate delta
|
||||
previous_snapshot = (
|
||||
ThreatScoreSnapshot.objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
compliance_id=compliance_id,
|
||||
)
|
||||
.order_by("-inserted_at")
|
||||
.first()
|
||||
)
|
||||
|
||||
# Calculate score delta (improvement)
|
||||
score_delta = None
|
||||
if previous_snapshot:
|
||||
score_delta = metrics["overall_score"] - float(
|
||||
previous_snapshot.overall_score
|
||||
)
|
||||
|
||||
snapshot = ThreatScoreSnapshot.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
provider_id=provider_id,
|
||||
compliance_id=compliance_id,
|
||||
overall_score=metrics["overall_score"],
|
||||
score_delta=score_delta,
|
||||
section_scores=metrics["section_scores"],
|
||||
critical_requirements=metrics["critical_requirements"],
|
||||
total_requirements=metrics["total_requirements"],
|
||||
passed_requirements=metrics["passed_requirements"],
|
||||
failed_requirements=metrics["failed_requirements"],
|
||||
manual_requirements=metrics["manual_requirements"],
|
||||
total_findings=metrics["total_findings"],
|
||||
passed_findings=metrics["passed_findings"],
|
||||
failed_findings=metrics["failed_findings"],
|
||||
)
|
||||
|
||||
delta_msg = (
|
||||
f" (delta: {score_delta:+.2f}%)" if score_delta is not None else ""
|
||||
)
|
||||
logger.info(
|
||||
f"ThreatScore snapshot created with ID {snapshot.id} "
|
||||
f"(score: {snapshot.overall_score}%{delta_msg})"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the job if snapshot creation fails
|
||||
logger.error(f"Error creating ThreatScore snapshot: {e}")
|
||||
|
||||
upload_uri = _upload_to_s3(
|
||||
tenant_id,
|
||||
scan_id,
|
||||
|
||||
214
api/src/backend/tasks/jobs/threatscore.py
Normal file
214
api/src/backend/tasks/jobs/threatscore.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from celery.utils.log import get_task_logger
|
||||
from tasks.jobs.threatscore_utils import (
|
||||
_aggregate_requirement_statistics_from_database,
|
||||
_calculate_requirements_data_from_statistics,
|
||||
)
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Provider, StatusChoices
|
||||
from prowler.lib.check.compliance_models import Compliance
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
def compute_threatscore_metrics(
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
provider_id: str,
|
||||
compliance_id: str,
|
||||
min_risk_level: int = 4,
|
||||
) -> dict:
|
||||
"""
|
||||
Compute ThreatScore metrics for a given scan.
|
||||
|
||||
This function calculates all the metrics needed for a ThreatScore snapshot:
|
||||
- Overall ThreatScore percentage
|
||||
- Section-by-section scores
|
||||
- Critical failed requirements (risk >= min_risk_level)
|
||||
- Summary statistics (requirements and findings counts)
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant ID for Row-Level Security context.
|
||||
scan_id (str): The ID of the scan to analyze.
|
||||
provider_id (str): The ID of the provider used in the scan.
|
||||
compliance_id (str): Compliance framework ID (e.g., "prowler_threatscore_aws").
|
||||
min_risk_level (int): Minimum risk level for critical requirements. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
- overall_score (float): Overall ThreatScore percentage (0-100)
|
||||
- section_scores (dict): Section name -> score percentage mapping
|
||||
- critical_requirements (list): List of critical failed requirement dicts
|
||||
- total_requirements (int): Total number of requirements
|
||||
- passed_requirements (int): Number of PASS requirements
|
||||
- failed_requirements (int): Number of FAIL requirements
|
||||
- manual_requirements (int): Number of MANUAL requirements
|
||||
- total_findings (int): Total findings count
|
||||
- passed_findings (int): Passed findings count
|
||||
- failed_findings (int): Failed findings count
|
||||
|
||||
Example:
|
||||
>>> metrics = compute_threatscore_metrics(
|
||||
... tenant_id="tenant-123",
|
||||
... scan_id="scan-456",
|
||||
... provider_id="provider-789",
|
||||
... compliance_id="prowler_threatscore_aws"
|
||||
... )
|
||||
>>> print(f"Overall ThreatScore: {metrics['overall_score']:.2f}%")
|
||||
"""
|
||||
# Get provider and compliance information
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
provider_obj = Provider.objects.get(id=provider_id)
|
||||
provider_type = provider_obj.provider
|
||||
|
||||
frameworks_bulk = Compliance.get_bulk(provider_type)
|
||||
compliance_obj = frameworks_bulk[compliance_id]
|
||||
|
||||
# Aggregate requirement statistics from database
|
||||
requirement_statistics_by_check_id = (
|
||||
_aggregate_requirement_statistics_from_database(tenant_id, scan_id)
|
||||
)
|
||||
|
||||
# Calculate requirements data using aggregated statistics
|
||||
attributes_by_requirement_id, requirements_list = (
|
||||
_calculate_requirements_data_from_statistics(
|
||||
compliance_obj, requirement_statistics_by_check_id
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize metrics
|
||||
overall_numerator = 0
|
||||
overall_denominator = 0
|
||||
overall_has_findings = False
|
||||
|
||||
sections_data = {}
|
||||
|
||||
total_requirements = len(requirements_list)
|
||||
passed_requirements = 0
|
||||
failed_requirements = 0
|
||||
manual_requirements = 0
|
||||
total_findings = 0
|
||||
passed_findings = 0
|
||||
failed_findings = 0
|
||||
|
||||
critical_requirements_list = []
|
||||
|
||||
# Process each requirement
|
||||
for requirement in requirements_list:
|
||||
requirement_id = requirement["id"]
|
||||
requirement_status = requirement["attributes"]["status"]
|
||||
requirement_attributes = attributes_by_requirement_id.get(requirement_id, {})
|
||||
|
||||
# Count requirements by status
|
||||
if requirement_status == StatusChoices.PASS:
|
||||
passed_requirements += 1
|
||||
elif requirement_status == StatusChoices.FAIL:
|
||||
failed_requirements += 1
|
||||
elif requirement_status == StatusChoices.MANUAL:
|
||||
manual_requirements += 1
|
||||
|
||||
# Get findings data
|
||||
req_passed_findings = requirement["attributes"].get("passed_findings", 0)
|
||||
req_total_findings = requirement["attributes"].get("total_findings", 0)
|
||||
|
||||
# Accumulate findings counts
|
||||
total_findings += req_total_findings
|
||||
passed_findings += req_passed_findings
|
||||
failed_findings += req_total_findings - req_passed_findings
|
||||
|
||||
# Skip requirements with no findings
|
||||
if req_total_findings == 0:
|
||||
continue
|
||||
|
||||
overall_has_findings = True
|
||||
|
||||
# Get requirement metadata
|
||||
metadata = requirement_attributes.get("attributes", {}).get(
|
||||
"req_attributes", []
|
||||
)
|
||||
if not metadata or len(metadata) == 0:
|
||||
continue
|
||||
|
||||
m = metadata[0]
|
||||
risk_level = getattr(m, "LevelOfRisk", 0)
|
||||
weight = getattr(m, "Weight", 0)
|
||||
section = getattr(m, "Section", "Unknown")
|
||||
|
||||
# Calculate ThreatScore components using formula from UI
|
||||
rate_i = req_passed_findings / req_total_findings
|
||||
rfac_i = 1 + 0.25 * risk_level
|
||||
|
||||
# Update overall score
|
||||
overall_numerator += rate_i * req_total_findings * weight * rfac_i
|
||||
overall_denominator += req_total_findings * weight * rfac_i
|
||||
|
||||
# Update section scores
|
||||
if section not in sections_data:
|
||||
sections_data[section] = {
|
||||
"numerator": 0,
|
||||
"denominator": 0,
|
||||
"has_findings": False,
|
||||
}
|
||||
|
||||
sections_data[section]["has_findings"] = True
|
||||
sections_data[section]["numerator"] += (
|
||||
rate_i * req_total_findings * weight * rfac_i
|
||||
)
|
||||
sections_data[section]["denominator"] += req_total_findings * weight * rfac_i
|
||||
|
||||
# Identify critical failed requirements
|
||||
if requirement_status == StatusChoices.FAIL and risk_level >= min_risk_level:
|
||||
critical_requirements_list.append(
|
||||
{
|
||||
"requirement_id": requirement_id,
|
||||
"title": getattr(m, "Title", "N/A"),
|
||||
"section": section,
|
||||
"subsection": getattr(m, "SubSection", "N/A"),
|
||||
"risk_level": risk_level,
|
||||
"weight": weight,
|
||||
"passed_findings": req_passed_findings,
|
||||
"total_findings": req_total_findings,
|
||||
"description": getattr(m, "AttributeDescription", "N/A"),
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate overall ThreatScore
|
||||
if not overall_has_findings:
|
||||
overall_score = 100.0
|
||||
elif overall_denominator > 0:
|
||||
overall_score = (overall_numerator / overall_denominator) * 100
|
||||
else:
|
||||
overall_score = 0.0
|
||||
|
||||
# Calculate section scores
|
||||
section_scores = {}
|
||||
for section, data in sections_data.items():
|
||||
if data["has_findings"] and data["denominator"] > 0:
|
||||
section_scores[section] = (data["numerator"] / data["denominator"]) * 100
|
||||
else:
|
||||
section_scores[section] = 100.0
|
||||
|
||||
# Sort critical requirements by risk level (desc) and weight (desc)
|
||||
critical_requirements_list.sort(
|
||||
key=lambda x: (x["risk_level"], x["weight"]), reverse=True
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"ThreatScore computed: {overall_score:.2f}% "
|
||||
f"({passed_requirements}/{total_requirements} requirements passed, "
|
||||
f"{len(critical_requirements_list)} critical failures)"
|
||||
)
|
||||
|
||||
return {
|
||||
"overall_score": round(overall_score, 2),
|
||||
"section_scores": {k: round(v, 2) for k, v in section_scores.items()},
|
||||
"critical_requirements": critical_requirements_list,
|
||||
"total_requirements": total_requirements,
|
||||
"passed_requirements": passed_requirements,
|
||||
"failed_requirements": failed_requirements,
|
||||
"manual_requirements": manual_requirements,
|
||||
"total_findings": total_findings,
|
||||
"passed_findings": passed_findings,
|
||||
"failed_findings": failed_findings,
|
||||
}
|
||||
127
api/src/backend/tasks/jobs/threatscore_utils.py
Normal file
127
api/src/backend/tasks/jobs/threatscore_utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from celery.utils.log import get_task_logger
|
||||
from django.db.models import Count, Q
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Finding, StatusChoices
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
def _aggregate_requirement_statistics_from_database(
|
||||
tenant_id: str, scan_id: str
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""
|
||||
Aggregate finding statistics by check_id using database aggregation.
|
||||
|
||||
This function uses Django ORM aggregation to calculate pass/fail statistics
|
||||
entirely in the database, avoiding the need to load findings into memory.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant ID for Row-Level Security context.
|
||||
scan_id (str): The ID of the scan to retrieve findings for.
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, int]]: Dictionary mapping check_id to statistics:
|
||||
- 'passed' (int): Number of passed findings for this check
|
||||
- 'total' (int): Total number of findings for this check
|
||||
|
||||
Example:
|
||||
{
|
||||
'aws_iam_user_mfa_enabled': {'passed': 10, 'total': 15},
|
||||
'aws_s3_bucket_public_access': {'passed': 0, 'total': 5}
|
||||
}
|
||||
"""
|
||||
requirement_statistics_by_check_id = {}
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
aggregated_statistics_queryset = (
|
||||
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
.values("check_id")
|
||||
.annotate(
|
||||
total_findings=Count("id"),
|
||||
passed_findings=Count("id", filter=Q(status=StatusChoices.PASS)),
|
||||
)
|
||||
)
|
||||
|
||||
for aggregated_stat in aggregated_statistics_queryset:
|
||||
check_id = aggregated_stat["check_id"]
|
||||
requirement_statistics_by_check_id[check_id] = {
|
||||
"passed": aggregated_stat["passed_findings"],
|
||||
"total": aggregated_stat["total_findings"],
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Aggregated statistics for {len(requirement_statistics_by_check_id)} unique checks"
|
||||
)
|
||||
return requirement_statistics_by_check_id
|
||||
|
||||
|
||||
def _calculate_requirements_data_from_statistics(
|
||||
compliance_obj, requirement_statistics_by_check_id: dict[str, dict[str, int]]
|
||||
) -> tuple[dict[str, dict], list[dict]]:
|
||||
"""
|
||||
Calculate requirement status and statistics using pre-aggregated database statistics.
|
||||
|
||||
Args:
|
||||
compliance_obj: The compliance framework object containing requirements.
|
||||
requirement_statistics_by_check_id (dict[str, dict[str, int]]): Pre-aggregated statistics
|
||||
mapping check_id to {'passed': int, 'total': int} counts.
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, dict], list[dict]]: A tuple containing:
|
||||
- attributes_by_requirement_id: Dictionary mapping requirement IDs to their attributes.
|
||||
- requirements_list: List of requirement dictionaries with status and statistics.
|
||||
"""
|
||||
attributes_by_requirement_id = {}
|
||||
requirements_list = []
|
||||
|
||||
compliance_framework = getattr(compliance_obj, "Framework", "N/A")
|
||||
compliance_version = getattr(compliance_obj, "Version", "N/A")
|
||||
|
||||
for requirement in compliance_obj.Requirements:
|
||||
requirement_id = requirement.Id
|
||||
requirement_description = getattr(requirement, "Description", "")
|
||||
requirement_checks = getattr(requirement, "Checks", [])
|
||||
requirement_attributes = getattr(requirement, "Attributes", [])
|
||||
|
||||
attributes_by_requirement_id[requirement_id] = {
|
||||
"attributes": {
|
||||
"req_attributes": requirement_attributes,
|
||||
"checks": requirement_checks,
|
||||
},
|
||||
"description": requirement_description,
|
||||
}
|
||||
|
||||
total_passed_findings = 0
|
||||
total_findings_count = 0
|
||||
|
||||
for check_id in requirement_checks:
|
||||
if check_id in requirement_statistics_by_check_id:
|
||||
check_statistics = requirement_statistics_by_check_id[check_id]
|
||||
total_findings_count += check_statistics["total"]
|
||||
total_passed_findings += check_statistics["passed"]
|
||||
|
||||
if total_findings_count > 0:
|
||||
if total_passed_findings == total_findings_count:
|
||||
requirement_status = StatusChoices.PASS
|
||||
else:
|
||||
requirement_status = StatusChoices.FAIL
|
||||
else:
|
||||
requirement_status = StatusChoices.MANUAL
|
||||
|
||||
requirements_list.append(
|
||||
{
|
||||
"id": requirement_id,
|
||||
"attributes": {
|
||||
"framework": compliance_framework,
|
||||
"version": compliance_version,
|
||||
"status": requirement_status,
|
||||
"description": requirement_description,
|
||||
"passed_findings": total_passed_findings,
|
||||
"total_findings": total_findings_count,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return attributes_by_requirement_id, requirements_list
|
||||
@@ -1,19 +1,25 @@
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import matplotlib
|
||||
import pytest
|
||||
from django.utils import timezone
|
||||
from freezegun import freeze_time
|
||||
from tasks.jobs.report import (
|
||||
_aggregate_requirement_statistics_from_database,
|
||||
_calculate_requirements_data_from_statistics,
|
||||
_load_findings_for_requirement_checks,
|
||||
generate_threatscore_report,
|
||||
generate_threatscore_report_job,
|
||||
)
|
||||
from tasks.jobs.threatscore_utils import (
|
||||
_aggregate_requirement_statistics_from_database,
|
||||
_calculate_requirements_data_from_statistics,
|
||||
)
|
||||
from tasks.tasks import generate_threatscore_report_task
|
||||
|
||||
from api.models import Finding, StatusChoices
|
||||
from api.models import Finding, Scan, StateChoices, StatusChoices, ThreatScoreSnapshot
|
||||
from prowler.lib.check.models import Severity
|
||||
|
||||
matplotlib.use("Agg") # Use non-interactive backend for tests
|
||||
@@ -39,6 +45,7 @@ class TestGenerateThreatscoreReport:
|
||||
assert result == {"upload": False}
|
||||
mock_filter.assert_called_once_with(scan_id=self.scan_id)
|
||||
|
||||
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
|
||||
@patch("tasks.jobs.report.rmtree")
|
||||
@patch("tasks.jobs.report._upload_to_s3")
|
||||
@patch("tasks.jobs.report.generate_threatscore_report")
|
||||
@@ -53,6 +60,7 @@ class TestGenerateThreatscoreReport:
|
||||
mock_generate_report,
|
||||
mock_upload,
|
||||
mock_rmtree,
|
||||
mock_snapshot_create,
|
||||
):
|
||||
mock_scan_summary_filter.return_value.exists.return_value = True
|
||||
|
||||
@@ -95,8 +103,10 @@ class TestGenerateThreatscoreReport:
|
||||
Path("/tmp/threatscore_path_threatscore_report.pdf").parent,
|
||||
ignore_errors=True,
|
||||
)
|
||||
mock_snapshot_create.assert_called_once()
|
||||
|
||||
def test_generate_threatscore_report_fails_upload(self):
|
||||
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
|
||||
def test_generate_threatscore_report_fails_upload(self, mock_snapshot_create):
|
||||
with (
|
||||
patch("tasks.jobs.report.ScanSummary.objects.filter") as mock_filter,
|
||||
patch("tasks.jobs.report.Provider.objects.get") as mock_provider_get,
|
||||
@@ -125,8 +135,12 @@ class TestGenerateThreatscoreReport:
|
||||
)
|
||||
|
||||
assert result == {"upload": False}
|
||||
mock_snapshot_create.assert_called_once()
|
||||
|
||||
def test_generate_threatscore_report_logs_rmtree_exception(self, caplog):
|
||||
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
|
||||
def test_generate_threatscore_report_logs_rmtree_exception(
|
||||
self, mock_snapshot_create, caplog
|
||||
):
|
||||
with (
|
||||
patch("tasks.jobs.report.ScanSummary.objects.filter") as mock_filter,
|
||||
patch("tasks.jobs.report.Provider.objects.get") as mock_provider_get,
|
||||
@@ -160,8 +174,10 @@ class TestGenerateThreatscoreReport:
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
assert "Error deleting output files" in caplog.text
|
||||
mock_snapshot_create.assert_called_once()
|
||||
|
||||
def test_generate_threatscore_report_azure_provider(self):
|
||||
@patch("tasks.jobs.report.ThreatScoreSnapshot.objects.create")
|
||||
def test_generate_threatscore_report_azure_provider(self, mock_snapshot_create):
|
||||
with (
|
||||
patch("tasks.jobs.report.ScanSummary.objects.filter") as mock_filter,
|
||||
patch("tasks.jobs.report.Provider.objects.get") as mock_provider_get,
|
||||
@@ -200,6 +216,135 @@ class TestGenerateThreatscoreReport:
|
||||
only_failed=True,
|
||||
min_risk_level=4,
|
||||
)
|
||||
mock_snapshot_create.assert_called_once()
|
||||
|
||||
@patch("tasks.jobs.report.rmtree")
|
||||
@patch(
|
||||
"tasks.jobs.report._upload_to_s3",
|
||||
return_value="s3://bucket/threatscore/threatscore_report.pdf",
|
||||
)
|
||||
@patch("tasks.jobs.report.generate_threatscore_report")
|
||||
@patch("tasks.jobs.report._generate_output_directory")
|
||||
@patch("tasks.jobs.report.ScanSummary.objects.filter")
|
||||
@patch("tasks.jobs.report.compute_threatscore_metrics")
|
||||
@pytest.mark.django_db
|
||||
@freeze_time("2025-01-10T12:00:00Z")
|
||||
def test_generate_threatscore_report_persists_snapshot_and_delta(
|
||||
self,
|
||||
mock_compute_metrics,
|
||||
mock_scan_summary_filter,
|
||||
mock_generate_output_directory,
|
||||
mock_generate_report,
|
||||
mock_upload,
|
||||
mock_rmtree,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
scan_previous = Scan.objects.create(
|
||||
tenant=tenant,
|
||||
provider=provider,
|
||||
name="previous-threatscore-scan",
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
started_at=timezone.now() - timedelta(hours=4),
|
||||
completed_at=timezone.now() - timedelta(hours=3),
|
||||
)
|
||||
ThreatScoreSnapshot.objects.create(
|
||||
tenant=tenant,
|
||||
scan=scan_previous,
|
||||
provider=provider,
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
overall_score=Decimal("70.00"),
|
||||
score_delta=None,
|
||||
section_scores={"1. IAM": "65.00"},
|
||||
critical_requirements=[],
|
||||
total_requirements=50,
|
||||
passed_requirements=35,
|
||||
failed_requirements=15,
|
||||
manual_requirements=0,
|
||||
total_findings=40,
|
||||
passed_findings=25,
|
||||
failed_findings=15,
|
||||
)
|
||||
|
||||
scan_current = Scan.objects.create(
|
||||
tenant=tenant,
|
||||
provider=provider,
|
||||
name="current-threatscore-scan",
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
started_at=timezone.now() - timedelta(hours=2),
|
||||
completed_at=timezone.now() - timedelta(hours=1),
|
||||
)
|
||||
|
||||
mock_scan_summary_filter.return_value.exists.return_value = True
|
||||
mock_generate_output_directory.return_value = (
|
||||
"/tmp/output",
|
||||
"/tmp/compressed",
|
||||
"/tmp/threatscore_path",
|
||||
)
|
||||
|
||||
metrics = {
|
||||
"overall_score": 85.5,
|
||||
"score_delta": 10.0,
|
||||
"section_scores": {"1. IAM": 82.3, "2. Attack Surface": 60.0},
|
||||
"critical_requirements": [
|
||||
{
|
||||
"requirement_id": "req_new",
|
||||
"title": "New high-risk requirement",
|
||||
"section": "1. IAM",
|
||||
"subsection": "Root Account",
|
||||
"risk_level": 5,
|
||||
"weight": 150,
|
||||
"passed_findings": 7,
|
||||
"total_findings": 10,
|
||||
"description": "Critical requirement description",
|
||||
}
|
||||
],
|
||||
"total_requirements": 140,
|
||||
"passed_requirements": 100,
|
||||
"failed_requirements": 40,
|
||||
"manual_requirements": 0,
|
||||
"total_findings": 200,
|
||||
"passed_findings": 150,
|
||||
"failed_findings": 50,
|
||||
}
|
||||
mock_compute_metrics.return_value = metrics
|
||||
|
||||
result = generate_threatscore_report_job(
|
||||
tenant_id=str(tenant.id),
|
||||
scan_id=str(scan_current.id),
|
||||
provider_id=str(provider.id),
|
||||
)
|
||||
|
||||
assert result == {"upload": True}
|
||||
mock_compute_metrics.assert_called_once_with(
|
||||
tenant_id=str(tenant.id),
|
||||
scan_id=str(scan_current.id),
|
||||
provider_id=str(provider.id),
|
||||
compliance_id="prowler_threatscore_aws",
|
||||
min_risk_level=4,
|
||||
)
|
||||
mock_generate_report.assert_called_once()
|
||||
mock_upload.assert_called_once()
|
||||
mock_rmtree.assert_called_once()
|
||||
|
||||
snapshots = ThreatScoreSnapshot.objects.filter(
|
||||
tenant=tenant, provider=provider
|
||||
).order_by("inserted_at")
|
||||
assert snapshots.count() == 2
|
||||
|
||||
new_snapshot = ThreatScoreSnapshot.objects.get(scan=scan_current)
|
||||
assert new_snapshot.compliance_id == "prowler_threatscore_aws"
|
||||
assert Decimal(new_snapshot.overall_score) == Decimal("85.50")
|
||||
assert Decimal(new_snapshot.score_delta) == Decimal("15.50")
|
||||
assert new_snapshot.section_scores == metrics["section_scores"]
|
||||
assert new_snapshot.critical_requirements == metrics["critical_requirements"]
|
||||
assert new_snapshot.total_requirements == metrics["total_requirements"]
|
||||
assert new_snapshot.total_findings == metrics["total_findings"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
||||
Reference in New Issue
Block a user