fix(api): cap attack paths sink sync batches (#11724)

This commit is contained in:
Josema Camacho
2026-06-29 17:48:02 +02:00
committed by GitHub
parent 5404863a3e
commit 4e7e2f7eab
5 changed files with 168 additions and 11 deletions
+5 -1
View File
@@ -101,7 +101,11 @@ def post_fork(_server, worker):
try:
graph_database.close_driver()
except Exception: # pragma: no cover - best-effort cleanup
pass
gunicorn_logger.debug(
"Failed to close inherited Neo4j driver in post_fork for worker pid=%s",
worker.pid,
exc_info=True,
)
graph_database.init_driver()
gunicorn_logger.info(f"Attack-paths drivers initialized for worker {worker.pid}")
@@ -271,6 +271,24 @@ AWS_NORMALIZED_LISTS: list[NormalizedList] = [
"LaunchTemplateVersionSecurityGroupsItem",
"HAS_SECURITY_GROUPS",
),
NormalizedList(
"AWSVpcEndpoint",
"route_table_ids",
"AWSVpcEndpointRouteTableIdsItem",
"HAS_ROUTE_TABLE_IDS",
),
NormalizedList(
"AWSVpcEndpoint",
"network_interface_ids",
"AWSVpcEndpointNetworkInterfaceIdsItem",
"HAS_NETWORK_INTERFACE_IDS",
),
NormalizedList(
"AWSVpcEndpoint",
"subnet_ids",
"AWSVpcEndpointSubnetIdsItem",
"HAS_SUBNET_IDS",
),
NormalizedList(
"ELBListener", "policy_names", "ELBListenerPolicyNamesItem", "HAS_POLICY_NAMES"
),
+28 -10
View File
@@ -18,6 +18,7 @@ added to the catalog.
import json
import time
from collections import defaultdict
from collections.abc import Iterator
from typing import Any
import neo4j
@@ -153,20 +154,21 @@ def sync_nodes(
break
for labels, batch in parent_groups.items():
sink.write_nodes(
target_database, _render_labels(labels, extra_labels), batch
)
rendered_labels = _render_labels(labels, extra_labels)
for sink_batch in _iter_sink_batches(batch):
sink.write_nodes(target_database, rendered_labels, sink_batch)
for child_label, batch in child_groups.items():
sink.write_nodes(
target_database,
_render_labels((child_label,), extra_labels),
batch,
)
rendered_labels = _render_labels((child_label,), extra_labels)
for sink_batch in _iter_sink_batches(batch):
sink.write_nodes(target_database, rendered_labels, sink_batch)
children_synced += len(batch)
for rel_type, batch in rel_groups.items():
sink.write_relationships(target_database, rel_type, provider_id, batch)
for sink_batch in _iter_sink_batches(batch):
sink.write_relationships(
target_database, rel_type, provider_id, sink_batch
)
parent_child_rels += len(batch)
parents_synced += batch_count
@@ -226,7 +228,10 @@ def sync_relationships(
break
for rel_type, batch in grouped.items():
sink.write_relationships(target_database, rel_type, provider_id, batch)
for sink_batch in _iter_sink_batches(batch):
sink.write_relationships(
target_database, rel_type, provider_id, sink_batch
)
total_synced += batch_count
batch_dt = time.perf_counter() - tb
@@ -239,6 +244,19 @@ def sync_relationships(
return total_synced
def _iter_sink_batches(
rows: list[dict[str, Any]],
batch_size: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
"""Yield final sink write batches after source rows have been transformed."""
batch_size = SYNC_BATCH_SIZE if batch_size is None else batch_size
if batch_size <= 0:
raise ValueError("Sink batch size must be greater than zero")
for index in range(0, len(rows), batch_size):
yield rows[index : index + batch_size]
def _node_to_sync_dict(
record: neo4j.Record,
provider_id: str,
@@ -0,0 +1,30 @@
from tasks.jobs.attack_paths.provider_config import AWS_NORMALIZED_LISTS
from tasks.jobs.attack_paths.sync import _build_catalog_index, _node_to_sync_dict
def test_aws_vpc_endpoint_id_lists_are_normalized():
catalog = _build_catalog_index(AWS_NORMALIZED_LISTS)
record = {
"element_id": "node-1",
"labels": ["AWSVpcEndpoint"],
"props": {
"id": "vpce-123",
"route_table_ids": ["rtb-1"],
"network_interface_ids": ["eni-1"],
"subnet_ids": ["subnet-1"],
},
}
_, parent, children, rels = _node_to_sync_dict(record, "provider-id", catalog)
assert parent["props"] == {"id": "vpce-123"}
assert {child["_child_label"] for child in children} == {
"AWSVpcEndpointRouteTableIdsItem",
"AWSVpcEndpointNetworkInterfaceIdsItem",
"AWSVpcEndpointSubnetIdsItem",
}
assert {rel["rel_type"] for rel in rels} == {
"HAS_ROUTE_TABLE_IDS",
"HAS_NETWORK_INTERFACE_IDS",
"HAS_SUBNET_IDS",
}
@@ -1835,6 +1835,12 @@ def _make_session_ctx(session, call_order=None, name=None):
class TestSyncNodes:
def test_iter_sink_batches_rejects_zero_batch_size(self):
with pytest.raises(
ValueError, match="Sink batch size must be greater than zero"
):
list(sync_module._iter_sink_batches([], batch_size=0))
def test_sync_nodes_passes_isolation_labels_to_sink(self):
row = {
"internal_id": 1,
@@ -1940,6 +1946,51 @@ class TestSyncNodes:
assert src_1.run.call_args.args[1]["last_id"] == -1
assert src_2.run.call_args.args[1]["last_id"] == 1
def test_sync_nodes_chunks_expanded_list_rows_before_sink_write(self):
row = {
"internal_id": 1,
"element_id": "elem-1",
"labels": ["SomeLabel"],
"props": {"values": ["a", "b", "c", "d", "e"]},
}
normalized_lists = [
sync_module.NormalizedList(
"SomeLabel",
"values",
"SomeLabelValuesItem",
"HAS_VALUES",
)
]
src_1 = MagicMock()
src_1.run.return_value = [row]
src_2 = MagicMock()
src_2.run.return_value = []
sink = MagicMock()
with (
patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1),
_make_session_ctx(src_2),
],
),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2),
):
result = sync_module.sync_nodes(
"src", "tgt", "t-1", "p-1", sink, normalized_lists
)
assert result == {"parents": 1, "children": 5, "parent_child_rels": 5}
assert [
len(call_args.args[2]) for call_args in sink.write_nodes.call_args_list[1:]
] == [2, 2, 1]
assert [
len(call_args.args[3])
for call_args in sink.write_relationships.call_args_list
] == [2, 2, 1]
def test_sync_nodes_empty_source_returns_zero(self):
src = MagicMock()
src.run.return_value = []
@@ -2030,6 +2081,42 @@ class TestSyncRelationships:
assert src_1.run.call_args.args[1]["last_id"] == -1
assert src_2.run.call_args.args[1]["last_id"] == 1
def test_sync_relationships_chunks_grouped_rows_before_sink_write(self):
rows = [
{
"internal_id": idx,
"rel_type": "HAS",
"start_element_id": f"s-{idx}",
"end_element_id": f"e-{idx}",
"props": {},
}
for idx in range(1, 6)
]
src_1 = MagicMock()
src_1.run.return_value = rows
src_2 = MagicMock()
src_2.run.return_value = []
sink = MagicMock()
with (
patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[
_make_session_ctx(src_1),
_make_session_ctx(src_2),
],
),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2),
):
total = sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert total == 5
assert [
len(call_args.args[3])
for call_args in sink.write_relationships.call_args_list
] == [2, 2, 1]
def test_sync_relationships_empty_source_returns_zero(self):
src = MagicMock()
src.run.return_value = []