mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
fix(api): cap attack paths sink sync batches (#11724)
This commit is contained in:
@@ -101,7 +101,11 @@ def post_fork(_server, worker):
|
||||
try:
|
||||
graph_database.close_driver()
|
||||
except Exception: # pragma: no cover - best-effort cleanup
|
||||
pass
|
||||
gunicorn_logger.debug(
|
||||
"Failed to close inherited Neo4j driver in post_fork for worker pid=%s",
|
||||
worker.pid,
|
||||
exc_info=True,
|
||||
)
|
||||
graph_database.init_driver()
|
||||
gunicorn_logger.info(f"Attack-paths drivers initialized for worker {worker.pid}")
|
||||
|
||||
|
||||
@@ -271,6 +271,24 @@ AWS_NORMALIZED_LISTS: list[NormalizedList] = [
|
||||
"LaunchTemplateVersionSecurityGroupsItem",
|
||||
"HAS_SECURITY_GROUPS",
|
||||
),
|
||||
NormalizedList(
|
||||
"AWSVpcEndpoint",
|
||||
"route_table_ids",
|
||||
"AWSVpcEndpointRouteTableIdsItem",
|
||||
"HAS_ROUTE_TABLE_IDS",
|
||||
),
|
||||
NormalizedList(
|
||||
"AWSVpcEndpoint",
|
||||
"network_interface_ids",
|
||||
"AWSVpcEndpointNetworkInterfaceIdsItem",
|
||||
"HAS_NETWORK_INTERFACE_IDS",
|
||||
),
|
||||
NormalizedList(
|
||||
"AWSVpcEndpoint",
|
||||
"subnet_ids",
|
||||
"AWSVpcEndpointSubnetIdsItem",
|
||||
"HAS_SUBNET_IDS",
|
||||
),
|
||||
NormalizedList(
|
||||
"ELBListener", "policy_names", "ELBListenerPolicyNamesItem", "HAS_POLICY_NAMES"
|
||||
),
|
||||
|
||||
@@ -18,6 +18,7 @@ added to the catalog.
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import neo4j
|
||||
@@ -153,20 +154,21 @@ def sync_nodes(
|
||||
break
|
||||
|
||||
for labels, batch in parent_groups.items():
|
||||
sink.write_nodes(
|
||||
target_database, _render_labels(labels, extra_labels), batch
|
||||
)
|
||||
rendered_labels = _render_labels(labels, extra_labels)
|
||||
for sink_batch in _iter_sink_batches(batch):
|
||||
sink.write_nodes(target_database, rendered_labels, sink_batch)
|
||||
|
||||
for child_label, batch in child_groups.items():
|
||||
sink.write_nodes(
|
||||
target_database,
|
||||
_render_labels((child_label,), extra_labels),
|
||||
batch,
|
||||
)
|
||||
rendered_labels = _render_labels((child_label,), extra_labels)
|
||||
for sink_batch in _iter_sink_batches(batch):
|
||||
sink.write_nodes(target_database, rendered_labels, sink_batch)
|
||||
children_synced += len(batch)
|
||||
|
||||
for rel_type, batch in rel_groups.items():
|
||||
sink.write_relationships(target_database, rel_type, provider_id, batch)
|
||||
for sink_batch in _iter_sink_batches(batch):
|
||||
sink.write_relationships(
|
||||
target_database, rel_type, provider_id, sink_batch
|
||||
)
|
||||
parent_child_rels += len(batch)
|
||||
|
||||
parents_synced += batch_count
|
||||
@@ -226,7 +228,10 @@ def sync_relationships(
|
||||
break
|
||||
|
||||
for rel_type, batch in grouped.items():
|
||||
sink.write_relationships(target_database, rel_type, provider_id, batch)
|
||||
for sink_batch in _iter_sink_batches(batch):
|
||||
sink.write_relationships(
|
||||
target_database, rel_type, provider_id, sink_batch
|
||||
)
|
||||
|
||||
total_synced += batch_count
|
||||
batch_dt = time.perf_counter() - tb
|
||||
@@ -239,6 +244,19 @@ def sync_relationships(
|
||||
return total_synced
|
||||
|
||||
|
||||
def _iter_sink_batches(
|
||||
rows: list[dict[str, Any]],
|
||||
batch_size: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""Yield final sink write batches after source rows have been transformed."""
|
||||
batch_size = SYNC_BATCH_SIZE if batch_size is None else batch_size
|
||||
if batch_size <= 0:
|
||||
raise ValueError("Sink batch size must be greater than zero")
|
||||
|
||||
for index in range(0, len(rows), batch_size):
|
||||
yield rows[index : index + batch_size]
|
||||
|
||||
|
||||
def _node_to_sync_dict(
|
||||
record: neo4j.Record,
|
||||
provider_id: str,
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from tasks.jobs.attack_paths.provider_config import AWS_NORMALIZED_LISTS
|
||||
from tasks.jobs.attack_paths.sync import _build_catalog_index, _node_to_sync_dict
|
||||
|
||||
|
||||
def test_aws_vpc_endpoint_id_lists_are_normalized():
|
||||
catalog = _build_catalog_index(AWS_NORMALIZED_LISTS)
|
||||
record = {
|
||||
"element_id": "node-1",
|
||||
"labels": ["AWSVpcEndpoint"],
|
||||
"props": {
|
||||
"id": "vpce-123",
|
||||
"route_table_ids": ["rtb-1"],
|
||||
"network_interface_ids": ["eni-1"],
|
||||
"subnet_ids": ["subnet-1"],
|
||||
},
|
||||
}
|
||||
|
||||
_, parent, children, rels = _node_to_sync_dict(record, "provider-id", catalog)
|
||||
|
||||
assert parent["props"] == {"id": "vpce-123"}
|
||||
assert {child["_child_label"] for child in children} == {
|
||||
"AWSVpcEndpointRouteTableIdsItem",
|
||||
"AWSVpcEndpointNetworkInterfaceIdsItem",
|
||||
"AWSVpcEndpointSubnetIdsItem",
|
||||
}
|
||||
assert {rel["rel_type"] for rel in rels} == {
|
||||
"HAS_ROUTE_TABLE_IDS",
|
||||
"HAS_NETWORK_INTERFACE_IDS",
|
||||
"HAS_SUBNET_IDS",
|
||||
}
|
||||
@@ -1835,6 +1835,12 @@ def _make_session_ctx(session, call_order=None, name=None):
|
||||
|
||||
|
||||
class TestSyncNodes:
|
||||
def test_iter_sink_batches_rejects_zero_batch_size(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="Sink batch size must be greater than zero"
|
||||
):
|
||||
list(sync_module._iter_sink_batches([], batch_size=0))
|
||||
|
||||
def test_sync_nodes_passes_isolation_labels_to_sink(self):
|
||||
row = {
|
||||
"internal_id": 1,
|
||||
@@ -1940,6 +1946,51 @@ class TestSyncNodes:
|
||||
assert src_1.run.call_args.args[1]["last_id"] == -1
|
||||
assert src_2.run.call_args.args[1]["last_id"] == 1
|
||||
|
||||
def test_sync_nodes_chunks_expanded_list_rows_before_sink_write(self):
|
||||
row = {
|
||||
"internal_id": 1,
|
||||
"element_id": "elem-1",
|
||||
"labels": ["SomeLabel"],
|
||||
"props": {"values": ["a", "b", "c", "d", "e"]},
|
||||
}
|
||||
normalized_lists = [
|
||||
sync_module.NormalizedList(
|
||||
"SomeLabel",
|
||||
"values",
|
||||
"SomeLabelValuesItem",
|
||||
"HAS_VALUES",
|
||||
)
|
||||
]
|
||||
|
||||
src_1 = MagicMock()
|
||||
src_1.run.return_value = [row]
|
||||
src_2 = MagicMock()
|
||||
src_2.run.return_value = []
|
||||
sink = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[
|
||||
_make_session_ctx(src_1),
|
||||
_make_session_ctx(src_2),
|
||||
],
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2),
|
||||
):
|
||||
result = sync_module.sync_nodes(
|
||||
"src", "tgt", "t-1", "p-1", sink, normalized_lists
|
||||
)
|
||||
|
||||
assert result == {"parents": 1, "children": 5, "parent_child_rels": 5}
|
||||
assert [
|
||||
len(call_args.args[2]) for call_args in sink.write_nodes.call_args_list[1:]
|
||||
] == [2, 2, 1]
|
||||
assert [
|
||||
len(call_args.args[3])
|
||||
for call_args in sink.write_relationships.call_args_list
|
||||
] == [2, 2, 1]
|
||||
|
||||
def test_sync_nodes_empty_source_returns_zero(self):
|
||||
src = MagicMock()
|
||||
src.run.return_value = []
|
||||
@@ -2030,6 +2081,42 @@ class TestSyncRelationships:
|
||||
assert src_1.run.call_args.args[1]["last_id"] == -1
|
||||
assert src_2.run.call_args.args[1]["last_id"] == 1
|
||||
|
||||
def test_sync_relationships_chunks_grouped_rows_before_sink_write(self):
|
||||
rows = [
|
||||
{
|
||||
"internal_id": idx,
|
||||
"rel_type": "HAS",
|
||||
"start_element_id": f"s-{idx}",
|
||||
"end_element_id": f"e-{idx}",
|
||||
"props": {},
|
||||
}
|
||||
for idx in range(1, 6)
|
||||
]
|
||||
|
||||
src_1 = MagicMock()
|
||||
src_1.run.return_value = rows
|
||||
src_2 = MagicMock()
|
||||
src_2.run.return_value = []
|
||||
sink = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[
|
||||
_make_session_ctx(src_1),
|
||||
_make_session_ctx(src_2),
|
||||
],
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 2),
|
||||
):
|
||||
total = sync_module.sync_relationships("src", "tgt", "p-1", sink)
|
||||
|
||||
assert total == 5
|
||||
assert [
|
||||
len(call_args.args[3])
|
||||
for call_args in sink.write_relationships.call_args_list
|
||||
] == [2, 2, 1]
|
||||
|
||||
def test_sync_relationships_empty_source_returns_zero(self):
|
||||
src = MagicMock()
|
||||
src.run.return_value = []
|
||||
|
||||
Reference in New Issue
Block a user