From 91e3c01f51757b47bd724d61f2691c126e6a7649 Mon Sep 17 00:00:00 2001 From: Pepe Fagoaga Date: Thu, 22 Jan 2026 18:17:50 +0100 Subject: [PATCH] fix(attack-paths): load findings in batches into Neo4j (#9862) Co-authored-by: Josema Camacho --- .env | 2 +- api/CHANGELOG.md | 6 +- .../tasks/jobs/attack_paths/prowler.py | 201 ++++++++---- .../tasks/tests/test_attack_paths_scan.py | 308 +++++++++++++++++- 4 files changed, 434 insertions(+), 83 deletions(-) diff --git a/.env b/.env index 54e2c3149d..734b9df42a 100644 --- a/.env +++ b/.env @@ -66,7 +66,7 @@ NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST=apoc.* NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED=apoc.* NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS=0.0.0.0:7687 # Neo4j Prowler settings -NEO4J_INSERT_BATCH_SIZE=500 +ATTACK_PATHS_FINDINGS_BATCH_SIZE=1000 # Celery-Prowler task settings TASK_RETRY_DELAY_SECONDS=0.1 diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index c37ff6de14..e1fa9942f1 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -4,14 +4,12 @@ All notable changes to the **Prowler API** are documented in this file. ## [1.18.1] (Prowler v5.17.1) -### Changed +### Fixed - Improve API startup process by `manage.py` argument detection [(#9856)](https://github.com/prowler-cloud/prowler/pull/9856) - Deleting providers don't try to delete a `None` Neo4j database when an Attack Paths scan is scheduled [(#9858)](https://github.com/prowler-cloud/prowler/pull/9858) - -### Fixed - - Use replica database for reading Findings to add them to the Attack Paths graph [(#9861)](https://github.com/prowler-cloud/prowler/pull/9861) +- Attack paths findings loading query to use streaming generator for O(batch_size) memory instead of O(total_findings) [(#9862)](https://github.com/prowler-cloud/prowler/pull/9862) ## [1.18.0] (Prowler v5.17.0) diff --git a/api/src/backend/tasks/jobs/attack_paths/prowler.py b/api/src/backend/tasks/jobs/attack_paths/prowler.py index 1cf904fea8..d1d4c6dfdb 100644 --- a/api/src/backend/tasks/jobs/attack_paths/prowler.py +++ b/api/src/backend/tasks/jobs/attack_paths/prowler.py @@ -1,19 +1,22 @@ +from collections import defaultdict +from typing import Generator + import neo4j from cartography.client.core.tx import run_write_query from cartography.config import Config as CartographyConfig from celery.utils.log import get_task_logger - -from api.db_router import MainRouter -from api.db_utils import rls_transaction -from api.models import Provider, ResourceFindingMapping from config.env import env -from prowler.config import config as ProwlerConfig from tasks.jobs.attack_paths.providers import get_node_uid_field, get_root_node_label +from api.db_router import READ_REPLICA_ALIAS +from api.db_utils import rls_transaction +from api.models import Finding, Provider, ResourceFindingMapping +from prowler.config import config as ProwlerConfig + logger = get_task_logger(__name__) -BATCH_SIZE = env.int("NEO4J_INSERT_BATCH_SIZE", 500) +BATCH_SIZE = env.int("ATTACK_PATHS_FINDINGS_BATCH_SIZE", 1000) INDEX_STATEMENTS = [ "CREATE INDEX prowler_finding_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.id);", @@ -84,9 +87,8 @@ def create_indexes(neo4j_session: neo4j.Session) -> None: Code based on Cartography version 0.122.0, specifically on `cartography.intel.create_indexes.run`. """ - logger.info("Creating indexes for Prowler node types.") + logger.info("Creating indexes for Prowler Findings node types") for statement in INDEX_STATEMENTS: - logger.debug("Executing statement: %s", statement) run_write_query(neo4j_session, statement) @@ -96,77 +98,136 @@ def analysis( scan_id: str, config: CartographyConfig, ) -> None: - logger.info(f"Getting Prowler findings for AWS account {prowler_api_provider.uid}") findings_data = get_provider_last_scan_findings(prowler_api_provider, scan_id) - - logger.info(f"Loading Prowler findings for AWS account {prowler_api_provider.uid}") load_findings(neo4j_session, findings_data, prowler_api_provider, config) - - logger.info( - f"Cleaning up Prowler findings for AWS account {prowler_api_provider.uid}" - ) cleanup_findings(neo4j_session, prowler_api_provider, config) def get_provider_last_scan_findings( prowler_api_provider: Provider, scan_id: str, -) -> list[dict[str, str]]: - with rls_transaction(prowler_api_provider.tenant_id): - resource_finding_qs = ( - ResourceFindingMapping.objects.using(MainRouter.replica_db) - .filter( - finding__scan_id=scan_id, - ) - .values( - "resource__uid", - "finding__id", - "finding__uid", - "finding__inserted_at", - "finding__updated_at", - "finding__first_seen_at", - "finding__scan_id", - "finding__delta", - "finding__status", - "finding__status_extended", - "finding__severity", - "finding__check_id", - "finding__check_metadata__checktitle", - "finding__muted", - "finding__muted_reason", - ) - ) +) -> Generator[list[dict[str, str]], None, None]: + """ + Generator that yields batches of finding-resource pairs. - findings = [] - for resource_finding in resource_finding_qs: - findings.append( + Two-step query approach per batch: + 1. Paginate findings for scan (single table, indexed by scan_id) + 2. Batch-fetch resource UIDs via mapping table (single join) + 3. Merge and yield flat structure for Neo4j + + Memory efficient: never holds more than BATCH_SIZE findings in memory. + """ + + logger.info( + f"Starting findings fetch for scan {scan_id} (tenant {prowler_api_provider.tenant_id}) with batch size {BATCH_SIZE}" + ) + + iteration = 0 + last_id = None + + while True: + iteration += 1 + + with rls_transaction(prowler_api_provider.tenant_id, using=READ_REPLICA_ALIAS): + qs = Finding.objects.filter(scan_id=scan_id).order_by("id") + if last_id is not None: + qs = qs.filter(id__gt=last_id) + + findings_batch = list( + qs.values( + "id", + "uid", + "inserted_at", + "updated_at", + "first_seen_at", + "scan_id", + "delta", + "status", + "status_extended", + "severity", + "check_id", + "check_metadata__checktitle", + "muted", + "muted_reason", + )[:BATCH_SIZE] + ) + + logger.info( + f"Iteration #{iteration} fetched {len(findings_batch)} findings" + ) + + if not findings_batch: + logger.info( + f"No findings returned for iteration #{iteration}; stopping pagination" + ) + break + + last_id = findings_batch[-1]["id"] + enriched_batch = _enrich_and_flatten_batch(findings_batch) + + # Yield outside the transaction + if enriched_batch: + yield enriched_batch + + logger.info(f"Finished fetching findings for scan {scan_id}") + + +def _enrich_and_flatten_batch( + findings_batch: list[dict], +) -> list[dict[str, str]]: + """ + Fetch resource UIDs for a batch of findings and return flat structure. + + One finding with 3 resources becomes 3 dicts (same output format as before). + Must be called within an RLS transaction context. + """ + finding_ids = [f["id"] for f in findings_batch] + + # Single join: mapping -> resource + resource_mappings = ResourceFindingMapping.objects.filter( + finding_id__in=finding_ids + ).values_list("finding_id", "resource__uid") + + # Build finding_id -> [resource_uids] mapping + finding_resources = defaultdict(list) + for finding_id, resource_uid in resource_mappings: + finding_resources[finding_id].append(resource_uid) + + # Flatten: one dict per (finding, resource) pair + results = [] + for f in findings_batch: + resource_uids = finding_resources.get(f["id"], []) + + if not resource_uids: + continue + + for resource_uid in resource_uids: + results.append( { - "resource_uid": str(resource_finding["resource__uid"]), - "id": str(resource_finding["finding__id"]), - "uid": resource_finding["finding__uid"], - "inserted_at": resource_finding["finding__inserted_at"], - "updated_at": resource_finding["finding__updated_at"], - "first_seen_at": resource_finding["finding__first_seen_at"], - "scan_id": str(resource_finding["finding__scan_id"]), - "delta": resource_finding["finding__delta"], - "status": resource_finding["finding__status"], - "status_extended": resource_finding["finding__status_extended"], - "severity": resource_finding["finding__severity"], - "check_id": str(resource_finding["finding__check_id"]), - "check_title": resource_finding[ - "finding__check_metadata__checktitle" - ], - "muted": resource_finding["finding__muted"], - "muted_reason": resource_finding["finding__muted_reason"], + "resource_uid": str(resource_uid), + "id": str(f["id"]), + "uid": f["uid"], + "inserted_at": f["inserted_at"], + "updated_at": f["updated_at"], + "first_seen_at": f["first_seen_at"], + "scan_id": str(f["scan_id"]), + "delta": f["delta"], + "status": f["status"], + "status_extended": f["status_extended"], + "severity": f["severity"], + "check_id": str(f["check_id"]), + "check_title": f["check_metadata__checktitle"], + "muted": f["muted"], + "muted_reason": f["muted_reason"], } ) - return findings + return results def load_findings( neo4j_session: neo4j.Session, - findings_data: list[dict[str, str]], + findings_batches: Generator[list[dict[str, str]], None, None], prowler_api_provider: Provider, config: CartographyConfig, ) -> None: @@ -184,16 +245,20 @@ def load_findings( "prowler_version": ProwlerConfig.prowler_version, } - total_length = len(findings_data) - for i in range(0, total_length, BATCH_SIZE): - parameters["findings_data"] = findings_data[i : i + BATCH_SIZE] + batch_num = 0 + total_records = 0 + for batch in findings_batches: + batch_num += 1 + batch_size = len(batch) + total_records += batch_size - logger.info( - f"Loading findings batch {i // BATCH_SIZE + 1} / {(total_length + BATCH_SIZE - 1) // BATCH_SIZE}" - ) + parameters["findings_data"] = batch + logger.info(f"Loading findings batch {batch_num} ({batch_size} records)") neo4j_session.run(query, parameters) + logger.info(f"Finished loading {total_records} records in {batch_num} batches") + def cleanup_findings( neo4j_session: neo4j.Session, diff --git a/api/src/backend/tasks/tests/test_attack_paths_scan.py b/api/src/backend/tasks/tests/test_attack_paths_scan.py index 226d53f60f..923cde4cba 100644 --- a/api/src/backend/tasks/tests/test_attack_paths_scan.py +++ b/api/src/backend/tasks/tests/test_attack_paths_scan.py @@ -3,6 +3,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock, call, patch import pytest +from tasks.jobs.attack_paths import prowler as prowler_module +from tasks.jobs.attack_paths.scan import run as attack_paths_run from api.models import ( AttackPathsScan, @@ -15,8 +17,6 @@ from api.models import ( StatusChoices, ) from prowler.lib.check.models import Severity -from tasks.jobs.attack_paths import prowler as prowler_module -from tasks.jobs.attack_paths.scan import run as attack_paths_run @pytest.mark.django_db @@ -276,15 +276,15 @@ class TestAttackPathsProwlerHelpers: provider.provider = Provider.ProviderChoices.AWS provider.save() - findings = [ - {"id": "1", "resource_uid": "r-1"}, - {"id": "2", "resource_uid": "r-2"}, - ] + # Create a generator that yields two batches + def findings_generator(): + yield [{"id": "1", "resource_uid": "r-1"}] + yield [{"id": "2", "resource_uid": "r-2"}] + config = SimpleNamespace(update_tag=12345) mock_session = MagicMock() with ( - patch.object(prowler_module, "BATCH_SIZE", 1), patch( "tasks.jobs.attack_paths.prowler.get_root_node_label", return_value="AWSAccount", @@ -294,7 +294,9 @@ class TestAttackPathsProwlerHelpers: return_value="arn", ), ): - prowler_module.load_findings(mock_session, findings, provider, config) + prowler_module.load_findings( + mock_session, findings_generator(), provider, config + ) assert mock_session.run.call_count == 2 for call_args in mock_session.run.call_args_list: @@ -403,13 +405,17 @@ class TestAttackPathsProwlerHelpers: "tasks.jobs.attack_paths.prowler.rls_transaction", new=lambda *args, **kwargs: nullcontext(), ), patch( - "tasks.jobs.attack_paths.prowler.MainRouter.replica_db", + "tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS", "default", ): - findings_data = prowler_module.get_provider_last_scan_findings( + # Generator yields batches, collect all findings from all batches + findings_batches = prowler_module.get_provider_last_scan_findings( provider, str(latest_scan.id), ) + findings_data = [] + for batch in findings_batches: + findings_data.extend(batch) assert len(findings_data) == 1 finding_dict = findings_data[0] @@ -417,3 +423,285 @@ class TestAttackPathsProwlerHelpers: assert finding_dict["resource_uid"] == resource.uid assert finding_dict["check_title"] == "Check title" assert finding_dict["scan_id"] == str(latest_scan.id) + + def test_enrich_and_flatten_batch_single_resource( + self, + tenants_fixture, + providers_fixture, + ): + """One finding + one resource = one output dict""" + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + resource = Resource.objects.create( + tenant_id=tenant.id, + provider=provider, + uid="resource-uid-1", + name="Resource 1", + region="us-east-1", + service="ec2", + type="instance", + ) + + scan = Scan.objects.create( + name="Test Scan", + provider=provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant_id=tenant.id, + ) + + finding = Finding.objects.create( + tenant_id=tenant.id, + uid="finding-uid", + scan=scan, + delta=Finding.DeltaChoices.NEW, + status=StatusChoices.FAIL, + status_extended="failed", + severity=Severity.high, + impact=Severity.high, + impact_extended="", + raw_result={}, + check_id="check-1", + check_metadata={"checktitle": "Check title"}, + first_seen_at=scan.inserted_at, + ) + ResourceFindingMapping.objects.create( + tenant_id=tenant.id, + resource=resource, + finding=finding, + ) + + # Simulate the dict returned by .values() + finding_dict = { + "id": finding.id, + "uid": finding.uid, + "inserted_at": finding.inserted_at, + "updated_at": finding.updated_at, + "first_seen_at": finding.first_seen_at, + "scan_id": scan.id, + "delta": finding.delta, + "status": finding.status, + "status_extended": finding.status_extended, + "severity": finding.severity, + "check_id": finding.check_id, + "check_metadata__checktitle": finding.check_metadata["checktitle"], + "muted": finding.muted, + "muted_reason": finding.muted_reason, + } + + # _enrich_and_flatten_batch queries ResourceFindingMapping directly + # No RLS mock needed - test DB doesn't enforce RLS policies + with patch( + "tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS", + "default", + ): + result = prowler_module._enrich_and_flatten_batch([finding_dict]) + + assert len(result) == 1 + assert result[0]["resource_uid"] == resource.uid + assert result[0]["id"] == str(finding.id) + assert result[0]["status"] == "FAIL" + + def test_enrich_and_flatten_batch_multiple_resources( + self, + tenants_fixture, + providers_fixture, + ): + """One finding + three resources = three output dicts""" + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + resources = [] + for i in range(3): + resource = Resource.objects.create( + tenant_id=tenant.id, + provider=provider, + uid=f"resource-uid-{i}", + name=f"Resource {i}", + region="us-east-1", + service="ec2", + type="instance", + ) + resources.append(resource) + + scan = Scan.objects.create( + name="Test Scan", + provider=provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant_id=tenant.id, + ) + + finding = Finding.objects.create( + tenant_id=tenant.id, + uid="finding-uid", + scan=scan, + delta=Finding.DeltaChoices.NEW, + status=StatusChoices.FAIL, + status_extended="failed", + severity=Severity.high, + impact=Severity.high, + impact_extended="", + raw_result={}, + check_id="check-1", + check_metadata={"checktitle": "Check title"}, + first_seen_at=scan.inserted_at, + ) + + # Map finding to all 3 resources + for resource in resources: + ResourceFindingMapping.objects.create( + tenant_id=tenant.id, + resource=resource, + finding=finding, + ) + + finding_dict = { + "id": finding.id, + "uid": finding.uid, + "inserted_at": finding.inserted_at, + "updated_at": finding.updated_at, + "first_seen_at": finding.first_seen_at, + "scan_id": scan.id, + "delta": finding.delta, + "status": finding.status, + "status_extended": finding.status_extended, + "severity": finding.severity, + "check_id": finding.check_id, + "check_metadata__checktitle": finding.check_metadata["checktitle"], + "muted": finding.muted, + "muted_reason": finding.muted_reason, + } + + # _enrich_and_flatten_batch queries ResourceFindingMapping directly + # No RLS mock needed - test DB doesn't enforce RLS policies + with patch( + "tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS", + "default", + ): + result = prowler_module._enrich_and_flatten_batch([finding_dict]) + + assert len(result) == 3 + result_resource_uids = {r["resource_uid"] for r in result} + assert result_resource_uids == {r.uid for r in resources} + + # All should have same finding data + for r in result: + assert r["id"] == str(finding.id) + assert r["status"] == "FAIL" + + def test_enrich_and_flatten_batch_no_resources_skips( + self, + tenants_fixture, + providers_fixture, + ): + """Finding without resources should be skipped""" + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + scan = Scan.objects.create( + name="Test Scan", + provider=provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant_id=tenant.id, + ) + + finding = Finding.objects.create( + tenant_id=tenant.id, + uid="orphan-finding", + scan=scan, + delta=Finding.DeltaChoices.NEW, + status=StatusChoices.FAIL, + status_extended="failed", + severity=Severity.high, + impact=Severity.high, + impact_extended="", + raw_result={}, + check_id="check-1", + check_metadata={"checktitle": "Check title"}, + first_seen_at=scan.inserted_at, + ) + # Note: No ResourceFindingMapping created + + finding_dict = { + "id": finding.id, + "uid": finding.uid, + "inserted_at": finding.inserted_at, + "updated_at": finding.updated_at, + "first_seen_at": finding.first_seen_at, + "scan_id": scan.id, + "delta": finding.delta, + "status": finding.status, + "status_extended": finding.status_extended, + "severity": finding.severity, + "check_id": finding.check_id, + "check_metadata__checktitle": finding.check_metadata["checktitle"], + "muted": finding.muted, + "muted_reason": finding.muted_reason, + } + + # Mock logger to verify no warning is emitted + with ( + patch( + "tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS", + "default", + ), + patch("tasks.jobs.attack_paths.prowler.logger") as mock_logger, + ): + result = prowler_module._enrich_and_flatten_batch([finding_dict]) + + assert len(result) == 0 + mock_logger.warning.assert_not_called() + + def test_generator_is_lazy(self, providers_fixture): + """Generator should not execute queries until iterated""" + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan_id = "some-scan-id" + + with ( + patch("tasks.jobs.attack_paths.prowler.rls_transaction") as mock_rls, + patch("tasks.jobs.attack_paths.prowler.Finding") as mock_finding, + ): + # Create generator but don't iterate + prowler_module.get_provider_last_scan_findings(provider, scan_id) + + # Nothing should be called yet + mock_rls.assert_not_called() + mock_finding.objects.filter.assert_not_called() + + def test_load_findings_empty_generator(self, providers_fixture): + """Empty generator should not call neo4j""" + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + mock_session = MagicMock() + config = SimpleNamespace(update_tag=12345) + + def empty_gen(): + return + yield # Make it a generator + + with ( + patch( + "tasks.jobs.attack_paths.prowler.get_root_node_label", + return_value="AWSAccount", + ), + patch( + "tasks.jobs.attack_paths.prowler.get_node_uid_field", + return_value="arn", + ), + ): + prowler_module.load_findings(mock_session, empty_gen(), provider, config) + + mock_session.run.assert_not_called()