feat(api): private labels and properties in Attack Paths graph - phase 1 (#10124)

This commit is contained in:
Josema Camacho
2026-02-23 11:30:26 +01:00
committed by GitHub
parent b5d2a75151
commit 75c7f61513
8 changed files with 133 additions and 28 deletions

View File

@@ -25,6 +25,7 @@ All notable changes to the **Prowler API** are documented in this file.
- AI agent guidelines with TDD and testing skills references [(#9925)](https://github.com/prowler-cloud/prowler/pull/9925)
- Attack Paths: Upgrade Cartography from fork 0.126.1 to upstream 0.129.0 and Neo4j driver from 5.x to 6.x [(#10110)](https://github.com/prowler-cloud/prowler/pull/10110)
- Attack Paths: Query results now filtered by provider, preventing future cross-tenant and cross-provider data leakage [(#10118)](https://github.com/prowler-cloud/prowler/pull/10118)
- Attack Paths: Add private labels and properties in Attack Paths graphs for avoiding future overlapping with Cartography's ones [(#10124)](https://github.com/prowler-cloud/prowler/pull/10124)
### 🐞 Fixed

View File

@@ -12,7 +12,10 @@ import neo4j.exceptions
from django.conf import settings
from api.attack_paths.retryable_session import RetryableSession
from tasks.jobs.attack_paths.config import BATCH_SIZE, PROVIDER_RESOURCE_LABEL
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
DEPRECATED_PROVIDER_RESOURCE_LABEL,
)
# Without this Celery goes crazy with Neo4j logging
logging.getLogger("neo4j").setLevel(logging.ERROR)
@@ -128,7 +131,7 @@ def drop_subgraph(database: str, provider_id: str) -> int:
while deleted_count > 0:
result = session.run(
f"""
MATCH (n:{PROVIDER_RESOURCE_LABEL} {{provider_id: $provider_id}})
MATCH (n:{DEPRECATED_PROVIDER_RESOURCE_LABEL} {{provider_id: $provider_id}})
WITH n LIMIT $batch_size
DETACH DELETE n
RETURN COUNT(n) AS deleted_nodes_count

View File

@@ -10,13 +10,17 @@ from tasks.jobs.attack_paths import aws
BATCH_SIZE = env.int("ATTACK_PATHS_BATCH_SIZE", 1000)
# Neo4j internal labels (Prowler-specific, not provider-specific)
# - `ProwlerFinding`: Label for finding nodes created by Prowler and linked to cloud resources.
# - `ProviderResource`: Added to ALL synced nodes for provider isolation and drop/query ops.
# - `Internet`: Singleton node representing external internet access for exposed-resource queries.
# - `ProwlerFinding`: Label for finding nodes created by Prowler and linked to cloud resources
# - `_ProviderResource`: Added to ALL synced nodes for provider isolation and drop/query ops
# - `Internet`: Singleton node representing external internet access for exposed-resource queries
PROWLER_FINDING_LABEL = "ProwlerFinding"
PROVIDER_RESOURCE_LABEL = "ProviderResource"
PROVIDER_RESOURCE_LABEL = "_ProviderResource"
INTERNET_NODE_LABEL = "Internet"
# Phase 1 dual-write: deprecated label kept for drop_subgraph and infrastructure queries
# Remove in Phase 2 once all nodes use the private label exclusively
DEPRECATED_PROVIDER_RESOURCE_LABEL = "ProviderResource"
@dataclass(frozen=True)
class ProviderConfig:
@@ -26,7 +30,8 @@ class ProviderConfig:
root_node_label: str # e.g., "AWSAccount"
uid_field: str # e.g., "arn"
# Label for resources connected to the account node, enabling indexed finding lookups.
resource_label: str # e.g., "AWSResource"
resource_label: str # e.g., "_AWSResource"
deprecated_resource_label: str # e.g., "AWSResource"
ingestion_function: Callable
@@ -37,7 +42,8 @@ AWS_CONFIG = ProviderConfig(
name="aws",
root_node_label="AWSAccount",
uid_field="arn",
resource_label="AWSResource",
resource_label="_AWSResource",
deprecated_resource_label="AWSResource",
ingestion_function=aws.start_aws_ingestion,
)
@@ -48,10 +54,12 @@ PROVIDER_CONFIGS: dict[str, ProviderConfig] = {
# Labels added by Prowler that should be filtered from API responses
# Derived from provider configs + common internal labels
INTERNAL_LABELS: list[str] = [
"Tenant",
"Tenant", # From Cartography, but it looks like it's ours
PROVIDER_RESOURCE_LABEL,
DEPRECATED_PROVIDER_RESOURCE_LABEL,
# Add all provider-specific resource labels
*[config.resource_label for config in PROVIDER_CONFIGS.values()],
*[config.deprecated_resource_label for config in PROVIDER_CONFIGS.values()],
]
@@ -83,6 +91,12 @@ def get_node_uid_field(provider_type: str) -> str:
def get_provider_resource_label(provider_type: str) -> str:
"""Get the resource label for a provider type (e.g., `AWSResource`)."""
"""Get the resource label for a provider type (e.g., `_AWSResource`)."""
config = PROVIDER_CONFIGS.get(provider_type)
return config.resource_label if config else "UnknownProviderResource"
return config.resource_label if config else "_UnknownProviderResource"
def get_deprecated_provider_resource_label(provider_type: str) -> str:
"""Get the deprecated resource label for a provider type (e.g., `AWSResource`)."""
config = PROVIDER_CONFIGS.get(provider_type)
return config.deprecated_resource_label if config else "UnknownProviderResource"

View File

@@ -25,6 +25,7 @@ from api.models import Provider, ResourceFindingMapping
from prowler.config import config as ProwlerConfig
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
get_deprecated_provider_resource_label,
get_node_uid_field,
get_provider_resource_label,
get_root_node_label,
@@ -152,6 +153,9 @@ def add_resource_label(
{
"__ROOT_LABEL__": get_root_node_label(provider_type),
"__RESOURCE_LABEL__": get_provider_resource_label(provider_type),
"__DEPRECATED_RESOURCE_LABEL__": get_deprecated_provider_resource_label(
provider_type
),
},
)

View File

@@ -6,6 +6,7 @@ from cartography.client.core.tx import run_write_query
from celery.utils.log import get_task_logger
from tasks.jobs.attack_paths.config import (
DEPRECATED_PROVIDER_RESOURCE_LABEL,
INTERNET_NODE_LABEL,
PROWLER_FINDING_LABEL,
PROVIDER_RESOURCE_LABEL,
@@ -23,9 +24,11 @@ class IndexType(Enum):
# Indexes for Prowler findings and resource lookups
FINDINGS_INDEX_STATEMENTS = [
# Resources indexes for quick Prowler Finding lookups
"CREATE INDEX aws_resource_arn IF NOT EXISTS FOR (n:AWSResource) ON (n.arn);",
"CREATE INDEX aws_resource_id IF NOT EXISTS FOR (n:AWSResource) ON (n.id);",
# Resource indexes for Prowler Finding lookups
"CREATE INDEX aws_resource_arn IF NOT EXISTS FOR (n:_AWSResource) ON (n.arn);",
"CREATE INDEX aws_resource_id IF NOT EXISTS FOR (n:_AWSResource) ON (n.id);",
"CREATE INDEX deprecated_aws_resource_arn IF NOT EXISTS FOR (n:AWSResource) ON (n.arn);",
"CREATE INDEX deprecated_aws_resource_id IF NOT EXISTS FOR (n:AWSResource) ON (n.id);",
# Prowler Finding indexes
f"CREATE INDEX prowler_finding_id IF NOT EXISTS FOR (n:{PROWLER_FINDING_LABEL}) ON (n.id);",
f"CREATE INDEX prowler_finding_provider_uid IF NOT EXISTS FOR (n:{PROWLER_FINDING_LABEL}) ON (n.provider_uid);",
@@ -37,8 +40,10 @@ FINDINGS_INDEX_STATEMENTS = [
# Indexes for provider resource sync operations
SYNC_INDEX_STATEMENTS = [
f"CREATE INDEX provider_element_id IF NOT EXISTS FOR (n:{PROVIDER_RESOURCE_LABEL}) ON (n.provider_element_id);",
f"CREATE INDEX provider_resource_provider_id IF NOT EXISTS FOR (n:{PROVIDER_RESOURCE_LABEL}) ON (n.provider_id);",
f"CREATE INDEX provider_element_id IF NOT EXISTS FOR (n:{PROVIDER_RESOURCE_LABEL}) ON (n._provider_element_id);",
f"CREATE INDEX provider_resource_provider_id IF NOT EXISTS FOR (n:{PROVIDER_RESOURCE_LABEL}) ON (n._provider_id);",
f"CREATE INDEX deprecated_provider_element_id IF NOT EXISTS FOR (n:{DEPRECATED_PROVIDER_RESOURCE_LABEL}) ON (n.provider_element_id);",
f"CREATE INDEX deprecated_provider_resource_provider_id IF NOT EXISTS FOR (n:{DEPRECATED_PROVIDER_RESOURCE_LABEL}) ON (n.provider_id);",
]

View File

@@ -26,7 +26,7 @@ ADD_RESOURCE_LABEL_TEMPLATE = """
MATCH (account:__ROOT_LABEL__ {id: $provider_uid})-->(r)
WHERE NOT r:__ROOT_LABEL__ AND NOT r:__RESOURCE_LABEL__
WITH r LIMIT $batch_size
SET r:__RESOURCE_LABEL__
SET r:__RESOURCE_LABEL__:__DEPRECATED_RESOURCE_LABEL__
RETURN COUNT(r) AS labeled_count
"""
@@ -151,16 +151,20 @@ RELATIONSHIPS_FETCH_QUERY = """
NODE_SYNC_TEMPLATE = """
UNWIND $rows AS row
MERGE (n:__NODE_LABELS__ {provider_element_id: row.provider_element_id})
MERGE (n:__NODE_LABELS__ {_provider_element_id: row.provider_element_id})
SET n += row.props
SET n._provider_id = $provider_id
SET n.provider_element_id = row.provider_element_id
SET n.provider_id = $provider_id
"""
""" # The last two lines are deprecated properties
RELATIONSHIP_SYNC_TEMPLATE = f"""
UNWIND $rows AS row
MATCH (s:{PROVIDER_RESOURCE_LABEL} {{provider_element_id: row.start_element_id}})
MATCH (t:{PROVIDER_RESOURCE_LABEL} {{provider_element_id: row.end_element_id}})
MERGE (s)-[r:__REL_TYPE__ {{provider_element_id: row.provider_element_id}}]->(t)
MATCH (s:{PROVIDER_RESOURCE_LABEL} {{_provider_element_id: row.start_element_id}})
MATCH (t:{PROVIDER_RESOURCE_LABEL} {{_provider_element_id: row.end_element_id}})
MERGE (s)-[r:__REL_TYPE__ {{_provider_element_id: row.provider_element_id}}]->(t)
SET r += row.props
SET r._provider_id = $provider_id
SET r.provider_element_id = row.provider_element_id
SET r.provider_id = $provider_id
"""
""" # The last two lines are deprecated properties

View File

@@ -11,7 +11,11 @@ from typing import Any
from celery.utils.log import get_task_logger
from api.attack_paths import database as graph_database
from tasks.jobs.attack_paths.config import BATCH_SIZE, PROVIDER_RESOURCE_LABEL
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
DEPRECATED_PROVIDER_RESOURCE_LABEL,
PROVIDER_RESOURCE_LABEL,
)
from tasks.jobs.attack_paths.indexes import IndexType, create_indexes
from tasks.jobs.attack_paths.queries import (
NODE_FETCH_QUERY,
@@ -70,7 +74,7 @@ def sync_nodes(
"""
Sync nodes from source to target database.
Adds `ProviderResource` label and `provider_id` property to all nodes.
Adds `_ProviderResource` label and `_provider_id` property to all nodes.
"""
last_id = -1
total_synced = 0
@@ -108,6 +112,7 @@ def sync_nodes(
for labels, batch in grouped.items():
label_set = set(labels)
label_set.add(PROVIDER_RESOURCE_LABEL)
label_set.add(DEPRECATED_PROVIDER_RESOURCE_LABEL)
node_labels = ":".join(f"`{label}`" for label in sorted(label_set))
query = render_cypher_template(
@@ -137,7 +142,7 @@ def sync_relationships(
"""
Sync relationships from source to target database.
Adds `provider_id` property to all relationships.
Adds `_provider_id` property to all relationships.
"""
last_id = -1
total_synced = 0
@@ -196,7 +201,9 @@ def sync_relationships(
def _strip_internal_properties(props: dict[str, Any]) -> None:
"""Remove internal properties that shouldn't be copied during sync."""
for key in [
"provider_element_id",
"provider_id",
"_provider_element_id",
"_provider_id",
"provider_element_id", # Deprecated
"provider_id", # Deprecated
]:
props.pop(key, None)

View File

@@ -5,6 +5,10 @@ from unittest.mock import MagicMock, call, patch
import pytest
from tasks.jobs.attack_paths import findings as findings_module
from tasks.jobs.attack_paths import internet as internet_module
from tasks.jobs.attack_paths import sync as sync_module
from tasks.jobs.attack_paths.config import (
get_deprecated_provider_resource_label,
)
from tasks.jobs.attack_paths.scan import run as attack_paths_run
from api.models import (
@@ -1073,6 +1077,69 @@ class TestAttackPathsFindingsHelpers:
mock_session.run.assert_not_called()
class TestProviderConfigAccessors:
def test_get_deprecated_provider_resource_label_known_provider(self):
assert get_deprecated_provider_resource_label("aws") == "AWSResource"
def test_get_deprecated_provider_resource_label_unknown_provider(self):
assert (
get_deprecated_provider_resource_label("unknown")
== "UnknownProviderResource"
)
class TestAddResourceLabel:
def test_add_resource_label_applies_both_labels(self):
mock_session = MagicMock()
first_result = MagicMock()
first_result.single.return_value = {"labeled_count": 5}
second_result = MagicMock()
second_result.single.return_value = {"labeled_count": 0}
mock_session.run.side_effect = [first_result, second_result]
total = findings_module.add_resource_label(mock_session, "aws", "123456789012")
assert total == 5
assert mock_session.run.call_count == 2
query = mock_session.run.call_args_list[0].args[0]
assert "_AWSResource" in query
assert "AWSResource" in query
class TestSyncNodes:
def test_sync_nodes_adds_both_labels(self):
mock_source_session = MagicMock()
mock_target_session = MagicMock()
row = {
"internal_id": 1,
"element_id": "elem-1",
"labels": ["SomeLabel"],
"props": {"key": "value"},
}
mock_source_session.run.side_effect = [[row], []]
source_ctx = MagicMock()
source_ctx.__enter__ = MagicMock(return_value=mock_source_session)
source_ctx.__exit__ = MagicMock(return_value=False)
target_ctx = MagicMock()
target_ctx.__enter__ = MagicMock(return_value=mock_target_session)
target_ctx.__exit__ = MagicMock(return_value=False)
with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[source_ctx, target_ctx],
):
total = sync_module.sync_nodes("source-db", "target-db", "prov-1")
assert total == 1
query = mock_target_session.run.call_args.args[0]
assert "_ProviderResource" in query
assert "ProviderResource" in query
class TestInternetAnalysis:
def _make_provider_and_config(self):
provider = MagicMock()