mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-21 18:58:04 +00:00
perf(attack-paths): reduce sync and findings memory usage with smaller batches and cursor iteration (#10359)
This commit is contained in:
@@ -12,6 +12,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
- Attack Paths: Complete migration to private graph labels and properties, removing deprecated dual-write support [(#10268)](https://github.com/prowler-cloud/prowler/pull/10268)
|
||||
- Attack Paths: Added tenant and provider related labels to the nodes so they can be easily filtered on custom queries [(#10308)](https://github.com/prowler-cloud/prowler/pull/10308)
|
||||
- Attack Paths: Reduce sync and findings memory usage with smaller batches, cursor iteration, and sequential sessions [(#10359)](https://github.com/prowler-cloud/prowler/pull/10359)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
|
||||
@@ -3,12 +3,14 @@ from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from config.env import env
|
||||
|
||||
from tasks.jobs.attack_paths import aws
|
||||
|
||||
|
||||
# Batch size for Neo4j operations
|
||||
# Batch size for Neo4j write operations (resource labeling, cleanup)
|
||||
BATCH_SIZE = env.int("ATTACK_PATHS_BATCH_SIZE", 1000)
|
||||
# Batch size for Postgres findings fetch (keyset pagination page size)
|
||||
FINDINGS_BATCH_SIZE = env.int("ATTACK_PATHS_FINDINGS_BATCH_SIZE", 500)
|
||||
# Batch size for temp-to-tenant graph sync (nodes and relationships per cursor page)
|
||||
SYNC_BATCH_SIZE = env.int("ATTACK_PATHS_SYNC_BATCH_SIZE", 250)
|
||||
|
||||
# Neo4j internal labels (Prowler-specific, not provider-specific)
|
||||
# - `Internet`: Singleton node representing external internet access for exposed-resource queries
|
||||
|
||||
@@ -9,22 +9,15 @@ This module handles:
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, Generator
|
||||
from uuid import UUID
|
||||
|
||||
import neo4j
|
||||
|
||||
from cartography.config import Config as CartographyConfig
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Finding as FindingModel
|
||||
from api.models import Provider, ResourceFindingMapping
|
||||
from prowler.config import config as ProwlerConfig
|
||||
from tasks.jobs.attack_paths.config import (
|
||||
BATCH_SIZE,
|
||||
FINDINGS_BATCH_SIZE,
|
||||
get_node_uid_field,
|
||||
get_provider_resource_label,
|
||||
get_root_node_label,
|
||||
@@ -37,75 +30,54 @@ from tasks.jobs.attack_paths.queries import (
|
||||
render_cypher_template,
|
||||
)
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Finding as FindingModel
|
||||
from api.models import Provider, ResourceFindingMapping
|
||||
from prowler.config import config as ProwlerConfig
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
# Type Definitions
|
||||
# -----------------
|
||||
|
||||
# Maps dataclass field names to Django ORM query field names
|
||||
_DB_FIELD_MAP: dict[str, str] = {
|
||||
"check_title": "check_metadata__checktitle",
|
||||
}
|
||||
# Django ORM field names for `.values()` queries
|
||||
# Most map 1:1 to Neo4j property names, exceptions are remapped in `_to_neo4j_dict`
|
||||
_DB_QUERY_FIELDS = [
|
||||
"id",
|
||||
"uid",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
"first_seen_at",
|
||||
"scan_id",
|
||||
"delta",
|
||||
"status",
|
||||
"status_extended",
|
||||
"severity",
|
||||
"check_id",
|
||||
"check_metadata__checktitle",
|
||||
"muted",
|
||||
"muted_reason",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Finding:
|
||||
"""
|
||||
Finding data for Neo4j ingestion.
|
||||
|
||||
Can be created from a Django .values() query result using from_db_record().
|
||||
"""
|
||||
|
||||
id: str
|
||||
uid: str
|
||||
inserted_at: str
|
||||
updated_at: str
|
||||
first_seen_at: str
|
||||
scan_id: str
|
||||
delta: str
|
||||
status: str
|
||||
status_extended: str
|
||||
severity: str
|
||||
check_id: str
|
||||
check_title: str
|
||||
muted: bool
|
||||
muted_reason: str | None
|
||||
resource_uid: str | None = None
|
||||
|
||||
@classmethod
|
||||
def get_db_query_fields(cls) -> tuple[str, ...]:
|
||||
"""Get field names for Django .values() query."""
|
||||
return tuple(
|
||||
_DB_FIELD_MAP.get(f.name, f.name)
|
||||
for f in fields(cls)
|
||||
if f.name != "resource_uid"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db_record(cls, record: dict[str, Any], resource_uid: str) -> "Finding":
|
||||
"""Create a Finding from a Django .values() query result."""
|
||||
return cls(
|
||||
id=str(record["id"]),
|
||||
uid=record["uid"],
|
||||
inserted_at=record["inserted_at"],
|
||||
updated_at=record["updated_at"],
|
||||
first_seen_at=record["first_seen_at"],
|
||||
scan_id=str(record["scan_id"]),
|
||||
delta=record["delta"],
|
||||
status=record["status"],
|
||||
status_extended=record["status_extended"],
|
||||
severity=record["severity"],
|
||||
check_id=str(record["check_id"]),
|
||||
check_title=record["check_metadata__checktitle"],
|
||||
muted=record["muted"],
|
||||
muted_reason=record["muted_reason"],
|
||||
resource_uid=resource_uid,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dict for Neo4j ingestion."""
|
||||
return asdict(self)
|
||||
def _to_neo4j_dict(record: dict[str, Any], resource_uid: str) -> dict[str, Any]:
|
||||
"""Transform a Django `.values()` record into a `dict` ready for Neo4j ingestion."""
|
||||
return {
|
||||
"id": str(record["id"]),
|
||||
"uid": record["uid"],
|
||||
"inserted_at": record["inserted_at"],
|
||||
"updated_at": record["updated_at"],
|
||||
"first_seen_at": record["first_seen_at"],
|
||||
"scan_id": str(record["scan_id"]),
|
||||
"delta": record["delta"],
|
||||
"status": record["status"],
|
||||
"status_extended": record["status_extended"],
|
||||
"severity": record["severity"],
|
||||
"check_id": str(record["check_id"]),
|
||||
"check_title": record["check_metadata__checktitle"],
|
||||
"muted": record["muted"],
|
||||
"muted_reason": record["muted_reason"],
|
||||
"resource_uid": resource_uid,
|
||||
}
|
||||
|
||||
|
||||
# Public API
|
||||
@@ -180,7 +152,7 @@ def add_resource_label(
|
||||
|
||||
def load_findings(
|
||||
neo4j_session: neo4j.Session,
|
||||
findings_batches: Generator[list[Finding], None, None],
|
||||
findings_batches: Generator[list[dict[str, Any]], None, None],
|
||||
prowler_api_provider: Provider,
|
||||
config: CartographyConfig,
|
||||
) -> None:
|
||||
@@ -209,7 +181,7 @@ def load_findings(
|
||||
batch_size = len(batch)
|
||||
total_records += batch_size
|
||||
|
||||
parameters["findings_data"] = [f.to_dict() for f in batch]
|
||||
parameters["findings_data"] = batch
|
||||
|
||||
logger.info(f"Loading findings batch {batch_num} ({batch_size} records)")
|
||||
neo4j_session.run(query, parameters)
|
||||
@@ -247,16 +219,17 @@ def cleanup_findings(
|
||||
def stream_findings_with_resources(
|
||||
prowler_api_provider: Provider,
|
||||
scan_id: str,
|
||||
) -> Generator[list[Finding], None, None]:
|
||||
) -> Generator[list[dict[str, Any]], None, None]:
|
||||
"""
|
||||
Stream findings with their associated resources in batches.
|
||||
|
||||
Uses keyset pagination for efficient traversal of large datasets.
|
||||
Memory efficient: yields one batch at a time, never holds all findings in memory.
|
||||
Memory efficient: yields one batch at a time as dicts ready for Neo4j ingestion,
|
||||
never holds all findings in memory.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting findings stream for scan {scan_id} "
|
||||
f"(tenant {prowler_api_provider.tenant_id}) with batch size {BATCH_SIZE}"
|
||||
f"(tenant {prowler_api_provider.tenant_id}) with batch size {FINDINGS_BATCH_SIZE}"
|
||||
)
|
||||
|
||||
tenant_id = prowler_api_provider.tenant_id
|
||||
@@ -305,15 +278,14 @@ def _fetch_findings_batch(
|
||||
Uses read replica and RLS-scoped transaction.
|
||||
"""
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Use all_objects to avoid the ActiveProviderManager's implicit JOIN
|
||||
# through Scan -> Provider (to check is_deleted=False).
|
||||
# The provider is already validated as active in this context.
|
||||
# Use `all_objects` to get `Findings` even on soft-deleted `Providers`
|
||||
# But even the provider is already validated as active in this context
|
||||
qs = FindingModel.all_objects.filter(scan_id=scan_id).order_by("id")
|
||||
|
||||
if after_id is not None:
|
||||
qs = qs.filter(id__gt=after_id)
|
||||
|
||||
return list(qs.values(*Finding.get_db_query_fields())[:BATCH_SIZE])
|
||||
return list(qs.values(*_DB_QUERY_FIELDS)[:FINDINGS_BATCH_SIZE])
|
||||
|
||||
|
||||
# Batch Enrichment
|
||||
@@ -323,7 +295,7 @@ def _fetch_findings_batch(
|
||||
def _enrich_batch_with_resources(
|
||||
findings_batch: list[dict[str, Any]],
|
||||
tenant_id: str,
|
||||
) -> list[Finding]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Enrich findings with their resource UIDs.
|
||||
|
||||
@@ -334,7 +306,7 @@ def _enrich_batch_with_resources(
|
||||
resource_map = _build_finding_resource_map(finding_ids, tenant_id)
|
||||
|
||||
return [
|
||||
Finding.from_db_record(finding, resource_uid)
|
||||
_to_neo4j_dict(finding, resource_uid)
|
||||
for finding in findings_batch
|
||||
for resource_uid in resource_map.get(finding["id"], [])
|
||||
]
|
||||
|
||||
@@ -8,13 +8,14 @@ to the tenant database, adding provider isolation labels and properties.
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import neo4j
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from api.attack_paths import database as graph_database
|
||||
from tasks.jobs.attack_paths.config import (
|
||||
BATCH_SIZE,
|
||||
PROVIDER_ISOLATION_PROPERTIES,
|
||||
PROVIDER_RESOURCE_LABEL,
|
||||
SYNC_BATCH_SIZE,
|
||||
get_provider_label,
|
||||
get_tenant_label,
|
||||
)
|
||||
@@ -82,40 +83,32 @@ def sync_nodes(
|
||||
|
||||
Adds `_ProviderResource` label and `_provider_id` property to all nodes.
|
||||
Also adds dynamic `_Tenant_{id}` and `_Provider_{id}` isolation labels.
|
||||
|
||||
Source and target sessions are opened sequentially per batch to avoid
|
||||
holding two Bolt connections simultaneously for the entire sync duration.
|
||||
"""
|
||||
last_id = -1
|
||||
total_synced = 0
|
||||
|
||||
with (
|
||||
graph_database.get_session(source_database) as source_session,
|
||||
graph_database.get_session(target_database) as target_session,
|
||||
):
|
||||
while True:
|
||||
rows = list(
|
||||
source_session.run(
|
||||
NODE_FETCH_QUERY,
|
||||
{"last_id": last_id, "batch_size": BATCH_SIZE},
|
||||
)
|
||||
while True:
|
||||
grouped: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list)
|
||||
batch_count = 0
|
||||
|
||||
with graph_database.get_session(source_database) as source_session:
|
||||
result = source_session.run(
|
||||
NODE_FETCH_QUERY,
|
||||
{"last_id": last_id, "batch_size": SYNC_BATCH_SIZE},
|
||||
)
|
||||
for record in result:
|
||||
batch_count += 1
|
||||
last_id = record["internal_id"]
|
||||
key, value = _node_to_sync_dict(record, provider_id)
|
||||
grouped[key].append(value)
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
last_id = rows[-1]["internal_id"]
|
||||
|
||||
grouped: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list)
|
||||
for row in rows:
|
||||
labels = tuple(sorted(set(row["labels"] or [])))
|
||||
props = dict(row["props"] or {})
|
||||
_strip_internal_properties(props)
|
||||
provider_element_id = f"{provider_id}:{row['element_id']}"
|
||||
grouped[labels].append(
|
||||
{
|
||||
"provider_element_id": provider_element_id,
|
||||
"props": props,
|
||||
}
|
||||
)
|
||||
if batch_count == 0:
|
||||
break
|
||||
|
||||
with graph_database.get_session(target_database) as target_session:
|
||||
for labels, batch in grouped.items():
|
||||
label_set = set(labels)
|
||||
label_set.add(PROVIDER_RESOURCE_LABEL)
|
||||
@@ -134,10 +127,10 @@ def sync_nodes(
|
||||
},
|
||||
)
|
||||
|
||||
total_synced += len(rows)
|
||||
logger.info(
|
||||
f"Synced {total_synced} nodes from {source_database} to {target_database}"
|
||||
)
|
||||
total_synced += batch_count
|
||||
logger.info(
|
||||
f"Synced {total_synced} nodes from {source_database} to {target_database}"
|
||||
)
|
||||
|
||||
return total_synced
|
||||
|
||||
@@ -151,41 +144,32 @@ def sync_relationships(
|
||||
Sync relationships from source to target database.
|
||||
|
||||
Adds `_provider_id` property to all relationships.
|
||||
|
||||
Source and target sessions are opened sequentially per batch to avoid
|
||||
holding two Bolt connections simultaneously for the entire sync duration.
|
||||
"""
|
||||
last_id = -1
|
||||
total_synced = 0
|
||||
|
||||
with (
|
||||
graph_database.get_session(source_database) as source_session,
|
||||
graph_database.get_session(target_database) as target_session,
|
||||
):
|
||||
while True:
|
||||
rows = list(
|
||||
source_session.run(
|
||||
RELATIONSHIPS_FETCH_QUERY,
|
||||
{"last_id": last_id, "batch_size": BATCH_SIZE},
|
||||
)
|
||||
while True:
|
||||
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
batch_count = 0
|
||||
|
||||
with graph_database.get_session(source_database) as source_session:
|
||||
result = source_session.run(
|
||||
RELATIONSHIPS_FETCH_QUERY,
|
||||
{"last_id": last_id, "batch_size": SYNC_BATCH_SIZE},
|
||||
)
|
||||
for record in result:
|
||||
batch_count += 1
|
||||
last_id = record["internal_id"]
|
||||
key, value = _rel_to_sync_dict(record, provider_id)
|
||||
grouped[key].append(value)
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
last_id = rows[-1]["internal_id"]
|
||||
|
||||
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for row in rows:
|
||||
props = dict(row["props"] or {})
|
||||
_strip_internal_properties(props)
|
||||
rel_type = row["rel_type"]
|
||||
grouped[rel_type].append(
|
||||
{
|
||||
"start_element_id": f"{provider_id}:{row['start_element_id']}",
|
||||
"end_element_id": f"{provider_id}:{row['end_element_id']}",
|
||||
"provider_element_id": f"{provider_id}:{rel_type}:{row['internal_id']}",
|
||||
"props": props,
|
||||
}
|
||||
)
|
||||
if batch_count == 0:
|
||||
break
|
||||
|
||||
with graph_database.get_session(target_database) as target_session:
|
||||
for rel_type, batch in grouped.items():
|
||||
query = render_cypher_template(
|
||||
RELATIONSHIP_SYNC_TEMPLATE, {"__REL_TYPE__": rel_type}
|
||||
@@ -198,14 +182,42 @@ def sync_relationships(
|
||||
},
|
||||
)
|
||||
|
||||
total_synced += len(rows)
|
||||
logger.info(
|
||||
f"Synced {total_synced} relationships from {source_database} to {target_database}"
|
||||
)
|
||||
total_synced += batch_count
|
||||
logger.info(
|
||||
f"Synced {total_synced} relationships from {source_database} to {target_database}"
|
||||
)
|
||||
|
||||
return total_synced
|
||||
|
||||
|
||||
def _node_to_sync_dict(
|
||||
record: neo4j.Record, provider_id: str
|
||||
) -> tuple[tuple[str, ...], dict[str, Any]]:
|
||||
"""Transform a source node record into a (grouping_key, sync_dict) pair."""
|
||||
props = dict(record["props"] or {})
|
||||
_strip_internal_properties(props)
|
||||
labels = tuple(sorted(set(record["labels"] or [])))
|
||||
return labels, {
|
||||
"provider_element_id": f"{provider_id}:{record['element_id']}",
|
||||
"props": props,
|
||||
}
|
||||
|
||||
|
||||
def _rel_to_sync_dict(
|
||||
record: neo4j.Record, provider_id: str
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Transform a source relationship record into a (grouping_key, sync_dict) pair."""
|
||||
props = dict(record["props"] or {})
|
||||
_strip_internal_properties(props)
|
||||
rel_type = record["rel_type"]
|
||||
return rel_type, {
|
||||
"start_element_id": f"{provider_id}:{record['start_element_id']}",
|
||||
"end_element_id": f"{provider_id}:{record['end_element_id']}",
|
||||
"provider_element_id": f"{provider_id}:{rel_type}:{record['internal_id']}",
|
||||
"props": props,
|
||||
}
|
||||
|
||||
|
||||
def _strip_internal_properties(props: dict[str, Any]) -> None:
|
||||
"""Remove provider isolation properties before the += spread in sync templates."""
|
||||
for key in PROVIDER_ISOLATION_PROPERTIES:
|
||||
|
||||
@@ -1279,16 +1279,10 @@ class TestAttackPathsFindingsHelpers:
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
# Create mock Finding objects with to_dict() method
|
||||
mock_finding_1 = MagicMock()
|
||||
mock_finding_1.to_dict.return_value = {"id": "1", "resource_uid": "r-1"}
|
||||
mock_finding_2 = MagicMock()
|
||||
mock_finding_2.to_dict.return_value = {"id": "2", "resource_uid": "r-2"}
|
||||
|
||||
# Create a generator that yields two batches of Finding instances
|
||||
# Create a generator that yields two batches of dicts (pre-converted)
|
||||
def findings_generator():
|
||||
yield [mock_finding_1]
|
||||
yield [mock_finding_2]
|
||||
yield [{"id": "1", "resource_uid": "r-1"}]
|
||||
yield [{"id": "2", "resource_uid": "r-2"}]
|
||||
|
||||
config = SimpleNamespace(update_tag=12345)
|
||||
mock_session = MagicMock()
|
||||
@@ -1435,17 +1429,17 @@ class TestAttackPathsFindingsHelpers:
|
||||
|
||||
assert len(findings_data) == 1
|
||||
finding_result = findings_data[0]
|
||||
assert finding_result.id == str(finding.id)
|
||||
assert finding_result.resource_uid == resource.uid
|
||||
assert finding_result.check_title == "Check title"
|
||||
assert finding_result.scan_id == str(latest_scan.id)
|
||||
assert finding_result["id"] == str(finding.id)
|
||||
assert finding_result["resource_uid"] == resource.uid
|
||||
assert finding_result["check_title"] == "Check title"
|
||||
assert finding_result["scan_id"] == str(latest_scan.id)
|
||||
|
||||
def test_enrich_batch_with_resources_single_resource(
|
||||
self,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""One finding + one resource = one output Finding instance"""
|
||||
"""One finding + one resource = one output dict"""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
@@ -1519,16 +1513,16 @@ class TestAttackPathsFindingsHelpers:
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].resource_uid == resource.uid
|
||||
assert result[0].id == str(finding.id)
|
||||
assert result[0].status == "FAIL"
|
||||
assert result[0]["resource_uid"] == resource.uid
|
||||
assert result[0]["id"] == str(finding.id)
|
||||
assert result[0]["status"] == "FAIL"
|
||||
|
||||
def test_enrich_batch_with_resources_multiple_resources(
|
||||
self,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
"""One finding + three resources = three output Finding instances"""
|
||||
"""One finding + three resources = three output dicts"""
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
@@ -1607,13 +1601,13 @@ class TestAttackPathsFindingsHelpers:
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
result_resource_uids = {r.resource_uid for r in result}
|
||||
result_resource_uids = {r["resource_uid"] for r in result}
|
||||
assert result_resource_uids == {r.uid for r in resources}
|
||||
|
||||
# All should have same finding data
|
||||
for r in result:
|
||||
assert r.id == str(finding.id)
|
||||
assert r.status == "FAIL"
|
||||
assert r["id"] == str(finding.id)
|
||||
assert r["status"] == "FAIL"
|
||||
|
||||
def test_enrich_batch_with_resources_no_resources_skips(
|
||||
self,
|
||||
@@ -1690,16 +1684,12 @@ class TestAttackPathsFindingsHelpers:
|
||||
provider.save()
|
||||
scan_id = "some-scan-id"
|
||||
|
||||
with (
|
||||
patch("tasks.jobs.attack_paths.findings.rls_transaction") as mock_rls,
|
||||
patch("tasks.jobs.attack_paths.findings.Finding") as mock_finding,
|
||||
):
|
||||
with patch("tasks.jobs.attack_paths.findings.rls_transaction") as mock_rls:
|
||||
# Create generator but don't iterate
|
||||
findings_module.stream_findings_with_resources(provider, scan_id)
|
||||
|
||||
# Nothing should be called yet
|
||||
mock_rls.assert_not_called()
|
||||
mock_finding.objects.filter.assert_not_called()
|
||||
|
||||
def test_load_findings_empty_generator(self, providers_fixture):
|
||||
"""Empty generator should not call neo4j"""
|
||||
@@ -1752,41 +1742,226 @@ class TestAddResourceLabel:
|
||||
assert "AWSResource" not in query.replace("_AWSResource", "")
|
||||
|
||||
|
||||
def _make_session_ctx(session, call_order=None, name=None):
|
||||
"""Create a mock context manager wrapping a mock session."""
|
||||
ctx = MagicMock()
|
||||
if call_order is not None and name is not None:
|
||||
ctx.__enter__ = MagicMock(
|
||||
side_effect=lambda: (call_order.append(f"{name}:enter"), session)[1]
|
||||
)
|
||||
ctx.__exit__ = MagicMock(
|
||||
side_effect=lambda *a: (call_order.append(f"{name}:exit"), False)[1]
|
||||
)
|
||||
else:
|
||||
ctx.__enter__ = MagicMock(return_value=session)
|
||||
ctx.__exit__ = MagicMock(return_value=False)
|
||||
return ctx
|
||||
|
||||
|
||||
class TestSyncNodes:
|
||||
def test_sync_nodes_adds_private_label(self):
|
||||
mock_source_session = MagicMock()
|
||||
mock_target_session = MagicMock()
|
||||
|
||||
row = {
|
||||
"internal_id": 1,
|
||||
"element_id": "elem-1",
|
||||
"labels": ["SomeLabel"],
|
||||
"props": {"key": "value"},
|
||||
}
|
||||
mock_source_session.run.side_effect = [[row], []]
|
||||
|
||||
source_ctx = MagicMock()
|
||||
source_ctx.__enter__ = MagicMock(return_value=mock_source_session)
|
||||
source_ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
target_ctx = MagicMock()
|
||||
target_ctx.__enter__ = MagicMock(return_value=mock_target_session)
|
||||
target_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_source_1 = MagicMock()
|
||||
mock_source_1.run.return_value = [row]
|
||||
mock_target = MagicMock()
|
||||
mock_source_2 = MagicMock()
|
||||
mock_source_2.run.return_value = []
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[source_ctx, target_ctx],
|
||||
side_effect=[
|
||||
_make_session_ctx(mock_source_1),
|
||||
_make_session_ctx(mock_target),
|
||||
_make_session_ctx(mock_source_2),
|
||||
],
|
||||
):
|
||||
total = sync_module.sync_nodes(
|
||||
"source-db", "target-db", "tenant-1", "prov-1"
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
query = mock_target_session.run.call_args.args[0]
|
||||
query = mock_target.run.call_args.args[0]
|
||||
assert "_ProviderResource" in query
|
||||
assert "_Tenant_tenant1" in query
|
||||
assert "_Provider_prov1" in query
|
||||
|
||||
def test_sync_nodes_source_closes_before_target_opens(self):
|
||||
row = {
|
||||
"internal_id": 1,
|
||||
"element_id": "elem-1",
|
||||
"labels": ["SomeLabel"],
|
||||
"props": {"key": "value"},
|
||||
}
|
||||
|
||||
call_order = []
|
||||
|
||||
src_1 = MagicMock()
|
||||
src_1.run.return_value = [row]
|
||||
tgt = MagicMock()
|
||||
src_2 = MagicMock()
|
||||
src_2.run.return_value = []
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[
|
||||
_make_session_ctx(src_1, call_order, "source1"),
|
||||
_make_session_ctx(tgt, call_order, "target"),
|
||||
_make_session_ctx(src_2, call_order, "source2"),
|
||||
],
|
||||
):
|
||||
sync_module.sync_nodes("src-db", "tgt-db", "t-1", "p-1")
|
||||
|
||||
assert call_order.index("source1:exit") < call_order.index("target:enter")
|
||||
|
||||
def test_sync_nodes_pagination_with_batch_size_1(self):
|
||||
row_a = {
|
||||
"internal_id": 1,
|
||||
"element_id": "elem-1",
|
||||
"labels": ["LabelA"],
|
||||
"props": {"a": 1},
|
||||
}
|
||||
row_b = {
|
||||
"internal_id": 2,
|
||||
"element_id": "elem-2",
|
||||
"labels": ["LabelB"],
|
||||
"props": {"b": 2},
|
||||
}
|
||||
|
||||
src_1 = MagicMock()
|
||||
src_1.run.return_value = [row_a]
|
||||
src_2 = MagicMock()
|
||||
src_2.run.return_value = [row_b]
|
||||
src_3 = MagicMock()
|
||||
src_3.run.return_value = []
|
||||
tgt_1 = MagicMock()
|
||||
tgt_2 = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[
|
||||
_make_session_ctx(src_1),
|
||||
_make_session_ctx(tgt_1),
|
||||
_make_session_ctx(src_2),
|
||||
_make_session_ctx(tgt_2),
|
||||
_make_session_ctx(src_3),
|
||||
],
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1),
|
||||
):
|
||||
total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1")
|
||||
|
||||
assert total == 2
|
||||
assert src_1.run.call_args.args[1]["last_id"] == -1
|
||||
assert src_2.run.call_args.args[1]["last_id"] == 1
|
||||
|
||||
def test_sync_nodes_empty_source_returns_zero(self):
|
||||
src = MagicMock()
|
||||
src.run.return_value = []
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[_make_session_ctx(src)],
|
||||
) as mock_get_session:
|
||||
total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1")
|
||||
|
||||
assert total == 0
|
||||
assert mock_get_session.call_count == 1
|
||||
|
||||
|
||||
class TestSyncRelationships:
|
||||
def test_sync_relationships_source_closes_before_target_opens(self):
|
||||
row = {
|
||||
"internal_id": 1,
|
||||
"rel_type": "HAS",
|
||||
"start_element_id": "s-1",
|
||||
"end_element_id": "e-1",
|
||||
"props": {},
|
||||
}
|
||||
|
||||
call_order = []
|
||||
|
||||
src_1 = MagicMock()
|
||||
src_1.run.return_value = [row]
|
||||
tgt = MagicMock()
|
||||
src_2 = MagicMock()
|
||||
src_2.run.return_value = []
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[
|
||||
_make_session_ctx(src_1, call_order, "source1"),
|
||||
_make_session_ctx(tgt, call_order, "target"),
|
||||
_make_session_ctx(src_2, call_order, "source2"),
|
||||
],
|
||||
):
|
||||
sync_module.sync_relationships("src", "tgt", "p-1")
|
||||
|
||||
assert call_order.index("source1:exit") < call_order.index("target:enter")
|
||||
|
||||
def test_sync_relationships_pagination_with_batch_size_1(self):
|
||||
row_a = {
|
||||
"internal_id": 1,
|
||||
"rel_type": "HAS",
|
||||
"start_element_id": "s-1",
|
||||
"end_element_id": "e-1",
|
||||
"props": {"a": 1},
|
||||
}
|
||||
row_b = {
|
||||
"internal_id": 2,
|
||||
"rel_type": "CONNECTS",
|
||||
"start_element_id": "s-2",
|
||||
"end_element_id": "e-2",
|
||||
"props": {"b": 2},
|
||||
}
|
||||
|
||||
src_1 = MagicMock()
|
||||
src_1.run.return_value = [row_a]
|
||||
src_2 = MagicMock()
|
||||
src_2.run.return_value = [row_b]
|
||||
src_3 = MagicMock()
|
||||
src_3.run.return_value = []
|
||||
tgt_1 = MagicMock()
|
||||
tgt_2 = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[
|
||||
_make_session_ctx(src_1),
|
||||
_make_session_ctx(tgt_1),
|
||||
_make_session_ctx(src_2),
|
||||
_make_session_ctx(tgt_2),
|
||||
_make_session_ctx(src_3),
|
||||
],
|
||||
),
|
||||
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1),
|
||||
):
|
||||
total = sync_module.sync_relationships("src", "tgt", "p-1")
|
||||
|
||||
assert total == 2
|
||||
assert src_1.run.call_args.args[1]["last_id"] == -1
|
||||
assert src_2.run.call_args.args[1]["last_id"] == 1
|
||||
|
||||
def test_sync_relationships_empty_source_returns_zero(self):
|
||||
src = MagicMock()
|
||||
src.run.return_value = []
|
||||
|
||||
with patch(
|
||||
"tasks.jobs.attack_paths.sync.graph_database.get_session",
|
||||
side_effect=[_make_session_ctx(src)],
|
||||
) as mock_get_session:
|
||||
total = sync_module.sync_relationships("src", "tgt", "p-1")
|
||||
|
||||
assert total == 0
|
||||
assert mock_get_session.call_count == 1
|
||||
|
||||
|
||||
class TestInternetAnalysis:
|
||||
def _make_provider_and_config(self):
|
||||
|
||||
Reference in New Issue
Block a user