Files
prowler/api/src/backend/tasks/tests/test_reports.py

444 lines
15 KiB
Python

import uuid
from unittest.mock import Mock, patch
import matplotlib
import pytest
from reportlab.lib import colors
from tasks.jobs.report import generate_compliance_reports, generate_threatscore_report
from tasks.jobs.reports import (
CHART_COLOR_GREEN_1,
CHART_COLOR_GREEN_2,
CHART_COLOR_ORANGE,
CHART_COLOR_RED,
CHART_COLOR_YELLOW,
COLOR_BLUE,
COLOR_ENS_ALTO,
COLOR_HIGH_RISK,
COLOR_LOW_RISK,
COLOR_MEDIUM_RISK,
COLOR_NIS2_PRIMARY,
COLOR_SAFE,
create_pdf_styles,
get_chart_color_for_percentage,
get_color_for_compliance,
get_color_for_risk_level,
get_color_for_weight,
)
from tasks.jobs.threatscore_utils import (
_aggregate_requirement_statistics_from_database,
_load_findings_for_requirement_checks,
)
from api.models import Finding, Resource, ResourceFindingMapping, StatusChoices
from prowler.lib.check.models import Severity
matplotlib.use("Agg") # Use non-interactive backend for tests
@pytest.mark.django_db
class TestAggregateRequirementStatistics:
"""Test suite for _aggregate_requirement_statistics_from_database function."""
def _create_finding_with_resource(
self, tenant, scan, uid, check_id, status, severity=Severity.high
):
"""Helper to create a finding linked to a resource (matching scan processing behavior)."""
finding = Finding.objects.create(
tenant_id=tenant.id,
scan=scan,
uid=uid,
check_id=check_id,
status=status,
severity=severity,
impact=severity,
check_metadata={},
raw_result={},
)
resource = Resource.objects.create(
tenant_id=tenant.id,
provider=scan.provider,
uid=f"resource-{uid}",
name=f"resource-{uid}",
region="us-east-1",
service="test",
type="test::resource",
)
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
finding=finding,
resource=resource,
)
return finding
def test_aggregates_findings_correctly(self, tenants_fixture, scans_fixture):
"""Verify correct pass/total counts per check are aggregated from database."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
self._create_finding_with_resource(
tenant, scan, "finding-1", "check_1", StatusChoices.PASS
)
self._create_finding_with_resource(
tenant, scan, "finding-2", "check_1", StatusChoices.FAIL
)
self._create_finding_with_resource(
tenant, scan, "finding-3", "check_2", StatusChoices.PASS, Severity.medium
)
result = _aggregate_requirement_statistics_from_database(
str(tenant.id), str(scan.id)
)
assert "check_1" in result
assert result["check_1"]["passed"] == 1
assert result["check_1"]["total"] == 2
assert "check_2" in result
assert result["check_2"]["passed"] == 1
assert result["check_2"]["total"] == 1
def test_handles_empty_scan(self, tenants_fixture, scans_fixture):
"""Verify empty result is returned for scan with no findings."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
result = _aggregate_requirement_statistics_from_database(
str(tenant.id), str(scan.id)
)
assert result == {}
def test_only_failed_findings(self, tenants_fixture, scans_fixture):
"""Verify correct counts when all findings are FAIL."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
self._create_finding_with_resource(
tenant, scan, "finding-1", "check_1", StatusChoices.FAIL
)
self._create_finding_with_resource(
tenant, scan, "finding-2", "check_1", StatusChoices.FAIL
)
result = _aggregate_requirement_statistics_from_database(
str(tenant.id), str(scan.id)
)
assert result["check_1"]["passed"] == 0
assert result["check_1"]["total"] == 2
def test_multiple_findings_same_check(self, tenants_fixture, scans_fixture):
"""Verify multiple findings for same check are correctly aggregated."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
for i in range(5):
self._create_finding_with_resource(
tenant,
scan,
f"finding-{i}",
"check_1",
StatusChoices.PASS if i % 2 == 0 else StatusChoices.FAIL,
)
result = _aggregate_requirement_statistics_from_database(
str(tenant.id), str(scan.id)
)
assert result["check_1"]["passed"] == 3
assert result["check_1"]["total"] == 5
def test_mixed_statuses(self, tenants_fixture, scans_fixture):
"""Verify MANUAL status is not counted in total or passed."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
self._create_finding_with_resource(
tenant, scan, "finding-1", "check_1", StatusChoices.PASS
)
self._create_finding_with_resource(
tenant, scan, "finding-2", "check_1", StatusChoices.MANUAL
)
result = _aggregate_requirement_statistics_from_database(
str(tenant.id), str(scan.id)
)
# MANUAL findings are excluded from the aggregation query
# since it only counts PASS and FAIL statuses
assert result["check_1"]["passed"] == 1
assert result["check_1"]["total"] == 1
def test_skips_aggregation_for_deleted_provider(
self, tenants_fixture, scans_fixture
):
"""Verify aggregation returns empty when the scan's provider is soft-deleted."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
self._create_finding_with_resource(
tenant, scan, "finding-1", "check_1", StatusChoices.PASS
)
# Soft-delete the provider
provider = scan.provider
provider.is_deleted = True
provider.save(update_fields=["is_deleted"])
result = _aggregate_requirement_statistics_from_database(
str(tenant.id), str(scan.id)
)
assert result == {}
def test_multiple_resources_no_double_count(self, tenants_fixture, scans_fixture):
"""Verify a finding with multiple resources is only counted once."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
finding = Finding.objects.create(
tenant_id=tenant.id,
scan=scan,
uid="finding-1",
check_id="check_1",
status=StatusChoices.PASS,
severity=Severity.high,
impact=Severity.high,
check_metadata={},
raw_result={},
)
# Link two resources to the same finding
for i in range(2):
resource = Resource.objects.create(
tenant_id=tenant.id,
provider=scan.provider,
uid=f"resource-{i}",
name=f"resource-{i}",
region="us-east-1",
service="test",
type="test::resource",
)
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
finding=finding,
resource=resource,
)
result = _aggregate_requirement_statistics_from_database(
str(tenant.id), str(scan.id)
)
assert result["check_1"]["passed"] == 1
assert result["check_1"]["total"] == 1
class TestColorHelperFunctions:
"""Test suite for color helper functions."""
def test_get_color_for_risk_level_high(self):
"""Test high risk level returns correct color."""
result = get_color_for_risk_level(5)
assert result == COLOR_HIGH_RISK
def test_get_color_for_risk_level_medium_high(self):
"""Test risk level 4 returns high risk color."""
result = get_color_for_risk_level(4)
assert result == COLOR_HIGH_RISK # >= 4 is high risk
def test_get_color_for_risk_level_medium(self):
"""Test risk level 3 returns medium risk color."""
result = get_color_for_risk_level(3)
assert result == COLOR_MEDIUM_RISK # >= 3 is medium risk
def test_get_color_for_risk_level_low(self):
"""Test low risk level returns safe color."""
result = get_color_for_risk_level(1)
assert result == COLOR_SAFE # < 2 is safe
def test_get_color_for_weight_high(self):
"""Test high weight returns correct color."""
result = get_color_for_weight(150)
assert result == COLOR_HIGH_RISK # > 100 is high risk
def test_get_color_for_weight_medium(self):
"""Test medium weight returns low risk color."""
result = get_color_for_weight(100)
assert result == COLOR_LOW_RISK # 51-100 is low risk
def test_get_color_for_weight_low(self):
"""Test low weight returns safe color."""
result = get_color_for_weight(50)
assert result == COLOR_SAFE # <= 50 is safe
def test_get_color_for_compliance_high(self):
"""Test high compliance returns green color."""
result = get_color_for_compliance(85)
assert result == COLOR_SAFE
def test_get_color_for_compliance_medium(self):
"""Test medium compliance returns yellow color."""
result = get_color_for_compliance(70)
assert result == COLOR_LOW_RISK
def test_get_color_for_compliance_low(self):
"""Test low compliance returns red color."""
result = get_color_for_compliance(50)
assert result == COLOR_HIGH_RISK
def test_get_chart_color_for_percentage_excellent(self):
"""Test excellent percentage returns correct chart color."""
result = get_chart_color_for_percentage(90)
assert result == CHART_COLOR_GREEN_1
def test_get_chart_color_for_percentage_good(self):
"""Test good percentage returns correct chart color."""
result = get_chart_color_for_percentage(70)
assert result == CHART_COLOR_GREEN_2
def test_get_chart_color_for_percentage_fair(self):
"""Test fair percentage returns correct chart color."""
result = get_chart_color_for_percentage(50)
assert result == CHART_COLOR_YELLOW
def test_get_chart_color_for_percentage_poor(self):
"""Test poor percentage returns correct chart color."""
result = get_chart_color_for_percentage(30)
assert result == CHART_COLOR_ORANGE
def test_get_chart_color_for_percentage_critical(self):
"""Test critical percentage returns correct chart color."""
result = get_chart_color_for_percentage(10)
assert result == CHART_COLOR_RED
class TestPDFStylesCreation:
"""Test suite for PDF styles creation."""
def test_create_pdf_styles_returns_dict(self):
"""Test that create_pdf_styles returns a dictionary."""
result = create_pdf_styles()
assert isinstance(result, dict)
def test_create_pdf_styles_caches_result(self):
"""Test that create_pdf_styles caches the result."""
result1 = create_pdf_styles()
result2 = create_pdf_styles()
assert result1 is result2
def test_pdf_styles_have_correct_keys(self):
"""Test that PDF styles dictionary has expected keys."""
result = create_pdf_styles()
expected_keys = ["title", "h1", "h2", "h3", "normal", "normal_center"]
for key in expected_keys:
assert key in result
@pytest.mark.django_db
class TestLoadFindingsForChecks:
"""Test suite for _load_findings_for_requirement_checks function."""
def test_empty_check_ids_returns_empty(self, tenants_fixture, providers_fixture):
"""Test that empty check_ids list returns empty dict."""
tenant = tenants_fixture[0]
mock_prowler_provider = Mock()
mock_prowler_provider.identity.account = "test-account"
result = _load_findings_for_requirement_checks(
str(tenant.id), str(uuid.uuid4()), [], mock_prowler_provider
)
assert result == {}
@pytest.mark.django_db
class TestGenerateThreatscoreReportFunction:
"""Test suite for generate_threatscore_report function."""
@patch("tasks.jobs.reports.base.initialize_prowler_provider")
def test_generate_threatscore_report_exception_handling(
self,
mock_initialize_provider,
tenants_fixture,
scans_fixture,
providers_fixture,
):
"""Test that exceptions during report generation are properly handled."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
mock_initialize_provider.side_effect = Exception("Test exception")
with pytest.raises(Exception) as exc_info:
generate_threatscore_report(
tenant_id=str(tenant.id),
scan_id=str(scan.id),
compliance_id="prowler_threatscore_aws",
output_path="/tmp/test_report.pdf",
provider_id=str(provider.id),
)
assert "Test exception" in str(exc_info.value)
@pytest.mark.django_db
class TestGenerateComplianceReportsOptimized:
"""Test suite for generate_compliance_reports function."""
@patch("tasks.jobs.report._upload_to_s3")
@patch("tasks.jobs.report.generate_threatscore_report")
@patch("tasks.jobs.report.generate_ens_report")
@patch("tasks.jobs.report.generate_nis2_report")
def test_no_findings_returns_early_for_both_reports(
self,
mock_nis2,
mock_ens,
mock_threatscore,
mock_upload,
tenants_fixture,
scans_fixture,
providers_fixture,
):
"""Test that function returns early when scan has no findings."""
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
result = generate_compliance_reports(
tenant_id=str(tenant.id),
scan_id=str(scan.id),
provider_id=str(provider.id),
generate_threatscore=True,
generate_ens=True,
generate_nis2=True,
)
assert result["threatscore"]["upload"] is False
assert result["ens"]["upload"] is False
assert result["nis2"]["upload"] is False
mock_threatscore.assert_not_called()
mock_ens.assert_not_called()
mock_nis2.assert_not_called()
class TestOptimizationImprovements:
"""Test suite for optimization-related functionality."""
def test_chart_color_constants_are_strings(self):
"""Verify chart color constants are valid hex color strings."""
assert CHART_COLOR_GREEN_1.startswith("#")
assert CHART_COLOR_GREEN_2.startswith("#")
assert CHART_COLOR_YELLOW.startswith("#")
assert CHART_COLOR_ORANGE.startswith("#")
assert CHART_COLOR_RED.startswith("#")
def test_color_constants_are_color_objects(self):
"""Verify color constants are Color objects."""
assert isinstance(COLOR_BLUE, colors.Color)
assert isinstance(COLOR_HIGH_RISK, colors.Color)
assert isinstance(COLOR_SAFE, colors.Color)
assert isinstance(COLOR_ENS_ALTO, colors.Color)
assert isinstance(COLOR_NIS2_PRIMARY, colors.Color)