From 8cde5a16363cf6b8dd4c9b15e6917c1881afc950 Mon Sep 17 00:00:00 2001 From: pedrooot Date: Thu, 18 Dec 2025 12:33:24 +0100 Subject: [PATCH] chore(revision): resolve comments --- .../backend/tasks/jobs/reports/__init__.py | 2 + api/src/backend/tasks/jobs/reports/base.py | 84 ++++++++++++++----- api/src/backend/tasks/jobs/reports/ens.py | 64 +++++++------- api/src/backend/tasks/jobs/reports/nis2.py | 24 +++--- .../backend/tasks/jobs/reports/threatscore.py | 31 ++++--- .../backend/tasks/tests/test_reports_base.py | 52 +----------- .../tasks/tests/test_reports_threatscore.py | 11 +-- 7 files changed, 126 insertions(+), 142 deletions(-) diff --git a/api/src/backend/tasks/jobs/reports/__init__.py b/api/src/backend/tasks/jobs/reports/__init__.py index cfa7b3a75f..60602b93ab 100644 --- a/api/src/backend/tasks/jobs/reports/__init__.py +++ b/api/src/backend/tasks/jobs/reports/__init__.py @@ -4,6 +4,7 @@ from .base import ( ComplianceData, RequirementData, create_pdf_styles, + get_requirement_metadata, ) # Chart functions @@ -99,6 +100,7 @@ __all__ = [ "ComplianceData", "RequirementData", "create_pdf_styles", + "get_requirement_metadata", # Framework-specific generators "ThreatScoreReportGenerator", "ENSReportGenerator", diff --git a/api/src/backend/tasks/jobs/reports/base.py b/api/src/backend/tasks/jobs/reports/base.py index 9ec35a10a7..91c4656d1f 100644 --- a/api/src/backend/tasks/jobs/reports/base.py +++ b/api/src/backend/tasks/jobs/reports/base.py @@ -13,13 +13,25 @@ from reportlab.pdfbase import pdfmetrics from reportlab.pdfbase.ttfonts import TTFont from reportlab.pdfgen import canvas from reportlab.platypus import Image, PageBreak, Paragraph, SimpleDocTemplate, Spacer +from tasks.jobs.threatscore_utils import ( + _aggregate_requirement_statistics_from_database, + _calculate_requirements_data_from_statistics, + _load_findings_for_requirement_checks, +) from api.db_router import READ_REPLICA_ALIAS from api.db_utils import rls_transaction from api.models import Provider, StatusChoices +from api.utils import initialize_prowler_provider from prowler.lib.check.compliance_models import Compliance from prowler.lib.outputs.finding import Finding as FindingOutput +from .components import ( + ColumnConfig, + create_data_table, + create_info_table, + create_status_badge, +) from .config import ( COLOR_BG_BLUE, COLOR_BG_LIGHT_BLUE, @@ -37,13 +49,17 @@ from .config import ( logger = get_task_logger(__name__) # Register fonts (done once at module load) -_FONTS_REGISTERED = False +_fonts_registered: bool = False def _register_fonts() -> None: - """Register custom fonts for PDF generation.""" - global _FONTS_REGISTERED - if _FONTS_REGISTERED: + """Register custom fonts for PDF generation. + + Uses a module-level flag to ensure fonts are only registered once, + avoiding duplicate registration errors from reportlab. + """ + global _fonts_registered + if _fonts_registered: return fonts_dir = os.path.join(os.path.dirname(__file__), "../../assets/fonts") @@ -62,7 +78,7 @@ def _register_fonts() -> None: ) ) - _FONTS_REGISTERED = True + _fonts_registered = True # ============================================================================= @@ -133,6 +149,35 @@ class ComplianceData: prowler_provider: Any = None +def get_requirement_metadata( + requirement_id: str, + attributes_by_requirement_id: dict[str, dict], +) -> Any | None: + """Get the first requirement metadata object from attributes. + + This helper function extracts the requirement metadata (req_attributes) + from the attributes dictionary. It's a common pattern used across all + report generators. + + Args: + requirement_id: The requirement ID to look up. + attributes_by_requirement_id: Mapping of requirement IDs to their attributes. + + Returns: + The first requirement attribute object, or None if not found. + + Example: + >>> meta = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + >>> if meta: + ... section = getattr(meta, "Section", "Unknown") + """ + req_attrs = attributes_by_requirement_id.get(requirement_id, {}) + meta_list = req_attrs.get("attributes", {}).get("req_attributes", []) + if meta_list: + return meta_list[0] + return None + + # ============================================================================= # PDF Styles Cache # ============================================================================= @@ -435,8 +480,6 @@ class BaseComplianceReportGenerator(ABC): Returns: List of ReportLab elements """ - from .components import create_info_table - elements = [] # Prowler logo @@ -493,17 +536,24 @@ class BaseComplianceReportGenerator(ABC): Returns: List of ReportLab elements """ - from tasks.jobs.threatscore_utils import _load_findings_for_requirement_checks - - from .components import create_status_badge - elements = [] only_failed = kwargs.get("only_failed", True) + include_manual = kwargs.get("include_manual", False) # Filter requirements if needed requirements = data.requirements if only_failed: - requirements = [r for r in requirements if r.status == StatusChoices.FAIL] + # Include FAIL requirements, and optionally MANUAL if include_manual is True + if include_manual: + requirements = [ + r + for r in requirements + if r.status in (StatusChoices.FAIL, StatusChoices.MANUAL) + ] + else: + requirements = [ + r for r in requirements if r.status == StatusChoices.FAIL + ] # Collect all check IDs for requirements that will be displayed # This allows us to load only the findings we actually need (memory optimization) @@ -602,13 +652,6 @@ class BaseComplianceReportGenerator(ABC): Returns: Aggregated ComplianceData object """ - from tasks.jobs.threatscore_utils import ( - _aggregate_requirement_statistics_from_database, - _calculate_requirements_data_from_statistics, - ) - - from api.utils import initialize_prowler_provider - with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): # Load provider if provider_obj is None: @@ -672,7 +715,7 @@ class BaseComplianceReportGenerator(ABC): description=description, requirements=requirements, attributes_by_requirement_id=attributes_by_requirement_id, - findings_by_check_id=findings_cache or {}, + findings_by_check_id=findings_cache if findings_cache is not None else {}, provider_obj=provider_obj, prowler_provider=prowler_provider, ) @@ -744,7 +787,6 @@ class BaseComplianceReportGenerator(ABC): Returns: ReportLab Table element """ - from .components import ColumnConfig, create_data_table def get_finding_title(f): metadata = getattr(f, "metadata", None) diff --git a/api/src/backend/tasks/jobs/reports/ens.py b/api/src/backend/tasks/jobs/reports/ens.py index bdb59db2cb..46d2793c14 100644 --- a/api/src/backend/tasks/jobs/reports/ens.py +++ b/api/src/backend/tasks/jobs/reports/ens.py @@ -8,7 +8,11 @@ from reportlab.platypus import Image, PageBreak, Paragraph, Spacer, Table, Table from api.models import StatusChoices -from .base import BaseComplianceReportGenerator, ComplianceData +from .base import ( + BaseComplianceReportGenerator, + ComplianceData, + get_requirement_metadata, +) from .charts import create_horizontal_bar_chart, create_radar_chart from .components import get_color_for_compliance from .config import ( @@ -330,10 +334,8 @@ class ENSReportGenerator(BaseComplianceReportGenerator): if req.status == StatusChoices.MANUAL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: marco = getattr(m, "Marco", "Otros") categoria = getattr(m, "Categoria", "Sin categoría") descripcion = getattr(m, "DescripcionControl", req.description) @@ -442,10 +444,8 @@ class ENSReportGenerator(BaseComplianceReportGenerator): if req.status == StatusChoices.MANUAL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: nivel = getattr(m, "Nivel", "").lower() nivel_data[nivel]["total"] += 1 if req.status == StatusChoices.PASS: @@ -520,10 +520,8 @@ class ENSReportGenerator(BaseComplianceReportGenerator): if req.status == StatusChoices.MANUAL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: marco = getattr(m, "Marco", "otros") categoria = getattr(m, "Categoria", "sin categoría") # Combined key: "marco - categoría" @@ -554,10 +552,8 @@ class ENSReportGenerator(BaseComplianceReportGenerator): if req.status == StatusChoices.MANUAL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: dimensiones = getattr(m, "Dimensiones", []) if isinstance(dimensiones, str): dimensiones = [d.strip().lower() for d in dimensiones.split(",")] @@ -600,10 +596,8 @@ class ENSReportGenerator(BaseComplianceReportGenerator): if req.status == StatusChoices.MANUAL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: tipo = getattr(m, "Tipo", "").lower() tipo_data[tipo]["total"] += 1 if req.status == StatusChoices.PASS: @@ -661,10 +655,8 @@ class ENSReportGenerator(BaseComplianceReportGenerator): if req.status != StatusChoices.FAIL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: nivel = getattr(m, "Nivel", "").lower() if nivel == "alto": critical_failed.append( @@ -766,14 +758,22 @@ class ENSReportGenerator(BaseComplianceReportGenerator): List of ReportLab elements. """ elements = [] + include_manual = kwargs.get("include_manual", True) elements.append(Paragraph("Detalle de Requisitos", self.styles["h1"])) elements.append(Spacer(1, 0.2 * inch)) - # Get failed requirements (non-manual) - failed_requirements = [ - r for r in data.requirements if r.status == StatusChoices.FAIL - ] + # Get failed requirements, and optionally manual requirements + if include_manual: + failed_requirements = [ + r + for r in data.requirements + if r.status in (StatusChoices.FAIL, StatusChoices.MANUAL) + ] + else: + failed_requirements = [ + r for r in data.requirements if r.status == StatusChoices.FAIL + ] if not failed_requirements: elements.append( @@ -802,13 +802,11 @@ class ENSReportGenerator(BaseComplianceReportGenerator): } for req in failed_requirements: - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) - if not meta: + if not m: continue - m = meta[0] nivel = getattr(m, "Nivel", "").lower() tipo = getattr(m, "Tipo", "") modo = getattr(m, "ModoEjecucion", "") diff --git a/api/src/backend/tasks/jobs/reports/nis2.py b/api/src/backend/tasks/jobs/reports/nis2.py index e4219ee5d2..1f3d20336e 100644 --- a/api/src/backend/tasks/jobs/reports/nis2.py +++ b/api/src/backend/tasks/jobs/reports/nis2.py @@ -6,7 +6,11 @@ from reportlab.platypus import Image, PageBreak, Paragraph, Spacer, Table, Table from api.models import StatusChoices -from .base import BaseComplianceReportGenerator, ComplianceData +from .base import ( + BaseComplianceReportGenerator, + ComplianceData, + get_requirement_metadata, +) from .charts import create_horizontal_bar_chart, get_chart_color_for_percentage from .config import ( COLOR_BORDER_GRAY, @@ -263,10 +267,8 @@ class NIS2ReportGenerator(BaseComplianceReportGenerator): # Organize by section number and subsection sections = {} for req in data.requirements: - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: full_section = getattr(m, "Section", "Other") # Extract section number from full title (e.g., "1 POLICY..." -> "1") section_num = _extract_section_number(full_section) @@ -343,10 +345,8 @@ class NIS2ReportGenerator(BaseComplianceReportGenerator): if req.status == StatusChoices.MANUAL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: full_section = getattr(m, "Section", "Other") # Extract section number from full title (e.g., "1 POLICY..." -> "1") section_num = _extract_section_number(full_section) @@ -385,10 +385,8 @@ class NIS2ReportGenerator(BaseComplianceReportGenerator): subsection_scores = defaultdict(lambda: {"passed": 0, "failed": 0, "manual": 0}) for req in data.requirements: - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: full_section = getattr(m, "Section", "") subsection = getattr(m, "SubSection", "") # Use section number + subsection for grouping diff --git a/api/src/backend/tasks/jobs/reports/threatscore.py b/api/src/backend/tasks/jobs/reports/threatscore.py index 9e2bae8883..76e9505ad0 100644 --- a/api/src/backend/tasks/jobs/reports/threatscore.py +++ b/api/src/backend/tasks/jobs/reports/threatscore.py @@ -4,7 +4,11 @@ from reportlab.platypus import Image, PageBreak, Paragraph, Spacer, Table, Table from api.models import StatusChoices -from .base import BaseComplianceReportGenerator, ComplianceData +from .base import ( + BaseComplianceReportGenerator, + ComplianceData, + get_requirement_metadata, +) from .charts import create_vertical_bar_chart, get_chart_color_for_percentage from .components import get_color_for_compliance, get_color_for_weight from .config import COLOR_HIGH_RISK, COLOR_WHITE @@ -145,10 +149,9 @@ class ThreatScoreReportGenerator(BaseComplianceReportGenerator): # Organize requirements by section and subsection sections = {} - for req_id, req_attrs in data.attributes_by_requirement_id.items(): - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + for req_id in data.attributes_by_requirement_id: + m = get_requirement_metadata(req_id, data.attributes_by_requirement_id) + if m: section = getattr(m, "Section", "N/A") subsection = getattr(m, "SubSection", "N/A") title = getattr(m, "Title", "N/A") @@ -202,10 +205,8 @@ class ThreatScoreReportGenerator(BaseComplianceReportGenerator): sections_data = {} for req in data.requirements: - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) - if meta: - m = meta[0] + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) + if m: section = getattr(m, "Section", "Other") all_sections.add(section) @@ -285,11 +286,9 @@ class ThreatScoreReportGenerator(BaseComplianceReportGenerator): continue has_findings = True - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) - if meta: - m = meta[0] + if m: risk_level_raw = getattr(m, "LevelOfRisk", 0) weight_raw = getattr(m, "Weight", 0) # Ensure numeric types for calculations (compliance data may have str) @@ -333,11 +332,9 @@ class ThreatScoreReportGenerator(BaseComplianceReportGenerator): if req.status != StatusChoices.FAIL: continue - req_attrs = data.attributes_by_requirement_id.get(req.id, {}) - meta = req_attrs.get("attributes", {}).get("req_attributes", [{}]) + m = get_requirement_metadata(req.id, data.attributes_by_requirement_id) - if meta: - m = meta[0] + if m: risk_level_raw = getattr(m, "LevelOfRisk", 0) weight_raw = getattr(m, "Weight", 0) # Ensure numeric types for calculations (compliance data may have str) diff --git a/api/src/backend/tasks/tests/test_reports_base.py b/api/src/backend/tasks/tests/test_reports_base.py index 0f48724e93..d6231c90cf 100644 --- a/api/src/backend/tasks/tests/test_reports_base.py +++ b/api/src/backend/tasks/tests/test_reports_base.py @@ -2,6 +2,7 @@ import io import pytest from reportlab.lib.units import inch +from reportlab.platypus import Image, LongTable, Paragraph, Spacer, Table from tasks.jobs.reports import ( # Configuration; Colors; Components; Charts; Base CHART_COLOR_GREEN_1, CHART_COLOR_GREEN_2, @@ -9,7 +10,10 @@ from tasks.jobs.reports import ( # Configuration; Colors; Components; Charts; B CHART_COLOR_RED, CHART_COLOR_YELLOW, COLOR_BLUE, + COLOR_DARK_GRAY, COLOR_HIGH_RISK, + COLOR_LOW_RISK, + COLOR_MEDIUM_RISK, COLOR_SAFE, FRAMEWORK_REGISTRY, BaseComplianceReportGenerator, @@ -155,14 +159,10 @@ class TestColorHelpers: def test_get_color_for_risk_level_medium(self): """Test medium risk level returns orange.""" - from tasks.jobs.reports import COLOR_MEDIUM_RISK - assert get_color_for_risk_level(3) == COLOR_MEDIUM_RISK def test_get_color_for_risk_level_low(self): """Test low risk level returns yellow.""" - from tasks.jobs.reports import COLOR_LOW_RISK - assert get_color_for_risk_level(2) == COLOR_LOW_RISK def test_get_color_for_risk_level_safe(self): @@ -181,8 +181,6 @@ class TestColorHelpers: def test_get_color_for_weight_medium(self): """Test medium weight returns yellow.""" - from tasks.jobs.reports import COLOR_LOW_RISK - assert get_color_for_weight(100) == COLOR_LOW_RISK assert get_color_for_weight(51) == COLOR_LOW_RISK @@ -198,8 +196,6 @@ class TestColorHelpers: def test_get_color_for_compliance_medium(self): """Test medium compliance returns yellow.""" - from tasks.jobs.reports import COLOR_LOW_RISK - assert get_color_for_compliance(79) == COLOR_LOW_RISK assert get_color_for_compliance(60) == COLOR_LOW_RISK @@ -220,8 +216,6 @@ class TestColorHelpers: def test_get_status_color_manual(self): """Test MANUAL status returns gray.""" - from tasks.jobs.reports import COLOR_DARK_GRAY - assert get_status_color("MANUAL") == COLOR_DARK_GRAY @@ -235,8 +229,6 @@ class TestChartColorHelpers: def test_chart_color_for_medium_high_percentage(self): """Test medium-high percentage returns light green.""" - from tasks.jobs.reports import CHART_COLOR_GREEN_2 - assert get_chart_color_for_percentage(79) == CHART_COLOR_GREEN_2 assert get_chart_color_for_percentage(60) == CHART_COLOR_GREEN_2 @@ -274,8 +266,6 @@ class TestBadgeComponents: def test_create_badge_returns_table(self): """Test create_badge returns a Table object.""" - from reportlab.platypus import Table - badge = create_badge("Test", COLOR_BLUE) assert isinstance(badge, Table) @@ -286,8 +276,6 @@ class TestBadgeComponents: def test_create_status_badge_pass(self): """Test status badge for PASS.""" - from reportlab.platypus import Table - badge = create_status_badge("PASS") assert isinstance(badge, Table) @@ -298,8 +286,6 @@ class TestBadgeComponents: def test_create_multi_badge_row_with_badges(self): """Test multi-badge row with data.""" - from reportlab.platypus import Table - badges = [ ("A", COLOR_BLUE), ("B", COLOR_SAFE), @@ -318,8 +304,6 @@ class TestRiskComponent: def test_create_risk_component_returns_table(self): """Test risk component returns a Table.""" - from reportlab.platypus import Table - component = create_risk_component(risk_level=4, weight=100, score=50) assert isinstance(component, Table) @@ -339,8 +323,6 @@ class TestTableComponents: def test_create_info_table(self): """Test info table creation.""" - from reportlab.platypus import Table - rows = [ ("Label 1:", "Value 1"), ("Label 2:", "Value 2"), @@ -356,8 +338,6 @@ class TestTableComponents: def test_create_data_table(self): """Test data table creation.""" - from reportlab.platypus import Table - data = [ {"name": "Item 1", "value": "100"}, {"name": "Item 2", "value": "200"}, @@ -380,8 +360,6 @@ class TestTableComponents: def test_create_summary_table(self): """Test summary table creation.""" - from reportlab.platypus import Table - table = create_summary_table( label="Score:", value="85%", @@ -391,8 +369,6 @@ class TestTableComponents: def test_create_summary_table_with_custom_widths(self): """Test summary table with custom widths.""" - from reportlab.platypus import Table - table = create_summary_table( label="ThreatScore:", value="92.5%", @@ -408,8 +384,6 @@ class TestFindingsTable: def test_create_findings_table_with_dicts(self): """Test findings table creation with dict data.""" - from reportlab.platypus import Table - findings = [ { "title": "Finding 1", @@ -450,8 +424,6 @@ class TestSectionHeader: def test_create_section_header_with_spacer(self): """Test section header with spacer.""" - from reportlab.platypus import Paragraph, Spacer - styles = create_pdf_styles() elements = create_section_header("Test Header", styles["h1"]) @@ -461,8 +433,6 @@ class TestSectionHeader: def test_create_section_header_without_spacer(self): """Test section header without spacer.""" - from reportlab.platypus import Paragraph - styles = create_pdf_styles() elements = create_section_header("Test Header", styles["h1"], add_spacer=False) @@ -864,8 +834,6 @@ class TestExampleReportGenerator: """Example concrete implementation for testing.""" def create_executive_summary(self, data): - from reportlab.platypus import Paragraph - return [ Paragraph("Executive Summary", self.styles["h1"]), Paragraph( @@ -875,8 +843,6 @@ class TestExampleReportGenerator: ] def create_charts_section(self, data): - from reportlab.platypus import Image - chart_buffer = create_vertical_bar_chart( labels=["Pass", "Fail"], values=[80, 20], @@ -884,8 +850,6 @@ class TestExampleReportGenerator: return [Image(chart_buffer, width=6 * inch, height=4 * inch)] def create_requirements_index(self, data): - from reportlab.platypus import Paragraph - elements = [Paragraph("Requirements Index", self.styles["h1"])] for req in data.requirements: elements.append( @@ -1063,8 +1027,6 @@ class TestComponentEdgeCases: def test_create_info_table_empty(self): """Test info table with empty rows.""" - from reportlab.platypus import Table - table = create_info_table([]) assert isinstance(table, Table) @@ -1092,8 +1054,6 @@ class TestComponentEdgeCases: columns = [ColumnConfig("Name", 2 * inch, "name")] table = create_data_table(data, columns) # Should be a LongTable for large datasets - from reportlab.platypus import LongTable - assert isinstance(table, LongTable) def test_create_risk_component_zero_values(self): @@ -1116,8 +1076,6 @@ class TestColorEdgeCases: def test_get_color_for_compliance_boundary_60(self): """Test compliance color at exactly 60%.""" - from tasks.jobs.reports import COLOR_LOW_RISK - assert get_color_for_compliance(60) == COLOR_LOW_RISK def test_get_color_for_compliance_over_100(self): @@ -1126,8 +1084,6 @@ class TestColorEdgeCases: def test_get_color_for_weight_boundary_100(self): """Test weight color at exactly 100.""" - from tasks.jobs.reports import COLOR_LOW_RISK - assert get_color_for_weight(100) == COLOR_LOW_RISK def test_get_color_for_weight_boundary_50(self): diff --git a/api/src/backend/tasks/tests/test_reports_threatscore.py b/api/src/backend/tasks/tests/test_reports_threatscore.py index 60eddc7548..c79c0b16e9 100644 --- a/api/src/backend/tasks/tests/test_reports_threatscore.py +++ b/api/src/backend/tasks/tests/test_reports_threatscore.py @@ -10,16 +10,7 @@ from tasks.jobs.reports import ( ThreatScoreReportGenerator, ) - -# Use string status values directly to avoid Django DB initialization -# These match api.models.StatusChoices values -class StatusChoices: - """Mock StatusChoices to avoid Django DB initialization.""" - - PASS = "PASS" - FAIL = "FAIL" - MANUAL = "MANUAL" - +from api.models import StatusChoices # ============================================================================= # Fixtures