From 1da10611e7eadac24a9a6619c232a35811965a92 Mon Sep 17 00:00:00 2001 From: Josema Camacho Date: Wed, 18 Mar 2026 10:08:30 +0100 Subject: [PATCH] perf(attack-paths): reduce sync and findings memory usage with smaller batches and cursor iteration (#10359) --- api/CHANGELOG.md | 1 + .../backend/tasks/jobs/attack_paths/config.py | 8 +- .../tasks/jobs/attack_paths/findings.py | 138 ++++------ .../backend/tasks/jobs/attack_paths/sync.py | 140 +++++----- .../tasks/tests/test_attack_paths_scan.py | 253 +++++++++++++++--- 5 files changed, 351 insertions(+), 189 deletions(-) diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index c7ff2bf99e..15702e342d 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -12,6 +12,7 @@ All notable changes to the **Prowler API** are documented in this file. - Attack Paths: Complete migration to private graph labels and properties, removing deprecated dual-write support [(#10268)](https://github.com/prowler-cloud/prowler/pull/10268) - Attack Paths: Added tenant and provider related labels to the nodes so they can be easily filtered on custom queries [(#10308)](https://github.com/prowler-cloud/prowler/pull/10308) +- Attack Paths: Reduce sync and findings memory usage with smaller batches, cursor iteration, and sequential sessions [(#10359)](https://github.com/prowler-cloud/prowler/pull/10359) ### 🐞 Fixed diff --git a/api/src/backend/tasks/jobs/attack_paths/config.py b/api/src/backend/tasks/jobs/attack_paths/config.py index be9305d8a5..315d2ecd24 100644 --- a/api/src/backend/tasks/jobs/attack_paths/config.py +++ b/api/src/backend/tasks/jobs/attack_paths/config.py @@ -3,12 +3,14 @@ from typing import Callable from uuid import UUID from config.env import env - from tasks.jobs.attack_paths import aws - -# Batch size for Neo4j operations +# Batch size for Neo4j write operations (resource labeling, cleanup) BATCH_SIZE = env.int("ATTACK_PATHS_BATCH_SIZE", 1000) +# Batch size for Postgres findings fetch (keyset pagination page size) +FINDINGS_BATCH_SIZE = env.int("ATTACK_PATHS_FINDINGS_BATCH_SIZE", 500) +# Batch size for temp-to-tenant graph sync (nodes and relationships per cursor page) +SYNC_BATCH_SIZE = env.int("ATTACK_PATHS_SYNC_BATCH_SIZE", 250) # Neo4j internal labels (Prowler-specific, not provider-specific) # - `Internet`: Singleton node representing external internet access for exposed-resource queries diff --git a/api/src/backend/tasks/jobs/attack_paths/findings.py b/api/src/backend/tasks/jobs/attack_paths/findings.py index b4534fb8de..9a9f365911 100644 --- a/api/src/backend/tasks/jobs/attack_paths/findings.py +++ b/api/src/backend/tasks/jobs/attack_paths/findings.py @@ -9,22 +9,15 @@ This module handles: """ from collections import defaultdict -from dataclasses import asdict, dataclass, fields from typing import Any, Generator from uuid import UUID import neo4j - from cartography.config import Config as CartographyConfig from celery.utils.log import get_task_logger - -from api.db_router import READ_REPLICA_ALIAS -from api.db_utils import rls_transaction -from api.models import Finding as FindingModel -from api.models import Provider, ResourceFindingMapping -from prowler.config import config as ProwlerConfig from tasks.jobs.attack_paths.config import ( BATCH_SIZE, + FINDINGS_BATCH_SIZE, get_node_uid_field, get_provider_resource_label, get_root_node_label, @@ -37,75 +30,54 @@ from tasks.jobs.attack_paths.queries import ( render_cypher_template, ) +from api.db_router import READ_REPLICA_ALIAS +from api.db_utils import rls_transaction +from api.models import Finding as FindingModel +from api.models import Provider, ResourceFindingMapping +from prowler.config import config as ProwlerConfig + logger = get_task_logger(__name__) -# Type Definitions -# ----------------- - -# Maps dataclass field names to Django ORM query field names -_DB_FIELD_MAP: dict[str, str] = { - "check_title": "check_metadata__checktitle", -} +# Django ORM field names for `.values()` queries +# Most map 1:1 to Neo4j property names, exceptions are remapped in `_to_neo4j_dict` +_DB_QUERY_FIELDS = [ + "id", + "uid", + "inserted_at", + "updated_at", + "first_seen_at", + "scan_id", + "delta", + "status", + "status_extended", + "severity", + "check_id", + "check_metadata__checktitle", + "muted", + "muted_reason", +] -@dataclass(slots=True) -class Finding: - """ - Finding data for Neo4j ingestion. - - Can be created from a Django .values() query result using from_db_record(). - """ - - id: str - uid: str - inserted_at: str - updated_at: str - first_seen_at: str - scan_id: str - delta: str - status: str - status_extended: str - severity: str - check_id: str - check_title: str - muted: bool - muted_reason: str | None - resource_uid: str | None = None - - @classmethod - def get_db_query_fields(cls) -> tuple[str, ...]: - """Get field names for Django .values() query.""" - return tuple( - _DB_FIELD_MAP.get(f.name, f.name) - for f in fields(cls) - if f.name != "resource_uid" - ) - - @classmethod - def from_db_record(cls, record: dict[str, Any], resource_uid: str) -> "Finding": - """Create a Finding from a Django .values() query result.""" - return cls( - id=str(record["id"]), - uid=record["uid"], - inserted_at=record["inserted_at"], - updated_at=record["updated_at"], - first_seen_at=record["first_seen_at"], - scan_id=str(record["scan_id"]), - delta=record["delta"], - status=record["status"], - status_extended=record["status_extended"], - severity=record["severity"], - check_id=str(record["check_id"]), - check_title=record["check_metadata__checktitle"], - muted=record["muted"], - muted_reason=record["muted_reason"], - resource_uid=resource_uid, - ) - - def to_dict(self) -> dict[str, Any]: - """Convert to dict for Neo4j ingestion.""" - return asdict(self) +def _to_neo4j_dict(record: dict[str, Any], resource_uid: str) -> dict[str, Any]: + """Transform a Django `.values()` record into a `dict` ready for Neo4j ingestion.""" + return { + "id": str(record["id"]), + "uid": record["uid"], + "inserted_at": record["inserted_at"], + "updated_at": record["updated_at"], + "first_seen_at": record["first_seen_at"], + "scan_id": str(record["scan_id"]), + "delta": record["delta"], + "status": record["status"], + "status_extended": record["status_extended"], + "severity": record["severity"], + "check_id": str(record["check_id"]), + "check_title": record["check_metadata__checktitle"], + "muted": record["muted"], + "muted_reason": record["muted_reason"], + "resource_uid": resource_uid, + } # Public API @@ -180,7 +152,7 @@ def add_resource_label( def load_findings( neo4j_session: neo4j.Session, - findings_batches: Generator[list[Finding], None, None], + findings_batches: Generator[list[dict[str, Any]], None, None], prowler_api_provider: Provider, config: CartographyConfig, ) -> None: @@ -209,7 +181,7 @@ def load_findings( batch_size = len(batch) total_records += batch_size - parameters["findings_data"] = [f.to_dict() for f in batch] + parameters["findings_data"] = batch logger.info(f"Loading findings batch {batch_num} ({batch_size} records)") neo4j_session.run(query, parameters) @@ -247,16 +219,17 @@ def cleanup_findings( def stream_findings_with_resources( prowler_api_provider: Provider, scan_id: str, -) -> Generator[list[Finding], None, None]: +) -> Generator[list[dict[str, Any]], None, None]: """ Stream findings with their associated resources in batches. Uses keyset pagination for efficient traversal of large datasets. - Memory efficient: yields one batch at a time, never holds all findings in memory. + Memory efficient: yields one batch at a time as dicts ready for Neo4j ingestion, + never holds all findings in memory. """ logger.info( f"Starting findings stream for scan {scan_id} " - f"(tenant {prowler_api_provider.tenant_id}) with batch size {BATCH_SIZE}" + f"(tenant {prowler_api_provider.tenant_id}) with batch size {FINDINGS_BATCH_SIZE}" ) tenant_id = prowler_api_provider.tenant_id @@ -305,15 +278,14 @@ def _fetch_findings_batch( Uses read replica and RLS-scoped transaction. """ with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): - # Use all_objects to avoid the ActiveProviderManager's implicit JOIN - # through Scan -> Provider (to check is_deleted=False). - # The provider is already validated as active in this context. + # Use `all_objects` to get `Findings` even on soft-deleted `Providers` + # But even the provider is already validated as active in this context qs = FindingModel.all_objects.filter(scan_id=scan_id).order_by("id") if after_id is not None: qs = qs.filter(id__gt=after_id) - return list(qs.values(*Finding.get_db_query_fields())[:BATCH_SIZE]) + return list(qs.values(*_DB_QUERY_FIELDS)[:FINDINGS_BATCH_SIZE]) # Batch Enrichment @@ -323,7 +295,7 @@ def _fetch_findings_batch( def _enrich_batch_with_resources( findings_batch: list[dict[str, Any]], tenant_id: str, -) -> list[Finding]: +) -> list[dict[str, Any]]: """ Enrich findings with their resource UIDs. @@ -334,7 +306,7 @@ def _enrich_batch_with_resources( resource_map = _build_finding_resource_map(finding_ids, tenant_id) return [ - Finding.from_db_record(finding, resource_uid) + _to_neo4j_dict(finding, resource_uid) for finding in findings_batch for resource_uid in resource_map.get(finding["id"], []) ] diff --git a/api/src/backend/tasks/jobs/attack_paths/sync.py b/api/src/backend/tasks/jobs/attack_paths/sync.py index c0b1799b9a..870aba6fa8 100644 --- a/api/src/backend/tasks/jobs/attack_paths/sync.py +++ b/api/src/backend/tasks/jobs/attack_paths/sync.py @@ -8,13 +8,14 @@ to the tenant database, adding provider isolation labels and properties. from collections import defaultdict from typing import Any +import neo4j from celery.utils.log import get_task_logger from api.attack_paths import database as graph_database from tasks.jobs.attack_paths.config import ( - BATCH_SIZE, PROVIDER_ISOLATION_PROPERTIES, PROVIDER_RESOURCE_LABEL, + SYNC_BATCH_SIZE, get_provider_label, get_tenant_label, ) @@ -82,40 +83,32 @@ def sync_nodes( Adds `_ProviderResource` label and `_provider_id` property to all nodes. Also adds dynamic `_Tenant_{id}` and `_Provider_{id}` isolation labels. + + Source and target sessions are opened sequentially per batch to avoid + holding two Bolt connections simultaneously for the entire sync duration. """ last_id = -1 total_synced = 0 - with ( - graph_database.get_session(source_database) as source_session, - graph_database.get_session(target_database) as target_session, - ): - while True: - rows = list( - source_session.run( - NODE_FETCH_QUERY, - {"last_id": last_id, "batch_size": BATCH_SIZE}, - ) + while True: + grouped: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list) + batch_count = 0 + + with graph_database.get_session(source_database) as source_session: + result = source_session.run( + NODE_FETCH_QUERY, + {"last_id": last_id, "batch_size": SYNC_BATCH_SIZE}, ) + for record in result: + batch_count += 1 + last_id = record["internal_id"] + key, value = _node_to_sync_dict(record, provider_id) + grouped[key].append(value) - if not rows: - break - - last_id = rows[-1]["internal_id"] - - grouped: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list) - for row in rows: - labels = tuple(sorted(set(row["labels"] or []))) - props = dict(row["props"] or {}) - _strip_internal_properties(props) - provider_element_id = f"{provider_id}:{row['element_id']}" - grouped[labels].append( - { - "provider_element_id": provider_element_id, - "props": props, - } - ) + if batch_count == 0: + break + with graph_database.get_session(target_database) as target_session: for labels, batch in grouped.items(): label_set = set(labels) label_set.add(PROVIDER_RESOURCE_LABEL) @@ -134,10 +127,10 @@ def sync_nodes( }, ) - total_synced += len(rows) - logger.info( - f"Synced {total_synced} nodes from {source_database} to {target_database}" - ) + total_synced += batch_count + logger.info( + f"Synced {total_synced} nodes from {source_database} to {target_database}" + ) return total_synced @@ -151,41 +144,32 @@ def sync_relationships( Sync relationships from source to target database. Adds `_provider_id` property to all relationships. + + Source and target sessions are opened sequentially per batch to avoid + holding two Bolt connections simultaneously for the entire sync duration. """ last_id = -1 total_synced = 0 - with ( - graph_database.get_session(source_database) as source_session, - graph_database.get_session(target_database) as target_session, - ): - while True: - rows = list( - source_session.run( - RELATIONSHIPS_FETCH_QUERY, - {"last_id": last_id, "batch_size": BATCH_SIZE}, - ) + while True: + grouped: dict[str, list[dict[str, Any]]] = defaultdict(list) + batch_count = 0 + + with graph_database.get_session(source_database) as source_session: + result = source_session.run( + RELATIONSHIPS_FETCH_QUERY, + {"last_id": last_id, "batch_size": SYNC_BATCH_SIZE}, ) + for record in result: + batch_count += 1 + last_id = record["internal_id"] + key, value = _rel_to_sync_dict(record, provider_id) + grouped[key].append(value) - if not rows: - break - - last_id = rows[-1]["internal_id"] - - grouped: dict[str, list[dict[str, Any]]] = defaultdict(list) - for row in rows: - props = dict(row["props"] or {}) - _strip_internal_properties(props) - rel_type = row["rel_type"] - grouped[rel_type].append( - { - "start_element_id": f"{provider_id}:{row['start_element_id']}", - "end_element_id": f"{provider_id}:{row['end_element_id']}", - "provider_element_id": f"{provider_id}:{rel_type}:{row['internal_id']}", - "props": props, - } - ) + if batch_count == 0: + break + with graph_database.get_session(target_database) as target_session: for rel_type, batch in grouped.items(): query = render_cypher_template( RELATIONSHIP_SYNC_TEMPLATE, {"__REL_TYPE__": rel_type} @@ -198,14 +182,42 @@ def sync_relationships( }, ) - total_synced += len(rows) - logger.info( - f"Synced {total_synced} relationships from {source_database} to {target_database}" - ) + total_synced += batch_count + logger.info( + f"Synced {total_synced} relationships from {source_database} to {target_database}" + ) return total_synced +def _node_to_sync_dict( + record: neo4j.Record, provider_id: str +) -> tuple[tuple[str, ...], dict[str, Any]]: + """Transform a source node record into a (grouping_key, sync_dict) pair.""" + props = dict(record["props"] or {}) + _strip_internal_properties(props) + labels = tuple(sorted(set(record["labels"] or []))) + return labels, { + "provider_element_id": f"{provider_id}:{record['element_id']}", + "props": props, + } + + +def _rel_to_sync_dict( + record: neo4j.Record, provider_id: str +) -> tuple[str, dict[str, Any]]: + """Transform a source relationship record into a (grouping_key, sync_dict) pair.""" + props = dict(record["props"] or {}) + _strip_internal_properties(props) + rel_type = record["rel_type"] + return rel_type, { + "start_element_id": f"{provider_id}:{record['start_element_id']}", + "end_element_id": f"{provider_id}:{record['end_element_id']}", + "provider_element_id": f"{provider_id}:{rel_type}:{record['internal_id']}", + "props": props, + } + + def _strip_internal_properties(props: dict[str, Any]) -> None: """Remove provider isolation properties before the += spread in sync templates.""" for key in PROVIDER_ISOLATION_PROPERTIES: diff --git a/api/src/backend/tasks/tests/test_attack_paths_scan.py b/api/src/backend/tasks/tests/test_attack_paths_scan.py index 43a2420321..dc621f173b 100644 --- a/api/src/backend/tasks/tests/test_attack_paths_scan.py +++ b/api/src/backend/tasks/tests/test_attack_paths_scan.py @@ -1279,16 +1279,10 @@ class TestAttackPathsFindingsHelpers: provider.provider = Provider.ProviderChoices.AWS provider.save() - # Create mock Finding objects with to_dict() method - mock_finding_1 = MagicMock() - mock_finding_1.to_dict.return_value = {"id": "1", "resource_uid": "r-1"} - mock_finding_2 = MagicMock() - mock_finding_2.to_dict.return_value = {"id": "2", "resource_uid": "r-2"} - - # Create a generator that yields two batches of Finding instances + # Create a generator that yields two batches of dicts (pre-converted) def findings_generator(): - yield [mock_finding_1] - yield [mock_finding_2] + yield [{"id": "1", "resource_uid": "r-1"}] + yield [{"id": "2", "resource_uid": "r-2"}] config = SimpleNamespace(update_tag=12345) mock_session = MagicMock() @@ -1435,17 +1429,17 @@ class TestAttackPathsFindingsHelpers: assert len(findings_data) == 1 finding_result = findings_data[0] - assert finding_result.id == str(finding.id) - assert finding_result.resource_uid == resource.uid - assert finding_result.check_title == "Check title" - assert finding_result.scan_id == str(latest_scan.id) + assert finding_result["id"] == str(finding.id) + assert finding_result["resource_uid"] == resource.uid + assert finding_result["check_title"] == "Check title" + assert finding_result["scan_id"] == str(latest_scan.id) def test_enrich_batch_with_resources_single_resource( self, tenants_fixture, providers_fixture, ): - """One finding + one resource = one output Finding instance""" + """One finding + one resource = one output dict""" tenant = tenants_fixture[0] provider = providers_fixture[0] provider.provider = Provider.ProviderChoices.AWS @@ -1519,16 +1513,16 @@ class TestAttackPathsFindingsHelpers: ) assert len(result) == 1 - assert result[0].resource_uid == resource.uid - assert result[0].id == str(finding.id) - assert result[0].status == "FAIL" + assert result[0]["resource_uid"] == resource.uid + assert result[0]["id"] == str(finding.id) + assert result[0]["status"] == "FAIL" def test_enrich_batch_with_resources_multiple_resources( self, tenants_fixture, providers_fixture, ): - """One finding + three resources = three output Finding instances""" + """One finding + three resources = three output dicts""" tenant = tenants_fixture[0] provider = providers_fixture[0] provider.provider = Provider.ProviderChoices.AWS @@ -1607,13 +1601,13 @@ class TestAttackPathsFindingsHelpers: ) assert len(result) == 3 - result_resource_uids = {r.resource_uid for r in result} + result_resource_uids = {r["resource_uid"] for r in result} assert result_resource_uids == {r.uid for r in resources} # All should have same finding data for r in result: - assert r.id == str(finding.id) - assert r.status == "FAIL" + assert r["id"] == str(finding.id) + assert r["status"] == "FAIL" def test_enrich_batch_with_resources_no_resources_skips( self, @@ -1690,16 +1684,12 @@ class TestAttackPathsFindingsHelpers: provider.save() scan_id = "some-scan-id" - with ( - patch("tasks.jobs.attack_paths.findings.rls_transaction") as mock_rls, - patch("tasks.jobs.attack_paths.findings.Finding") as mock_finding, - ): + with patch("tasks.jobs.attack_paths.findings.rls_transaction") as mock_rls: # Create generator but don't iterate findings_module.stream_findings_with_resources(provider, scan_id) # Nothing should be called yet mock_rls.assert_not_called() - mock_finding.objects.filter.assert_not_called() def test_load_findings_empty_generator(self, providers_fixture): """Empty generator should not call neo4j""" @@ -1752,41 +1742,226 @@ class TestAddResourceLabel: assert "AWSResource" not in query.replace("_AWSResource", "") +def _make_session_ctx(session, call_order=None, name=None): + """Create a mock context manager wrapping a mock session.""" + ctx = MagicMock() + if call_order is not None and name is not None: + ctx.__enter__ = MagicMock( + side_effect=lambda: (call_order.append(f"{name}:enter"), session)[1] + ) + ctx.__exit__ = MagicMock( + side_effect=lambda *a: (call_order.append(f"{name}:exit"), False)[1] + ) + else: + ctx.__enter__ = MagicMock(return_value=session) + ctx.__exit__ = MagicMock(return_value=False) + return ctx + + class TestSyncNodes: def test_sync_nodes_adds_private_label(self): - mock_source_session = MagicMock() - mock_target_session = MagicMock() - row = { "internal_id": 1, "element_id": "elem-1", "labels": ["SomeLabel"], "props": {"key": "value"}, } - mock_source_session.run.side_effect = [[row], []] - source_ctx = MagicMock() - source_ctx.__enter__ = MagicMock(return_value=mock_source_session) - source_ctx.__exit__ = MagicMock(return_value=False) - - target_ctx = MagicMock() - target_ctx.__enter__ = MagicMock(return_value=mock_target_session) - target_ctx.__exit__ = MagicMock(return_value=False) + mock_source_1 = MagicMock() + mock_source_1.run.return_value = [row] + mock_target = MagicMock() + mock_source_2 = MagicMock() + mock_source_2.run.return_value = [] with patch( "tasks.jobs.attack_paths.sync.graph_database.get_session", - side_effect=[source_ctx, target_ctx], + side_effect=[ + _make_session_ctx(mock_source_1), + _make_session_ctx(mock_target), + _make_session_ctx(mock_source_2), + ], ): total = sync_module.sync_nodes( "source-db", "target-db", "tenant-1", "prov-1" ) assert total == 1 - query = mock_target_session.run.call_args.args[0] + query = mock_target.run.call_args.args[0] assert "_ProviderResource" in query assert "_Tenant_tenant1" in query assert "_Provider_prov1" in query + def test_sync_nodes_source_closes_before_target_opens(self): + row = { + "internal_id": 1, + "element_id": "elem-1", + "labels": ["SomeLabel"], + "props": {"key": "value"}, + } + + call_order = [] + + src_1 = MagicMock() + src_1.run.return_value = [row] + tgt = MagicMock() + src_2 = MagicMock() + src_2.run.return_value = [] + + with patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[ + _make_session_ctx(src_1, call_order, "source1"), + _make_session_ctx(tgt, call_order, "target"), + _make_session_ctx(src_2, call_order, "source2"), + ], + ): + sync_module.sync_nodes("src-db", "tgt-db", "t-1", "p-1") + + assert call_order.index("source1:exit") < call_order.index("target:enter") + + def test_sync_nodes_pagination_with_batch_size_1(self): + row_a = { + "internal_id": 1, + "element_id": "elem-1", + "labels": ["LabelA"], + "props": {"a": 1}, + } + row_b = { + "internal_id": 2, + "element_id": "elem-2", + "labels": ["LabelB"], + "props": {"b": 2}, + } + + src_1 = MagicMock() + src_1.run.return_value = [row_a] + src_2 = MagicMock() + src_2.run.return_value = [row_b] + src_3 = MagicMock() + src_3.run.return_value = [] + tgt_1 = MagicMock() + tgt_2 = MagicMock() + + with ( + patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[ + _make_session_ctx(src_1), + _make_session_ctx(tgt_1), + _make_session_ctx(src_2), + _make_session_ctx(tgt_2), + _make_session_ctx(src_3), + ], + ), + patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1), + ): + total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1") + + assert total == 2 + assert src_1.run.call_args.args[1]["last_id"] == -1 + assert src_2.run.call_args.args[1]["last_id"] == 1 + + def test_sync_nodes_empty_source_returns_zero(self): + src = MagicMock() + src.run.return_value = [] + + with patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[_make_session_ctx(src)], + ) as mock_get_session: + total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1") + + assert total == 0 + assert mock_get_session.call_count == 1 + + +class TestSyncRelationships: + def test_sync_relationships_source_closes_before_target_opens(self): + row = { + "internal_id": 1, + "rel_type": "HAS", + "start_element_id": "s-1", + "end_element_id": "e-1", + "props": {}, + } + + call_order = [] + + src_1 = MagicMock() + src_1.run.return_value = [row] + tgt = MagicMock() + src_2 = MagicMock() + src_2.run.return_value = [] + + with patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[ + _make_session_ctx(src_1, call_order, "source1"), + _make_session_ctx(tgt, call_order, "target"), + _make_session_ctx(src_2, call_order, "source2"), + ], + ): + sync_module.sync_relationships("src", "tgt", "p-1") + + assert call_order.index("source1:exit") < call_order.index("target:enter") + + def test_sync_relationships_pagination_with_batch_size_1(self): + row_a = { + "internal_id": 1, + "rel_type": "HAS", + "start_element_id": "s-1", + "end_element_id": "e-1", + "props": {"a": 1}, + } + row_b = { + "internal_id": 2, + "rel_type": "CONNECTS", + "start_element_id": "s-2", + "end_element_id": "e-2", + "props": {"b": 2}, + } + + src_1 = MagicMock() + src_1.run.return_value = [row_a] + src_2 = MagicMock() + src_2.run.return_value = [row_b] + src_3 = MagicMock() + src_3.run.return_value = [] + tgt_1 = MagicMock() + tgt_2 = MagicMock() + + with ( + patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[ + _make_session_ctx(src_1), + _make_session_ctx(tgt_1), + _make_session_ctx(src_2), + _make_session_ctx(tgt_2), + _make_session_ctx(src_3), + ], + ), + patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1), + ): + total = sync_module.sync_relationships("src", "tgt", "p-1") + + assert total == 2 + assert src_1.run.call_args.args[1]["last_id"] == -1 + assert src_2.run.call_args.args[1]["last_id"] == 1 + + def test_sync_relationships_empty_source_returns_zero(self): + src = MagicMock() + src.run.return_value = [] + + with patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[_make_session_ctx(src)], + ) as mock_get_session: + total = sync_module.sync_relationships("src", "tgt", "p-1") + + assert total == 0 + assert mock_get_session.call_count == 1 + class TestInternetAnalysis: def _make_provider_and_config(self):