feat(threatscore): restore API threatscore snapshots (#9291)

This commit is contained in:
Adrián Jesús Peña Rodríguez
2025-11-24 10:47:03 +01:00
committed by GitHub
parent e2e06a78f9
commit 2f184a493b
3 changed files with 186 additions and 243 deletions

View File

@@ -6,8 +6,7 @@ from shutil import rmtree
import matplotlib.pyplot as plt
from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django.db.models import Count, Q
from config.django.base import DJANGO_TMP_OUTPUT_DIRECTORY
from reportlab.lib import colors
from reportlab.lib.enums import TA_CENTER
from reportlab.lib.pagesizes import letter
@@ -26,11 +25,16 @@ from reportlab.platypus import (
TableStyle,
)
from tasks.jobs.export import _generate_compliance_output_directory, _upload_to_s3
from tasks.utils import batched
from tasks.jobs.threatscore import compute_threatscore_metrics
from tasks.jobs.threatscore_utils import (
_aggregate_requirement_statistics_from_database,
_calculate_requirements_data_from_statistics,
_load_findings_for_requirement_checks,
)
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import Finding, Provider, ScanSummary, StatusChoices
from api.models import Provider, ScanSummary, StatusChoices, ThreatScoreSnapshot
from api.utils import initialize_prowler_provider
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.finding import Finding as FindingOutput
@@ -974,236 +978,6 @@ def _create_dimensions_radar_chart(
return buffer
def _aggregate_requirement_statistics_from_database(
tenant_id: str, scan_id: str
) -> dict[str, dict[str, int]]:
"""
Aggregate finding statistics by check_id using database aggregation.
This function uses Django ORM aggregation to calculate pass/fail statistics
entirely in the database, avoiding the need to load findings into memory.
Args:
tenant_id (str): The tenant ID for Row-Level Security context.
scan_id (str): The ID of the scan to retrieve findings for.
Returns:
dict[str, dict[str, int]]: Dictionary mapping check_id to statistics:
- 'passed' (int): Number of passed findings for this check
- 'total' (int): Total number of findings for this check
Example:
{
'aws_iam_user_mfa_enabled': {'passed': 10, 'total': 15},
'aws_s3_bucket_public_access': {'passed': 0, 'total': 5}
}
"""
requirement_statistics_by_check_id = {}
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
# Use database aggregation to calculate stats without loading findings into memory
aggregated_statistics_queryset = (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.values("check_id")
.annotate(
total_findings=Count("id"),
passed_findings=Count("id", filter=Q(status=StatusChoices.PASS)),
)
)
for aggregated_stat in aggregated_statistics_queryset:
check_id = aggregated_stat["check_id"]
requirement_statistics_by_check_id[check_id] = {
"passed": aggregated_stat["passed_findings"],
"total": aggregated_stat["total_findings"],
}
logger.info(
f"Aggregated statistics for {len(requirement_statistics_by_check_id)} unique checks"
)
return requirement_statistics_by_check_id
def _load_findings_for_requirement_checks(
tenant_id: str,
scan_id: str,
check_ids: list[str],
prowler_provider,
findings_cache: dict[str, list[FindingOutput]] | None = None,
) -> dict[str, list[FindingOutput]]:
"""
Load findings for specific check IDs on-demand with optional caching.
This function loads only the findings needed for a specific set of checks,
minimizing memory usage by avoiding loading all findings at once. This is used
when generating detailed findings tables for specific requirements in the PDF.
Supports optional caching to avoid duplicate queries when generating multiple
reports for the same scan.
Args:
tenant_id (str): The tenant ID for Row-Level Security context.
scan_id (str): The ID of the scan to retrieve findings for.
check_ids (list[str]): List of check IDs to load findings for.
prowler_provider: The initialized Prowler provider instance.
findings_cache (dict, optional): Cache of already loaded findings.
If provided, checks are first looked up in cache before querying database.
Returns:
dict[str, list[FindingOutput]]: Dictionary mapping check_id to list of FindingOutput objects.
Example:
{
'aws_iam_user_mfa_enabled': [FindingOutput(...), FindingOutput(...)],
'aws_s3_bucket_public_access': [FindingOutput(...)]
}
"""
findings_by_check_id = defaultdict(list)
if not check_ids:
return dict(findings_by_check_id)
# Initialize cache if not provided
if findings_cache is None:
findings_cache = {}
# Separate cached and non-cached check_ids
check_ids_to_load = []
cache_hits = 0
cache_misses = 0
for check_id in check_ids:
if check_id in findings_cache:
# Reuse from cache
findings_by_check_id[check_id] = findings_cache[check_id]
cache_hits += 1
else:
# Need to load from database
check_ids_to_load.append(check_id)
cache_misses += 1
if cache_hits > 0:
logger.info(
f"Findings cache: {cache_hits} hits, {cache_misses} misses "
f"({cache_hits / (cache_hits + cache_misses) * 100:.1f}% hit rate)"
)
# If all check_ids were in cache, return early
if not check_ids_to_load:
return dict(findings_by_check_id)
logger.info(f"Loading findings for {len(check_ids_to_load)} checks on-demand")
findings_queryset = (
Finding.all_objects.filter(
tenant_id=tenant_id, scan_id=scan_id, check_id__in=check_ids_to_load
)
.order_by("uid")
.iterator()
)
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
for batch, is_last_batch in batched(
findings_queryset, DJANGO_FINDINGS_BATCH_SIZE
):
for finding_model in batch:
finding_output = FindingOutput.transform_api_finding(
finding_model, prowler_provider
)
findings_by_check_id[finding_output.check_id].append(finding_output)
# Update cache with newly loaded findings
if finding_output.check_id not in findings_cache:
findings_cache[finding_output.check_id] = []
findings_cache[finding_output.check_id].append(finding_output)
total_findings_loaded = sum(
len(findings) for findings in findings_by_check_id.values()
)
logger.info(
f"Loaded {total_findings_loaded} findings for {len(findings_by_check_id)} checks"
)
return dict(findings_by_check_id)
def _calculate_requirements_data_from_statistics(
compliance_obj, requirement_statistics_by_check_id: dict[str, dict[str, int]]
) -> tuple[dict[str, dict], list[dict]]:
"""
Calculate requirement status and statistics using pre-aggregated database statistics.
This function uses O(n) lookups with pre-aggregated statistics from the database,
avoiding the need to iterate over all findings for each requirement.
Args:
compliance_obj: The compliance framework object containing requirements.
requirement_statistics_by_check_id (dict[str, dict[str, int]]): Pre-aggregated statistics
mapping check_id to {'passed': int, 'total': int} counts.
Returns:
tuple[dict[str, dict], list[dict]]: A tuple containing:
- attributes_by_requirement_id: Dictionary mapping requirement IDs to their attributes.
- requirements_list: List of requirement dictionaries with status and statistics.
"""
attributes_by_requirement_id = {}
requirements_list = []
compliance_framework = getattr(compliance_obj, "Framework", "N/A")
compliance_version = getattr(compliance_obj, "Version", "N/A")
for requirement in compliance_obj.Requirements:
requirement_id = requirement.Id
requirement_description = getattr(requirement, "Description", "")
requirement_checks = getattr(requirement, "Checks", [])
requirement_attributes = getattr(requirement, "Attributes", [])
# Store requirement metadata for later use
attributes_by_requirement_id[requirement_id] = {
"attributes": {
"req_attributes": requirement_attributes,
"checks": requirement_checks,
},
"description": requirement_description,
}
# Calculate aggregated passed and total findings for this requirement
total_passed_findings = 0
total_findings_count = 0
for check_id in requirement_checks:
if check_id in requirement_statistics_by_check_id:
check_statistics = requirement_statistics_by_check_id[check_id]
total_findings_count += check_statistics["total"]
total_passed_findings += check_statistics["passed"]
# Determine overall requirement status based on findings
if total_findings_count > 0:
if total_passed_findings == total_findings_count:
requirement_status = StatusChoices.PASS
else:
# Partial pass or complete fail both count as FAIL
requirement_status = StatusChoices.FAIL
else:
# No findings means manual review required
requirement_status = StatusChoices.MANUAL
requirements_list.append(
{
"id": requirement_id,
"attributes": {
"framework": compliance_framework,
"version": compliance_version,
"status": requirement_status,
"description": requirement_description,
"passed_findings": total_passed_findings,
"total_findings": total_findings_count,
},
}
)
return attributes_by_requirement_id, requirements_list
def generate_threatscore_report(
tenant_id: str,
scan_id: str,
@@ -3782,6 +3556,68 @@ def generate_compliance_reports(
findings_cache=findings_cache, # Share findings cache
)
# Compute and store ThreatScore metrics snapshot
logger.info(f"Computing ThreatScore metrics for scan {scan_id}")
try:
metrics = compute_threatscore_metrics(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
compliance_id=compliance_id_threatscore,
min_risk_level=min_risk_level_threatscore,
)
# Create snapshot in database
with rls_transaction(tenant_id):
# Get previous snapshot for the same provider to calculate delta
previous_snapshot = (
ThreatScoreSnapshot.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
compliance_id=compliance_id_threatscore,
)
.order_by("-inserted_at")
.first()
)
# Calculate score delta (improvement)
score_delta = None
if previous_snapshot:
score_delta = metrics["overall_score"] - float(
previous_snapshot.overall_score
)
snapshot = ThreatScoreSnapshot.objects.create(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
compliance_id=compliance_id_threatscore,
overall_score=metrics["overall_score"],
score_delta=score_delta,
section_scores=metrics["section_scores"],
critical_requirements=metrics["critical_requirements"],
total_requirements=metrics["total_requirements"],
passed_requirements=metrics["passed_requirements"],
failed_requirements=metrics["failed_requirements"],
manual_requirements=metrics["manual_requirements"],
total_findings=metrics["total_findings"],
passed_findings=metrics["passed_findings"],
failed_findings=metrics["failed_findings"],
)
delta_msg = (
f" (delta: {score_delta:+.2f}%)"
if score_delta is not None
else ""
)
logger.info(
f"ThreatScore snapshot created with ID {snapshot.id} "
f"(score: {snapshot.overall_score}%{delta_msg})"
)
except Exception as e:
# Log error but don't fail the job if snapshot creation fails
logger.error(f"Error creating ThreatScore snapshot: {e}")
upload_uri_threatscore = _upload_to_s3(
tenant_id,
scan_id,

View File

@@ -1,9 +1,14 @@
from collections import defaultdict
from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
from django.db.models import Count, Q
from tasks.utils import batched
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import Finding, StatusChoices
from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
@@ -125,3 +130,105 @@ def _calculate_requirements_data_from_statistics(
)
return attributes_by_requirement_id, requirements_list
def _load_findings_for_requirement_checks(
tenant_id: str,
scan_id: str,
check_ids: list[str],
prowler_provider,
findings_cache: dict[str, list[FindingOutput]] | None = None,
) -> dict[str, list[FindingOutput]]:
"""
Load findings for specific check IDs on-demand with optional caching.
This function loads only the findings needed for a specific set of checks,
minimizing memory usage by avoiding loading all findings at once. This is used
when generating detailed findings tables for specific requirements in the PDF.
Supports optional caching to avoid duplicate queries when generating multiple
reports for the same scan.
Args:
tenant_id (str): The tenant ID for Row-Level Security context.
scan_id (str): The ID of the scan to retrieve findings for.
check_ids (list[str]): List of check IDs to load findings for.
prowler_provider: The initialized Prowler provider instance.
findings_cache (dict, optional): Cache of already loaded findings.
If provided, checks are first looked up in cache before querying database.
Returns:
dict[str, list[FindingOutput]]: Dictionary mapping check_id to list of FindingOutput objects.
Example:
{
'aws_iam_user_mfa_enabled': [FindingOutput(...), FindingOutput(...)],
'aws_s3_bucket_public_access': [FindingOutput(...)]
}
"""
findings_by_check_id = defaultdict(list)
if not check_ids:
return dict(findings_by_check_id)
# Initialize cache if not provided
if findings_cache is None:
findings_cache = {}
# Separate cached and non-cached check_ids
check_ids_to_load = []
cache_hits = 0
cache_misses = 0
for check_id in check_ids:
if check_id in findings_cache:
# Reuse from cache
findings_by_check_id[check_id] = findings_cache[check_id]
cache_hits += 1
else:
# Need to load from database
check_ids_to_load.append(check_id)
cache_misses += 1
if cache_hits > 0:
logger.info(
f"Findings cache: {cache_hits} hits, {cache_misses} misses "
f"({cache_hits / (cache_hits + cache_misses) * 100:.1f}% hit rate)"
)
# If all check_ids were in cache, return early
if not check_ids_to_load:
return dict(findings_by_check_id)
logger.info(f"Loading findings for {len(check_ids_to_load)} checks on-demand")
findings_queryset = (
Finding.all_objects.filter(
tenant_id=tenant_id, scan_id=scan_id, check_id__in=check_ids_to_load
)
.order_by("uid")
.iterator()
)
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
for batch, is_last_batch in batched(
findings_queryset, DJANGO_FINDINGS_BATCH_SIZE
):
for finding_model in batch:
finding_output = FindingOutput.transform_api_finding(
finding_model, prowler_provider
)
findings_by_check_id[finding_output.check_id].append(finding_output)
# Update cache with newly loaded findings
if finding_output.check_id not in findings_cache:
findings_cache[finding_output.check_id] = []
findings_cache[finding_output.check_id].append(finding_output)
total_findings_loaded = sum(
len(findings) for findings in findings_by_check_id.values()
)
logger.info(
f"Loaded {total_findings_loaded} findings for {len(findings_by_check_id)} checks"
)
return dict(findings_by_check_id)

View File

@@ -281,7 +281,7 @@ class TestLoadFindingsForChecks:
mock_provider = MagicMock()
with patch(
"tasks.jobs.report.FindingOutput.transform_api_finding"
"tasks.jobs.threatscore_utils.FindingOutput.transform_api_finding"
) as mock_transform:
mock_finding_output = MagicMock()
mock_finding_output.check_id = "check_requested"
@@ -335,7 +335,7 @@ class TestLoadFindingsForChecks:
mock_provider = MagicMock()
with patch(
"tasks.jobs.report.FindingOutput.transform_api_finding"
"tasks.jobs.threatscore_utils.FindingOutput.transform_api_finding"
) as mock_transform:
mock_finding_output = MagicMock()
mock_finding_output.check_id = "check_group"
@@ -369,7 +369,7 @@ class TestLoadFindingsForChecks:
mock_provider = MagicMock()
with patch(
"tasks.jobs.report.FindingOutput.transform_api_finding"
"tasks.jobs.threatscore_utils.FindingOutput.transform_api_finding"
) as mock_transform:
mock_finding_output = MagicMock()
mock_finding_output.check_id = "check_transform"
@@ -406,7 +406,7 @@ class TestLoadFindingsForChecks:
mock_provider = MagicMock()
with patch(
"tasks.jobs.report.FindingOutput.transform_api_finding"
"tasks.jobs.threatscore_utils.FindingOutput.transform_api_finding"
) as mock_transform:
mock_finding_output = MagicMock()
mock_finding_output.check_id = "check_batch"
@@ -760,7 +760,7 @@ class TestGenerateThreatscoreReportFunction:
@patch("tasks.jobs.report.initialize_prowler_provider")
@patch("tasks.jobs.report.Provider.objects.get")
@patch("tasks.jobs.report.Compliance.get_bulk")
@patch("tasks.jobs.report.Finding.all_objects.filter")
@patch("tasks.jobs.threatscore_utils.Finding.all_objects.filter")
def test_generate_threatscore_report_exception_handling(
self,
mock_finding_filter,
@@ -1408,10 +1408,10 @@ class TestGenerateComplianceReportsOptimized:
}
with (
patch("tasks.jobs.report.Finding") as mock_finding_class,
patch("tasks.jobs.report.FindingOutput") as mock_finding_output,
patch("tasks.jobs.report.rls_transaction"),
patch("tasks.jobs.report.batched") as mock_batched,
patch("tasks.jobs.threatscore_utils.Finding") as mock_finding_class,
patch("tasks.jobs.threatscore_utils.FindingOutput") as mock_finding_output,
patch("tasks.jobs.threatscore_utils.rls_transaction"),
patch("tasks.jobs.threatscore_utils.batched") as mock_batched,
):
# Setup mocks
mock_finding_class.all_objects.filter.return_value.order_by.return_value.iterator.return_value = [