diff --git a/api/src/backend/config/guniconf.py b/api/src/backend/config/guniconf.py index 041eed3645..8b17ad669d 100644 --- a/api/src/backend/config/guniconf.py +++ b/api/src/backend/config/guniconf.py @@ -101,7 +101,11 @@ def post_fork(_server, worker): try: graph_database.close_driver() except Exception: # pragma: no cover - best-effort cleanup - pass + gunicorn_logger.debug( + "Failed to close inherited Neo4j driver in post_fork for worker pid=%s", + worker.pid, + exc_info=True, + ) graph_database.init_driver() gunicorn_logger.info(f"Attack-paths drivers initialized for worker {worker.pid}") diff --git a/api/src/backend/tasks/jobs/attack_paths/provider_config.py b/api/src/backend/tasks/jobs/attack_paths/provider_config.py index a5c9d12369..7d834e6aff 100644 --- a/api/src/backend/tasks/jobs/attack_paths/provider_config.py +++ b/api/src/backend/tasks/jobs/attack_paths/provider_config.py @@ -271,6 +271,24 @@ AWS_NORMALIZED_LISTS: list[NormalizedList] = [ "LaunchTemplateVersionSecurityGroupsItem", "HAS_SECURITY_GROUPS", ), + NormalizedList( + "AWSVpcEndpoint", + "route_table_ids", + "AWSVpcEndpointRouteTableIdsItem", + "HAS_ROUTE_TABLE_IDS", + ), + NormalizedList( + "AWSVpcEndpoint", + "network_interface_ids", + "AWSVpcEndpointNetworkInterfaceIdsItem", + "HAS_NETWORK_INTERFACE_IDS", + ), + NormalizedList( + "AWSVpcEndpoint", + "subnet_ids", + "AWSVpcEndpointSubnetIdsItem", + "HAS_SUBNET_IDS", + ), NormalizedList( "ELBListener", "policy_names", "ELBListenerPolicyNamesItem", "HAS_POLICY_NAMES" ), diff --git a/api/src/backend/tasks/jobs/attack_paths/sync.py b/api/src/backend/tasks/jobs/attack_paths/sync.py index adb8ce9b9c..7b73fa21e2 100644 --- a/api/src/backend/tasks/jobs/attack_paths/sync.py +++ b/api/src/backend/tasks/jobs/attack_paths/sync.py @@ -18,6 +18,7 @@ added to the catalog. import json import time from collections import defaultdict +from collections.abc import Iterator from typing import Any import neo4j @@ -153,20 +154,21 @@ def sync_nodes( break for labels, batch in parent_groups.items(): - sink.write_nodes( - target_database, _render_labels(labels, extra_labels), batch - ) + rendered_labels = _render_labels(labels, extra_labels) + for sink_batch in _iter_sink_batches(batch): + sink.write_nodes(target_database, rendered_labels, sink_batch) for child_label, batch in child_groups.items(): - sink.write_nodes( - target_database, - _render_labels((child_label,), extra_labels), - batch, - ) + rendered_labels = _render_labels((child_label,), extra_labels) + for sink_batch in _iter_sink_batches(batch): + sink.write_nodes(target_database, rendered_labels, sink_batch) children_synced += len(batch) for rel_type, batch in rel_groups.items(): - sink.write_relationships(target_database, rel_type, provider_id, batch) + for sink_batch in _iter_sink_batches(batch): + sink.write_relationships( + target_database, rel_type, provider_id, sink_batch + ) parent_child_rels += len(batch) parents_synced += batch_count @@ -226,7 +228,10 @@ def sync_relationships( break for rel_type, batch in grouped.items(): - sink.write_relationships(target_database, rel_type, provider_id, batch) + for sink_batch in _iter_sink_batches(batch): + sink.write_relationships( + target_database, rel_type, provider_id, sink_batch + ) total_synced += batch_count batch_dt = time.perf_counter() - tb @@ -239,6 +244,19 @@ def sync_relationships( return total_synced +def _iter_sink_batches( + rows: list[dict[str, Any]], + batch_size: int | None = None, +) -> Iterator[list[dict[str, Any]]]: + """Yield final sink write batches after source rows have been transformed.""" + batch_size = SYNC_BATCH_SIZE if batch_size is None else batch_size + if batch_size <= 0: + raise ValueError("Sink batch size must be greater than zero") + + for index in range(0, len(rows), batch_size): + yield rows[index : index + batch_size] + + def _node_to_sync_dict( record: neo4j.Record, provider_id: str, diff --git a/api/src/backend/tasks/tests/test_attack_paths_provider_config.py b/api/src/backend/tasks/tests/test_attack_paths_provider_config.py new file mode 100644 index 0000000000..41ec2847d4 --- /dev/null +++ b/api/src/backend/tasks/tests/test_attack_paths_provider_config.py @@ -0,0 +1,30 @@ +from tasks.jobs.attack_paths.provider_config import AWS_NORMALIZED_LISTS +from tasks.jobs.attack_paths.sync import _build_catalog_index, _node_to_sync_dict + + +def test_aws_vpc_endpoint_id_lists_are_normalized(): + catalog = _build_catalog_index(AWS_NORMALIZED_LISTS) + record = { + "element_id": "node-1", + "labels": ["AWSVpcEndpoint"], + "props": { + "id": "vpce-123", + "route_table_ids": ["rtb-1"], + "network_interface_ids": ["eni-1"], + "subnet_ids": ["subnet-1"], + }, + } + + _, parent, children, rels = _node_to_sync_dict(record, "provider-id", catalog) + + assert parent["props"] == {"id": "vpce-123"} + assert {child["_child_label"] for child in children} == { + "AWSVpcEndpointRouteTableIdsItem", + "AWSVpcEndpointNetworkInterfaceIdsItem", + "AWSVpcEndpointSubnetIdsItem", + } + assert {rel["rel_type"] for rel in rels} == { + "HAS_ROUTE_TABLE_IDS", + "HAS_NETWORK_INTERFACE_IDS", + "HAS_SUBNET_IDS", + } diff --git a/api/src/backend/tasks/tests/test_attack_paths_scan.py b/api/src/backend/tasks/tests/test_attack_paths_scan.py index 4768d243de..ab77a5af67 100644 --- a/api/src/backend/tasks/tests/test_attack_paths_scan.py +++ b/api/src/backend/tasks/tests/test_attack_paths_scan.py @@ -1835,6 +1835,12 @@ def _make_session_ctx(session, call_order=None, name=None): class TestSyncNodes: + def test_iter_sink_batches_rejects_zero_batch_size(self): + with pytest.raises( + ValueError, match="Sink batch size must be greater than zero" + ): + list(sync_module._iter_sink_batches([], batch_size=0)) + def test_sync_nodes_passes_isolation_labels_to_sink(self): row = { "internal_id": 1, @@ -1940,6 +1946,51 @@ class TestSyncNodes: assert src_1.run.call_args.args[1]["last_id"] == -1 assert src_2.run.call_args.args[1]["last_id"] == 1 + def test_sync_nodes_chunks_expanded_list_rows_before_sink_write(self): + row = { + "internal_id": 1, + "element_id": "elem-1", + "labels": ["SomeLabel"], + "props": {"values": ["a", "b", "c", "d", "e"]}, + } + normalized_lists = [ + sync_module.NormalizedList( + "SomeLabel", + "values", + "SomeLabelValuesItem", + "HAS_VALUES", + ) + ] + + src_1 = MagicMock() + src_1.run.return_value = [row] + src_2 = MagicMock() + src_2.run.return_value = [] + sink = MagicMock() + + with ( + patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[ + _make_session_ctx(src_1), + _make_session_ctx(src_2), + ], + ), + patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2), + ): + result = sync_module.sync_nodes( + "src", "tgt", "t-1", "p-1", sink, normalized_lists + ) + + assert result == {"parents": 1, "children": 5, "parent_child_rels": 5} + assert [ + len(call_args.args[2]) for call_args in sink.write_nodes.call_args_list[1:] + ] == [2, 2, 1] + assert [ + len(call_args.args[3]) + for call_args in sink.write_relationships.call_args_list + ] == [2, 2, 1] + def test_sync_nodes_empty_source_returns_zero(self): src = MagicMock() src.run.return_value = [] @@ -2030,6 +2081,42 @@ class TestSyncRelationships: assert src_1.run.call_args.args[1]["last_id"] == -1 assert src_2.run.call_args.args[1]["last_id"] == 1 + def test_sync_relationships_chunks_grouped_rows_before_sink_write(self): + rows = [ + { + "internal_id": idx, + "rel_type": "HAS", + "start_element_id": f"s-{idx}", + "end_element_id": f"e-{idx}", + "props": {}, + } + for idx in range(1, 6) + ] + + src_1 = MagicMock() + src_1.run.return_value = rows + src_2 = MagicMock() + src_2.run.return_value = [] + sink = MagicMock() + + with ( + patch( + "tasks.jobs.attack_paths.sync.graph_database.get_session", + side_effect=[ + _make_session_ctx(src_1), + _make_session_ctx(src_2), + ], + ), + patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2), + ): + total = sync_module.sync_relationships("src", "tgt", "p-1", sink) + + assert total == 5 + assert [ + len(call_args.args[3]) + for call_args in sink.write_relationships.call_args_list + ] == [2, 2, 1] + def test_sync_relationships_empty_source_returns_zero(self): src = MagicMock() src.run.return_value = []