fix(attack-paths): load findings in batches into Neo4j (#9862)

Co-authored-by: Josema Camacho <josema@prowler.com>
This commit is contained in:
Pepe Fagoaga
2026-01-22 18:17:50 +01:00
committed by GitHub
parent 6cb0edf3e1
commit 91e3c01f51
4 changed files with 434 additions and 83 deletions

View File

@@ -4,14 +4,12 @@ All notable changes to the **Prowler API** are documented in this file.
## [1.18.1] (Prowler v5.17.1)
### Changed
### Fixed
- Improve API startup process by `manage.py` argument detection [(#9856)](https://github.com/prowler-cloud/prowler/pull/9856)
- Deleting providers don't try to delete a `None` Neo4j database when an Attack Paths scan is scheduled [(#9858)](https://github.com/prowler-cloud/prowler/pull/9858)
### Fixed
- Use replica database for reading Findings to add them to the Attack Paths graph [(#9861)](https://github.com/prowler-cloud/prowler/pull/9861)
- Attack paths findings loading query to use streaming generator for O(batch_size) memory instead of O(total_findings) [(#9862)](https://github.com/prowler-cloud/prowler/pull/9862)
## [1.18.0] (Prowler v5.17.0)

View File

@@ -1,19 +1,22 @@
from collections import defaultdict
from typing import Generator
import neo4j
from cartography.client.core.tx import run_write_query
from cartography.config import Config as CartographyConfig
from celery.utils.log import get_task_logger
from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.models import Provider, ResourceFindingMapping
from config.env import env
from prowler.config import config as ProwlerConfig
from tasks.jobs.attack_paths.providers import get_node_uid_field, get_root_node_label
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import Finding, Provider, ResourceFindingMapping
from prowler.config import config as ProwlerConfig
logger = get_task_logger(__name__)
BATCH_SIZE = env.int("NEO4J_INSERT_BATCH_SIZE", 500)
BATCH_SIZE = env.int("ATTACK_PATHS_FINDINGS_BATCH_SIZE", 1000)
INDEX_STATEMENTS = [
"CREATE INDEX prowler_finding_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.id);",
@@ -84,9 +87,8 @@ def create_indexes(neo4j_session: neo4j.Session) -> None:
Code based on Cartography version 0.122.0, specifically on `cartography.intel.create_indexes.run`.
"""
logger.info("Creating indexes for Prowler node types.")
logger.info("Creating indexes for Prowler Findings node types")
for statement in INDEX_STATEMENTS:
logger.debug("Executing statement: %s", statement)
run_write_query(neo4j_session, statement)
@@ -96,77 +98,136 @@ def analysis(
scan_id: str,
config: CartographyConfig,
) -> None:
logger.info(f"Getting Prowler findings for AWS account {prowler_api_provider.uid}")
findings_data = get_provider_last_scan_findings(prowler_api_provider, scan_id)
logger.info(f"Loading Prowler findings for AWS account {prowler_api_provider.uid}")
load_findings(neo4j_session, findings_data, prowler_api_provider, config)
logger.info(
f"Cleaning up Prowler findings for AWS account {prowler_api_provider.uid}"
)
cleanup_findings(neo4j_session, prowler_api_provider, config)
def get_provider_last_scan_findings(
prowler_api_provider: Provider,
scan_id: str,
) -> list[dict[str, str]]:
with rls_transaction(prowler_api_provider.tenant_id):
resource_finding_qs = (
ResourceFindingMapping.objects.using(MainRouter.replica_db)
.filter(
finding__scan_id=scan_id,
)
.values(
"resource__uid",
"finding__id",
"finding__uid",
"finding__inserted_at",
"finding__updated_at",
"finding__first_seen_at",
"finding__scan_id",
"finding__delta",
"finding__status",
"finding__status_extended",
"finding__severity",
"finding__check_id",
"finding__check_metadata__checktitle",
"finding__muted",
"finding__muted_reason",
)
)
) -> Generator[list[dict[str, str]], None, None]:
"""
Generator that yields batches of finding-resource pairs.
findings = []
for resource_finding in resource_finding_qs:
findings.append(
Two-step query approach per batch:
1. Paginate findings for scan (single table, indexed by scan_id)
2. Batch-fetch resource UIDs via mapping table (single join)
3. Merge and yield flat structure for Neo4j
Memory efficient: never holds more than BATCH_SIZE findings in memory.
"""
logger.info(
f"Starting findings fetch for scan {scan_id} (tenant {prowler_api_provider.tenant_id}) with batch size {BATCH_SIZE}"
)
iteration = 0
last_id = None
while True:
iteration += 1
with rls_transaction(prowler_api_provider.tenant_id, using=READ_REPLICA_ALIAS):
qs = Finding.objects.filter(scan_id=scan_id).order_by("id")
if last_id is not None:
qs = qs.filter(id__gt=last_id)
findings_batch = list(
qs.values(
"id",
"uid",
"inserted_at",
"updated_at",
"first_seen_at",
"scan_id",
"delta",
"status",
"status_extended",
"severity",
"check_id",
"check_metadata__checktitle",
"muted",
"muted_reason",
)[:BATCH_SIZE]
)
logger.info(
f"Iteration #{iteration} fetched {len(findings_batch)} findings"
)
if not findings_batch:
logger.info(
f"No findings returned for iteration #{iteration}; stopping pagination"
)
break
last_id = findings_batch[-1]["id"]
enriched_batch = _enrich_and_flatten_batch(findings_batch)
# Yield outside the transaction
if enriched_batch:
yield enriched_batch
logger.info(f"Finished fetching findings for scan {scan_id}")
def _enrich_and_flatten_batch(
findings_batch: list[dict],
) -> list[dict[str, str]]:
"""
Fetch resource UIDs for a batch of findings and return flat structure.
One finding with 3 resources becomes 3 dicts (same output format as before).
Must be called within an RLS transaction context.
"""
finding_ids = [f["id"] for f in findings_batch]
# Single join: mapping -> resource
resource_mappings = ResourceFindingMapping.objects.filter(
finding_id__in=finding_ids
).values_list("finding_id", "resource__uid")
# Build finding_id -> [resource_uids] mapping
finding_resources = defaultdict(list)
for finding_id, resource_uid in resource_mappings:
finding_resources[finding_id].append(resource_uid)
# Flatten: one dict per (finding, resource) pair
results = []
for f in findings_batch:
resource_uids = finding_resources.get(f["id"], [])
if not resource_uids:
continue
for resource_uid in resource_uids:
results.append(
{
"resource_uid": str(resource_finding["resource__uid"]),
"id": str(resource_finding["finding__id"]),
"uid": resource_finding["finding__uid"],
"inserted_at": resource_finding["finding__inserted_at"],
"updated_at": resource_finding["finding__updated_at"],
"first_seen_at": resource_finding["finding__first_seen_at"],
"scan_id": str(resource_finding["finding__scan_id"]),
"delta": resource_finding["finding__delta"],
"status": resource_finding["finding__status"],
"status_extended": resource_finding["finding__status_extended"],
"severity": resource_finding["finding__severity"],
"check_id": str(resource_finding["finding__check_id"]),
"check_title": resource_finding[
"finding__check_metadata__checktitle"
],
"muted": resource_finding["finding__muted"],
"muted_reason": resource_finding["finding__muted_reason"],
"resource_uid": str(resource_uid),
"id": str(f["id"]),
"uid": f["uid"],
"inserted_at": f["inserted_at"],
"updated_at": f["updated_at"],
"first_seen_at": f["first_seen_at"],
"scan_id": str(f["scan_id"]),
"delta": f["delta"],
"status": f["status"],
"status_extended": f["status_extended"],
"severity": f["severity"],
"check_id": str(f["check_id"]),
"check_title": f["check_metadata__checktitle"],
"muted": f["muted"],
"muted_reason": f["muted_reason"],
}
)
return findings
return results
def load_findings(
neo4j_session: neo4j.Session,
findings_data: list[dict[str, str]],
findings_batches: Generator[list[dict[str, str]], None, None],
prowler_api_provider: Provider,
config: CartographyConfig,
) -> None:
@@ -184,16 +245,20 @@ def load_findings(
"prowler_version": ProwlerConfig.prowler_version,
}
total_length = len(findings_data)
for i in range(0, total_length, BATCH_SIZE):
parameters["findings_data"] = findings_data[i : i + BATCH_SIZE]
batch_num = 0
total_records = 0
for batch in findings_batches:
batch_num += 1
batch_size = len(batch)
total_records += batch_size
logger.info(
f"Loading findings batch {i // BATCH_SIZE + 1} / {(total_length + BATCH_SIZE - 1) // BATCH_SIZE}"
)
parameters["findings_data"] = batch
logger.info(f"Loading findings batch {batch_num} ({batch_size} records)")
neo4j_session.run(query, parameters)
logger.info(f"Finished loading {total_records} records in {batch_num} batches")
def cleanup_findings(
neo4j_session: neo4j.Session,

View File

@@ -3,6 +3,8 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, call, patch
import pytest
from tasks.jobs.attack_paths import prowler as prowler_module
from tasks.jobs.attack_paths.scan import run as attack_paths_run
from api.models import (
AttackPathsScan,
@@ -15,8 +17,6 @@ from api.models import (
StatusChoices,
)
from prowler.lib.check.models import Severity
from tasks.jobs.attack_paths import prowler as prowler_module
from tasks.jobs.attack_paths.scan import run as attack_paths_run
@pytest.mark.django_db
@@ -276,15 +276,15 @@ class TestAttackPathsProwlerHelpers:
provider.provider = Provider.ProviderChoices.AWS
provider.save()
findings = [
{"id": "1", "resource_uid": "r-1"},
{"id": "2", "resource_uid": "r-2"},
]
# Create a generator that yields two batches
def findings_generator():
yield [{"id": "1", "resource_uid": "r-1"}]
yield [{"id": "2", "resource_uid": "r-2"}]
config = SimpleNamespace(update_tag=12345)
mock_session = MagicMock()
with (
patch.object(prowler_module, "BATCH_SIZE", 1),
patch(
"tasks.jobs.attack_paths.prowler.get_root_node_label",
return_value="AWSAccount",
@@ -294,7 +294,9 @@ class TestAttackPathsProwlerHelpers:
return_value="arn",
),
):
prowler_module.load_findings(mock_session, findings, provider, config)
prowler_module.load_findings(
mock_session, findings_generator(), provider, config
)
assert mock_session.run.call_count == 2
for call_args in mock_session.run.call_args_list:
@@ -403,13 +405,17 @@ class TestAttackPathsProwlerHelpers:
"tasks.jobs.attack_paths.prowler.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
), patch(
"tasks.jobs.attack_paths.prowler.MainRouter.replica_db",
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
"default",
):
findings_data = prowler_module.get_provider_last_scan_findings(
# Generator yields batches, collect all findings from all batches
findings_batches = prowler_module.get_provider_last_scan_findings(
provider,
str(latest_scan.id),
)
findings_data = []
for batch in findings_batches:
findings_data.extend(batch)
assert len(findings_data) == 1
finding_dict = findings_data[0]
@@ -417,3 +423,285 @@ class TestAttackPathsProwlerHelpers:
assert finding_dict["resource_uid"] == resource.uid
assert finding_dict["check_title"] == "Check title"
assert finding_dict["scan_id"] == str(latest_scan.id)
def test_enrich_and_flatten_batch_single_resource(
self,
tenants_fixture,
providers_fixture,
):
"""One finding + one resource = one output dict"""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
resource = Resource.objects.create(
tenant_id=tenant.id,
provider=provider,
uid="resource-uid-1",
name="Resource 1",
region="us-east-1",
service="ec2",
type="instance",
)
scan = Scan.objects.create(
name="Test Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
)
finding = Finding.objects.create(
tenant_id=tenant.id,
uid="finding-uid",
scan=scan,
delta=Finding.DeltaChoices.NEW,
status=StatusChoices.FAIL,
status_extended="failed",
severity=Severity.high,
impact=Severity.high,
impact_extended="",
raw_result={},
check_id="check-1",
check_metadata={"checktitle": "Check title"},
first_seen_at=scan.inserted_at,
)
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
resource=resource,
finding=finding,
)
# Simulate the dict returned by .values()
finding_dict = {
"id": finding.id,
"uid": finding.uid,
"inserted_at": finding.inserted_at,
"updated_at": finding.updated_at,
"first_seen_at": finding.first_seen_at,
"scan_id": scan.id,
"delta": finding.delta,
"status": finding.status,
"status_extended": finding.status_extended,
"severity": finding.severity,
"check_id": finding.check_id,
"check_metadata__checktitle": finding.check_metadata["checktitle"],
"muted": finding.muted,
"muted_reason": finding.muted_reason,
}
# _enrich_and_flatten_batch queries ResourceFindingMapping directly
# No RLS mock needed - test DB doesn't enforce RLS policies
with patch(
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
"default",
):
result = prowler_module._enrich_and_flatten_batch([finding_dict])
assert len(result) == 1
assert result[0]["resource_uid"] == resource.uid
assert result[0]["id"] == str(finding.id)
assert result[0]["status"] == "FAIL"
def test_enrich_and_flatten_batch_multiple_resources(
self,
tenants_fixture,
providers_fixture,
):
"""One finding + three resources = three output dicts"""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
resources = []
for i in range(3):
resource = Resource.objects.create(
tenant_id=tenant.id,
provider=provider,
uid=f"resource-uid-{i}",
name=f"Resource {i}",
region="us-east-1",
service="ec2",
type="instance",
)
resources.append(resource)
scan = Scan.objects.create(
name="Test Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
)
finding = Finding.objects.create(
tenant_id=tenant.id,
uid="finding-uid",
scan=scan,
delta=Finding.DeltaChoices.NEW,
status=StatusChoices.FAIL,
status_extended="failed",
severity=Severity.high,
impact=Severity.high,
impact_extended="",
raw_result={},
check_id="check-1",
check_metadata={"checktitle": "Check title"},
first_seen_at=scan.inserted_at,
)
# Map finding to all 3 resources
for resource in resources:
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
resource=resource,
finding=finding,
)
finding_dict = {
"id": finding.id,
"uid": finding.uid,
"inserted_at": finding.inserted_at,
"updated_at": finding.updated_at,
"first_seen_at": finding.first_seen_at,
"scan_id": scan.id,
"delta": finding.delta,
"status": finding.status,
"status_extended": finding.status_extended,
"severity": finding.severity,
"check_id": finding.check_id,
"check_metadata__checktitle": finding.check_metadata["checktitle"],
"muted": finding.muted,
"muted_reason": finding.muted_reason,
}
# _enrich_and_flatten_batch queries ResourceFindingMapping directly
# No RLS mock needed - test DB doesn't enforce RLS policies
with patch(
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
"default",
):
result = prowler_module._enrich_and_flatten_batch([finding_dict])
assert len(result) == 3
result_resource_uids = {r["resource_uid"] for r in result}
assert result_resource_uids == {r.uid for r in resources}
# All should have same finding data
for r in result:
assert r["id"] == str(finding.id)
assert r["status"] == "FAIL"
def test_enrich_and_flatten_batch_no_resources_skips(
self,
tenants_fixture,
providers_fixture,
):
"""Finding without resources should be skipped"""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = Scan.objects.create(
name="Test Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
)
finding = Finding.objects.create(
tenant_id=tenant.id,
uid="orphan-finding",
scan=scan,
delta=Finding.DeltaChoices.NEW,
status=StatusChoices.FAIL,
status_extended="failed",
severity=Severity.high,
impact=Severity.high,
impact_extended="",
raw_result={},
check_id="check-1",
check_metadata={"checktitle": "Check title"},
first_seen_at=scan.inserted_at,
)
# Note: No ResourceFindingMapping created
finding_dict = {
"id": finding.id,
"uid": finding.uid,
"inserted_at": finding.inserted_at,
"updated_at": finding.updated_at,
"first_seen_at": finding.first_seen_at,
"scan_id": scan.id,
"delta": finding.delta,
"status": finding.status,
"status_extended": finding.status_extended,
"severity": finding.severity,
"check_id": finding.check_id,
"check_metadata__checktitle": finding.check_metadata["checktitle"],
"muted": finding.muted,
"muted_reason": finding.muted_reason,
}
# Mock logger to verify no warning is emitted
with (
patch(
"tasks.jobs.attack_paths.prowler.READ_REPLICA_ALIAS",
"default",
),
patch("tasks.jobs.attack_paths.prowler.logger") as mock_logger,
):
result = prowler_module._enrich_and_flatten_batch([finding_dict])
assert len(result) == 0
mock_logger.warning.assert_not_called()
def test_generator_is_lazy(self, providers_fixture):
"""Generator should not execute queries until iterated"""
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan_id = "some-scan-id"
with (
patch("tasks.jobs.attack_paths.prowler.rls_transaction") as mock_rls,
patch("tasks.jobs.attack_paths.prowler.Finding") as mock_finding,
):
# Create generator but don't iterate
prowler_module.get_provider_last_scan_findings(provider, scan_id)
# Nothing should be called yet
mock_rls.assert_not_called()
mock_finding.objects.filter.assert_not_called()
def test_load_findings_empty_generator(self, providers_fixture):
"""Empty generator should not call neo4j"""
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
mock_session = MagicMock()
config = SimpleNamespace(update_tag=12345)
def empty_gen():
return
yield # Make it a generator
with (
patch(
"tasks.jobs.attack_paths.prowler.get_root_node_label",
return_value="AWSAccount",
),
patch(
"tasks.jobs.attack_paths.prowler.get_node_uid_field",
return_value="arn",
),
):
prowler_module.load_findings(mock_session, empty_gen(), provider, config)
mock_session.run.assert_not_called()