feat(api): make Attack Paths sink selectable between Neo4j and Neptune (#11524)

This commit is contained in:
Josema Camacho
2026-06-26 10:22:29 +02:00
committed by GitHub
parent 9b8b77cec0
commit 5793cd7e38
48 changed files with 9928 additions and 3210 deletions
+4
View File
@@ -169,3 +169,7 @@ GEMINI.md
# Claude Code # Claude Code
.claude/* .claude/*
# Docker
docker-compose.override.yml
docker-compose-dev.override.yml
+22 -3
View File
@@ -83,10 +83,18 @@ prowler dashboard
## Attack Paths ## Attack Paths
Attack Paths automatically extends every completed AWS scan with a Neo4j graph that combines Cartography's cloud inventory with Prowler findings. The feature runs in the API worker after each scan and therefore requires: Attack Paths automatically extends every completed AWS scan with a graph that combines Cartography's cloud inventory with Prowler findings. The feature runs in the API worker after each scan.
- An accessible Neo4j instance (the Docker Compose files already ships a `neo4j` service). Two graph backends are supported as the long-lived sink:
- The following environment variables so Django and Celery can connect:
- **Neo4j** (default; the Docker Compose files already ship a `neo4j` service).
- **Amazon Neptune** (cloud-managed; opt-in).
Select the sink with `ATTACK_PATHS_SINK_DATABASE` (`neo4j` or `neptune`; default `neo4j`).
> Note: Cartography ingestion always uses a temporary Neo4j database, regardless of the configured sink. The `NEO4J_*` variables below must remain set even when `ATTACK_PATHS_SINK_DATABASE=neptune`.
### Neo4j sink
| Variable | Description | Default | | Variable | Description | Default |
| --- | --- | --- | | --- | --- | --- |
@@ -94,6 +102,17 @@ Attack Paths automatically extends every completed AWS scan with a Neo4j graph t
| `NEO4J_PORT` | Bolt port exposed by Neo4j. | `7687` | | `NEO4J_PORT` | Bolt port exposed by Neo4j. | `7687` |
| `NEO4J_USER` / `NEO4J_PASSWORD` | Credentials with rights to create per-tenant databases. | `neo4j` / `neo4j_password` | | `NEO4J_USER` / `NEO4J_PASSWORD` | Credentials with rights to create per-tenant databases. | `neo4j` / `neo4j_password` |
### Neptune sink
| Variable | Description | Default |
| --- | --- | --- |
| `NEPTUNE_WRITER_ENDPOINT` | Bolt host for the Neptune writer instance. Required when sink is `neptune`. | _empty_ |
| `NEPTUNE_READER_ENDPOINT` | Optional reader endpoint for read-only queries. Falls back to the writer when unset. | _empty_ |
| `NEPTUNE_PORT` | Bolt port exposed by Neptune. | `8182` |
| `AWS_REGION` | Region the Neptune cluster lives in. Required when sink is `neptune`. | _empty_ |
Neptune authenticates with SigV4 using the standard boto3 credential chain. The worker's IAM role (or `AWS_ACCESS_KEY_ID` / `AWS_SECRET_ACCESS_KEY`) supplies the credentials. There is no Neptune password variable.
Every AWS provider scan will enqueue an Attack Paths ingestion job automatically. Other cloud providers will be added in future iterations. Every AWS provider scan will enqueue an Attack Paths ingestion job automatically. Other cloud providers will be added in future iterations.
+8
View File
@@ -2,6 +2,14 @@
All notable changes to the **Prowler API** are documented in this file. All notable changes to the **Prowler API** are documented in this file.
## [1.33.0] (Prowler UNRELEASED)
### 🔄 Changed
- Attack Paths: AWS Neptune is now supported as a persistent sink database, selectable via `ATTACK_PATHS_SINK_DATABASE=neptune` (default `neo4j`), Cartography's (bumped to 0.138.1) per-scan ingest database stays on Neo4j [(#11524)](https://github.com/prowler-cloud/prowler/pull/11524)
---
## [1.32.2] (Prowler UNRELEASED) ## [1.32.2] (Prowler UNRELEASED)
### 🐞 Fixed ### 🐞 Fixed
+11 -5
View File
@@ -58,7 +58,7 @@ dependencies = [
"matplotlib (==3.10.8)", "matplotlib (==3.10.8)",
"reportlab (==4.4.10)", "reportlab (==4.4.10)",
"neo4j (==6.1.0)", "neo4j (==6.1.0)",
"cartography (==0.135.0)", "cartography (==0.138.1)",
"gevent (==25.9.1)", "gevent (==25.9.1)",
"werkzeug (==3.1.7)", "werkzeug (==3.1.7)",
"sqlparse (==0.5.5)", "sqlparse (==0.5.5)",
@@ -193,7 +193,7 @@ constraint-dependencies = [
"blinker==1.9.0", "blinker==1.9.0",
"boto3==1.40.61", "boto3==1.40.61",
"botocore==1.40.61", "botocore==1.40.61",
"cartography==0.135.0", "cartography==0.138.1",
"celery==5.6.2", "celery==5.6.2",
"certifi==2026.1.4", "certifi==2026.1.4",
"cffi==2.0.0", "cffi==2.0.0",
@@ -447,7 +447,7 @@ constraint-dependencies = [
"wcwidth==0.5.3", "wcwidth==0.5.3",
"websocket-client==1.9.0", "websocket-client==1.9.0",
"werkzeug==3.1.7", "werkzeug==3.1.7",
"workos==6.0.4", "workos==6.0.8",
"wrapt==1.17.3", "wrapt==1.17.3",
"xlsxwriter==3.2.9", "xlsxwriter==3.2.9",
"xmlsec==1.3.17", "xmlsec==1.3.17",
@@ -458,8 +458,13 @@ constraint-dependencies = [
"zope-interface==8.2", "zope-interface==8.2",
"zstd==1.5.7.3" "zstd==1.5.7.3"
] ]
# prowler@master needs okta==3.4.2; cartography 0.135.0 declares okta<1.0.0 for an # prowler@master needs okta==3.4.2, but cartography 0.138.1 requires okta<1.0.0.
# integration prowler does not import. # Attack Paths does not ingest Okta today, so override the Cartography
# dependency to the Prowler pin.
#
# prowler@master needs azure-mgmt-containerservice==34.1.0, but cartography
# 0.138.1 requires azure-mgmt-containerservice>=41.0.0. Attack Paths does not
# ingest Azure today, so override the Cartography dependency to the Prowler pin.
# #
# prowler@master hard-pins microsoft-kiota-abstractions==1.9.2 in [project.dependencies]. # prowler@master hard-pins microsoft-kiota-abstractions==1.9.2 in [project.dependencies].
# The microsoft-kiota-http security bump to 1.9.9 (GHSA-7j59-v9qr-6fq9) requires # The microsoft-kiota-http security bump to 1.9.9 (GHSA-7j59-v9qr-6fq9) requires
@@ -475,6 +480,7 @@ constraint-dependencies = [
# that request pyjwt[crypto] and leave cryptography (needed for RS256) only transitive. # that request pyjwt[crypto] and leave cryptography (needed for RS256) only transitive.
override-dependencies = [ override-dependencies = [
"okta==3.4.2", "okta==3.4.2",
"azure-mgmt-containerservice==34.1.0",
"microsoft-kiota-abstractions==1.9.9", "microsoft-kiota-abstractions==1.9.9",
"dulwich==1.2.5", "dulwich==1.2.5",
"pyjwt[crypto]==2.13.0" "pyjwt[crypto]==2.13.0"
-3
View File
@@ -42,9 +42,6 @@ class ApiConfig(AppConfig):
): ):
self._ensure_crypto_keys() self._ensure_crypto_keys()
# Neo4j driver is created lazily on first use (see api.attack_paths.database).
# App init never contacts Neo4j, so a Neo4j outage cannot block API startup.
def _ensure_crypto_keys(self): def _ensure_crypto_keys(self):
""" """
Orchestrator method that ensures all required cryptographic keys are present. Orchestrator method that ensures all required cryptographic keys are present.
@@ -4,10 +4,10 @@ Cypher sanitizer for custom (user-supplied) Attack Paths queries.
Two responsibilities: Two responsibilities:
1. **Validation** - reject queries containing SSRF or dangerous procedure 1. **Validation** - reject queries containing SSRF or dangerous procedure
patterns (defense-in-depth; the primary control is ``neo4j.READ_ACCESS``). patterns (defense-in-depth; the primary control is `neo4j.READ_ACCESS`).
2. **Provider-scoped label injection** - inject a dynamic 2. **Provider-scoped label injection** - inject a dynamic
``_Provider_{uuid}`` label into every node pattern so the database can `_Provider_{uuid}` label into every node pattern so the database can
use its native label index for provider isolation. use its native label index for provider isolation.
Label-injection pipeline: Label-injection pipeline:
@@ -25,13 +25,13 @@ from rest_framework.exceptions import ValidationError
from tasks.jobs.attack_paths.config import get_provider_label from tasks.jobs.attack_paths.config import get_provider_label
# Step 1 - String / comment protection # Step 1 - String / comment protection
# Single combined regex: strings first, then line comments. # Single combined regex: strings first, then line comments
# The regex engine finds the leftmost match, so a string like 'https://prowler.com' # The regex engine finds the leftmost match, so a string like 'https://prowler.com'
# is consumed as a string before the // inside it can match as a comment. # is consumed as a string before the // inside it can match as a comment
_PROTECTED_RE = re.compile(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"|//[^\n]*") _PROTECTED_RE = re.compile(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"|//[^\n]*")
# Step 2 - Clause splitting # Step 2 - Clause splitting
# OPTIONAL MATCH must come before MATCH to avoid partial matching. # `OPTIONAL MATCH` must come before `MATCH` to avoid partial matching
_CLAUSE_RE = re.compile( _CLAUSE_RE = re.compile(
r"\b(OPTIONAL\s+MATCH|MATCH|WHERE|RETURN|WITH|ORDER\s+BY" r"\b(OPTIONAL\s+MATCH|MATCH|WHERE|RETURN|WITH|ORDER\s+BY"
r"|SKIP|LIMIT|UNION|UNWIND|CALL)\b", r"|SKIP|LIMIT|UNION|UNWIND|CALL)\b",
@@ -39,10 +39,10 @@ _CLAUSE_RE = re.compile(
) )
# Pass A - Labeled node patterns (all segments) # Pass A - Labeled node patterns (all segments)
# Matches node patterns that have at least one :Label. # Matches node patterns that have at least one `:Label`
# (?<!\w)\( - open paren NOT preceded by a word char (excludes function calls). # `(?<!\w)\(` - open paren NOT preceded by a word char, excludes function calls
# Group 1: optional variable + one or more :Label # Group 1: optional variable + one or more `:Label`
# Group 2: optional {properties} + closing paren # Group 2: optional `{`properties`}` + closing paren
_LABELED_NODE_RE = re.compile( _LABELED_NODE_RE = re.compile(
r"(?<!\w)\(" r"(?<!\w)\("
r"(" r"("
@@ -55,9 +55,9 @@ _LABELED_NODE_RE = re.compile(
r")" r")"
) )
# Pass B - Bare node patterns (MATCH segments only) # Pass B - Bare node patterns (`MATCH` segments only)
# Matches (identifier) or (identifier {properties}) without any :Label. # Matches (identifier) or (identifier {properties}) without any `:Label`
# Only applied in MATCH/OPTIONAL MATCH segments. # Only applied in `MATCH` / `OPTIONAL MATCH` segments
_BARE_NODE_RE = re.compile( _BARE_NODE_RE = re.compile(
r"(?<!\w)\(" r"(\s*[a-zA-Z_]\w*)" r"(\s*(?:\{[^}]*\})?)" r"\s*\)" r"(?<!\w)\(" r"(\s*[a-zA-Z_]\w*)" r"(\s*(?:\{[^}]*\})?)" r"\s*\)"
) )
@@ -134,9 +134,7 @@ def inject_provider_label(cypher: str, provider_id: str) -> str:
return work return work
# ---------------------------------------------------------------------------
# Validation # Validation
# ---------------------------------------------------------------------------
# Patterns that indicate SSRF or dangerous procedure calls # Patterns that indicate SSRF or dangerous procedure calls
# Defense-in-depth layer - the primary control is `neo4j.READ_ACCESS` # Defense-in-depth layer - the primary control is `neo4j.READ_ACCESS`
+170 -251
View File
@@ -1,261 +1,32 @@
import atexit """Backwards-compatible facade over the ingest and sink modules.
import logging
import threading Historically this module owned a single Neo4j driver used for both the
from collections.abc import Iterator cartography temp database and the per-tenant sink database. The port to AWS
from contextlib import contextmanager Neptune split those roles: the cartography ingest (temp) database is always
Neo4j and lives in `api.attack_paths.ingest`; the sink is configurable
(Neo4j or Neptune) and lives in `api.attack_paths.sink`. This shim preserves
the public API that `tasks/` and `api/v1/views.py` already depend on, and
dispatches to the right module by database-name prefix.
A database name starting with `db-tmp-scan-` is a cartography temp DB and
routes to ingest. Everything else routes to the configured sink.
"""
from contextlib import AbstractContextManager
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
import neo4j import neo4j # noqa: F401 - kept for tests that patch api.attack_paths.database.neo4j
import neo4j.exceptions from api.attack_paths import ingest
from api.attack_paths.retryable_session import RetryableSession from api.attack_paths import sink as sink_module
from config.env import env from config.env import env
from django.conf import settings from django.conf import (
from tasks.jobs.attack_paths.config import ( settings, # noqa: F401 - kept for tests that patch ...database.settings
BATCH_SIZE,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
) )
# Without this Celery goes crazy with Neo4j logging
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
READ_QUERY_TIMEOUT_SECONDS = env.int(
"ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30
)
MAX_CUSTOM_QUERY_NODES = env.int("ATTACK_PATHS_MAX_CUSTOM_QUERY_NODES", default=250) MAX_CUSTOM_QUERY_NODES = env.int("ATTACK_PATHS_MAX_CUSTOM_QUERY_NODES", default=250)
# Shorter than CONN_ACQUISITION_TIMEOUT — the driver requires acquisition to be
# the longer of the two (it may include opening a new connection).
CONNECTION_TIMEOUT = env.int("NEO4J_CONNECTION_TIMEOUT", default=5)
CONN_ACQUISITION_TIMEOUT = env.int("NEO4J_CONN_ACQUISITION_TIMEOUT", default=15)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
# Module-level process-wide driver singleton TEMP_DB_PREFIX = "db-tmp-scan-"
_driver: neo4j.Driver | None = None
_lock = threading.Lock()
# Base Neo4j functions
def get_uri() -> str:
host = settings.DATABASES["neo4j"]["HOST"]
port = settings.DATABASES["neo4j"]["PORT"]
return f"bolt://{host}:{port}"
def init_driver() -> neo4j.Driver:
global _driver
if _driver is not None:
return _driver
with _lock:
if _driver is None:
uri = get_uri()
config = settings.DATABASES["neo4j"]
driver = neo4j.GraphDatabase.driver(
uri,
auth=(config["USER"], config["PASSWORD"]),
keep_alive=True,
max_connection_lifetime=7200,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=50,
)
# Publish the singleton only after connectivity is verified so a
# failed probe does not leave an unverified driver behind. Close the
# driver on failure so a repeatedly-probed outage cannot leak pools.
try:
driver.verify_connectivity()
except Exception:
driver.close()
raise
_driver = driver
# Register cleanup handler (only runs once since we're inside the _driver is None block)
atexit.register(close_driver)
return _driver
def get_driver() -> neo4j.Driver:
return init_driver()
def close_driver() -> None: # TODO: Use it
global _driver
with _lock:
if _driver is not None:
try:
_driver.close()
finally:
_driver = None
@contextmanager
def get_session(
database: str | None = None, default_access_mode: str | None = None
) -> Iterator[RetryableSession]:
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: get_driver().session(
database=database, default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
message = "Read query not allowed"
code = READ_EXCEPTION_CODES[0]
raise WriteQueryNotAllowedException(message=message, code=code)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
def execute_read_query(
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
with get_session(database, default_access_mode=neo4j.READ_ACCESS) as session:
def _run(tx: neo4j.ManagedTransaction) -> neo4j.graph.Graph:
result = tx.run(
cypher, parameters or {}, timeout=READ_QUERY_TIMEOUT_SECONDS
)
return result.graph()
return session.execute_read(_run)
def create_database(database: str) -> None:
query = "CREATE DATABASE $database IF NOT EXISTS"
parameters = {"database": database}
with get_session() as session:
session.run(query, parameters)
def drop_database(database: str) -> None:
query = f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA"
with get_session() as session:
session.run(query)
def drop_subgraph(database: str, provider_id: str) -> int:
"""
Delete all nodes for a provider from the tenant database.
Deletes relationships then nodes in batches (not `DETACH DELETE`) so a dense
provider's graph cannot exceed Neo4j's transaction memory limit.
Silently returns 0 if the database doesn't exist.
"""
provider_label = get_provider_label(provider_id)
deleted_nodes = 0
try:
with get_session(database) as session:
# Phase 1: delete relationships incident to provider nodes in batches.
deleted_count = 1
while deleted_count > 0:
result = session.run(
f"""
MATCH (:`{provider_label}`)-[r]-()
WITH DISTINCT r LIMIT $batch_size
DELETE r
RETURN COUNT(r) AS deleted_rels_count
""",
{"batch_size": BATCH_SIZE},
)
deleted_count = result.single().get("deleted_rels_count", 0)
# Phase 2: delete the now relationship-free nodes in batches.
deleted_count = 1
while deleted_count > 0:
result = session.run(
f"""
MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`)
WITH n LIMIT $batch_size
DELETE n
RETURN COUNT(n) AS deleted_nodes_count
""",
{"batch_size": BATCH_SIZE},
)
deleted_count = result.single().get("deleted_nodes_count", 0)
deleted_nodes += deleted_count
except GraphDatabaseQueryException as exc:
if exc.code == "Neo.ClientError.Database.DatabaseNotFound":
return 0
raise
return deleted_nodes
def has_provider_data(database: str, provider_id: str) -> bool:
"""
Check if any ProviderResource node exists for this provider.
Returns `False` if the database doesn't exist.
"""
provider_label = get_provider_label(provider_id)
query = f"MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`) RETURN 1 LIMIT 1"
try:
with get_session(database, default_access_mode=neo4j.READ_ACCESS) as session:
result = session.run(query)
return result.single() is not None
except GraphDatabaseQueryException as exc:
if exc.code == "Neo.ClientError.Database.DatabaseNotFound":
return False
raise
def clear_cache(database: str) -> None:
query = "CALL db.clearQueryCaches()"
try:
with get_session(database) as session:
session.run(query)
except GraphDatabaseQueryException as exc:
logging.warning(f"Failed to clear query cache for database `{database}`: {exc}")
# Neo4j functions related to Prowler + Cartography
def get_database_name(entity_id: str | UUID, temporary: bool = False) -> str:
prefix = "tmp-scan" if temporary else "tenant"
return f"db-{prefix}-{str(entity_id).lower()}"
# Exceptions # Exceptions
@@ -270,7 +41,6 @@ class GraphDatabaseQueryException(Exception):
def __str__(self) -> str: def __str__(self) -> str:
if self.code: if self.code:
return f"{self.code}: {self.message}" return f"{self.code}: {self.message}"
return self.message return self.message
@@ -280,3 +50,152 @@ class WriteQueryNotAllowedException(GraphDatabaseQueryException):
class ClientStatementException(GraphDatabaseQueryException): class ClientStatementException(GraphDatabaseQueryException):
pass pass
# Routing
def _is_ingest_database(database: str | None) -> bool:
return bool(database) and database.startswith(TEMP_DB_PREFIX)
# Driver lifecycle
def init_driver() -> Any:
"""Initialize the configured sink backend.
The ingest driver (Neo4j for cartography temp DBs) stays lazy: it is
only initialized when a temp-DB operation actually runs, which never
happens on API pods.
"""
return sink_module.init()
def close_driver() -> None:
"""Close every driver held by this process."""
sink_module.close()
ingest.close_driver()
def get_driver() -> neo4j.Driver:
"""Return the sink backend's underlying driver.
Only meaningful for the Neo4j sink (where the backend has a single Neo4j
driver). On Neptune this returns the writer driver. Kept for tests and
legacy call-sites; prefer `get_session` for new code.
"""
backend = sink_module.get_backend()
# Neo4jSink exposes get_driver(); NeptuneSink exposes get_writer()
if hasattr(backend, "get_driver"):
return backend.get_driver()
if hasattr(backend, "get_writer"):
return backend.get_writer()
raise RuntimeError("Active sink backend does not expose a driver handle")
def verify_connectivity() -> None:
"""Raise if the configured graph database is unreachable on the API read path.
Backend-agnostic entry point for the readiness probe: Neo4j verifies its
driver, Neptune verifies the reader endpoint.
"""
sink_module.get_backend().verify_connectivity()
def get_uri() -> str:
"""Return the sink URI. Retained for backwards compatibility."""
if settings.ATTACK_PATHS_SINK_DATABASE == "neptune":
cfg = settings.DATABASES["neptune"]
return f"bolt+s://{cfg['WRITER_ENDPOINT']}:{cfg['PORT']}"
cfg = settings.DATABASES["neo4j"]
return f"bolt://{cfg['HOST']}:{cfg['PORT']}"
def get_ingest_uri() -> str:
"""Neo4j URI for the cartography temp (ingest) database, which is always
Neo4j regardless of the configured sink."""
return ingest.get_uri()
# Session API
def get_session(
database: str | None = None,
default_access_mode: str | None = None,
) -> AbstractContextManager:
"""Return a session against the right backend.
- `database` names starting with `db-tmp-scan-` always go to ingest.
- No database name → ingest (used for CREATE / DROP DATABASE admin ops).
- Any other name → sink.
"""
if _is_ingest_database(database) or database is None:
return ingest.get_session(
database=database, default_access_mode=default_access_mode
)
return sink_module.get_backend().get_session(
database=database, default_access_mode=default_access_mode
)
def execute_read_query(
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
"""Read-only query against the sink."""
return sink_module.get_backend().execute_read_query(database, cypher, parameters)
def create_database(database: str) -> None:
"""Create a database. Temp DBs always land on ingest (Neo4j).
On the Neo4j sink, tenant DBs also route to ingest because both drivers
connect to the same Neo4j cluster. On the Neptune sink, tenant DB creates
are no-ops.
"""
if _is_ingest_database(database):
ingest.create_database(database)
return
sink_module.get_backend().create_database(database)
def drop_database(database: str) -> None:
"""Drop a database. Mirrors `create_database` routing."""
if _is_ingest_database(database):
ingest.drop_database(database)
return
sink_module.get_backend().drop_database(database)
def drop_subgraph(database: str, provider_id: str) -> int:
return sink_module.get_backend().drop_subgraph(database, provider_id)
def has_provider_data(database: str, provider_id: str) -> bool:
return sink_module.get_backend().has_provider_data(database, provider_id)
def clear_cache(database: str) -> None:
if _is_ingest_database(database):
ingest.clear_cache(database)
return
sink_module.get_backend().clear_cache(database)
# Name helper
def get_database_name(entity_id: str | UUID, temporary: bool = False) -> str:
prefix = "tmp-scan" if temporary else "tenant"
return f"db-{prefix}-{str(entity_id).lower()}"
@@ -0,0 +1,29 @@
"""Cartography ingest layer.
Public surface for the per-scan Neo4j temp database driver. Implementation
lives in `api.attack_paths.ingest.driver`.
"""
from api.attack_paths.ingest.driver import (
clear_cache,
close_driver,
create_database,
drop_database,
get_driver,
get_session,
get_uri,
init_driver,
run_cypher,
)
__all__ = [
"clear_cache",
"close_driver",
"create_database",
"drop_database",
"get_driver",
"get_session",
"get_uri",
"init_driver",
"run_cypher",
]
@@ -0,0 +1,187 @@
"""Cartography ingest driver: per-scan throw-away Neo4j database.
Cartography writes each scan's graph into a throw-away Neo4j database named
`db-tmp-scan-{scan_uuid}`. This is always Neo4j, regardless of the configured
sink: Neptune is single-database and cannot host per-scan throw-away
databases. This module owns the Neo4j driver used for those temp DBs and the
admin ops they need (CREATE / DROP DATABASE).
"""
import atexit
import logging
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from config.env import env
from django.conf import settings
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
CONN_ACQUISITION_TIMEOUT = env.int("NEO4J_CONN_ACQUISITION_TIMEOUT", default=15)
# TCP connect timeout, ordered below the acquisition timeout so an unreachable
# host can't pin a worker on a temp-DB op longer than this.
CONNECTION_TIMEOUT = env.int("NEO4J_CONNECTION_TIMEOUT", default=5)
MAX_CONNECTION_LIFETIME = env.int("NEO4J_MAX_CONNECTION_LIFETIME", default=7200)
MAX_CONNECTION_POOL_SIZE = env.int("NEO4J_MAX_CONNECTION_POOL_SIZE", default=50)
_driver: neo4j.Driver | None = None
_lock = threading.Lock()
def _neo4j_config() -> dict:
return settings.DATABASES["neo4j"]
def get_uri() -> str:
"""Bolt URI for the Neo4j temp (ingest) database. Always Neo4j."""
config = _neo4j_config()
host = config["HOST"]
port = config["PORT"]
if not host or not port:
raise RuntimeError(
"NEO4J_HOST / NEO4J_PORT must be set to use the attack-paths "
"temp database. Workers require Neo4j env even when the sink is Neptune."
)
return f"bolt://{host}:{port}"
def init_driver() -> neo4j.Driver:
"""Initialize the temp-database Neo4j driver. Idempotent."""
global _driver
if _driver is not None:
return _driver
with _lock:
if _driver is None:
config = _neo4j_config()
_driver = neo4j.GraphDatabase.driver(
get_uri(),
auth=(config["USER"], config["PASSWORD"]),
keep_alive=True,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
)
# Best-effort connectivity check: a Neo4j that is down at boot must
# not crash the worker. The driver reconnects lazily on first use.
try:
_driver.verify_connectivity()
except Exception:
logging.warning(
"Neo4j temp-database unreachable at init; continuing with a "
"lazily-reconnecting driver",
exc_info=True,
)
atexit.register(close_driver)
return _driver
def get_driver() -> neo4j.Driver:
return init_driver()
def close_driver() -> None:
global _driver
with _lock:
if _driver is not None:
try:
_driver.close()
finally:
_driver = None
@contextmanager
def get_session(
database: str | None = None,
default_access_mode: str | None = None,
) -> Iterator[RetryableSession]:
"""Session against the Neo4j temp-database cluster. Used for temp DB sessions
and for admin operations (CREATE / DROP DATABASE) when `database` is None."""
from api.attack_paths.database import (
ClientStatementException,
GraphDatabaseQueryException,
WriteQueryNotAllowedException,
)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: get_driver().session(
database=database, default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
raise WriteQueryNotAllowedException(
message="Read query not allowed", code=READ_EXCEPTION_CODES[0]
)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
def create_database(database: str) -> None:
"""Create a database on the Neo4j cluster. Used for temp scan DBs."""
with get_session() as session:
session.run("CREATE DATABASE $database IF NOT EXISTS", {"database": database})
def drop_database(database: str) -> None:
"""Drop a database on the Neo4j cluster. Used for temp scan DBs."""
with get_session() as session:
session.run(f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA")
def clear_cache(database: str) -> None:
"""Best-effort cache clear for a Neo4j database."""
from api.attack_paths.database import GraphDatabaseQueryException
try:
with get_session(database) as session:
session.run("CALL db.clearQueryCaches()")
except GraphDatabaseQueryException as exc:
logging.warning(f"Failed to clear query cache for database `{database}`: {exc}")
def run_cypher(
database: str | None,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> Any:
"""Execute Cypher directly without the context manager. Thin helper."""
with get_session(database) as session:
return session.run(cypher, parameters or {})
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,12 +1,14 @@
from api.attack_paths.queries.aws import AWS_QUERIES from api.attack_paths.queries.aws import AWS_QUERIES
# TODO: drop after Neptune cutover
from api.attack_paths.queries.aws_deprecated import AWS_DEPRECATED_QUERIES
from api.attack_paths.queries.types import AttackPathsQueryDefinition from api.attack_paths.queries.types import AttackPathsQueryDefinition
# Query definitions organized by provider # Query definitions for scans synced with the current schema.
_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = { _QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
"aws": AWS_QUERIES, "aws": AWS_QUERIES,
} }
# Flat lookup by query ID for O(1) access
_QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = { _QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
definition.id: definition definition.id: definition
for definitions in _QUERY_DEFINITIONS.values() for definitions in _QUERY_DEFINITIONS.values()
@@ -14,11 +16,45 @@ _QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
} }
def get_queries_for_provider(provider: str) -> list[AttackPathsQueryDefinition]: # TODO: drop after Neptune cutover
"""Get all attack path queries for a specific provider.""" #
return _QUERY_DEFINITIONS.get(provider, []) # Query definitions for pre-cutover scans (`AttackPathsScan.is_migrated=False`)
# whose graph data was written under the previous schema. Both maps expose the
# same query IDs so the API contract is identical regardless of which set is
# routed to.
_DEPRECATED_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
"aws": AWS_DEPRECATED_QUERIES,
}
_DEPRECATED_QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
definition.id: definition
for definitions in _DEPRECATED_QUERY_DEFINITIONS.values()
for definition in definitions
}
def get_query_by_id(query_id: str) -> AttackPathsQueryDefinition | None: def get_queries_for_provider(
"""Get a specific attack path query by its ID.""" provider: str,
return _QUERIES_BY_ID.get(query_id) is_migrated: bool = True,
) -> list[AttackPathsQueryDefinition]:
"""Get all attack path queries for a provider.
`is_migrated` selects the catalog: True for scans synced with the current
schema, False for pre-cutover scans still using the legacy graph shape.
# TODO: drop the `is_migrated` parameter after Neptune cutover
"""
catalog = _QUERY_DEFINITIONS if is_migrated else _DEPRECATED_QUERY_DEFINITIONS
return catalog.get(provider, [])
def get_query_by_id(
query_id: str,
is_migrated: bool = True,
) -> AttackPathsQueryDefinition | None:
"""Get a specific attack path query by ID.
`is_migrated` selects the catalog (see `get_queries_for_provider`).
# TODO: drop the `is_migrated` parameter after Neptune cutover
"""
by_id = _QUERIES_BY_ID if is_migrated else _DEPRECATED_QUERIES_BY_ID
return by_id.get(query_id)
@@ -0,0 +1,28 @@
"""Attack-paths sink database layer.
The sink is the persistent store where attack-paths graphs live after a scan
finishes. Currently selectable between Neo4j (OSS / local dev default) and
AWS Neptune (hosted dev/staging/prod). Backend is picked by the
`ATTACK_PATHS_SINK_DATABASE` setting at process init.
This package exposes the public factory API; the implementation lives in
`api.attack_paths.sink.factory`.
"""
from api.attack_paths.sink.factory import (
SinkBackend,
close,
get_backend,
get_backend_for_name,
get_backend_for_scan,
init,
)
__all__ = [
"SinkBackend",
"close",
"get_backend",
"get_backend_for_name",
"get_backend_for_scan",
"init",
]
@@ -0,0 +1,92 @@
"""Protocol every sink backend must implement."""
from contextlib import AbstractContextManager
from typing import Any, Protocol
import neo4j
class SinkDatabase(Protocol):
"""Contract for the persistent attack-paths graph store.
The `database` argument is an opaque identifier passed through from the
legacy `database.py` API surface. On Neo4j it is the per-tenant database
name (e.g. `db-tenant-{uuid}`). On Neptune it is ignored (the cluster
has a single graph, and isolation is label-based).
"""
def init(self) -> None: ...
def close(self) -> None: ...
def verify_connectivity(self) -> None:
"""Raise if the backend the API read path uses is unreachable.
Neo4j verifies its single driver. Neptune verifies the reader
driver (the endpoint the API serves reads from); on single-endpoint
clusters the reader aliases the writer, so that path is covered too.
Used by the readiness probe; must not block longer than the caller's
probe budget.
"""
...
def get_session(
self,
database: str | None = None,
default_access_mode: str | None = None,
) -> AbstractContextManager: ...
def execute_read_query(
self,
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph: ...
def create_database(self, database: str) -> None: ...
def drop_database(self, database: str) -> None: ...
def drop_subgraph(self, database: str, provider_id: str) -> int: ...
def has_provider_data(self, database: str, provider_id: str) -> bool: ...
def clear_cache(self, database: str) -> None: ...
def ensure_sync_indexes(self, database: str) -> None:
"""Create any index needed for the sync write path.
Called once at the start of each provider sync; must be idempotent.
Neo4j creates a `_provider_element_id` index on `_ProviderResource`;
Neptune is a no-op (its `~id` lookup needs no index).
"""
...
def write_nodes(
self,
database: str,
labels: str,
rows: list[dict[str, Any]],
) -> None:
"""Upsert a batch of nodes into the sink.
`labels` is a pre-rendered Cypher label string ready to drop after
the node variable (e.g. `` `AWSUser`:`_ProviderResource`:`_Tenant_x` ``).
Each row carries `provider_element_id` and `props`.
"""
...
def write_relationships(
self,
database: str,
rel_type: str,
provider_id: str,
rows: list[dict[str, Any]],
) -> None:
"""Upsert a batch of relationships into the sink.
Each row carries `start_element_id`, `end_element_id`,
`provider_element_id` and `props`. `rel_type` is the relationship
type (already a valid Cypher identifier).
"""
...
@@ -0,0 +1,134 @@
"""Sink backend factory and process-wide handle cache.
Picks the active backend from `settings.ATTACK_PATHS_SINK_DATABASE` at first
use, holds the active backend plus any secondary backends needed to serve
scans written under the previous configuration, and tears them all down on
process shutdown. Imported via `from api.attack_paths import sink as
sink_module`.
"""
import threading
from enum import StrEnum, auto
from api.attack_paths.sink.base import SinkDatabase
from api.models import AttackPathsScan
from django.conf import settings
# Backend names
class SinkBackend(StrEnum):
NEO4J = auto()
NEPTUNE = auto()
# Backend cache
_backend: SinkDatabase | None = None
_secondary_backends: dict[SinkBackend, SinkDatabase] = {}
_lock = threading.Lock()
def _resolve_setting() -> SinkBackend:
raw = settings.ATTACK_PATHS_SINK_DATABASE.lower()
try:
return SinkBackend(raw)
except ValueError:
valid = sorted(b.value for b in SinkBackend)
raise RuntimeError(
f"ATTACK_PATHS_SINK_DATABASE must be one of {valid}; got {raw!r}"
)
def _build_backend(name: SinkBackend) -> SinkDatabase:
if name is SinkBackend.NEO4J:
from api.attack_paths.sink.neo4j import Neo4jSink
return Neo4jSink()
if name is SinkBackend.NEPTUNE:
from api.attack_paths.sink.neptune import NeptuneSink
return NeptuneSink()
raise RuntimeError(f"Unknown sink backend {name!r}")
# Lifecycle
def init(name: SinkBackend | str | None = None) -> SinkDatabase:
"""Initialize the configured sink backend. Idempotent."""
global _backend
if _backend is not None:
return _backend
with _lock:
if _backend is None:
resolved = SinkBackend(name) if name else _resolve_setting()
backend = _build_backend(resolved)
backend.init()
_backend = backend
return _backend
def close() -> None:
"""Close the active backend and every cached secondary backend."""
global _backend
with _lock:
backends = [
b for b in (_backend, *_secondary_backends.values()) if b is not None
]
_backend = None
_secondary_backends.clear()
for backend in backends:
try:
backend.close()
except Exception: # pragma: no cover - best-effort
pass
def get_backend() -> SinkDatabase:
"""Return the active sink. Initializes on first call."""
return init()
# Per-scan routing
def get_backend_for_scan(scan: AttackPathsScan) -> SinkDatabase:
"""Route reads by the sink that stores this scan's graph."""
raw_backend = getattr(scan, "sink_backend", SinkBackend.NEO4J.value)
if not isinstance(raw_backend, str):
raw_backend = SinkBackend.NEO4J.value
return get_backend_for_name(raw_backend)
def get_backend_for_name(name: SinkBackend | str) -> SinkDatabase:
"""Return the backend named by persisted scan metadata."""
resolved = SinkBackend(name)
if resolved is _resolve_setting():
return get_backend()
return _build_backend_cached(resolved)
def _build_backend_cached(name: SinkBackend) -> SinkDatabase:
# TODO: drop after Neptune cutover
# Needed only during cutover to serve Neo4j-written scans from a Neptune-
# configured API pod (and vice versa). Once every scan is on Neptune,
# `get_backend_for_scan` becomes a one-liner returning `get_backend()`.
if name in _secondary_backends:
return _secondary_backends[name]
with _lock:
if name not in _secondary_backends:
backend = _build_backend(name)
backend.init()
_secondary_backends[name] = backend
return _secondary_backends[name]
@@ -0,0 +1,454 @@
"""Neo4j sink implementation.
Owns a Neo4j driver independent from the staging driver. On OSS and local dev
this is the only sink; on hosted deployments it runs only as a legacy read
path while phase-1 drains tenant DBs.
"""
import atexit
import logging
import threading
import time
from collections.abc import Iterator
from contextlib import AbstractContextManager, contextmanager
from typing import Any
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from api.attack_paths.sink.base import SinkDatabase
from config.env import env
from django.conf import settings
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
logger = logging.getLogger(__name__)
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
READ_QUERY_TIMEOUT_SECONDS = env.int(
"ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30
)
CONN_ACQUISITION_TIMEOUT = env.int("NEO4J_CONN_ACQUISITION_TIMEOUT", default=15)
# TCP connect timeout, ordered below the acquisition timeout so an unreachable
# host can't pin a request or the readiness probe longer than this.
CONNECTION_TIMEOUT = env.int("NEO4J_CONNECTION_TIMEOUT", default=5)
MAX_CONNECTION_LIFETIME = env.int("NEO4J_MAX_CONNECTION_LIFETIME", default=7200)
MAX_CONNECTION_POOL_SIZE = env.int("NEO4J_MAX_CONNECTION_POOL_SIZE", default=50)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
DATABASE_NOT_FOUND_CODE = "Neo.ClientError.Database.DatabaseNotFound"
class Neo4jSink(SinkDatabase):
"""Neo4j-backed sink. Multi-database cluster; tenant isolation is physical."""
def __init__(self) -> None:
self._driver: neo4j.Driver | None = None
self._lock = threading.Lock()
self._atexit_registered = False
# Driver
def _config(self) -> dict:
return settings.DATABASES["neo4j"]
def _uri(self) -> str:
cfg = self._config()
host = cfg["HOST"]
port = cfg["PORT"]
if not host or not port:
raise RuntimeError(
"NEO4J_HOST / NEO4J_PORT must be set when ATTACK_PATHS_SINK_DATABASE=neo4j"
)
return f"bolt://{host}:{port}"
def init(self) -> neo4j.Driver:
if self._driver is not None:
return self._driver
with self._lock:
if self._driver is None:
cfg = self._config()
self._driver = neo4j.GraphDatabase.driver(
self._uri(),
auth=(cfg["USER"], cfg["PASSWORD"]),
keep_alive=True,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
)
# Eager connectivity check is best-effort:
# A Neo4j that is down at boot must not crash the process, same degradation model as Postgres
# The driver reconnects lazily on first use
# /health/ready surfaces the outage until it recovers
try:
self._driver.verify_connectivity()
except Exception:
logger.warning(
"Neo4j sink unreachable at init; continuing with a lazily-reconnecting driver",
exc_info=True,
)
if not self._atexit_registered:
atexit.register(self.close)
self._atexit_registered = True
return self._driver
def _get_driver(self) -> neo4j.Driver:
return self.init()
def verify_connectivity(self) -> None:
self._get_driver().verify_connectivity()
def close(self) -> None:
with self._lock:
if self._driver is not None:
try:
self._driver.close()
finally:
self._driver = None
# Sessions
@contextmanager
def get_session(
self,
database: str | None = None,
default_access_mode: str | None = None,
) -> Iterator[RetryableSession]:
from api.attack_paths.database import (
ClientStatementException,
GraphDatabaseQueryException,
WriteQueryNotAllowedException,
)
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: self._get_driver().session(
database=database, default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
raise WriteQueryNotAllowedException(
message="Read query not allowed", code=READ_EXCEPTION_CODES[0]
)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
# Operations
def execute_read_query(
self,
database: str,
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
with self.get_session(
database, default_access_mode=neo4j.READ_ACCESS
) as session:
def _run(tx: neo4j.ManagedTransaction) -> neo4j.graph.Graph:
result = tx.run(
cypher, parameters or {}, timeout=READ_QUERY_TIMEOUT_SECONDS
)
return result.graph()
return session.execute_read(_run)
def create_database(self, database: str) -> None:
with self.get_session() as session:
session.run(
"CREATE DATABASE $database IF NOT EXISTS", {"database": database}
)
def drop_database(self, database: str) -> None:
with self.get_session() as session:
session.run(f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA")
def drop_subgraph(self, database: str, provider_id: str) -> int:
"""Delete all nodes for a provider from a tenant database, batched.
Deletes relationships then nodes in batches (not `DETACH DELETE`) so a
dense provider's graph cannot exceed Neo4j's transaction memory limit.
Silently returns 0 if the database doesn't exist.
"""
from api.attack_paths.database import GraphDatabaseQueryException
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
deleted_nodes = 0
deleted_relationships = 0
relationship_batches = 0
node_batches = 0
drop_t0 = time.perf_counter()
logger.info(
"Dropping provider graph from Neo4j sink database %s "
"(provider=%s, provider_label=%s)",
database,
provider_id,
provider_label,
)
try:
logger.info(
"Opening Neo4j sink session for provider graph drop "
"(database=%s, provider=%s)",
database,
provider_id,
)
with self.get_session(database) as session:
logger.info(
"Opened Neo4j sink session for provider graph drop "
"(database=%s, provider=%s)",
database,
provider_id,
)
# Phase 1: delete relationships incident to provider nodes in
# batches. The undirected pattern matches an edge between two
# provider nodes from both ends, so `DISTINCT r` dedupes it to
# delete a full batch of unique relationships each round.
deleted_count = 1
while deleted_count > 0:
next_batch = relationship_batches + 1
logger.info(
"Deleting relationship batch from Neo4j sink database %s "
"(provider=%s, batch=%s, total_rels=%s, elapsed=%.3fs)",
database,
provider_id,
next_batch,
deleted_relationships,
time.perf_counter() - drop_t0,
)
result = session.run(
f"""
MATCH (:`{provider_label}`)-[r]-()
WITH DISTINCT r LIMIT $batch_size
DELETE r
RETURN COUNT(r) AS deleted_rels_count
""",
{"batch_size": BATCH_SIZE},
)
deleted_count = result.single().get("deleted_rels_count", 0)
if deleted_count > 0:
relationship_batches += 1
deleted_relationships += deleted_count
logger.info(
"Deleted relationship batch from Neo4j sink database %s "
"(provider=%s, batch=%s, deleted_rels=%s, "
"total_rels=%s, elapsed=%.3fs)",
database,
provider_id,
relationship_batches,
deleted_count,
deleted_relationships,
time.perf_counter() - drop_t0,
)
# Phase 2: delete the now relationship-free nodes in batches.
deleted_count = 1
while deleted_count > 0:
next_batch = node_batches + 1
logger.info(
"Deleting node batch from Neo4j sink database %s "
"(provider=%s, batch=%s, total_nodes=%s, elapsed=%.3fs)",
database,
provider_id,
next_batch,
deleted_nodes,
time.perf_counter() - drop_t0,
)
result = session.run(
f"""
MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`)
WITH n LIMIT $batch_size
DELETE n
RETURN COUNT(n) AS deleted_nodes_count
""",
{"batch_size": BATCH_SIZE},
)
deleted_count = result.single().get("deleted_nodes_count", 0)
if deleted_count > 0:
node_batches += 1
deleted_nodes += deleted_count
logger.info(
"Deleted node batch from Neo4j sink database %s "
"(provider=%s, batch=%s, deleted_nodes=%s, "
"total_nodes=%s, elapsed=%.3fs)",
database,
provider_id,
node_batches,
deleted_count,
deleted_nodes,
time.perf_counter() - drop_t0,
)
except GraphDatabaseQueryException as exc:
if exc.code == DATABASE_NOT_FOUND_CODE:
logger.info(
"Skipped provider graph drop from Neo4j sink database %s "
"(provider=%s, reason=database_not_found, elapsed=%.3fs)",
database,
provider_id,
time.perf_counter() - drop_t0,
)
return 0
raise
logger.info(
"Finished dropping provider graph from Neo4j sink database %s "
"(provider=%s, relationship_batches=%s, deleted_rels=%s, "
"node_batches=%s, deleted_nodes=%s, elapsed=%.3fs)",
database,
provider_id,
relationship_batches,
deleted_relationships,
node_batches,
deleted_nodes,
time.perf_counter() - drop_t0,
)
return deleted_nodes
def has_provider_data(self, database: str, provider_id: str) -> bool:
from api.attack_paths.database import GraphDatabaseQueryException
from tasks.jobs.attack_paths.config import (
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
query = (
f"MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`) RETURN 1 LIMIT 1"
)
try:
with self.get_session(
database, default_access_mode=neo4j.READ_ACCESS
) as session:
result = session.run(query)
return result.single() is not None
except GraphDatabaseQueryException as exc:
if exc.code == DATABASE_NOT_FOUND_CODE:
return False
raise
def clear_cache(self, database: str) -> None:
from api.attack_paths.database import GraphDatabaseQueryException
try:
with self.get_session(database) as session:
session.run("CALL db.clearQueryCaches()")
except GraphDatabaseQueryException as exc:
logger.warning(
f"Failed to clear query cache for database `{database}`: {exc}"
)
# Sync write path
def ensure_sync_indexes(self, database: str) -> None:
"""Create the `_provider_element_id` lookup index on `_ProviderResource`.
Every synced node carries the `_ProviderResource` label, so a single
index covers both node-upserts and relationship endpoint MATCHes.
Without this index the rel sync degrades to a label scan per row and
large provider syncs become unworkable.
"""
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
)
query = (
f"CREATE INDEX provider_element_id_idx IF NOT EXISTS "
f"FOR (n:`{PROVIDER_RESOURCE_LABEL}`) "
f"ON (n.`{PROVIDER_ELEMENT_ID_PROPERTY}`)"
)
with self.get_session(database) as session:
session.run(query).consume()
def write_nodes(
self,
database: str,
labels: str,
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
)
query = f"""
UNWIND $rows AS row
MERGE (n:`{PROVIDER_RESOURCE_LABEL}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.provider_element_id}})
SET n:{labels}
SET n += row.props
"""
with self.get_session(database) as session:
session.run(query, {"rows": rows}).consume()
def write_relationships(
self,
database: str,
rel_type: str,
provider_id: str,
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
query = f"""
UNWIND $rows AS row
MATCH (s:`{PROVIDER_RESOURCE_LABEL}`:`{provider_label}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.start_element_id}})
MATCH (t:`{PROVIDER_RESOURCE_LABEL}`:`{provider_label}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.end_element_id}})
MERGE (s)-[r:`{rel_type}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.provider_element_id}}]->(t)
SET r += row.props
"""
with self.get_session(database) as session:
session.run(query, {"rows": rows}).consume()
# For compatibility with test harnesses that patch the concrete driver
def get_driver(self) -> neo4j.Driver:
return self._get_driver()
# Helper for tests / external callers that want a writer session specifically
def get_read_session(
sink: Neo4jSink, database: str
) -> AbstractContextManager[RetryableSession]:
return sink.get_session(database, default_access_mode=neo4j.READ_ACCESS)
@@ -0,0 +1,524 @@
"""AWS Neptune sink implementation.
Dual Bolt drivers: one against the writer endpoint for workers, one against
the reader endpoint for the API read path. If `NEPTUNE_READER_ENDPOINT` is
unset the reader falls back to the writer driver so single-node clusters work.
Neptune is single-database. The `database` argument on the SinkDatabase
protocol is ignored; tenant / provider isolation is enforced by labels that
the sync step already writes on every node (see tasks/jobs/attack_paths/sync.py).
SigV4 auth lives at the bottom of this file as `neptune_auth_provider`. The
neo4j driver invokes the returned callable on each token refresh.
"""
import atexit
import datetime
import json
import logging
import threading
import time
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
from urllib.parse import urlsplit
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from api.attack_paths.sink.base import SinkDatabase
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.session import Session as BotoSession
from config.env import env
from django.conf import settings
from neo4j.auth_management import AuthManagers, ExpiringAuth
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
logger = logging.getLogger(__name__)
SERVICE_UNAVAILABLE_MAX_RETRIES = env.int(
"ATTACK_PATHS_SERVICE_UNAVAILABLE_MAX_RETRIES", default=3
)
READ_QUERY_TIMEOUT_SECONDS = env.int(
"ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30
)
# Neptune serverless cold-start can be >30s; give the driver room
CONN_ACQUISITION_TIMEOUT = env.int("NEPTUNE_CONN_ACQUISITION_TIMEOUT", default=60)
# TCP connect timeout, ordered below the acquisition timeout so an unreachable
# endpoint can't pin a request or the readiness probe longer than this. Kept
# generous: cold-start delays query execution, not the socket connect.
CONNECTION_TIMEOUT = env.int("NEPTUNE_CONNECTION_TIMEOUT", default=10)
# Roll connections hourly so SigV4 rotations and cert refreshes don't strand long-lived pool entries
MAX_CONNECTION_LIFETIME = env.int("NEPTUNE_MAX_CONNECTION_LIFETIME", default=3600)
MAX_CONNECTION_POOL_SIZE = env.int("NEPTUNE_MAX_CONNECTION_POOL_SIZE", default=50)
READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
# Refresh 60s before the 5-minute SigV4 window closes
SIGV4_TOKEN_LIFETIME_MINUTES = 4
class NeptuneSink(SinkDatabase):
"""Neptune-backed sink. Single database; isolation is label-based."""
def __init__(self) -> None:
self._writer: neo4j.Driver | None = None
self._reader: neo4j.Driver | None = None
self._lock = threading.Lock()
self._atexit_registered = False
# Config
def _config(self) -> dict:
return settings.DATABASES["neptune"]
def _bolt_uri(self, endpoint: str, port: str) -> str:
return f"bolt+s://{endpoint}:{port}"
def _https_url(self, endpoint: str, port: str) -> str:
return f"https://{endpoint}:{port}"
def _build_driver(self, endpoint: str) -> neo4j.Driver:
cfg = self._config()
port = cfg["PORT"]
region = cfg["REGION"]
if not endpoint or not region:
raise RuntimeError(
"NEPTUNE_WRITER_ENDPOINT and AWS_REGION must be set when "
"ATTACK_PATHS_SINK_DATABASE=neptune"
)
return neo4j.GraphDatabase.driver(
self._bolt_uri(endpoint, port),
auth=AuthManagers.bearer(
neptune_auth_provider(region, self._https_url(endpoint, port))
),
keep_alive=True,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONN_ACQUISITION_TIMEOUT,
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
max_transaction_retry_time=0,
)
# Lifecycle
def init(self) -> None:
if self._writer is not None:
return
with self._lock:
if self._writer is None:
cfg = self._config()
writer_endpoint = cfg["WRITER_ENDPOINT"]
reader_endpoint = cfg["READER_ENDPOINT"] or writer_endpoint
# Eager connectivity checks are best-effort
# A Neptune that is down at boot must not crash the process, same degradation model as Postgres
# Drivers reconnect lazily on first use
# /health/ready surfaces the outage until it recovers
self._writer = self._build_driver(writer_endpoint)
self._verify_best_effort(self._writer, "writer")
if reader_endpoint == writer_endpoint:
self._reader = self._writer
else:
self._reader = self._build_driver(reader_endpoint)
self._verify_best_effort(self._reader, "reader")
if not self._atexit_registered:
atexit.register(self.close)
self._atexit_registered = True
def close(self) -> None:
with self._lock:
# `Driver.close()` is idempotent, so closing the same driver twice
# (when reader aliases writer on single-endpoint configs) is safe
for driver in (self._reader, self._writer):
if driver is None:
continue
try:
driver.close()
except Exception: # pragma: no cover - best-effort
pass
self._writer = None
self._reader = None
# Sessions
def _get_writer(self) -> neo4j.Driver:
self.init()
assert self._writer is not None
return self._writer
def _get_reader(self) -> neo4j.Driver:
self.init()
assert self._reader is not None
return self._reader
@staticmethod
def _verify_best_effort(driver: neo4j.Driver, role: str) -> None:
try:
driver.verify_connectivity()
except Exception:
logger.warning(
"Neptune %s endpoint unreachable at init; continuing with a lazily-reconnecting driver",
role,
exc_info=True,
)
def verify_connectivity(self) -> None:
# The API read path uses the reader driver
# On single-endpoint clusters it aliases the writer, so this also covers the writer
# A writer-only outage is a workers' concern (no HTTP probe there) and deliberately does not fail API readiness
self._get_reader().verify_connectivity()
@contextmanager
def get_session(
self,
database: str | None = None, # noqa: ARG002 - ignored on Neptune
default_access_mode: str | None = None,
) -> Iterator[RetryableSession]:
from api.attack_paths.database import (
ClientStatementException,
GraphDatabaseQueryException,
WriteQueryNotAllowedException,
)
driver = (
self._get_reader()
if default_access_mode == neo4j.READ_ACCESS
else self._get_writer()
)
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: driver.session(
default_access_mode=default_access_mode
),
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
raise WriteQueryNotAllowedException(
message="Read query not allowed", code=READ_EXCEPTION_CODES[0]
)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
# Operations
def execute_read_query(
self,
database: str, # noqa: ARG002 - ignored on Neptune
cypher: str,
parameters: dict[str, Any] | None = None,
) -> neo4j.graph.Graph:
with self.get_session(default_access_mode=neo4j.READ_ACCESS) as session:
def _run(tx: neo4j.ManagedTransaction) -> neo4j.graph.Graph:
result = tx.run(
cypher, parameters or {}, timeout=READ_QUERY_TIMEOUT_SECONDS
)
return result.graph()
return session.execute_read(_run)
def create_database(self, database: str) -> None: # noqa: ARG002
# Neptune clusters are single-database; there is nothing to create.
return None
def drop_database(self, database: str) -> None: # noqa: ARG002
# Neptune clusters are single-database; there is nothing to drop.
return None
def drop_subgraph(self, database: str, provider_id: str) -> int: # noqa: ARG002
"""Delete a provider's subgraph in two bounded phases.
Neptune write transactions are capped at ~2 minutes. A naive
`DETACH DELETE` on a label-scanned batch grows unbounded with graph
density (one node can drag thousands of relationships into the same
transaction). Instead:
1. Delete relationships incident to provider nodes, one fixed-size
batch per transaction.
2. Delete the now-orphaned nodes, one fixed-size batch per transaction.
Each transaction does work proportional to `batch_size`, never to the
graph's branching factor.
"""
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
deleted_relationships = 0
relationship_batches = 0
node_batches = 0
drop_t0 = time.perf_counter()
logger.info(
"Dropping provider graph from Neptune sink "
"(provider=%s, provider_label=%s)",
provider_id,
provider_label,
)
logger.info(
"Opening Neptune writer session for provider graph drop (provider=%s)",
provider_id,
)
with self.get_session() as session:
logger.info(
"Opened Neptune writer session for provider graph drop (provider=%s)",
provider_id,
)
while True:
next_batch = relationship_batches + 1
logger.info(
"Deleting relationship batch from Neptune sink "
"(provider=%s, batch=%s, total_rels=%s, elapsed=%.3fs)",
provider_id,
next_batch,
deleted_relationships,
time.perf_counter() - drop_t0,
)
result = session.run(
f"""
MATCH (:`{provider_label}`)-[r]-()
WITH DISTINCT r LIMIT $batch_size
DELETE r
RETURN COUNT(r) AS deleted_rels_count
""",
{"batch_size": BATCH_SIZE},
)
record = result.single()
deleted_rels = (record["deleted_rels_count"] if record else 0) or 0
if deleted_rels == 0:
break
relationship_batches += 1
deleted_relationships += deleted_rels
logger.info(
"Deleted relationship batch from Neptune sink "
"(provider=%s, batch=%s, deleted_rels=%s, total_rels=%s, "
"elapsed=%.3fs)",
provider_id,
relationship_batches,
deleted_rels,
deleted_relationships,
time.perf_counter() - drop_t0,
)
deleted_nodes = 0
while True:
next_batch = node_batches + 1
logger.info(
"Deleting node batch from Neptune sink "
"(provider=%s, batch=%s, total_nodes=%s, elapsed=%.3fs)",
provider_id,
next_batch,
deleted_nodes,
time.perf_counter() - drop_t0,
)
result = session.run(
f"""
MATCH (n:`{PROVIDER_RESOURCE_LABEL}`:`{provider_label}`)
WITH n LIMIT $batch_size
DELETE n
RETURN COUNT(n) AS deleted_nodes_count
""",
{"batch_size": BATCH_SIZE},
)
record = result.single()
deleted = (record["deleted_nodes_count"] if record else 0) or 0
if deleted == 0:
break
node_batches += 1
deleted_nodes += deleted
logger.info(
"Deleted node batch from Neptune sink "
"(provider=%s, batch=%s, deleted_nodes=%s, total_nodes=%s, "
"elapsed=%.3fs)",
provider_id,
node_batches,
deleted,
deleted_nodes,
time.perf_counter() - drop_t0,
)
logger.info(
"Finished dropping provider graph from Neptune sink "
"(provider=%s, relationship_batches=%s, deleted_rels=%s, "
"node_batches=%s, deleted_nodes=%s, elapsed=%.3fs)",
provider_id,
relationship_batches,
deleted_relationships,
node_batches,
deleted_nodes,
time.perf_counter() - drop_t0,
)
return deleted_nodes
def has_provider_data(self, database: str, provider_id: str) -> bool: # noqa: ARG002
from tasks.jobs.attack_paths.config import (
PROVIDER_RESOURCE_LABEL,
get_provider_label,
)
provider_label = get_provider_label(provider_id)
query = (
f"MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`) RETURN 1 LIMIT 1"
)
with self.get_session(default_access_mode=neo4j.READ_ACCESS) as session:
result = session.run(query)
return result.single() is not None
def clear_cache(self, database: str) -> None: # noqa: ARG002
# Neptune has no user-facing cache-clear procedure; no-op.
return None
# Sync write path
def ensure_sync_indexes(self, database: str) -> None: # noqa: ARG002
# Neptune routes node and relationship lookups through `~id`, which is the cluster's primary key
# No additional index is needed or supported
return None
def write_nodes(
self,
database: str, # noqa: ARG002
labels: str,
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
)
# MERGE on `~id` is the documented and engine-optimized idempotent
# upsert pattern for Neptune openCypher. The label inside the MERGE
# matters: Neptune assigns a default `vertex` label to any node
# created without an explicit one, so we pin `_ProviderResource`
# (which every synced node carries anyway) at MERGE-time. Additional
# labels are added after
#
# We also write `_provider_element_id` as a regular property so
# non-sync code (drop_subgraph, query helpers) keeps a stable contract
# that doesn't know about `~id`
query = f"""
UNWIND $rows AS row
MERGE (n:`{PROVIDER_RESOURCE_LABEL}` {{`~id`: row.provider_element_id}})
SET n:{labels}
SET n += row.props
SET n.`{PROVIDER_ELEMENT_ID_PROPERTY}` = row.provider_element_id
"""
with self.get_session() as session:
session.run(query, {"rows": rows}).consume()
def write_relationships(
self,
database: str, # noqa: ARG002
rel_type: str,
provider_id: str, # noqa: ARG002 - encoded in start/end `~id` already
rows: list[dict[str, Any]],
) -> None:
if not rows:
return
from tasks.jobs.attack_paths.config import PROVIDER_ELEMENT_ID_PROPERTY
# `id(n) = $value` is Neptune's parameterized fast path; both endpoint
# MATCHes resolve in O(1) via the system `~id`, so per-row work stays
# bounded regardless of batch size
query = f"""
UNWIND $rows AS row
MATCH (s) WHERE id(s) = row.start_element_id
MATCH (e) WHERE id(e) = row.end_element_id
MERGE (s)-[r:`{rel_type}` {{`{PROVIDER_ELEMENT_ID_PROPERTY}`: row.provider_element_id}}]->(e)
SET r += row.props
"""
with self.get_session() as session:
session.run(query, {"rows": rows}).consume()
# Test helpers
def get_writer(self) -> neo4j.Driver:
return self._get_writer()
def get_reader(self) -> neo4j.Driver:
return self._get_reader()
# SigV4 auth provider
class _NeptuneAuthToken(neo4j.Auth):
"""Neo4j Auth backed by a SigV4-signed GET to `/opencypher`."""
def __init__(self, region: str, url: str) -> None:
session = BotoSession()
credentials = session.get_credentials()
if credentials is None:
raise RuntimeError(
"No AWS credentials available for Neptune SigV4 signing. "
"Ensure the boto3 credential chain can resolve."
)
credentials = credentials.get_frozen_credentials()
request = AWSRequest(method="GET", url=url + "/opencypher")
# SigV4 canonical Host must carry the real `host:port`
# Neptune runs on a non-default port (8182), so `.hostname` would drop it and break signing
request.headers.add_header("Host", urlsplit(url).netloc)
SigV4Auth(credentials, "neptune-db", region).add_auth(request)
auth_obj = {
header: request.headers[header]
for header in (
"Authorization",
"X-Amz-Date",
"X-Amz-Security-Token",
"Host",
)
if header in request.headers
}
auth_obj["HttpMethod"] = "GET"
super().__init__("basic", "username", json.dumps(auth_obj))
def neptune_auth_provider(region: str, https_url: str) -> Callable[[], ExpiringAuth]:
"""Return a callable the neo4j driver can invoke to refresh credentials."""
def _provider() -> ExpiringAuth:
token = _NeptuneAuthToken(region, https_url)
expires_at = (
datetime.datetime.now(datetime.UTC)
+ datetime.timedelta(minutes=SIGV4_TOKEN_LIFETIME_MINUTES)
).timestamp()
return ExpiringAuth(auth=token, expires_at=expires_at)
return _provider
@@ -5,6 +5,7 @@ from typing import Any
import neo4j import neo4j
from api.attack_paths import AttackPathsQueryDefinition from api.attack_paths import AttackPathsQueryDefinition
from api.attack_paths import database as graph_database from api.attack_paths import database as graph_database
from api.attack_paths import sink as sink_module
from api.attack_paths.cypher_sanitizer import ( from api.attack_paths.cypher_sanitizer import (
inject_provider_label, inject_provider_label,
validate_custom_query, validate_custom_query,
@@ -14,7 +15,9 @@ from api.attack_paths.queries.schema import (
RAW_SCHEMA_URL, RAW_SCHEMA_URL,
get_cartography_schema_query, get_cartography_schema_query,
) )
from api.models import AttackPathsScan
from config.custom_logging import BackendLogger from config.custom_logging import BackendLogger
from config.env import env
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from tasks.jobs.attack_paths.config import ( from tasks.jobs.attack_paths.config import (
INTERNAL_LABELS, INTERNAL_LABELS,
@@ -26,6 +29,10 @@ from tasks.jobs.attack_paths.config import (
logger = logging.getLogger(BackendLogger.API) logger = logging.getLogger(BackendLogger.API)
def _custom_query_timeout_ms() -> int:
return env.int("ATTACK_PATHS_READ_QUERY_TIMEOUT_SECONDS", default=30) * 1000
# Predefined query helpers # Predefined query helpers
@@ -102,13 +109,13 @@ def execute_query(
definition: AttackPathsQueryDefinition, definition: AttackPathsQueryDefinition,
parameters: dict[str, Any], parameters: dict[str, Any],
provider_id: str, provider_id: str,
scan: AttackPathsScan,
) -> dict[str, Any]: ) -> dict[str, Any]:
try: try:
graph = graph_database.execute_read_query( # TODO: drop after Neptune cutover
database=database_name, # Route reads by the scan row's recorded sink, not by current settings.
cypher=definition.cypher, backend = sink_module.get_backend_for_scan(scan)
parameters=parameters, graph = backend.execute_read_query(database_name, definition.cypher, parameters)
)
return _serialize_graph(graph, provider_id) return _serialize_graph(graph, provider_id)
except graph_database.WriteQueryNotAllowedException: except graph_database.WriteQueryNotAllowedException:
@@ -142,22 +149,31 @@ def execute_custom_query(
database_name: str, database_name: str,
cypher: str, cypher: str,
provider_id: str, provider_id: str,
scan: AttackPathsScan,
) -> dict[str, Any]: ) -> dict[str, Any]:
# Defense-in-depth for custom queries: # Defense-in-depth for custom queries:
# 1. neo4j.READ_ACCESS — prevents mutations at the driver level # 1. `neo4j.READ_ACCESS` — prevents mutations at the driver level
# 2. inject_provider_label() — regex-based label injection scopes node patterns # 2. `inject_provider_label()` — regex-based label injection scopes node patterns
# 3. _serialize_graph() — post-query filter drops nodes without the provider label # 3. `_serialize_graph()` — post-query filter drops nodes without the provider label
# 4. `USING QUERY:TIMEOUTMILLISECONDS` on Neptune — server-side runaway cutoff
# #
# Layer 2 is best-effort (regex can't fully parse Cypher); # Layer 2 is best-effort (regex can't fully parse Cypher);
# layer 3 is the safety net that guarantees provider isolation. # layer 3 is the safety net that guarantees provider isolation.
validate_custom_query(cypher) validate_custom_query(cypher)
cypher = inject_provider_label(cypher, provider_id) cypher = inject_provider_label(cypher, provider_id)
# TODO: drop after Neptune cutover
backend = sink_module.get_backend_for_scan(scan)
# Neptune enforces a cluster-level query timeout; prepending the hint
# makes the limit explicit and matches the client-side read timeout.
# Applies only when the scan's graph lives in Neptune.
if getattr(scan, "sink_backend", None) == "neptune":
timeout_ms = _custom_query_timeout_ms()
cypher = f"USING QUERY:TIMEOUTMILLISECONDS {timeout_ms}\n{cypher}"
try: try:
graph = graph_database.execute_read_query( graph = backend.execute_read_query(database_name, cypher, None)
database=database_name,
cypher=cypher,
)
serialized = _serialize_graph(graph, provider_id) serialized = _serialize_graph(graph, provider_id)
return _truncate_graph(serialized) return _truncate_graph(serialized)
@@ -180,10 +196,11 @@ def execute_custom_query(
def get_cartography_schema( def get_cartography_schema(
database_name: str, provider_id: str database_name: str, provider_id: str, scan: AttackPathsScan
) -> dict[str, str] | None: ) -> dict[str, str] | None:
try: try:
with graph_database.get_session( backend = sink_module.get_backend_for_scan(scan)
with backend.get_session(
database_name, default_access_mode=neo4j.READ_ACCESS database_name, default_access_mode=neo4j.READ_ACCESS
) as session: ) as session:
result = session.run(get_cartography_schema_query(provider_id)) result = session.run(get_cartography_schema_query(provider_id))
+53 -14
View File
@@ -2,8 +2,9 @@
Format (draft-inadarei-api-health-check-06). Format (draft-inadarei-api-health-check-06).
Liveness reports only process status. Readiness verifies that PostgreSQL, Liveness reports only process status. Readiness verifies that PostgreSQL,
Valkey and Neo4j are reachable and returns per-dependency detail when any Valkey and the attack-paths graph store (Neo4j or Neptune, per
of them is unreachable. ``ATTACK_PATHS_SINK_DATABASE``) are reachable and returns per-dependency
detail when any of them is unreachable.
""" """
from __future__ import annotations from __future__ import annotations
@@ -11,6 +12,8 @@ from __future__ import annotations
import logging import logging
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from contextlib import suppress from contextlib import suppress
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
@@ -37,9 +40,28 @@ STATUS_FAIL = "fail"
STATUS_WARN = "warn" STATUS_WARN = "warn"
# Short socket timeout so a stuck Valkey cannot stall the probe. # Short socket timeout so a stuck Valkey cannot stall the probe.
# Neo4j inherits its driver-level ``connection_acquisition_timeout``.
VALKEY_PROBE_TIMEOUT_SECONDS = 2 VALKEY_PROBE_TIMEOUT_SECONDS = 2
# Probe-scoped budget for the graph database.
# ``Driver.verify_connectivity()`` takes no timeout; its only bound is the
# driver-level ``connection_acquisition_timeout`` (60s on Neptune). The
# probe needs its own budget, independent of the workload driver, so a
# graph-database outage cannot pin a worker thread (and the readiness lock)
# for a minute.
GRAPH_DB_PROBE_TIMEOUT_SECONDS = 5
# Bounded pool that enforces ``GRAPH_DB_PROBE_TIMEOUT_SECONDS``. If the
# graph database is unreachable the probe call blocks until the driver's
# own acquisition timeout fires; we abandon the future after the budget and
# report ``fail``. Orphaned tasks are capped by ``max_workers`` plus the 3s
# readiness cache plus the per-IP throttle, so they cannot pile up: worst
# case during a graph-database outage is every readiness call failing fast
# in ``GRAPH_DB_PROBE_TIMEOUT_SECONDS`` with at most 2 background threads
# stuck for <= the driver acquisition timeout.
_graph_db_probe_executor = ThreadPoolExecutor(
max_workers=2, thread_name_prefix="health-graph-db-probe"
)
# Brief cache window so high-frequency probes (ALB target groups, scrapers) # Brief cache window so high-frequency probes (ALB target groups, scrapers)
# do not stampede the actual dependency checks. # do not stampede the actual dependency checks.
CACHE_CONTROL_HEADER = "max-age=3, must-revalidate" CACHE_CONTROL_HEADER = "max-age=3, must-revalidate"
@@ -109,11 +131,24 @@ def _probe_valkey() -> None:
client.close() client.close()
def _probe_neo4j() -> None: def _graph_db_component_id() -> str:
# Lazy import: avoids pulling attack_paths into the boot import graph. """Return the active graph database name for the ``componentId`` field."""
from api.attack_paths.database import get_driver return settings.ATTACK_PATHS_SINK_DATABASE.strip().lower()
get_driver().verify_connectivity()
def _probe_graph_db() -> None:
# Lazy import: avoids pulling attack_paths into the boot import graph
from api.attack_paths.database import verify_connectivity
future = _graph_db_probe_executor.submit(verify_connectivity)
try:
future.result(timeout=GRAPH_DB_PROBE_TIMEOUT_SECONDS)
except FuturesTimeoutError as exc:
# Do not wait for the abandoned task; it ends when the driver's own acquisition timeout fires
future.cancel()
raise TimeoutError(
f"graph-db probe exceeded {GRAPH_DB_PROBE_TIMEOUT_SECONDS}s"
) from exc
def _build_check_entry( def _build_check_entry(
@@ -176,14 +211,18 @@ def _readiness_payload() -> tuple[dict[str, Any], int]:
): ):
return snapshot[1], snapshot[2] return snapshot[1], snapshot[2]
graph_db_component_id = _graph_db_component_id()
postgres_result, postgres_ms = _measure("postgres", _probe_postgres) postgres_result, postgres_ms = _measure("postgres", _probe_postgres)
valkey_result, valkey_ms = _measure("valkey", _probe_valkey) valkey_result, valkey_ms = _measure("valkey", _probe_valkey)
neo4j_result, neo4j_ms = _measure("neo4j", _probe_neo4j) graph_db_result, graph_db_ms = _measure(graph_db_component_id, _probe_graph_db)
entries = [ entries = [
_build_check_entry("postgres", "datastore", postgres_result, postgres_ms), _build_check_entry("postgres", "datastore", postgres_result, postgres_ms),
_build_check_entry("valkey", "datastore", valkey_result, valkey_ms), _build_check_entry("valkey", "datastore", valkey_result, valkey_ms),
_build_check_entry("neo4j", "datastore", neo4j_result, neo4j_ms), _build_check_entry(
graph_db_component_id, "datastore", graph_db_result, graph_db_ms
),
] ]
overall = _aggregate_status(entries) overall = _aggregate_status(entries)
@@ -191,7 +230,7 @@ def _readiness_payload() -> tuple[dict[str, Any], int]:
payload["checks"] = { payload["checks"] = {
"postgres:responseTime": [entries[0]], "postgres:responseTime": [entries[0]],
"valkey:responseTime": [entries[1]], "valkey:responseTime": [entries[1]],
"neo4j:responseTime": [entries[2]], "graphdb:responseTime": [entries[2]],
} }
http_status = ( http_status = (
@@ -233,10 +272,10 @@ class LivenessView(APIView):
class ReadinessView(APIView): class ReadinessView(APIView):
"""Readiness probe. """Readiness probe.
Returns 200 when PostgreSQL, Valkey and Neo4j all respond, or 503 with Returns 200 when PostgreSQL, Valkey and the attack-paths graph store
per-dependency detail when any of them is unreachable. Per-IP throttle all respond, or 503 with per-dependency detail when any of them is
plus the short in-process result cache cap the real dependency hits unreachable. Per-IP throttle plus the short in-process result cache cap
regardless of inbound traffic shape. the real dependency hits regardless of inbound traffic shape.
""" """
authentication_classes: list = [] authentication_classes: list = []
@@ -0,0 +1,24 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0095_reconcile_orphan_tasks_periodic_task"),
]
operations = [
migrations.AddField(
model_name="attackpathsscan",
name="is_migrated",
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name="attackpathsscan",
name="sink_backend",
field=models.CharField(
choices=[("neo4j", "Neo4j"), ("neptune", "Neptune")],
default="neo4j",
max_length=16,
),
),
]
+16
View File
@@ -757,6 +757,10 @@ class Scan(RowLevelSecurityProtectedModel):
class AttackPathsScan(RowLevelSecurityProtectedModel): class AttackPathsScan(RowLevelSecurityProtectedModel):
class SinkBackendChoices(models.TextChoices):
NEO4J = "neo4j", "Neo4j"
NEPTUNE = "neptune", "Neptune"
objects = ActiveProviderManager() objects = ActiveProviderManager()
all_objects = models.Manager() all_objects = models.Manager()
@@ -805,6 +809,18 @@ class AttackPathsScan(RowLevelSecurityProtectedModel):
) )
ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True) ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True)
# True when the scan was synced with the current schema (list-typed
# properties materialised as child item nodes). False for pre-cutover scans
# still using the previous graph shape. Query catalog selection uses this
# flag; physical read routing uses sink_backend below.
# TODO: drop after Neptune cutover
is_migrated = models.BooleanField(default=False)
sink_backend = models.CharField(
choices=SinkBackendChoices.choices,
default=SinkBackendChoices.NEO4J,
max_length=16,
)
class Meta(RowLevelSecurityProtectedModel.Meta): class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "attack_paths_scans" db_table = "attack_paths_scans"
+128 -61
View File
@@ -92,7 +92,9 @@ def test_prepare_parameters_validates_cast(
def test_execute_query_serializes_graph( def test_execute_query_serializes_graph(
attack_paths_query_definition_factory, attack_paths_graph_stub_classes attack_paths_query_definition_factory,
attack_paths_graph_stub_classes,
sink_backend_stub,
): ):
definition = attack_paths_query_definition_factory( definition = attack_paths_query_definition_factory(
id="aws-rds", id="aws-rds",
@@ -135,18 +137,17 @@ def test_execute_query_serializes_graph(
database_name = "db-tenant-test-tenant-id" database_name = "db-tenant-test-tenant-id"
with patch( sink_backend_stub.execute_read_query.return_value = graph_result
"api.attack_paths.views_helpers.graph_database.execute_read_query",
return_value=graph_result,
) as mock_execute_read_query:
result = views_helpers.execute_query( result = views_helpers.execute_query(
database_name, definition, parameters, provider_id=provider_id database_name,
definition,
parameters,
provider_id=provider_id,
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
) )
mock_execute_read_query.assert_called_once_with( sink_backend_stub.execute_read_query.assert_called_once_with(
database=database_name, database_name, definition.cypher, parameters
cypher=definition.cypher,
parameters=parameters,
) )
assert result["nodes"][0]["id"] == "node-1" assert result["nodes"][0]["id"] == "node-1"
assert result["nodes"][0]["properties"]["complex"]["items"][0] == "value" assert result["nodes"][0]["properties"]["complex"]["items"][0] == "value"
@@ -155,6 +156,7 @@ def test_execute_query_serializes_graph(
def test_execute_query_wraps_graph_errors( def test_execute_query_wraps_graph_errors(
attack_paths_query_definition_factory, attack_paths_query_definition_factory,
sink_backend_stub,
): ):
definition = attack_paths_query_definition_factory( definition = attack_paths_query_definition_factory(
id="aws-rds", id="aws-rds",
@@ -167,16 +169,17 @@ def test_execute_query_wraps_graph_errors(
database_name = "db-tenant-test-tenant-id" database_name = "db-tenant-test-tenant-id"
parameters = {"provider_uid": "123"} parameters = {"provider_uid": "123"}
with ( sink_backend_stub.execute_read_query.side_effect = (
patch( graph_database.GraphDatabaseQueryException("boom")
"api.attack_paths.views_helpers.graph_database.execute_read_query", )
side_effect=graph_database.GraphDatabaseQueryException("boom"), with patch("api.attack_paths.views_helpers.logger") as mock_logger:
),
patch("api.attack_paths.views_helpers.logger") as mock_logger,
):
with pytest.raises(APIException): with pytest.raises(APIException):
views_helpers.execute_query( views_helpers.execute_query(
database_name, definition, parameters, provider_id="test-provider-123" database_name,
definition,
parameters,
provider_id="test-provider-123",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
) )
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
@@ -184,6 +187,7 @@ def test_execute_query_wraps_graph_errors(
def test_execute_query_raises_permission_denied_on_read_only( def test_execute_query_raises_permission_denied_on_read_only(
attack_paths_query_definition_factory, attack_paths_query_definition_factory,
sink_backend_stub,
): ):
definition = attack_paths_query_definition_factory( definition = attack_paths_query_definition_factory(
id="aws-rds", id="aws-rds",
@@ -196,16 +200,19 @@ def test_execute_query_raises_permission_denied_on_read_only(
database_name = "db-tenant-test-tenant-id" database_name = "db-tenant-test-tenant-id"
parameters = {"provider_uid": "123"} parameters = {"provider_uid": "123"}
with patch( sink_backend_stub.execute_read_query.side_effect = (
"api.attack_paths.views_helpers.graph_database.execute_read_query", graph_database.WriteQueryNotAllowedException(
side_effect=graph_database.WriteQueryNotAllowedException(
message="Read query not allowed", message="Read query not allowed",
code="Neo.ClientError.Statement.AccessMode", code="Neo.ClientError.Statement.AccessMode",
), )
): )
with pytest.raises(PermissionDenied): with pytest.raises(PermissionDenied):
views_helpers.execute_query( views_helpers.execute_query(
database_name, definition, parameters, provider_id="test-provider-123" database_name,
definition,
parameters,
provider_id="test-provider-123",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
) )
@@ -440,6 +447,7 @@ def test_normalize_custom_query_payload_passthrough_for_flat_dict():
def test_execute_custom_query_serializes_graph( def test_execute_custom_query_serializes_graph(
attack_paths_graph_stub_classes, attack_paths_graph_stub_classes,
sink_backend_stub,
): ):
provider_id = "test-provider-123" provider_id = "test-provider-123"
plabel = get_provider_label(provider_id) plabel = get_provider_label(provider_id)
@@ -453,50 +461,73 @@ def test_execute_custom_query_serializes_graph(
graph_result.nodes = [node_1, node_2] graph_result.nodes = [node_1, node_2]
graph_result.relationships = [relationship] graph_result.relationships = [relationship]
with patch( sink_backend_stub.execute_read_query.return_value = graph_result
"api.attack_paths.views_helpers.graph_database.execute_read_query",
return_value=graph_result,
) as mock_execute:
result = views_helpers.execute_custom_query( result = views_helpers.execute_custom_query(
"db-tenant-test", "MATCH (n) RETURN n", provider_id "db-tenant-test",
"MATCH (n) RETURN n",
provider_id,
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
) )
mock_execute.assert_called_once() sink_backend_stub.execute_read_query.assert_called_once()
call_kwargs = mock_execute.call_args[1] call_args = sink_backend_stub.execute_read_query.call_args[0]
assert call_kwargs["database"] == "db-tenant-test" assert call_args[0] == "db-tenant-test"
# The cypher is rewritten with the provider label injection # The cypher is rewritten with the provider label injection
assert plabel in call_kwargs["cypher"] assert plabel in call_args[1]
assert len(result["nodes"]) == 2 assert len(result["nodes"]) == 2
assert result["relationships"][0]["label"] == "OWNS" assert result["relationships"][0]["label"] == "OWNS"
assert result["truncated"] is False assert result["truncated"] is False
assert result["total_nodes"] == 2 assert result["total_nodes"] == 2
def test_execute_custom_query_raises_permission_denied_on_write(): def test_execute_custom_query_adds_timeout_for_neptune_scan(sink_backend_stub):
graph_result = MagicMock()
graph_result.nodes = []
graph_result.relationships = []
sink_backend_stub.execute_read_query.return_value = graph_result
with patch( with patch(
"api.attack_paths.views_helpers.graph_database.execute_read_query", "api.attack_paths.views_helpers.sink_module.get_backend_for_scan",
side_effect=graph_database.WriteQueryNotAllowedException( return_value=sink_backend_stub,
):
views_helpers.execute_custom_query(
"db-tenant-test",
"MATCH (n) RETURN n",
"provider-1",
scan=MagicMock(is_migrated=True, sink_backend="neptune"),
)
cypher = sink_backend_stub.execute_read_query.call_args[0][1]
assert cypher.startswith("USING QUERY:TIMEOUTMILLISECONDS")
def test_execute_custom_query_raises_permission_denied_on_write(sink_backend_stub):
sink_backend_stub.execute_read_query.side_effect = (
graph_database.WriteQueryNotAllowedException(
message="Read query not allowed", message="Read query not allowed",
code="Neo.ClientError.Statement.AccessMode", code="Neo.ClientError.Statement.AccessMode",
), )
): )
with pytest.raises(PermissionDenied): with pytest.raises(PermissionDenied):
views_helpers.execute_custom_query( views_helpers.execute_custom_query(
"db-tenant-test", "CREATE (n) RETURN n", "provider-1" "db-tenant-test",
"CREATE (n) RETURN n",
"provider-1",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
) )
def test_execute_custom_query_wraps_graph_errors(): def test_execute_custom_query_wraps_graph_errors(sink_backend_stub):
with ( sink_backend_stub.execute_read_query.side_effect = (
patch( graph_database.GraphDatabaseQueryException("boom")
"api.attack_paths.views_helpers.graph_database.execute_read_query", )
side_effect=graph_database.GraphDatabaseQueryException("boom"), with patch("api.attack_paths.views_helpers.logger") as mock_logger:
),
patch("api.attack_paths.views_helpers.logger") as mock_logger,
):
with pytest.raises(APIException): with pytest.raises(APIException):
views_helpers.execute_custom_query( views_helpers.execute_custom_query(
"db-tenant-test", "MATCH (n) RETURN n", "provider-1" "db-tenant-test",
"MATCH (n) RETURN n",
"provider-1",
scan=MagicMock(is_migrated=False, sink_backend="neo4j"),
) )
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
@@ -561,13 +592,33 @@ def test_truncate_graph_empty_graph():
@pytest.fixture @pytest.fixture
def mock_neo4j_session(): def mock_neo4j_session():
"""Mock the Neo4j driver so execute_read_query uses a fake session.""" """Install a Neo4jSink with a mocked Bolt driver into the sink factory.
The yielded mock is the `neo4j.Session` that the Neo4jSink will obtain via
`driver.session(...)`. Tests configure `mock_neo4j_session.execute_read`
return values / side effects to exercise the read-mode error translation
path on the real `Neo4jSink.execute_read_query` and `get_session` code.
"""
from api.attack_paths.sink import factory
from api.attack_paths.sink.neo4j import Neo4jSink
mock_session = MagicMock(spec=neo4j.Session) mock_session = MagicMock(spec=neo4j.Session)
mock_driver = MagicMock(spec=neo4j.Driver) mock_driver = MagicMock(spec=neo4j.Driver)
mock_driver.session.return_value = mock_session mock_driver.session.return_value = mock_session
with patch("api.attack_paths.database.get_driver", return_value=mock_driver): sink = Neo4jSink()
sink._driver = mock_driver
previous_backend = factory._backend
previous_secondary = dict(factory._secondary_backends)
factory._backend = sink
factory._secondary_backends.clear()
try:
yield mock_session yield mock_session
finally:
factory._backend = previous_backend
factory._secondary_backends.clear()
factory._secondary_backends.update(previous_secondary)
def test_execute_read_query_succeeds_with_select(mock_neo4j_session): def test_execute_read_query_succeeds_with_select(mock_neo4j_session):
@@ -663,16 +714,20 @@ def test_execute_read_query_rejects_apoc_real_create(mock_neo4j_session, cypher)
@pytest.fixture @pytest.fixture
def mock_schema_session(): def mock_schema_session():
"""Mock get_session for cartography schema tests.""" """Mock the routed sink backend session for cartography schema tests."""
mock_result = MagicMock() mock_result = MagicMock()
mock_session = MagicMock() mock_session = MagicMock()
mock_session.run.return_value = mock_result mock_session.run.return_value = mock_result
mock_backend = MagicMock()
with patch( with patch(
"api.attack_paths.views_helpers.graph_database.get_session" "api.attack_paths.views_helpers.sink_module.get_backend_for_scan",
) as mock_get_session: return_value=mock_backend,
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session) ):
mock_get_session.return_value.__exit__ = MagicMock(return_value=False) mock_backend.get_session.return_value.__enter__ = MagicMock(
return_value=mock_session
)
mock_backend.get_session.return_value.__exit__ = MagicMock(return_value=False)
yield mock_session, mock_result yield mock_session, mock_result
@@ -683,7 +738,9 @@ def test_get_cartography_schema_returns_urls(mock_schema_session):
"module_version": "0.129.0", "module_version": "0.129.0",
} }
result = views_helpers.get_cartography_schema("db-tenant-test", "provider-123") result = views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
mock_session.run.assert_called_once() mock_session.run.assert_called_once()
assert result["id"] == "aws-0.129.0" assert result["id"] == "aws-0.129.0"
@@ -699,7 +756,9 @@ def test_get_cartography_schema_returns_none_when_no_data(mock_schema_session):
_, mock_result = mock_schema_session _, mock_result = mock_schema_session
mock_result.single.return_value = None mock_result.single.return_value = None
result = views_helpers.get_cartography_schema("db-tenant-test", "provider-123") result = views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
assert result is None assert result is None
@@ -721,21 +780,29 @@ def test_get_cartography_schema_extracts_provider(
"module_version": "1.0.0", "module_version": "1.0.0",
} }
result = views_helpers.get_cartography_schema("db-tenant-test", "provider-123") result = views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
assert result["id"] == f"{expected_provider}-1.0.0" assert result["id"] == f"{expected_provider}-1.0.0"
assert result["provider"] == expected_provider assert result["provider"] == expected_provider
def test_get_cartography_schema_wraps_database_error(): def test_get_cartography_schema_wraps_database_error():
mock_backend = MagicMock()
mock_backend.get_session.side_effect = graph_database.GraphDatabaseQueryException(
"boom"
)
with ( with (
patch( patch(
"api.attack_paths.views_helpers.graph_database.get_session", "api.attack_paths.views_helpers.sink_module.get_backend_for_scan",
side_effect=graph_database.GraphDatabaseQueryException("boom"), return_value=mock_backend,
), ),
patch("api.attack_paths.views_helpers.logger") as mock_logger, patch("api.attack_paths.views_helpers.logger") as mock_logger,
): ):
with pytest.raises(APIException): with pytest.raises(APIException):
views_helpers.get_cartography_schema("db-tenant-test", "provider-123") views_helpers.get_cartography_schema(
"db-tenant-test", "provider-123", MagicMock(sink_backend="neo4j")
)
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
@@ -1,623 +1,174 @@
""" """Tests for the attack-paths database facade.
Tests for Neo4j database lazy initialization.
The Neo4j driver is created on first use for every process type; app startup After the Neptune port, `api.attack_paths.database` is a thin routing shim
never contacts Neo4j. These tests validate the database module behavior itself. over `api.attack_paths.ingest` (cartography temp DB, always Neo4j) and
`api.attack_paths.sink` (configurable Neo4j or Neptune). The facade's
contract is routing by database-name prefix and the public exception
hierarchy; sink-internal behavior is exercised in `test_sink.py`.
""" """
import threading
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import api.attack_paths.database as db_module import api.attack_paths.database as db_module
import neo4j
import neo4j.exceptions
import pytest
class TestLazyInitialization: class TestDatabaseNameHelper:
"""Test that Neo4j driver is initialized lazily on first use.""" def test_tenant_name_lowercases_uuid(self):
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
def test_driver_not_initialized_at_import(self):
"""Driver should be None after module import (no eager connection)."""
assert db_module._driver is None
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_creates_connection_on_first_call(
self, mock_driver_factory, mock_settings
):
"""init_driver() should create connection only when called."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
assert db_module._driver is None
result = db_module.init_driver()
mock_driver_factory.assert_called_once()
mock_driver.verify_connectivity.assert_called_once()
assert result is mock_driver
assert db_module._driver is mock_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_leaves_driver_none_when_verify_fails(
self, mock_driver_factory, mock_settings
):
"""A failed verify_connectivity() must not publish or leak the driver."""
mock_driver = MagicMock()
mock_driver.verify_connectivity.side_effect = (
neo4j.exceptions.ServiceUnavailable("down")
)
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
with pytest.raises(neo4j.exceptions.ServiceUnavailable):
db_module.init_driver()
assert db_module._driver is None
mock_driver.close.assert_called_once()
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_init_driver_returns_cached_driver_on_subsequent_calls(
self, mock_driver_factory, mock_settings
):
"""Subsequent calls should return cached driver without reconnecting."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
first_result = db_module.init_driver()
second_result = db_module.init_driver()
third_result = db_module.init_driver()
# Only one connection attempt
assert mock_driver_factory.call_count == 1
assert mock_driver.verify_connectivity.call_count == 1
# All calls return same instance
assert first_result is second_result is third_result
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_get_driver_delegates_to_init_driver(
self, mock_driver_factory, mock_settings
):
"""get_driver() should use init_driver() for lazy initialization."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
result = db_module.get_driver()
assert result is mock_driver
mock_driver_factory.assert_called_once()
class TestConnectionAcquisitionTimeout:
"""Test that the connection acquisition timeout is configurable."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
original_driver = db_module._driver
original_acq_timeout = db_module.CONN_ACQUISITION_TIMEOUT
original_conn_timeout = db_module.CONNECTION_TIMEOUT
db_module._driver = None
yield
db_module._driver = original_driver
db_module.CONN_ACQUISITION_TIMEOUT = original_acq_timeout
db_module.CONNECTION_TIMEOUT = original_conn_timeout
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_driver_receives_configured_timeout(
self, mock_driver_factory, mock_settings
):
"""init_driver() should pass the configured timeouts to the neo4j driver."""
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.CONN_ACQUISITION_TIMEOUT = 42
db_module.CONNECTION_TIMEOUT = 7
db_module.init_driver()
_, kwargs = mock_driver_factory.call_args
assert kwargs["connection_acquisition_timeout"] == 42
assert kwargs["connection_timeout"] == 7
class TestAtexitRegistration:
"""Test that atexit cleanup handler is registered correctly."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.atexit.register")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_atexit_registered_on_first_init(
self, mock_driver_factory, mock_atexit_register, mock_settings
):
"""atexit.register should be called on first initialization."""
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.init_driver()
mock_atexit_register.assert_called_once_with(db_module.close_driver)
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.atexit.register")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_atexit_registered_only_once(
self, mock_driver_factory, mock_atexit_register, mock_settings
):
"""atexit.register should only be called once across multiple inits.
The double-checked locking on _driver ensures the atexit registration
block only executes once (when _driver is first created).
"""
mock_driver_factory.return_value = MagicMock()
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
db_module.init_driver()
db_module.init_driver()
db_module.init_driver()
# Only registered once because subsequent calls hit the fast path
assert mock_atexit_register.call_count == 1
class TestCloseDriver:
"""Test driver cleanup functionality."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
def test_close_driver_closes_and_clears_driver(self):
"""close_driver() should close the driver and set it to None."""
mock_driver = MagicMock()
db_module._driver = mock_driver
db_module.close_driver()
mock_driver.close.assert_called_once()
assert db_module._driver is None
def test_close_driver_handles_none_driver(self):
"""close_driver() should handle case where driver is None."""
db_module._driver = None
# Should not raise
db_module.close_driver()
assert db_module._driver is None
def test_close_driver_clears_driver_even_on_close_error(self):
"""Driver should be cleared even if close() raises an exception."""
mock_driver = MagicMock()
mock_driver.close.side_effect = Exception("Connection error")
db_module._driver = mock_driver
with pytest.raises(Exception, match="Connection error"):
db_module.close_driver()
# Driver should still be cleared
assert db_module._driver is None
class TestExecuteReadQuery:
"""Test read query execution helper."""
def test_execute_read_query_calls_read_session_and_returns_result(self):
tx = MagicMock()
expected_graph = MagicMock()
run_result = MagicMock()
run_result.graph.return_value = expected_graph
tx.run.return_value = run_result
session = MagicMock()
def execute_read_side_effect(fn):
return fn(tx)
session.execute_read.side_effect = execute_read_side_effect
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
) as mock_get_session:
result = db_module.execute_read_query(
"db-tenant-test-tenant-id",
"MATCH (n) RETURN n",
{"provider_uid": "123"},
)
mock_get_session.assert_called_once_with(
"db-tenant-test-tenant-id",
default_access_mode=neo4j.READ_ACCESS,
)
session.execute_read.assert_called_once()
tx.run.assert_called_once_with(
"MATCH (n) RETURN n",
{"provider_uid": "123"},
timeout=db_module.READ_QUERY_TIMEOUT_SECONDS,
)
run_result.graph.assert_called_once_with()
assert result is expected_graph
def test_execute_read_query_defaults_parameters_to_empty_dict(self):
tx = MagicMock()
run_result = MagicMock()
run_result.graph.return_value = MagicMock()
tx.run.return_value = run_result
session = MagicMock()
session.execute_read.side_effect = lambda fn: fn(tx)
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
db_module.execute_read_query(
"db-tenant-test-tenant-id",
"MATCH (n) RETURN n",
)
tx.run.assert_called_once_with(
"MATCH (n) RETURN n",
{},
timeout=db_module.READ_QUERY_TIMEOUT_SECONDS,
)
run_result.graph.assert_called_once_with()
class TestGetSessionReadOnly:
"""Test that get_session translates Neo4j read-mode errors."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
@pytest.mark.parametrize(
"neo4j_code",
[
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
],
)
def test_get_session_raises_write_query_not_allowed(self, neo4j_code):
"""Read-mode Neo4j errors should raise `WriteQueryNotAllowedException`."""
mock_session = MagicMock()
neo4j_error = neo4j.exceptions.Neo4jError._hydrate_neo4j(
code=neo4j_code,
message="Write operations are not allowed",
)
mock_session.run.side_effect = neo4j_error
mock_driver = MagicMock()
mock_driver.session.return_value = mock_session
db_module._driver = mock_driver
with pytest.raises(db_module.WriteQueryNotAllowedException):
with db_module.get_session(
default_access_mode=neo4j.READ_ACCESS
) as session:
session.run("CREATE (n) RETURN n")
def test_get_session_raises_generic_exception_for_other_errors(self):
"""Non-read-mode Neo4j errors should raise GraphDatabaseQueryException."""
mock_session = MagicMock()
neo4j_error = neo4j.exceptions.Neo4jError._hydrate_neo4j(
code="Neo.ClientError.Statement.SyntaxError",
message="Invalid syntax",
)
mock_session.run.side_effect = neo4j_error
mock_driver = MagicMock()
mock_driver.session.return_value = mock_session
db_module._driver = mock_driver
with pytest.raises(db_module.GraphDatabaseQueryException):
with db_module.get_session(
default_access_mode=neo4j.READ_ACCESS
) as session:
session.run("INVALID CYPHER")
class TestThreadSafety:
"""Test thread-safe initialization."""
@pytest.fixture(autouse=True)
def reset_module_state(self):
"""Reset module-level singleton state before each test."""
original_driver = db_module._driver
db_module._driver = None
yield
db_module._driver = original_driver
@patch("api.attack_paths.database.settings")
@patch("api.attack_paths.database.neo4j.GraphDatabase.driver")
def test_concurrent_init_creates_single_driver(
self, mock_driver_factory, mock_settings
):
"""Multiple threads calling init_driver() should create only one driver."""
mock_driver = MagicMock()
mock_driver_factory.return_value = mock_driver
mock_settings.DATABASES = {
"neo4j": {
"HOST": "localhost",
"PORT": 7687,
"USER": "neo4j",
"PASSWORD": "password",
}
}
results = []
errors = []
def call_init():
try:
result = db_module.init_driver()
results.append(result)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=call_init) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, f"Threads raised errors: {errors}"
# Only one driver created
assert mock_driver_factory.call_count == 1
# All threads got the same driver instance
assert all(r is mock_driver for r in results)
assert len(results) == 10
class TestHasProviderData:
"""Test has_provider_data helper for checking provider nodes in Neo4j."""
def test_returns_true_when_nodes_exist(self):
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.single.return_value = MagicMock() # non-None record
mock_session.run.return_value = mock_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True
mock_session.run.assert_called_once()
def test_returns_false_when_no_nodes(self):
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.single.return_value = None
mock_session.run.return_value = mock_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is False
def test_returns_false_when_database_not_found(self):
session_ctx = MagicMock()
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
message="Database does not exist",
code="Neo.ClientError.Database.DatabaseNotFound",
)
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
assert ( assert (
db_module.has_provider_data("db-tenant-gone", "provider-123") is False db_module.get_database_name("ABC-123", temporary=False)
== "db-tenant-abc-123"
) )
def test_raises_on_other_errors(self): def test_temporary_name_uses_tmp_scan_prefix(self):
session_ctx = MagicMock() assert (
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException( db_module.get_database_name("XYZ-789", temporary=True)
message="Connection refused", == "db-tmp-scan-xyz-789"
code="Neo.TransientError.General.UnknownError",
) )
with patch(
"api.attack_paths.database.get_session",
return_value=session_ctx,
):
with pytest.raises(db_module.GraphDatabaseQueryException):
db_module.has_provider_data("db-tenant-abc", "provider-123")
class TestExceptionHierarchy:
"""`tasks/` and `api/v1/views.py` import these from the facade."""
class TestDropSubgraph: def test_write_query_is_graph_database_exception(self):
"""Test drop_subgraph two-phase batched deletion of a provider's graph.""" assert issubclass(
db_module.WriteQueryNotAllowedException,
@staticmethod db_module.GraphDatabaseQueryException,
def _result(count):
result = MagicMock()
result.single.return_value.get.return_value = count
return result
@staticmethod
def _session_ctx(session):
ctx = MagicMock()
ctx.__enter__.return_value = session
ctx.__exit__.return_value = False
return ctx
def test_deletes_relationships_then_nodes_in_batches(self):
session = MagicMock()
# Phase 1 (relationships): one full batch then empty.
# Phase 2 (nodes): one full batch then empty.
session.run.side_effect = [
self._result(1000),
self._result(0),
self._result(1000),
self._result(0),
]
with patch(
"api.attack_paths.database.get_session",
return_value=self._session_ctx(session),
):
deleted = db_module.drop_subgraph("db-tenant-abc", "provider-123")
# Only phase-2 node counts contribute to the return value.
assert deleted == 1000
assert session.run.call_count == 4
queries = [call.args[0] for call in session.run.call_args_list]
# Regression guard: the memory blow-up was caused by DETACH DELETE.
assert all("DETACH DELETE" not in query for query in queries)
rel_queries = [query for query in queries if "DELETE r" in query]
node_queries = [query for query in queries if "DELETE n" in query]
assert rel_queries and node_queries
# DISTINCT avoids double-counting relationships matched from both ends.
assert all("DISTINCT r" in query for query in rel_queries)
# Relationships must be fully drained before nodes are deleted.
first_node = next(i for i, q in enumerate(queries) if "DELETE n" in q)
last_rel = max(i for i, q in enumerate(queries) if "DELETE r" in q)
assert last_rel < first_node
def test_returns_zero_when_database_not_found(self):
session_ctx = MagicMock()
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
message="Database does not exist",
code="Neo.ClientError.Database.DatabaseNotFound",
) )
with patch( def test_client_statement_is_graph_database_exception(self):
"api.attack_paths.database.get_session", assert issubclass(
return_value=session_ctx, db_module.ClientStatementException, db_module.GraphDatabaseQueryException
):
assert db_module.drop_subgraph("db-tenant-gone", "provider-123") == 0
def test_raises_on_other_errors(self):
session_ctx = MagicMock()
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
message="Connection refused",
code="Neo.TransientError.General.UnknownError",
) )
with patch( def test_exception_str_includes_code_when_set(self):
"api.attack_paths.database.get_session", exc = db_module.GraphDatabaseQueryException(
return_value=session_ctx, message="boom", code="Neo.ClientError.X.Y"
): )
with pytest.raises(db_module.GraphDatabaseQueryException): assert str(exc) == "Neo.ClientError.X.Y: boom"
db_module.drop_subgraph("db-tenant-abc", "provider-123")
def test_exception_str_falls_back_to_message_without_code(self):
exc = db_module.GraphDatabaseQueryException(message="boom")
assert str(exc) == "boom"
class TestExecuteReadQueryRoutes:
def test_execute_read_query_delegates_to_sink(self, sink_backend_stub):
sink_backend_stub.execute_read_query.return_value = "graph"
result = db_module.execute_read_query(
"db-tenant-abc", "MATCH (n) RETURN n", {"provider_uid": "123"}
)
sink_backend_stub.execute_read_query.assert_called_once_with(
"db-tenant-abc", "MATCH (n) RETURN n", {"provider_uid": "123"}
)
assert result == "graph"
def test_execute_read_query_defaults_parameters_to_none(self, sink_backend_stub):
db_module.execute_read_query("db-tenant-abc", "MATCH (n) RETURN n")
sink_backend_stub.execute_read_query.assert_called_once_with(
"db-tenant-abc", "MATCH (n) RETURN n", None
)
class TestSinkOperationsDelegation:
def test_has_provider_data_delegates_to_sink(self, sink_backend_stub):
sink_backend_stub.has_provider_data.return_value = True
assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True
sink_backend_stub.has_provider_data.assert_called_once_with(
"db-tenant-abc", "provider-123"
)
def test_drop_subgraph_delegates_to_sink(self, sink_backend_stub):
sink_backend_stub.drop_subgraph.return_value = 42
assert db_module.drop_subgraph("db-tenant-abc", "provider-123") == 42
sink_backend_stub.drop_subgraph.assert_called_once_with(
"db-tenant-abc", "provider-123"
)
class TestRoutingByDatabasePrefix:
"""`db-tmp-scan-*` and `None` route to ingest; everything else to sink."""
def test_create_database_routes_temp_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.create_database("db-tmp-scan-uuid-1")
mock_ingest.create_database.assert_called_once_with("db-tmp-scan-uuid-1")
sink_backend_stub.create_database.assert_not_called()
def test_create_database_routes_tenant_to_sink(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.create_database("db-tenant-abc")
sink_backend_stub.create_database.assert_called_once_with("db-tenant-abc")
mock_ingest.create_database.assert_not_called()
def test_drop_database_routes_temp_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.drop_database("db-tmp-scan-uuid-1")
mock_ingest.drop_database.assert_called_once_with("db-tmp-scan-uuid-1")
sink_backend_stub.drop_database.assert_not_called()
def test_drop_database_routes_tenant_to_sink(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.drop_database("db-tenant-abc")
sink_backend_stub.drop_database.assert_called_once_with("db-tenant-abc")
mock_ingest.drop_database.assert_not_called()
def test_clear_cache_routes_temp_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.clear_cache("db-tmp-scan-uuid-1")
mock_ingest.clear_cache.assert_called_once_with("db-tmp-scan-uuid-1")
sink_backend_stub.clear_cache.assert_not_called()
def test_clear_cache_routes_tenant_to_sink(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
db_module.clear_cache("db-tenant-abc")
sink_backend_stub.clear_cache.assert_called_once_with("db-tenant-abc")
mock_ingest.clear_cache.assert_not_called()
def test_get_session_routes_temp_to_ingest(self, sink_backend_stub):
sentinel = MagicMock()
with patch("api.attack_paths.database.ingest") as mock_ingest:
mock_ingest.get_session.return_value = sentinel
result = db_module.get_session("db-tmp-scan-uuid-1")
assert result is sentinel
mock_ingest.get_session.assert_called_once()
sink_backend_stub.get_session.assert_not_called()
def test_get_session_routes_none_to_ingest(self, sink_backend_stub):
sentinel = MagicMock()
with patch("api.attack_paths.database.ingest") as mock_ingest:
mock_ingest.get_session.return_value = sentinel
result = db_module.get_session(None)
assert result is sentinel
sink_backend_stub.get_session.assert_not_called()
def test_get_ingest_uri_delegates_to_ingest(self, sink_backend_stub):
with patch("api.attack_paths.database.ingest") as mock_ingest:
mock_ingest.get_uri.return_value = "bolt://neo4j:7687"
assert db_module.get_ingest_uri() == "bolt://neo4j:7687"
mock_ingest.get_uri.assert_called_once_with()
def test_get_session_routes_tenant_to_sink(self, sink_backend_stub):
sentinel = MagicMock()
sink_backend_stub.get_session.return_value = sentinel
with patch("api.attack_paths.database.ingest") as mock_ingest:
result = db_module.get_session("db-tenant-abc")
assert result is sentinel
mock_ingest.get_session.assert_not_called()
+71 -31
View File
@@ -67,7 +67,7 @@ class TestLivenessEndpoint:
with ( with (
patch("api.health._probe_postgres") as mock_pg, patch("api.health._probe_postgres") as mock_pg,
patch("api.health._probe_valkey") as mock_vk, patch("api.health._probe_valkey") as mock_vk,
patch("api.health._probe_neo4j") as mock_neo, patch("api.health._probe_graph_db") as mock_neo,
): ):
response = api_client.get(reverse("health-live")) response = api_client.get(reverse("health-live"))
@@ -83,14 +83,14 @@ class TestReadinessEndpoint:
return ( return (
patch("api.health._probe_postgres", return_value=None), patch("api.health._probe_postgres", return_value=None),
patch("api.health._probe_valkey", return_value=None), patch("api.health._probe_valkey", return_value=None),
patch("api.health._probe_neo4j", return_value=None), patch("api.health._probe_graph_db", return_value=None),
) )
def test_returns_200_and_pass_when_all_dependencies_healthy(self, api_client): def test_returns_200_and_pass_when_all_dependencies_healthy(self, api_client):
with ( with (
patch("api.health._probe_postgres"), patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
): ):
response = api_client.get(reverse("health-ready")) response = api_client.get(reverse("health-ready"))
@@ -107,7 +107,7 @@ class TestReadinessEndpoint:
assert set(body["checks"].keys()) == { assert set(body["checks"].keys()) == {
"postgres:responseTime", "postgres:responseTime",
"valkey:responseTime", "valkey:responseTime",
"neo4j:responseTime", "graphdb:responseTime",
} }
for key in body["checks"]: for key in body["checks"]:
entries = body["checks"][key] entries = body["checks"][key]
@@ -122,6 +122,23 @@ class TestReadinessEndpoint:
# `output` must not leak when the check passed. # `output` must not leak when the check passed.
assert "output" not in entry assert "output" not in entry
@pytest.mark.parametrize("sink", ["neo4j", "neptune"])
def test_graphdb_component_id_reflects_active_sink(self, api_client, sink):
from django.test import override_settings
with (
override_settings(ATTACK_PATHS_SINK_DATABASE=sink),
patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"),
patch("api.health._probe_graph_db"),
):
response = api_client.get(reverse("health-ready"))
assert response.status_code == status.HTTP_200_OK
entry = response.json()["checks"]["graphdb:responseTime"][0]
# Stable key, but the concrete store is named in componentId.
assert entry["componentId"] == sink
def test_returns_503_and_fail_when_postgres_is_down(self, api_client): def test_returns_503_and_fail_when_postgres_is_down(self, api_client):
with ( with (
patch( patch(
@@ -129,7 +146,7 @@ class TestReadinessEndpoint:
side_effect=RuntimeError("connection refused"), side_effect=RuntimeError("connection refused"),
), ),
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
): ):
response = api_client.get(reverse("health-ready")) response = api_client.get(reverse("health-ready"))
@@ -141,13 +158,13 @@ class TestReadinessEndpoint:
# Exception detail is never echoed in the response, only logged. # Exception detail is never echoed in the response, only logged.
assert "output" not in pg_entry assert "output" not in pg_entry
assert body["checks"]["valkey:responseTime"][0]["status"] == "pass" assert body["checks"]["valkey:responseTime"][0]["status"] == "pass"
assert body["checks"]["neo4j:responseTime"][0]["status"] == "pass" assert body["checks"]["graphdb:responseTime"][0]["status"] == "pass"
def test_returns_503_and_fail_when_valkey_is_down(self, api_client): def test_returns_503_and_fail_when_valkey_is_down(self, api_client):
with ( with (
patch("api.health._probe_postgres"), patch("api.health._probe_postgres"),
patch("api.health._probe_valkey", side_effect=ConnectionError("timeout")), patch("api.health._probe_valkey", side_effect=ConnectionError("timeout")),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
): ):
response = api_client.get(reverse("health-ready")) response = api_client.get(reverse("health-ready"))
@@ -158,12 +175,12 @@ class TestReadinessEndpoint:
assert vk_entry["status"] == "fail" assert vk_entry["status"] == "fail"
assert "output" not in vk_entry assert "output" not in vk_entry
def test_returns_503_and_fail_when_neo4j_is_down(self, api_client): def test_returns_503_and_fail_when_graph_db_is_down(self, api_client):
with ( with (
patch("api.health._probe_postgres"), patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch( patch(
"api.health._probe_neo4j", "api.health._probe_graph_db",
side_effect=RuntimeError("ServiceUnavailable"), side_effect=RuntimeError("ServiceUnavailable"),
), ),
): ):
@@ -172,15 +189,15 @@ class TestReadinessEndpoint:
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
body = response.json() body = response.json()
assert body["status"] == "fail" assert body["status"] == "fail"
neo_entry = body["checks"]["neo4j:responseTime"][0] graph_db_entry = body["checks"]["graphdb:responseTime"][0]
assert neo_entry["status"] == "fail" assert graph_db_entry["status"] == "fail"
assert "output" not in neo_entry assert "output" not in graph_db_entry
def test_reports_all_failures_simultaneously(self, api_client): def test_reports_all_failures_simultaneously(self, api_client):
with ( with (
patch("api.health._probe_postgres", side_effect=RuntimeError("pg down")), patch("api.health._probe_postgres", side_effect=RuntimeError("pg down")),
patch("api.health._probe_valkey", side_effect=RuntimeError("vk down")), patch("api.health._probe_valkey", side_effect=RuntimeError("vk down")),
patch("api.health._probe_neo4j", side_effect=RuntimeError("neo down")), patch("api.health._probe_graph_db", side_effect=RuntimeError("neo down")),
): ):
response = api_client.get(reverse("health-ready")) response = api_client.get(reverse("health-ready"))
@@ -190,7 +207,7 @@ class TestReadinessEndpoint:
for key in ( for key in (
"postgres:responseTime", "postgres:responseTime",
"valkey:responseTime", "valkey:responseTime",
"neo4j:responseTime", "graphdb:responseTime",
): ):
entry = body["checks"][key][0] entry = body["checks"][key][0]
assert entry["status"] == "fail" assert entry["status"] == "fail"
@@ -209,7 +226,7 @@ class TestReadinessEndpoint:
with ( with (
patch("api.health._probe_postgres", side_effect=RuntimeError(sensitive)), patch("api.health._probe_postgres", side_effect=RuntimeError(sensitive)),
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
): ):
response = api_client.get(reverse("health-ready")) response = api_client.get(reverse("health-ready"))
@@ -229,7 +246,7 @@ class TestReadinessEndpoint:
with ( with (
patch("api.health._probe_postgres"), patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
): ):
api_client.credentials() api_client.credentials()
response = api_client.get(reverse("health-ready")) response = api_client.get(reverse("health-ready"))
@@ -244,7 +261,7 @@ class TestReadinessCache:
with ( with (
patch("api.health._probe_postgres") as pg, patch("api.health._probe_postgres") as pg,
patch("api.health._probe_valkey") as vk, patch("api.health._probe_valkey") as vk,
patch("api.health._probe_neo4j") as neo, patch("api.health._probe_graph_db") as neo,
): ):
r1 = api_client.get(reverse("health-ready")) r1 = api_client.get(reverse("health-ready"))
r2 = api_client.get(reverse("health-ready")) r2 = api_client.get(reverse("health-ready"))
@@ -262,7 +279,7 @@ class TestReadinessCache:
with ( with (
patch("api.health._probe_postgres") as pg, patch("api.health._probe_postgres") as pg,
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
): ):
api_client.get(reverse("health-ready")) api_client.get(reverse("health-ready"))
assert pg.call_count == 1 assert pg.call_count == 1
@@ -286,7 +303,7 @@ class TestReadinessCache:
with ( with (
patch("api.health._probe_postgres", side_effect=RuntimeError("down")) as pg, patch("api.health._probe_postgres", side_effect=RuntimeError("down")) as pg,
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
): ):
r1 = api_client.get(reverse("health-ready")) r1 = api_client.get(reverse("health-ready"))
r2 = api_client.get(reverse("health-ready")) r2 = api_client.get(reverse("health-ready"))
@@ -320,7 +337,7 @@ class TestRateLimiting:
with ( with (
patch("api.health._probe_postgres"), patch("api.health._probe_postgres"),
patch("api.health._probe_valkey"), patch("api.health._probe_valkey"),
patch("api.health._probe_neo4j"), patch("api.health._probe_graph_db"),
patch.object(ScopedRateThrottle, "parse_rate", return_value=(2, 60)), patch.object(ScopedRateThrottle, "parse_rate", return_value=(2, 60)),
): ):
statuses = [ statuses = [
@@ -414,19 +431,42 @@ class TestProbeImplementations:
with pytest.raises(RuntimeError, match="bug"): with pytest.raises(RuntimeError, match="bug"):
health._probe_valkey() health._probe_valkey()
def test_neo4j_probe_calls_verify_connectivity(self): def test_graph_db_probe_calls_verify_connectivity(self):
with patch("api.attack_paths.database.get_driver") as mock_get_driver: with patch("api.attack_paths.database.verify_connectivity") as mock_verify:
mock_get_driver.return_value.verify_connectivity.return_value = None mock_verify.return_value = None
assert health._probe_neo4j() is None assert health._probe_graph_db() is None
mock_get_driver.return_value.verify_connectivity.assert_called_once_with() mock_verify.assert_called_once_with()
def test_neo4j_probe_propagates_driver_errors(self): def test_graph_db_probe_propagates_errors(self):
with patch("api.attack_paths.database.get_driver") as mock_get_driver: with patch(
mock_get_driver.return_value.verify_connectivity.side_effect = RuntimeError( "api.attack_paths.database.verify_connectivity",
"unreachable" side_effect=RuntimeError("unreachable"),
) ):
with pytest.raises(RuntimeError, match="unreachable"): with pytest.raises(RuntimeError, match="unreachable"):
health._probe_neo4j() health._probe_graph_db()
def test_graph_db_probe_times_out_when_check_exceeds_budget(self):
# A sink whose connectivity check blocks past the probe budget must
# surface as a failure fast, not pin the request thread for the
# driver's full acquisition timeout.
import time as _time
def _hang() -> None:
_time.sleep(2)
with (
patch("api.health.GRAPH_DB_PROBE_TIMEOUT_SECONDS", 0.2),
patch(
"api.attack_paths.database.verify_connectivity",
side_effect=_hang,
),
):
started = _time.perf_counter()
with pytest.raises(TimeoutError):
health._probe_graph_db()
elapsed = _time.perf_counter() - started
assert elapsed < health.GRAPH_DB_PROBE_TIMEOUT_SECONDS + 1
class TestStatusAggregation: class TestStatusAggregation:
+626
View File
@@ -0,0 +1,626 @@
"""Tests for the attack-paths sink factory and Neo4j sink.
The sink module picks a backend per ``settings.ATTACK_PATHS_SINK_DATABASE``.
Neo4j is the default and preserves today's behavior; Neptune is opt-in and
builds dual writer/reader Bolt drivers.
"""
import json
from importlib import import_module
from unittest.mock import MagicMock, patch
import pytest
# Prime patch-target resolution. `api.attack_paths.sink/__init__.py` doesn't
# eagerly import these submodules (they're loaded on demand inside the
# factory), so `mock.patch("api.attack_paths.sink.<sub>.…")` would fail with
# AttributeError on first call. Importing here registers them as attributes
# of the package before any decorator runs.
import_module("api.attack_paths.sink.neo4j")
import_module("api.attack_paths.sink.neptune")
@pytest.fixture(autouse=True)
def reset_sink_state():
"""Reset the module-level backend singletons around each test.
The cache lives in `api.attack_paths.sink.factory`, not on the package.
"""
from api.attack_paths.sink import factory
original_backend = factory._backend
original_secondary = dict(factory._secondary_backends)
factory._backend = None
factory._secondary_backends.clear()
yield
factory._backend = original_backend
factory._secondary_backends.clear()
factory._secondary_backends.update(original_secondary)
class TestSinkFactory:
def test_default_resolves_to_neo4j(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
assert factory._resolve_setting() == "neo4j"
def test_neptune_resolves_correctly(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
assert factory._resolve_setting() == "neptune"
def test_invalid_value_raises(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "foo"
with pytest.raises(RuntimeError, match="ATTACK_PATHS_SINK_DATABASE"):
factory._resolve_setting()
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_init_builds_neo4j_backend_by_default(self, mock_driver, settings):
from api.attack_paths import sink as sink_module
from api.attack_paths.sink.neo4j import Neo4jSink
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
mock_driver.return_value = MagicMock()
backend = sink_module.init()
assert isinstance(backend, Neo4jSink)
mock_driver.assert_called_once()
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_init_builds_neptune_backend(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths import sink as sink_module
from api.attack_paths.sink.neptune import NeptuneSink
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "reader.example",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
mock_driver.return_value = MagicMock()
mock_auth_provider.return_value = lambda: None
backend = sink_module.init()
assert isinstance(backend, NeptuneSink)
# Writer + reader endpoints both trigger driver construction
assert mock_driver.call_count == 2
writer_uri = mock_driver.call_args_list[0][0][0]
reader_uri = mock_driver.call_args_list[1][0][0]
assert writer_uri == "bolt+s://writer.example:8182"
assert reader_uri == "bolt+s://reader.example:8182"
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_neptune_reader_falls_back_to_writer(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths import sink as sink_module
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
mock_driver.return_value = MagicMock()
mock_auth_provider.return_value = lambda: None
sink_module.init()
# Only one driver call — reader aliases writer
assert mock_driver.call_count == 1
class TestGetBackendForScan:
"""``get_backend_for_scan`` routes by the row's recorded sink backend."""
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_legacy_scan_in_neo4j_process_uses_active_backend(
self, mock_driver, settings
):
from api.attack_paths import sink as sink_module
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
mock_driver.return_value = MagicMock()
scan = MagicMock(sink_backend="neo4j")
backend = sink_module.get_backend_for_scan(scan)
assert backend is sink_module.get_backend()
def test_neptune_scan_on_neo4j_process_uses_neptune_secondary(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
active_neo4j = MagicMock(name="neo4j-active")
factory._backend = active_neo4j
secondary_neptune = MagicMock(name="neptune-secondary")
with patch.object(factory, "_build_backend", return_value=secondary_neptune):
scan = MagicMock(sink_backend="neptune")
backend = factory.get_backend_for_scan(scan)
assert backend is secondary_neptune
assert backend is not active_neo4j
def _session_ctx(session: MagicMock) -> MagicMock:
ctx = MagicMock()
ctx.__enter__ = MagicMock(return_value=session)
ctx.__exit__ = MagicMock(return_value=False)
return ctx
class TestNeo4jSinkSyncWrites:
def test_ensure_sync_indexes_runs_create_index_idempotent(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.return_value = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.ensure_sync_indexes("db-tenant-x")
query = session.run.call_args.args[0]
assert "CREATE INDEX" in query
assert "IF NOT EXISTS" in query
assert "`_ProviderResource`" in query
assert "`_provider_element_id`" in query
def test_write_nodes_skips_empty_batch(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
with patch.object(sink, "get_session") as get_session:
sink.write_nodes("db-tenant-x", "`AWSUser`", [])
get_session.assert_not_called()
def test_write_nodes_merges_on_provider_resource_label(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_nodes(
"db-tenant-x",
"`AWSUser`:`_ProviderResource`",
[{"provider_element_id": "p:e", "props": {"k": "v"}}],
)
query, params = session.run.call_args.args
assert "MERGE (n:`_ProviderResource`" in query
assert "`_provider_element_id`: row.provider_element_id" in query
assert "SET n:`AWSUser`:`_ProviderResource`" in query
assert params == {"rows": [{"provider_element_id": "p:e", "props": {"k": "v"}}]}
def test_write_relationships_scopes_endpoints_by_provider_label(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
provider_id = "00000000-0000-0000-0000-000000000abc"
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_relationships(
"db-tenant-x",
"RESOURCE",
provider_id,
[
{
"start_element_id": "s",
"end_element_id": "e",
"provider_element_id": "pe",
"props": {},
}
],
)
query = session.run.call_args.args[0]
assert ":`_Provider_00000000000000000000000000000abc`" in query
assert ":RESOURCE" in query.replace("`", "")
assert "MERGE (s)-[r:`RESOURCE`" in query
class TestNeptuneSinkSyncWrites:
def test_ensure_sync_indexes_is_noop(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
with patch.object(sink, "get_session") as get_session:
sink.ensure_sync_indexes("ignored")
get_session.assert_not_called()
def test_write_nodes_merges_on_neptune_id_with_provider_resource_label(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_nodes(
"ignored",
"`AWSUser`",
[{"provider_element_id": "p:e", "props": {"k": "v"}}],
)
query = session.run.call_args.args[0]
# Neptune assigns a default `vertex` label to any unlabeled node,
# so the MERGE must pin a real label at creation time.
assert "MERGE (n:`_ProviderResource` {`~id`: row.provider_element_id})" in query
assert "SET n:`AWSUser`" in query
assert "SET n.`_provider_element_id` = row.provider_element_id" in query
def test_write_relationships_matches_endpoints_by_id(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
sink.write_relationships(
"ignored",
"RESOURCE",
"provider-1",
[
{
"start_element_id": "s",
"end_element_id": "e",
"provider_element_id": "pe",
"props": {},
}
],
)
query = session.run.call_args.args[0]
assert "MATCH (s) WHERE id(s) = row.start_element_id" in query
assert "MATCH (e) WHERE id(e) = row.end_element_id" in query
assert "MERGE (s)-[r:`RESOURCE`" in query
class TestNeptuneSinkDropSubgraph:
def test_drop_subgraph_deletes_rels_before_nodes_in_bounded_batches(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
rel_record_first = MagicMock()
rel_record_first.__getitem__ = lambda _self, key: 50
rel_record_drain = MagicMock()
rel_record_drain.__getitem__ = lambda _self, key: 0
node_record_first = MagicMock()
node_record_first.__getitem__ = lambda _self, key: 10
node_record_drain = MagicMock()
node_record_drain.__getitem__ = lambda _self, key: 0
run_results = [
MagicMock(single=MagicMock(return_value=rel_record_first)),
MagicMock(single=MagicMock(return_value=rel_record_drain)),
MagicMock(single=MagicMock(return_value=node_record_first)),
MagicMock(single=MagicMock(return_value=node_record_drain)),
]
session.run.side_effect = run_results
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
deleted = sink.drop_subgraph("ignored", "provider-1")
assert deleted == 10
first_query = session.run.call_args_list[0].args[0]
assert "DELETE r" in first_query
assert "DETACH DELETE" not in first_query
# DISTINCT avoids double-counting relationships matched from both ends.
assert "DISTINCT r" in first_query
third_query = session.run.call_args_list[2].args[0]
assert "DELETE n" in third_query
class TestNeo4jSinkDropSubgraph:
"""Neo4j drop deletes relationships then nodes in batches (no ``DETACH DELETE``)."""
def test_drop_subgraph_deletes_rels_before_nodes_in_bounded_batches(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
rel_first = MagicMock()
rel_first.get = lambda key, default=0: 50
rel_drain = MagicMock()
rel_drain.get = lambda key, default=0: 0
node_first = MagicMock()
node_first.get = lambda key, default=0: 10
node_drain = MagicMock()
node_drain.get = lambda key, default=0: 0
session.run.side_effect = [
MagicMock(single=MagicMock(return_value=rel_first)),
MagicMock(single=MagicMock(return_value=rel_drain)),
MagicMock(single=MagicMock(return_value=node_first)),
MagicMock(single=MagicMock(return_value=node_drain)),
]
provider_id = "00000000-0000-0000-0000-000000000abc"
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
deleted = sink.drop_subgraph("db-tenant-x", provider_id)
# Only phase-2 node counts contribute to the return value.
assert deleted == 10
assert session.run.call_count == 4
queries = [call.args[0] for call in session.run.call_args_list]
# Regression guard: the memory blow-up was caused by DETACH DELETE.
assert all("DETACH DELETE" not in query for query in queries)
first_query = queries[0]
assert "DELETE r" in first_query
# DISTINCT avoids double-counting relationships matched from both ends.
assert "DISTINCT r" in first_query
assert ":`_Provider_00000000000000000000000000000abc`" in first_query
assert "DELETE n" in queries[2]
# Relationships must be fully drained before nodes are deleted.
first_node = next(i for i, q in enumerate(queries) if "DELETE n" in q)
last_rel = max(i for i, q in enumerate(queries) if "DELETE r" in q)
assert last_rel < first_node
def test_drop_subgraph_returns_zero_when_database_does_not_exist(self):
from api.attack_paths.database import GraphDatabaseQueryException
from api.attack_paths.sink.neo4j import DATABASE_NOT_FOUND_CODE, Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.side_effect = GraphDatabaseQueryException(
message="db missing", code=DATABASE_NOT_FOUND_CODE
)
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
deleted = sink.drop_subgraph("db-tenant-missing", "provider-1")
assert deleted == 0
class TestSinkHasProviderData:
"""``has_provider_data`` is the read-path probe used by API views."""
def test_neo4j_returns_true_when_provider_node_exists(self):
from api.attack_paths.sink.neo4j import Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.return_value.single.return_value = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
present = sink.has_provider_data(
"db-tenant-x", "00000000-0000-0000-0000-000000000abc"
)
assert present is True
query = session.run.call_args.args[0]
assert ":`_Provider_00000000000000000000000000000abc`" in query
def test_neo4j_returns_false_when_database_does_not_exist(self):
from api.attack_paths.database import GraphDatabaseQueryException
from api.attack_paths.sink.neo4j import DATABASE_NOT_FOUND_CODE, Neo4jSink
sink = Neo4jSink()
session = MagicMock()
session.run.side_effect = GraphDatabaseQueryException(
message="db missing", code=DATABASE_NOT_FOUND_CODE
)
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
present = sink.has_provider_data("db-tenant-missing", "provider-1")
assert present is False
def test_neptune_returns_true_when_provider_node_exists(self):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
session = MagicMock()
session.run.return_value.single.return_value = MagicMock()
with patch.object(sink, "get_session", return_value=_session_ctx(session)):
present = sink.has_provider_data("ignored", "provider-1")
assert present is True
class TestGetBackendForScanCutover:
"""``get_backend_for_scan`` keeps old-sink scans queryable after cutover."""
def test_legacy_scan_on_neptune_process_uses_neo4j_secondary(self, settings):
from api.attack_paths.sink import factory
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
active_neptune = MagicMock(name="neptune-active")
factory._backend = active_neptune
secondary_neo4j = MagicMock(name="neo4j-secondary")
with patch.object(factory, "_build_backend", return_value=secondary_neo4j):
scan = MagicMock(sink_backend="neo4j")
backend = factory.get_backend_for_scan(scan)
assert backend is secondary_neo4j
assert backend is not active_neptune
class TestSinkVerifyConnectivity:
"""The readiness probe calls ``verify_connectivity`` through the shim.
Neo4j checks its single driver; Neptune checks the reader (the API read
path), which on single-endpoint clusters aliases the writer.
"""
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_neo4j_verifies_its_driver(self, mock_driver, settings):
from api.attack_paths.sink.neo4j import Neo4jSink
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
driver = MagicMock()
mock_driver.return_value = driver
sink = Neo4jSink()
sink.init()
driver.verify_connectivity.reset_mock() # ignore the eager init check
sink.verify_connectivity()
driver.verify_connectivity.assert_called_once_with()
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_neptune_verifies_reader_not_writer(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths.sink.neptune import NeptuneSink
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "reader.example",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
writer, reader = MagicMock(name="writer"), MagicMock(name="reader")
mock_driver.side_effect = [writer, reader]
mock_auth_provider.return_value = lambda: None
sink = NeptuneSink()
sink.init()
writer.verify_connectivity.reset_mock()
reader.verify_connectivity.reset_mock()
sink.verify_connectivity()
reader.verify_connectivity.assert_called_once_with()
writer.verify_connectivity.assert_not_called()
class TestSinkInitToleratesUnreachableSink:
"""Init must not crash the process when the sink is down at boot.
Same degradation model as Postgres: the driver is retained and
reconnects lazily; /health/ready surfaces the outage until it recovers.
"""
@patch("api.attack_paths.sink.neo4j.neo4j.GraphDatabase.driver")
def test_neo4j_init_continues_when_verify_fails(self, mock_driver, settings):
from api.attack_paths.sink.neo4j import Neo4jSink
settings.DATABASES = {
**settings.DATABASES,
"neo4j": {
"HOST": "localhost",
"PORT": "7687",
"USER": "neo4j",
"PASSWORD": "pw",
},
}
driver = MagicMock()
driver.verify_connectivity.side_effect = RuntimeError("unreachable")
mock_driver.return_value = driver
sink = Neo4jSink()
# Must not raise.
assert sink.init() is driver
assert sink._driver is driver
@patch("api.attack_paths.sink.neptune.neptune_auth_provider")
@patch("api.attack_paths.sink.neptune.neo4j.GraphDatabase.driver")
def test_neptune_init_continues_when_verify_fails(
self, mock_driver, mock_auth_provider, settings
):
from api.attack_paths.sink.neptune import NeptuneSink
settings.DATABASES = {
**settings.DATABASES,
"neptune": {
"WRITER_ENDPOINT": "writer.example",
"READER_ENDPOINT": "reader.example",
"PORT": "8182",
"REGION": "eu-west-1",
},
}
driver = MagicMock()
driver.verify_connectivity.side_effect = RuntimeError("unreachable")
mock_driver.return_value = driver
mock_auth_provider.return_value = lambda: None
sink = NeptuneSink()
# Must not raise; both drivers retained.
sink.init()
assert sink._writer is not None
assert sink._reader is not None
class TestNeptuneAdminNoOps:
"""Neptune is single-database; admin DDL has no work to do."""
@pytest.mark.parametrize("method", ["create_database", "drop_database"])
def test_admin_ops_return_none_without_touching_a_session(self, method):
from api.attack_paths.sink.neptune import NeptuneSink
sink = NeptuneSink()
with patch.object(sink, "get_session") as get_session:
assert getattr(sink, method)("ignored") is None
get_session.assert_not_called()
class TestNeptuneAuthToken:
"""SigV4 signing for the Neptune Bolt endpoint."""
@patch("api.attack_paths.sink.neptune.SigV4Auth")
@patch("api.attack_paths.sink.neptune.BotoSession")
def test_host_header_includes_non_default_port(self, mock_boto, mock_sigv4):
# Neptune runs on 8182; the SigV4 canonical Host must keep the port or
# the signature is rejected.
from api.attack_paths.sink.neptune import _NeptuneAuthToken
credentials = MagicMock()
credentials.get_frozen_credentials.return_value = MagicMock()
mock_boto.return_value.get_credentials.return_value = credentials
token = _NeptuneAuthToken("eu-west-1", "https://writer.example:8182")
auth_obj = json.loads(token.credentials)
assert auth_obj["Host"] == "writer.example:8182"
+68 -5
View File
@@ -4754,6 +4754,64 @@ class TestAttackPathsScanViewSet:
assert first_attributes["provider_type"] == provider.provider assert first_attributes["provider_type"] == provider.provider
assert first_attributes["provider_uid"] == provider.uid assert first_attributes["provider_uid"] == provider.uid
def test_attack_paths_scans_list_prefers_active_sink_scan_on_rollback(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
settings,
):
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
provider = providers_fixture[0]
neo4j_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neo4j",
)
neptune_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neptune",
)
response = authenticated_client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
ids = {item["id"] for item in response.json()["data"]}
assert str(neo4j_scan.id) in ids
assert str(neptune_scan.id) not in ids
def test_attack_paths_scans_list_falls_back_when_active_sink_has_no_scan(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
settings,
):
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
provider = providers_fixture[0]
legacy_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neo4j",
)
response = authenticated_client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
ids = {item["id"] for item in response.json()["data"]}
assert str(legacy_scan.id) in ids
def test_attack_paths_scans_list_respects_provider_group_visibility( def test_attack_paths_scans_list_respects_provider_group_visibility(
self, self,
authenticated_client_no_permissions_rbac, authenticated_client_no_permissions_rbac,
@@ -4874,7 +4932,8 @@ class TestAttackPathsScanViewSet:
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
mock_get_queries.assert_called_once_with(provider.provider) # TODO: drop the is_migrated argument after Neptune cutover
mock_get_queries.assert_called_once_with(provider.provider, is_migrated=False)
payload = response.json()["data"] payload = response.json()["data"]
assert len(payload) == 1 assert len(payload) == 1
assert payload[0]["id"] == "aws-rds" assert payload[0]["id"] == "aws-rds"
@@ -4974,7 +5033,8 @@ class TestAttackPathsScanViewSet:
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
mock_get_query.assert_called_once_with("aws-rds") # TODO: drop the is_migrated argument after Neptune cutover
mock_get_query.assert_called_once_with("aws-rds", is_migrated=False)
mock_get_db_name.assert_called_once_with(attack_paths_scan.provider.tenant_id) mock_get_db_name.assert_called_once_with(attack_paths_scan.provider.tenant_id)
provider_id = str(attack_paths_scan.provider_id) provider_id = str(attack_paths_scan.provider_id)
mock_prepare.assert_called_once_with( mock_prepare.assert_called_once_with(
@@ -4988,6 +5048,7 @@ class TestAttackPathsScanViewSet:
query_definition, query_definition,
prepared_parameters, prepared_parameters,
provider_id, provider_id,
scan=attack_paths_scan,
) )
result = response.json()["data"] result = response.json()["data"]
attributes = result["attributes"] attributes = result["attributes"]
@@ -5339,6 +5400,7 @@ class TestAttackPathsScanViewSet:
"db-test", "db-test",
"MATCH (n) RETURN n", "MATCH (n) RETURN n",
str(attack_paths_scan.provider_id), str(attack_paths_scan.provider_id),
scan=attack_paths_scan,
) )
attributes = response.json()["data"]["attributes"] attributes = response.json()["data"]["attributes"]
assert len(attributes["nodes"]) == 1 assert len(attributes["nodes"]) == 1
@@ -5875,9 +5937,10 @@ class TestAttackPathsScanViewSet:
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
mock_get_schema.assert_called_once_with( mock_get_schema.assert_called_once()
"db-test", str(attack_paths_scan.provider_id) schema_args = mock_get_schema.call_args[0]
) assert schema_args[:2] == ("db-test", str(attack_paths_scan.provider_id))
assert schema_args[2].id == attack_paths_scan.id
attributes = response.json()["data"]["attributes"] attributes = response.json()["data"]["attributes"]
assert attributes["provider"] == "aws" assert attributes["provider"] == "aws"
assert attributes["cartography_version"] == "0.129.0" assert attributes["cartography_version"] == "0.129.0"
+24 -5
View File
@@ -2876,13 +2876,22 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset()) queryset = self.filter_queryset(self.get_queryset())
active_sink_backend = django_settings.ATTACK_PATHS_SINK_DATABASE
latest_per_provider = queryset.annotate( latest_per_provider = queryset.annotate(
active_sink_rank=Case(
When(sink_backend=active_sink_backend, then=Value(0)),
default=Value(1),
output_field=IntegerField(),
),
latest_scan_rank=Window( latest_scan_rank=Window(
expression=RowNumber(), expression=RowNumber(),
partition_by=[F("provider_id")], partition_by=[F("provider_id")],
order_by=[F("inserted_at").desc()], order_by=[
) F("active_sink_rank").asc(),
F("inserted_at").desc(),
],
),
).filter(latest_scan_rank=1) ).filter(latest_scan_rank=1)
page = self.paginate_queryset(latest_per_provider) page = self.paginate_queryset(latest_per_provider)
@@ -2909,7 +2918,11 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
) )
def attack_paths_queries(self, request, pk=None): def attack_paths_queries(self, request, pk=None):
attack_paths_scan = self.get_object() attack_paths_scan = self.get_object()
queries = get_queries_for_provider(attack_paths_scan.provider.provider) # TODO: drop the is_migrated argument after Neptune cutover
queries = get_queries_for_provider(
attack_paths_scan.provider.provider,
is_migrated=attack_paths_scan.is_migrated,
)
if not queries: if not queries:
return Response( return Response(
@@ -2942,7 +2955,11 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
serializer = AttackPathsQueryRunRequestSerializer(data=payload) serializer = AttackPathsQueryRunRequestSerializer(data=payload)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
query_definition = get_query_by_id(serializer.validated_data["id"]) # TODO: drop the is_migrated argument after Neptune cutover
query_definition = get_query_by_id(
serializer.validated_data["id"],
is_migrated=attack_paths_scan.is_migrated,
)
if ( if (
query_definition is None query_definition is None
or query_definition.provider != attack_paths_scan.provider.provider or query_definition.provider != attack_paths_scan.provider.provider
@@ -2968,6 +2985,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
query_definition, query_definition,
parameters, parameters,
provider_id, provider_id,
scan=attack_paths_scan,
) )
query_duration = time.monotonic() - start query_duration = time.monotonic() - start
@@ -3035,6 +3053,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
database_name, database_name,
serializer.validated_data["query"], serializer.validated_data["query"],
provider_id, provider_id,
scan=attack_paths_scan,
) )
query_duration = time.monotonic() - start query_duration = time.monotonic() - start
@@ -3091,7 +3110,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
provider_id = str(attack_paths_scan.provider_id) provider_id = str(attack_paths_scan.provider_id)
schema = attack_paths_views_helpers.get_cartography_schema( schema = attack_paths_views_helpers.get_cartography_schema(
database_name, provider_id database_name, provider_id, attack_paths_scan
) )
if not schema: if not schema:
return Response( return Response(
+5
View File
@@ -311,6 +311,11 @@ ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES = env.int(
"ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES", 2880 "ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES", 2880
) # 48h ) # 48h
# Selects where the persistent attack-paths graph is stored. The scan
# temporary database is always Neo4j; only the sink is configurable.
# Valid values: "neo4j" (default, OSS and local dev), "neptune" (hosted).
ATTACK_PATHS_SINK_DATABASE = env.str("ATTACK_PATHS_SINK_DATABASE", default="neo4j")
# Orphan task recovery feature flags. The master switch is OFF by default, so task # Orphan task recovery feature flags. The master switch is OFF by default, so task
# recovery is opt-in; enable it with DJANGO_TASK_RECOVERY_ENABLED=true. The per-group # recovery is opt-in; enable it with DJANGO_TASK_RECOVERY_ENABLED=true. The per-group
# toggles default to enabled, so once the master is on every group recovers unless a # toggles default to enabled, so once the master is on every group recovers unless a
+6
View File
@@ -50,6 +50,12 @@ DATABASES = {
"USER": env.str("NEO4J_USER", "neo4j"), "USER": env.str("NEO4J_USER", "neo4j"),
"PASSWORD": env.str("NEO4J_PASSWORD", "neo4j_password"), "PASSWORD": env.str("NEO4J_PASSWORD", "neo4j_password"),
}, },
"neptune": {
"WRITER_ENDPOINT": env.str("NEPTUNE_WRITER_ENDPOINT", ""),
"READER_ENDPOINT": env.str("NEPTUNE_READER_ENDPOINT", ""),
"PORT": env.str("NEPTUNE_PORT", "8182"),
"REGION": env.str("AWS_REGION", ""),
},
} }
DATABASES["default"] = DATABASES["prowler_user"] DATABASES["default"] = DATABASES["prowler_user"]
@@ -49,12 +49,19 @@ DATABASES = {
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host), "HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port), "PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
}, },
# TODO: drop after Neptune cutover just loosen defaults to `""`
"neo4j": { "neo4j": {
"HOST": env.str("NEO4J_HOST"), "HOST": env.str("NEO4J_HOST"),
"PORT": env.str("NEO4J_PORT"), "PORT": env.str("NEO4J_PORT"),
"USER": env.str("NEO4J_USER"), "USER": env.str("NEO4J_USER"),
"PASSWORD": env.str("NEO4J_PASSWORD"), "PASSWORD": env.str("NEO4J_PASSWORD"),
}, },
"neptune": {
"WRITER_ENDPOINT": env.str("NEPTUNE_WRITER_ENDPOINT", default=""),
"READER_ENDPOINT": env.str("NEPTUNE_READER_ENDPOINT", default=""),
"PORT": env.str("NEPTUNE_PORT", default="8182"),
"REGION": env.str("AWS_REGION", default=""),
},
} }
DATABASES["default"] = DATABASES["prowler_user"] DATABASES["default"] = DATABASES["prowler_user"]
+20 -4
View File
@@ -83,12 +83,28 @@ def _warm_compliance_caches_in_background():
def post_fork(_server, worker): def post_fork(_server, worker):
"""Warm compliance caches after each worker fork. """Re-initialize attack-paths drivers and warm compliance caches per worker.
Warm compliance caches in a background thread so the worker becomes ready Neo4j / Neptune drivers spawn background IO threads that do not survive
immediately. A request for a not-yet-warmed provider lazily loads just that ``fork()``. When the gunicorn master runs with ``preload_app=True``, the
provider, which stays well under the worker timeout. child inherits driver objects whose pool references dead threads and
hangs on the first ``pool.acquire`` call until the watchdog kills the
worker. Re-initializing per worker guarantees each child owns its own
live threads. See GUNICORN_WORKER_TIMEOUTS_ANALYSIS.md for detail.
Compliance caches are then warmed in a background thread so the worker
becomes ready immediately. A request for a not-yet-warmed provider lazily
loads just that provider, which stays well under the worker timeout.
""" """
from api.attack_paths import database as graph_database
try:
graph_database.close_driver()
except Exception: # pragma: no cover - best-effort cleanup
pass
graph_database.init_driver()
gunicorn_logger.info(f"Attack-paths drivers initialized for worker {worker.pid}")
threading.Thread( threading.Thread(
target=_warm_compliance_caches_in_background, target=_warm_compliance_caches_in_background,
name="warm-compliance-caches", name="warm-compliance-caches",
+30
View File
@@ -1821,6 +1821,36 @@ def attack_paths_query_definition_factory():
return _create return _create
@pytest.fixture
def sink_backend_stub():
"""Install a stub `SinkDatabase` into the sink factory for the test's duration.
The sink factory caches a process-wide backend and lazily initializes it
against `settings.DATABASES["neo4j"]` / `["neptune"]`. Tests that don't
want to stand up a real Bolt driver can yield this fixture's mock and
configure its return values directly:
sink_backend_stub.execute_read_query.return_value = some_graph
Both the active backend and the secondary-backend cache are restored on
teardown so tests stay isolated.
"""
from api.attack_paths.sink import factory
from api.attack_paths.sink.base import SinkDatabase
stub = MagicMock(spec=SinkDatabase)
previous_backend = factory._backend
previous_secondary = dict(factory._secondary_backends)
factory._backend = stub
factory._secondary_backends.clear()
try:
yield stub
finally:
factory._backend = previous_backend
factory._secondary_backends.clear()
factory._secondary_backends.update(previous_secondary)
@pytest.fixture @pytest.fixture
def attack_paths_graph_stub_classes(): def attack_paths_graph_stub_classes():
"""Provide lightweight graph element stubs for Attack Paths serialization tests.""" """Provide lightweight graph element stubs for Attack Paths serialization tests."""
+20 -4
View File
@@ -6,6 +6,7 @@ from typing import Any
import aioboto3 import aioboto3
import boto3 import boto3
import botocore
import neo4j import neo4j
from api.models import ( from api.models import (
AttackPathsScan as ProwlerAPIAttackPathsScan, AttackPathsScan as ProwlerAPIAttackPathsScan,
@@ -73,13 +74,28 @@ def start_aws_ingestion(
# Adding an extra field # Adding an extra field
common_job_parameters["AWS_ID"] = prowler_api_provider.uid common_job_parameters["AWS_ID"] = prowler_api_provider.uid
cartography_aws._autodiscover_accounts( # AWS Organizations account autodiscovery. Inlined from Cartography's removed
# `_autodiscover_accounts` (deleted in `0.137.0`), as `load_aws_accounts` is still public.
try:
org_client = boto3_session.client("organizations")
paginator = org_client.get_paginator("list_accounts")
discovered = []
for page in paginator.paginate():
discovered.extend(page["Accounts"])
active_accounts = {
a["Name"]: a["Id"] for a in discovered if a["Status"] == "ACTIVE"
}
cartography_aws.organizations.load_aws_accounts(
neo4j_session, neo4j_session,
boto3_session, active_accounts,
prowler_api_provider.uid,
cartography_config.update_tag, cartography_config.update_tag,
common_job_parameters, common_job_parameters,
) )
except botocore.exceptions.ClientError:
logger.warning(
f"Account {prowler_api_provider.uid} lacks permissions for AWS "
"Organizations autodiscovery."
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 4) db_utils.update_attack_paths_scan_progress(attack_paths_scan, 4)
failed_syncs = sync_aws_account( failed_syncs = sync_aws_account(
@@ -277,7 +293,7 @@ def sync_aws_account(
sync_args: dict[str, Any], sync_args: dict[str, Any],
attack_paths_scan: ProwlerAPIAttackPathsScan, attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> dict[str, str]: ) -> dict[str, str]:
current_progress = 4 # `cartography_aws._autodiscover_accounts` current_progress = 4 # AWS Organizations account autodiscovery
max_progress = ( max_progress = (
87 # `cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"]` - 1 87 # `cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"]` - 1
) )
@@ -8,7 +8,7 @@ from celery import states
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from config.django.base import ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES from config.django.base import ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES
from tasks.jobs.attack_paths.db_utils import ( from tasks.jobs.attack_paths.db_utils import (
_mark_scan_finished, mark_scan_finished,
recover_graph_data_ready, recover_graph_data_ready,
) )
from tasks.jobs.orphan_recovery import is_worker_alive as _is_worker_alive from tasks.jobs.orphan_recovery import is_worker_alive as _is_worker_alive
@@ -87,7 +87,7 @@ def _cleanup_stale_executing_scans(cutoff: datetime) -> list[str]:
else: else:
reason = "Worker dead — cleaned up by periodic task" reason = "Worker dead — cleaned up by periodic task"
else: else:
# No worker recorded time-based heuristic only # No worker recorded, time-based heuristic only
if scan.started_at and scan.started_at >= cutoff: if scan.started_at and scan.started_at >= cutoff:
continue continue
reason = ( reason = (
@@ -160,7 +160,7 @@ def _cleanup_scan(scan, task_result, reason: str) -> bool:
""" """
scan_id_str = str(scan.id) scan_id_str = str(scan.id)
# 1. Drop temp Neo4j database # Drop temp Neo4j database
tmp_db_name = graph_database.get_database_name(scan.id, temporary=True) tmp_db_name = graph_database.get_database_name(scan.id, temporary=True)
try: try:
graph_database.drop_database(tmp_db_name) graph_database.drop_database(tmp_db_name)
@@ -225,6 +225,6 @@ def _finalize_failed_scan(scan, expected_state: str, reason: str):
logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping") logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping")
return None return None
_mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason}) mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason})
return fresh_scan return fresh_scan
@@ -1,9 +1,14 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass
from uuid import UUID from uuid import UUID
from config.env import env from config.env import env
from tasks.jobs.attack_paths import aws from tasks.jobs.attack_paths import provider_config as _provider_config
# Re-export provider config objects so existing imports keep working.
AWS_CONFIG = _provider_config.AWS_CONFIG
NormalizedList = _provider_config.NormalizedList
PROVIDER_CONFIGS = _provider_config.PROVIDER_CONFIGS
ProviderConfig = _provider_config.ProviderConfig
# Batch size for Neo4j write operations (resource labeling, cleanup) # Batch size for Neo4j write operations (resource labeling, cleanup)
BATCH_SIZE = env.int("ATTACK_PATHS_BATCH_SIZE", 1000) BATCH_SIZE = env.int("ATTACK_PATHS_BATCH_SIZE", 1000)
@@ -21,42 +26,12 @@ PROWLER_FINDING_LABEL = "ProwlerFinding"
PROVIDER_RESOURCE_LABEL = "_ProviderResource" PROVIDER_RESOURCE_LABEL = "_ProviderResource"
# Dynamic isolation labels that contain entity UUIDs and are added to every synced node during sync # Dynamic isolation labels that contain entity UUIDs and are added to every synced node during sync
# Format: _Tenant_{uuid_no_hyphens}, _Provider_{uuid_no_hyphens} # Format: `_Tenant_{uuid_no_hyphens}`, `_Provider_{uuid_no_hyphens}`
TENANT_LABEL_PREFIX = "_Tenant_" TENANT_LABEL_PREFIX = "_Tenant_"
PROVIDER_LABEL_PREFIX = "_Provider_" PROVIDER_LABEL_PREFIX = "_Provider_"
DYNAMIC_ISOLATION_PREFIXES = [TENANT_LABEL_PREFIX, PROVIDER_LABEL_PREFIX] DYNAMIC_ISOLATION_PREFIXES = [TENANT_LABEL_PREFIX, PROVIDER_LABEL_PREFIX]
@dataclass(frozen=True)
class ProviderConfig:
"""Configuration for a cloud provider's Attack Paths integration."""
name: str
root_node_label: str # e.g., "AWSAccount"
uid_field: str # e.g., "arn"
# Label for resources connected to the account node, enabling indexed finding lookups.
resource_label: str # e.g., "_AWSResource"
ingestion_function: Callable
# Maps a Postgres resource UID (e.g. full ARN) to the short-id form Cartography stores on some node types (e.g. `i-xxx` for EC2Instance).
short_uid_extractor: Callable[[str], str]
# Provider Configurations
# -----------------------
AWS_CONFIG = ProviderConfig(
name="aws",
root_node_label="AWSAccount",
uid_field="arn",
resource_label="_AWSResource",
ingestion_function=aws.start_aws_ingestion,
short_uid_extractor=aws.extract_short_uid,
)
PROVIDER_CONFIGS: dict[str, ProviderConfig] = {
"aws": AWS_CONFIG,
}
# Labels added by Prowler that should be filtered from API responses # Labels added by Prowler that should be filtered from API responses
# Derived from provider configs + common internal labels # Derived from provider configs + common internal labels
INTERNAL_LABELS: list[str] = [ INTERNAL_LABELS: list[str] = [
@@ -87,7 +62,6 @@ INTERNAL_PROPERTIES: list[str] = [
# Provider Config Accessors # Provider Config Accessors
# -------------------------
def is_provider_available(provider_type: str) -> bool: def is_provider_available(provider_type: str) -> bool:
@@ -135,7 +109,6 @@ def get_short_uid_extractor(provider_type: str) -> Callable[[str], str]:
# Dynamic Isolation Label Helpers # Dynamic Isolation Label Helpers
# --------------------------------
def _normalize_uuid(value: str | UUID) -> str: def _normalize_uuid(value: str | UUID) -> str:
@@ -8,6 +8,8 @@ from api.models import Provider as ProwlerAPIProvider
from api.models import StateChoices from api.models import StateChoices
from cartography.config import Config as CartographyConfig from cartography.config import Config as CartographyConfig
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from django.conf import settings
from django.db.models import Case, IntegerField, Value, When
from tasks.jobs.attack_paths.config import is_provider_available from tasks.jobs.attack_paths.config import is_provider_available
logger = get_task_logger(__name__) logger = get_task_logger(__name__)
@@ -29,13 +31,33 @@ def create_attack_paths_scan(
return None return None
with rls_transaction(tenant_id): with rls_transaction(tenant_id):
# Inherit graph_data_ready from the previous scan for this provider, # Inherit metadata from the previous ready scan for this provider so
# so queries remain available while the new scan runs. # queries remain available while the new scan runs. The new row only
previous_data_ready = ProwlerAPIAttackPathsScan.objects.filter( # flips to the target sink after its own graph sync succeeds.
active_sink_backend = settings.ATTACK_PATHS_SINK_DATABASE
previous_ready = (
ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=provider_id,
graph_data_ready=True, graph_data_ready=True,
).exists() )
.annotate(
active_sink_rank=Case(
When(sink_backend=active_sink_backend, then=Value(0)),
default=Value(1),
output_field=IntegerField(),
)
)
.order_by("active_sink_rank", "-inserted_at")
.first()
)
previous_data_ready = previous_ready is not None
inherited_is_migrated = previous_ready.is_migrated if previous_ready else False
inherited_sink_backend = (
previous_ready.sink_backend
if previous_ready
else ProwlerAPIAttackPathsScan.SinkBackendChoices.NEO4J
)
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create( attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
tenant_id=tenant_id, tenant_id=tenant_id,
@@ -44,6 +66,8 @@ def create_attack_paths_scan(
state=StateChoices.SCHEDULED, state=StateChoices.SCHEDULED,
started_at=datetime.now(tz=UTC), started_at=datetime.now(tz=UTC),
graph_data_ready=previous_data_ready, graph_data_ready=previous_data_ready,
is_migrated=inherited_is_migrated,
sink_backend=inherited_sink_backend,
) )
attack_paths_scan.save() attack_paths_scan.save()
@@ -114,7 +138,7 @@ def starting_attack_paths_scan(
return True return True
def _mark_scan_finished( def mark_scan_finished(
attack_paths_scan: ProwlerAPIAttackPathsScan, attack_paths_scan: ProwlerAPIAttackPathsScan,
state: StateChoices, state: StateChoices,
ingestion_exceptions: dict[str, Any], ingestion_exceptions: dict[str, Any],
@@ -148,7 +172,7 @@ def finish_attack_paths_scan(
ingestion_exceptions: dict[str, Any], ingestion_exceptions: dict[str, Any],
) -> None: ) -> None:
with rls_transaction(attack_paths_scan.tenant_id): with rls_transaction(attack_paths_scan.tenant_id):
_mark_scan_finished(attack_paths_scan, state, ingestion_exceptions) mark_scan_finished(attack_paths_scan, state, ingestion_exceptions)
def update_attack_paths_scan_progress( def update_attack_paths_scan_progress(
@@ -169,19 +193,45 @@ def set_graph_data_ready(
attack_paths_scan.save(update_fields=["graph_data_ready"]) attack_paths_scan.save(update_fields=["graph_data_ready"])
def set_scan_migrated(
attack_paths_scan: ProwlerAPIAttackPathsScan,
migrated: bool,
sink_backend: str | None = None,
) -> None:
"""Mark the scan as written with the current (migrated) schema.
Called after a successful sync so the read catalog and sink backend only
switch once the new graph is actually live.
# TODO: drop after Neptune cutover
"""
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.is_migrated = migrated
update_fields = ["is_migrated"]
if sink_backend is not None:
attack_paths_scan.sink_backend = sink_backend
update_fields.append("sink_backend")
attack_paths_scan.save(update_fields=update_fields)
def set_provider_graph_data_ready( def set_provider_graph_data_ready(
attack_paths_scan: ProwlerAPIAttackPathsScan, attack_paths_scan: ProwlerAPIAttackPathsScan,
ready: bool, ready: bool,
sink_backend: str | None = None,
) -> None: ) -> None:
""" """
Set `graph_data_ready` for ALL scans of the same provider. Set `graph_data_ready` for scans of the same provider in one sink.
Used before drop/sync so that older scan IDs cannot bypass the query gate while the graph is being replaced. Used before drop/sync so that older scan IDs in the target sink cannot
bypass the query gate while that sink's graph is being replaced. Scans
preserved in another sink stay queryable for rollback.
""" """
target_sink_backend = sink_backend or attack_paths_scan.sink_backend
with rls_transaction(attack_paths_scan.tenant_id): with rls_transaction(attack_paths_scan.tenant_id):
ProwlerAPIAttackPathsScan.objects.filter( ProwlerAPIAttackPathsScan.objects.filter(
tenant_id=attack_paths_scan.tenant_id, tenant_id=attack_paths_scan.tenant_id,
provider_id=attack_paths_scan.provider_id, provider_id=attack_paths_scan.provider_id,
sink_backend=target_sink_backend,
).update(graph_data_ready=ready) ).update(graph_data_ready=ready)
attack_paths_scan.refresh_from_db(fields=["graph_data_ready"]) attack_paths_scan.refresh_from_db(fields=["graph_data_ready"])
@@ -202,10 +252,15 @@ def recover_graph_data_ready(
next successful scan) is a worse outcome for the user. next successful scan) is a worse outcome for the user.
""" """
try: try:
from api.attack_paths import sink as sink_module
tenant_db = graph_database.get_database_name(attack_paths_scan.tenant_id) tenant_db = graph_database.get_database_name(attack_paths_scan.tenant_id)
if graph_database.has_provider_data( # TODO: drop after Neptune cutover
tenant_db, str(attack_paths_scan.provider_id) # Check the backend that actually holds this scan's data, not the
): # currently configured sink, a stale `EXECUTING` scan from before a
# backend switch must still be recoverable
backend = sink_module.get_backend_for_scan(attack_paths_scan)
if backend.has_provider_data(tenant_db, str(attack_paths_scan.provider_id)):
set_provider_graph_data_ready(attack_paths_scan, True) set_provider_graph_data_ready(attack_paths_scan, True)
logger.info( logger.info(
f"Recovered `graph_data_ready` for provider {attack_paths_scan.provider_id}" f"Recovered `graph_data_ready` for provider {attack_paths_scan.provider_id}"
@@ -247,6 +302,6 @@ def fail_attack_paths_scan(
return return
if fresh.state in (StateChoices.COMPLETED, StateChoices.FAILED): if fresh.state in (StateChoices.COMPLETED, StateChoices.FAILED):
return return
_mark_scan_finished(fresh, StateChoices.FAILED, {"global_error": error}) mark_scan_finished(fresh, StateChoices.FAILED, {"global_error": error})
recover_graph_data_ready(fresh) recover_graph_data_ready(fresh)
@@ -82,7 +82,6 @@ def _to_neo4j_dict(
# Public API # Public API
# ----------
def analysis( def analysis(
@@ -196,7 +195,6 @@ def load_findings(
# Findings Streaming (Generator-based) # Findings Streaming (Generator-based)
# -------------------------------------
def stream_findings_with_resources( def stream_findings_with_resources(
@@ -275,7 +273,6 @@ def _fetch_findings_batch(
# Batch Enrichment # Batch Enrichment
# -----------------
def _enrich_batch_with_resources( def _enrich_batch_with_resources(
@@ -1,5 +1,6 @@
import neo4j import neo4j
from cartography.client.core.tx import run_write_query from cartography.client.core.tx import run_write_query
from cartography.intel import create_indexes as cartography_create_indexes
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from tasks.jobs.attack_paths.config import ( from tasks.jobs.attack_paths.config import (
INTERNET_NODE_LABEL, INTERNET_NODE_LABEL,
@@ -30,14 +31,34 @@ SYNC_INDEX_STATEMENTS = [
def create_findings_indexes(neo4j_session: neo4j.Session) -> None: def create_findings_indexes(neo4j_session: neo4j.Session) -> None:
"""Create indexes for Prowler findings and resource lookups.""" """Create indexes for Prowler findings and resource lookups.
Runs `CREATE INDEX`, so the caller must only invoke this against a Neo4j
session (the temp ingest DB or a Neo4j sink). Neptune auto-manages indexes
and rejects `CREATE INDEX`, so callers skip it for the Neptune sink.
"""
logger.info("Creating indexes for Prowler Findings node types") logger.info("Creating indexes for Prowler Findings node types")
for statement in FINDINGS_INDEX_STATEMENTS: for statement in FINDINGS_INDEX_STATEMENTS:
run_write_query(neo4j_session, statement) run_write_query(neo4j_session, statement)
def create_cartography_indexes(neo4j_session: neo4j.Session, config) -> None:
"""Create Cartography's standard indexes for the session's database.
Runs `CREATE INDEX`, so the caller must only invoke this against a Neo4j
session (the temp ingest DB or a Neo4j sink). Neptune auto-manages indexes
and rejects `CREATE INDEX`, so callers skip it for the Neptune sink.
"""
cartography_create_indexes.run(neo4j_session, config)
def create_sync_indexes(neo4j_session: neo4j.Session) -> None: def create_sync_indexes(neo4j_session: neo4j.Session) -> None:
"""Create indexes for provider resource sync operations.""" """Create indexes for provider resource sync operations.
Runs `CREATE INDEX`, so the caller must only invoke this against a Neo4j
session (the temp ingest DB or a Neo4j sink). Neptune auto-manages indexes
and rejects `CREATE INDEX`, so callers skip it for the Neptune sink.
"""
logger.info("Ensuring ProviderResource indexes exist") logger.info("Ensuring ProviderResource indexes exist")
for statement in SYNC_INDEX_STATEMENTS: for statement in SYNC_INDEX_STATEMENTS:
neo4j_session.run(statement) neo4j_session.run(statement)
@@ -0,0 +1,413 @@
"""
Provider-level Attack Paths configuration.
Each `ProviderConfig` carries the cloud provider's ingestion entry point and
the catalog of list-typed node properties (`normalized_lists`). The sync
layer reads this catalog and materialises each list element as a child node
connected to the parent by a typed edge, so queries traverse the graph
instead of working on serialised list values. Both Neo4j and Neptune sinks
write the same shape and queries are portable across them.
"""
from collections.abc import Callable
from dataclasses import dataclass, field
from tasks.jobs.attack_paths import aws
@dataclass(frozen=True)
class NormalizedList:
"""Catalog entry for a list-typed node property.
Describes how the sync layer materialises a parent node's list-typed
property as a set of child item nodes connected by a typed edge.
Conventions (mechanical, do not invent):
- `child_label`: `<SourceLabel><PropertyPascal>Item`
e.g. AWSPolicyStatement.resource -> AWSPolicyStatementResourceItem
- `rel_type`: `HAS_<PROPERTY_UPPER>`
e.g. resource -> HAS_RESOURCE
- child node property:
* `field_map = []` (scalar list, ~95% case) -> child stores `value: str`
* `field_map = [(src_key, child_field), ...]` (list of dicts, rare)
-> child stores those fields
"""
source_label: str
source_property: str
child_label: str
rel_type: str
field_map: list[tuple[str, str]] = field(default_factory=list)
def __post_init__(self) -> None:
if self.field_map:
child_fields = [dst for _, dst in self.field_map]
if "value" in child_fields:
raise ValueError(
f"NormalizedList {self.source_label}.{self.source_property}: "
"`value` is reserved for scalar mode; do not map a source key to it"
)
src_keys = [src for src, _ in self.field_map]
if len(set(src_keys)) != len(src_keys):
raise ValueError(
f"NormalizedList {self.source_label}.{self.source_property}: "
"duplicate source key in field_map"
)
if len(set(child_fields)) != len(child_fields):
raise ValueError(
f"NormalizedList {self.source_label}.{self.source_property}: "
"duplicate child field in field_map"
)
@dataclass(frozen=True)
class ProviderConfig:
"""Configuration for a cloud provider's Attack Paths integration."""
name: str
root_node_label: str # e.g., "AWSAccount"
uid_field: str # e.g., "arn"
# Label for resources connected to the account node, enabling indexed finding lookups
resource_label: str # e.g., "_AWSResource"
ingestion_function: Callable
# Maps a Postgres resource UID (e.g. full ARN) to the short-id form Cartography stores on some node types (e.g. `i-xxx` for EC2Instance)
short_uid_extractor: Callable[[str], str]
# List-typed properties to materialise as child nodes + edges at sync time.
# Mandatory (may be []). Without an entry here, a list-typed property falls
# back to comma-string flatten and emits a one-time warning.
normalized_lists: list[NormalizedList]
# AWS list-typed property catalog.
# One entry per Cartography node property whose runtime value is a list. The
# sync layer materialises each element as a `<child_label>` node and links it
# to the parent with a `<rel_type>` edge; see the `NormalizedList` docstring
# above for the naming conventions.
AWS_NORMALIZED_LISTS: list[NormalizedList] = [
# AWSPolicyStatement - the hot path driving the 53-query perf fix.
NormalizedList(
"AWSPolicyStatement", "action", "AWSPolicyStatementActionItem", "HAS_ACTION"
),
NormalizedList(
"AWSPolicyStatement",
"notaction",
"AWSPolicyStatementNotactionItem",
"HAS_NOTACTION",
),
NormalizedList(
"AWSPolicyStatement",
"resource",
"AWSPolicyStatementResourceItem",
"HAS_RESOURCE",
),
NormalizedList(
"AWSPolicyStatement",
"notresource",
"AWSPolicyStatementNotresourceItem",
"HAS_NOTRESOURCE",
),
# S3PolicyStatement - same shape as IAM policies; AWS allows list or string.
NormalizedList(
"S3PolicyStatement", "action", "S3PolicyStatementActionItem", "HAS_ACTION"
),
NormalizedList(
"S3PolicyStatement", "resource", "S3PolicyStatementResourceItem", "HAS_RESOURCE"
),
# IAM / Cognito / KMS / Secrets
NormalizedList(
"CognitoIdentityPool", "roles", "CognitoIdentityPoolRolesItem", "HAS_ROLES"
),
NormalizedList(
"KMSKey",
"encryption_algorithms",
"KMSKeyEncryptionAlgorithmsItem",
"HAS_ENCRYPTION_ALGORITHMS",
),
NormalizedList(
"KMSKey",
"signing_algorithms",
"KMSKeySigningAlgorithmsItem",
"HAS_SIGNING_ALGORITHMS",
),
NormalizedList(
"KMSKey",
"anonymous_actions",
"KMSKeyAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
NormalizedList(
"KMSGrant", "operations", "KMSGrantOperationsItem", "HAS_OPERATIONS"
),
NormalizedList(
"SecretsManagerSecretVersion",
"version_stages",
"SecretsManagerSecretVersionVersionStagesItem",
"HAS_VERSION_STAGES",
),
NormalizedList(
"SecretsManagerSecretVersion",
"kms_key_ids",
"SecretsManagerSecretVersionKmsKeyIdsItem",
"HAS_KMS_KEY_IDS",
),
NormalizedList(
"SecretsManagerSecretVersion",
"tags",
"SecretsManagerSecretVersionTagsItem",
"HAS_TAGS",
field_map=[("Key", "key"), ("Value", "value_")],
# `value` is reserved for scalar mode; map `Value` to `value_` to keep dict shape.
),
# Lambda / Compute
NormalizedList(
"AWSLambda", "architectures", "AWSLambdaArchitecturesItem", "HAS_ARCHITECTURES"
),
NormalizedList(
"AWSLambda",
"anonymous_actions",
"AWSLambdaAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
NormalizedList(
"CodeBuildProject",
"environment_variables",
"CodeBuildProjectEnvironmentVariablesItem",
"HAS_ENVIRONMENT_VARIABLES",
),
# ECS family
NormalizedList(
"ECSCluster",
"capacity_providers",
"ECSClusterCapacityProvidersItem",
"HAS_CAPACITY_PROVIDERS",
),
NormalizedList(
"ECSTaskDefinition",
"compatibilities",
"ECSTaskDefinitionCompatibilitiesItem",
"HAS_COMPATIBILITIES",
),
NormalizedList(
"ECSTaskDefinition",
"requires_compatibilities",
"ECSTaskDefinitionRequiresCompatibilitiesItem",
"HAS_REQUIRES_COMPATIBILITIES",
),
NormalizedList(
"ECSContainerDefinition",
"links",
"ECSContainerDefinitionLinksItem",
"HAS_LINKS",
),
NormalizedList(
"ECSContainerDefinition",
"entry_point",
"ECSContainerDefinitionEntryPointItem",
"HAS_ENTRY_POINT",
),
NormalizedList(
"ECSContainerDefinition",
"command",
"ECSContainerDefinitionCommandItem",
"HAS_COMMAND",
),
NormalizedList(
"ECSContainerDefinition",
"dns_servers",
"ECSContainerDefinitionDnsServersItem",
"HAS_DNS_SERVERS",
),
NormalizedList(
"ECSContainerDefinition",
"dns_search_domains",
"ECSContainerDefinitionDnsSearchDomainsItem",
"HAS_DNS_SEARCH_DOMAINS",
),
NormalizedList(
"ECSContainerDefinition",
"docker_security_options",
"ECSContainerDefinitionDockerSecurityOptionsItem",
"HAS_DOCKER_SECURITY_OPTIONS",
),
NormalizedList("ECSContainer", "gpu_ids", "ECSContainerGpuIdsItem", "HAS_GPU_IDS"),
# ECR
NormalizedList(
"ECRImage", "layer_diff_ids", "ECRImageLayerDiffIdsItem", "HAS_LAYER_DIFF_IDS"
),
NormalizedList(
"ECRImage",
"child_image_digests",
"ECRImageChildImageDigestsItem",
"HAS_CHILD_IMAGE_DIGESTS",
),
# EC2 / Networking
NormalizedList(
"EC2Instance",
"exposed_internet_type",
"EC2InstanceExposedInternetTypeItem",
"HAS_EXPOSED_INTERNET_TYPE",
),
NormalizedList(
"AutoScalingGroup",
"exposed_internet_type",
"AutoScalingGroupExposedInternetTypeItem",
"HAS_EXPOSED_INTERNET_TYPE",
),
NormalizedList(
"LaunchConfiguration",
"security_groups",
"LaunchConfigurationSecurityGroupsItem",
"HAS_SECURITY_GROUPS",
),
NormalizedList(
"LaunchTemplateVersion",
"security_group_ids",
"LaunchTemplateVersionSecurityGroupIdsItem",
"HAS_SECURITY_GROUP_IDS",
),
NormalizedList(
"LaunchTemplateVersion",
"security_groups",
"LaunchTemplateVersionSecurityGroupsItem",
"HAS_SECURITY_GROUPS",
),
NormalizedList(
"ELBListener", "policy_names", "ELBListenerPolicyNamesItem", "HAS_POLICY_NAMES"
),
# CloudFront / Route53 / CloudWatch / CloudTrail
NormalizedList(
"CloudFrontDistribution",
"aliases",
"CloudFrontDistributionAliasesItem",
"HAS_ALIASES",
),
NormalizedList(
"CloudFrontDistribution",
"geo_restriction_locations",
"CloudFrontDistributionGeoRestrictionLocationsItem",
"HAS_GEO_RESTRICTION_LOCATIONS",
),
NormalizedList(
"CloudWatchLogGroup",
"inherited_properties",
"CloudWatchLogGroupInheritedPropertiesItem",
"HAS_INHERITED_PROPERTIES",
),
# RDS / Storage
NormalizedList(
"RDSCluster",
"availability_zones",
"RDSClusterAvailabilityZonesItem",
"HAS_AVAILABILITY_ZONES",
),
NormalizedList(
"RDSEventSubscription",
"event_categories",
"RDSEventSubscriptionEventCategoriesItem",
"HAS_EVENT_CATEGORIES",
),
NormalizedList(
"RDSEventSubscription",
"source_ids",
"RDSEventSubscriptionSourceIdsItem",
"HAS_SOURCE_IDS",
),
NormalizedList(
"S3Bucket",
"anonymous_actions",
"S3BucketAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
# Inspector / Config / SSM / ACM / APIGateway / Glue / SageMaker / Bedrock
NormalizedList(
"AWSInspectorFinding",
"referenceurls",
"AWSInspectorFindingReferenceurlsItem",
"HAS_REFERENCEURLS",
),
NormalizedList(
"AWSInspectorFinding",
"relatedvulnerabilities",
"AWSInspectorFindingRelatedvulnerabilitiesItem",
"HAS_RELATEDVULNERABILITIES",
),
NormalizedList(
"AWSInspectorFinding",
"vulnerablepackageids",
"AWSInspectorFindingVulnerablepackageidsItem",
"HAS_VULNERABLEPACKAGEIDS",
),
NormalizedList(
"AWSConfigurationRecorder",
"recording_group_resource_types",
"AWSConfigurationRecorderRecordingGroupResourceTypesItem",
"HAS_RECORDING_GROUP_RESOURCE_TYPES",
),
NormalizedList(
"AWSConfigRule",
"scope_compliance_resource_types",
"AWSConfigRuleScopeComplianceResourceTypesItem",
"HAS_SCOPE_COMPLIANCE_RESOURCE_TYPES",
),
NormalizedList(
"AWSConfigRule",
"source_details",
"AWSConfigRuleSourceDetailsItem",
"HAS_SOURCE_DETAILS",
),
NormalizedList(
"SSMInstancePatch", "cve_ids", "SSMInstancePatchCveIdsItem", "HAS_CVE_IDS"
),
NormalizedList(
"ACMCertificate", "in_use_by", "ACMCertificateInUseByItem", "HAS_IN_USE_BY"
),
NormalizedList(
"APIGatewayRestAPI",
"anonymous_actions",
"APIGatewayRestAPIAnonymousActionsItem",
"HAS_ANONYMOUS_ACTIONS",
),
NormalizedList(
"GlueJob", "connections", "GlueJobConnectionsItem", "HAS_CONNECTIONS"
),
NormalizedList(
"AWSBedrockFoundationModel",
"input_modalities",
"AWSBedrockFoundationModelInputModalitiesItem",
"HAS_INPUT_MODALITIES",
),
NormalizedList(
"AWSBedrockFoundationModel",
"output_modalities",
"AWSBedrockFoundationModelOutputModalitiesItem",
"HAS_OUTPUT_MODALITIES",
),
NormalizedList(
"AWSBedrockFoundationModel",
"customizations_supported",
"AWSBedrockFoundationModelCustomizationsSupportedItem",
"HAS_CUSTOMIZATIONS_SUPPORTED",
),
NormalizedList(
"AWSBedrockFoundationModel",
"inference_types_supported",
"AWSBedrockFoundationModelInferenceTypesSupportedItem",
"HAS_INFERENCE_TYPES_SUPPORTED",
),
]
AWS_CONFIG = ProviderConfig(
name="aws",
root_node_label="AWSAccount",
uid_field="arn",
resource_label="_AWSResource",
ingestion_function=aws.start_aws_ingestion,
short_uid_extractor=aws.extract_short_uid,
normalized_lists=AWS_NORMALIZED_LISTS,
)
PROVIDER_CONFIGS: dict[str, ProviderConfig] = {
"aws": AWS_CONFIG,
}
@@ -1,8 +1,6 @@
# Cypher query templates for Attack Paths operations # Cypher query templates for Attack Paths operations
from tasks.jobs.attack_paths.config import ( from tasks.jobs.attack_paths.config import (
INTERNET_NODE_LABEL, INTERNET_NODE_LABEL,
PROVIDER_ELEMENT_ID_PROPERTY,
PROVIDER_RESOURCE_LABEL,
PROWLER_FINDING_LABEL, PROWLER_FINDING_LABEL,
) )
@@ -21,7 +19,6 @@ def render_cypher_template(template: str, replacements: dict[str, str]) -> str:
# Findings queries (used by findings.py) # Findings queries (used by findings.py)
# ---------------------------------------
ADD_RESOURCE_LABEL_TEMPLATE = """ ADD_RESOURCE_LABEL_TEMPLATE = """
MATCH (account:__ROOT_LABEL__ {id: $provider_uid})-->(r) MATCH (account:__ROOT_LABEL__ {id: $provider_uid})-->(r)
@@ -88,7 +85,6 @@ INSERT_FINDING_TEMPLATE = f"""
""" """
# Internet queries (used by internet.py) # Internet queries (used by internet.py)
# ---------------------------------------
CREATE_INTERNET_NODE = f""" CREATE_INTERNET_NODE = f"""
MERGE (internet:{INTERNET_NODE_LABEL} {{id: 'Internet'}}) MERGE (internet:{INTERNET_NODE_LABEL} {{id: 'Internet'}})
@@ -118,8 +114,8 @@ CREATE_CAN_ACCESS_RELATIONSHIPS_TEMPLATE = f"""
RETURN COUNT(r) AS relationships_merged RETURN COUNT(r) AS relationships_merged
""" """
# Sync queries (used by sync.py) # Sync queries (used by sync.py to fetch from the cartography temp Neo4j DB)
# ------------------------------- # The write side of sync lives in each sink (`api/attack_paths/sink/`).
NODE_FETCH_QUERY = """ NODE_FETCH_QUERY = """
MATCH (n) MATCH (n)
@@ -143,17 +139,3 @@ RELATIONSHIPS_FETCH_QUERY = """
ORDER BY internal_id ORDER BY internal_id
LIMIT $batch_size LIMIT $batch_size
""" """
NODE_SYNC_TEMPLATE = f"""
UNWIND $rows AS row
MERGE (n:__NODE_LABELS__ {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.provider_element_id}})
SET n += row.props
"""
RELATIONSHIP_SYNC_TEMPLATE = f"""
UNWIND $rows AS row
MATCH (s:{PROVIDER_RESOURCE_LABEL} {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.start_element_id}})
MATCH (t:{PROVIDER_RESOURCE_LABEL} {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.end_element_id}})
MERGE (s)-[r:__REL_TYPE__ {{{PROVIDER_ELEMENT_ID_PROPERTY}: row.provider_element_id}}]->(t)
SET r += row.props
"""
+85 -26
View File
@@ -39,8 +39,8 @@ Pipeline steps:
7. Sync the temp database into the tenant database: 7. Sync the temp database into the tenant database:
- Drop the old provider subgraph (matched by dynamic _Provider_{uuid} label). - Drop the old provider subgraph (matched by dynamic _Provider_{uuid} label).
graph_data_ready is set to False for all scans of this provider while graph_data_ready is set to False for scans of this provider in the
the swap happens so the API doesn't serve partial data. target sink while the swap happens so the API doesn't serve partial data.
- Copy nodes and relationships in batches. Every synced node gets a - Copy nodes and relationships in batches. Every synced node gets a
_ProviderResource label and dynamic _Tenant_{uuid} / _Provider_{uuid} _ProviderResource label and dynamic _Tenant_{uuid} / _Provider_{uuid}
isolation labels, plus a _provider_element_id property for MERGE keys. isolation labels, plus a _provider_element_id property for MERGE keys.
@@ -64,10 +64,17 @@ from api.models import StateChoices
from api.utils import initialize_prowler_provider from api.utils import initialize_prowler_provider
from cartography.config import Config as CartographyConfig from cartography.config import Config as CartographyConfig
from cartography.intel import analysis as cartography_analysis from cartography.intel import analysis as cartography_analysis
from cartography.intel import create_indexes as cartography_create_indexes
from cartography.intel import ontology as cartography_ontology from cartography.intel import ontology as cartography_ontology
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from tasks.jobs.attack_paths import db_utils, findings, indexes, internet, sync, utils from django.conf import settings
from tasks.jobs.attack_paths import (
db_utils,
findings,
indexes,
internet,
sync,
utils,
)
from tasks.jobs.attack_paths.config import get_cartography_ingestion_function from tasks.jobs.attack_paths.config import get_cartography_ingestion_function
# Without this Celery goes crazy with Cartography logging # Without this Celery goes crazy with Cartography logging
@@ -96,7 +103,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
attack_paths_scan = db_utils.retrieve_attack_paths_scan(tenant_id, scan_id) attack_paths_scan = db_utils.retrieve_attack_paths_scan(tenant_id, scan_id)
# Idempotency guard: cleanup may have flipped this row to a terminal state # Idempotency guard: cleanup may have flipped this row to a terminal state
# while the message was still in flight. Bail out before touching state. # while the message was still in flight. Bail out before touching state
if attack_paths_scan and attack_paths_scan.state in ( if attack_paths_scan and attack_paths_scan.state in (
StateChoices.FAILED, StateChoices.FAILED,
StateChoices.COMPLETED, StateChoices.COMPLETED,
@@ -125,7 +132,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
else: else:
if not attack_paths_scan: if not attack_paths_scan:
# Safety net for in-flight messages or direct task invocations; dispatcher normally pre-creates the row. # Safety net for in-flight messages or direct task invocations; dispatcher normally pre-creates the row
logger.warning( logger.warning(
f"No Attack Paths Scan found for scan {scan_id} and tenant {tenant_id}, let's create it then" f"No Attack Paths Scan found for scan {scan_id} and tenant {tenant_id}, let's create it then"
) )
@@ -143,10 +150,18 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
tenant_database_name = graph_database.get_database_name( tenant_database_name = graph_database.get_database_name(
prowler_api_provider.tenant_id prowler_api_provider.tenant_id
) )
target_sink_backend = settings.ATTACK_PATHS_SINK_DATABASE
target_description = (
f"tenant Neo4j database {tenant_database_name}"
if target_sink_backend == "neo4j"
else f"{target_sink_backend} sink"
)
# While creating the Cartography configuration, attributes `neo4j_user` and `neo4j_password` are not really needed in this config object # While creating the Cartography configuration, attributes `neo4j_user` and `neo4j_password` are not really needed in this config object
tmp_cartography_config = CartographyConfig( tmp_cartography_config = CartographyConfig(
neo4j_uri=graph_database.get_uri(), # The temp ingest database is always Neo4j, so use the ingest URI here
# rather than the sink URI (which points at Neptune when configured).
neo4j_uri=graph_database.get_ingest_uri(),
neo4j_database=tmp_database_name, neo4j_database=tmp_database_name,
update_tag=int(time.time()), update_tag=int(time.time()),
) )
@@ -169,6 +184,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
logger.info( logger.info(
f"Starting Attack Paths scan ({attack_paths_scan.id}) for " f"Starting Attack Paths scan ({attack_paths_scan.id}) for "
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id} " f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id} "
f"(staging=Neo4j database {tmp_database_name}, target={target_description})"
) )
subgraph_dropped = False subgraph_dropped = False
@@ -177,7 +193,8 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
try: try:
logger.info( logger.info(
f"Creating Neo4j database {tmp_cartography_config.neo4j_database} for tenant {prowler_api_provider.tenant_id}" f"Creating staging Neo4j database {tmp_cartography_config.neo4j_database} "
f"for tenant {prowler_api_provider.tenant_id}"
) )
graph_database.create_database(tmp_cartography_config.neo4j_database) graph_database.create_database(tmp_cartography_config.neo4j_database)
@@ -191,7 +208,9 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
tmp_cartography_config.neo4j_database tmp_cartography_config.neo4j_database
) as tmp_neo4j_session: ) as tmp_neo4j_session:
# Indexes creation # Indexes creation
cartography_create_indexes.run(tmp_neo4j_session, tmp_cartography_config) indexes.create_cartography_indexes(
tmp_neo4j_session, tmp_cartography_config
)
indexes.create_findings_indexes(tmp_neo4j_session) indexes.create_findings_indexes(tmp_neo4j_session)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 2) db_utils.update_attack_paths_scan_progress(attack_paths_scan, 2)
@@ -223,7 +242,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
cartography_analysis.run(tmp_neo4j_session, tmp_cartography_config) cartography_analysis.run(tmp_neo4j_session, tmp_cartography_config)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 95) db_utils.update_attack_paths_scan_progress(attack_paths_scan, 95)
# Creating Internet node and CAN_ACCESS relationships # Creating Internet node and `CAN_ACCESS` relationships
logger.info( logger.info(
f"Creating Internet graph for AWS account {prowler_api_provider.uid}" f"Creating Internet graph for AWS account {prowler_api_provider.uid}"
) )
@@ -247,23 +266,41 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 97) db_utils.update_attack_paths_scan_progress(attack_paths_scan, 97)
logger.info( logger.info(
f"Clearing Neo4j cache for database {tmp_cartography_config.neo4j_database}" f"Clearing Neo4j cache for staging database {tmp_cartography_config.neo4j_database}"
) )
graph_database.clear_cache(tmp_cartography_config.neo4j_database) graph_database.clear_cache(tmp_cartography_config.neo4j_database)
t0 = time.perf_counter()
logger.info( logger.info(
f"Ensuring tenant database {tenant_database_name}, and its indexes, exists for tenant {prowler_api_provider.tenant_id}" f"Preparing target {target_description} for tenant {prowler_api_provider.tenant_id}"
) )
graph_database.create_database(tenant_database_name) graph_database.create_database(tenant_database_name)
with graph_database.get_session(tenant_database_name) as tenant_neo4j_session: # Sink-side index creation: Neptune auto-manages indexes and rejects
cartography_create_indexes.run( # `CREATE INDEX`, so only run it when the sink is Neo4j
# The temp ingest DB is always Neo4j and is always indexed above
if target_sink_backend != "neptune":
logger.info(f"Ensuring indexes exist for {target_description}")
with graph_database.get_session(
tenant_database_name
) as tenant_neo4j_session:
indexes.create_cartography_indexes(
tenant_neo4j_session, tenant_cartography_config tenant_neo4j_session, tenant_cartography_config
) )
indexes.create_findings_indexes(tenant_neo4j_session) indexes.create_findings_indexes(tenant_neo4j_session)
indexes.create_sync_indexes(tenant_neo4j_session) indexes.create_sync_indexes(tenant_neo4j_session)
else:
logger.info("Skipping tenant database indexes for neptune sink")
logger.info(
f"Prepared target {target_description} in {time.perf_counter() - t0:.3f}s"
)
logger.info(f"Deleting existing provider graph in {tenant_database_name}") logger.info(
db_utils.set_provider_graph_data_ready(attack_paths_scan, False) f"Deleting existing provider graph from {target_description} "
f"(tenant={prowler_api_provider.tenant_id}, provider={prowler_api_provider.id})"
)
db_utils.set_provider_graph_data_ready(
attack_paths_scan, False, target_sink_backend
)
provider_gated = True provider_gated = True
t0 = time.perf_counter() t0 = time.perf_counter()
@@ -272,14 +309,17 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
provider_id=str(prowler_api_provider.id), provider_id=str(prowler_api_provider.id),
) )
logger.info( logger.info(
f"Deleted existing provider graph in {time.perf_counter() - t0:.3f}s " f"Deleted existing provider graph from {target_description} "
f"(deleted_nodes={deleted_nodes})" f"in {time.perf_counter() - t0:.3f}s (deleted_nodes={deleted_nodes})"
) )
subgraph_dropped = True subgraph_dropped = True
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 98) db_utils.update_attack_paths_scan_progress(attack_paths_scan, 98)
logger.info( logger.info(
f"Syncing graph from {tmp_database_name} into {tenant_database_name}" f"Syncing staging graph {tmp_database_name} into {target_description} "
f"for provider {prowler_api_provider.id} "
f"(tenant {prowler_api_provider.tenant_id}, "
f"type {prowler_api_provider.provider})"
) )
t0 = time.perf_counter() t0 = time.perf_counter()
sync_result = sync.sync_graph( sync_result = sync.sync_graph(
@@ -287,16 +327,33 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
target_database=tenant_database_name, target_database=tenant_database_name,
tenant_id=str(prowler_api_provider.tenant_id), tenant_id=str(prowler_api_provider.tenant_id),
provider_id=str(prowler_api_provider.id), provider_id=str(prowler_api_provider.id),
provider_type=prowler_api_provider.provider,
) )
elapsed = time.perf_counter() - t0
total_nodes = sync_result["nodes"] + sync_result["child_nodes"]
elements = total_nodes + sync_result["relationships"]
rate = elements / elapsed if elapsed else 0
logger.info( logger.info(
f"Synced graph in {time.perf_counter() - t0:.3f}s " f"Synced staging graph into {target_description} in {elapsed:.3f}s - "
f"(nodes={sync_result['nodes']}, relationships={sync_result['relationships']})" f"nodes={total_nodes} (source={sync_result['nodes']}, "
f"items={sync_result['child_nodes']}), "
f"relationships={sync_result['relationships']} "
f"(structural={sync_result['structural_relationships']}, "
f"items={sync_result['item_relationships']}), "
f"~{rate:.0f} elem/s"
) )
sync_completed = True sync_completed = True
# Flip metadata only now: the new schema is live in the target sink, so
# reads can switch to the current catalog/backend. The target-sink gate
# is already closed, so the switch is atomic from the API's view.
db_utils.set_scan_migrated(attack_paths_scan, True, target_sink_backend)
db_utils.set_graph_data_ready(attack_paths_scan, True) db_utils.set_graph_data_ready(attack_paths_scan, True)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 99) db_utils.update_attack_paths_scan_progress(attack_paths_scan, 99)
logger.info(f"Clearing Neo4j cache for database {tenant_database_name}") if target_sink_backend == "neptune":
logger.info("Skipping cache clear for neptune sink")
else:
logger.info(f"Clearing Neo4j cache for target {target_description}")
graph_database.clear_cache(tenant_database_name) graph_database.clear_cache(tenant_database_name)
logger.info(f"Dropping temporary Neo4j database {tmp_database_name}") logger.info(f"Dropping temporary Neo4j database {tmp_database_name}")
@@ -316,14 +373,16 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
logger.exception(exception_message) logger.exception(exception_message)
ingestion_exceptions["global_error"] = exception_message ingestion_exceptions["global_error"] = exception_message
# Recover graph_data_ready based on how far the swap got. # Recover `graph_data_ready` based on how far the swap got
# Partial drop (mid-batch failure) may leave `subgraph_dropped=False` # Partial drop (mid-batch failure) may leave `subgraph_dropped=False` with data partially deleted,
# with data partially deleted, so we prefer that over permanently blocked queries. # so we prefer that over permanently blocked queries
try: try:
if sync_completed: if sync_completed:
db_utils.set_graph_data_ready(attack_paths_scan, True) db_utils.set_graph_data_ready(attack_paths_scan, True)
elif provider_gated and not subgraph_dropped: elif provider_gated and not subgraph_dropped:
db_utils.set_provider_graph_data_ready(attack_paths_scan, True) db_utils.set_provider_graph_data_ready(
attack_paths_scan, True, target_sink_backend
)
except Exception: except Exception:
logger.error( logger.error(
+355 -43
View File
@@ -1,40 +1,57 @@
""" """
Graph sync operations for Attack Paths. Graph sync operations for Attack Paths.
This module handles syncing graph data from temporary scan databases Reads nodes and relationships out of the cartography temp database (always
to the tenant database, adding provider isolation labels and properties. Neo4j) and hands them to the configured sink (Neo4j or Neptune) in batches.
Backend-specific Cypher (MERGE shape, ID strategy, indexes) lives in each
sink; this module owns the source read loop, per-batch grouping, and the
list-property materialisation policy (see `NormalizedList`).
Each list-typed node property that appears in the provider's
`normalized_lists` catalog becomes a set of child item nodes connected to
the parent by a typed edge. A list-typed property that is not in the
catalog is serialised to a comma-delimited string and emits a one-time
warning per (label, property), surfacing Cartography fields that should be
added to the catalog.
""" """
import json
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
import neo4j import neo4j
from api.attack_paths import database as graph_database from api.attack_paths import database as graph_database
from api.attack_paths import sink as sink_module
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from tasks.jobs.attack_paths.config import ( from tasks.jobs.attack_paths.config import (
PROVIDER_CONFIGS,
PROVIDER_ISOLATION_PROPERTIES, PROVIDER_ISOLATION_PROPERTIES,
PROVIDER_RESOURCE_LABEL, PROVIDER_RESOURCE_LABEL,
SYNC_BATCH_SIZE, SYNC_BATCH_SIZE,
NormalizedList,
get_provider_label, get_provider_label,
get_tenant_label, get_tenant_label,
) )
from tasks.jobs.attack_paths.queries import ( from tasks.jobs.attack_paths.queries import (
NODE_FETCH_QUERY, NODE_FETCH_QUERY,
NODE_SYNC_TEMPLATE,
RELATIONSHIP_SYNC_TEMPLATE,
RELATIONSHIPS_FETCH_QUERY, RELATIONSHIPS_FETCH_QUERY,
render_cypher_template,
) )
logger = get_task_logger(__name__) logger = get_task_logger(__name__)
# (label, property) tuples for which we've already emitted the
# "unnormalised list" warning. Module-level so the warning fires once per
# process, not once per node.
_WARNED_UNNORMALIZED: set[tuple[str, str]] = set()
def sync_graph( def sync_graph(
source_database: str, source_database: str,
target_database: str, target_database: str,
tenant_id: str, tenant_id: str,
provider_id: str, provider_id: str,
provider_type: str,
) -> dict[str, int]: ) -> dict[str, int]:
""" """
Sync all nodes and relationships from source to target database. Sync all nodes and relationships from source to target database.
@@ -44,25 +61,38 @@ def sync_graph(
`target_database`: The tenant database `target_database`: The tenant database
`tenant_id`: The tenant ID for isolation `tenant_id`: The tenant ID for isolation
`provider_id`: The provider ID for isolation `provider_id`: The provider ID for isolation
`provider_type`: Provider type key (e.g. "aws"), used to resolve the
`NormalizedList` catalog from `PROVIDER_CONFIGS`.
Returns: Returns:
Dict with counts of synced nodes and relationships Dict with counts of synced nodes, child item nodes, and relationships.
""" """
nodes_synced = sync_nodes( sink = sink_module.get_backend()
sink.ensure_sync_indexes(target_database)
normalized_lists = _resolve_normalized_lists(provider_type)
node_result = sync_nodes(
source_database, source_database,
target_database, target_database,
tenant_id, tenant_id,
provider_id, provider_id,
sink,
normalized_lists,
) )
relationships_synced = sync_relationships( relationships_synced = sync_relationships(
source_database, source_database,
target_database, target_database,
provider_id, provider_id,
sink,
) )
return { return {
"nodes": nodes_synced, "nodes": node_result["parents"],
"relationships": relationships_synced, "child_nodes": node_result["children"],
"relationships": relationships_synced + node_result["parent_child_rels"],
"structural_relationships": relationships_synced,
"item_relationships": node_result["parent_child_rels"],
} }
@@ -71,22 +101,35 @@ def sync_nodes(
target_database: str, target_database: str,
tenant_id: str, tenant_id: str,
provider_id: str, provider_id: str,
) -> int: sink: Any,
normalized_lists: list[NormalizedList],
) -> dict[str, int]:
""" """
Sync nodes from source to target database. Sync nodes from source to target database, exploding catalogued list
properties into child nodes + parent->child edges.
Adds `_ProviderResource` label and dynamic `_Tenant_{id}` and `_Provider_{id}` Adds `_ProviderResource` label and dynamic `_Tenant_{id}` and `_Provider_{id}`
isolation labels to all nodes. isolation labels to all nodes (parents and children alike).
Source and target sessions are opened sequentially per batch to avoid Source and target sessions are opened sequentially per batch to avoid
holding two Bolt connections simultaneously for the entire sync duration. holding two Bolt connections simultaneously for the entire sync duration.
""" """
t0 = time.perf_counter() t0 = time.perf_counter()
last_id = -1 last_id = -1
total_synced = 0 parents_synced = 0
children_synced = 0
parent_child_rels = 0
catalog = _build_catalog_index(normalized_lists)
extra_labels = _build_extra_labels(tenant_id, provider_id)
while True: while True:
grouped: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list) tb = time.perf_counter()
prev_children = children_synced
prev_rels = parent_child_rels
parent_groups: dict[tuple[str, ...], list[dict[str, Any]]] = defaultdict(list)
child_groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
rel_groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
batch_count = 0 batch_count = 0
with graph_database.get_session(source_database) as source_session: with graph_database.get_session(source_database) as source_session:
@@ -97,43 +140,65 @@ def sync_nodes(
for record in result: for record in result:
batch_count += 1 batch_count += 1
last_id = record["internal_id"] last_id = record["internal_id"]
key, value = _node_to_sync_dict(record, provider_id) key, parent_dict, children, rels = _node_to_sync_dict(
grouped[key].append(value) record, provider_id, catalog
)
parent_groups[key].append(parent_dict)
for child in children:
child_groups[child["_child_label"]].append(child["row"])
for rel in rels:
rel_groups[rel["rel_type"]].append(rel["row"])
if batch_count == 0: if batch_count == 0:
break break
with graph_database.get_session(target_database) as target_session: for labels, batch in parent_groups.items():
for labels, batch in grouped.items(): sink.write_nodes(
label_set = set(labels) target_database, _render_labels(labels, extra_labels), batch
label_set.add(PROVIDER_RESOURCE_LABEL)
label_set.add(get_tenant_label(tenant_id))
label_set.add(get_provider_label(provider_id))
node_labels = ":".join(f"`{label}`" for label in sorted(label_set))
query = render_cypher_template(
NODE_SYNC_TEMPLATE, {"__NODE_LABELS__": node_labels}
) )
target_session.run(query, {"rows": batch})
total_synced += batch_count for child_label, batch in child_groups.items():
sink.write_nodes(
target_database,
_render_labels((child_label,), extra_labels),
batch,
)
children_synced += len(batch)
for rel_type, batch in rel_groups.items():
sink.write_relationships(target_database, rel_type, provider_id, batch)
parent_child_rels += len(batch)
parents_synced += batch_count
batch_dt = time.perf_counter() - tb
batch_elements = (
batch_count
+ (children_synced - prev_children)
+ (parent_child_rels - prev_rels)
)
rate = batch_elements / batch_dt if batch_dt else 0
logger.info( logger.info(
f"Synced {total_synced} nodes from {source_database} to {target_database} in {time.perf_counter() - t0:.3f}s" f"[sync nodes] {parents_synced} source (+{children_synced} items, "
f"+{parent_child_rels} item rels) · batch {batch_dt:.1f}s · "
f"elapsed {time.perf_counter() - t0:.1f}s · ~{rate:.0f} elem/s"
) )
return total_synced return {
"parents": parents_synced,
"children": children_synced,
"parent_child_rels": parent_child_rels,
}
def sync_relationships( def sync_relationships(
source_database: str, source_database: str,
target_database: str, target_database: str,
provider_id: str, provider_id: str,
sink: Any,
) -> int: ) -> int:
""" """
Sync relationships from source to target database. Sync relationships from source to target database.
Matches source and target nodes by `_provider_element_id` in the tenant database.
Source and target sessions are opened sequentially per batch to avoid Source and target sessions are opened sequentially per batch to avoid
holding two Bolt connections simultaneously for the entire sync duration. holding two Bolt connections simultaneously for the entire sync duration.
""" """
@@ -142,6 +207,7 @@ def sync_relationships(
total_synced = 0 total_synced = 0
while True: while True:
tb = time.perf_counter()
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list) grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
batch_count = 0 batch_count = 0
@@ -159,32 +225,197 @@ def sync_relationships(
if batch_count == 0: if batch_count == 0:
break break
with graph_database.get_session(target_database) as target_session:
for rel_type, batch in grouped.items(): for rel_type, batch in grouped.items():
query = render_cypher_template( sink.write_relationships(target_database, rel_type, provider_id, batch)
RELATIONSHIP_SYNC_TEMPLATE, {"__REL_TYPE__": rel_type}
)
target_session.run(query, {"rows": batch})
total_synced += batch_count total_synced += batch_count
batch_dt = time.perf_counter() - tb
rate = batch_count / batch_dt if batch_dt else 0
logger.info( logger.info(
f"Synced {total_synced} relationships from {source_database} to {target_database} in {time.perf_counter() - t0:.3f}s" f"[sync rels] {total_synced} structural · batch {batch_dt:.1f}s · "
f"elapsed {time.perf_counter() - t0:.1f}s · ~{rate:.0f}/s"
) )
return total_synced return total_synced
def _node_to_sync_dict( def _node_to_sync_dict(
record: neo4j.Record, provider_id: str record: neo4j.Record,
) -> tuple[tuple[str, ...], dict[str, Any]]: provider_id: str,
"""Transform a source node record into a (grouping_key, sync_dict) pair.""" catalog: dict[tuple[str, str], NormalizedList],
) -> tuple[
tuple[str, ...],
dict[str, Any],
list[dict[str, Any]],
list[dict[str, Any]],
]:
"""Transform a source node record into a (grouping_key, sync_dict, children, rels) tuple.
Catalogued list properties are popped from `props` and emitted as child
nodes + parent->child relationships.
"""
props = dict(record["props"] or {}) props = dict(record["props"] or {})
_strip_internal_properties(props) _strip_internal_properties(props)
labels = tuple(sorted(set(record["labels"] or []))) labels = tuple(sorted(set(record["labels"] or [])))
return labels, { parent_element_id = f"{provider_id}:{record['element_id']}"
"provider_element_id": f"{provider_id}:{record['element_id']}",
children, rels = _explode_catalogued_lists(
labels, props, catalog, provider_id, parent_element_id
)
_normalize_sink_properties(props, labels)
parent = {
"provider_element_id": parent_element_id,
"props": props, "props": props,
} }
return labels, parent, children, rels
def _explode_catalogued_lists(
labels: tuple[str, ...],
props: dict[str, Any],
catalog: dict[tuple[str, str], NormalizedList],
provider_id: str,
parent_element_id: str,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Pop catalogued list properties from `props` and produce child + rel emits.
A node may carry multiple labels (e.g. `AWSPolicyStatement` plus
`_AWSResource`); we check each label for catalog matches independently.
Returns:
- children: list of {"_child_label": str, "row": <node row>} dicts.
- rels: list of {"rel_type": str, "row": <rel row>} dicts.
"""
children: list[dict[str, Any]] = []
rels: list[dict[str, Any]] = []
for label in labels:
for key in list(props.keys()):
spec = catalog.get((label, key))
if spec is None:
continue
value = props.pop(key)
if value is None:
continue
if not isinstance(value, list):
# Catalogued but not actually a list this scan - fall back to
# the generic normaliser so we don't lose the value.
props[key] = value
continue
for item in value:
child_value_key, child_props = _build_child_props(spec, item)
if child_value_key is None:
continue
child_element_id = _build_child_id(
provider_id, spec.child_label, child_value_key
)
children.append(
{
"_child_label": spec.child_label,
"row": {
"provider_element_id": child_element_id,
"props": child_props,
},
}
)
rels.append(
{
"rel_type": spec.rel_type,
"row": {
"start_element_id": parent_element_id,
"end_element_id": child_element_id,
"provider_element_id": (
f"{parent_element_id}::{spec.rel_type}::"
f"{child_element_id}"
),
"props": {},
},
}
)
return children, rels
def _build_child_props(
spec: NormalizedList, item: Any
) -> tuple[str | None, dict[str, Any]]:
"""Translate one list element into a child node's prop dict.
Returns (dedup_key, props). The dedup_key is what makes two child nodes
equal within (tenant, provider) - used to build `_provider_element_id`.
For scalar mode, the dedup key is the value itself. For dict mode it is
a stable concatenation of the mapped fields in `field_map` order.
"""
if not spec.field_map:
if isinstance(item, (dict, list)):
# Defensive: caller marked this list as scalar but elements are
# structured. Convert to a stable string so the value survives.
value_str = json.dumps(item, sort_keys=True, default=str)
else:
value_str = str(item)
return value_str, {"value": value_str}
if not isinstance(item, dict):
# Catalogued as dict-shape but got a scalar. Skip - caller will see
# the value go missing and can fix the field_map.
return None, {}
props: dict[str, Any] = {}
dedup_parts: list[str] = []
for src_key, child_field in spec.field_map:
raw = item.get(src_key)
value_str = _to_sink_property_value(raw) if raw is not None else ""
props[child_field] = value_str
dedup_parts.append(f"{child_field}={value_str}")
return "::".join(dedup_parts), props
def _build_child_id(provider_id: str, child_label: str, value_key: str) -> str:
"""Deterministic `_provider_element_id` for a list-item child node.
Dedupes within (tenant, provider): multiple parents referencing the same
value share one child node via the existing MERGE-on-_provider_element_id
index in both sinks.
"""
return f"{provider_id}::{child_label}::{value_key}"
def _build_catalog_index(
normalized_lists: list[NormalizedList],
) -> dict[tuple[str, str], NormalizedList]:
"""Index the catalog by (source_label, source_property) for O(1) lookup."""
return {
(spec.source_label, spec.source_property): spec for spec in normalized_lists
}
def _build_extra_labels(tenant_id: str, provider_id: str) -> tuple[str, ...]:
return (
PROVIDER_RESOURCE_LABEL,
get_tenant_label(tenant_id),
get_provider_label(provider_id),
)
def _render_labels(base_labels: tuple[str, ...], extra_labels: tuple[str, ...]) -> str:
"""Render the Cypher label string for a node-write batch."""
label_set = set(base_labels) | set(extra_labels)
return ":".join(f"`{label}`" for label in sorted(label_set))
def _resolve_normalized_lists(provider_type: str) -> list[NormalizedList]:
config = PROVIDER_CONFIGS.get(provider_type)
if config is None:
# Unknown provider: empty catalog. Any list-typed property will be
# serialised to a comma-delimited string with one warning per
# (label, property).
logger.warning(
"Provider type %s not in PROVIDER_CONFIGS; no normalized_lists active",
provider_type,
)
return []
return config.normalized_lists
def _rel_to_sync_dict( def _rel_to_sync_dict(
@@ -193,7 +424,11 @@ def _rel_to_sync_dict(
"""Transform a source relationship record into a (grouping_key, sync_dict) pair.""" """Transform a source relationship record into a (grouping_key, sync_dict) pair."""
props = dict(record["props"] or {}) props = dict(record["props"] or {})
_strip_internal_properties(props) _strip_internal_properties(props)
# Relationship properties go through the same primitive coercion as
# nodes; catalog-driven materialisation applies to node properties only.
_normalize_sink_properties(props, labels=None)
rel_type = record["rel_type"] rel_type = record["rel_type"]
return rel_type, { return rel_type, {
"start_element_id": f"{provider_id}:{record['start_element_id']}", "start_element_id": f"{provider_id}:{record['start_element_id']}",
"end_element_id": f"{provider_id}:{record['end_element_id']}", "end_element_id": f"{provider_id}:{record['end_element_id']}",
@@ -206,3 +441,80 @@ def _strip_internal_properties(props: dict[str, Any]) -> None:
"""Remove provider isolation properties before the += spread in sync templates.""" """Remove provider isolation properties before the += spread in sync templates."""
for key in PROVIDER_ISOLATION_PROPERTIES: for key in PROVIDER_ISOLATION_PROPERTIES:
props.pop(key, None) props.pop(key, None)
def _normalize_sink_properties(
props: dict[str, Any], labels: tuple[str, ...] | None
) -> None:
"""Normalize property values to primitive Cypher literals for either sink.
Attack-paths node and relationship properties are written as primitive
scalars regardless of the active sink (Neo4j or Neptune). The convention
is driven by Neptune's openCypher type restrictions, which reject list,
map, temporal and spatial property values, but it is applied uniformly
so that custom and predefined queries are portable across sinks without
runtime rewriting.
Concretely:
- Temporal values (neo4j.time.{DateTime,Date,Time,Duration}) become
their ISO-8601 string representation.
- Spatial values (neo4j.spatial.Point and subclasses) become their
WKT-style string representation.
- Maps / dicts become a JSON-encoded string, read back with `CONTAINS`
substring checks inside queries.
- Lists become a comma-delimited string. Catalogued list properties
are materialised as child item nodes upstream in
`_explode_catalogued_lists` and never reach this point; any list
seen here is uncatalogued, so we log a one-time warning per
(label, property) to surface Cartography fields that should be
added to the catalog.
`labels` is only used for the warning message; pass `None` for
relationship props (no label context).
"""
for key, value in list(props.items()):
if isinstance(value, list) and labels is not None:
_warn_unnormalized_list(labels, key)
props[key] = _to_sink_property_value(value)
def _warn_unnormalized_list(labels: tuple[str, ...], key: str) -> None:
"""Warn once per (label, property), on the real label(s) only.
Every synced node also carries internal isolation labels (`_AWSResource`,
`_ProviderResource`, `_Tenant_*`, `_Provider_*`); warning on those just
doubles the noise, so skip them and point at the actionable Cartography
label. Falls back to all labels if only internal ones are present.
"""
real_labels = [label for label in labels if not label.startswith("_")]
for label in real_labels or labels:
token = (label, key)
if token in _WARNED_UNNORMALIZED:
continue
_WARNED_UNNORMALIZED.add(token)
logger.warning(
"Unnormalized list property %s.%s reached sink as comma-string; "
"add a NormalizedList entry to the provider catalog to explode it",
label,
key,
)
def _to_sink_property_value(value: Any) -> Any:
if hasattr(value, "iso_format") and callable(value.iso_format):
return value.iso_format()
if type(value).__module__.startswith("neo4j.spatial"):
return str(value)
if isinstance(value, dict):
# openCypher `SET` rejects map property values: encode as JSON so the structured payload
# survives the round-trip and is queryable with `CONTAINS` substring checks
return json.dumps(value, sort_keys=True, default=str)
if isinstance(value, list):
# openCypher `SET` rejects list/array property values: encode as a
# delimited string read back with split() inside queries
return ",".join(str(_to_sink_property_value(v)) for v in value)
return value
+13
View File
@@ -1,4 +1,5 @@
from api.attack_paths import database as graph_database from api.attack_paths import database as graph_database
from api.attack_paths import sink as sink_module
from api.db_router import MainRouter from api.db_router import MainRouter
from api.db_utils import batch_delete, rls_transaction from api.db_utils import batch_delete, rls_transaction
from api.models import ( from api.models import (
@@ -76,6 +77,12 @@ def delete_provider(tenant_id: str, pk: str):
"id", flat=True "id", flat=True
) )
) )
attack_paths_sink_backends = list(
AttackPathsScan.all_objects.filter(provider=instance)
.values_list("sink_backend", flat=True)
.distinct()
.order_by("sink_backend")
)
deletion_steps = [ deletion_steps = [
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)), ("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
@@ -97,6 +104,12 @@ def delete_provider(tenant_id: str, pk: str):
# Delete the Attack Paths' graph data related to the provider from the tenant database # Delete the Attack Paths' graph data related to the provider from the tenant database
tenant_database_name = graph_database.get_database_name(tenant_id) tenant_database_name = graph_database.get_database_name(tenant_id)
try: try:
if attack_paths_sink_backends:
for sink_backend in attack_paths_sink_backends:
sink_module.get_backend_for_name(sink_backend).drop_subgraph(
tenant_database_name, str(pk)
)
else:
graph_database.drop_subgraph(tenant_database_name, str(pk)) graph_database.drop_subgraph(tenant_database_name, str(pk))
except graph_database.GraphDatabaseQueryException as gdb_error: except graph_database.GraphDatabaseQueryException as gdb_error:
@@ -23,6 +23,14 @@ from tasks.jobs.attack_paths import internet as internet_module
from tasks.jobs.attack_paths import sync as sync_module from tasks.jobs.attack_paths import sync as sync_module
from tasks.jobs.attack_paths.scan import run as attack_paths_run from tasks.jobs.attack_paths.scan import run as attack_paths_run
SYNC_RESULT_EMPTY = {
"nodes": 0,
"child_nodes": 0,
"relationships": 0,
"structural_relationships": 0,
"item_relationships": 0,
}
@pytest.mark.django_db @pytest.mark.django_db
class TestAttackPathsRun: class TestAttackPathsRun:
@@ -32,6 +40,7 @@ class TestAttackPathsRun:
"tasks.jobs.attack_paths.scan.utils.call_within_event_loop", "tasks.jobs.attack_paths.scan.utils.call_within_event_loop",
side_effect=lambda fn, *a, **kw: fn(*a, **kw), side_effect=lambda fn, *a, **kw: fn(*a, **kw),
) )
@patch("tasks.jobs.attack_paths.scan.db_utils.set_scan_migrated")
@patch("tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready") @patch("tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready")
@patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready") @patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready")
@patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan") @patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan")
@@ -39,7 +48,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch( @patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph", "tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0}, return_value=SYNC_RESULT_EMPTY,
) )
@patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph", return_value=0) @patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph", return_value=0)
@patch("tasks.jobs.attack_paths.scan.indexes.create_sync_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_sync_indexes")
@@ -48,11 +57,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri", "tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j", return_value="bolt://neo4j",
) )
@patch( @patch(
@@ -66,7 +75,7 @@ class TestAttackPathsRun:
def test_run_success_flow( def test_run_success_flow(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_create_db, mock_create_db,
mock_clear_cache, mock_clear_cache,
mock_cartography_indexes, mock_cartography_indexes,
@@ -83,6 +92,7 @@ class TestAttackPathsRun:
mock_finish, mock_finish,
mock_set_provider_graph_data_ready, mock_set_provider_graph_data_ready,
mock_set_graph_data_ready, mock_set_graph_data_ready,
mock_set_scan_migrated,
mock_event_loop, mock_event_loop,
mock_drop_db, mock_drop_db,
tenants_fixture, tenants_fixture,
@@ -159,6 +169,7 @@ class TestAttackPathsRun:
target_database="tenant-db", target_database="tenant-db",
tenant_id=str(provider.tenant_id), tenant_id=str(provider.tenant_id),
provider_id=str(provider.id), provider_id=str(provider.id),
provider_type="aws",
) )
mock_get_ingestion.assert_called_once_with(provider.provider) mock_get_ingestion.assert_called_once_with(provider.provider)
mock_event_loop.assert_called_once() mock_event_loop.assert_called_once()
@@ -172,9 +183,12 @@ class TestAttackPathsRun:
attack_paths_scan, StateChoices.COMPLETED, ingestion_result attack_paths_scan, StateChoices.COMPLETED, ingestion_result
) )
mock_set_provider_graph_data_ready.assert_called_once_with( mock_set_provider_graph_data_ready.assert_called_once_with(
attack_paths_scan, False attack_paths_scan, False, "neo4j"
) )
mock_set_graph_data_ready.assert_called_once_with(attack_paths_scan, True) mock_set_graph_data_ready.assert_called_once_with(attack_paths_scan, True)
# is_migrated is flipped to True only after the sync succeeds, so reads
# don't switch to the new catalog/sink before the graph is live.
mock_set_scan_migrated.assert_called_once_with(attack_paths_scan, True, "neo4j")
@patch( @patch(
"tasks.jobs.attack_paths.scan.utils.stringify_exception", "tasks.jobs.attack_paths.scan.utils.stringify_exception",
@@ -194,13 +208,13 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.internet.analysis") @patch("tasks.jobs.attack_paths.scan.internet.analysis")
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name", "tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id", return_value="db-scan-id",
) )
@patch("tasks.jobs.attack_paths.scan.graph_database.get_uri") @patch("tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri")
@patch( @patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider", "tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]), return_value=MagicMock(_enabled_regions=["us-east-1"]),
@@ -212,7 +226,7 @@ class TestAttackPathsRun:
def test_run_failure_marks_scan_failed( def test_run_failure_marks_scan_failed(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_get_db_name, mock_get_db_name,
mock_create_db, mock_create_db,
mock_cartography_indexes, mock_cartography_indexes,
@@ -293,13 +307,13 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.internet.analysis") @patch("tasks.jobs.attack_paths.scan.internet.analysis")
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name", "tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id", return_value="db-scan-id",
) )
@patch("tasks.jobs.attack_paths.scan.graph_database.get_uri") @patch("tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri")
@patch( @patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider", "tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]), return_value=MagicMock(_enabled_regions=["us-east-1"]),
@@ -311,7 +325,7 @@ class TestAttackPathsRun:
def test_failure_before_gate_does_not_flip_graph_data_ready_true( def test_failure_before_gate_does_not_flip_graph_data_ready_true(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_get_db_name, mock_get_db_name,
mock_create_db, mock_create_db,
mock_cartography_indexes, mock_cartography_indexes,
@@ -396,13 +410,13 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.internet.analysis") @patch("tasks.jobs.attack_paths.scan.internet.analysis")
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name", "tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id", return_value="db-scan-id",
) )
@patch("tasks.jobs.attack_paths.scan.graph_database.get_uri") @patch("tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri")
@patch( @patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider", "tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]), return_value=MagicMock(_enabled_regions=["us-east-1"]),
@@ -414,7 +428,7 @@ class TestAttackPathsRun:
def test_run_failure_marks_scan_failed_even_when_drop_database_fails( def test_run_failure_marks_scan_failed_even_when_drop_database_fails(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_get_db_name, mock_get_db_name,
mock_create_db, mock_create_db,
mock_cartography_indexes, mock_cartography_indexes,
@@ -493,7 +507,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch( @patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph", "tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0}, return_value=SYNC_RESULT_EMPTY,
) )
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.drop_subgraph", "tasks.jobs.attack_paths.scan.graph_database.drop_subgraph",
@@ -505,11 +519,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri", "tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j", return_value="bolt://neo4j",
) )
@patch( @patch(
@@ -523,7 +537,7 @@ class TestAttackPathsRun:
def test_failure_after_gate_before_drop_restores_graph_data_ready( def test_failure_after_gate_before_drop_restores_graph_data_ready(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_create_db, mock_create_db,
mock_clear_cache, mock_clear_cache,
mock_cartography_indexes, mock_cartography_indexes,
@@ -589,8 +603,8 @@ class TestAttackPathsRun:
attack_paths_run(str(tenant.id), str(scan.id), "task-456") attack_paths_run(str(tenant.id), str(scan.id), "task-456")
assert mock_set_provider_graph_data_ready.call_args_list == [ assert mock_set_provider_graph_data_ready.call_args_list == [
call(attack_paths_scan, False), call(attack_paths_scan, False, "neo4j"),
call(attack_paths_scan, True), call(attack_paths_scan, True, "neo4j"),
] ]
@patch( @patch(
@@ -618,11 +632,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri", "tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j", return_value="bolt://neo4j",
) )
@patch( @patch(
@@ -636,7 +650,7 @@ class TestAttackPathsRun:
def test_failure_after_drop_before_sync_leaves_graph_data_ready_false( def test_failure_after_drop_before_sync_leaves_graph_data_ready_false(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_create_db, mock_create_db,
mock_clear_cache, mock_clear_cache,
mock_cartography_indexes, mock_cartography_indexes,
@@ -703,7 +717,7 @@ class TestAttackPathsRun:
# Only called with False (gate), never with True (no recovery for partial data) # Only called with False (gate), never with True (no recovery for partial data)
mock_set_provider_graph_data_ready.assert_called_once_with( mock_set_provider_graph_data_ready.assert_called_once_with(
attack_paths_scan, False attack_paths_scan, False, "neo4j"
) )
@patch( @patch(
@@ -716,6 +730,7 @@ class TestAttackPathsRun:
) )
@patch("tasks.jobs.attack_paths.scan.graph_database.drop_database") @patch("tasks.jobs.attack_paths.scan.graph_database.drop_database")
@patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan") @patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan")
@patch("tasks.jobs.attack_paths.scan.db_utils.set_scan_migrated")
@patch( @patch(
"tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready", "tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready",
side_effect=[RuntimeError("flag failed"), None], side_effect=[RuntimeError("flag failed"), None],
@@ -725,7 +740,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch( @patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph", "tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0}, return_value=SYNC_RESULT_EMPTY,
) )
@patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph") @patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph")
@patch("tasks.jobs.attack_paths.scan.indexes.create_sync_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_sync_indexes")
@@ -734,11 +749,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri", "tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j", return_value="bolt://neo4j",
) )
@patch( @patch(
@@ -752,7 +767,7 @@ class TestAttackPathsRun:
def test_failure_after_sync_restores_graph_data_ready( def test_failure_after_sync_restores_graph_data_ready(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_create_db, mock_create_db,
mock_clear_cache, mock_clear_cache,
mock_cartography_indexes, mock_cartography_indexes,
@@ -768,6 +783,7 @@ class TestAttackPathsRun:
mock_update_progress, mock_update_progress,
mock_set_provider_graph_data_ready, mock_set_provider_graph_data_ready,
mock_set_graph_data_ready, mock_set_graph_data_ready,
mock_set_scan_migrated,
mock_finish, mock_finish,
mock_drop_db, mock_drop_db,
mock_event_loop, mock_event_loop,
@@ -824,8 +840,11 @@ class TestAttackPathsRun:
] ]
# set_provider_graph_data_ready only called once with False (the gate) # set_provider_graph_data_ready only called once with False (the gate)
mock_set_provider_graph_data_ready.assert_called_once_with( mock_set_provider_graph_data_ready.assert_called_once_with(
attack_paths_scan, False attack_paths_scan, False, "neo4j"
) )
# is_migrated is flipped once after the sync and is not touched again by
# the failure-recovery branch
mock_set_scan_migrated.assert_called_once_with(attack_paths_scan, True, "neo4j")
@patch( @patch(
"tasks.jobs.attack_paths.scan.utils.stringify_exception", "tasks.jobs.attack_paths.scan.utils.stringify_exception",
@@ -843,7 +862,7 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@patch( @patch(
"tasks.jobs.attack_paths.scan.sync.sync_graph", "tasks.jobs.attack_paths.scan.sync.sync_graph",
return_value={"nodes": 0, "relationships": 0}, return_value=SYNC_RESULT_EMPTY,
) )
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.drop_subgraph", "tasks.jobs.attack_paths.scan.graph_database.drop_subgraph",
@@ -855,11 +874,11 @@ class TestAttackPathsRun:
@patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes") @patch("tasks.jobs.attack_paths.scan.indexes.create_findings_indexes")
@patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run")
@patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run")
@patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") @patch("tasks.jobs.attack_paths.indexes.cartography_create_indexes.run")
@patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache")
@patch("tasks.jobs.attack_paths.scan.graph_database.create_database") @patch("tasks.jobs.attack_paths.scan.graph_database.create_database")
@patch( @patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri", "tasks.jobs.attack_paths.scan.graph_database.get_ingest_uri",
return_value="bolt://neo4j", return_value="bolt://neo4j",
) )
@patch( @patch(
@@ -873,7 +892,7 @@ class TestAttackPathsRun:
def test_recovery_failure_does_not_suppress_original_exception( def test_recovery_failure_does_not_suppress_original_exception(
self, self,
mock_init_provider, mock_init_provider,
mock_get_uri, mock_get_ingest_uri,
mock_create_db, mock_create_db,
mock_clear_cache, mock_clear_cache,
mock_cartography_indexes, mock_cartography_indexes,
@@ -1116,7 +1135,7 @@ class TestFailAttackPathsScan:
fail_attack_paths_scan(str(tenant.id), "nonexistent", "setup exploded") fail_attack_paths_scan(str(tenant.id), "nonexistent", "setup exploded")
def test_fail_recovers_graph_data_ready_when_data_exists( def test_fail_recovers_graph_data_ready_when_data_exists(
self, tenants_fixture, providers_fixture, scans_fixture self, tenants_fixture, providers_fixture, scans_fixture, sink_backend_stub
): ):
from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan
@@ -1135,16 +1154,18 @@ class TestFailAttackPathsScan:
state=StateChoices.EXECUTING, state=StateChoices.EXECUTING,
) )
# `recover_graph_data_ready` routes `has_provider_data` through
# `sink_module.get_backend_for_scan(scan)`. With `is_migrated=False`
# and the default `ATTACK_PATHS_SINK_DATABASE=neo4j`, the factory
# returns the active backend, which `sink_backend_stub` replaces.
sink_backend_stub.has_provider_data.return_value = True
with ( with (
patch( patch(
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan", "tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan, return_value=attack_paths_scan,
), ),
patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"), patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"),
patch(
"tasks.jobs.attack_paths.db_utils.graph_database.has_provider_data",
return_value=True,
),
patch( patch(
"tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready" "tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready"
) as mock_set_ready, ) as mock_set_ready,
@@ -1154,7 +1175,7 @@ class TestFailAttackPathsScan:
mock_set_ready.assert_called_once_with(attack_paths_scan, True) mock_set_ready.assert_called_once_with(attack_paths_scan, True)
def test_fail_leaves_graph_data_ready_false_when_no_data( def test_fail_leaves_graph_data_ready_false_when_no_data(
self, tenants_fixture, providers_fixture, scans_fixture self, tenants_fixture, providers_fixture, scans_fixture, sink_backend_stub
): ):
from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan
@@ -1173,16 +1194,14 @@ class TestFailAttackPathsScan:
state=StateChoices.EXECUTING, state=StateChoices.EXECUTING,
) )
sink_backend_stub.has_provider_data.return_value = False
with ( with (
patch( patch(
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan", "tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan, return_value=attack_paths_scan,
), ),
patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"), patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"),
patch(
"tasks.jobs.attack_paths.db_utils.graph_database.has_provider_data",
return_value=False,
),
patch( patch(
"tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready" "tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready"
) as mock_set_ready, ) as mock_set_ready,
@@ -1271,6 +1290,20 @@ class TestAttackPathsFindingsHelpers:
[call(mock_session, stmt) for stmt in FINDINGS_INDEX_STATEMENTS] [call(mock_session, stmt) for stmt in FINDINGS_INDEX_STATEMENTS]
) )
def test_create_findings_indexes_runs_even_when_sink_is_neptune(self, settings):
# The index helpers run against the temp ingest DB, which is always
# Neo4j regardless of the configured sink. A Neptune sink must not
# suppress index creation on that DB (regression for the dropped
# in-helper sink gate).
settings.ATTACK_PATHS_SINK_DATABASE = "neptune"
mock_session = MagicMock()
with patch("tasks.jobs.attack_paths.indexes.run_write_query") as mock_run_write:
indexes_module.create_findings_indexes(mock_session)
from tasks.jobs.attack_paths.indexes import FINDINGS_INDEX_STATEMENTS
assert mock_run_write.call_count == len(FINDINGS_INDEX_STATEMENTS)
def test_load_findings_batches_requests(self, providers_fixture): def test_load_findings_batches_requests(self, providers_fixture):
provider = providers_fixture[0] provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS provider.provider = Provider.ProviderChoices.AWS
@@ -1802,7 +1835,7 @@ def _make_session_ctx(session, call_order=None, name=None):
class TestSyncNodes: class TestSyncNodes:
def test_sync_nodes_adds_private_label(self): def test_sync_nodes_passes_isolation_labels_to_sink(self):
row = { row = {
"internal_id": 1, "internal_id": 1,
"element_id": "elem-1", "element_id": "elem-1",
@@ -1812,29 +1845,32 @@ class TestSyncNodes:
mock_source_1 = MagicMock() mock_source_1 = MagicMock()
mock_source_1.run.return_value = [row] mock_source_1.run.return_value = [row]
mock_target = MagicMock()
mock_source_2 = MagicMock() mock_source_2 = MagicMock()
mock_source_2.run.return_value = [] mock_source_2.run.return_value = []
sink = MagicMock()
with patch( with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session", "tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[ side_effect=[
_make_session_ctx(mock_source_1), _make_session_ctx(mock_source_1),
_make_session_ctx(mock_target),
_make_session_ctx(mock_source_2), _make_session_ctx(mock_source_2),
], ],
): ):
total = sync_module.sync_nodes( result = sync_module.sync_nodes(
"source-db", "target-db", "tenant-1", "prov-1" "source-db", "target-db", "tenant-1", "prov-1", sink, []
) )
assert total == 1 assert result["parents"] == 1
query = mock_target.run.call_args.args[0] sink.write_nodes.assert_called_once()
assert "_ProviderResource" in query target_db, labels, batch = sink.write_nodes.call_args.args
assert "_Tenant_tenant1" in query assert target_db == "target-db"
assert "_Provider_prov1" in query assert "_ProviderResource" in labels
assert "_Tenant_tenant1" in labels
assert "_Provider_prov1" in labels
assert batch[0]["provider_element_id"] == "prov-1:elem-1"
assert batch[0]["props"] == {"key": "value"}
def test_sync_nodes_source_closes_before_target_opens(self): def test_sync_nodes_writes_after_source_session_closes(self):
row = { row = {
"internal_id": 1, "internal_id": 1,
"element_id": "elem-1", "element_id": "elem-1",
@@ -1846,21 +1882,23 @@ class TestSyncNodes:
src_1 = MagicMock() src_1 = MagicMock()
src_1.run.return_value = [row] src_1.run.return_value = [row]
tgt = MagicMock()
src_2 = MagicMock() src_2 = MagicMock()
src_2.run.return_value = [] src_2.run.return_value = []
sink = MagicMock()
sink.write_nodes.side_effect = lambda *_a, **_kw: call_order.append(
"sink:write"
)
with patch( with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session", "tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[ side_effect=[
_make_session_ctx(src_1, call_order, "source1"), _make_session_ctx(src_1, call_order, "source1"),
_make_session_ctx(tgt, call_order, "target"),
_make_session_ctx(src_2, call_order, "source2"), _make_session_ctx(src_2, call_order, "source2"),
], ],
): ):
sync_module.sync_nodes("src-db", "tgt-db", "t-1", "p-1") sync_module.sync_nodes("src-db", "tgt-db", "t-1", "p-1", sink, [])
assert call_order.index("source1:exit") < call_order.index("target:enter") assert call_order.index("source1:exit") < call_order.index("sink:write")
def test_sync_nodes_pagination_with_batch_size_1(self): def test_sync_nodes_pagination_with_batch_size_1(self):
row_a = { row_a = {
@@ -1882,44 +1920,44 @@ class TestSyncNodes:
src_2.run.return_value = [row_b] src_2.run.return_value = [row_b]
src_3 = MagicMock() src_3 = MagicMock()
src_3.run.return_value = [] src_3.run.return_value = []
tgt_1 = MagicMock() sink = MagicMock()
tgt_2 = MagicMock()
with ( with (
patch( patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session", "tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[ side_effect=[
_make_session_ctx(src_1), _make_session_ctx(src_1),
_make_session_ctx(tgt_1),
_make_session_ctx(src_2), _make_session_ctx(src_2),
_make_session_ctx(tgt_2),
_make_session_ctx(src_3), _make_session_ctx(src_3),
], ],
), ),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1), patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1),
): ):
total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1") result = sync_module.sync_nodes("src", "tgt", "t-1", "p-1", sink, [])
assert total == 2 assert result["parents"] == 2
assert sink.write_nodes.call_count == 2
assert src_1.run.call_args.args[1]["last_id"] == -1 assert src_1.run.call_args.args[1]["last_id"] == -1
assert src_2.run.call_args.args[1]["last_id"] == 1 assert src_2.run.call_args.args[1]["last_id"] == 1
def test_sync_nodes_empty_source_returns_zero(self): def test_sync_nodes_empty_source_returns_zero(self):
src = MagicMock() src = MagicMock()
src.run.return_value = [] src.run.return_value = []
sink = MagicMock()
with patch( with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session", "tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[_make_session_ctx(src)], side_effect=[_make_session_ctx(src)],
) as mock_get_session: ) as mock_get_session:
total = sync_module.sync_nodes("src", "tgt", "t-1", "p-1") result = sync_module.sync_nodes("src", "tgt", "t-1", "p-1", sink, [])
assert total == 0 assert result["parents"] == 0
assert mock_get_session.call_count == 1 assert mock_get_session.call_count == 1
sink.write_nodes.assert_not_called()
class TestSyncRelationships: class TestSyncRelationships:
def test_sync_relationships_source_closes_before_target_opens(self): def test_sync_relationships_writes_after_source_session_closes(self):
row = { row = {
"internal_id": 1, "internal_id": 1,
"rel_type": "HAS", "rel_type": "HAS",
@@ -1932,21 +1970,23 @@ class TestSyncRelationships:
src_1 = MagicMock() src_1 = MagicMock()
src_1.run.return_value = [row] src_1.run.return_value = [row]
tgt = MagicMock()
src_2 = MagicMock() src_2 = MagicMock()
src_2.run.return_value = [] src_2.run.return_value = []
sink = MagicMock()
sink.write_relationships.side_effect = lambda *_a, **_kw: call_order.append(
"sink:write"
)
with patch( with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session", "tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[ side_effect=[
_make_session_ctx(src_1, call_order, "source1"), _make_session_ctx(src_1, call_order, "source1"),
_make_session_ctx(tgt, call_order, "target"),
_make_session_ctx(src_2, call_order, "source2"), _make_session_ctx(src_2, call_order, "source2"),
], ],
): ):
sync_module.sync_relationships("src", "tgt", "p-1") sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert call_order.index("source1:exit") < call_order.index("target:enter") assert call_order.index("source1:exit") < call_order.index("sink:write")
def test_sync_relationships_pagination_with_batch_size_1(self): def test_sync_relationships_pagination_with_batch_size_1(self):
row_a = { row_a = {
@@ -1970,40 +2010,40 @@ class TestSyncRelationships:
src_2.run.return_value = [row_b] src_2.run.return_value = [row_b]
src_3 = MagicMock() src_3 = MagicMock()
src_3.run.return_value = [] src_3.run.return_value = []
tgt_1 = MagicMock() sink = MagicMock()
tgt_2 = MagicMock()
with ( with (
patch( patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session", "tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[ side_effect=[
_make_session_ctx(src_1), _make_session_ctx(src_1),
_make_session_ctx(tgt_1),
_make_session_ctx(src_2), _make_session_ctx(src_2),
_make_session_ctx(tgt_2),
_make_session_ctx(src_3), _make_session_ctx(src_3),
], ],
), ),
patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1), patch("tasks.jobs.attack_paths.sync.SYNC_BATCH_SIZE", 1),
): ):
total = sync_module.sync_relationships("src", "tgt", "p-1") total = sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert total == 2 assert total == 2
assert sink.write_relationships.call_count == 2
assert src_1.run.call_args.args[1]["last_id"] == -1 assert src_1.run.call_args.args[1]["last_id"] == -1
assert src_2.run.call_args.args[1]["last_id"] == 1 assert src_2.run.call_args.args[1]["last_id"] == 1
def test_sync_relationships_empty_source_returns_zero(self): def test_sync_relationships_empty_source_returns_zero(self):
src = MagicMock() src = MagicMock()
src.run.return_value = [] src.run.return_value = []
sink = MagicMock()
with patch( with patch(
"tasks.jobs.attack_paths.sync.graph_database.get_session", "tasks.jobs.attack_paths.sync.graph_database.get_session",
side_effect=[_make_session_ctx(src)], side_effect=[_make_session_ctx(src)],
) as mock_get_session: ) as mock_get_session:
total = sync_module.sync_relationships("src", "tgt", "p-1") total = sync_module.sync_relationships("src", "tgt", "p-1", sink)
assert total == 0 assert total == 0
assert mock_get_session.call_count == 1 assert mock_get_session.call_count == 1
sink.write_relationships.assert_not_called()
class TestInternetAnalysis: class TestInternetAnalysis:
@@ -2075,6 +2115,8 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan is not None assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is False assert attack_paths_scan.graph_data_ready is False
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_create_attack_paths_scan_inherits_true_from_previous( def test_create_attack_paths_scan_inherits_true_from_previous(
self, tenants_fixture, providers_fixture, scans_fixture self, tenants_fixture, providers_fixture, scans_fixture
@@ -2095,6 +2137,8 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan, scan=scan,
state=StateChoices.COMPLETED, state=StateChoices.COMPLETED,
graph_data_ready=True, graph_data_ready=True,
is_migrated=True,
sink_backend="neptune",
) )
new_scan = Scan.objects.create( new_scan = Scan.objects.create(
@@ -2115,6 +2159,109 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan is not None assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is True assert attack_paths_scan.graph_data_ready is True
# is_migrated tracks the data being served: inherited from the ready scan
assert attack_paths_scan.is_migrated is True
assert attack_paths_scan.sink_backend == "neptune"
def test_create_attack_paths_scan_prefers_active_sink_ready_scan(
self, tenants_fixture, providers_fixture, scans_fixture, settings
):
from tasks.jobs.attack_paths.db_utils import create_attack_paths_scan
settings.ATTACK_PATHS_SINK_DATABASE = "neo4j"
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
is_migrated=False,
sink_backend="neo4j",
)
AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
is_migrated=True,
sink_backend="neptune",
)
new_scan = Scan.objects.create(
name="New Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
tenant_id=tenant.id,
)
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
attack_paths_scan = create_attack_paths_scan(
str(tenant.id), str(new_scan.id), provider.id
)
assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is True
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_create_attack_paths_scan_inherits_is_migrated_false_from_legacy_ready(
self, tenants_fixture, providers_fixture, scans_fixture
):
from tasks.jobs.attack_paths.db_utils import create_attack_paths_scan
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
# Previous scan is ready but pre-cutover (legacy Neo4j graph shape)
AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
is_migrated=False,
sink_backend="neo4j",
)
new_scan = Scan.objects.create(
name="New Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
tenant_id=tenant.id,
)
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
attack_paths_scan = create_attack_paths_scan(
str(tenant.id), str(new_scan.id), provider.id
)
assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is True
# Reads stay on the legacy catalog/backend until this scan's own sync
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_create_attack_paths_scan_inherits_false_when_no_previous_ready( def test_create_attack_paths_scan_inherits_false_when_no_previous_ready(
self, tenants_fixture, providers_fixture, scans_fixture self, tenants_fixture, providers_fixture, scans_fixture
@@ -2135,6 +2282,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan, scan=scan,
state=StateChoices.FAILED, state=StateChoices.FAILED,
graph_data_ready=False, graph_data_ready=False,
sink_backend="neptune",
) )
new_scan = Scan.objects.create( new_scan = Scan.objects.create(
@@ -2155,6 +2303,8 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan is not None assert attack_paths_scan is not None
assert attack_paths_scan.graph_data_ready is False assert attack_paths_scan.graph_data_ready is False
assert attack_paths_scan.is_migrated is False
assert attack_paths_scan.sink_backend == "neo4j"
def test_set_graph_data_ready_updates_field( def test_set_graph_data_ready_updates_field(
self, tenants_fixture, providers_fixture, scans_fixture self, tenants_fixture, providers_fixture, scans_fixture
@@ -2261,7 +2411,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert attack_paths_scan.state == StateChoices.FAILED assert attack_paths_scan.state == StateChoices.FAILED
assert attack_paths_scan.graph_data_ready is True assert attack_paths_scan.graph_data_ready is True
def test_set_provider_graph_data_ready_updates_all_scans_for_provider( def test_set_provider_graph_data_ready_updates_all_scans_for_provider_sink(
self, tenants_fixture, providers_fixture, scans_fixture self, tenants_fixture, providers_fixture, scans_fixture
): ):
from tasks.jobs.attack_paths.db_utils import set_provider_graph_data_ready from tasks.jobs.attack_paths.db_utils import set_provider_graph_data_ready
@@ -2289,6 +2439,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan_a, scan=scan_a,
state=StateChoices.COMPLETED, state=StateChoices.COMPLETED,
graph_data_ready=True, graph_data_ready=True,
sink_backend="neptune",
) )
new_ap_scan = AttackPathsScan.objects.create( new_ap_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id, tenant_id=tenant.id,
@@ -2296,6 +2447,7 @@ class TestAttackPathsDbUtilsGraphDataReady:
scan=scan_b, scan=scan_b,
state=StateChoices.EXECUTING, state=StateChoices.EXECUTING,
graph_data_ready=True, graph_data_ready=True,
sink_backend="neptune",
) )
with patch( with patch(
@@ -2309,6 +2461,48 @@ class TestAttackPathsDbUtilsGraphDataReady:
assert old_ap_scan.graph_data_ready is False assert old_ap_scan.graph_data_ready is False
assert new_ap_scan.graph_data_ready is False assert new_ap_scan.graph_data_ready is False
def test_set_provider_graph_data_ready_preserves_other_sink_scans(
self, tenants_fixture, providers_fixture, scans_fixture
):
from tasks.jobs.attack_paths.db_utils import set_provider_graph_data_ready
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
legacy_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.COMPLETED,
graph_data_ready=True,
sink_backend="neo4j",
)
neptune_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.EXECUTING,
graph_data_ready=True,
sink_backend="neptune",
)
with patch(
"tasks.jobs.attack_paths.db_utils.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
set_provider_graph_data_ready(neptune_scan, False)
legacy_scan.refresh_from_db()
neptune_scan.refresh_from_db()
assert legacy_scan.graph_data_ready is True
assert neptune_scan.graph_data_ready is False
def test_set_provider_graph_data_ready_does_not_affect_other_providers( def test_set_provider_graph_data_ready_does_not_affect_other_providers(
self, tenants_fixture, providers_fixture, scans_fixture self, tenants_fixture, providers_fixture, scans_fixture
): ):
@@ -2871,3 +3065,57 @@ class TestCleanupStaleAttackPathsScans:
ap_scan.refresh_from_db() ap_scan.refresh_from_db()
assert ap_scan.state == StateChoices.SCHEDULED assert ap_scan.state == StateChoices.SCHEDULED
mock_revoke.assert_not_called() mock_revoke.assert_not_called()
class TestNormalizeSinkProperties:
"""Coerce Cartography-emitted property values into sink-portable primitives.
Lists become comma-strings, dicts become JSON strings, temporals become
ISO strings, spatials become their stringified form. The same coercion
runs regardless of the active sink so queries are portable.
"""
@pytest.mark.parametrize(
"raw, expected",
[
(
{"a": "x", "b": 1, "c": 1.5, "d": True, "e": None},
{"a": "x", "b": 1, "c": 1.5, "d": True, "e": None},
),
(
{"actions": ["s3:GetObject", "s3:PutObject"], "tags": []},
{"actions": "s3:GetObject,s3:PutObject", "tags": ""},
),
(
{"condition": {"StringEquals": {"aws:SourceAccount": "123456789012"}}},
{
"condition": '{"StringEquals": {"aws:SourceAccount": "123456789012"}}'
},
),
],
)
def test_primitive_list_and_dict_branches(self, raw, expected):
sync_module._normalize_sink_properties(raw, labels=None)
assert raw == expected
def test_temporal_and_spatial_become_strings(self):
class FakeDateTime:
def iso_format(self) -> str:
return "2026-05-13T10:00:00+00:00"
class FakeSpatialPoint:
def __str__(self) -> str:
return "POINT(1.0 2.0)"
# The spatial branch is detected by module prefix, not by base class.
FakeSpatialPoint.__module__ = "neo4j.spatial.fake"
props = {
"created_at": FakeDateTime(),
"location": FakeSpatialPoint(),
}
sync_module._normalize_sink_properties(props, labels=None)
assert props == {
"created_at": "2026-05-13T10:00:00+00:00",
"location": "POINT(1.0 2.0)",
}
+50 -3
View File
@@ -1,4 +1,4 @@
from unittest.mock import call, patch from unittest.mock import MagicMock, call, patch
import pytest import pytest
from api.attack_paths import database as graph_database from api.attack_paths import database as graph_database
@@ -60,10 +60,12 @@ class TestDeleteProvider:
aps1 = create_attack_paths_scan(instance) aps1 = create_attack_paths_scan(instance)
aps2 = create_attack_paths_scan(instance) aps2 = create_attack_paths_scan(instance)
backend = MagicMock()
with ( with (
patch( patch(
"tasks.jobs.deletion.graph_database.drop_subgraph", "tasks.jobs.deletion.sink_module.get_backend_for_name",
return_value=backend,
), ),
patch( patch(
"tasks.jobs.deletion.graph_database.drop_database", "tasks.jobs.deletion.graph_database.drop_database",
@@ -72,12 +74,55 @@ class TestDeleteProvider:
result = delete_provider(tenant_id, instance.id) result = delete_provider(tenant_id, instance.id)
assert result assert result
backend.drop_subgraph.assert_called_once_with(
graph_database.get_database_name(tenant_id), str(instance.id)
)
expected_tmp_calls = [ expected_tmp_calls = [
call(f"db-tmp-scan-{str(aps1.id).lower()}"), call(f"db-tmp-scan-{str(aps1.id).lower()}"),
call(f"db-tmp-scan-{str(aps2.id).lower()}"), call(f"db-tmp-scan-{str(aps2.id).lower()}"),
] ]
mock_drop_database.assert_has_calls(expected_tmp_calls, any_order=True) mock_drop_database.assert_has_calls(expected_tmp_calls, any_order=True)
def test_delete_provider_drops_graph_data_from_all_recorded_sinks(
self, providers_fixture, create_attack_paths_scan
):
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
create_attack_paths_scan(instance, sink_backend="neo4j")
create_attack_paths_scan(instance, sink_backend="neptune")
neo4j_backend = MagicMock()
neptune_backend = MagicMock()
def get_backend_for_name(name):
return {
"neo4j": neo4j_backend,
"neptune": neptune_backend,
}[name]
with (
patch(
"tasks.jobs.deletion.graph_database.get_database_name",
return_value="tenant-db",
),
patch(
"tasks.jobs.deletion.sink_module.get_backend_for_name",
side_effect=get_backend_for_name,
) as mock_get_backend_for_name,
patch("tasks.jobs.deletion.graph_database.drop_database"),
):
result = delete_provider(tenant_id, instance.id)
assert result
mock_get_backend_for_name.assert_has_calls(
[call("neo4j"), call("neptune")], any_order=True
)
neo4j_backend.drop_subgraph.assert_called_once_with(
"tenant-db", str(instance.id)
)
neptune_backend.drop_subgraph.assert_called_once_with(
"tenant-db", str(instance.id)
)
def test_delete_provider_continues_when_temp_db_drop_fails( def test_delete_provider_continues_when_temp_db_drop_fails(
self, providers_fixture, create_attack_paths_scan self, providers_fixture, create_attack_paths_scan
): ):
@@ -85,10 +130,12 @@ class TestDeleteProvider:
tenant_id = str(instance.tenant_id) tenant_id = str(instance.tenant_id)
create_attack_paths_scan(instance) create_attack_paths_scan(instance)
backend = MagicMock()
with ( with (
patch( patch(
"tasks.jobs.deletion.graph_database.drop_subgraph", "tasks.jobs.deletion.sink_module.get_backend_for_name",
return_value=backend,
), ),
patch( patch(
"tasks.jobs.deletion.graph_database.drop_database", "tasks.jobs.deletion.graph_database.drop_database",
Generated
+234 -9
View File
@@ -110,7 +110,7 @@ constraints = [
{ name = "blinker", specifier = "==1.9.0" }, { name = "blinker", specifier = "==1.9.0" },
{ name = "boto3", specifier = "==1.40.61" }, { name = "boto3", specifier = "==1.40.61" },
{ name = "botocore", specifier = "==1.40.61" }, { name = "botocore", specifier = "==1.40.61" },
{ name = "cartography", specifier = "==0.135.0" }, { name = "cartography", specifier = "==0.138.1" },
{ name = "celery", specifier = "==5.6.2" }, { name = "celery", specifier = "==5.6.2" },
{ name = "certifi", specifier = "==2026.1.4" }, { name = "certifi", specifier = "==2026.1.4" },
{ name = "cffi", specifier = "==2.0.0" }, { name = "cffi", specifier = "==2.0.0" },
@@ -364,7 +364,7 @@ constraints = [
{ name = "wcwidth", specifier = "==0.5.3" }, { name = "wcwidth", specifier = "==0.5.3" },
{ name = "websocket-client", specifier = "==1.9.0" }, { name = "websocket-client", specifier = "==1.9.0" },
{ name = "werkzeug", specifier = "==3.1.7" }, { name = "werkzeug", specifier = "==3.1.7" },
{ name = "workos", specifier = "==6.0.4" }, { name = "workos", specifier = "==6.0.8" },
{ name = "wrapt", specifier = "==1.17.3" }, { name = "wrapt", specifier = "==1.17.3" },
{ name = "xlsxwriter", specifier = "==3.2.9" }, { name = "xlsxwriter", specifier = "==3.2.9" },
{ name = "xmlsec", specifier = "==1.3.17" }, { name = "xmlsec", specifier = "==1.3.17" },
@@ -376,6 +376,7 @@ constraints = [
{ name = "zstd", specifier = "==1.5.7.3" }, { name = "zstd", specifier = "==1.5.7.3" },
] ]
overrides = [ overrides = [
{ name = "azure-mgmt-containerservice", specifier = "==34.1.0" },
{ name = "dulwich", specifier = "==1.2.5" }, { name = "dulwich", specifier = "==1.2.5" },
{ name = "microsoft-kiota-abstractions", specifier = "==1.9.9" }, { name = "microsoft-kiota-abstractions", specifier = "==1.9.9" },
{ name = "okta", specifier = "==3.4.2" }, { name = "okta", specifier = "==3.4.2" },
@@ -1407,6 +1408,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3d/66/0d8ae9ca4d75e57746026a1f9a10a7e25029511c128cf20166fce516bda9/azure_mgmt_logic-10.0.0-py3-none-any.whl", hash = "sha256:525c78afedf3edb35eb0a16152c8beba89769ee1bc6af01bcdc42842a551e443", size = 235433, upload-time = "2022-06-13T01:38:27.333Z" }, { url = "https://files.pythonhosted.org/packages/3d/66/0d8ae9ca4d75e57746026a1f9a10a7e25029511c128cf20166fce516bda9/azure_mgmt_logic-10.0.0-py3-none-any.whl", hash = "sha256:525c78afedf3edb35eb0a16152c8beba89769ee1bc6af01bcdc42842a551e443", size = 235433, upload-time = "2022-06-13T01:38:27.333Z" },
] ]
[[package]]
name = "azure-mgmt-managementgroups"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "azure-mgmt-core" },
{ name = "isodate" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fd/73/ac5e064ed7343e1b3172f32f09be3efca906087218d3046b5038f2f394ed/azure_mgmt_managementgroups-1.1.0.tar.gz", hash = "sha256:e6199baf118890ba2bda35dda83a88861c0b1bbef126311b20ec12eed9681951", size = 60101, upload-time = "2026-02-13T03:45:45.439Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/92/bc/993158de03cc0a49f2cf8192615ffedbc508c417cb3522e88f6652b714cc/azure_mgmt_managementgroups-1.1.0-py3-none-any.whl", hash = "sha256:140934589559ef6afcac6f1d24f995588a1965aaa89d47851c1cc639fafb1942", size = 83586, upload-time = "2026-02-13T03:45:46.836Z" },
]
[[package]] [[package]]
name = "azure-mgmt-monitor" name = "azure-mgmt-monitor"
version = "6.0.2" version = "6.0.2"
@@ -1726,7 +1741,7 @@ wheels = [
[[package]] [[package]]
name = "cartography" name = "cartography"
version = "0.135.0" version = "0.138.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "adal" }, { name = "adal" },
@@ -1746,6 +1761,7 @@ dependencies = [
{ name = "azure-mgmt-eventhub" }, { name = "azure-mgmt-eventhub" },
{ name = "azure-mgmt-keyvault" }, { name = "azure-mgmt-keyvault" },
{ name = "azure-mgmt-logic" }, { name = "azure-mgmt-logic" },
{ name = "azure-mgmt-managementgroups" },
{ name = "azure-mgmt-monitor" }, { name = "azure-mgmt-monitor" },
{ name = "azure-mgmt-network" }, { name = "azure-mgmt-network" },
{ name = "azure-mgmt-resource" }, { name = "azure-mgmt-resource" },
@@ -1754,6 +1770,7 @@ dependencies = [
{ name = "azure-mgmt-storage" }, { name = "azure-mgmt-storage" },
{ name = "azure-mgmt-synapse" }, { name = "azure-mgmt-synapse" },
{ name = "azure-mgmt-web" }, { name = "azure-mgmt-web" },
{ name = "azure-storage-blob" },
{ name = "azure-synapse-artifacts" }, { name = "azure-synapse-artifacts" },
{ name = "backoff" }, { name = "backoff" },
{ name = "boto3" }, { name = "boto3" },
@@ -1765,8 +1782,12 @@ dependencies = [
{ name = "duo-client" }, { name = "duo-client" },
{ name = "google-api-python-client" }, { name = "google-api-python-client" },
{ name = "google-auth" }, { name = "google-auth" },
{ name = "google-cloud-aiplatform" },
{ name = "google-cloud-artifact-registry" },
{ name = "google-cloud-asset" }, { name = "google-cloud-asset" },
{ name = "google-cloud-resource-manager" }, { name = "google-cloud-resource-manager" },
{ name = "google-cloud-run" },
{ name = "google-cloud-storage" },
{ name = "httpx" }, { name = "httpx" },
{ name = "kubernetes" }, { name = "kubernetes" },
{ name = "marshmallow" }, { name = "marshmallow" },
@@ -1792,9 +1813,9 @@ dependencies = [
{ name = "workos" }, { name = "workos" },
{ name = "xmltodict" }, { name = "xmltodict" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/39/47/606851d2403a983b63813b9e95427a5dd896e49bc5a501868c041262e9a5/cartography-0.135.0.tar.gz", hash = "sha256:3f500cd22c3b392d00e8b49f62acc95fd4dcd559ce514aafe2eb8101133c7a49", size = 9106458, upload-time = "2026-04-10T16:25:34.898Z" } sdist = { url = "https://files.pythonhosted.org/packages/51/cd/0eb6a5a3c89cc179801d902ade9719af1a583c516c00f50d72b8207db1eb/cartography-0.138.1.tar.gz", hash = "sha256:356e946a0bcac899cba293d57803c71bd35fdeabe623f5f67d9405d7a643af9f", size = 9756966, upload-time = "2026-06-19T22:11:32.411Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/e1/99a26b3e662202be77961aba73338e1448623490710b81783e53a4bbef15/cartography-0.135.0-py3-none-any.whl", hash = "sha256:c62c32a6917b8f23a8b98fe2b6c7c4a918b50f55918482966c4dae1cf5f538e1", size = 1590545, upload-time = "2026-04-10T16:25:37.669Z" }, { url = "https://files.pythonhosted.org/packages/a8/15/4447ec968825b2a19cba26ecb74964208aa3f941d9181a7782572e30b43d/cartography-0.138.1-py3-none-any.whl", hash = "sha256:88ec0898ea1a1b3f4653be9a3e7e61144f5cee20384b9040e92039617d39f029", size = 2014725, upload-time = "2026-06-19T22:11:29.886Z" },
] ]
[[package]] [[package]]
@@ -2511,6 +2532,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
] ]
[[package]]
name = "docstring-parser"
version = "0.18.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e0/4d/f332313098c1de1b2d2ff91cf2674415cc7cddab2ca1b01ae29774bd5fdf/docstring_parser-0.18.0.tar.gz", hash = "sha256:292510982205c12b1248696f44959db3cdd1740237a968ea1e2e7a900eeb2015", size = 29341, upload-time = "2026-04-14T04:09:19.867Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/5f/ed01f9a3cdffbd5a008556fc7b2a08ddb1cc6ace7effa7340604b1d16699/docstring_parser-0.18.0-py3-none-any.whl", hash = "sha256:b3fcbed555c47d8479be0796ef7e19c2670d428d72e96da63f3a40122860374b", size = 22484, upload-time = "2026-04-14T04:09:18.638Z" },
]
[[package]] [[package]]
name = "dogpile-cache" name = "dogpile-cache"
version = "1.5.0" version = "1.5.0"
@@ -2851,6 +2881,11 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/83/1d/d6466de3a5249d35e832a52834115ca9d1d0de6abc22065f049707516d47/google_auth-2.48.0-py3-none-any.whl", hash = "sha256:2e2a537873d449434252a9632c28bfc268b0adb1e53f9fb62afc5333a975903f", size = 236499, upload-time = "2026-01-26T19:22:45.099Z" }, { url = "https://files.pythonhosted.org/packages/83/1d/d6466de3a5249d35e832a52834115ca9d1d0de6abc22065f049707516d47/google_auth-2.48.0-py3-none-any.whl", hash = "sha256:2e2a537873d449434252a9632c28bfc268b0adb1e53f9fb62afc5333a975903f", size = 236499, upload-time = "2026-01-26T19:22:45.099Z" },
] ]
[package.optional-dependencies]
requests = [
{ name = "requests" },
]
[[package]] [[package]]
name = "google-auth-httplib2" name = "google-auth-httplib2"
version = "0.2.0" version = "0.2.0"
@@ -2877,6 +2912,46 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ca/94/24b010493660dd55e2d9769ae7ef44164aebd7e1f4a9266cf9459affd687/google_cloud_access_context_manager-0.3.0-py3-none-any.whl", hash = "sha256:5d15ad51547f06c281e35f16b4ffcb3e98bb2d898b01470f88b94edfb2eeb0a3", size = 58852, upload-time = "2025-10-17T02:30:33.768Z" }, { url = "https://files.pythonhosted.org/packages/ca/94/24b010493660dd55e2d9769ae7ef44164aebd7e1f4a9266cf9459affd687/google_cloud_access_context_manager-0.3.0-py3-none-any.whl", hash = "sha256:5d15ad51547f06c281e35f16b4ffcb3e98bb2d898b01470f88b94edfb2eeb0a3", size = 58852, upload-time = "2025-10-17T02:30:33.768Z" },
] ]
[[package]]
name = "google-cloud-aiplatform"
version = "1.153.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "docstring-parser" },
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "google-cloud-bigquery" },
{ name = "google-cloud-resource-manager" },
{ name = "google-cloud-storage" },
{ name = "google-genai" },
{ name = "packaging" },
{ name = "proto-plus" },
{ name = "protobuf" },
{ name = "pydantic" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d5/97/1779e66ab845550bc602364311ea093ba156cb805a1c31b7c4d6f25b5863/google_cloud_aiplatform-1.153.1.tar.gz", hash = "sha256:445b6c683d5c630f174d81ae1f69f7da9e27e4d4ec5b70c5fe96de5c1247cfbc", size = 11011349, upload-time = "2026-05-15T06:34:14.851Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/16/01/8a1900e7a742ed480e6037ac4f6541466cb981d81bd4cbd34a9d46204ea1/google_cloud_aiplatform-1.153.1-py2.py3-none-any.whl", hash = "sha256:033fa1595a7e8ed1d97066e261e630f38fbc60e10c98c6487cf228fe9c7ec151", size = 9170782, upload-time = "2026-05-15T06:34:10.887Z" },
]
[[package]]
name = "google-cloud-artifact-registry"
version = "1.21.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "grpc-google-iam-v1" },
{ name = "grpcio" },
{ name = "proto-plus" },
{ name = "protobuf" },
]
sdist = { url = "https://files.pythonhosted.org/packages/13/2b/24e6956789bc1244efb18143aa4f124e03d870228e5bfd065c04d38a4d6b/google_cloud_artifact_registry-1.21.0.tar.gz", hash = "sha256:546e51eb5d463a6e5c668be6727d14f8ec82bc798031398006b2213d703e184c", size = 315219, upload-time = "2026-03-30T22:50:38.875Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e1/8c/a5c68031728f38d3306bad5ac10c0ca670cbdf414db308ddefa2c47f2b34/google_cloud_artifact_registry-1.21.0-py3-none-any.whl", hash = "sha256:a07079035438fd0f2e7264d4318b388650495f011db575405c18c9881449025c", size = 250544, upload-time = "2026-03-30T22:48:49.345Z" },
]
[[package]] [[package]]
name = "google-cloud-asset" name = "google-cloud-asset"
version = "4.2.0" version = "4.2.0"
@@ -2897,6 +2972,37 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/88/9a43fae1d2fed94d7f5f46b6f4c44bd15e5ea0e8657632108b5ec5f53d9d/google_cloud_asset-4.2.0-py3-none-any.whl", hash = "sha256:fd7ea04c64948a4779790343204cd5b41d4772d6ab1d05a9125e28a637ac0862", size = 282707, upload-time = "2026-01-09T14:53:03.081Z" }, { url = "https://files.pythonhosted.org/packages/05/88/9a43fae1d2fed94d7f5f46b6f4c44bd15e5ea0e8657632108b5ec5f53d9d/google_cloud_asset-4.2.0-py3-none-any.whl", hash = "sha256:fd7ea04c64948a4779790343204cd5b41d4772d6ab1d05a9125e28a637ac0862", size = 282707, upload-time = "2026-01-09T14:53:03.081Z" },
] ]
[[package]]
name = "google-cloud-bigquery"
version = "3.41.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "google-cloud-core" },
{ name = "google-resumable-media" },
{ name = "packaging" },
{ name = "python-dateutil" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ce/13/6515c7aab55a4a0cf708ffd309fb9af5bab54c13e32dc22c5acd6497193c/google_cloud_bigquery-3.41.0.tar.gz", hash = "sha256:2217e488b47ed576360c9b2cc07d59d883a54b83167c0ef37f915c26b01a06fe", size = 513434, upload-time = "2026-03-30T22:50:55.347Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/40/33/1d3902efadef9194566d499d61507e1f038454e0b55499d2d7f8ab2a4fee/google_cloud_bigquery-3.41.0-py3-none-any.whl", hash = "sha256:2a5b5a737b401cbd824a6e5eac7554100b878668d908e6548836b5d8aaa4dcaa", size = 262343, upload-time = "2026-03-30T22:48:45.444Z" },
]
[[package]]
name = "google-cloud-core"
version = "2.6.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core" },
{ name = "google-auth" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/dd/1eef226e470369b26824a505c34482c0b493bc35fe8e0c6b003b5feca21a/google_cloud_core-2.6.0.tar.gz", hash = "sha256:e76149739f90fac1fc6757c09f47eaccb3145b54adbd7759b0f7c4b235f46c83", size = 36001, upload-time = "2026-05-07T08:04:04.124Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/4a/98da8930ab109c73d9a5d13782a9ebb81ea8c111f6d534a567b71d23e52b/google_cloud_core-2.6.0-py3-none-any.whl", hash = "sha256:6d63ac8e5eca6d9e4319d0a1e2265fadcd7f1049904378caecfa01cf52dd869e", size = 29390, upload-time = "2026-05-07T08:02:34.672Z" },
]
[[package]] [[package]]
name = "google-cloud-org-policy" name = "google-cloud-org-policy"
version = "1.16.0" version = "1.16.0"
@@ -2946,6 +3052,93 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/94/ff/4b28bcc791d9d7e4ac8fea00fbd90ccb236afda56746a3b4564d2ae45df3/google_cloud_resource_manager-1.16.0-py3-none-any.whl", hash = "sha256:fb9a2ad2b5053c508e1c407ac31abfd1a22e91c32876c1892830724195819a28", size = 400218, upload-time = "2026-01-15T13:02:47.378Z" }, { url = "https://files.pythonhosted.org/packages/94/ff/4b28bcc791d9d7e4ac8fea00fbd90ccb236afda56746a3b4564d2ae45df3/google_cloud_resource_manager-1.16.0-py3-none-any.whl", hash = "sha256:fb9a2ad2b5053c508e1c407ac31abfd1a22e91c32876c1892830724195819a28", size = 400218, upload-time = "2026-01-15T13:02:47.378Z" },
] ]
[[package]]
name = "google-cloud-run"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core", extra = ["grpc"] },
{ name = "google-auth" },
{ name = "grpc-google-iam-v1" },
{ name = "grpcio" },
{ name = "proto-plus" },
{ name = "protobuf" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b7/89/dcaf0dc97e39b41e446456ceb60657ab025de79cfccd39cbd739d1a9849e/google_cloud_run-0.16.0.tar.gz", hash = "sha256:d52cf4e6ad3702ae48caccf6abcab543afee6f61c2a6ec753cc62a31e5b629f1", size = 514452, upload-time = "2026-03-26T22:17:05.589Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/c7/46153dc13713b5e4276d86f28ff4563332f9e4bae5ebc83abc5bfd994801/google_cloud_run-0.16.0-py3-none-any.whl", hash = "sha256:d7d2dd7307130fde2a0ce27e96d580dd23b7b2d973b6484b94d902e6b2618860", size = 459112, upload-time = "2026-03-26T22:16:00.018Z" },
]
[[package]]
name = "google-cloud-storage"
version = "3.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core" },
{ name = "google-auth" },
{ name = "google-cloud-core" },
{ name = "google-crc32c" },
{ name = "google-resumable-media" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" },
]
[[package]]
name = "google-crc32c"
version = "1.8.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" },
{ url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" },
{ url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" },
{ url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" },
{ url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" },
{ url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" },
{ url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" },
{ url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" },
{ url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" },
{ url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" },
{ url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" },
{ url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" },
]
[[package]]
name = "google-genai"
version = "1.68.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "google-auth", extra = ["requests"] },
{ name = "httpx" },
{ name = "pydantic" },
{ name = "requests" },
{ name = "sniffio" },
{ name = "tenacity" },
{ name = "typing-extensions" },
{ name = "websockets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9c/2c/f059982dbcb658cc535c81bbcbe7e2c040d675f4b563b03cdb01018a4bc3/google_genai-1.68.0.tar.gz", hash = "sha256:ac30c0b8bc630f9372993a97e4a11dae0e36f2e10d7c55eacdca95a9fa14ca96", size = 511285, upload-time = "2026-03-18T01:03:18.243Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15", size = 750912, upload-time = "2026-03-18T01:03:15.983Z" },
]
[[package]]
name = "google-resumable-media"
version = "2.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-crc32c" },
]
sdist = { url = "https://files.pythonhosted.org/packages/00/4b/0b235beccc310d0a48adbc7246b719d173cca6c88c572dfa4b090e39143c/google_resumable_media-2.9.0.tar.gz", hash = "sha256:f7cfb224846a9dd444d125115dfbe8ef02a2b893e78f087762fe716a255a734b", size = 2164534, upload-time = "2026-05-07T08:04:44.236Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/73/3518e63deb1667c5409a4579e28daf5e84479a87a72c547e0487f7883dcd/google_resumable_media-2.9.0-py3-none-any.whl", hash = "sha256:c8901e88e389af8bed64d9696c74d8bad961865eb2236e13e0bfca9bb0a65ca3", size = 81507, upload-time = "2026-05-07T08:03:23.809Z" },
]
[[package]] [[package]]
name = "googleapis-common-protos" name = "googleapis-common-protos"
version = "1.72.0" version = "1.72.0"
@@ -4606,7 +4799,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "cartography", specifier = "==0.135.0" }, { name = "cartography", specifier = "==0.138.1" },
{ name = "celery", specifier = "==5.6.2" }, { name = "celery", specifier = "==5.6.2" },
{ name = "defusedxml", specifier = "==0.7.1" }, { name = "defusedxml", specifier = "==0.7.1" },
{ name = "dj-rest-auth", extras = ["with-social", "jwt"], specifier = "==7.0.1" }, { name = "dj-rest-auth", extras = ["with-social", "jwt"], specifier = "==7.0.1" },
@@ -5931,6 +6124,38 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" }, { url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" },
] ]
[[package]]
name = "websockets"
version = "16.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" },
{ url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" },
{ url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" },
{ url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" },
{ url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" },
{ url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" },
{ url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" },
{ url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" },
{ url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" },
{ url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" },
{ url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" },
{ url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" },
{ url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" },
{ url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" },
{ url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" },
{ url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" },
{ url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" },
{ url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" },
{ url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" },
{ url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" },
{ url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" },
{ url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" },
{ url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" },
{ url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" },
]
[[package]] [[package]]
name = "werkzeug" name = "werkzeug"
version = "3.1.7" version = "3.1.7"
@@ -5945,16 +6170,16 @@ wheels = [
[[package]] [[package]]
name = "workos" name = "workos"
version = "6.0.4" version = "6.0.8"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "cryptography" }, { name = "cryptography" },
{ name = "httpx" }, { name = "httpx" },
{ name = "pyjwt", extra = ["crypto"] }, { name = "pyjwt", extra = ["crypto"] },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/3c/2f/99fb8718274116c5c146c745755620fd5c5943f78ca52ca9b17e94348286/workos-6.0.4.tar.gz", hash = "sha256:b0bfe8fd212b8567422c4ea3732eb33608794033eb3a69900c6b04db183c32d6", size = 172217, upload-time = "2026-04-16T03:09:28.583Z" } sdist = { url = "https://files.pythonhosted.org/packages/ca/0d/0a7f78912657f99412c788932ea1f3f4089916e77bdef7d2463842febe08/workos-6.0.8.tar.gz", hash = "sha256:43aa3f1992a0a4ca8933d9b6e5ada846dd3b1fe0ee10e64c876ee2000fc6090d", size = 178137, upload-time = "2026-04-24T18:48:03.203Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/f1/d2ab661e6dc2828a4c73e38f12630c3b109cfe2bc664ab70631c04f0db4b/workos-6.0.4-py3-none-any.whl", hash = "sha256:548668b3702673536f853ba72a7b5bbbc269e467aaf9ac4f477b6e0177df5e21", size = 511418, upload-time = "2026-04-16T03:09:27.098Z" }, { url = "https://files.pythonhosted.org/packages/b2/3f/3d96da80d650b2f97d58af626053354584f619dbb769051e118bd9cd1ca5/workos-6.0.8-py3-none-any.whl", hash = "sha256:a00dd4930333aded2babbba824f8032eea05c5ca8c44d04a3fa068cf6be6e21a", size = 524505, upload-time = "2026-04-24T18:48:01.389Z" },
] ]
[[package]] [[package]]
@@ -3,13 +3,13 @@ title: "Attack Paths"
description: "Identify privilege escalation chains and security misconfigurations across cloud environments using graph-based analysis." description: "Identify privilege escalation chains and security misconfigurations across cloud environments using graph-based analysis."
--- ---
import { VersionBadge } from "/snippets/version-badge.mdx" import { VersionBadge } from "/snippets/version-badge.mdx";
<VersionBadge version="5.17.0" /> <VersionBadge version="5.17.0" />
Attack Paths analyzes relationships between cloud resources, permissions, and security findings to detect how privileges can be escalated and how misconfigurations can be exploited by threat actors. Attack Paths analyzes relationships between cloud resources, permissions, and security findings to detect how privileges can be escalated and how misconfigurations can be exploited by threat actors.
By mapping these relationships as a graph, Attack Paths reveals risks that individual security checks cannot detect on their own such as an IAM role that can escalate its own permissions, or a chain of policies that grants unintended access to sensitive resources. By mapping these relationships as a graph, Attack Paths reveals risks that individual security checks cannot detect on their own, such as an IAM role that can escalate its own permissions, or a chain of policies that grants unintended access to sensitive resources.
<Note> <Note>
Attack Paths is currently available for **AWS** providers. Support for Attack Paths is currently available for **AWS** providers. Support for
@@ -21,7 +21,7 @@ By mapping these relationships as a graph, Attack Paths reveals risks that indiv
The following prerequisites are required for Attack Paths: The following prerequisites are required for Attack Paths:
- **An AWS provider is configured** with valid credentials in Prowler App. For setup instructions, see [Getting Started with AWS](/user-guide/providers/aws/getting-started-aws). - **An AWS provider is configured** with valid credentials in Prowler App. For setup instructions, see [Getting Started with AWS](/user-guide/providers/aws/getting-started-aws).
- **At least one scan has completed** on the configured AWS provider. Attack Paths scans run automatically alongside regular security scans no separate configuration is required. - **At least one scan has completed** on the configured AWS provider. Attack Paths scans run automatically alongside regular security scans, no separate configuration is required.
## How Attack Paths Scans Work ## How Attack Paths Scans Work
@@ -145,11 +145,10 @@ LIMIT 25
**IAM principals with wildcard Allow statements:** **IAM principals with wildcard Allow statements:**
```cypher ```cypher
MATCH (principal:AWSPrincipal)--(policy:AWSPolicy)--(stmt:AWSPolicyStatement) MATCH (principal:AWSPrincipal)-[:POLICY]->(policy:AWSPolicy)-[:STATEMENT]->(stmt:AWSPolicyStatement {effect: 'Allow'})
WHERE stmt.effect = 'Allow' MATCH (stmt)-[:HAS_ACTION]->(a:AWSPolicyStatementActionItem)
AND ANY(action IN stmt.action WHERE action = '*') WHERE a.value = '*'
RETURN principal.arn AS principal, policy.arn AS policy, RETURN DISTINCT principal.arn AS principal, policy.arn AS policy
stmt.action AS actions, stmt.resource AS resources
LIMIT 25 LIMIT 25
``` ```
@@ -173,218 +172,89 @@ RETURN r.name AS role_name, r.arn AS role_arn, p.arn AS trusted_service
LIMIT 25 LIMIT 25
``` ```
### Advanced Attack Path Scenarios ### Working with List-Typed Properties
The following scenarios show how to compose graph traversals into real attack-path stories. Each query can be pasted directly into the custom query box: the API auto-scopes them to the selected provider and injects tenant/provider isolation, so there is no need to include account identifiers or `$provider_uid` in the text. All queries are openCypher v9 (Neo4j and Neptune compatible). Some Cartography node properties carry a list of values, such as `action`, `resource`, `notaction`, and `notresource` on `AWSPolicyStatement` nodes, the algorithms on `KMSKey`, the container-definition lists on `ECSContainerDefinition`, and many others. The Attack Paths graph models each such property as a set of child item nodes connected to the parent by a typed edge. To read the values, traverse the edge; the parent does not carry the list as a single field.
#### 1. Live attacker on the box that owns the keys The naming convention for any list-typed property on a parent label is:
**Query story:** Finds an internet-exposed EC2 under an active GuardDuty SSH brute-force whose instance role can assume a higher-privileged role that can read a sensitive S3 bucket. - **Child label:** `<ParentLabel><PropertyPascal>Item`. Example: `AWSPolicyStatement.resource` resolves to `AWSPolicyStatementResourceItem`.
- **Edge type:** `HAS_<PROPERTY_UPPER>`. Example: `resource` resolves to `HAS_RESOURCE`.
- **Child property:** `value` for scalar lists (one string per list element). List-of-dict properties (rare; for example `SecretsManagerSecretVersion.tags`) carry the original dict keys as named fields on the child node.
To express "at least one item in the list satisfies a predicate", traverse the `HAS_*` edge in its own `MATCH` clause and apply the predicate in the attached `WHERE`. `RETURN DISTINCT` collapses duplicate parent rows produced when multiple child items satisfy the filter:
```cypher ```cypher
MATCH path_ec2 = (acct:AWSAccount)--(ec2:EC2Instance) MATCH (stmt:AWSPolicyStatement {effect: 'Allow'})
WHERE ec2.exposed_internet = true MATCH (stmt)-[:HAS_ACTION]->(a:AWSPolicyStatementActionItem)
MATCH p0 = (gd:GuardDutyFinding)-[:AFFECTS]->(ec2) WHERE toLower(a.value) STARTS WITH 's3:get'
MATCH p1 = (ec2)-[:INSTANCE_PROFILE]->(prof:AWSInstanceProfile)-[:ASSOCIATED_WITH]->(low:AWSRole) OR toLower(a.value) STARTS WITH 's3:list'
MATCH p2 = (low)-[:STS_ASSUMEROLE_ALLOW]-(high:AWSRole) RETURN DISTINCT stmt
MATCH p3 = (high)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
OPTIONAL MATCH path_net = (internet:Internet)-[:CAN_ACCESS]->(ec2)
MATCH path_s3 = (acct)--(s3:S3Bucket)
WHERE high <> low
AND stmt.effect = 'Allow'
AND size([a IN stmt.action WHERE
toLower(a) STARTS WITH 's3:getobject'
OR toLower(a) STARTS WITH 's3:listbucket'
OR toLower(a) IN ['s3:*']
]) > 0
AND size([r IN stmt.resource WHERE
r CONTAINS s3.name
]) > 0
RETURN path_net, path_ec2, p0, p1, p2, p3, path_s3
```
**How it's built:**
- `path_ec2` anchors the graph on the account node and its internet-exposed EC2 instance, via a real account-to-resource edge. This is the visible spine that keeps everything connected.
- `p0` ties a `GuardDutyFinding` to that instance through the `AFFECTS` edge (the live SSH brute-force alert).
- `p1` walks the real graph edges from the instance to its instance profile to the role it runs as.
- `p2` follows the `STS_ASSUMEROLE_ALLOW` edge to the higher-privileged role the low role can assume. It is undirected so it works regardless of how the assume edge was ingested. `high <> low` stops a role matching itself.
- `p3` walks that role into its policy and policy statement.
- `path_net` is the optional `Internet -[:CAN_ACCESS]-> instance` edge. It makes "from the internet" literal on screen. Optional so a missing `Internet` node never breaks the query live.
- `path_s3` connects the sensitive bucket to the same account node, so it draws connected instead of floating. There is no physical edge from a role to a bucket; the grant is logical, enforced in the `WHERE`: the statement must allow an S3 read action (list comprehension over the `action` array) and its resource must cover the bucket (`CONTAINS s3.name`). The account is the shared hub; the bucket hanging off it next to the role chain is the teaching moment — the access exists only in IAM.
#### 2. Who can read the crown jewels
**Query story:** The sensitive bucket from the previous scenario seen from the data side: every role whose IAM policy can read it, regardless of how the role is reached.
```cypher
MATCH (s3:S3Bucket)
WHERE toLower(s3.name) CONTAINS 'sensitive'
MATCH (role:AWSRole)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
WHERE stmt.effect = 'Allow'
AND size([a IN stmt.action WHERE
toLower(a) STARTS WITH 's3:get'
OR toLower(a) STARTS WITH 's3:list'
OR toLower(a) IN ['s3:*']
]) > 0
AND size([r IN stmt.resource WHERE
r CONTAINS s3.name
]) > 0
WITH DISTINCT s3, role
LIMIT 25 LIMIT 25
MATCH path_s3 = (acct:AWSAccount)--(s3)
MATCH path_role = (acct)--(role)
RETURN path_s3, path_role
``` ```
**How it's built:** data-centric, not attacker-centric — the same bucket the previous kill chain exfiltrates, approached from the other direction. To check whether every item in the list satisfies a predicate, count the counter-examples and require zero, together with a guard that ensures at least one item is attached. This is the one case where the pattern-comprehension form is the right tool:
- The `S3Bucket` is bound first by name (one node), so everything else filters against it.
- `(role:AWSRole)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)` reaches statements only *through a role*, never via a global statement scan. A blanket `AWSPolicyStatement` scan also hits resource-policy statements whose shape differs and makes the list comprehension fail outright.
- The `WHERE` filters in place: an S3 read action plus a resource that names that bucket.
- `WITH DISTINCT s3, role LIMIT 25` collapses undirected-traversal duplicates and hard-caps the result.
- `path_s3` and `path_role` attach the account hubs only after the cap, against at most 25 rows, so the bucket and role(s) draw connected through the account instead of floating.
- No internet or EC2 here; this answers "who has the keys" instead of "how would an attacker get in."
#### 3. Lateral reach from an internet-exposed instance
**Query story:** The wide-angle view of the live-attacker scenario: every internet-exposed EC2, the role it runs as, and every role that role can assume. The first scenario is one specific exfiltration path inside this reach, under live attack.
```cypher ```cypher
MATCH path_ec2 = (acct:AWSAccount)--(ec2:EC2Instance) MATCH (stmt:AWSPolicyStatement)
WHERE ec2.exposed_internet = true WHERE size([
MATCH p1 = (ec2)-[:INSTANCE_PROFILE]->(prof:AWSInstanceProfile)-[:ASSOCIATED_WITH]->(low:AWSRole) (stmt)-[:HAS_ACTION]->(a:AWSPolicyStatementActionItem)
MATCH p2 = (low)-[:STS_ASSUMEROLE_ALLOW]-(high:AWSRole) WHERE NOT toLower(a.value) STARTS WITH 's3:'
OPTIONAL MATCH path_net = (internet:Internet)-[:CAN_ACCESS]->(ec2) | a
WHERE high <> low ]) = 0
RETURN path_net, path_ec2, p1, p2 AND size([(stmt)-[:HAS_ACTION]->(a:AWSPolicyStatementActionItem) | a]) > 0
RETURN stmt
LIMIT 25
``` ```
**How it's built:** widens the lens instead of filtering down. It stops at the assume-role hop and shows every role reachable from any internet-exposed instance, without filtering down to a specific S3 leg. For the "is any item of this list a substring of a dynamic value" case, such as "does any resource pattern in this policy match a target role ARN", add the `HAS_*` traversal as its own `MATCH` and check the substring relationship between the item value and the dynamic node in `WHERE`:
- `path_ec2` is the account-to-instance spine.
- `p1` walks to the instance role.
- `p2` fans out to every role that role can assume.
- `path_net` adds the optional `Internet -[:CAN_ACCESS]->` edge.
- The first scenario is the specific exfiltration path under live attack; this is the broader privilege reach an attacker inherits the moment they land on the box.
#### 4. Role-chain privilege escalation
**Query story:** A pure-IAM escalation, no compromised instance: a role that can assume a second role whose policy lets it assume a third, admin-level role.
```cypher ```cypher
MATCH path_root = (acct:AWSAccount)--(r1:AWSRole) MATCH (role:AWSRole)
MATCH p1 = (r1)-[:STS_ASSUMEROLE_ALLOW]-(r2:AWSRole) WHERE role.name = 'Admin'
MATCH p2 = (r2)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement) MATCH (principal:AWSPrincipal)-[:POLICY]->(:AWSPolicy)-[:STATEMENT]->(stmt:AWSPolicyStatement {effect: 'Allow'})
MATCH path_admin = (acct)--(admin:AWSRole) MATCH (stmt)-[:HAS_RESOURCE]->(r:AWSPolicyStatementResourceItem)
WHERE r1 <> r2 AND r1 <> admin AND r2 <> admin WHERE r.value = '*'
AND stmt.effect = 'Allow' OR r.value CONTAINS role.name
AND size([a IN stmt.action WHERE OR role.arn CONTAINS r.value
toLower(a) IN ['sts:*', 'sts:assumerole'] RETURN DISTINCT principal.arn AS principal, stmt, role
]) > 0 LIMIT 25
AND size([res IN stmt.resource WHERE
res CONTAINS admin.name
]) > 0
RETURN path_root, p1, p2, path_admin
``` ```
**How it's built:** To return the list of values directly, collect them from the child items:
- `path_root` anchors role 1 to the account node, the spine that keeps the picture connected.
- `p1` is the one real assume edge in the chain (role 1 to role 2).
- `p2` walks role 2 into its policy and statement.
- `path_admin` connects the target admin role to the same account node so it draws connected. The third hop is not a graph edge: it exists only as `sts:AssumeRole` on that role's ARN inside the statement. The query proves it the same way the first scenario proves S3 access — the statement action must include an assume-role action and its resource list must reference the admin role's name.
- The three `<>` guards stop a role matching itself at any position.
#### 5. External identity trust map
**Query story:** Finds external identity providers (SSO, GitHub, GitLab, Terraform Cloud) and the AWS roles they are trusted to assume.
```cypher ```cypher
MATCH p = (role:AWSRole)-[:TRUSTS_AWS_PRINCIPAL]-(idp:AWSPrincipal) MATCH (stmt:AWSPolicyStatement {effect: 'Allow'})
WHERE idp.arn CONTAINS 'saml-provider' OPTIONAL MATCH (stmt)-[:HAS_ACTION]->(a:AWSPolicyStatementActionItem)
OR idp.arn CONTAINS 'oidc-provider' RETURN stmt, collect(a.value) AS actions
MATCH path_role = (acct:AWSAccount)--(role) LIMIT 25
RETURN p, path_role
``` ```
**How it's built:** federated principals are stored as `AWSPrincipal` nodes whose ARN contains `saml-provider` (SSO) or `oidc-provider` (GitHub, GitLab, Terraform Cloud). ### Working with JSON-Encoded Properties
- `p` matches the trust edge undirected. It is written `(AWSRole)-[:TRUSTS_AWS_PRINCIPAL]->(AWSPrincipal)`, role to principal, so a directed `principal -> role` match returns nothing; undirected matches regardless of ingest direction. Some Cartography properties represent nested objects, most notably `condition` on `AWSPolicyStatement` and `S3PolicyStatement` nodes. In the Attack Paths graph, object-typed properties are stored as JSON-encoded strings to keep the schema portable across graph backends. The value looks like:
- The `WHERE` keeps only SAML or OIDC providers, drawing a fan-out from each external identity provider to every role it can assume (including reserved SSO admin roles).
- `path_role` ties every trusted role to the account node so the provider stars share one spine instead of drawing as separate islands.
#### 6. Federated SSO roles flagged as admin or privesc ```
'{"StringEquals":{"aws:SourceAccount":"123456789012"}}'
```
**Query story:** The dangerous subset of the trust map above — externally-federated SSO roles that Prowler also flags for AdministratorAccess or privilege escalation. There is no JSON parser available at query time, so use `CONTAINS` for substring checks against keys or known values:
```cypher ```cypher
MATCH (idp:AWSPrincipal)-[:TRUSTS_AWS_PRINCIPAL]-(role:AWSRole) MATCH (stmt:AWSPolicyStatement)
WHERE idp.arn CONTAINS 'saml-provider' WHERE stmt.effect = 'Allow'
OR idp.arn CONTAINS 'oidc-provider' AND stmt.condition CONTAINS '"aws:SourceAccount"'
MATCH (role)-[:HAS_FINDING]-(pf:ProwlerFinding) RETURN stmt
WHERE pf.status = 'FAIL' LIMIT 25
AND pf.check_id IN [
'iam_inline_policy_allows_privilege_escalation',
'iam_role_administratoraccess_policy',
'iam_inline_policy_no_administrative_privileges',
'iam_user_administrator_access_policy'
]
WITH DISTINCT idp, role, pf
LIMIT 60
MATCH path_root = (acct:AWSAccount)--(role)
MATCH p_trust = (idp)-[:TRUSTS_AWS_PRINCIPAL]-(role)
MATCH p_find = (role)-[:HAS_FINDING]-(pf)
RETURN path_root, p_trust, p_find
``` ```
**How it's built:** a plain "list every flagged identity" query is a wide fan that draws as a column, and `ProwlerFinding` nodes accumulate across scans with no scan filter available in custom queries. When a query needs to inspect the structured members of a condition (for example, evaluate every operator and key), fetch the rows first and parse the JSON in application code. Cypher cannot navigate JSON object keys or values.
- The first MATCH plus `WHERE` keeps only roles trusted by a SAML or OIDC provider (trust edge undirected, so direction does not matter).
- The second MATCH plus `check_id IN [...]` keeps only those carrying one of the four privilege-escalation or admin checks.
- `WITH DISTINCT ... LIMIT 60` collapses duplicate finding nodes and hard-caps the result.
- `p_trust`, `p_find`, and `path_root` draw it connected three ways: provider to role through the trust edge, role to its finding, and role to the account.
- The previous scenario shows who can walk in; this shows which of those roles Prowler already flags as over-privileged.
#### 7. World-readable S3 buckets
**Query story:** Unlike the IAM-gated sensitive bucket in scenarios 1 and 2, these buckets are open to anyone on the internet with no credentials at all.
```cypher
MATCH path_s3 = (acct:AWSAccount)--(s3:S3Bucket)
WHERE s3.anonymous_access = true
OPTIONAL MATCH p = (s3)--(stmt:S3PolicyStatement)
RETURN path_s3, p
```
**How it's built:** the counterpoint to scenarios 1 and 2 — there the sensitive bucket is reachable only through an IAM role chain; here the bucket needs no role at all.
- `path_s3` connects each public bucket to its account node so they draw connected. Cartography sets `anonymous_access = true` when a bucket's policy or ACL allows public access.
- `p` is an optional match that pulls in the `S3PolicyStatement` granting the access where one exists, so the public grant is visible next to the bucket. Buckets that are public via ACL only still show, connected to the account.
#### 8. Internet exposure surface
**Query story:** The raw external attack surface behind scenarios 1 and 3: every internet-exposed EC2 instance with its security groups and the exact inbound ports left open.
```cypher
MATCH path_ec2 = (acct:AWSAccount)--(ec2:EC2Instance)
WHERE ec2.exposed_internet = true
MATCH p1 = (ec2)--(sg:EC2SecurityGroup)--(rule:IpPermissionInbound)
OPTIONAL MATCH path_net = (internet:Internet)-[:CAN_ACCESS]->(ec2)
OPTIONAL MATCH p2 = (ec2)-[:INSTANCE_PROFILE]->(:AWSInstanceProfile)-[:ASSOCIATED_WITH]->(:AWSRole)
RETURN path_net, path_ec2, p1, p2
```
**How it's built:** `exposed_internet = true` is Cartography's computed reachability flag.
- `path_ec2` hubs all exposed instances on the account node so they draw as one picture.
- `p1` joins each instance to its security groups and inbound rules so the open ports are on screen.
- `path_net` adds the optional `Internet -[:CAN_ACCESS]->` edge so the external reachability is explicit.
- `p2` optionally adds the instance role, which connects this surface view back to the kill chains in scenarios 1 and 3.
### Tips for Writing Queries ### Tips for Writing Queries
- Start small with `LIMIT` to inspect the shape of the data before broadening the pattern. - Start small with `LIMIT` to inspect the shape of the data before broadening the pattern.
- Traverse `HAS_*` edges to reach list-typed property values (for example `action`, `resource`). The parent node does not carry the list as a single field; see [Working with List-Typed Properties](#working-with-list-typed-properties) for the patterns.
- On large scans, avoid broad disconnected patterns such as `MATCH (a:Label), (b:OtherLabel)`. Bind one side with a selective predicate first, and use `WITH DISTINCT` between expanding traversals when duplicates are possible.
- Use `RETURN` projections (`RETURN n.name, n.region`) instead of returning whole nodes to keep responses compact. - Use `RETURN` projections (`RETURN n.name, n.region`) instead of returning whole nodes to keep responses compact.
- Combine resource nodes with `ProwlerFinding` nodes via `HAS_FINDING` to correlate misconfigurations with the affected resources. - Combine resource nodes with `ProwlerFinding` nodes via `HAS_FINDING` to correlate misconfigurations with the affected resources.
- When a query times out or returns no rows, simplify the pattern step by step until the first variant runs successfully, then add constraints back. - When a query times out or returns no rows, simplify the pattern step by step until the first variant runs successfully, then add constraints back.
@@ -401,6 +271,8 @@ In addition to the upstream schema, Prowler enriches the graph with:
- **`ProwlerFinding`** nodes representing Prowler check results, linked to affected resources via `HAS_FINDING` relationships. - **`ProwlerFinding`** nodes representing Prowler check results, linked to affected resources via `HAS_FINDING` relationships.
- **`Internet`** nodes used to model exposure paths from the public internet to internal resources. - **`Internet`** nodes used to model exposure paths from the public internet to internal resources.
- **List-typed properties** such as `action` or `resource` on `AWSPolicyStatement`, the algorithm lists on `KMSKey`, and similar lists on other node types are modeled as child item nodes linked by typed `HAS_*` edges. See [Working with List-Typed Properties](#working-with-list-typed-properties) for the read pattern.
- **Object-typed properties** such as `condition` on `AWSPolicyStatement` are stored as JSON-encoded strings. See [Working with JSON-Encoded Properties](#working-with-json-encoded-properties) for the read pattern.
<Note> <Note>
AI assistants connected through Prowler MCP Server can fetch the exact AI assistants connected through Prowler MCP Server can fetch the exact
@@ -540,13 +412,13 @@ Attack Paths currently supports the following built-in queries for AWS:
#### Custom Attack Path Queries #### Custom Attack Path Queries
| Query | Description | | Query | Description |
|---|---| | ------------------------------------------------- | ---------------------------------------------------------------------------------------- |
| **Internet-Exposed EC2 with Sensitive S3 Access** | Find SSH-exposed EC2 instances that can assume roles to read tagged sensitive S3 buckets | | **Internet-Exposed EC2 with Sensitive S3 Access** | Find SSH-exposed EC2 instances that can assume roles to read tagged sensitive S3 buckets |
#### Basic Resource Queries #### Basic Resource Queries
| Query | Description | | Query | Description |
|---|---| | ------------------------------------------- | ------------------------------------------------------------------- |
| **RDS Instances Inventory** | List all provisioned RDS database instances in the account | | **RDS Instances Inventory** | List all provisioned RDS database instances in the account |
| **Unencrypted RDS Instances** | Find RDS instances with storage encryption disabled | | **Unencrypted RDS Instances** | Find RDS instances with storage encryption disabled |
| **S3 Buckets with Anonymous Access** | Find S3 buckets that allow anonymous access | | **S3 Buckets with Anonymous Access** | Find S3 buckets that allow anonymous access |
@@ -557,7 +429,7 @@ Attack Paths currently supports the following built-in queries for AWS:
#### Network Exposure Queries #### Network Exposure Queries
| Query | Description | | Query | Description |
|---|---| | ----------------------------------------------------- | ----------------------------------------------------------------------------------- |
| **Internet-Exposed EC2 Instances** | Find EC2 instances flagged as exposed to the internet | | **Internet-Exposed EC2 Instances** | Find EC2 instances flagged as exposed to the internet |
| **Open Security Groups on Internet-Facing Resources** | Find internet-facing resources with security groups allowing inbound from 0.0.0.0/0 | | **Open Security Groups on Internet-Facing Resources** | Find internet-facing resources with security groups allowing inbound from 0.0.0.0/0 |
| **Internet-Exposed Classic Load Balancers** | Find Classic Load Balancers exposed to the internet with their listeners | | **Internet-Exposed Classic Load Balancers** | Find Classic Load Balancers exposed to the internet with their listeners |
@@ -569,7 +441,7 @@ Attack Paths currently supports the following built-in queries for AWS:
These queries are based on research from [pathfinding.cloud](https://pathfinding.cloud) by Datadog. These queries are based on research from [pathfinding.cloud](https://pathfinding.cloud) by Datadog.
| Query | Description | | Query | Description |
|---|---| | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **App Runner Service Creation with Privileged Role (APPRUNNER-001)** | Create an App Runner service with a privileged IAM role to gain its permissions | | **App Runner Service Creation with Privileged Role (APPRUNNER-001)** | Create an App Runner service with a privileged IAM role to gain its permissions |
| **App Runner Service Update for Role Access (APPRUNNER-002)** | Update an existing App Runner service to leverage its already-attached privileged role | | **App Runner Service Update for Role Access (APPRUNNER-002)** | Update an existing App Runner service to leverage its already-attached privileged role |
| **Bedrock Code Interpreter with Privileged Role (BEDROCK-001)** | Create a Bedrock AgentCore Code Interpreter with a privileged role attached | | **Bedrock Code Interpreter with Privileged Role (BEDROCK-001)** | Create a Bedrock AgentCore Code Interpreter with a privileged role attached |
@@ -638,6 +510,7 @@ These queries are based on research from [pathfinding.cloud](https://pathfinding
| **Role Assumption for Privilege Escalation (STS-001)** | Assume IAM roles with elevated permissions by exploiting bidirectional trust between the starting principal and the target role | | **Role Assumption for Privilege Escalation (STS-001)** | Assume IAM roles with elevated permissions by exploiting bidirectional trust between the starting principal and the target role |
These tools enable workflows such as: These tools enable workflows such as:
- Asking an AI assistant to identify privilege escalation paths in a specific AWS account - Asking an AI assistant to identify privilege escalation paths in a specific AWS account
- Automating attack path analysis across multiple scans - Automating attack path analysis across multiple scans
- Combining attack path data with findings and compliance information for comprehensive security reports - Combining attack path data with findings and compliance information for comprehensive security reports
+238 -236
View File
@@ -2,13 +2,14 @@
name: prowler-attack-paths-query name: prowler-attack-paths-query
description: > description: >
Creates Prowler Attack Paths openCypher queries using the Cartography schema as the source of truth Creates Prowler Attack Paths openCypher queries using the Cartography schema as the source of truth
for node labels, properties, and relationships. Also covers Prowler-specific additions (Internet node, for node labels, properties, and relationships. Covers Prowler-specific additions (Internet node,
ProwlerFinding, internal isolation labels) and $provider_uid scoping for predefined queries. ProwlerFinding, internal isolation labels), $provider_uid scoping, and list-property item nodes
with typed `HAS_*` edges that run efficiently on both Neo4j and Amazon Neptune sinks.
Trigger: When creating or updating Attack Paths queries. Trigger: When creating or updating Attack Paths queries.
license: Apache-2.0 license: Apache-2.0
metadata: metadata:
author: prowler-cloud author: prowler-cloud
version: "2.0" version: "3.0"
scope: [root, api] scope: [root, api]
auto_invoke: auto_invoke:
- "Creating Attack Paths queries" - "Creating Attack Paths queries"
@@ -19,36 +20,30 @@ allowed-tools: Read, Edit, Write, Glob, Grep, Bash, WebFetch, Task
## Overview ## Overview
Attack Paths queries are openCypher queries that analyze cloud infrastructure graphs (ingested via Cartography) to detect security risks like privilege escalation paths, network exposure, and misconfigurations. Attack Paths queries are read-only openCypher queries over a Cartography-ingested cloud graph that detect privilege escalation chains, network exposure, and other graph-shaped security risks. Queries are written in openCypher Version 9 so they run on both Neo4j and Amazon Neptune sinks.
Queries are written in **openCypher Version 9** for compatibility with both Neo4j and Amazon Neptune.
--- ---
## Two query audiences ## Two query audiences
This skill covers two types of queries with different isolation mechanisms:
| | Predefined queries | Custom queries | | | Predefined queries | Custom queries |
|---|---|---| | ------------------ | ----------------------------------------------------------- | --------------------------------------------------------------------- |
| **Where they live** | `api/src/backend/api/attack_paths/queries/{provider}.py` | User/LLM-supplied via the custom query API endpoint | | Where they live | `api/src/backend/api/attack_paths/queries/{provider}.py` | User-supplied via the custom query API endpoint |
| **Provider isolation** | `AWSAccount {id: $provider_uid}` anchor + path connectivity | Automatic `_Provider_{uuid}` label injection via `cypher_sanitizer.py` | | Provider isolation | `AWSAccount {id: $provider_uid}` anchor + path connectivity | Automatic `_Provider_{uuid}` label injection by `cypher_sanitizer.py` |
| **What to write** | Chain every MATCH from the `aws` variable | Plain Cypher, no isolation boilerplate needed | | What to write | Chain every MATCH from the `aws` variable | Plain Cypher, no isolation boilerplate |
| **Internal labels** | Never use (`_ProviderResource`, `_Tenant_*`, `_Provider_*`) | Never use (injected automatically by the system) | | Internal labels | Never use | Never use (system-injected) |
**For predefined queries**: every node must be reachable from the `AWSAccount` root via graph traversal. This is the isolation boundary. **Predefined queries**: every node must be reachable from the `AWSAccount` root via graph traversal. That is the isolation boundary.
**For custom queries**: write natural Cypher without isolation concerns. The query runner injects a `_Provider_{uuid}` label into every node pattern before execution, and a post-query filter catches edge cases. **Custom queries**: write natural Cypher. The runner injects a `_Provider_{uuid}` label into every node pattern, and a post-query filter handles edge cases.
--- ---
## Input Sources ## Input sources
Queries can be created from: Two sources for new queries:
1. **pathfinding.cloud ID** (e.g., `ECS-001`, `GLUE-001`) 1. **pathfinding.cloud ID** (e.g. `ECS-001`, `GLUE-001`), the Datadog research catalogue. The aggregated `paths.json` is too large for WebFetch:
- Reference: https://github.com/DataDog/pathfinding.cloud
- The aggregated `paths.json` is too large for WebFetch. Use Bash:
```bash ```bash
# Fetch a single path by ID # Fetch a single path by ID
@@ -64,28 +59,24 @@ Queries can be created from:
| jq -r '.[] | select(.id | startswith("ecs")) | "\(.id): \(.name)"' | jq -r '.[] | select(.id | startswith("ecs")) | "\(.id): \(.name)"'
``` ```
If `jq` is not available, use `python3 -c "import json,sys; ..."` as a fallback. If `jq` is unavailable, use `python3 -c "import json,sys; ..."`.
2. **Natural language description** from the user 2. **Natural language description** from the requester.
--- ---
## Query Structure ## Query structure
### Provider scoping parameter ### Provider scoping parameter
One parameter is injected automatically by the query runner: | Parameter | Property | Used on | Purpose |
| --------------- | -------- | ------------ | -------------------------------------- |
| `$provider_uid` | `id` | `AWSAccount` | Scopes the query to a specific account |
| Parameter | Property it matches | Used on | Purpose | The runner binds `$provider_uid` automatically. Every other node is isolated by path connectivity from the `AWSAccount` anchor.
| --------------- | ------------------- | ------------ | -------------------------------- |
| `$provider_uid` | `id` | `AWSAccount` | Scopes to a specific AWS account |
All other nodes are isolated by path connectivity from the `AWSAccount` anchor.
### Imports ### Imports
All query files start with these imports:
```python ```python
from api.attack_paths.queries.types import ( from api.attack_paths.queries.types import (
AttackPathsQueryAttribution, AttackPathsQueryAttribution,
@@ -95,29 +86,33 @@ from api.attack_paths.queries.types import (
from tasks.jobs.attack_paths.config import PROWLER_FINDING_LABEL from tasks.jobs.attack_paths.config import PROWLER_FINDING_LABEL
``` ```
The `PROWLER_FINDING_LABEL` constant (value: `"ProwlerFinding"`) is used via f-string interpolation in all queries. Never hardcode the label string. Always use `PROWLER_FINDING_LABEL` via f-string interpolation, never hardcode `"ProwlerFinding"`.
### Privilege escalation sub-patterns ### Definition fields
There are four distinct privilege escalation patterns. Choose based on the attack type: - **id**: kebab-case `{provider}-{description}`, e.g. `aws-ec2-privesc-passrole-iam`.
- **name**: short, human-friendly label. Sourced queries append the reference ID: `"EC2 Instance Launch with Privileged Role (EC2-001)"`.
- **short_description**: one sentence, no technical permissions.
- **description**: full technical explanation, plain text.
- **provider**: `aws`, `azure`, `gcp`, `kubernetes`, or `github`.
- **cypher**: f-string Cypher body. Literal `{` / `}` are escaped as `{{` / `}}`.
- **parameters**: `parameters=[]` if none.
- **attribution**: optional `AttackPathsQueryAttribution(text, link)` for sourced queries. `link` uses the lowercase ID.
| Sub-pattern | Target | `path_target` shape | Example | Append the constant to the `{PROVIDER}_QUERIES` list at the bottom of the provider file.
|---|---|---|---|
| Self-escalation | Principal's own policies | `(aws)--(target_policy:AWSPolicy)--(principal)` | IAM-001 |
| Lateral to user | Other IAM users | `(aws)--(target_user:AWSUser)` | IAM-002 |
| Assume-role lateral | Assumable roles | `(aws)--(target_role:AWSRole)<-[:STS_ASSUMEROLE_ALLOW]-(principal)` | IAM-014 |
| PassRole + service | Service-trusting roles | `(aws)--(target_role:AWSRole)-[:TRUSTS_AWS_PRINCIPAL]->(...)` | EC2-001 |
#### Self-escalation (e.g., IAM-001) ---
The principal modifies resources attached to itself. `path_target` loops back to `principal`: ## Predefined query template
The canonical shape combines a principal walk, an optional target walk, deduplicated nodes, and a typed finding overlay:
```python ```python
AWS_{QUERY_NAME} = AttackPathsQueryDefinition( AWS_{QUERY_NAME} = AttackPathsQueryDefinition(
id="aws-{kebab-case-name}", id="aws-{kebab-case-name}",
name="{Human-friendly label} ({REFERENCE_ID})", name="{Label} ({REFERENCE_ID})",
short_description="{Brief explanation, no technical permissions.}", short_description="{One sentence.}",
description="{Detailed description of the attack vector and impact.}", description="{Full technical explanation.}",
attribution=AttackPathsQueryAttribution( attribution=AttackPathsQueryAttribution(
text="pathfinding.cloud - {REFERENCE_ID} - {permission}", text="pathfinding.cloud - {REFERENCE_ID} - {permission}",
link="https://pathfinding.cloud/paths/{reference_id_lowercase}", link="https://pathfinding.cloud/paths/{reference_id_lowercase}",
@@ -125,29 +120,27 @@ AWS_{QUERY_NAME} = AttackPathsQueryDefinition(
provider="aws", provider="aws",
cypher=f""" cypher=f"""
// Find principals with {permission} // Find principals with {permission}
MATCH path_principal = (aws:AWSAccount {{id: $provider_uid}})--(principal:AWSPrincipal)--(policy:AWSPolicy)--(stmt:AWSPolicyStatement) MATCH path_principal = (aws:AWSAccount {{id: $provider_uid}})--(principal:AWSPrincipal)-[:POLICY]->(policy:AWSPolicy)-[:STATEMENT]->(stmt:AWSPolicyStatement {{effect: 'Allow'}})
WHERE stmt.effect = 'Allow' MATCH (stmt)-[:HAS_ACTION]->(act:AWSPolicyStatementActionItem)
AND any(action IN stmt.action WHERE WHERE toLower(act.value) IN ['{permission_lowercase}', '{service}:*']
toLower(action) = '{permission_lowercase}' OR act.value = '*'
OR toLower(action) = '{service}:*' WITH DISTINCT aws, principal, stmt, path_principal
OR action = '*'
)
// Find target resources attached to the same principal // Target resources attached to the same principal (sub-patterns below)
MATCH path_target = (aws)--(target_policy:AWSPolicy)--(principal) MATCH path_target = (aws)--(target_policy:AWSPolicy)--(principal)
WHERE target_policy.arn CONTAINS $provider_uid WHERE target_policy.arn CONTAINS $provider_uid
AND any(resource IN stmt.resource WHERE MATCH (stmt)-[:HAS_RESOURCE]->(res:AWSPolicyStatementResourceItem)
resource = '*' WHERE res.value = '*'
OR target_policy.arn CONTAINS resource OR target_policy.arn CONTAINS res.value
)
WITH DISTINCT path_principal, path_target
WITH collect(path_principal) + collect(path_target) AS paths WITH collect(path_principal) + collect(path_target) AS paths
UNWIND paths AS p UNWIND paths AS p
UNWIND nodes(p) AS n UNWIND nodes(p) AS n
WITH paths, collect(DISTINCT n) AS unique_nodes WITH paths, collect(DISTINCT n) AS unique_nodes
UNWIND unique_nodes AS n UNWIND unique_nodes AS n
OPTIONAL MATCH (n)-[pfr]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}}) OPTIONAL MATCH (n)-[pfr:HAS_FINDING]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}})
RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""", """,
@@ -155,39 +148,67 @@ AWS_{QUERY_NAME} = AttackPathsQueryDefinition(
) )
``` ```
#### Other sub-pattern `path_target` shapes Key points:
The other 3 sub-patterns share the same `path_principal`, deduplication tail, and RETURN as self-escalation. Only the `path_target` MATCH differs: - The principal walk types the `POLICY` and `STATEMENT` hops. Both are low-fan-out (each principal has a handful of policies; each policy a handful of statements), so the typed edge lets the planner cost a cheap inline filter.
- The `(aws)--` hub hops stay anonymous. `AWSAccount` is a high-degree node that fans out to every principal, role, policy, and resource in the account; typing those edges forces the planner to enumerate from the hub and collapses performance on multi-tenant Neptune.
- Other relationship types appear only where the file's existing queries already use one (`TRUSTS_AWS_PRINCIPAL`, `STS_ASSUMEROLE_ALLOW`, `MEMBER_AWS_GROUP`, `HAS_EXECUTION_ROLE`).
- The finding probe is typed `:HAS_FINDING` and left undirected. The type lets Neptune apply an inline edge filter; the lack of direction matches the convention of the rest of the file.
- Collapse duplicate rows after each permission gate with `WITH DISTINCT`, carrying only the variables needed by later clauses.
- Each `HAS_*` traversal is its own `MATCH` clause with a `WHERE` on the child item node. `WITH DISTINCT path_principal, path_target` precedes `collect(path...)` to dedupe the row multiplication produced by the joins.
- The `RETURN` shape `paths, dpf, dpfr` is the contract the serializer and visualiser depend on. Do not change it.
---
## Privilege escalation sub-patterns
Four `path_target` shapes cover the common attack types. Each shares the canonical template's `path_principal`, deduplication tail, and `RETURN`; only the `path_target` MATCH and its resource predicate differ.
| Sub-pattern | Target | `path_target` shape | Example |
| ------------------- | ------------------------ | ------------------------------------------------------------------------------------------------------- | ------- |
| Self-escalation | Principal's own policies | `(aws)--(target_policy:AWSPolicy)--(principal)` | IAM-001 |
| Lateral to user | Other IAM users | `(aws)--(target_user:AWSUser)` | IAM-002 |
| Assume-role lateral | Assumable roles | `(aws)--(target_role:AWSRole)-[:STS_ASSUMEROLE_ALLOW]-(principal)` | IAM-014 |
| PassRole + service | Service-trusting roles | `(aws)--(target_role:AWSRole)-[:TRUSTS_AWS_PRINCIPAL]-(:AWSPrincipal {arn: '{service}.amazonaws.com'})` | EC2-001 |
**Multi-permission queries** (e.g. PassRole plus a service-create action) add permission gates before `path_target`. Reuse the per-query counter for new variables (`act2`, `policy2`, `stmt2`) and collapse rows after each gate:
```cypher ```cypher
// Lateral to user (e.g., IAM-002) - targets other IAM users MATCH (principal)-[:POLICY]->(policy2:AWSPolicy)-[:STATEMENT]->(stmt2:AWSPolicyStatement {effect: 'Allow'})
MATCH path_target = (aws)--(target_user:AWSUser) MATCH (stmt2)-[:HAS_ACTION]->(act2:AWSPolicyStatementActionItem)
WHERE any(resource IN stmt.resource WHERE resource = '*' OR target_user.arn CONTAINS resource OR resource CONTAINS target_user.name) WHERE toLower(act2.value) IN ['service:*', 'service:createsomething']
OR act2.value = '*'
// Assume-role lateral (e.g., IAM-014) - targets roles the principal can assume WITH DISTINCT aws, principal, stmt, stmt2, path_principal
MATCH path_target = (aws)--(target_role:AWSRole)<-[:STS_ASSUMEROLE_ALLOW]-(principal)
WHERE any(resource IN stmt.resource WHERE resource = '*' OR target_role.arn CONTAINS resource OR resource CONTAINS target_role.name)
// PassRole + service (e.g., EC2-001) - targets roles trusting a service
MATCH path_target = (aws)--(target_role:AWSRole)-[:TRUSTS_AWS_PRINCIPAL]->(:AWSPrincipal {arn: '{service}.amazonaws.com'})
WHERE any(resource IN stmt.resource WHERE resource = '*' OR target_role.arn CONTAINS resource OR resource CONTAINS target_role.name)
``` ```
**Multi-permission**: PassRole queries require a second permission. Add `MATCH (principal)--(policy2:AWSPolicy)--(stmt2:AWSPolicyStatement)` with its own WHERE before `path_target`, then check BOTH `stmt.resource` AND `stmt2.resource` against the target. See IAM-015 or EC2-001 in `aws.py` for examples. If a permission is an existence-only gate whose statement resource is not checked later, keep the policy and statement anonymous and carry only the variables still needed:
### Network exposure pattern ```cypher
MATCH (principal)-[:POLICY]->(:AWSPolicy)-[:STATEMENT]->(:AWSPolicyStatement {effect: 'Allow'})-[:HAS_ACTION]->(act3:AWSPolicyStatementActionItem)
WHERE toLower(act3.value) IN ['service:*', 'service:othersomething']
OR act3.value = '*'
WITH DISTINCT aws, principal, stmt, path_principal
```
The Internet node is reached via `CAN_ACCESS` through the already-scoped resource, not via a standalone lookup: When all matching principals can target the same independent resource set, collect principal paths before expanding targets instead of creating one row per principal-target pair:
```cypher
WITH aws, collect(DISTINCT path_principal) AS principal_paths
MATCH path_target = (aws)--(target)
WITH principal_paths + collect(DISTINCT path_target) AS paths
```
Statements that constrain a target are still checked via `HAS_RESOURCE` traversals (`res`, `res2`). See IAM-015 or EC2-001 in `aws.py`.
---
## Network exposure pattern
The Internet node is reached via `CAN_ACCESS` through an already-scoped resource, never as a standalone lookup:
```python ```python
AWS_{QUERY_NAME} = AttackPathsQueryDefinition(
id="aws-{kebab-case-name}",
name="{Human-friendly label}",
short_description="{Brief explanation.}",
description="{Detailed description.}",
provider="aws",
cypher=f""" cypher=f"""
// Match exposed resources (MUST chain from `aws`) // Resource scoped through the account anchor
MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(resource:EC2Instance) MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(resource:EC2Instance)
WHERE resource.exposed_internet = true WHERE resource.exposed_internet = true
@@ -200,113 +221,72 @@ AWS_{QUERY_NAME} = AttackPathsQueryDefinition(
WITH paths, internet, can_access, collect(DISTINCT n) AS unique_nodes WITH paths, internet, can_access, collect(DISTINCT n) AS unique_nodes
UNWIND unique_nodes AS n UNWIND unique_nodes AS n
OPTIONAL MATCH (n)-[pfr]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}}) OPTIONAL MATCH (n)-[pfr:HAS_FINDING]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}})
RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr,
internet, can_access internet, can_access
""", """
parameters=[],
)
``` ```
### Register in query list The `CAN_ACCESS` edge stays typed and directed (`-[:CAN_ACCESS]->`); that is its canonical sync-time orientation.
Add to the `{PROVIDER}_QUERIES` list at the bottom of the file:
```python
AWS_QUERIES: list[AttackPathsQueryDefinition] = [
# ... existing queries ...
AWS_{NEW_QUERY_NAME}, # Add here
]
```
--- ---
## Step-by-step creation process ## List-typed properties as child nodes
### 1. Read the queries module Some Cartography node properties carry a list of values: `AWSPolicyStatement.action`, `AWSPolicyStatement.resource`, `KMSKey.encryption_algorithms`, `CloudFrontDistribution.aliases`, and many others. The graph models each such property as a set of child item nodes connected to the parent by a typed edge. Queries reach the values by traversing the edge; the parent does not carry the list as a single field.
**FIRST**, read all files in the queries module to understand the structure, type definitions, registration, and existing style: ### Naming convention
```text For a list-typed parent property the sink stores:
api/src/backend/api/attack_paths/queries/
├── __init__.py # Module exports - **Child label**: `<ParentLabel><PropertyPascal>Item`. Example: `AWSPolicyStatement.resource``AWSPolicyStatementResourceItem`.
├── types.py # AttackPathsQueryDefinition, AttackPathsQueryParameterDefinition - **Edge type**: `HAS_<PROPERTY_UPPER>`. Example: `resource``HAS_RESOURCE`.
├── registry.py # Query registry logic - **Child property**: `value` (a single scalar string) for scalar-list properties. For list-of-dict properties (rare; for example `SecretsManagerSecretVersion.tags`) the child carries the dict keys as named fields per the catalog's `field_map`.
└── {provider}.py # Provider-specific queries (e.g., aws.py)
### Variable naming for child-item matches
`aws.py` uses a per-query counter for each `HAS_*` traversal so chained matches stay unambiguous:
| Edge | First | Second | Third |
| ----------------- | ------ | ------- | ------- |
| `HAS_ACTION` | `act` | `act2` | `act3` |
| `HAS_RESOURCE` | `res` | `res2` | `res3` |
| `HAS_NOTACTION` | `nact` | `nact2` | `nact3` |
| `HAS_NOTRESOURCE` | `nres` | `nres2` | `nres3` |
The counter resets at the top of every query.
### Example - action match
Find statements that grant `iam:PassRole`, `iam:*`, or `*`. Traverse the `HAS_ACTION` edge in its own `MATCH` clause and apply the predicate in the attached `WHERE`:
```cypher
MATCH (stmt:AWSPolicyStatement {effect: 'Allow'})
MATCH (stmt)-[:HAS_ACTION]->(act:AWSPolicyStatementActionItem)
WHERE toLower(act.value) IN ['iam:passrole', 'iam:*']
OR act.value = '*'
``` ```
**DO NOT** use generic templates. Match the exact style of existing queries in the file. The literal-action list is case-folded with `toLower(act.value)` because IAM authors mix case (`iam:PassRole`, `iam:passrole`); the `*` wildcard never lower-cases.
### 2. Fetch and consult the Cartography schema ### Example - resource ARN match
**This is the most important step.** Every node label, property, and relationship in the query must exist in the Cartography schema for the pinned version. Do not guess or rely on memory. Find statements whose resource can target a specific role:
Check `api/pyproject.toml` for the Cartography dependency, then fetch the schema: ```cypher
MATCH path_target = (aws)--(target_role:AWSRole)
```bash MATCH (stmt)-[:HAS_RESOURCE]->(res:AWSPolicyStatementResourceItem)
grep cartography api/pyproject.toml WHERE res.value = '*'
OR res.value CONTAINS target_role.name
OR target_role.arn CONTAINS res.value
``` ```
Build the schema URL (ALWAYS use the specific tag, not master/main): Three predicates cover the cases: full wildcard (`*`), pattern containing the role name (`arn:aws:iam::*:role/admin*`), and pattern that is a prefix or component of the actual ARN.
```text ### Catalog of list properties
# Git dependency (prowler-cloud/cartography@0.126.1):
https://raw.githubusercontent.com/prowler-cloud/cartography/refs/tags/0.126.1/docs/root/modules/{provider}/schema.md
# PyPI dependency (cartography = "^0.126.0"): The provider catalog lives in `api/src/backend/tasks/jobs/attack_paths/provider_config.py` (`AWS_NORMALIZED_LISTS`). Beyond policy statements it includes KMS algorithms, ECS container-definition lists (`entry_point`, `command`, `links`, `dns_servers`, ...), CloudFront aliases, Inspector finding URL and vulnerability lists, RDS event-subscription categories, and others. To query a list property that is not in the catalog, add an entry there first so the sync layer materialises it.
https://raw.githubusercontent.com/cartography-cncf/cartography/refs/tags/0.126.0/docs/root/modules/{provider}/schema.md
```
Read the schema to discover available node labels, properties, and relationships for the target resources. Internal labels (`_ProviderResource`, `_AWSResource`, `_Tenant_*`, `_Provider_*`) exist for isolation but should never appear in queries.
### 4. Create query definition
Use the appropriate pattern (privilege escalation or network exposure) with:
- **id**: `{provider}-{kebab-case-description}`
- **name**: Short, human-friendly label. For sourced queries, append the reference ID: `"EC2 Instance Launch with Privileged Role (EC2-001)"`.
- **short_description**: Brief explanation, no technical permissions.
- **description**: Full technical explanation. Plain text only.
- **provider**: Provider identifier (aws, azure, gcp, kubernetes, github)
- **cypher**: The openCypher query with proper escaping
- **parameters**: Optional list of user-provided parameters (`parameters=[]` if none)
- **attribution**: Optional `AttackPathsQueryAttribution(text, link)` for sourced queries. The `text` includes source, reference ID, and permissions. The `link` uses a lowercase ID. Omit for non-sourced queries.
### 5. Add query to provider list
Add the constant to the `{PROVIDER}_QUERIES` list.
---
## Query naming conventions
### Query ID
```text
{provider}-{category}-{description}
```
Examples: `aws-ec2-privesc-passrole-iam`, `aws-ec2-instances-internet-exposed`
### Query constant name
```text
{PROVIDER}_{CATEGORY}_{DESCRIPTION}
```
Examples: `AWS_EC2_PRIVESC_PASSROLE_IAM`, `AWS_EC2_INSTANCES_INTERNET_EXPOSED`
---
## Query categories
| Category | Description | Example |
| -------------------- | ------------------------------ | ------------------------- |
| Basic Resource | List resources with properties | RDS instances, S3 buckets |
| Network Exposure | Internet-exposed resources | EC2 with public IPs |
| Privilege Escalation | IAM privilege escalation paths | PassRole + RunInstances |
| Data Access | Access to sensitive data | EC2 with S3 access |
--- ---
@@ -315,53 +295,42 @@ Examples: `AWS_EC2_PRIVESC_PASSROLE_IAM`, `AWS_EC2_INSTANCES_INTERNET_EXPOSED`
### Match account and principal ### Match account and principal
```cypher ```cypher
MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)--(policy:AWSPolicy)--(stmt:AWSPolicyStatement) MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)-[:POLICY]->(policy:AWSPolicy)-[:STATEMENT]->(stmt:AWSPolicyStatement {effect: 'Allow'})
``` ```
### Check IAM action permissions The `(aws)--(principal)` hop stays anonymous; the `POLICY` and `STATEMENT` hops are typed.
### Roles trusting a service
```cypher ```cypher
WHERE stmt.effect = 'Allow' MATCH path_target = (aws)--(target_role:AWSRole)-[:TRUSTS_AWS_PRINCIPAL]-(:AWSPrincipal {arn: 'ec2.amazonaws.com'})
AND any(action IN stmt.action WHERE
toLower(action) = 'iam:passrole'
OR toLower(action) = 'iam:*'
OR action = '*'
)
``` ```
### Find roles trusting a service ### Roles a principal can assume
```cypher ```cypher
MATCH path_target = (aws)--(target_role:AWSRole)-[:TRUSTS_AWS_PRINCIPAL]->(:AWSPrincipal {arn: 'ec2.amazonaws.com'}) MATCH path_target = (aws)--(target_role:AWSRole)-[:STS_ASSUMEROLE_ALLOW]-(principal)
``` ```
### Find roles the principal can assume ### JSON-encoded properties
Note the arrow direction - `STS_ASSUMEROLE_ALLOW` points from the role to the principal: Object-typed Cartography properties (most notably `condition` on `AWSPolicyStatement` and `S3PolicyStatement`) are stored as JSON-encoded strings, e.g. `'{"StringEquals":{"aws:SourceAccount":"123456789012"}}'`. There is no JSON parser at query time, so use `CONTAINS` for substring checks:
```cypher ```cypher
MATCH path_target = (aws)--(target_role:AWSRole)<-[:STS_ASSUMEROLE_ALLOW]-(principal) WHERE stmt.condition CONTAINS '"aws:SourceAccount"'
``` ```
### Check resource scope For structured inspection, fetch the rows and parse in Python. Cypher cannot navigate JSON object keys.
```cypher
WHERE any(resource IN stmt.resource WHERE
resource = '*'
OR target_role.arn CONTAINS resource
OR resource CONTAINS target_role.name
)
```
### Internet node via path connectivity ### Internet node via path connectivity
The Internet node is reached through `CAN_ACCESS` relationships to already-scoped resources. No standalone lookup needed:
```cypher ```cypher
OPTIONAL MATCH (internet:Internet)-[can_access:CAN_ACCESS]->(resource) OPTIONAL MATCH (internet:Internet)-[can_access:CAN_ACCESS]->(resource)
``` ```
### Multi-label OR (match multiple resource types) `resource` must already be bound by the account-anchored pattern above.
### Multi-label OR (multiple resource types)
```cypher ```cypher
MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x)-[q]-(y) MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x)-[q]-(y)
@@ -373,7 +342,7 @@ WHERE (x:EC2PrivateIp AND x.public_ip = $ip)
### Include Prowler findings ### Include Prowler findings
Deduplicate nodes before the ProwlerFinding lookup to avoid redundant OPTIONAL MATCH calls on nodes that appear in multiple paths: Deduplicate nodes before the typed finding probe to avoid one `OPTIONAL MATCH` per path-occurrence of the same node:
```cypher ```cypher
WITH collect(path_principal) + collect(path_target) AS paths WITH collect(path_principal) + collect(path_target) AS paths
@@ -382,12 +351,12 @@ UNWIND nodes(p) AS n
WITH paths, collect(DISTINCT n) AS unique_nodes WITH paths, collect(DISTINCT n) AS unique_nodes
UNWIND unique_nodes AS n UNWIND unique_nodes AS n
OPTIONAL MATCH (n)-[pfr]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}}) OPTIONAL MATCH (n)-[pfr:HAS_FINDING]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}})
RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
``` ```
For network exposure queries, aggregate the internet node and relationship alongside paths: For network-exposure queries, aggregate the Internet node and its edge alongside paths:
```cypher ```cypher
WITH collect(path) AS paths, head(collect(internet)) AS internet, collect(can_access) AS can_access WITH collect(path) AS paths, head(collect(internet)) AS internet, collect(can_access) AS can_access
@@ -396,7 +365,7 @@ UNWIND nodes(p) AS n
WITH paths, internet, can_access, collect(DISTINCT n) AS unique_nodes WITH paths, internet, can_access, collect(DISTINCT n) AS unique_nodes
UNWIND unique_nodes AS n UNWIND unique_nodes AS n
OPTIONAL MATCH (n)-[pfr]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}}) OPTIONAL MATCH (n)-[pfr:HAS_FINDING]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}})
RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr,
internet, can_access internet, can_access
@@ -406,22 +375,22 @@ RETURN paths, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr,
## Prowler-specific labels and relationships ## Prowler-specific labels and relationships
These are added by the sync task, not part of the Cartography schema. For all other node labels, properties, and relationships, **always consult the Cartography schema** (see step 2 below). Added by the sync task, not part of the Cartography schema. For everything else, consult the pinned Cartography schema (see "Creation steps").
| Label / Relationship | Description | | Label / Relationship | Description |
| ---------------------- | -------------------------------------------------- | | ---------------------- | ----------------------------------------------------------- |
| `ProwlerFinding` | Finding node (`status`, `severity`, `check_id`) | | `ProwlerFinding` | Finding node (`status`, `severity`, `check_id`) |
| `Internet` | Internet sentinel node | | `Internet` | Internet sentinel node |
| `CAN_ACCESS` | Internet-to-resource exposure (relationship) | | `CAN_ACCESS` | `(Internet)-[:CAN_ACCESS]->(resource)` exposure edge |
| `HAS_FINDING` | Resource-to-finding link (relationship) | | `HAS_FINDING` | `(resource)-[:HAS_FINDING]->(:ProwlerFinding)` finding link |
| `TRUSTS_AWS_PRINCIPAL` | Role trust relationship | | `TRUSTS_AWS_PRINCIPAL` | Role trust relationship |
| `STS_ASSUMEROLE_ALLOW` | Can assume role (direction: role -> principal) | | `STS_ASSUMEROLE_ALLOW` | Can assume role |
--- ---
## Parameters ## Parameters
For queries requiring user input: For queries that take user input:
```python ```python
parameters=[ parameters=[
@@ -438,50 +407,83 @@ parameters=[
--- ---
## Best practices ## openCypher compatibility
1. **Chain all MATCHes from the root account node**: Every `MATCH` clause must connect to the `aws` variable (or another variable already bound to the account's subgraph). An unanchored `MATCH` would return nodes from all providers. Queries must run on both Neo4j and Amazon Neptune. Avoid these constructs:
```cypher | Feature | Use instead |
// WRONG: matches ALL AWSRoles across all providers | --------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
MATCH (role:AWSRole) WHERE role.name = 'admin' | APOC procedures (`apoc.*`) | Real nodes and relationships in the graph |
| Neptune extensions | Standard openCypher |
| `reduce()` | `UNWIND` + `collect()` |
| `FOREACH` | `WITH` + `UNWIND` + `SET` |
| Regex `=~` | `toLower()` + exact match, or `STARTS WITH` / `CONTAINS` |
| `CALL () { UNION }` | Multi-label `OR` in `WHERE` (see pattern above) |
| `any(x IN list ...)` | `size([x IN list WHERE pred]) > 0` |
| `all(x IN list ...)` | `size([x IN list WHERE pred]) = size(list)` |
| `none(x IN list ...)` | `size([x IN list WHERE pred]) = 0` |
| `EXISTS { MATCH (pattern) WHERE pred }` | Standalone `MATCH (pattern)` + `WHERE pred`; precede the downstream `collect(path...)` with `WITH DISTINCT <path-vars>` to dedupe the joins |
// CORRECT: scoped to the specific account's subgraph For list-typed properties in the catalog (action, resource, and so on), traverse the `HAS_*` edges to the child item nodes via the multi-`MATCH` shape shown in "List-typed properties as child nodes". The parent node does not carry the list as a single field, so `split(...)` and comma-string predicates do not apply.
MATCH (aws)--(role:AWSRole) WHERE role.name = 'admin'
```
**Exception**: A second-permission MATCH like `MATCH (principal)--(policy2:AWSPolicy)--(stmt2:AWSPolicyStatement)` is safe because `principal` is already bound to the account's subgraph by the first MATCH. It does not need to chain from `aws` again.
2. **Include Prowler findings**: Always add `OPTIONAL MATCH (n)-[pfr]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}})` with `collect(DISTINCT pf)`.
3. **Comment the query purpose**: Add inline comments explaining each MATCH clause.
4. **Never use internal labels in queries**: `_ProviderResource`, `_AWSResource`, `_Tenant_*`, `_Provider_*` are for system isolation. They should never appear in predefined or custom query text.
6. **Internet node uses path connectivity**: Reach it via `OPTIONAL MATCH (internet:Internet)-[can_access:CAN_ACCESS]->(resource)` where `resource` is already scoped by the account anchor. No standalone lookup.
--- ---
## openCypher compatibility ## Best practices
Queries must be written in **openCypher Version 9** for compatibility with both Neo4j and Amazon Neptune. 1. **Chain every MATCH from the account anchor.** An unanchored `MATCH (role:AWSRole)` returns roles from every provider in the graph; `MATCH (aws)--(role:AWSRole)` is scoped. A second-permission MATCH like `MATCH (principal)--(policy2:AWSPolicy)--(stmt2:AWSPolicyStatement)` is safe because `principal` is already bound to the account's subgraph.
2. **Type the finding probe.** Always `OPTIONAL MATCH (n)-[pfr:HAS_FINDING]-(pf:{PROWLER_FINDING_LABEL} {{status: 'FAIL'}})`. The type lets Neptune apply an inline edge filter; an untyped probe scans every incident edge of high-degree nodes.
3. **Comment each MATCH.** One inline `// ...` line per clause explaining its role.
4. **Never use internal labels.** `_ProviderResource`, `_AWSResource`, `_Tenant_*`, `_Provider_*` are system isolation labels and must not appear in query text (predefined or custom).
5. **Reach the Internet node through path connectivity** via `(internet:Internet)-[:CAN_ACCESS]->(resource)`, never as a standalone match.
6. **Preserve the `RETURN` contract.** `paths, dpf, dpfr` for the standard shape; add `internet, can_access` for network-exposure queries. The serializer and visualiser depend on these names.
### Avoid these (not in openCypher spec) ---
| Feature | Use instead | ## Naming conventions
| -------------------------- | ------------------------------------------------------ |
| APOC procedures (`apoc.*`) | Real nodes and relationships in the graph | - **ID**: kebab-case `{provider}-{category}-{description}`, e.g. `aws-ec2-privesc-passrole-iam`.
| Neptune extensions | Standard openCypher | - **Constant**: SHOUTING*SNAKE_CASE `{PROVIDER}*{CATEGORY}\_{DESCRIPTION}`, e.g. `AWS_EC2_PRIVESC_PASSROLE_IAM`.
| `reduce()` function | `UNWIND` + `collect()` |
| `FOREACH` clause | `WITH` + `UNWIND` + `SET` | ---
| Regex operator (`=~`) | `toLower()` + exact match, or `CONTAINS`/`STARTS WITH`. One legacy query uses `=~` - do not add new usages |
| `CALL () { UNION }` | Multi-label OR in WHERE (see patterns section) | ## Creation steps
1. **Read the queries module first** to match the existing style:
```text
api/src/backend/api/attack_paths/queries/
├── __init__.py
├── types.py # dataclass definitions
├── registry.py
└── {provider}.py
```
2. **Fetch the Cartography schema for the pinned version.** Do not guess labels, properties, or relationships. Read the dependency pin:
```bash
grep cartography api/pyproject.toml
```
Then fetch the schema for that exact tag:
```text
# Git pin (prowler-cloud/cartography@<TAG>):
https://raw.githubusercontent.com/prowler-cloud/cartography/refs/tags/<TAG>/docs/root/modules/{provider}/schema.md
# PyPI pin (cartography==<TAG>):
https://raw.githubusercontent.com/cartography-cncf/cartography/refs/tags/<TAG>/docs/root/modules/{provider}/schema.md
```
3. **Build the query** using the canonical predefined template plus the appropriate sub-pattern (privilege escalation or network exposure). For list-typed properties (action/resource/etc.), traverse the exploded child nodes via `[:HAS_ACTION]->(:AWSPolicyStatementActionItem)` etc. (see "List-typed properties as child nodes" and the `AWS_NORMALIZED_LISTS` catalog).
4. **Register** the constant in the `{PROVIDER}_QUERIES` list at the bottom of the provider file.
--- ---
## Reference ## Reference
- **pathfinding.cloud**: https://github.com/DataDog/pathfinding.cloud (use `curl | jq`, not WebFetch) - **pathfinding.cloud**: https://github.com/DataDog/pathfinding.cloud (use `curl | jq`; the aggregated `paths.json` is too large for WebFetch).
- **Cartography schema**: `https://raw.githubusercontent.com/{org}/cartography/refs/tags/{version}/docs/root/modules/{provider}/schema.md` - **Cartography schema** (per pinned tag): `https://raw.githubusercontent.com/{org}/cartography/refs/tags/{tag}/docs/root/modules/{provider}/schema.md`.
- **Neptune openCypher compliance**: https://docs.aws.amazon.com/neptune/latest/userguide/feature-opencypher-compliance.html - **Neptune openCypher compliance**: https://docs.aws.amazon.com/neptune/latest/userguide/feature-opencypher-compliance.html.
- **openCypher spec**: https://github.com/opencypher/openCypher - **openCypher spec**: https://github.com/opencypher/openCypher.
- **Sync converter** (`tasks/jobs/attack_paths/sync.py`): list-typed node properties listed in `tasks/jobs/attack_paths/provider_config.py::AWS_NORMALIZED_LISTS` are materialised as child item nodes + `HAS_*` edges. Properties that are not in the catalog are serialised to a comma-delimited string and emit a one-time warning. Dict-typed properties become JSON strings. Same shape on both sinks.