Compare commits

...

31 Commits

Author SHA1 Message Date
Andoni A. b050c917c6 perf(attack-paths): optimize getPathEdges with O(1) adjacency maps
Pre-build parent and children maps at the start of traversal for O(1)
lookups instead of O(n) array searches per traversal step. This improves
performance for large attack path graphs.
2026-01-14 17:12:13 +01:00
Andoni A. 67933d7d2d Merge branch 'attack-paths-demo' into attack-paths-demo-extras 2026-01-14 17:06:33 +01:00
Andoni Alonso 39280c8b9b feat(attack-paths): add Bedrock and AttachRolePolicy privilege escalation queries (#9793) 2026-01-14 17:01:21 +01:00
Andoni Alonso 4bcaf29b32 feat(attack-paths): improve graph path highlighting (#9769) 2026-01-14 16:59:27 +01:00
Josema Camacho e95be697ef Prowler 511 leaving one database per scan (#9795) 2026-01-14 16:19:02 +01:00
Andoni A. 6fa4565ebd fix(attack-paths): connect virtual nodes from principal instead of effective_principal
When a principal can assume a role (effective_principal), the virtual
relationship was being created from effective_principal but the graph
showed the original principal, causing disconnected nodes.

Now the virtual relationship is created from the original principal,
keeping the graph fully connected while still detecting escalation
paths that require role assumption.
2026-01-14 09:18:18 +01:00
Andoni A. e426c29207 fix(attack-paths): remove duplicate path_target causing disconnected nodes
Removed the extra path_target re-match that was causing role nodes to appear
disconnected in the visualization. The target_role is now only connected via
virtual relationships (PASSES_ROLE → target_role → GRANTS_ACCESS), which
provides a cleaner and more accurate attack path visualization.
2026-01-14 09:11:27 +01:00
Andoni A. 1d8d4f9325 refactor(attack-paths): show target roles in PassRole escalation paths
Updated PassRole queries to display which specific role(s) can be passed
in the visualization, instead of just showing a count. The path now shows:

Principal → New Resource → Target Role → Privilege Escalation

This allows users to see exactly which admin roles a principal can pass
to escalate privileges, which is crucial for security analysis.

Queries updated: Lambda, ECS, Glue, Bedrock, CloudFormation
2026-01-14 09:06:11 +01:00
Andoni A. cad44a3510 fix(attack-paths): fix duplicate virtual nodes in priv escalation queries
Virtual nodes were being created for each result row, causing duplicates
in the graph visualization. Fixed by using aggregation pattern:
1. Deduplicate principals FIRST (before matching target roles)
2. Collect target roles per principal
3. Create ONE virtual node per principal with role count

Queries fixed:
- aws-iam-privesc-passrole-lambda
- aws-glue-privesc-passrole-dev-endpoint
- aws-bedrock-privesc-passrole-code-interpreter
- aws-cloudformation-privesc-passrole-create-stack

The virtual node description now shows "N admin role(s) can be passed"
instead of creating N separate nodes.
2026-01-14 08:58:26 +01:00
Andoni A. ee73e043f9 refactor(attack-paths): apply query improvements to remaining priv escalation queries
Apply the same patterns from PR #9770 to the other privilege escalation
queries that were missing the improvements:

- aws-iam-privesc-create-policy-version
- aws-iam-privesc-attach-role-policy-assume-role
- aws-iam-privesc-passrole-lambda
- aws-iam-privesc-role-chain

Changes applied:
- Add DISTINCT deduplication before creating virtual relationships
- Add re-match paths at the end for proper visualization
- Remove redundant path variables from RETURN statements
- Create unique virtual node IDs per principal->target pair
2026-01-13 16:39:32 +01:00
Andoni A. 815797bc2b fix(attack-paths): hide findings completely in full view
- Change opacity-based hiding to display:none for finding nodes
- Use visibility:hidden for finding edges in full view
- Add isFilteredView to useEffect dependency array
- In filtered view, all nodes/edges remain visible as expected
2026-01-12 17:13:52 +01:00
Andoni A. 9cd249c561 fix(attack-paths): show findings at full opacity in filtered view
When in filtered view, findings are part of the selected path and
should be fully visible, not hidden with reduced opacity.

- Add isFilteredView prop to AttackPathGraph component
- Skip hiding findings when isFilteredView is true
2026-01-12 16:59:47 +01:00
Andoni A. 00fe96a9f7 refactor(attack-paths): redraw graph with filtered data for optimal layout
Reverts the visibility-based approach to use data-changing approach
which redraws the graph with only the selected path nodes, optimizing
the layout for the filtered view.

- Store fullData separately when entering filtered view
- Compute filtered subgraph with only visible nodes and edges
- Graph redraws with new data, auto-fitting to show selected path
- Restore fullData when exiting filtered view
2026-01-12 16:56:35 +01:00
Andoni A. 7c45ee1dbb refactor(attack-paths): use visibility-based filtering with D3 transitions
Replace data-replacement approach with visibleNodeIds Set to fix
animation issues caused by Next.js server component re-renders.

- Changed useGraphState to compute visibleNodeIds without modifying data
- Graph component now animates opacity changes via D3 transitions
- Keeps DOM structure stable while providing smooth visual transitions
- Findings now properly appear when filtering by a resource node
2026-01-12 16:50:25 +01:00
Andoni A. d19a23f829 feat(attack-paths): highlight selected node and show findings in filtered view
- Add orange glow filter and pulsing animation for selected/filtered nodes
- Pass isFilteredView prop to graph component to show findings when filtering
- Update node styling to show thicker border (4px) and orange glow on selection
- Ensure findings are visible in filtered view instead of being hidden by default
2026-01-12 16:41:09 +01:00
Andoni A. b071fffe57 feat(attack-paths): add filtered view when clicking on graph nodes
When clicking a node in the attack path graph:
- Filters the graph to show only upstream (ancestors) and downstream (descendants) paths
- Includes findings directly connected to the selected node
- Shows a "Back to Full View" button to restore the complete graph
- Displays an indicator showing which node is being filtered

Uses atomic Zustand state updates to ensure proper re-rendering of the D3 graph.
2026-01-12 16:33:30 +01:00
Andoni A. 422c55404b refactor(attack-paths): simplify legend by consolidating resource types
Merge all individual resource type items (AWS Account, EC2 Instance,
S3 Bucket, etc.) into a single "Resource" entry to reduce visual
clutter in the graph legend.
2026-01-12 16:26:20 +01:00
Andoni A. 6c307385b0 fix(attack-paths): preserve aws variable in ECS query WITH clause
Add aws to intermediate WITH clause to fix 'Variable aws not defined'
error when deduplicating principals before matching target roles.
2026-01-12 14:47:12 +01:00
Andoni A. 13964ccb1c fix(attack-paths): merge ECS target roles into single virtual node per principal
Instead of creating one virtual ECS task node per target role, merge all
target roles into a single node per principal. The node description shows
how many admin roles can be passed (e.g., '3 admin role(s) can be passed').

This reduces visual clutter when a principal can pass multiple admin roles.
2026-01-12 13:30:53 +01:00
Andoni A. 64ed526e31 refactor(attack-paths): rename virtual nodes from Malicious to New
The virtual nodes represent potential resources that could be created
for privilege escalation, not actual malicious resources. Renamed for
clarity:
- Malicious Task Definition -> New Task Definition
- Malicious Dev Endpoint -> New Dev Endpoint
- Malicious Code Interpreter -> New Code Interpreter
- Malicious Stack -> New Stack
2026-01-09 15:05:12 +01:00
Andoni A. 2388a053ee fix(attack-paths): highlight single path upstream instead of all paths
Changed upstream traversal to follow only one parent at each level
instead of all parents. This prevents the entire graph from lighting
up when selecting a node that has multiple ancestors with many children.

- Upstream: now uses find() to get first parent only
- Downstream: unchanged, still highlights all descendants
2026-01-08 18:18:44 +01:00
Andoni A. 7bb5354275 feat(attack-paths): improve graph visualization and interactions
- Change edge color from orange to white by default
- Highlight entire path in orange on node hover/selection
- Add Ctrl + scroll to zoom functionality with increased speed
- Update node borders to orange on hover/selection
- Add zoom hint to legend
- Remove hover effect from info button
2026-01-08 17:57:24 +01:00
Andoni A. 03cae9895b wip 2026-01-08 15:43:59 +01:00
Andoni A. e398b654d4 merge all privesc nodes into one, change it to look like a finding 2026-01-07 16:45:40 +01:00
Andoni A. d9e978af29 initial version 2026-01-07 10:50:29 +01:00
Josema Camacho 95d9e9a59f feat(attack-paths): Update Cartography dependency and its usage (#9593) 2025-12-18 15:52:15 +01:00
Josema Camacho 48f19d0f11 fix(attack-paths): neo4j.exceptions import (#9356) 2025-12-01 10:31:18 +01:00
Josema Camacho 345033e58a Fix attack paths demo neo4j conneciton (#9352)
Add retryable Neo4j session.
2025-11-29 12:55:49 +01:00
Alan Buscaglia 15cb87534c feat(attack-paths): apply Scope Rule pattern for feature-local organization (#9270)
Co-authored-by: Claude <noreply@anthropic.com>
2025-11-28 17:05:35 +01:00
Josema Camacho 5a85db103d feat(attack-paths): Task and endpoints (#9344)
- Added support to Neo4j
- Added Cartography as Attack Paths Scan
- Added Attack Path Scans endpoints for their management and run queries on those scan
2025-11-28 15:44:15 +01:00
César Arroba 2b86078d06 chore(api): build attack paths demo image (#9349) 2025-11-28 15:33:04 +01:00
101 changed files with 12261 additions and 226 deletions
+20 -1
View File
@@ -41,6 +41,26 @@ POSTGRES_DB=prowler_db
# POSTGRES_REPLICA_MAX_ATTEMPTS=3
# POSTGRES_REPLICA_RETRY_BASE_DELAY=0.5
# Neo4j auth
NEO4J_HOST=neo4j
NEO4J_PORT=7687
NEO4J_USER=neo4j
NEO4J_PASSWORD=neo4j_password
# Neo4j settings
NEO4J_DBMS_MAX__DATABASES=1000000
NEO4J_SERVER_MEMORY_PAGECACHE_SIZE=1G
NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE=1G
NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE=1G
NEO4J_POC_EXPORT_FILE_ENABLED=true
NEO4J_APOC_IMPORT_FILE_ENABLED=true
NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG=true
NEO4J_PLUGINS=["apoc"]
NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST=apoc.*
NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED=apoc.*
NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS=0.0.0.0:7687
# Neo4j Prowler settings
NEO4J_INSERT_BATCH_SIZE=500
# Celery-Prowler task settings
TASK_RETRY_DELAY_SECONDS=0.1
TASK_RETRY_ATTEMPTS=5
@@ -110,7 +130,6 @@ SENTRY_ENVIRONMENT=local
SENTRY_RELEASE=local
NEXT_PUBLIC_SENTRY_ENVIRONMENT=${SENTRY_ENVIRONMENT}
#### Prowler release version ####
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.12.2
@@ -3,11 +3,11 @@ name: 'API: Container Build and Push'
on:
push:
branches:
- 'master'
- 'attack-paths-demo'
paths:
- 'api/**'
- 'prowler/**'
- '.github/workflows/api-build-lint-push-containers.yml'
- '.github/workflows/api-container-build-push.yml'
release:
types:
- 'published'
@@ -27,7 +27,7 @@ concurrency:
env:
# Tags
LATEST_TAG: latest
LATEST_TAG: attack-paths-demo
RELEASE_TAG: ${{ github.event.release.tag_name || inputs.release_tag }}
STABLE_TAG: stable
WORKING_DIRECTORY: ./api
+17
View File
@@ -75,6 +75,23 @@ prowler dashboard
```
![Prowler Dashboard](docs/images/products/dashboard.png)
## Attack Paths
Attack Paths automatically extends every completed AWS scan with a Neo4j graph that combines Cartography's cloud inventory with Prowler findings. The feature runs in the API worker after each scan and therefore requires:
- An accessible Neo4j instance (the Docker Compose files already ships a `neo4j` service).
- The following environment variables so Django and Celery can connect:
| Variable | Description | Default |
| --- | --- | --- |
| `NEO4J_HOST` | Hostname used by the API containers. | `neo4j` |
| `NEO4J_PORT` | Bolt port exposed by Neo4j. | `7687` |
| `NEO4J_USER` / `NEO4J_PASSWORD` | Credentials with rights to create per-tenant databases. | `neo4j` / `neo4j_password` |
Every AWS provider scan will enqueue an Attack Paths ingestion job automatically. Other cloud providers will be added in future iterations.
# Prowler at a Glance
> [!Tip]
> For the most accurate and up-to-date information about checks, services, frameworks, and categories, visit [**Prowler Hub**](https://hub.prowler.com).
+3
View File
@@ -4,6 +4,9 @@ All notable changes to the **Prowler API** are documented in this file.
## [1.16.0] (Unreleased)
### Added
- Attack Paths backend support [(#9344)](https://github.com/prowler-cloud/prowler/pull/9344)
### Changed
- Restore the compliance overview endpoint's mandatory filters [(#9330)](https://github.com/prowler-cloud/prowler/pull/9330)
+991 -45
View File
File diff suppressed because it is too large Load Diff
+4 -2
View File
@@ -24,7 +24,7 @@ dependencies = [
"drf-spectacular-jsonapi==0.5.1",
"gunicorn==23.0.0",
"lxml==5.3.2",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@attack-paths-demo",
"psycopg2-binary==2.9.9",
"pytest-celery[redis] (>=1.0.1,<2.0.0)",
"sentry-sdk[django] (>=2.20.0,<3.0.0)",
@@ -35,7 +35,9 @@ dependencies = [
"markdown (>=3.9,<4.0)",
"drf-simple-apikey (==2.2.1)",
"matplotlib (>=3.10.6,<4.0.0)",
"reportlab (>=4.4.4,<5.0.0)"
"reportlab (>=4.4.4,<5.0.0)",
"neo4j (<6.0.0)",
"cartography @ git+https://github.com/prowler-cloud/cartography@master",
]
description = "Prowler's API (Django/DRF)"
license = "Apache-2.0"
+7 -1
View File
@@ -1,4 +1,5 @@
import logging
import atexit
import os
import sys
from pathlib import Path
@@ -30,6 +31,7 @@ class ApiConfig(AppConfig):
def ready(self):
from api import schema_extensions # noqa: F401
from api import signals # noqa: F401
from api.attack_paths import database as graph_database
from api.compliance import load_prowler_compliance
# Generate required cryptographic keys if not present, but only if:
@@ -39,6 +41,10 @@ class ApiConfig(AppConfig):
if "manage.py" not in sys.argv or os.environ.get("RUN_MAIN"):
self._ensure_crypto_keys()
if not getattr(settings, "TESTING", False):
graph_database.init_driver()
atexit.register(graph_database.close_driver)
load_prowler_compliance()
def _ensure_crypto_keys(self):
@@ -54,7 +60,7 @@ class ApiConfig(AppConfig):
global _keys_initialized
# Skip key generation if running tests
if hasattr(settings, "TESTING") and settings.TESTING:
if getattr(settings, "TESTING", False):
return
# Skip if already initialized in this process
@@ -0,0 +1,13 @@
from api.attack_paths.query_definitions import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
get_queries_for_provider,
get_query_by_id,
)
__all__ = [
"AttackPathsQueryDefinition",
"AttackPathsQueryParameterDefinition",
"get_queries_for_provider",
"get_query_by_id",
]
@@ -0,0 +1,144 @@
import logging
import threading
from contextlib import contextmanager
from typing import Iterator
from uuid import UUID
import neo4j
import neo4j.exceptions
from django.conf import settings
from api.attack_paths.retryable_session import RetryableSession
# Without this Celery goes crazy with Neo4j logging
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
SERVICE_UNAVAILABLE_MAX_RETRIES = 3
# Module-level process-wide driver singleton
_driver: neo4j.Driver | None = None
_lock = threading.Lock()
# Base Neo4j functions
def get_uri() -> str:
host = settings.DATABASES["neo4j"]["HOST"]
port = settings.DATABASES["neo4j"]["PORT"]
return f"bolt://{host}:{port}"
def init_driver() -> neo4j.Driver:
global _driver
if _driver is not None:
return _driver
with _lock:
if _driver is None:
uri = get_uri()
config = settings.DATABASES["neo4j"]
_driver = neo4j.GraphDatabase.driver(
uri, auth=(config["USER"], config["PASSWORD"])
)
_driver.verify_connectivity()
return _driver
def get_driver() -> neo4j.Driver:
return init_driver()
def close_driver() -> None: # TODO: Use it
global _driver
with _lock:
if _driver is not None:
try:
_driver.close()
finally:
_driver = None
@contextmanager
def get_session(database: str | None = None) -> Iterator[RetryableSession]:
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: get_driver().session(database=database),
close_driver=close_driver, # Just to avoid circular imports
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
raise GraphDatabaseQueryException(message=exc.message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
def create_database(database: str) -> None:
query = "CREATE DATABASE $database IF NOT EXISTS"
parameters = {"database": database}
with get_session() as session:
session.run(query, parameters)
def drop_database(database: str) -> None:
query = f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA"
with get_session() as session:
session.run(query)
def drop_subgraph(database: str, root_node_label: str, root_node_id: str) -> int:
query = """
MATCH (a:__ROOT_NODE_LABEL__ {id: $root_node_id})
CALL apoc.path.subgraphNodes(a, {})
YIELD node
DETACH DELETE node
RETURN COUNT(node) AS deleted_nodes_count
""".replace("__ROOT_NODE_LABEL__", root_node_label)
parameters = {"root_node_id": root_node_id}
with get_session(database) as session:
result = session.run(query, parameters)
try:
return result.single()["deleted_nodes_count"]
except neo4j.exceptions.ResultConsumedError:
return 0 # As there are no nodes to delete, the result is empty
# Neo4j functions related to Prowler + Cartography
DATABASE_NAME_TEMPLATE = "db-{attack_paths_scan_id}"
def get_database_name(attack_paths_scan_id: UUID) -> str:
attack_paths_scan_id_str = str(attack_paths_scan_id).lower()
return DATABASE_NAME_TEMPLATE.format(attack_paths_scan_id=attack_paths_scan_id_str)
# Exceptions
class GraphDatabaseQueryException(Exception):
def __init__(self, message: str, code: str | None = None) -> None:
super().__init__(message)
self.message = message
self.code = code
def __str__(self) -> str:
if self.code:
return f"{self.code}: {self.message}"
return self.message
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,87 @@
import logging
from collections.abc import Callable
from typing import Any
import neo4j
import neo4j.exceptions
logger = logging.getLogger(__name__)
class RetryableSession:
"""
Wrapper around `neo4j.Session` that retries `neo4j.exceptions.ServiceUnavailable` errors.
"""
def __init__(
self,
session_factory: Callable[[], neo4j.Session],
close_driver: Callable[[], None], # Just to avoid circular imports
max_retries: int,
) -> None:
self._session_factory = session_factory
self._close_driver = close_driver
self._max_retries = max(0, max_retries)
self._session = self._session_factory()
def close(self) -> None:
if self._session is not None:
self._session.close()
self._session = None
def __enter__(self) -> "RetryableSession":
return self
def __exit__(self, exc_type: Any, exc: Any, exc_tb: Any) -> None:
self.close()
def run(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("run", *args, **kwargs)
def write_transaction(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("write_transaction", *args, **kwargs)
def read_transaction(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("read_transaction", *args, **kwargs)
def execute_write(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("execute_write", *args, **kwargs)
def execute_read(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("execute_read", *args, **kwargs)
def __getattr__(self, item: str) -> Any:
return getattr(self._session, item)
def _call_with_retry(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
attempt = 0
last_exc: neo4j.exceptions.ServiceUnavailable | None = None
while attempt <= self._max_retries:
try:
method = getattr(self._session, method_name)
return method(*args, **kwargs)
except (
neo4j.exceptions.ServiceUnavailable
) as exc: # pragma: no cover - depends on infra
last_exc = exc
attempt += 1
if attempt > self._max_retries:
raise
logger.warning(
f"Neo4j session {method_name} failed with ServiceUnavailable ({attempt}/{self._max_retries} attempts). Retrying..."
)
self._refresh_session()
raise last_exc if last_exc else RuntimeError("Unexpected retry loop exit")
def _refresh_session(self) -> None:
if self._session is not None:
self._session.close()
self._close_driver()
self._session = self._session_factory()
@@ -0,0 +1,143 @@
import logging
from typing import Any
from rest_framework.exceptions import APIException, ValidationError
from api.attack_paths import database as graph_database, AttackPathsQueryDefinition
from api.models import AttackPathsScan
from config.custom_logging import BackendLogger
logger = logging.getLogger(BackendLogger.API)
def normalize_run_payload(raw_data):
if not isinstance(raw_data, dict): # Let the serializer handle this
return raw_data
if "data" in raw_data and isinstance(raw_data.get("data"), dict):
data_section = raw_data.get("data") or {}
attributes = data_section.get("attributes") or {}
payload = {
"id": attributes.get("id", data_section.get("id")),
"parameters": attributes.get("parameters"),
}
# Remove `None` parameters to allow defaults downstream
if payload.get("parameters") is None:
payload.pop("parameters")
return payload
return raw_data
def prepare_query_parameters(
definition: AttackPathsQueryDefinition,
provided_parameters: dict[str, Any],
provider_uid: str,
) -> dict[str, Any]:
parameters = dict(provided_parameters or {})
expected_names = {parameter.name for parameter in definition.parameters}
provided_names = set(parameters.keys())
unexpected = provided_names - expected_names
if unexpected:
raise ValidationError(
{"parameters": f"Unknown parameter(s): {', '.join(sorted(unexpected))}"}
)
missing = expected_names - provided_names
if missing:
raise ValidationError(
{
"parameters": f"Missing required parameter(s): {', '.join(sorted(missing))}"
}
)
clean_parameters = {
"provider_uid": str(provider_uid),
}
for definition_parameter in definition.parameters:
raw_value = provided_parameters[definition_parameter.name]
try:
casted_value = definition_parameter.cast(raw_value)
except (ValueError, TypeError) as exc:
raise ValidationError(
{
"parameters": (
f"Invalid value for parameter `{definition_parameter.name}`: {str(exc)}"
)
}
)
clean_parameters[definition_parameter.name] = casted_value
return clean_parameters
def execute_attack_paths_query(
attack_paths_scan: AttackPathsScan,
definition: AttackPathsQueryDefinition,
parameters: dict[str, Any],
) -> dict[str, Any]:
try:
with graph_database.get_session(attack_paths_scan.graph_database) as session:
result = session.run(definition.cypher, parameters)
return _serialize_graph(result.graph())
except graph_database.GraphDatabaseQueryException as exc:
logger.error(f"Query failed for Attack Paths query `{definition.id}`: {exc}")
raise APIException(
"Attack Paths query execution failed due to a database error"
)
def _serialize_graph(graph):
nodes = []
for node in graph.nodes:
nodes.append(
{
"id": node.element_id,
"labels": list(node.labels),
"properties": _serialize_properties(node._properties),
},
)
relationships = []
for relationship in graph.relationships:
relationships.append(
{
"id": relationship.element_id,
"label": relationship.type,
"source": relationship.start_node.element_id,
"target": relationship.end_node.element_id,
"properties": _serialize_properties(relationship._properties),
},
)
return {
"nodes": nodes,
"relationships": relationships,
}
def _serialize_properties(properties: dict[str, Any]) -> dict[str, Any]:
"""Convert Neo4j property values into JSON-serializable primitives."""
def _serialize_value(value: Any) -> Any:
# Neo4j temporal and spatial values expose `to_native` returning Python primitives
if hasattr(value, "to_native") and callable(value.to_native):
return _serialize_value(value.to_native())
if isinstance(value, (list, tuple)):
return [_serialize_value(item) for item in value]
if isinstance(value, dict):
return {key: _serialize_value(val) for key, val in value.items()}
return value
return {key: _serialize_value(val) for key, val in properties.items()}
+18
View File
@@ -27,6 +27,7 @@ from api.models import (
Finding,
Integration,
Invitation,
AttackPathsScan,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Membership,
@@ -330,6 +331,23 @@ class ScanFilter(ProviderRelationshipFilterSet):
}
class AttackPathsScanFilter(ProviderRelationshipFilterSet):
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
completed_at = DateFilter(field_name="completed_at", lookup_expr="date")
started_at = DateFilter(field_name="started_at", lookup_expr="date")
state = ChoiceFilter(choices=StateChoices.choices)
state__in = ChoiceInFilter(
field_name="state", choices=StateChoices.choices, lookup_expr="in"
)
class Meta:
model = AttackPathsScan
fields = {
"provider": ["exact", "in"],
"scan": ["exact", "in"],
}
class TaskFilter(FilterSet):
name = CharFilter(field_name="task_runner_task__task_name", lookup_expr="exact")
name__icontains = CharFilter(
@@ -0,0 +1,41 @@
[
{
"model": "api.attackpathsscan",
"pk": "a7f0f6de-6f8e-4b3a-8cbe-3f6dd9012345",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"provider": "b85601a8-4b45-4194-8135-03fb980ef428",
"scan": "01920573-aa9c-73c9-bcda-f2e35c9b19d2",
"state": "completed",
"progress": 100,
"update_tag": 1693586667,
"graph_database": "db-a7f0f6de-6f8e-4b3a-8cbe-3f6dd9012345",
"is_graph_database_deleted": false,
"task": null,
"inserted_at": "2024-09-01T17:24:37Z",
"updated_at": "2024-09-01T17:44:37Z",
"started_at": "2024-09-01T17:34:37Z",
"completed_at": "2024-09-01T17:44:37Z",
"duration": 269,
"ingestion_exceptions": {}
}
},
{
"model": "api.attackpathsscan",
"pk": "4a2fb2af-8a60-4d7d-9cae-4ca65e098765",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"provider": "15fce1fa-ecaa-433f-a9dc-62553f3a2555",
"scan": "01929f3b-ed2e-7623-ad63-7c37cd37828f",
"state": "executing",
"progress": 48,
"update_tag": 1697625000,
"graph_database": "db-4a2fb2af-8a60-4d7d-9cae-4ca65e098765",
"is_graph_database_deleted": false,
"task": null,
"inserted_at": "2024-10-18T10:55:57Z",
"updated_at": "2024-10-18T10:56:15Z",
"started_at": "2024-10-18T10:56:05Z"
}
}
]
@@ -0,0 +1,154 @@
# Generated by Django 5.1.13 on 2025-11-06 16:20
import django.db.models.deletion
from django.db import migrations, models
from uuid6 import uuid7
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0059_compliance_overview_summary"),
]
operations = [
migrations.CreateModel(
name="AttackPathsScan",
fields=[
(
"id",
models.UUIDField(
default=uuid7,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"state",
api.db_utils.StateEnumField(
choices=[
("available", "Available"),
("scheduled", "Scheduled"),
("executing", "Executing"),
("completed", "Completed"),
("failed", "Failed"),
("cancelled", "Cancelled"),
],
default="available",
),
),
("progress", models.IntegerField(default=0)),
("started_at", models.DateTimeField(blank=True, null=True)),
("completed_at", models.DateTimeField(blank=True, null=True)),
(
"duration",
models.IntegerField(
blank=True, help_text="Duration in seconds", null=True
),
),
(
"update_tag",
models.BigIntegerField(
blank=True,
help_text="Cartography update tag (epoch)",
null=True,
),
),
(
"graph_database",
models.CharField(blank=True, max_length=63, null=True),
),
(
"is_graph_database_deleted",
models.BooleanField(default=False),
),
(
"ingestion_exceptions",
models.JSONField(blank=True, default=dict, null=True),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
to="api.provider",
),
),
(
"scan",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
to="api.scan",
),
),
(
"task",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
to="api.task",
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "attack_paths_scans",
"abstract": False,
"indexes": [
models.Index(
fields=["tenant_id", "provider_id", "-inserted_at"],
name="aps_prov_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "state", "-inserted_at"],
name="aps_state_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "scan_id"],
name="aps_scan_lookup_idx",
),
models.Index(
fields=["tenant_id", "provider_id"],
name="aps_active_graph_idx",
include=["graph_database", "id"],
condition=models.Q(("is_graph_database_deleted", False)),
),
models.Index(
fields=["tenant_id", "provider_id", "-completed_at"],
name="aps_completed_graph_idx",
include=["graph_database", "id"],
condition=models.Q(
("state", "completed"),
("is_graph_database_deleted", False),
),
),
],
},
),
migrations.AddConstraint(
model_name="attackpathsscan",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_attackpathsscan",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]
+95
View File
@@ -616,6 +616,101 @@ class Scan(RowLevelSecurityProtectedModel):
resource_name = "scans"
class AttackPathsScan(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()
id = models.UUIDField(primary_key=True, default=uuid7, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
state = StateEnumField(choices=StateChoices.choices, default=StateChoices.AVAILABLE)
progress = models.IntegerField(default=0)
# Timing
started_at = models.DateTimeField(null=True, blank=True)
completed_at = models.DateTimeField(null=True, blank=True)
duration = models.IntegerField(
null=True, blank=True, help_text="Duration in seconds"
)
# Relationship to the provider and optional prowler Scan and celery Task
provider = models.ForeignKey(
"Provider",
on_delete=models.CASCADE,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
)
scan = models.ForeignKey(
"Scan",
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
)
task = models.ForeignKey(
"Task",
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
)
# Cartography specific metadata
update_tag = models.BigIntegerField(
null=True, blank=True, help_text="Cartography update tag (epoch)"
)
graph_database = models.CharField(max_length=63, null=True, blank=True)
is_graph_database_deleted = models.BooleanField(default=False)
ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "attack_paths_scans"
constraints = [
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
indexes = [
models.Index(
fields=["tenant_id", "provider_id", "-inserted_at"],
name="aps_prov_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "state", "-inserted_at"],
name="aps_state_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "scan_id"],
name="aps_scan_lookup_idx",
),
models.Index(
fields=["tenant_id", "provider_id"],
name="aps_active_graph_idx",
include=["graph_database", "id"],
condition=Q(is_graph_database_deleted=False),
),
models.Index(
fields=["tenant_id", "provider_id", "-completed_at"],
name="aps_completed_graph_idx",
include=["graph_database", "id"],
condition=Q(
state=StateChoices.COMPLETED,
is_graph_database_deleted=False,
),
),
]
class JSONAPIMeta:
resource_name = "attack-paths-scans"
class ResourceTag(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
+792
View File
@@ -280,6 +280,435 @@ paths:
schema:
$ref: '#/components/schemas/OpenApiResponseResponse'
description: API key was successfully revoked
/api/v1/attack-paths-scans:
get:
operationId: attack_paths_scans_list
description: Retrieve Attack Paths scans for the tenant with support for filtering,
ordering, and pagination.
summary: List Attack Paths scans
parameters:
- in: query
name: fields[attack-paths-scans]
schema:
type: array
items:
type: string
enum:
- state
- progress
- provider
- provider_alias
- provider_type
- provider_uid
- scan
- task
- inserted_at
- started_at
- completed_at
- duration
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: query
name: filter[completed_at]
schema:
type: string
format: date
- in: query
name: filter[inserted_at]
schema:
type: string
format: date
- in: query
name: filter[provider]
schema:
type: string
format: uuid
- in: query
name: filter[provider__in]
schema:
type: array
items:
type: string
format: uuid
description: Multiple values may be separated by commas.
explode: false
style: form
- in: query
name: filter[provider_alias]
schema:
type: string
- in: query
name: filter[provider_alias__icontains]
schema:
type: string
- in: query
name: filter[provider_alias__in]
schema:
type: array
items:
type: string
description: Multiple values may be separated by commas.
explode: false
style: form
- in: query
name: filter[provider_type]
schema:
type: string
x-spec-enum-id: eca8c51e6bd28935
enum:
- aws
- azure
- gcp
- github
- iac
- kubernetes
- m365
- mongodbatlas
- oraclecloud
description: |-
* `aws` - AWS
* `azure` - Azure
* `gcp` - GCP
* `kubernetes` - Kubernetes
* `m365` - M365
* `github` - GitHub
* `mongodbatlas` - MongoDB Atlas
* `iac` - IaC
* `oraclecloud` - Oracle Cloud Infrastructure
- in: query
name: filter[provider_type__in]
schema:
type: array
items:
type: string
x-spec-enum-id: eca8c51e6bd28935
enum:
- aws
- azure
- gcp
- github
- iac
- kubernetes
- m365
- mongodbatlas
- oraclecloud
description: |-
Multiple values may be separated by commas.
* `aws` - AWS
* `azure` - Azure
* `gcp` - GCP
* `kubernetes` - Kubernetes
* `m365` - M365
* `github` - GitHub
* `mongodbatlas` - MongoDB Atlas
* `iac` - IaC
* `oraclecloud` - Oracle Cloud Infrastructure
explode: false
style: form
- in: query
name: filter[provider_uid]
schema:
type: string
- in: query
name: filter[provider_uid__icontains]
schema:
type: string
- in: query
name: filter[provider_uid__in]
schema:
type: array
items:
type: string
description: Multiple values may be separated by commas.
explode: false
style: form
- in: query
name: filter[scan]
schema:
type: string
format: uuid
- in: query
name: filter[scan__in]
schema:
type: array
items:
type: string
format: uuid
description: Multiple values may be separated by commas.
explode: false
style: form
- name: filter[search]
required: false
in: query
description: A search term.
schema:
type: string
- in: query
name: filter[started_at]
schema:
type: string
format: date
- in: query
name: filter[state]
schema:
type: string
x-spec-enum-id: d38ba07264e1ed34
enum:
- available
- cancelled
- completed
- executing
- failed
- scheduled
description: |-
* `available` - Available
* `scheduled` - Scheduled
* `executing` - Executing
* `completed` - Completed
* `failed` - Failed
* `cancelled` - Cancelled
- in: query
name: filter[state__in]
schema:
type: array
items:
type: string
x-spec-enum-id: d38ba07264e1ed34
enum:
- available
- cancelled
- completed
- executing
- failed
- scheduled
description: |-
Multiple values may be separated by commas.
* `available` - Available
* `scheduled` - Scheduled
* `executing` - Executing
* `completed` - Completed
* `failed` - Failed
* `cancelled` - Cancelled
explode: false
style: form
- in: query
name: include
schema:
type: array
items:
type: string
enum:
- provider
- scan
- task
description: include query parameter to allow the client to customize which
related resources should be returned.
explode: false
- name: page[number]
required: false
in: query
description: A page number within the paginated result set.
schema:
type: integer
- name: page[size]
required: false
in: query
description: Number of results to return per page.
schema:
type: integer
- name: sort
required: false
in: query
description: '[list of fields to sort by](https://jsonapi.org/format/#fetching-sorting)'
schema:
type: array
items:
type: string
enum:
- inserted_at
- -inserted_at
- started_at
- -started_at
explode: false
tags:
- Attack Paths
security:
- JWT or API Key: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/PaginatedAttackPathsScanList'
description: ''
/api/v1/attack-paths-scans/{id}:
get:
operationId: attack_paths_scans_retrieve
description: Fetch full details for a specific Attack Paths scan.
summary: Retrieve Attack Paths scan details
parameters:
- in: query
name: fields[attack-paths-scans]
schema:
type: array
items:
type: string
enum:
- state
- progress
- provider
- provider_alias
- provider_type
- provider_uid
- scan
- task
- inserted_at
- started_at
- completed_at
- duration
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this attack paths scan.
required: true
- in: query
name: include
schema:
type: array
items:
type: string
enum:
- provider
- scan
- task
description: include query parameter to allow the client to customize which
related resources should be returned.
explode: false
tags:
- Attack Paths
security:
- JWT or API Key: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/AttackPathsScanResponse'
description: ''
/api/v1/attack-paths-scans/{id}/queries:
get:
operationId: attack_paths_scans_queries_retrieve
description: Retrieve the catalog of Attack Paths queries available for this
Attack Paths scan.
summary: List attack paths queries
parameters:
- in: query
name: fields[attack-paths-scans]
schema:
type: array
items:
type: string
enum:
- state
- progress
- provider
- provider_alias
- provider_type
- provider_uid
- scan
- task
- inserted_at
- started_at
- completed_at
- duration
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this attack paths scan.
required: true
- in: query
name: include
schema:
type: array
items:
type: string
enum:
- provider
- scan
- task
description: include query parameter to allow the client to customize which
related resources should be returned.
explode: false
tags:
- Attack Paths
security:
- JWT or API Key: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/PaginatedAttackPathsQueryList'
description: ''
'404':
description: No queries found for the selected provider
/api/v1/attack-paths-scans/{id}/queries/run:
post:
operationId: attack_paths_scans_queries_run_create
description: Execute the selected Attack Paths query against the Attack Paths
graph and return the resulting subgraph.
summary: Execute an Attack Paths query
parameters:
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this attack paths scan.
required: true
tags:
- Attack Paths
requestBody:
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/AttackPathsQueryRunRequestRequest'
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/AttackPathsQueryRunRequestRequest'
multipart/form-data:
schema:
$ref: '#/components/schemas/AttackPathsQueryRunRequestRequest'
required: true
security:
- JWT or API Key: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/OpenApiResponseResponse'
description: ''
'400':
description: Bad request (e.g., Unknown Attack Paths query for the selected
provider)
'404':
description: No attack paths found for the given query and parameters
'500':
description: Attack Paths query execution failed due to a database error
/api/v1/compliance-overviews:
get:
operationId: compliance_overviews_list
@@ -10618,6 +11047,349 @@ paths:
description: ''
components:
schemas:
AttackPathsNode:
type: object
required:
- type
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-query-result-node
attributes:
type: object
properties:
id:
type: string
labels:
type: array
items:
type: string
properties:
type: object
additionalProperties: {}
required:
- id
- labels
- properties
AttackPathsQuery:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-query
id: {}
attributes:
type: object
properties:
id:
type: string
name:
type: string
description:
type: string
provider:
type: string
parameters:
type: array
items:
$ref: '#/components/schemas/AttackPathsQueryParameter'
required:
- id
- name
- description
- provider
- parameters
AttackPathsQueryParameter:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-query-parameter
id: {}
attributes:
type: object
properties:
name:
type: string
label:
type: string
data_type:
type: string
default: string
description:
type: string
nullable: true
placeholder:
type: string
nullable: true
required:
- name
- label
AttackPathsQueryResult:
type: object
required:
- type
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-query-result
attributes:
type: object
properties:
nodes:
type: array
items:
$ref: '#/components/schemas/AttackPathsNode'
relationships:
type: array
items:
$ref: '#/components/schemas/AttackPathsRelationship'
required:
- nodes
- relationships
AttackPathsQueryRunRequestRequest:
type: object
properties:
data:
type: object
required:
- type
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-query-run-request
attributes:
type: object
properties:
id:
type: string
minLength: 1
parameters:
type: object
additionalProperties: {}
required:
- id
required:
- data
AttackPathsRelationship:
type: object
required:
- type
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-query-result-relationship
attributes:
type: object
properties:
id:
type: string
label:
type: string
source:
type: string
target:
type: string
properties:
type: object
additionalProperties: {}
required:
- id
- label
- source
- target
- properties
AttackPathsScan:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- attack-paths-scans
id:
type: string
format: uuid
attributes:
type: object
properties:
state:
enum:
- available
- scheduled
- executing
- completed
- failed
- cancelled
type: string
description: |-
* `available` - Available
* `scheduled` - Scheduled
* `executing` - Executing
* `completed` - Completed
* `failed` - Failed
* `cancelled` - Cancelled
x-spec-enum-id: d38ba07264e1ed34
readOnly: true
progress:
type: integer
maximum: 2147483647
minimum: -2147483648
provider_alias:
type: string
readOnly: true
provider_type:
type: string
readOnly: true
provider_uid:
type: string
readOnly: true
inserted_at:
type: string
format: date-time
readOnly: true
started_at:
type: string
format: date-time
nullable: true
completed_at:
type: string
format: date-time
nullable: true
duration:
type: integer
maximum: 2147483647
minimum: -2147483648
nullable: true
description: Duration in seconds
relationships:
type: object
properties:
provider:
type: object
properties:
data:
type: object
properties:
id:
type: string
format: uuid
type:
type: string
enum:
- providers
title: Resource Type Name
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common
attributes and relationships.
required:
- id
- type
required:
- data
description: The identifier of the related object.
title: Resource Identifier
scan:
type: object
properties:
data:
type: object
properties:
id:
type: string
format: uuid
type:
type: string
enum:
- scans
title: Resource Type Name
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common
attributes and relationships.
required:
- id
- type
required:
- data
description: The identifier of the related object.
title: Resource Identifier
nullable: true
task:
type: object
properties:
data:
type: object
properties:
id:
type: string
format: uuid
type:
type: string
enum:
- tasks
title: Resource Type Name
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common
attributes and relationships.
required:
- id
- type
required:
- data
description: The identifier of the related object.
title: Resource Identifier
nullable: true
required:
- provider
AttackPathsScanResponse:
type: object
properties:
data:
$ref: '#/components/schemas/AttackPathsScan'
required:
- data
ComplianceOverview:
type: object
required:
@@ -13576,6 +14348,24 @@ components:
$ref: '#/components/schemas/OverviewSeverity'
required:
- data
PaginatedAttackPathsQueryList:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/AttackPathsQuery'
required:
- data
PaginatedAttackPathsScanList:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/AttackPathsScan'
required:
- data
PaginatedComplianceOverviewAttributesList:
type: object
properties:
@@ -19660,6 +20450,8 @@ tags:
revoking tasks that have not started.
- name: Scan
description: Endpoints for triggering manual scans and viewing scan results.
- name: Attack Paths
description: Endpoints for Attack Paths scan status and executing Attack Paths queries.
- name: Schedule
description: Endpoints for managing scan schedules, allowing configuration of automated
scans with different scheduling options.
@@ -0,0 +1,172 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from rest_framework.exceptions import APIException, ValidationError
from api.attack_paths import database as graph_database
from api.attack_paths import views_helpers
def test_normalize_run_payload_extracts_attributes_section():
payload = {
"data": {
"id": "ignored",
"attributes": {
"id": "aws-rds",
"parameters": {"ip": "192.0.2.0"},
},
}
}
result = views_helpers.normalize_run_payload(payload)
assert result == {"id": "aws-rds", "parameters": {"ip": "192.0.2.0"}}
def test_normalize_run_payload_passthrough_for_non_dict():
sentinel = "not-a-dict"
assert views_helpers.normalize_run_payload(sentinel) is sentinel
def test_prepare_query_parameters_includes_provider_and_casts(
attack_paths_query_definition_factory,
):
definition = attack_paths_query_definition_factory(cast_type=int)
result = views_helpers.prepare_query_parameters(
definition,
{"limit": "5"},
provider_uid="123456789012",
)
assert result["provider_uid"] == "123456789012"
assert result["limit"] == 5
@pytest.mark.parametrize(
"provided,expected_message",
[
({}, "Missing required parameter"),
({"limit": 10, "extra": True}, "Unknown parameter"),
],
)
def test_prepare_query_parameters_validates_names(
attack_paths_query_definition_factory, provided, expected_message
):
definition = attack_paths_query_definition_factory()
with pytest.raises(ValidationError) as exc:
views_helpers.prepare_query_parameters(definition, provided, provider_uid="1")
assert expected_message in str(exc.value)
def test_prepare_query_parameters_validates_cast(
attack_paths_query_definition_factory,
):
definition = attack_paths_query_definition_factory(cast_type=int)
with pytest.raises(ValidationError) as exc:
views_helpers.prepare_query_parameters(
definition,
{"limit": "not-an-int"},
provider_uid="1",
)
assert "Invalid value" in str(exc.value)
def test_execute_attack_paths_query_serializes_graph(
attack_paths_query_definition_factory, attack_paths_graph_stub_classes
):
definition = attack_paths_query_definition_factory(
id="aws-rds",
name="RDS",
description="",
cypher="MATCH (n) RETURN n",
parameters=[],
)
parameters = {"provider_uid": "123"}
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
node = attack_paths_graph_stub_classes.Node(
element_id="node-1",
labels=["AWSAccount"],
properties={
"name": "account",
"complex": {
"items": [
attack_paths_graph_stub_classes.NativeValue("value"),
{"nested": 1},
]
},
},
)
relationship = attack_paths_graph_stub_classes.Relationship(
element_id="rel-1",
rel_type="OWNS",
start_node=node,
end_node=attack_paths_graph_stub_classes.Node("node-2", ["RDSInstance"], {}),
properties={"weight": 1},
)
graph = SimpleNamespace(nodes=[node], relationships=[relationship])
run_result = MagicMock()
run_result.graph.return_value = graph
session = MagicMock()
session.run.return_value = run_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.views_helpers.graph_database.get_session",
return_value=session_ctx,
) as mock_get_session:
result = views_helpers.execute_attack_paths_query(
attack_paths_scan, definition, parameters
)
mock_get_session.assert_called_once_with("tenant-db")
session.run.assert_called_once_with(definition.cypher, parameters)
assert result["nodes"][0]["id"] == "node-1"
assert result["nodes"][0]["properties"]["complex"]["items"][0] == "value"
assert result["relationships"][0]["label"] == "OWNS"
def test_execute_attack_paths_query_wraps_graph_errors(
attack_paths_query_definition_factory,
):
definition = attack_paths_query_definition_factory(
id="aws-rds",
name="RDS",
description="",
cypher="MATCH (n) RETURN n",
parameters=[],
)
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
parameters = {"provider_uid": "123"}
class ExplodingContext:
def __enter__(self):
raise graph_database.GraphDatabaseQueryException("boom")
def __exit__(self, exc_type, exc, tb):
return False
with (
patch(
"api.attack_paths.views_helpers.graph_database.get_session",
return_value=ExplodingContext(),
),
patch("api.attack_paths.views_helpers.logger") as mock_logger,
):
with pytest.raises(APIException):
views_helpers.execute_attack_paths_query(
attack_paths_scan, definition, parameters
)
mock_logger.error.assert_called_once()
+418
View File
@@ -32,6 +32,10 @@ from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.attack_paths import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
from api.compliance import get_compliance_frameworks
from api.db_router import MainRouter
from api.models import (
@@ -3522,6 +3526,420 @@ class TestTaskViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
@pytest.mark.django_db
class TestAttackPathsScanViewSet:
@staticmethod
def _run_payload(query_id="aws-rds", parameters=None):
return {
"data": {
"type": "attack-paths-query-run-request",
"attributes": {
"id": query_id,
"parameters": parameters or {},
},
}
}
def test_attack_paths_scans_list_returns_latest_entry_per_provider(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
other_provider = providers_fixture[1]
older_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.AVAILABLE,
progress=10,
)
latest_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
progress=95,
)
other_provider_scan = create_attack_paths_scan(
other_provider,
scan=scans_fixture[2],
state=StateChoices.FAILED,
progress=50,
)
response = authenticated_client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
ids = {item["id"] for item in data}
assert ids == {str(latest_scan.id), str(other_provider_scan.id)}
assert str(older_scan.id) not in ids
provider_entry = next(
item
for item in data
if item["relationships"]["provider"]["data"]["id"] == str(provider.id)
)
first_attributes = provider_entry["attributes"]
assert first_attributes["provider_alias"] == provider.alias
assert first_attributes["provider_type"] == provider.provider
assert first_attributes["provider_uid"] == provider.uid
def test_attack_paths_scans_list_respects_provider_group_visibility(
self,
authenticated_client_no_permissions_rbac,
providers_fixture,
create_attack_paths_scan,
):
client = authenticated_client_no_permissions_rbac
limited_user = client.user
membership = Membership.objects.filter(user=limited_user).first()
tenant = membership.tenant
allowed_provider = providers_fixture[0]
denied_provider = providers_fixture[1]
allowed_scan = create_attack_paths_scan(allowed_provider)
create_attack_paths_scan(denied_provider)
provider_group = ProviderGroup.objects.create(
name="limited-group",
tenant_id=tenant.id,
)
ProviderGroupMembership.objects.create(
tenant_id=tenant.id,
provider_group=provider_group,
provider=allowed_provider,
)
limited_role = limited_user.roles.first()
RoleProviderGroupRelationship.objects.create(
tenant_id=tenant.id,
role=limited_role,
provider_group=provider_group,
)
response = client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
assert data[0]["id"] == str(allowed_scan.id)
def test_attack_paths_scan_retrieve(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
progress=80,
)
response = authenticated_client.get(
reverse("attack-paths-scans-detail", kwargs={"pk": attack_paths_scan.id})
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert data["id"] == str(attack_paths_scan.id)
assert data["relationships"]["provider"]["data"]["id"] == str(provider.id)
assert data["attributes"]["state"] == StateChoices.COMPLETED
def test_attack_paths_scan_retrieve_not_found_for_foreign_tenant(
self, authenticated_client, create_attack_paths_scan
):
other_tenant = Tenant.objects.create(name="Foreign AttackPaths Tenant")
foreign_provider = Provider.objects.create(
provider="aws",
uid="333333333333",
alias="foreign",
tenant_id=other_tenant.id,
)
foreign_scan = create_attack_paths_scan(foreign_provider)
response = authenticated_client.get(
reverse("attack-paths-scans-detail", kwargs={"pk": foreign_scan.id})
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_attack_paths_queries_returns_catalog(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
)
definitions = [
AttackPathsQueryDefinition(
id="aws-rds",
name="RDS inventory",
description="List account RDS assets",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
parameters=[
AttackPathsQueryParameterDefinition(name="ip", label="IP address")
],
)
]
with patch(
"api.v1.views.get_queries_for_provider", return_value=definitions
) as mock_get_queries:
response = authenticated_client.get(
reverse(
"attack-paths-scans-queries", kwargs={"pk": attack_paths_scan.id}
)
)
assert response.status_code == status.HTTP_200_OK
mock_get_queries.assert_called_once_with(provider.provider)
payload = response.json()["data"]
assert len(payload) == 1
assert payload[0]["id"] == "aws-rds"
assert payload[0]["attributes"]["name"] == "RDS inventory"
assert payload[0]["attributes"]["parameters"][0]["name"] == "ip"
def test_attack_paths_queries_returns_404_when_catalog_missing(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(provider, scan=scans_fixture[0])
with patch("api.v1.views.get_queries_for_provider", return_value=[]):
response = authenticated_client.get(
reverse(
"attack-paths-scans-queries", kwargs={"pk": attack_paths_scan.id}
)
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "No queries found" in str(response.json())
def test_run_attack_paths_query_returns_graph(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database="tenant-db",
)
query_definition = AttackPathsQueryDefinition(
id="aws-rds",
name="RDS inventory",
description="List account RDS assets",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
parameters=[],
)
prepared_parameters = {"provider_uid": provider.uid}
graph_payload = {
"nodes": [
{
"id": "node-1",
"labels": ["AWSAccount"],
"properties": {"name": "root"},
}
],
"relationships": [
{
"id": "rel-1",
"label": "OWNS",
"source": "node-1",
"target": "node-2",
"properties": {},
}
],
}
with (
patch(
"api.v1.views.get_query_by_id", return_value=query_definition
) as mock_get_query,
patch(
"api.v1.views.attack_paths_views_helpers.prepare_query_parameters",
return_value=prepared_parameters,
) as mock_prepare,
patch(
"api.v1.views.attack_paths_views_helpers.execute_attack_paths_query",
return_value=graph_payload,
) as mock_execute,
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run",
kwargs={"pk": attack_paths_scan.id},
),
data=self._run_payload("aws-rds"),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_200_OK
mock_get_query.assert_called_once_with("aws-rds")
mock_prepare.assert_called_once_with(
query_definition,
{},
attack_paths_scan.provider.uid,
)
mock_execute.assert_called_once_with(
attack_paths_scan,
query_definition,
prepared_parameters,
)
result = response.json()["data"]
attributes = result["attributes"]
assert attributes["nodes"] == graph_payload["nodes"]
assert attributes["relationships"] == graph_payload["relationships"]
def test_run_attack_paths_query_requires_completed_scan(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.EXECUTING,
)
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run", kwargs={"pk": attack_paths_scan.id}
),
data=self._run_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "must be completed" in response.json()["errors"][0]["detail"]
def test_run_attack_paths_query_requires_graph_database(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database=None,
)
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run", kwargs={"pk": attack_paths_scan.id}
),
data=self._run_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "does not reference a graph database" in str(response.json())
def test_run_attack_paths_query_unknown_query(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
)
with patch("api.v1.views.get_query_by_id", return_value=None):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run",
kwargs={"pk": attack_paths_scan.id},
),
data=self._run_payload("unknown-query"),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "Unknown Attack Paths query" in response.json()["errors"][0]["detail"]
def test_run_attack_paths_query_returns_404_when_no_nodes_found(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
)
query_definition = AttackPathsQueryDefinition(
id="aws-empty",
name="empty",
description="",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
)
with (
patch("api.v1.views.get_query_by_id", return_value=query_definition),
patch(
"api.v1.views.attack_paths_views_helpers.prepare_query_parameters",
return_value={"provider_uid": provider.uid},
),
patch(
"api.v1.views.attack_paths_views_helpers.execute_attack_paths_query",
return_value={"nodes": [], "relationships": []},
),
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run",
kwargs={"pk": attack_paths_scan.id},
),
data=self._run_payload("aws-empty"),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
payload = response.json()
if "data" in payload:
attributes = payload["data"].get("attributes", {})
assert attributes.get("nodes") == []
assert attributes.get("relationships") == []
else:
assert "errors" in payload
@pytest.mark.django_db
class TestResourceViewSet:
def test_resources_list_none(self, authenticated_client):
+104
View File
@@ -21,6 +21,7 @@ from rest_framework_simplejwt.tokens import RefreshToken
from api.db_router import MainRouter
from api.exceptions import ConflictException
from api.models import (
AttackPathsScan,
Finding,
Integration,
IntegrationProviderRelationship,
@@ -1127,6 +1128,109 @@ class ScanComplianceReportSerializer(serializers.Serializer):
fields = ["id", "name"]
class AttackPathsScanSerializer(RLSSerializer):
state = StateEnumSerializerField(read_only=True)
provider_alias = serializers.SerializerMethodField(read_only=True)
provider_type = serializers.SerializerMethodField(read_only=True)
provider_uid = serializers.SerializerMethodField(read_only=True)
class Meta:
model = AttackPathsScan
fields = [
"id",
"state",
"progress",
"provider",
"provider_alias",
"provider_type",
"provider_uid",
"scan",
"task",
"inserted_at",
"started_at",
"completed_at",
"duration",
]
included_serializers = {
"provider": "api.v1.serializers.ProviderIncludeSerializer",
"scan": "api.v1.serializers.ScanIncludeSerializer",
"task": "api.v1.serializers.TaskSerializer",
}
def get_provider_alias(self, obj):
provider = getattr(obj, "provider", None)
return provider.alias if provider else None
def get_provider_type(self, obj):
provider = getattr(obj, "provider", None)
return provider.provider if provider else None
def get_provider_uid(self, obj):
provider = getattr(obj, "provider", None)
return provider.uid if provider else None
class AttackPathsQueryParameterSerializer(serializers.Serializer):
name = serializers.CharField()
label = serializers.CharField()
data_type = serializers.CharField(default="string")
description = serializers.CharField(allow_null=True, required=False)
placeholder = serializers.CharField(allow_null=True, required=False)
class JSONAPIMeta:
resource_name = "attack-paths-query-parameter"
class AttackPathsQuerySerializer(serializers.Serializer):
id = serializers.CharField()
name = serializers.CharField()
description = serializers.CharField()
provider = serializers.CharField()
parameters = AttackPathsQueryParameterSerializer(many=True)
class JSONAPIMeta:
resource_name = "attack-paths-query"
class AttackPathsQueryRunRequestSerializer(serializers.Serializer):
id = serializers.CharField()
parameters = serializers.DictField(
child=serializers.JSONField(), allow_empty=True, required=False
)
class JSONAPIMeta:
resource_name = "attack-paths-query-run-request"
class AttackPathsNodeSerializer(serializers.Serializer):
id = serializers.CharField()
labels = serializers.ListField(child=serializers.CharField())
properties = serializers.DictField(child=serializers.JSONField())
class JSONAPIMeta:
resource_name = "attack-paths-query-result-node"
class AttackPathsRelationshipSerializer(serializers.Serializer):
id = serializers.CharField()
label = serializers.CharField()
source = serializers.CharField()
target = serializers.CharField()
properties = serializers.DictField(child=serializers.JSONField())
class JSONAPIMeta:
resource_name = "attack-paths-query-result-relationship"
class AttackPathsQueryResultSerializer(serializers.Serializer):
nodes = AttackPathsNodeSerializer(many=True)
relationships = AttackPathsRelationshipSerializer(many=True)
class JSONAPIMeta:
resource_name = "attack-paths-query-result"
class ResourceTagSerializer(RLSSerializer):
"""
Serializer for the ResourceTag model
+4
View File
@@ -4,6 +4,7 @@ from drf_spectacular.views import SpectacularRedocView
from rest_framework_nested import routers
from api.v1.views import (
AttackPathsScanViewSet,
ComplianceOverviewViewSet,
CustomSAMLLoginView,
CustomTokenObtainView,
@@ -53,6 +54,9 @@ router.register(r"tenants", TenantViewSet, basename="tenant")
router.register(r"providers", ProviderViewSet, basename="provider")
router.register(r"provider-groups", ProviderGroupViewSet, basename="providergroup")
router.register(r"scans", ScanViewSet, basename="scan")
router.register(
r"attack-paths-scans", AttackPathsScanViewSet, basename="attack-paths-scans"
)
router.register(r"tasks", TaskViewSet, basename="task")
router.register(r"resources", ResourceViewSet, basename="resource")
router.register(r"findings", FindingViewSet, basename="finding")
+223 -17
View File
@@ -3,6 +3,7 @@ import glob
import json
import logging
import os
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timedelta, timezone
@@ -10,6 +11,7 @@ from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
from urllib.parse import urljoin
import sentry_sdk
from allauth.socialaccount.models import SocialAccount, SocialApp
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
@@ -41,8 +43,9 @@ from django.db.models import (
Sum,
Value,
When,
Window,
)
from django.db.models.functions import Coalesce
from django.db.models.functions import Coalesce, RowNumber
from django.http import HttpResponse, QueryDict
from django.shortcuts import redirect
from django.urls import reverse
@@ -72,22 +75,12 @@ from rest_framework.generics import GenericAPIView, get_object_or_404
from rest_framework.permissions import SAFE_METHODS
from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
check_lighthouse_provider_connection_task,
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
jira_integration_task,
mute_historical_findings_task,
perform_scan_task,
refresh_lighthouse_provider_models_task,
)
from api.attack_paths import (
get_queries_for_provider,
get_query_by_id,
views_helpers as attack_paths_views_helpers,
)
from api.base_views import BaseRLSViewSet, BaseTenantViewset, BaseUserViewset
from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
@@ -105,6 +98,7 @@ from api.filters import (
InvitationFilter,
LatestFindingFilter,
LatestResourceFilter,
AttackPathsScanFilter,
LighthouseProviderConfigFilter,
LighthouseProviderModelsFilter,
MembershipFilter,
@@ -129,6 +123,7 @@ from api.models import (
Finding,
Integration,
Invitation,
AttackPathsScan,
LighthouseConfiguration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
@@ -170,6 +165,10 @@ from api.utils import (
from api.uuid_utils import datetime_to_uuid7, uuid7_start
from api.v1.mixins import DisablePaginationMixin, PaginateByPkMixin, TaskManagementMixin
from api.v1.serializers import (
AttackPathsQueryRunRequestSerializer,
AttackPathsQuerySerializer,
AttackPathsQueryResultSerializer,
AttackPathsScanSerializer,
ComplianceOverviewAttributesSerializer,
ComplianceOverviewDetailSerializer,
ComplianceOverviewDetailThreatscoreSerializer,
@@ -247,6 +246,22 @@ from api.v1.serializers import (
UserSerializer,
UserUpdateSerializer,
)
from tasks.beat import schedule_provider_scan
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
check_lighthouse_provider_connection_task,
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
jira_integration_task,
mute_historical_findings_task,
perform_scan_task,
refresh_lighthouse_provider_models_task,
)
logger = logging.getLogger(BackendLogger.API)
@@ -390,6 +405,10 @@ class SchemaView(SpectacularAPIView):
"name": "Scan",
"description": "Endpoints for triggering manual scans and viewing scan results.",
},
{
"name": "Attack Paths",
"description": "Endpoints for Attack Paths scan status and executing Attack Paths queries.",
},
{
"name": "Schedule",
"description": "Endpoints for managing scan schedules, allowing configuration of automated "
@@ -2140,6 +2159,12 @@ class ScanViewSet(BaseRLSViewSet):
},
)
attack_paths_db_utils.create_attack_paths_scan(
tenant_id=self.request.tenant_id,
scan_id=str(scan.id),
provider_id=str(scan.provider_id),
)
prowler_task = Task.objects.get(id=task.id)
scan.task_id = task.id
scan.save(update_fields=["task_id"])
@@ -2220,6 +2245,187 @@ class TaskViewSet(BaseRLSViewSet):
)
@extend_schema_view(
list=extend_schema(
tags=["Attack Paths"],
summary="List Attack Paths scans",
description="Retrieve Attack Paths scans for the tenant with support for filtering, ordering, and pagination.",
),
retrieve=extend_schema(
tags=["Attack Paths"],
summary="Retrieve Attack Paths scan details",
description="Fetch full details for a specific Attack Paths scan.",
),
attack_paths_queries=extend_schema(
tags=["Attack Paths"],
summary="List attack paths queries",
description="Retrieve the catalog of Attack Paths queries available for this Attack Paths scan.",
responses={
200: OpenApiResponse(AttackPathsQuerySerializer(many=True)),
404: OpenApiResponse(
description="No queries found for the selected provider"
),
},
),
run_attack_paths_query=extend_schema(
tags=["Attack Paths"],
summary="Execute an Attack Paths query",
description="Execute the selected Attack Paths query against the Attack Paths graph and return the resulting subgraph.",
request=AttackPathsQueryRunRequestSerializer,
responses={
200: OpenApiResponse(AttackPathsQueryResultSerializer),
400: OpenApiResponse(
description="Bad request (e.g., Unknown Attack Paths query for the selected provider)"
),
404: OpenApiResponse(
description="No attack paths found for the given query and parameters"
),
500: OpenApiResponse(
description="Attack Paths query execution failed due to a database error"
),
},
),
)
class AttackPathsScanViewSet(BaseRLSViewSet):
queryset = AttackPathsScan.objects.all()
serializer_class = AttackPathsScanSerializer
http_method_names = ["get", "post"]
filterset_class = AttackPathsScanFilter
ordering = ["-inserted_at"]
ordering_fields = [
"inserted_at",
"started_at",
]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_SCANS]
def set_required_permissions(self):
if self.request.method in SAFE_METHODS:
self.required_permissions = []
else:
self.required_permissions = [Permissions.MANAGE_SCANS]
def get_serializer_class(self):
if self.action == "run_attack_paths_query":
return AttackPathsQueryRunRequestSerializer
return super().get_serializer_class()
def get_queryset(self):
user_roles = get_role(self.request.user)
base_queryset = AttackPathsScan.objects.filter(tenant_id=self.request.tenant_id)
if user_roles.unlimited_visibility:
queryset = base_queryset
else:
queryset = base_queryset.filter(provider__in=get_providers(user_roles))
return queryset.select_related("provider", "scan", "task")
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
latest_per_provider = queryset.annotate(
latest_scan_rank=Window(
expression=RowNumber(),
partition_by=[F("provider_id")],
order_by=[F("inserted_at").desc()],
)
).filter(latest_scan_rank=1)
page = self.paginate_queryset(latest_per_provider)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(latest_per_provider, many=True)
return Response(serializer.data)
@extend_schema(exclude=True)
def create(self, request, *args, **kwargs):
raise MethodNotAllowed(method="POST")
@extend_schema(exclude=True)
def destroy(self, request, *args, **kwargs):
raise MethodNotAllowed(method="DELETE")
@action(
detail=True,
methods=["get"],
url_path="queries",
url_name="queries",
)
def attack_paths_queries(self, request, pk=None):
attack_paths_scan = self.get_object()
queries = get_queries_for_provider(attack_paths_scan.provider.provider)
if not queries:
return Response(
{"detail": "No queries found for the selected provider"},
status=status.HTTP_404_NOT_FOUND,
)
serializer = AttackPathsQuerySerializer(queries, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(
detail=True,
methods=["post"],
url_path="queries/run",
url_name="queries-run",
)
def run_attack_paths_query(self, request, pk=None):
attack_paths_scan = self.get_object()
if attack_paths_scan.state != StateChoices.COMPLETED:
raise ValidationError(
{
"detail": "The Attack Paths scan must be completed before running Attack Paths queries"
}
)
if not attack_paths_scan.graph_database:
logger.error(
f"The Attack Paths Scan {attack_paths_scan.id} does not reference a graph database"
)
return Response(
{"detail": "The Attack Paths scan does not reference a graph database"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
payload = attack_paths_views_helpers.normalize_run_payload(request.data)
serializer = AttackPathsQueryRunRequestSerializer(data=payload)
serializer.is_valid(raise_exception=True)
query_definition = get_query_by_id(serializer.validated_data["id"])
if (
query_definition is None
or query_definition.provider != attack_paths_scan.provider.provider
):
raise ValidationError(
{"id": "Unknown Attack Paths query for the selected provider"}
)
parameters = attack_paths_views_helpers.prepare_query_parameters(
query_definition,
serializer.validated_data.get("parameters", {}),
attack_paths_scan.provider.uid,
)
graph = attack_paths_views_helpers.execute_attack_paths_query(
attack_paths_scan, query_definition, parameters
)
status_code = status.HTTP_200_OK
if not graph.get("nodes"):
status_code = status.HTTP_404_NOT_FOUND
response_serializer = AttackPathsQueryResultSerializer(graph)
return Response(response_serializer.data, status=status_code)
@extend_schema_view(
list=extend_schema(
tags=["Resource"],
@@ -5143,7 +5349,7 @@ class TenantApiKeyViewSet(BaseRLSViewSet):
@extend_schema(exclude=True)
def destroy(self, request, *args, **kwargs):
raise MethodNotAllowed(method="DESTROY")
raise MethodNotAllowed(method="DELETE")
@action(detail=True, methods=["delete"])
def revoke(self, request, *args, **kwargs):
+1
View File
@@ -1,6 +1,7 @@
import warnings
from celery import Celery, Task
from config.env import env
# Suppress specific warnings from django-rest-auth: https://github.com/iMerica/dj-rest-auth/issues/684
+6
View File
@@ -36,6 +36,12 @@ DATABASES = {
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
},
"neo4j": {
"HOST": env.str("NEO4J_HOST", "neo4j"),
"PORT": env.str("NEO4J_PORT", "7687"),
"USER": env.str("NEO4J_USER", "neo4j"),
"PASSWORD": env.str("NEO4J_PASSWORD", "neo4j_password"),
},
}
DATABASES["default"] = DATABASES["prowler_user"]
@@ -37,6 +37,12 @@ DATABASES = {
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
},
"neo4j": {
"HOST": env.str("NEO4J_HOST"),
"PORT": env.str("NEO4J_PORT"),
"USER": env.str("NEO4J_USER"),
"PASSWORD": env.str("NEO4J_PASSWORD"),
},
}
DATABASES["default"] = DATABASES["prowler_user"]
+111 -7
View File
@@ -1,8 +1,11 @@
import logging
from types import SimpleNamespace
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
import pytest
from allauth.socialaccount.models import SocialLogin
from django.conf import settings
from django.db import connection as django_connection
@@ -11,10 +14,14 @@ from django.urls import reverse
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.test import APIClient
from tasks.jobs.backfill import backfill_resource_scan_summaries
from api.attack_paths import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
from api.db_utils import rls_transaction
from api.models import (
AttackPathsScan,
ComplianceOverview,
ComplianceRequirementOverview,
Finding,
@@ -47,6 +54,7 @@ from api.rls import Tenant
from api.v1.serializers import TokenSerializer
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
from tasks.jobs.backfill import backfill_resource_scan_summaries
TODAY = str(datetime.today().date())
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
@@ -159,22 +167,20 @@ def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker, tenants_f
@pytest.fixture(scope="function")
def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
def create_test_user_rbac_limited(django_db_setup, django_db_blocker, tenants_fixture):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing_limited",
email="rbac_limited@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
tenant = tenants_fixture[0]
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
Role.objects.create(
role = Role.objects.create(
name="limited",
tenant_id=tenant.id,
manage_users=False,
@@ -187,7 +193,7 @@ def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
)
UserRoleRelationship.objects.create(
user=user,
role=Role.objects.get(name="limited"),
role=role,
tenant_id=tenant.id,
)
return user
@@ -1469,6 +1475,104 @@ def mute_rules_fixture(tenants_fixture, create_test_user, findings_fixture):
return mute_rule1, mute_rule2
@pytest.fixture
def create_attack_paths_scan():
"""Factory fixture to create Attack Paths scans for tests."""
def _create(
provider,
*,
scan=None,
state=StateChoices.COMPLETED,
progress=0,
graph_database="tenant-db",
**extra_fields,
):
scan_instance = scan or Scan.objects.create(
name=extra_fields.pop("scan_name", "Attack Paths Supporting Scan"),
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=extra_fields.pop("scan_state", StateChoices.COMPLETED),
tenant_id=provider.tenant_id,
)
payload = {
"tenant_id": provider.tenant_id,
"provider": provider,
"scan": scan_instance,
"state": state,
"progress": progress,
"graph_database": graph_database,
}
payload.update(extra_fields)
return AttackPathsScan.objects.create(**payload)
return _create
@pytest.fixture
def attack_paths_query_definition_factory():
"""Factory fixture for building Attack Paths query definitions."""
def _create(**overrides):
cast_type = overrides.pop("cast_type", str)
parameters = overrides.pop(
"parameters",
[
AttackPathsQueryParameterDefinition(
name="limit",
label="Limit",
cast=cast_type,
)
],
)
definition_payload = {
"id": "aws-test",
"name": "Attack Paths Test Query",
"description": "Synthetic Attack Paths definition for tests.",
"provider": "aws",
"cypher": "RETURN 1",
"parameters": parameters,
}
definition_payload.update(overrides)
return AttackPathsQueryDefinition(**definition_payload)
return _create
@pytest.fixture
def attack_paths_graph_stub_classes():
"""Provide lightweight graph element stubs for Attack Paths serialization tests."""
class AttackPathsNativeValue:
def __init__(self, value):
self._value = value
def to_native(self):
return self._value
class AttackPathsNode:
def __init__(self, element_id, labels, properties):
self.element_id = element_id
self.labels = labels
self._properties = properties
class AttackPathsRelationship:
def __init__(self, element_id, rel_type, start_node, end_node, properties):
self.element_id = element_id
self.type = rel_type
self.start_node = start_node
self.end_node = end_node
self._properties = properties
return SimpleNamespace(
NativeValue=AttackPathsNativeValue,
Node=AttackPathsNode,
Relationship=AttackPathsRelationship,
)
def get_authorization_header(access_token: str) -> dict:
return {"Authorization": f"Bearer {access_token}"}
+7
View File
@@ -7,6 +7,7 @@ from tasks.tasks import perform_scheduled_scan_task
from api.db_utils import rls_transaction
from api.exceptions import ConflictException
from api.models import Provider, Scan, StateChoices
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
def schedule_provider_scan(provider_instance: Provider):
@@ -39,6 +40,12 @@ def schedule_provider_scan(provider_instance: Provider):
scheduled_at=datetime.now(timezone.utc),
)
attack_paths_db_utils.create_attack_paths_scan(
tenant_id=tenant_id,
scan_id=str(scheduled_scan.id),
provider_id=provider_id,
)
# Schedule the task
periodic_task_instance = PeriodicTask.objects.create(
interval=schedule,
@@ -0,0 +1,5 @@
from tasks.jobs.attack_paths.scan import run as attack_paths_scan
__all__ = [
"attack_paths_scan",
]
@@ -0,0 +1,237 @@
# Portions of this file are based on code from the Cartography project
# (https://github.com/cartography-cncf/cartography), which is licensed under the Apache 2.0 License.
from typing import Any
import aioboto3
import boto3
import neo4j
from cartography.config import Config as CartographyConfig
from cartography.intel import aws as cartography_aws
from celery.utils.log import get_task_logger
from api.models import (
AttackPathsScan as ProwlerAPIAttackPathsScan,
Provider as ProwlerAPIProvider,
)
from prowler.providers.common.provider import Provider as ProwlerSDKProvider
from tasks.jobs.attack_paths import db_utils, utils
logger = get_task_logger(__name__)
def start_aws_ingestion(
neo4j_session: neo4j.Session,
cartography_config: CartographyConfig,
prowler_api_provider: ProwlerAPIProvider,
prowler_sdk_provider: ProwlerSDKProvider,
attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> dict[str, dict[str, str]]:
"""
Code based on Cartography version 0.122.0, specifically on `cartography.intel.aws.__init__.py`.
For the scan progress updates:
- The caller of this function (`tasks.jobs.attack_paths.scan.run`) has set it to 2.
- When the control returns to the caller, it will be set to 95.
"""
# Initialize variables common to all jobs
common_job_parameters = {
"UPDATE_TAG": cartography_config.update_tag,
"permission_relationships_file": cartography_config.permission_relationships_file,
"aws_guardduty_severity_threshold": cartography_config.aws_guardduty_severity_threshold,
"aws_cloudtrail_management_events_lookback_hours": cartography_config.aws_cloudtrail_management_events_lookback_hours,
"experimental_aws_inspector_batch": cartography_config.experimental_aws_inspector_batch,
}
boto3_session = get_boto3_session(prowler_api_provider, prowler_sdk_provider)
regions: list[str] = list(prowler_sdk_provider._enabled_regions)
requested_syncs = list(cartography_aws.RESOURCE_FUNCTIONS.keys())
sync_args = cartography_aws._build_aws_sync_kwargs(
neo4j_session,
boto3_session,
regions,
prowler_api_provider.uid,
cartography_config.update_tag,
common_job_parameters,
)
# Starting with sync functions
cartography_aws.organizations.sync(
neo4j_session,
{prowler_api_provider.alias: prowler_api_provider.uid},
cartography_config.update_tag,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 3)
# Adding an extra field
common_job_parameters["AWS_ID"] = prowler_api_provider.uid
cartography_aws._autodiscover_accounts(
neo4j_session,
boto3_session,
prowler_api_provider.uid,
cartography_config.update_tag,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 4)
failed_syncs = sync_aws_account(
prowler_api_provider, requested_syncs, sync_args, attack_paths_scan
)
if "permission_relationships" in requested_syncs:
cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"](**sync_args)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 88)
if "resourcegroupstaggingapi" in requested_syncs:
cartography_aws.RESOURCE_FUNCTIONS["resourcegroupstaggingapi"](**sync_args)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 89)
cartography_aws.run_scoped_analysis_job(
"aws_ec2_iaminstanceprofile.json",
neo4j_session,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 90)
cartography_aws.run_analysis_job(
"aws_lambda_ecr.json",
neo4j_session,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 91)
cartography_aws.merge_module_sync_metadata(
neo4j_session,
group_type="AWSAccount",
group_id=prowler_api_provider.uid,
synced_type="AWSAccount",
update_tag=cartography_config.update_tag,
stat_handler=cartography_aws.stat_handler,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 92)
# Removing the added extra field
del common_job_parameters["AWS_ID"]
cartography_aws.run_cleanup_job(
"aws_post_ingestion_principals_cleanup.json",
neo4j_session,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 93)
cartography_aws._perform_aws_analysis(
requested_syncs, neo4j_session, common_job_parameters
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 94)
return failed_syncs
def get_boto3_session(
prowler_api_provider: ProwlerAPIProvider, prowler_sdk_provider: ProwlerSDKProvider
) -> boto3.Session:
boto3_session = prowler_sdk_provider.session.current_session
aws_accounts_from_session = cartography_aws.organizations.get_aws_account_default(
boto3_session
)
if not aws_accounts_from_session:
raise Exception(
"No valid AWS credentials could be found. No AWS accounts can be synced."
)
aws_account_id_from_session = list(aws_accounts_from_session.values())[0]
if prowler_api_provider.uid != aws_account_id_from_session:
raise Exception(
f"Provider {prowler_api_provider.uid} doesn't match AWS account {aws_account_id_from_session}."
)
if boto3_session.region_name is None:
global_region = prowler_sdk_provider.get_global_region()
boto3_session._session.set_config_variable("region", global_region)
return boto3_session
def get_aioboto3_session(boto3_session: boto3.Session) -> aioboto3.Session:
return aioboto3.Session(botocore_session=boto3_session._session)
def sync_aws_account(
prowler_api_provider: ProwlerAPIProvider,
requested_syncs: list[str],
sync_args: dict[str, Any],
attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> dict[str, str]:
current_progress = 4 # `cartography_aws._autodiscover_accounts`
max_progress = (
87 # `cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"]` - 1
)
n_steps = (
len(requested_syncs) - 2
) # Excluding `permission_relationships` and `resourcegroupstaggingapi`
progress_step = (max_progress - current_progress) / n_steps
failed_syncs = {}
for func_name in requested_syncs:
if func_name in cartography_aws.RESOURCE_FUNCTIONS:
logger.info(
f"Syncing function {func_name} for AWS account {prowler_api_provider.uid}"
)
# Updating progress, not really the right place but good enough
current_progress += progress_step
db_utils.update_attack_paths_scan_progress(
attack_paths_scan, int(current_progress)
)
try:
# `ecr:image_layers` uses `aioboto3_session` instead of `boto3_session`
if func_name == "ecr:image_layers":
cartography_aws.RESOURCE_FUNCTIONS[func_name](
neo4j_session=sync_args.get("neo4j_session"),
aioboto3_session=get_aioboto3_session(
sync_args.get("boto3_session")
),
regions=sync_args.get("regions"),
current_aws_account_id=sync_args.get("current_aws_account_id"),
update_tag=sync_args.get("update_tag"),
common_job_parameters=sync_args.get("common_job_parameters"),
)
# Skip permission relationships and tags for now because they rely on data already being in the graph
elif func_name in [
"permission_relationships",
"resourcegroupstaggingapi",
]:
continue
else:
cartography_aws.RESOURCE_FUNCTIONS[func_name](**sync_args)
except Exception as e:
exception_message = utils.stringify_exception(
e, f"Exception for AWS sync function: {func_name}"
)
failed_syncs[func_name] = exception_message
logger.warning(
f"Caught exception syncing function {func_name} from AWS account {prowler_api_provider.uid}. We "
"are continuing on to the next AWS sync function.",
)
continue
else:
raise ValueError(
f'AWS sync function "{func_name}" was specified but does not exist. Did you misspell it?'
)
return failed_syncs
@@ -0,0 +1,158 @@
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from cartography.config import Config as CartographyConfig
from api.db_utils import rls_transaction
from api.models import (
AttackPathsScan as ProwlerAPIAttackPathsScan,
Provider as ProwlerAPIProvider,
StateChoices,
)
from tasks.jobs.attack_paths.providers import is_provider_available
def create_attack_paths_scan(
tenant_id: str,
scan_id: str,
provider_id: int,
) -> ProwlerAPIAttackPathsScan | None:
with rls_transaction(tenant_id):
prowler_api_provider = ProwlerAPIProvider.objects.get(id=provider_id)
if not is_provider_available(prowler_api_provider.provider):
return None
with rls_transaction(tenant_id):
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
state=StateChoices.SCHEDULED,
started_at=datetime.now(tz=timezone.utc),
)
attack_paths_scan.save()
return attack_paths_scan
def retrieve_attack_paths_scan(
tenant_id: str,
scan_id: str,
) -> ProwlerAPIAttackPathsScan | None:
try:
with rls_transaction(tenant_id):
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.get(
scan_id=scan_id,
)
return attack_paths_scan
except ProwlerAPIAttackPathsScan.DoesNotExist:
return None
def starting_attack_paths_scan(
attack_paths_scan: ProwlerAPIAttackPathsScan,
task_id: str,
cartography_config: CartographyConfig,
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.task_id = task_id
attack_paths_scan.state = StateChoices.EXECUTING
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
attack_paths_scan.update_tag = cartography_config.update_tag
attack_paths_scan.graph_database = cartography_config.neo4j_database
attack_paths_scan.save(
update_fields=[
"task_id",
"state",
"started_at",
"update_tag",
"graph_database",
]
)
def finish_attack_paths_scan(
attack_paths_scan: ProwlerAPIAttackPathsScan,
state: StateChoices,
ingestion_exceptions: dict[str, Any],
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
now = datetime.now(tz=timezone.utc)
duration = int((now - attack_paths_scan.started_at).total_seconds())
attack_paths_scan.state = state
attack_paths_scan.progress = 100
attack_paths_scan.completed_at = now
attack_paths_scan.duration = duration
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
attack_paths_scan.save(
update_fields=[
"state",
"progress",
"completed_at",
"duration",
"ingestion_exceptions",
]
)
def update_attack_paths_scan_progress(
attack_paths_scan: ProwlerAPIAttackPathsScan,
progress: int,
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.progress = progress
attack_paths_scan.save(update_fields=["progress"])
def get_old_attack_paths_scans(
tenant_id: str,
provider_id: str,
attack_paths_scan_id: str,
) -> list[ProwlerAPIAttackPathsScan]:
"""
An `old_attack_paths_scan` is any `completed` Attack Paths scan for the same provider,
with its graph database not deleted, excluding the current Attack Paths scan.
"""
with rls_transaction(tenant_id):
completed_scans_qs = (
ProwlerAPIAttackPathsScan.objects.filter(
provider_id=provider_id,
state=StateChoices.COMPLETED,
is_graph_database_deleted=False,
)
.exclude(id=attack_paths_scan_id)
.all()
)
return list(completed_scans_qs)
def update_old_attack_paths_scan(
old_attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> None:
with rls_transaction(old_attack_paths_scan.tenant_id):
old_attack_paths_scan.is_graph_database_deleted = True
old_attack_paths_scan.save(update_fields=["is_graph_database_deleted"])
def get_provider_graph_database_names(tenant_id: str, provider_id: str) -> list[str]:
"""
Return existing graph database names for a tenant/provider.
Note: For accesing the `AttackPathsScan` we need to use `all_objects` manager because the provider is soft-deleted.
"""
with rls_transaction(tenant_id):
graph_databases_names_qs = ProwlerAPIAttackPathsScan.all_objects.filter(
provider_id=provider_id,
is_graph_database_deleted=False,
).values_list("graph_database", flat=True)
return list(graph_databases_names_qs)
@@ -0,0 +1,23 @@
AVAILABLE_PROVIDERS: list[str] = [
"aws",
]
ROOT_NODE_LABELS: dict[str, str] = {
"aws": "AWSAccount",
}
NODE_UID_FIELDS: dict[str, str] = {
"aws": "arn",
}
def is_provider_available(provider_type: str) -> bool:
return provider_type in AVAILABLE_PROVIDERS
def get_root_node_label(provider_type: str) -> str:
return ROOT_NODE_LABELS.get(provider_type, "UnknownProviderAccount")
def get_node_uid_field(provider_type: str) -> str:
return NODE_UID_FIELDS.get(provider_type, "UnknownProviderUID")
@@ -0,0 +1,205 @@
import neo4j
from cartography.client.core.tx import run_write_query
from cartography.config import Config as CartographyConfig
from celery.utils.log import get_task_logger
from api.db_utils import rls_transaction
from api.models import Provider, ResourceFindingMapping
from config.env import env
from prowler.config import config as ProwlerConfig
from tasks.jobs.attack_paths.providers import get_node_uid_field, get_root_node_label
logger = get_task_logger(__name__)
BATCH_SIZE = env.int("NEO4J_INSERT_BATCH_SIZE", 500)
INDEX_STATEMENTS = [
"CREATE INDEX prowler_finding_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.id);",
"CREATE INDEX prowler_finding_provider_uid IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.provider_uid);",
"CREATE INDEX prowler_finding_lastupdated IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.lastupdated);",
"CREATE INDEX prowler_finding_check_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.status);",
]
INSERT_STATEMENT_TEMPLATE = """
UNWIND $findings_data AS finding_data
MATCH (account:__ROOT_NODE_LABEL__ {id: $provider_uid})
MATCH (account)-->(resource)
WHERE resource.__NODE_UID_FIELD__ = finding_data.resource_uid
OR resource.id = finding_data.resource_uid
MERGE (finding:ProwlerFinding {id: finding_data.id})
ON CREATE SET
finding.id = finding_data.id,
finding.uid = finding_data.uid,
finding.inserted_at = finding_data.inserted_at,
finding.updated_at = finding_data.updated_at,
finding.first_seen_at = finding_data.first_seen_at,
finding.scan_id = finding_data.scan_id,
finding.delta = finding_data.delta,
finding.status = finding_data.status,
finding.status_extended = finding_data.status_extended,
finding.severity = finding_data.severity,
finding.check_id = finding_data.check_id,
finding.check_title = finding_data.check_title,
finding.muted = finding_data.muted,
finding.muted_reason = finding_data.muted_reason,
finding.provider_uid = $provider_uid,
finding.firstseen = timestamp(),
finding.lastupdated = $last_updated,
finding._module_name = 'cartography:prowler',
finding._module_version = $prowler_version
ON MATCH SET
finding.status = finding_data.status,
finding.status_extended = finding_data.status_extended,
finding.lastupdated = $last_updated
MERGE (resource)-[rel:HAS_FINDING]->(finding)
ON CREATE SET
rel.provider_uid = $provider_uid,
rel.firstseen = timestamp(),
rel.lastupdated = $last_updated,
rel._module_name = 'cartography:prowler',
rel._module_version = $prowler_version
ON MATCH SET
rel.lastupdated = $last_updated
"""
CLEANUP_STATEMENT = """
MATCH (finding:ProwlerFinding {provider_uid: $provider_uid})
WHERE finding.lastupdated < $last_updated
WITH finding LIMIT $batch_size
DETACH DELETE finding
RETURN COUNT(finding) AS deleted_findings_count
"""
def create_indexes(neo4j_session: neo4j.Session) -> None:
"""
Code based on Cartography version 0.122.0, specifically on `cartography.intel.create_indexes.run`.
"""
logger.info("Creating indexes for Prowler node types.")
for statement in INDEX_STATEMENTS:
logger.debug("Executing statement: %s", statement)
run_write_query(neo4j_session, statement)
def analysis(
neo4j_session: neo4j.Session,
prowler_api_provider: Provider,
scan_id: str,
config: CartographyConfig,
) -> None:
findings_data = get_provider_last_scan_findings(prowler_api_provider, scan_id)
load_findings(neo4j_session, findings_data, prowler_api_provider, config)
cleanup_findings(neo4j_session, prowler_api_provider, config)
def get_provider_last_scan_findings(
prowler_api_provider: Provider,
scan_id: str,
) -> list[dict[str, str]]:
with rls_transaction(prowler_api_provider.tenant_id):
resource_finding_qs = ResourceFindingMapping.objects.filter(
finding__scan_id=scan_id,
).values(
"resource__uid",
"finding__id",
"finding__uid",
"finding__inserted_at",
"finding__updated_at",
"finding__first_seen_at",
"finding__scan_id",
"finding__delta",
"finding__status",
"finding__status_extended",
"finding__severity",
"finding__check_id",
"finding__check_metadata__checktitle",
"finding__muted",
"finding__muted_reason",
)
findings = []
for resource_finding in resource_finding_qs:
findings.append(
{
"resource_uid": str(resource_finding["resource__uid"]),
"id": str(resource_finding["finding__id"]),
"uid": resource_finding["finding__uid"],
"inserted_at": resource_finding["finding__inserted_at"],
"updated_at": resource_finding["finding__updated_at"],
"first_seen_at": resource_finding["finding__first_seen_at"],
"scan_id": str(resource_finding["finding__scan_id"]),
"delta": resource_finding["finding__delta"],
"status": resource_finding["finding__status"],
"status_extended": resource_finding["finding__status_extended"],
"severity": resource_finding["finding__severity"],
"check_id": str(resource_finding["finding__check_id"]),
"check_title": resource_finding[
"finding__check_metadata__checktitle"
],
"muted": resource_finding["finding__muted"],
"muted_reason": resource_finding["finding__muted_reason"],
}
)
return findings
def load_findings(
neo4j_session: neo4j.Session,
findings_data: list[dict[str, str]],
prowler_api_provider: Provider,
config: CartographyConfig,
) -> None:
replacements = {
"__ROOT_NODE_LABEL__": get_root_node_label(prowler_api_provider.provider),
"__NODE_UID_FIELD__": get_node_uid_field(prowler_api_provider.provider),
}
query = INSERT_STATEMENT_TEMPLATE
for replace_key, replace_value in replacements.items():
query = query.replace(replace_key, replace_value)
parameters = {
"provider_uid": str(prowler_api_provider.uid),
"last_updated": config.update_tag,
"prowler_version": ProwlerConfig.prowler_version,
}
total_length = len(findings_data)
for i in range(0, total_length, BATCH_SIZE):
parameters["findings_data"] = findings_data[i : i + BATCH_SIZE]
logger.info(
f"Loading findings batch {i // BATCH_SIZE + 1} / {(total_length + BATCH_SIZE - 1) // BATCH_SIZE}"
)
neo4j_session.run(query, parameters)
def cleanup_findings(
neo4j_session: neo4j.Session,
prowler_api_provider: Provider,
config: CartographyConfig,
) -> None:
parameters = {
"provider_uid": str(prowler_api_provider.uid),
"last_updated": config.update_tag,
"batch_size": BATCH_SIZE,
}
batch = 1
deleted_count = 1
while deleted_count > 0:
logger.info(f"Cleaning findings batch {batch}")
result = neo4j_session.run(CLEANUP_STATEMENT, parameters)
deleted_count = result.single().get("deleted_findings_count", 0)
batch += 1
@@ -0,0 +1,183 @@
import logging
import time
import asyncio
from typing import Any, Callable
from cartography.config import Config as CartographyConfig
from cartography.intel import analysis as cartography_analysis
from cartography.intel import create_indexes as cartography_create_indexes
from cartography.intel import ontology as cartography_ontology
from celery.utils.log import get_task_logger
from api.attack_paths import database as graph_database
from api.db_utils import rls_transaction
from api.models import (
Provider as ProwlerAPIProvider,
StateChoices,
)
from api.utils import initialize_prowler_provider
from tasks.jobs.attack_paths import aws, db_utils, prowler, utils
# Without this Celery goes crazy with Cartography logging
logging.getLogger("cartography").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
logger = get_task_logger(__name__)
CARTOGRAPHY_INGESTION_FUNCTIONS: dict[str, Callable] = {
"aws": aws.start_aws_ingestion,
}
def get_cartography_ingestion_function(provider_type: str) -> Callable | None:
return CARTOGRAPHY_INGESTION_FUNCTIONS.get(provider_type)
def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
"""
Code based on Cartography version 0.122.0, specifically on `cartography.cli.main`, `cartography.cli.CLI.main`,
`cartography.sync.run_with_config` and `cartography.sync.Sync.run`.
"""
ingestion_exceptions = {} # This will hold any exceptions raised during ingestion
# Prowler necessary objects
with rls_transaction(tenant_id):
prowler_api_provider = ProwlerAPIProvider.objects.get(scan__pk=scan_id)
prowler_sdk_provider = initialize_prowler_provider(prowler_api_provider)
# Attack Paths Scan necessary objects
cartography_ingestion_function = get_cartography_ingestion_function(
prowler_api_provider.provider
)
attack_paths_scan = db_utils.retrieve_attack_paths_scan(tenant_id, scan_id)
# Checks before starting the scan
if not cartography_ingestion_function:
ingestion_exceptions = {
"global_error": f"Provider {prowler_api_provider.provider} is not supported for Attack Paths scans"
}
if attack_paths_scan:
db_utils.finish_attack_paths_scan(
attack_paths_scan, StateChoices.COMPLETED, ingestion_exceptions
)
logger.warning(
f"Provider {prowler_api_provider.provider} is not supported for Attack Paths scans"
)
return ingestion_exceptions
else:
if not attack_paths_scan:
logger.warning(
f"No Attack Paths Scan found for scan {scan_id} and tenant {tenant_id}, let's create it then"
)
attack_paths_scan = db_utils.create_attack_paths_scan(
tenant_id, scan_id, prowler_api_provider.id
)
# While creating the Cartography configuration, attributes `neo4j_user` and `neo4j_password` are not really needed in this config object
cartography_config = CartographyConfig(
neo4j_uri=graph_database.get_uri(),
neo4j_database=graph_database.get_database_name(attack_paths_scan.id),
update_tag=int(time.time()),
)
# Starting the Attack Paths scan
db_utils.starting_attack_paths_scan(attack_paths_scan, task_id, cartography_config)
try:
logger.info(
f"Creating Neo4j database {cartography_config.neo4j_database} for tenant {prowler_api_provider.tenant_id}"
)
graph_database.create_database(cartography_config.neo4j_database)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 1)
logger.info(
f"Starting Cartography ({attack_paths_scan.id}) for "
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
)
with graph_database.get_session(
cartography_config.neo4j_database
) as neo4j_session:
# Indexes creation
cartography_create_indexes.run(neo4j_session, cartography_config)
prowler.create_indexes(neo4j_session)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 2)
# The real scan, where iterates over cloud services
ingestion_exceptions = _call_within_event_loop(
cartography_ingestion_function,
neo4j_session,
cartography_config,
prowler_api_provider,
prowler_sdk_provider,
attack_paths_scan,
)
# Post-processing: Just keeping it to be more Cartography compliant
cartography_ontology.run(neo4j_session, cartography_config)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 95)
cartography_analysis.run(neo4j_session, cartography_config)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 96)
# Adding Prowler nodes and relationships
prowler.analysis(
neo4j_session, prowler_api_provider, scan_id, cartography_config
)
logger.info(
f"Completed Cartography ({attack_paths_scan.id}) for "
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
)
# Handling databases changes
old_attack_paths_scans = db_utils.get_old_attack_paths_scans(
prowler_api_provider.tenant_id,
prowler_api_provider.id,
attack_paths_scan.id,
)
for old_attack_paths_scan in old_attack_paths_scans:
graph_database.drop_database(old_attack_paths_scan.graph_database)
db_utils.update_old_attack_paths_scan(old_attack_paths_scan)
db_utils.finish_attack_paths_scan(
attack_paths_scan, StateChoices.COMPLETED, ingestion_exceptions
)
return ingestion_exceptions
except Exception as e:
exception_message = utils.stringify_exception(e, "Cartography failed")
logger.error(exception_message)
ingestion_exceptions["global_cartography_error"] = exception_message
# Handling databases changes
graph_database.drop_database(cartography_config.neo4j_database)
db_utils.finish_attack_paths_scan(
attack_paths_scan, StateChoices.FAILED, ingestion_exceptions
)
raise
def _call_within_event_loop(fn, *args, **kwargs):
"""
Cartography needs a running event loop, so assuming there is none (Celery task or even regular DRF endpoint),
let's create a new one and set it as the current event loop for this thread.
"""
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return fn(*args, **kwargs)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
except Exception:
pass
loop.close()
asyncio.set_event_loop(None)
@@ -0,0 +1,10 @@
import traceback
from datetime import datetime, timezone
def stringify_exception(exception: Exception, context: str) -> str:
timestamp = datetime.now(tz=timezone.utc)
exception_traceback = traceback.TracebackException.from_exception(exception)
traceback_string = "".join(exception_traceback.format())
return f"{timestamp} - {context}\n{traceback_string}"
+24 -2
View File
@@ -1,9 +1,19 @@
from celery.utils.log import get_task_logger
from django.db import DatabaseError
from api.attack_paths import database as graph_database
from api.db_router import MainRouter
from api.db_utils import batch_delete, rls_transaction
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant
from api.models import (
AttackPathsScan,
Finding,
Provider,
Resource,
Scan,
ScanSummary,
Tenant,
)
from tasks.jobs.attack_paths.db_utils import get_provider_graph_database_names
logger = get_task_logger(__name__)
@@ -23,16 +33,27 @@ def delete_provider(tenant_id: str, pk: str):
Raises:
Provider.DoesNotExist: If no instance with the provided primary key exists.
"""
# Delete the Attack Paths' graph databases related to the provider
graph_database_names = get_provider_graph_database_names(tenant_id, pk)
try:
for graph_database_name in graph_database_names:
graph_database.drop_database(graph_database_name)
except graph_database.GraphDatabaseQueryException as gdb_error:
logger.error(f"Error deleting Provider databases: {gdb_error}")
raise
# Get all provider related data and delete them in batches
with rls_transaction(tenant_id):
instance = Provider.all_objects.get(pk=pk)
deletion_summary = {}
deletion_steps = [
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
("Findings", Finding.all_objects.filter(scan__provider=instance)),
("Resources", Resource.all_objects.filter(provider=instance)),
("Scans", Scan.all_objects.filter(provider=instance)),
("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)),
]
deletion_summary = {}
for step_name, queryset in deletion_steps:
try:
_, step_summary = batch_delete(tenant_id, queryset)
@@ -48,6 +69,7 @@ def delete_provider(tenant_id: str, pk: str):
except DatabaseError as db_error:
logger.error(f"Error deleting Provider: {db_error}")
raise
return deletion_summary
+36 -12
View File
@@ -1,13 +1,26 @@
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from shutil import rmtree
from celery import chain, group, shared_task
from celery.utils.log import get_task_logger
from django_celery_beat.models import PeriodicTask
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
from tasks.jobs.attack_paths import attack_paths_scan
from tasks.jobs.backfill import (
backfill_compliance_summaries,
backfill_resource_scan_summaries,
@@ -43,17 +56,6 @@ from tasks.jobs.scan import (
)
from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
@@ -86,6 +88,9 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
),
),
).apply_async()
perform_attack_paths_scan_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
@shared_task(base=RLSTask, name="provider-connection-check")
@@ -281,6 +286,25 @@ def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
# TODO: This task must be queued at the `attack-paths` queue, don't forget to add it to the `docker-entrypoint.sh` file
@shared_task(base=RLSTask, bind=True, name="attack-paths-scan-perform", queue="scans")
def perform_attack_paths_scan_task(self, tenant_id: str, scan_id: str):
"""
Execute an Attack Paths scan for the given provider within the current tenant RLS context.
Args:
self: The task instance (automatically passed when bind=True).
tenant_id (str): The tenant identifier for RLS context.
scan_id (str): The Prowler scan identifier for obtaining the tenant and provider context.
Returns:
Any: The result from `attack_paths_scan`, including any per-scan failure details.
"""
return attack_paths_scan(
tenant_id=tenant_id, scan_id=scan_id, task_id=self.request.id
)
@shared_task(name="tenant-deletion", queue="deletion", autoretry_for=(Exception,))
def delete_tenant_task(tenant_id: str):
return delete_tenant(pk=tenant_id)
@@ -0,0 +1,416 @@
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock, call, patch
import pytest
from api.models import (
AttackPathsScan,
Finding,
Provider,
Resource,
ResourceFindingMapping,
Scan,
StateChoices,
StatusChoices,
)
from prowler.lib.check.models import Severity
from tasks.jobs.attack_paths import prowler as prowler_module
from tasks.jobs.attack_paths.scan import run as attack_paths_run
@pytest.mark.django_db
class TestAttackPathsRun:
def test_run_success_flow(self, tenants_fixture, providers_fixture, scans_fixture):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
attack_paths_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.SCHEDULED,
)
mock_session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
ingestion_result = {"organizations": "warning"}
ingestion_fn = MagicMock(return_value=ingestion_result)
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
return_value="bolt://neo4j",
),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id",
) as mock_get_db_name,
patch(
"tasks.jobs.attack_paths.scan.graph_database.create_database"
) as mock_create_db,
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_session",
return_value=session_ctx,
) as mock_get_session,
patch(
"tasks.jobs.attack_paths.scan.cartography_create_indexes.run"
) as mock_cartography_indexes,
patch(
"tasks.jobs.attack_paths.scan.cartography_analysis.run"
) as mock_cartography_analysis,
patch(
"tasks.jobs.attack_paths.scan.cartography_ontology.run"
) as mock_cartography_ontology,
patch(
"tasks.jobs.attack_paths.scan.prowler.create_indexes"
) as mock_prowler_indexes,
patch(
"tasks.jobs.attack_paths.scan.prowler.analysis"
) as mock_prowler_analysis,
patch(
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
) as mock_retrieve_scan,
patch(
"tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan"
) as mock_starting,
patch(
"tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress"
) as mock_update_progress,
patch(
"tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan"
) as mock_finish,
patch(
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
return_value=ingestion_fn,
) as mock_get_ingestion,
patch(
"tasks.jobs.attack_paths.scan._call_within_event_loop",
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
) as mock_event_loop,
):
result = attack_paths_run(str(tenant.id), str(scan.id), "task-123")
assert result == ingestion_result
mock_retrieve_scan.assert_called_once_with(str(tenant.id), str(scan.id))
mock_starting.assert_called_once()
config = mock_starting.call_args[0][2]
assert config.neo4j_database == "db-scan-id"
mock_create_db.assert_called_once_with("db-scan-id")
mock_get_session.assert_called_once_with("db-scan-id")
mock_cartography_indexes.assert_called_once_with(mock_session, config)
mock_prowler_indexes.assert_called_once_with(mock_session)
mock_cartography_analysis.assert_called_once_with(mock_session, config)
mock_cartography_ontology.assert_called_once_with(mock_session, config)
mock_prowler_analysis.assert_called_once_with(
mock_session,
provider,
str(scan.id),
config,
)
assert mock_get_ingestion.call_args_list == [
call(provider.provider),
call(provider.provider),
]
mock_event_loop.assert_called_once()
mock_update_progress.assert_any_call(attack_paths_scan, 1)
mock_update_progress.assert_any_call(attack_paths_scan, 2)
mock_update_progress.assert_any_call(attack_paths_scan, 95)
mock_finish.assert_called_once_with(
attack_paths_scan, StateChoices.COMPLETED, ingestion_result
)
mock_get_db_name.assert_called_once_with(attack_paths_scan.id)
def test_run_failure_marks_scan_failed(
self, tenants_fixture, providers_fixture, scans_fixture
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
attack_paths_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.SCHEDULED,
)
mock_session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
ingestion_fn = MagicMock(side_effect=RuntimeError("ingestion boom"))
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
),
patch("tasks.jobs.attack_paths.scan.graph_database.get_uri"),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id",
),
patch("tasks.jobs.attack_paths.scan.graph_database.create_database"),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_session",
return_value=session_ctx,
),
patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run"),
patch("tasks.jobs.attack_paths.scan.cartography_analysis.run"),
patch("tasks.jobs.attack_paths.scan.prowler.create_indexes"),
patch("tasks.jobs.attack_paths.scan.prowler.analysis"),
patch(
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
),
patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan"),
patch(
"tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress"
),
patch(
"tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan"
) as mock_finish,
patch(
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
return_value=ingestion_fn,
),
patch(
"tasks.jobs.attack_paths.scan._call_within_event_loop",
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
),
patch(
"tasks.jobs.attack_paths.scan.utils.stringify_exception",
return_value="Cartography failed: ingestion boom",
),
):
with pytest.raises(RuntimeError, match="ingestion boom"):
attack_paths_run(str(tenant.id), str(scan.id), "task-456")
failure_args = mock_finish.call_args[0]
assert failure_args[0] is attack_paths_scan
assert failure_args[1] == StateChoices.FAILED
assert failure_args[2] == {
"global_cartography_error": "Cartography failed: ingestion boom"
}
def test_run_returns_early_for_unsupported_provider(self, tenants_fixture):
tenant = tenants_fixture[0]
provider = Provider.objects.create(
provider=Provider.ProviderChoices.GCP,
uid="gcp-account",
alias="gcp",
tenant_id=tenant.id,
)
scan = Scan.objects.create(
name="GCP Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
tenant_id=tenant.id,
)
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(),
),
patch(
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
return_value=None,
) as mock_get_ingestion,
patch(
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan"
) as mock_retrieve,
):
result = attack_paths_run(str(tenant.id), str(scan.id), "task-789")
assert result == {}
mock_get_ingestion.assert_called_once_with(provider.provider)
mock_retrieve.assert_not_called()
@pytest.mark.django_db
class TestAttackPathsProwlerHelpers:
def test_create_indexes_executes_all_statements(self):
mock_session = MagicMock()
with patch("tasks.jobs.attack_paths.prowler.run_write_query") as mock_run_write:
prowler_module.create_indexes(mock_session)
assert mock_run_write.call_count == len(prowler_module.INDEX_STATEMENTS)
mock_run_write.assert_has_calls(
[call(mock_session, stmt) for stmt in prowler_module.INDEX_STATEMENTS]
)
def test_load_findings_batches_requests(self, providers_fixture):
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
findings = [
{"id": "1", "resource_uid": "r-1"},
{"id": "2", "resource_uid": "r-2"},
]
config = SimpleNamespace(update_tag=12345)
mock_session = MagicMock()
with (
patch.object(prowler_module, "BATCH_SIZE", 1),
patch(
"tasks.jobs.attack_paths.prowler.get_root_node_label",
return_value="AWSAccount",
),
patch(
"tasks.jobs.attack_paths.prowler.get_node_uid_field",
return_value="arn",
),
):
prowler_module.load_findings(mock_session, findings, provider, config)
assert mock_session.run.call_count == 2
for call_args in mock_session.run.call_args_list:
params = call_args.args[1]
assert params["provider_uid"] == str(provider.uid)
assert params["last_updated"] == config.update_tag
assert "findings_data" in params
def test_cleanup_findings_runs_batches(self, providers_fixture):
provider = providers_fixture[0]
config = SimpleNamespace(update_tag=1024)
mock_session = MagicMock()
first_batch = MagicMock()
first_batch.single.return_value = {"deleted_findings_count": 3}
second_batch = MagicMock()
second_batch.single.return_value = {"deleted_findings_count": 0}
mock_session.run.side_effect = [first_batch, second_batch]
prowler_module.cleanup_findings(mock_session, provider, config)
assert mock_session.run.call_count == 2
params = mock_session.run.call_args.args[1]
assert params["provider_uid"] == str(provider.uid)
assert params["last_updated"] == config.update_tag
def test_get_provider_last_scan_findings_returns_latest_scan_data(
self,
tenants_fixture,
providers_fixture,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
resource = Resource.objects.create(
tenant_id=tenant.id,
provider=provider,
uid="resource-uid",
name="Resource",
region="us-east-1",
service="ec2",
type="instance",
)
older_scan = Scan.objects.create(
name="Older",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
)
old_finding = Finding.objects.create(
tenant_id=tenant.id,
uid="older-finding",
scan=older_scan,
delta=Finding.DeltaChoices.NEW,
status=StatusChoices.PASS,
status_extended="ok",
severity=Severity.low,
impact=Severity.low,
impact_extended="",
raw_result={},
check_id="check-old",
check_metadata={"checktitle": "Old"},
first_seen_at=older_scan.inserted_at,
)
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
resource=resource,
finding=old_finding,
)
latest_scan = Scan.objects.create(
name="Latest",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
)
finding = Finding.objects.create(
tenant_id=tenant.id,
uid="finding-uid",
scan=latest_scan,
delta=Finding.DeltaChoices.NEW,
status=StatusChoices.FAIL,
status_extended="failed",
severity=Severity.high,
impact=Severity.high,
impact_extended="",
raw_result={},
check_id="check-1",
check_metadata={"checktitle": "Check title"},
first_seen_at=latest_scan.inserted_at,
)
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
resource=resource,
finding=finding,
)
latest_scan.refresh_from_db()
with patch(
"tasks.jobs.attack_paths.prowler.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
findings_data = prowler_module.get_provider_last_scan_findings(
provider,
str(latest_scan.id),
)
assert len(findings_data) == 1
finding_dict = findings_data[0]
assert finding_dict["id"] == str(finding.id)
assert finding_dict["resource_uid"] == resource.uid
assert finding_dict["check_title"] == "Check title"
assert finding_dict["scan_id"] == str(latest_scan.id)
+98 -30
View File
@@ -1,27 +1,60 @@
from unittest.mock import call, patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider, delete_tenant
from api.models import Provider, Tenant
from tasks.jobs.deletion import delete_provider, delete_tenant
@pytest.mark.django_db
class TestDeleteProvider:
def test_delete_provider_success(self, providers_fixture):
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
result = delete_provider(tenant_id, instance.id)
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
graph_db_names = ["graph-db-1", "graph-db-2"]
mock_get_provider_graph_database_names.return_value = graph_db_names
assert result
with pytest.raises(ObjectDoesNotExist):
Provider.objects.get(pk=instance.id)
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
result = delete_provider(tenant_id, instance.id)
assert result
with pytest.raises(ObjectDoesNotExist):
Provider.objects.get(pk=instance.id)
mock_get_provider_graph_database_names.assert_called_once_with(
tenant_id, instance.id
)
mock_drop_database.assert_has_calls(
[call(graph_db_name) for graph_db_name in graph_db_names]
)
def test_delete_provider_does_not_exist(self, tenants_fixture):
tenant_id = str(tenants_fixture[0].id)
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
graph_db_names = ["graph-db-1"]
mock_get_provider_graph_database_names.return_value = graph_db_names
with pytest.raises(ObjectDoesNotExist):
delete_provider(tenant_id, non_existent_pk)
tenant_id = str(tenants_fixture[0].id)
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
with pytest.raises(ObjectDoesNotExist):
delete_provider(tenant_id, non_existent_pk)
mock_get_provider_graph_database_names.assert_called_once_with(
tenant_id, non_existent_pk
)
mock_drop_database.assert_has_calls(
[call(graph_db_name) for graph_db_name in graph_db_names]
)
@pytest.mark.django_db
@@ -30,33 +63,68 @@ class TestDeleteTenant:
"""
Test successful deletion of a tenant and its related data.
"""
tenant = tenants_fixture[0]
providers = Provider.objects.filter(tenant_id=tenant.id)
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
tenant = tenants_fixture[0]
providers = list(Provider.objects.filter(tenant_id=tenant.id))
# Ensure the tenant and related providers exist before deletion
assert Tenant.objects.filter(id=tenant.id).exists()
assert providers.exists()
graph_db_names_per_provider = [
[f"graph-db-{provider.id}"] for provider in providers
]
mock_get_provider_graph_database_names.side_effect = (
graph_db_names_per_provider
)
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
# Ensure the tenant and related providers exist before deletion
assert Tenant.objects.filter(id=tenant.id).exists()
assert providers
assert deletion_summary is not None
assert not Tenant.objects.filter(id=tenant.id).exists()
assert not Provider.objects.filter(tenant_id=tenant.id).exists()
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
assert deletion_summary is not None
assert not Tenant.objects.filter(id=tenant.id).exists()
assert not Provider.objects.filter(tenant_id=tenant.id).exists()
expected_calls = [
call(provider.tenant_id, provider.id) for provider in providers
]
mock_get_provider_graph_database_names.assert_has_calls(
expected_calls, any_order=True
)
assert mock_get_provider_graph_database_names.call_count == len(
expected_calls
)
expected_drop_calls = [
call(graph_db_name[0]) for graph_db_name in graph_db_names_per_provider
]
mock_drop_database.assert_has_calls(expected_drop_calls, any_order=True)
assert mock_drop_database.call_count == len(expected_drop_calls)
def test_delete_tenant_with_no_providers(self, tenants_fixture):
"""
Test deletion of a tenant with no related providers.
"""
tenant = tenants_fixture[1] # Assume this tenant has no providers
providers = Provider.objects.filter(tenant_id=tenant.id)
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
tenant = tenants_fixture[1] # Assume this tenant has no providers
providers = Provider.objects.filter(tenant_id=tenant.id)
# Ensure the tenant exists but has no related providers
assert Tenant.objects.filter(id=tenant.id).exists()
assert not providers.exists()
# Ensure the tenant exists but has no related providers
assert Tenant.objects.filter(id=tenant.id).exists()
assert not providers.exists()
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
assert deletion_summary == {} # No providers, so empty summary
assert not Tenant.objects.filter(id=tenant.id).exists()
assert deletion_summary == {} # No providers, so empty summary
assert not Tenant.objects.filter(id=tenant.id).exists()
mock_get_provider_graph_database_names.assert_not_called()
mock_drop_database.assert_not_called()
+77 -9
View File
@@ -1,24 +1,28 @@
import uuid
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import openai
import pytest
from botocore.exceptions import ClientError
from tasks.tasks import (
_perform_scan_complete_tasks,
check_integrations_task,
check_lighthouse_provider_connection_task,
generate_outputs_task,
refresh_lighthouse_provider_models_task,
s3_integration_task,
security_hub_integration_task,
)
from api.models import (
Integration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
)
from tasks.tasks import (
_perform_scan_complete_tasks,
check_integrations_task,
check_lighthouse_provider_connection_task,
generate_outputs_task,
perform_attack_paths_scan_task,
refresh_lighthouse_provider_models_task,
s3_integration_task,
security_hub_integration_task,
)
# TODO Move this to outputs/reports jobs
@@ -529,6 +533,7 @@ class TestGenerateOutputs:
class TestScanCompleteTasks:
@patch("tasks.tasks.perform_attack_paths_scan_task.apply_async")
@patch("tasks.tasks.create_compliance_requirements_task.apply_async")
@patch("tasks.tasks.perform_scan_summary_task.si")
@patch("tasks.tasks.generate_outputs_task.si")
@@ -541,6 +546,7 @@ class TestScanCompleteTasks:
mock_outputs_task,
mock_scan_summary_task,
mock_compliance_requirements_task,
mock_attack_paths_task,
):
"""Test that scan complete tasks are properly orchestrated with optimized reports."""
_perform_scan_complete_tasks("tenant-id", "scan-id", "provider-id")
@@ -577,6 +583,68 @@ class TestScanCompleteTasks:
scan_id="scan-id",
)
mock_attack_paths_task.assert_called_once_with(
kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"}
)
class TestAttackPathsTasks:
@staticmethod
@contextmanager
def _override_task_request(task, **attrs):
request = task.request
sentinel = object()
previous = {key: getattr(request, key, sentinel) for key in attrs}
for key, value in attrs.items():
setattr(request, key, value)
try:
yield
finally:
for key, prev in previous.items():
if prev is sentinel:
if hasattr(request, key):
delattr(request, key)
else:
setattr(request, key, prev)
def test_perform_attack_paths_scan_task_calls_runner(self):
with (
patch("tasks.tasks.attack_paths_scan") as mock_attack_paths_scan,
self._override_task_request(
perform_attack_paths_scan_task, id="celery-task-id"
),
):
mock_attack_paths_scan.return_value = {"status": "ok"}
result = perform_attack_paths_scan_task.run(
tenant_id="tenant-id", scan_id="scan-id"
)
mock_attack_paths_scan.assert_called_once_with(
tenant_id="tenant-id", scan_id="scan-id", task_id="celery-task-id"
)
assert result == {"status": "ok"}
def test_perform_attack_paths_scan_task_propagates_exception(self):
with (
patch(
"tasks.tasks.attack_paths_scan",
side_effect=RuntimeError("Exception to propagate"),
) as mock_attack_paths_scan,
self._override_task_request(
perform_attack_paths_scan_task, id="celery-task-error"
),
):
with pytest.raises(RuntimeError, match="Exception to propagate"):
perform_attack_paths_scan_task.run(
tenant_id="tenant-id", scan_id="scan-id"
)
mock_attack_paths_scan.assert_called_once_with(
tenant_id="tenant-id", scan_id="scan-id", task_id="celery-task-error"
)
@pytest.mark.django_db
class TestCheckIntegrationsTask:
+46 -1
View File
@@ -1,6 +1,7 @@
services:
api-dev:
hostname: "prowler-api"
# image: prowler-api-dev
build:
context: ./api
dockerfile: Dockerfile
@@ -24,6 +25,8 @@ services:
condition: service_healthy
valkey:
condition: service_healthy
neo4j:
condition: service_healthy
entrypoint:
- "/home/prowler/docker-entrypoint.sh"
- "dev"
@@ -78,7 +81,41 @@ services:
timeout: 5s
retries: 3
neo4j:
image: graphstack/dozerdb:5.26.3.0
hostname: "neo4j"
volumes:
- ./_data/neo4j:/data
environment:
# We can't add our .env file because some of our current variables are not compatible with Neo4j env vars
# Auth
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
# Memory limits
- NEO4J_dbms_max__databases=${NEO4J_DBMS_MAX__DATABASES:-1000000}
- NEO4J_server_memory_pagecache_size=${NEO4J_SERVER_MEMORY_PAGECACHE_SIZE:-1G}
- NEO4J_server_memory_heap_initial__size=${NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE:-1G}
- NEO4J_server_memory_heap_max__size=${NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE:-1G}
# APOC
- apoc.export.file.enabled=${NEO4J_POC_EXPORT_FILE_ENABLED:-true}
- apoc.import.file.enabled=${NEO4J_APOC_IMPORT_FILE_ENABLED:-true}
- apoc.import.file.use_neo4j_config=${NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG:-true}
- "NEO4J_PLUGINS=${NEO4J_PLUGINS:-[\"apoc\"]}"
- "NEO4J_dbms_security_procedures_allowlist=${NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST:-apoc.*}"
- "NEO4J_dbms_security_procedures_unrestricted=${NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED:-apoc.*}"
# Networking
- "dbms.connector.bolt.listen_address=${NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS:-0.0.0.0:7687}"
# 7474 is the UI port
ports:
- 7474:7474
- ${NEO4J_PORT:-7687}:7687
healthcheck:
test: ["CMD", "wget", "--no-verbose", "http://localhost:7474"]
interval: 10s
timeout: 10s
retries: 10
worker-dev:
# image: prowler-api-dev
build:
context: ./api
dockerfile: Dockerfile
@@ -89,17 +126,23 @@ services:
- path: .env
required: false
volumes:
- "outputs:/tmp/prowler_api_output"
- ./api/src/backend:/home/prowler/backend
- ./api/pyproject.toml:/home/prowler/pyproject.toml
- ./api/docker-entrypoint.sh:/home/prowler/docker-entrypoint.sh
- outputs:/tmp/prowler_api_output
depends_on:
valkey:
condition: service_healthy
postgres:
condition: service_healthy
neo4j:
condition: service_healthy
entrypoint:
- "/home/prowler/docker-entrypoint.sh"
- "worker"
worker-beat:
# image: prowler-api-dev
build:
context: ./api
dockerfile: Dockerfile
@@ -114,6 +157,8 @@ services:
condition: service_healthy
postgres:
condition: service_healthy
neo4j:
condition: service_healthy
entrypoint:
- "../docker-entrypoint.sh"
- "beat"
+31
View File
@@ -63,6 +63,37 @@ services:
timeout: 5s
retries: 3
neo4j:
image: graphstack/dozerdb:5.26.3.0
hostname: "neo4j"
volumes:
- ./_data/neo4j:/data
environment:
# We can't add our .env file because some of our current variables are not compatible with Neo4j env vars
# Auth
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
# Memory limits
- NEO4J_dbms_max__databases=${NEO4J_DBMS_MAX__DATABASES:-1000000}
- NEO4J_server_memory_pagecache_size=${NEO4J_SERVER_MEMORY_PAGECACHE_SIZE:-1G}
- NEO4J_server_memory_heap_initial__size=${NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE:-1G}
- NEO4J_server_memory_heap_max__size=${NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE:-1G}
# APOC
- apoc.export.file.enabled=${NEO4J_POC_EXPORT_FILE_ENABLED:-true}
- apoc.import.file.enabled=${NEO4J_APOC_IMPORT_FILE_ENABLED:-true}
- apoc.import.file.use_neo4j_config=${NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG:-true}
- "NEO4J_PLUGINS=${NEO4J_PLUGINS:-[\"apoc\"]}"
- "NEO4J_dbms_security_procedures_allowlist=${NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST:-apoc.*}"
- "NEO4J_dbms_security_procedures_unrestricted=${NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED:-apoc.*}"
# Networking
- "dbms.connector.bolt.listen_address=${NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS:-0.0.0.0:7687}"
ports:
- ${NEO4J_PORT:-7687}:7687
healthcheck:
test: ["CMD", "wget", "--no-verbose", "http://localhost:7474"]
interval: 10s
timeout: 10s
retries: 10
worker:
image: prowlercloud/prowler-api:${PROWLER_API_VERSION:-stable}
env_file:
Generated
+26 -19
View File
@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand.
[[package]]
name = "about-time"
@@ -938,34 +938,34 @@ files = [
[[package]]
name = "boto3"
version = "1.39.15"
version = "1.40.61"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.9"
groups = ["main", "dev"]
files = [
{file = "boto3-1.39.15-py3-none-any.whl", hash = "sha256:38fc54576b925af0075636752de9974e172c8a2cf7133400e3e09b150d20fb6a"},
{file = "boto3-1.39.15.tar.gz", hash = "sha256:b4483625f0d8c35045254dee46cd3c851bbc0450814f20b9b25bee1b5c0d8409"},
{file = "boto3-1.40.61-py3-none-any.whl", hash = "sha256:6b9c57b2a922b5d8c17766e29ed792586a818098efe84def27c8f582b33f898c"},
{file = "boto3-1.40.61.tar.gz", hash = "sha256:d6c56277251adf6c2bdd25249feae625abe4966831676689ff23b4694dea5b12"},
]
[package.dependencies]
botocore = ">=1.39.15,<1.40.0"
botocore = ">=1.40.61,<1.41.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.13.0,<0.14.0"
s3transfer = ">=0.14.0,<0.15.0"
[package.extras]
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.39.15"
version = "1.40.61"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.9"
groups = ["main", "dev"]
files = [
{file = "botocore-1.39.15-py3-none-any.whl", hash = "sha256:eb9cfe918ebfbfb8654e1b153b29f0c129d586d2c0d7fb4032731d49baf04cff"},
{file = "botocore-1.39.15.tar.gz", hash = "sha256:2aa29a717f14f8c7ca058c2e297aaed0aa10ecea24b91514eee802814d1b7600"},
{file = "botocore-1.40.61-py3-none-any.whl", hash = "sha256:17ebae412692fd4824f99cde0f08d50126dc97954008e5ba2b522eb049238aa7"},
{file = "botocore-1.40.61.tar.gz", hash = "sha256:a2487ad69b090f9cccd64cf07c7021cd80ee9c0655ad974f87045b02f3ef52cd"},
]
[package.dependencies]
@@ -977,7 +977,7 @@ urllib3 = [
]
[package.extras]
crt = ["awscrt (==0.23.8)"]
crt = ["awscrt (==0.27.6)"]
[[package]]
name = "cachetools"
@@ -2366,6 +2366,8 @@ python-versions = "*"
groups = ["dev"]
files = [
{file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"},
{file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"},
{file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"},
]
[package.dependencies]
@@ -4884,6 +4886,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"},
{file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"},
@@ -4892,6 +4895,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"},
{file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"},
@@ -4900,6 +4904,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"},
{file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"},
@@ -4908,6 +4913,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"},
{file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"},
@@ -4916,6 +4922,7 @@ files = [
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"},
{file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"},
{file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"},
@@ -4923,14 +4930,14 @@ files = [
[[package]]
name = "s3transfer"
version = "0.13.1"
version = "0.14.0"
description = "An Amazon S3 Transfer Manager"
optional = false
python-versions = ">=3.9"
groups = ["main", "dev"]
files = [
{file = "s3transfer-0.13.1-py3-none-any.whl", hash = "sha256:a981aa7429be23fe6dfc13e80e4020057cbab622b08c0315288758d67cabc724"},
{file = "s3transfer-0.13.1.tar.gz", hash = "sha256:c3fdba22ba1bd367922f27ec8032d6a1cf5f10c934fb5d68cf60fd5a23d936cf"},
{file = "s3transfer-0.14.0-py3-none-any.whl", hash = "sha256:ea3b790c7077558ed1f02a3072fb3cb992bbbd253392f4b6e9e8976941c7d456"},
{file = "s3transfer-0.14.0.tar.gz", hash = "sha256:eff12264e7c8b4985074ccce27a3b38a485bb7f7422cc8046fee9be4983e4125"},
]
[package.dependencies]
@@ -5075,18 +5082,18 @@ files = [
[[package]]
name = "slack-sdk"
version = "3.34.0"
version = "3.39.0"
description = "The Slack API Platform SDK for Python"
optional = false
python-versions = ">=3.6"
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "slack_sdk-3.34.0-py2.py3-none-any.whl", hash = "sha256:c61f57f310d85be83466db5a98ab6ae3bb2e5587437b54fa0daa8fae6a0feffa"},
{file = "slack_sdk-3.34.0.tar.gz", hash = "sha256:ff61db7012160eed742285ea91f11c72b7a38a6500a7f6c5335662b4bc6b853d"},
{file = "slack_sdk-3.39.0-py2.py3-none-any.whl", hash = "sha256:b1556b2f5b8b12b94e5ea3f56c4f2c7f04462e4e1013d325c5764ff118044fa8"},
{file = "slack_sdk-3.39.0.tar.gz", hash = "sha256:6a56be10dc155c436ff658c6b776e1c082e29eae6a771fccf8b0a235822bbcb1"},
]
[package.extras]
optional = ["SQLAlchemy (>=1.4,<3)", "aiodns (>1.0)", "aiohttp (>=3.7.3,<4)", "boto3 (<=2)", "websocket-client (>=1,<2)", "websockets (>=9.1,<15)"]
optional = ["SQLAlchemy (>=1.4,<3)", "aiodns (>1.0)", "aiohttp (>=3.7.3,<4)", "boto3 (<=2)", "websocket-client (>=1,<2)", "websockets (>=9.1,<16)"]
[[package]]
name = "sniffio"
@@ -5688,4 +5695,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">3.9.1,<3.13"
content-hash = "a367e65bc43c0a16495a3d0f6eab8b356cc49b509e329b61c6641cd87f374ff4"
content-hash = "82015f7b4b08e419ac5d28eab1a2d4b563b1980c84679e020ed3d42d3b4e9b85"
+3 -3
View File
@@ -40,8 +40,8 @@ dependencies = [
"azure-mgmt-loganalytics==12.0.0",
"azure-monitor-query==2.0.0",
"azure-storage-blob==12.24.1",
"boto3==1.39.15",
"botocore==1.39.15",
"boto3==1.40.61",
"botocore==1.40.61",
"colorama==0.4.6",
"cryptography==44.0.1",
"dash==3.1.1",
@@ -64,7 +64,7 @@ dependencies = [
"pytz==2025.1",
"schema==0.7.5",
"shodan==1.31.0",
"slack-sdk==3.34.0",
"slack-sdk==3.39.0",
"tabulate==0.9.0",
"tzlocal==5.3.1",
"py-iam-expand==0.1.0",
+10 -3
View File
@@ -37,8 +37,8 @@ CODE_REVIEW_ENABLED=$(echo "$CODE_REVIEW_ENABLED" | tr '[:upper:]' '[:lower:]')
echo -e "${BLUE}️ Code Review Status: ${CODE_REVIEW_ENABLED}${NC}"
echo ""
# Get staged files (what will be committed)
STAGED_FILES=$(git diff --cached --name-only --diff-filter=ACM | grep -E '\.(tsx?|jsx?)$' || true)
# Get staged files in the UI folder only (what will be committed)
STAGED_FILES=$(git diff --cached --name-only --diff-filter=ACM -- 'ui/**' | grep -E '\.(tsx?|jsx?)$' || true)
if [ "$CODE_REVIEW_ENABLED" = "true" ]; then
if [ -z "$STAGED_FILES" ]; then
@@ -135,7 +135,14 @@ else
echo ""
fi
# Run healthcheck (typecheck and lint check)
# Check if there are any UI files to validate
if [ -z "$STAGED_FILES" ] && [ "$CODE_REVIEW_ENABLED" = "true" ]; then
echo -e "${YELLOW}⏭️ No UI files to validate, skipping healthcheck${NC}"
echo ""
exit 0
fi
# Run healthcheck (typecheck and lint check) only if there are UI changes
echo -e "${BLUE}🏥 Running healthcheck...${NC}"
echo ""
+1
View File
@@ -24,6 +24,7 @@ All notable changes to the **Prowler UI** are documented in this file.
- PDF reporting for NIS2 compliance framework [(#9170)](https://github.com/prowler-cloud/prowler/pull/9170)
- External resource link to IaC findings for direct navigation to source code in Git repositories [(#9151)](https://github.com/prowler-cloud/prowler/pull/9151)
- New Overview page and new app styles [(#9234)](https://github.com/prowler-cloud/prowler/pull/9234)
- Attack Paths feature with query execution and graph visualization [(#PROWLER-383)](https://github.com/prowler-cloud/prowler/pull/9270)
- Use branch name as region for IaC findings [(#9296)](https://github.com/prowler-cloud/prowler/pull/9296)
### 🔄 Changed
+4
View File
@@ -0,0 +1,4 @@
export * from "./queries";
export * from "./queries.adapter";
export * from "./scans";
export * from "./scans.adapter";
@@ -0,0 +1,55 @@
import { MetaDataProps } from "@/types";
import {
AttackPathQueriesResponse,
AttackPathQuery,
} from "@/types/attack-paths";
/**
* Adapts raw query API responses to enriched domain models
* - Enriches queries with metadata and computed properties
* - Co-locates related data for better performance
* - Preserves pagination metadata for list operations
*
* Uses plugin architecture for extensibility:
* - Handles query-specific response transformation
* - Can be composed with backend service plugins
* - Maintains separation of concerns between API layer and business logic
*/
/**
* Adapt attack path queries response with enriched data
*
* @param response - Raw API response from attack-paths-scans/{id}/queries endpoint
* @returns Enriched queries data with metadata
*/
export function adaptAttackPathQueriesResponse(
response: AttackPathQueriesResponse | undefined,
): {
data: AttackPathQuery[];
metadata?: MetaDataProps;
} {
if (!response?.data) {
return { data: [] };
}
// Enrich query data with computed properties
const enrichedData = response.data.map((query) => ({
...query,
// Can add computed properties here, e.g.:
// parameterCount: query.attributes.parameters.length,
// requiredParameters: query.attributes.parameters.filter(p => p.required),
// hasParameters: query.attributes.parameters.length > 0,
}));
const metadata: MetaDataProps | undefined = {
pagination: {
page: 1,
pages: 1,
count: enrichedData.length,
itemsPerPage: [10, 25, 50, 100],
},
version: "1.0",
};
return { data: enrichedData, metadata };
}
+97
View File
@@ -0,0 +1,97 @@
"use server";
import { z } from "zod";
import { apiBaseUrl, getAuthHeaders } from "@/lib";
import { handleApiResponse } from "@/lib/server-actions-helper";
import {
AttackPathQueriesResponse,
AttackPathQuery,
AttackPathQueryResult,
ExecuteQueryRequest,
} from "@/types/attack-paths";
import { adaptAttackPathQueriesResponse } from "./queries.adapter";
// Validation schema for UUID - RFC 9562/4122 compliant
const UUIDSchema = z.uuid();
/**
* Fetch available queries for a specific attack path scan
*/
export const getAvailableQueries = async (
scanId: string,
): Promise<{ data: AttackPathQuery[] } | undefined> => {
// Validate scanId is a valid UUID format to prevent request forgery
const validatedScanId = UUIDSchema.safeParse(scanId);
if (!validatedScanId.success) {
console.error("Invalid scan ID format");
return undefined;
}
const headers = await getAuthHeaders({ contentType: false });
try {
const response = await fetch(
`${apiBaseUrl}/attack-paths-scans/${validatedScanId.data}/queries`,
{
headers,
method: "GET",
},
);
const apiResponse = (await handleApiResponse(
response,
)) as AttackPathQueriesResponse;
const adaptedData = adaptAttackPathQueriesResponse(apiResponse);
return { data: adaptedData.data };
} catch (error) {
console.error("Error fetching available queries for scan:", error);
return undefined;
}
};
/**
* Execute a query on an attack path scan
*/
export const executeQuery = async (
scanId: string,
queryId: string,
parameters?: Record<string, string | number | boolean>,
): Promise<AttackPathQueryResult | undefined> => {
// Validate scanId is a valid UUID format to prevent request forgery
const validatedScanId = UUIDSchema.safeParse(scanId);
if (!validatedScanId.success) {
console.error("Invalid scan ID format");
return undefined;
}
const headers = await getAuthHeaders({ contentType: true });
const requestBody: ExecuteQueryRequest = {
data: {
type: "attack-paths-query-run-request",
attributes: {
id: queryId,
...(parameters && { parameters }),
},
},
};
try {
const response = await fetch(
`${apiBaseUrl}/attack-paths-scans/${validatedScanId.data}/queries/run`,
{
headers,
method: "POST",
body: JSON.stringify(requestBody),
},
);
return handleApiResponse(response);
} catch (error) {
console.error("Error executing query on scan:", error);
return undefined;
}
};
@@ -0,0 +1,164 @@
import {
AttackPathGraphData,
GraphEdge,
GraphNodeProperties,
GraphNodePropertyValue,
GraphRelationship,
} from "@/types/attack-paths";
/**
* Normalizes property values to ensure they are primitives
* Arrays are converted to comma-separated strings
*
* @param value - The property value to normalize
* @returns Normalized primitive value
*/
function normalizePropertyValue(
value:
| GraphNodePropertyValue
| GraphNodePropertyValue[]
| Record<string, unknown>,
): string | number | boolean | null | undefined {
if (value === null || value === undefined) {
return value;
}
if (Array.isArray(value)) {
// Convert arrays to comma-separated strings
return value.join(", ");
}
if (
typeof value === "string" ||
typeof value === "number" ||
typeof value === "boolean"
) {
return value;
}
// For any other type, convert to string
return String(value);
}
/**
* Normalizes all properties in an object to ensure they are primitives
*
* @param properties - The properties object to normalize
* @returns Normalized properties object
*/
function normalizeProperties(
properties: Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[] | Record<string, unknown>
>,
): GraphNodeProperties {
const normalized: GraphNodeProperties = {};
for (const [key, value] of Object.entries(properties)) {
normalized[key] = normalizePropertyValue(value);
}
return normalized;
}
/**
* Adapts graph query result data for D3 visualization
* Transforms relationships array into edges array for D3 force-directed graph
*
* The adapter handles:
* - Converting relationship objects to edge objects compatible with D3
* - Mapping relationship labels to edge types for graph styling
* - Normalizing array properties to strings (e.g., anonymous_actions: ["s3:GetObject"] -> "s3:GetObject")
* - Preserving node and relationship data structure
* - Adding findings array to each node based on HAS_FINDING edges
* - Adding resources array to finding nodes based on HAS_FINDING edges (reverse relationship)
*
* @param graphData - Raw graph data with nodes and relationships from API
* @returns Graph data with edges array formatted for D3 visualization and findings/resources on nodes
*/
export function adaptQueryResultToGraphData(
graphData: AttackPathGraphData,
): AttackPathGraphData {
// Normalize node properties to ensure all values are primitives
const normalizedNodes = graphData.nodes.map((node) => ({
...node,
properties: normalizeProperties(
node.properties as Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[]
>,
),
findings: [] as string[], // Will be populated below
resources: [] as string[], // Will be populated below for finding nodes
}));
// Transform relationships into D3-compatible edges if relationships exist
// Also handle case where edges are already provided (e.g., from mock data)
let edges: GraphEdge[] = [];
if (graphData.relationships) {
edges = (graphData.relationships as GraphRelationship[]).map(
(relationship) => ({
id: relationship.id,
source: relationship.source,
target: relationship.target,
type: relationship.label, // D3 uses 'type' for styling edge appearance
properties: relationship.properties
? normalizeProperties(
relationship.properties as Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[]
>,
)
: undefined,
}),
);
} else if (graphData.edges) {
// If edges are already provided, just normalize their properties
edges = (graphData.edges as GraphEdge[]).map((edge) => ({
...edge,
properties: edge.properties
? normalizeProperties(
edge.properties as Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[]
>,
)
: undefined,
}));
}
// Populate findings and resources based on HAS_FINDING edges
edges.forEach((edge) => {
if (edge.type === "HAS_FINDING") {
const sourceId =
typeof edge.source === "string"
? edge.source
: (edge.source as { id?: string })?.id;
const targetId =
typeof edge.target === "string"
? edge.target
: (edge.target as { id?: string })?.id;
if (sourceId && targetId) {
// Add finding to source node (resource -> finding)
const sourceNode = normalizedNodes.find((n) => n.id === sourceId);
if (sourceNode) {
sourceNode.findings.push(targetId);
}
// Add resource to target node (finding <- resource)
const targetNode = normalizedNodes.find((n) => n.id === targetId);
if (targetNode) {
targetNode.resources.push(sourceId);
}
}
}
});
return {
nodes: normalizedNodes,
edges,
relationships: graphData.relationships, // Preserve original relationships data
};
}
+89
View File
@@ -0,0 +1,89 @@
import { MetaDataProps } from "@/types";
import { AttackPathScan, AttackPathScansResponse } from "@/types/attack-paths";
/**
* Adapts raw scan API responses to enriched domain models
* - Transforms raw scan data with computed properties
* - Co-locates related data for better performance
* - Preserves pagination metadata for list operations
*
* Uses plugin architecture for extensibility:
* - Handles scan-specific response transformation
* - Can be composed with backend service plugins
* - Maintains separation of concerns between API layer and business logic
*/
/**
* Adapt attack path scans response with enriched data
*
* @param response - Raw API response from attack-paths-scans endpoint
* @returns Enriched scans data with metadata and computed properties
*/
export function adaptAttackPathScansResponse(
response: AttackPathScansResponse | undefined,
): {
data: AttackPathScan[];
metadata?: MetaDataProps;
} {
if (!response?.data) {
return { data: [] };
}
// Enrich scan data with computed properties
const enrichedData = response.data.map((scan) => ({
...scan,
attributes: {
...scan.attributes,
// Format duration for display
durationLabel: scan.attributes.duration
? formatDuration(scan.attributes.duration)
: null,
// Check if scan is recent (completed within last 24 hours)
isRecent: isRecentScan(scan.attributes.completed_at),
},
}));
// Transform links to MetaDataProps format if pagination exists
const metadata: MetaDataProps | undefined = response.links
? {
pagination: {
// Links-based pagination doesn't have traditional page numbers
// but we preserve the structure for consistency
page: 1,
pages: 1,
count: enrichedData.length,
itemsPerPage: [10, 25, 50, 100],
},
version: "1.0",
}
: undefined;
return { data: enrichedData, metadata };
}
/**
* Format duration in seconds to human-readable format
*
* @param seconds - Duration in seconds
* @returns Formatted duration string (e.g., "2m 30s")
*/
function formatDuration(seconds: number): string {
const minutes = Math.floor(seconds / 60);
const remainingSeconds = seconds % 60;
return `${minutes}m ${remainingSeconds}s`;
}
/**
* Check if a scan is recent (completed within last 24 hours)
*
* @param completedAt - Completion timestamp
* @returns true if scan completed within last 24 hours
*/
function isRecentScan(completedAt: string | null): boolean {
if (!completedAt) return false;
const completionTime = new Date(completedAt).getTime();
const oneDayAgo = Date.now() - 24 * 60 * 60 * 1000;
return completionTime > oneDayAgo;
}
+69
View File
@@ -0,0 +1,69 @@
"use server";
import { z } from "zod";
import { apiBaseUrl, getAuthHeaders } from "@/lib";
import { handleApiResponse } from "@/lib/server-actions-helper";
import { AttackPathScan, AttackPathScansResponse } from "@/types/attack-paths";
import { adaptAttackPathScansResponse } from "./scans.adapter";
// Validation schema for UUID - RFC 9562/4122 compliant
const UUIDSchema = z.uuid();
/**
* Fetch list of attack path scans (latest scan for each provider)
*/
export const getAttackPathScans = async (): Promise<
{ data: AttackPathScan[] } | undefined
> => {
const headers = await getAuthHeaders({ contentType: false });
try {
const response = await fetch(`${apiBaseUrl}/attack-paths-scans`, {
headers,
method: "GET",
});
const apiResponse = (await handleApiResponse(
response,
)) as AttackPathScansResponse;
const adaptedData = adaptAttackPathScansResponse(apiResponse);
return { data: adaptedData.data };
} catch (error) {
console.error("Error fetching attack path scans:", error);
return undefined;
}
};
/**
* Fetch detail of a specific attack path scan
*/
export const getAttackPathScanDetail = async (
scanId: string,
): Promise<{ data: AttackPathScan } | undefined> => {
// Validate scanId is a valid UUID format to prevent request forgery
const validatedScanId = UUIDSchema.safeParse(scanId);
if (!validatedScanId.success) {
console.error("Invalid scan ID format");
return undefined;
}
const headers = await getAuthHeaders({ contentType: false });
try {
const response = await fetch(
`${apiBaseUrl}/attack-paths-scans/${validatedScanId.data}`,
{
headers,
method: "GET",
},
);
return handleApiResponse(response);
} catch (error) {
console.error("Error fetching attack path scan detail:", error);
return undefined;
}
};
@@ -0,0 +1,2 @@
export { VerticalSteps } from "./vertical-steps";
export { WorkflowAttackPaths } from "./workflow-attack-paths";
@@ -0,0 +1,299 @@
"use client";
import { useControlledState } from "@react-stately/utils";
import { domAnimation, LazyMotion, m } from "framer-motion";
import type {
ComponentProps,
CSSProperties,
HTMLAttributes,
ReactNode,
} from "react";
import { forwardRef } from "react";
import { cn } from "@/lib/utils";
export type VerticalStepProps = {
className?: string;
description?: ReactNode;
title?: ReactNode;
};
const STEP_COLORS = {
primary: "primary",
secondary: "secondary",
success: "success",
warning: "warning",
danger: "danger",
default: "default",
} as const;
type StepColor = (typeof STEP_COLORS)[keyof typeof STEP_COLORS];
export interface VerticalStepsProps extends HTMLAttributes<HTMLButtonElement> {
/**
* An array of steps.
*
* @default []
*/
steps?: VerticalStepProps[];
/**
* The color of the steps.
*
* @default "primary"
*/
color?: StepColor;
/**
* The current step index.
*/
currentStep?: number;
/**
* The default step index.
*
* @default 0
*/
defaultStep?: number;
/**
* Whether to hide the progress bars.
*
* @default false
*/
hideProgressBars?: boolean;
/**
* The custom class for the steps wrapper.
*/
className?: string;
/**
* The custom class for the step.
*/
stepClassName?: string;
/**
* Callback function when the step index changes.
*/
onStepChange?: (stepIndex: number) => void;
}
function CheckIcon(props: ComponentProps<"svg">) {
return (
<svg
{...props}
fill="none"
stroke="currentColor"
strokeWidth={2}
viewBox="0 0 24 24"
>
<m.path
animate={{ pathLength: 1 }}
d="M5 13l4 4L19 7"
initial={{ pathLength: 0 }}
strokeLinecap="round"
strokeLinejoin="round"
transition={{
delay: 0.2,
type: "tween",
ease: "easeOut",
duration: 0.3,
}}
/>
</svg>
);
}
export const VerticalSteps = forwardRef<HTMLButtonElement, VerticalStepsProps>(
(
{
color = "primary",
steps = [],
defaultStep = 0,
onStepChange,
currentStep: currentStepProp,
hideProgressBars = false,
stepClassName,
className,
...props
},
ref,
) => {
const [currentStep, setCurrentStep] = useControlledState(
currentStepProp,
defaultStep,
onStepChange,
);
let userColor;
let fgColor;
const colorsVars = [
"[--active-fg-color:var(--step-fg-color)]",
"[--active-border-color:var(--step-color)]",
"[--active-color:var(--step-color)]",
"[--complete-background-color:var(--step-color)]",
"[--complete-border-color:var(--step-color)]",
"[--inactive-border-color:hsl(var(--heroui-default-300))]",
"[--inactive-color:hsl(var(--heroui-default-300))]",
];
switch (color) {
case "primary":
userColor = "[--step-color:hsl(var(--heroui-primary))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-primary-foreground))]";
break;
case "secondary":
userColor = "[--step-color:hsl(var(--heroui-secondary))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-secondary-foreground))]";
break;
case "success":
userColor = "[--step-color:hsl(var(--heroui-success))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-success-foreground))]";
break;
case "warning":
userColor = "[--step-color:hsl(var(--heroui-warning))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-warning-foreground))]";
break;
case "danger":
userColor = "[--step-color:hsl(var(--heroui-error))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-error-foreground))]";
break;
case "default":
userColor = "[--step-color:hsl(var(--heroui-default))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-default-foreground))]";
break;
default:
userColor = "[--step-color:hsl(var(--heroui-primary))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-primary-foreground))]";
break;
}
if (!className?.includes("--step-fg-color")) colorsVars.unshift(fgColor);
if (!className?.includes("--step-color")) colorsVars.unshift(userColor);
if (!className?.includes("--inactive-bar-color"))
colorsVars.push("[--inactive-bar-color:hsl(var(--heroui-default-300))]");
const colors = colorsVars;
return (
<nav aria-label="Progress" className="max-w-fit">
<ol className={cn("flex flex-col gap-y-3", colors, className)}>
{steps?.map((step, stepIdx) => {
const status =
currentStep === stepIdx
? "active"
: currentStep < stepIdx
? "inactive"
: "complete";
return (
<li key={stepIdx} className="relative">
<div className="flex w-full max-w-full items-center">
<button
key={stepIdx}
ref={ref}
aria-current={status === "active" ? "step" : undefined}
className={cn(
"group rounded-large flex w-full cursor-pointer items-center justify-center gap-4 px-3 py-2.5",
stepClassName,
)}
onClick={() => setCurrentStep(stepIdx)}
{...props}
>
<div className="flex h-full items-center">
<LazyMotion features={domAnimation}>
<div className="relative">
<m.div
animate={status}
className={cn(
"border-medium text-large text-default-foreground relative flex h-[34px] w-[34px] items-center justify-center rounded-full font-semibold",
{
"shadow-lg": status === "complete",
},
)}
data-status={status}
initial={false}
transition={{ duration: 0.25 }}
variants={{
inactive: {
backgroundColor: "transparent",
borderColor: "var(--inactive-border-color)",
color: "var(--inactive-color)",
},
active: {
backgroundColor: "transparent",
borderColor: "var(--active-border-color)",
color: "var(--active-color)",
},
complete: {
backgroundColor:
"var(--complete-background-color)",
borderColor: "var(--complete-border-color)",
},
}}
>
<div className="flex items-center justify-center">
{status === "complete" ? (
<CheckIcon className="h-6 w-6 text-(--active-fg-color)" />
) : (
<span>{stepIdx + 1}</span>
)}
</div>
</m.div>
</div>
</LazyMotion>
</div>
<div className="flex-1 text-left">
<div>
<div
className={cn(
"text-medium text-default-foreground font-medium transition-[color,opacity] duration-300 group-active:opacity-70",
{
"text-default-500": status === "inactive",
},
)}
>
{step.title}
</div>
<div
className={cn(
"text-tiny text-default-600 lg:text-small transition-[color,opacity] duration-300 group-active:opacity-70",
{
"text-default-500": status === "inactive",
},
)}
>
{step.description}
</div>
</div>
</div>
</button>
</div>
{stepIdx < steps.length - 1 && !hideProgressBars && (
<div
aria-hidden="true"
className={cn(
"pointer-events-none absolute top-[calc(64px*var(--idx)+1)] left-3 flex h-1/2 -translate-y-1/3 items-center px-4",
)}
style={
{
"--idx": stepIdx,
} as CSSProperties
}
>
<div
className={cn(
"relative h-full w-0.5 bg-(--inactive-bar-color) transition-colors duration-300",
"after:absolute after:block after:h-0 after:w-full after:bg-(--active-border-color) after:transition-[height] after:duration-300 after:content-['']",
{
"after:h-full": stepIdx < currentStep,
},
)}
/>
</div>
)}
</li>
);
})}
</ol>
</nav>
);
},
);
VerticalSteps.displayName = "VerticalSteps";
@@ -0,0 +1,49 @@
"use client";
import { usePathname } from "next/navigation";
import { VerticalSteps } from "./vertical-steps";
/**
* Workflow steps component for Attack Paths wizard
* Shows progress and navigation steps for the two-step process
*/
export const WorkflowAttackPaths = () => {
const pathname = usePathname();
// Determine current step based on pathname
const isQueryBuilderStep = pathname.includes("query-builder");
const currentStep = isQueryBuilderStep ? 1 : 0; // 0-indexed
const steps = [
{
title: "Select Attack Paths Scan",
description: "Choose an AWS account and its latest Attack Paths scan",
},
{
title: "Build Query & Visualize",
description: "Create a query and view the Attack Paths graph",
},
];
const progressPercentage = (currentStep / (steps.length - 1)) * 100;
return (
<section className="flex flex-col gap-6">
<div>
<div className="bg-bg-neutral-tertiary mb-4 h-2 w-full overflow-hidden rounded-full">
<div
className="bg-success-primary h-full transition-all duration-300"
style={{ width: `${progressPercentage}%` }}
/>
</div>
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Step {currentStep + 1} of {steps.length}
</h3>
</div>
<VerticalSteps currentStep={currentStep} steps={steps} color="success" />
</section>
);
};
@@ -0,0 +1,21 @@
import { Navbar } from "@/components/ui/nav-bar/navbar";
/**
* Workflow layout for Attack Paths
* Displays content with navbar
*/
export default function AttackPathsWorkflowLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<>
<Navbar title="Attack Paths Analysis" icon="" />
<div className="px-6 py-4 sm:px-8 xl:px-10">
{/* Content */}
<div>{children}</div>
</div>
</>
);
}
@@ -0,0 +1,34 @@
"use client";
import { Play } from "lucide-react";
import { Button } from "@/components/shadcn";
interface ExecuteButtonProps {
isLoading: boolean;
isDisabled: boolean;
onExecute: () => void;
}
/**
* Execute query button component
* Triggers query execution with loading state
*/
export const ExecuteButton = ({
isLoading,
isDisabled,
onExecute,
}: ExecuteButtonProps) => {
return (
<Button
variant="default"
size="lg"
disabled={isDisabled || isLoading}
onClick={onExecute}
className="w-full gap-2 font-semibold sm:w-auto"
>
{!isLoading && <Play size={18} />}
{isLoading ? "Executing Query..." : "Execute Query"}
</Button>
);
};
@@ -0,0 +1,93 @@
"use client";
import { Download, Minimize2, ZoomIn, ZoomOut } from "lucide-react";
import { Button } from "@/components/shadcn";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/shadcn/tooltip";
interface GraphControlsProps {
onZoomIn: () => void;
onZoomOut: () => void;
onFitToScreen: () => void;
onExport: () => void;
}
/**
* Controls for graph visualization (zoom, pan, export)
* Positioned as floating toolbar above graph
*/
export const GraphControls = ({
onZoomIn,
onZoomOut,
onFitToScreen,
onExport,
}: GraphControlsProps) => {
return (
<div className="flex items-center">
<div className="border-border-neutral-primary bg-bg-neutral-tertiary flex gap-1 rounded-lg border p-1">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onZoomIn}
className="h-8 w-8 p-0"
>
<ZoomIn size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Zoom in</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onZoomOut}
className="h-8 w-8 p-0"
>
<ZoomOut size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Zoom out</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onFitToScreen}
className="h-8 w-8 p-0"
>
<Minimize2 size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Fit graph to view</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onExport}
className="h-8 w-8 p-0"
>
<Download size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Export graph</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
</div>
);
};
@@ -0,0 +1,493 @@
"use client";
import { Card, CardContent } from "@/components/shadcn";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/shadcn/tooltip";
import type { AttackPathGraphData } from "@/types/attack-paths";
import {
getNodeBorderColor,
getNodeColor,
GRAPH_EDGE_COLOR,
GRAPH_NODE_BORDER_COLORS,
GRAPH_NODE_COLORS,
} from "../../_lib/graph-colors";
interface LegendItem {
label: string;
color: string;
borderColor: string;
description: string;
shape: "rectangle" | "hexagon" | "cloud";
}
// Map node labels to human-readable names and descriptions
const nodeTypeDescriptions: Record<
string,
{ name: string; description: string }
> = {
// Findings
ProwlerFinding: {
name: "Finding",
description: "Security findings from Prowler scans",
},
// AWS Account
AWSAccount: {
name: "AWS Account",
description: "AWS account root node",
},
// Compute
EC2Instance: {
name: "EC2 Instance",
description: "Elastic Compute Cloud instance",
},
LambdaFunction: {
name: "Lambda Function",
description: "AWS Lambda serverless function",
},
// Storage
S3Bucket: {
name: "S3 Bucket",
description: "Simple Storage Service bucket",
},
// IAM
IAMRole: {
name: "IAM Role",
description: "Identity and Access Management role",
},
IAMPolicy: {
name: "IAM Policy",
description: "Identity and Access Management policy",
},
AWSRole: {
name: "AWS Role",
description: "AWS IAM role",
},
AWSPolicy: {
name: "AWS Policy",
description: "AWS IAM policy",
},
AWSInlinePolicy: {
name: "AWS Inline Policy",
description: "AWS IAM inline policy",
},
AWSPolicyStatement: {
name: "AWS Policy Statement",
description: "AWS IAM policy statement",
},
AWSPrincipal: {
name: "AWS Principal",
description: "AWS IAM principal entity",
},
// Networking
SecurityGroup: {
name: "Security Group",
description: "AWS security group for network access control",
},
EC2SecurityGroup: {
name: "EC2 Security Group",
description: "EC2 security group for network access control",
},
IpPermissionInbound: {
name: "IP Permission Inbound",
description: "Inbound IP permission rule",
},
IpRule: {
name: "IP Rule",
description: "IP address rule",
},
Internet: {
name: "Internet",
description: "Internet gateway or public access",
},
// Tags
AWSTag: {
name: "AWS Tag",
description: "AWS resource tag",
},
Tag: {
name: "Tag",
description: "Resource tag",
},
};
/**
* Extract unique node types from graph data
*/
function extractNodeTypes(
nodes: AttackPathGraphData["nodes"] | undefined,
): string[] {
if (!nodes) return [];
const nodeTypes = new Set<string>();
nodes.forEach((node) => {
node.labels.forEach((label) => {
nodeTypes.add(label);
});
});
return Array.from(nodeTypes).sort();
}
/**
* Severity legend items - colors work in both light and dark themes
*/
const severityLegendItems: LegendItem[] = [
{
label: "Critical",
color: GRAPH_NODE_COLORS.critical,
borderColor: GRAPH_NODE_BORDER_COLORS.critical,
description: "Critical severity finding",
shape: "hexagon",
},
{
label: "High",
color: GRAPH_NODE_COLORS.high,
borderColor: GRAPH_NODE_BORDER_COLORS.high,
description: "High severity finding",
shape: "hexagon",
},
{
label: "Medium",
color: GRAPH_NODE_COLORS.medium,
borderColor: GRAPH_NODE_BORDER_COLORS.medium,
description: "Medium severity finding",
shape: "hexagon",
},
{
label: "Low",
color: GRAPH_NODE_COLORS.low,
borderColor: GRAPH_NODE_BORDER_COLORS.low,
description: "Low severity finding",
shape: "hexagon",
},
];
/**
* Generate legend items from graph data
*/
function generateLegendItems(
nodeTypes: string[],
hasFindings: boolean,
): LegendItem[] {
const items: LegendItem[] = [];
// Add severity items if there are findings
if (hasFindings) {
items.push(...severityLegendItems);
}
// Check for Internet node
const hasInternet = nodeTypes.some(
(type) => type.toLowerCase() === "internet",
);
// Check for any resource nodes (non-finding, non-internet)
const hasResources = nodeTypes.some((type) => {
const isFinding = type.toLowerCase().includes("finding");
const isPrivilegeEscalation = type === "PrivilegeEscalation";
const isInternet = type.toLowerCase() === "internet";
return !isFinding && !isPrivilegeEscalation && !isInternet;
});
// Add a single "Resource" item for all resource types
if (hasResources) {
items.push({
label: "Resource",
color: GRAPH_NODE_COLORS.default,
borderColor: GRAPH_NODE_BORDER_COLORS.default,
description: "Cloud infrastructure resource",
shape: "rectangle",
});
}
// Add Internet node if present
if (hasInternet) {
items.push({
label: "Internet",
color: getNodeColor(["Internet"]),
borderColor: getNodeBorderColor(["Internet"]),
description: "Internet gateway or public access",
shape: "cloud",
});
}
return items;
}
/**
* Hexagon shape component for legend
*/
const HexagonShape = ({
color,
borderColor,
}: {
color: string;
borderColor: string;
}) => (
<svg width="32" height="22" viewBox="0 0 32 22" aria-hidden="true">
<defs>
<filter id="legendGlow" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="1" result="coloredBlur" />
<feMerge>
<feMergeNode in="coloredBlur" />
<feMergeNode in="SourceGraphic" />
</feMerge>
</filter>
</defs>
<path
d="M5 1 L27 1 L31 11 L27 21 L5 21 L1 11 Z"
fill={color}
fillOpacity={0.85}
stroke={borderColor}
strokeWidth={1.5}
filter="url(#legendGlow)"
/>
</svg>
);
/**
* Pill shape component for legend
*/
const PillShape = ({
color,
borderColor,
}: {
color: string;
borderColor: string;
}) => (
<svg width="36" height="20" viewBox="0 0 36 20" aria-hidden="true">
<defs>
<filter id="legendGlow2" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="1" result="coloredBlur" />
<feMerge>
<feMergeNode in="coloredBlur" />
<feMergeNode in="SourceGraphic" />
</feMerge>
</filter>
</defs>
<rect
x="1"
y="1"
width="34"
height="18"
rx="9"
ry="9"
fill={color}
fillOpacity={0.85}
stroke={borderColor}
strokeWidth={1.5}
filter="url(#legendGlow2)"
/>
</svg>
);
/**
* Globe shape component for legend (used for Internet nodes)
*/
const GlobeShape = ({
color,
borderColor,
}: {
color: string;
borderColor: string;
}) => (
<svg width="24" height="24" viewBox="0 0 24 24" aria-hidden="true">
<defs>
<filter id="legendGlow3" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="1" result="coloredBlur" />
<feMerge>
<feMergeNode in="coloredBlur" />
<feMergeNode in="SourceGraphic" />
</feMerge>
</filter>
</defs>
{/* Globe circle */}
<circle
cx="12"
cy="12"
r="10"
fill={color}
fillOpacity={0.85}
stroke={borderColor}
strokeWidth={1.5}
filter="url(#legendGlow3)"
/>
{/* Horizontal line */}
<ellipse
cx="12"
cy="12"
rx="10"
ry="4"
fill="none"
stroke={borderColor}
strokeWidth={1}
strokeOpacity={0.6}
/>
{/* Vertical ellipse */}
<ellipse
cx="12"
cy="12"
rx="4"
ry="10"
fill="none"
stroke={borderColor}
strokeWidth={1}
strokeOpacity={0.6}
/>
</svg>
);
/**
* Edge line component for legend
*/
const EdgeLine = ({ dashed }: { dashed: boolean }) => (
<svg
width="60"
height="20"
viewBox="0 0 60 20"
aria-hidden="true"
style={{ overflow: "visible" }}
>
{/* Line */}
<line
x1="4"
y1="10"
x2="44"
y2="10"
stroke={GRAPH_EDGE_COLOR}
strokeWidth={3}
strokeLinecap="round"
strokeDasharray={dashed ? "8,6" : undefined}
/>
{/* Arrow head */}
<polygon points="44,5 56,10 44,15" fill={GRAPH_EDGE_COLOR} />
</svg>
);
interface GraphLegendProps {
data?: AttackPathGraphData;
}
/**
* Legend for attack path graph node types and edge styles
*/
export const GraphLegend = ({ data }: GraphLegendProps) => {
const nodeTypes = extractNodeTypes(data?.nodes);
// Check if there are any findings or privilege escalations in the data
const hasFindings = nodeTypes.some(
(type) =>
type.toLowerCase().includes("finding") || type === "PrivilegeEscalation",
);
const legendItems = generateLegendItems(nodeTypes, hasFindings);
if (legendItems.length === 0) {
return null;
}
return (
<Card className="w-fit border-0">
<CardContent className="gap-3 p-4">
<div className="flex flex-col gap-4">
{/* Node types section */}
<div className="flex flex-col items-start gap-3 lg:flex-row lg:flex-wrap lg:items-center">
<TooltipProvider>
{legendItems.map((item) => (
<Tooltip key={item.label}>
<TooltipTrigger asChild>
<div
className="flex cursor-help items-center gap-2"
role="img"
aria-label={`${item.label}: ${item.description}`}
>
{item.shape === "hexagon" ? (
<HexagonShape
color={item.color}
borderColor={item.borderColor}
/>
) : item.shape === "cloud" ? (
<GlobeShape
color={item.color}
borderColor={item.borderColor}
/>
) : (
<PillShape
color={item.color}
borderColor={item.borderColor}
/>
)}
<span className="text-text-neutral-secondary text-xs">
{item.label}
</span>
</div>
</TooltipTrigger>
<TooltipContent>{item.description}</TooltipContent>
</Tooltip>
))}
</TooltipProvider>
</div>
{/* Edge types section */}
<div className="border-border-neutral-primary flex flex-col items-start gap-3 border-t pt-3 lg:flex-row lg:flex-wrap lg:items-center">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div
className="flex cursor-help items-center gap-2"
role="img"
aria-label="Solid line: Resource connection"
>
<EdgeLine dashed={false} />
<span className="text-text-neutral-secondary text-xs">
Resource Connection
</span>
</div>
</TooltipTrigger>
<TooltipContent>
Connection between infrastructure resources
</TooltipContent>
</Tooltip>
{hasFindings && (
<Tooltip>
<TooltipTrigger asChild>
<div
className="flex cursor-help items-center gap-2"
role="img"
aria-label="Dashed line: Finding connection"
>
<EdgeLine dashed={true} />
<span className="text-text-neutral-secondary text-xs">
Finding Connection
</span>
</div>
</TooltipTrigger>
<TooltipContent>
Connection to a security finding
</TooltipContent>
</Tooltip>
)}
</TooltipProvider>
</div>
{/* Zoom control hint */}
<div className="border-border-neutral-primary flex items-center gap-2 border-t pt-3">
<kbd className="bg-bg-neutral-tertiary text-text-neutral-secondary rounded px-1.5 py-0.5 text-xs font-medium">
Ctrl
</kbd>
<span className="text-text-neutral-secondary text-xs">+</span>
<span className="text-text-neutral-secondary text-xs">
Scroll to zoom
</span>
</div>
</div>
</CardContent>
</Card>
);
};
@@ -0,0 +1,24 @@
"use client";
import { Skeleton } from "@/components/shadcn/skeleton/skeleton";
/**
* Loading skeleton for graph visualization
* Shows while graph data is being fetched and processed
*/
export const GraphLoading = () => {
return (
<div className="dark:bg-prowler-blue-400 flex h-96 items-center justify-center rounded-lg bg-gray-50">
<div className="flex flex-col items-center gap-3">
<div className="flex gap-2">
<Skeleton className="h-3 w-3 rounded-full" />
<Skeleton className="h-3 w-3 rounded-full" />
<Skeleton className="h-3 w-3 rounded-full" />
</div>
<p className="text-sm text-gray-600 dark:text-gray-400">
Loading Attack Paths graph...
</p>
</div>
</div>
);
};
@@ -0,0 +1,5 @@
export type { AttackPathGraphRef } from "./attack-path-graph";
export { AttackPathGraph } from "./attack-path-graph";
export { GraphControls } from "./graph-controls";
export { GraphLegend } from "./graph-legend";
export { GraphLoading } from "./graph-loading";
@@ -0,0 +1,7 @@
export { ExecuteButton } from "./execute-button";
export * from "./graph";
export * from "./node-detail";
export { QueryParametersForm } from "./query-parameters-form";
export { QuerySelector } from "./query-selector";
export { ScanListTable } from "./scan-list-table";
export { ScanStatusBadge } from "./scan-status-badge";
@@ -0,0 +1,4 @@
export { NodeDetailContent, NodeDetailPanel } from "./node-detail-panel";
export { NodeOverview } from "./node-overview";
export { NodeRelationships } from "./node-relationships";
export { NodeRemediation } from "./node-remediation";
@@ -0,0 +1,132 @@
"use client";
import { Button, Card, CardContent } from "@/components/shadcn";
import {
Sheet,
SheetContent,
SheetDescription,
SheetHeader,
SheetTitle,
} from "@/components/ui/sheet/sheet";
import type { GraphNode } from "@/types/attack-paths";
import { NodeFindings } from "./node-findings";
import { NodeOverview } from "./node-overview";
import { NodeResources } from "./node-resources";
interface NodeDetailPanelProps {
node: GraphNode | null;
allNodes?: GraphNode[];
onClose?: () => void;
}
/**
* Node details content component (reusable)
*/
export const NodeDetailContent = ({
node,
allNodes = [],
}: {
node: GraphNode;
allNodes?: GraphNode[];
}) => {
const isProwlerFinding = node?.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
return (
<div className="flex flex-col gap-6">
{/* Node Overview Section */}
<Card className="border-border-neutral-secondary">
<CardContent className="flex flex-col gap-3 p-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Node Overview
</h3>
<NodeOverview node={node} />
</CardContent>
</Card>
{/* Related Findings Section - Only show for non-Finding nodes */}
{!isProwlerFinding && (
<Card className="border-border-neutral-secondary">
<CardContent className="flex flex-col gap-3 p-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Related Findings
</h3>
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
Findings connected to this node
</div>
<NodeFindings node={node} allNodes={allNodes} />
</CardContent>
</Card>
)}
{/* Affected Resources Section - Only show for Finding nodes */}
{isProwlerFinding && (
<Card className="border-border-neutral-secondary">
<CardContent className="flex flex-col gap-3 p-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Affected Resources
</h3>
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
Resources affected by this finding
</div>
<NodeResources node={node} allNodes={allNodes} />
</CardContent>
</Card>
)}
</div>
);
};
/**
* Right-side sheet panel for node details
* Shows comprehensive information about selected graph node
* Uses shadcn Sheet component for sliding panel from right
*/
export const NodeDetailPanel = ({
node,
allNodes = [],
onClose,
}: NodeDetailPanelProps) => {
const isOpen = node !== null;
const isProwlerFinding = node?.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
return (
<Sheet open={isOpen} onOpenChange={(open) => !open && onClose?.()}>
<SheetContent className="dark:bg-prowler-theme-midnight my-4 max-h-[calc(100vh-2rem)] max-w-[95vw] overflow-y-auto rounded-l-xl pt-10 md:my-8 md:max-h-[calc(100vh-4rem)] md:max-w-[55vw]">
<SheetHeader>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<SheetTitle>Node Details</SheetTitle>
<SheetDescription>
{String(node?.properties?.name || node?.id.substring(0, 20))}
</SheetDescription>
</div>
{node && isProwlerFinding && (
<Button asChild variant="default" size="sm" className="mt-1">
<a
href={`/findings?id=${String(node.properties?.id || node.id)}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View finding ${String(node.properties?.id || node.id)}`}
>
View Finding
</a>
</Button>
)}
</div>
</SheetHeader>
{node && (
<div className="pt-6">
<NodeDetailContent node={node} allNodes={allNodes} />
</div>
)}
</SheetContent>
</Sheet>
);
};
@@ -0,0 +1,102 @@
"use client";
import { SeverityBadge } from "@/components/ui/table/severity-badge";
import type { GraphNode } from "@/types/attack-paths";
const SEVERITY_LEVELS = {
informational: "informational",
low: "low",
medium: "medium",
high: "high",
critical: "critical",
} as const;
type Severity = (typeof SEVERITY_LEVELS)[keyof typeof SEVERITY_LEVELS];
interface NodeFindingsProps {
node: GraphNode;
allNodes?: GraphNode[];
}
/**
* Node findings section showing related findings for the selected node
* Displays findings that are connected to the node via HAS_FINDING edges
*/
export const NodeFindings = ({ node, allNodes = [] }: NodeFindingsProps) => {
// Get finding IDs from the node's findings array (populated by adapter)
const findingIds = node.findings || [];
// Get the actual finding nodes
const findingNodes = allNodes.filter((n) => findingIds.includes(n.id));
if (findingNodes.length === 0) {
return null;
}
const normalizeSeverity = (
severity?: string | number | boolean | string[] | number[] | null,
): Severity => {
const sev = String(
Array.isArray(severity) ? severity[0] : severity || "",
).toLowerCase();
if (sev in SEVERITY_LEVELS) {
return sev as Severity;
}
return "informational";
};
return (
<ul className="flex flex-col gap-3">
{findingNodes.map((finding) => {
// Get the finding name (check_title preferred, then name)
const findingName = String(
finding.properties?.check_title ||
finding.properties?.name ||
finding.properties?.finding_id ||
"Unknown Finding",
);
// Use properties.id for display, fallback to graph node id
const findingId = String(finding.properties?.id || finding.id);
return (
<li
key={finding.id}
className="border-border-neutral-secondary rounded-lg border p-3"
>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<div className="flex items-center gap-2">
{finding.properties?.severity && (
<SeverityBadge
severity={normalizeSeverity(finding.properties.severity)}
/>
)}
<h5 className="dark:text-prowler-theme-pale/90 text-sm font-medium">
{findingName}
</h5>
</div>
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary mt-1 text-xs">
ID: {findingId}
</p>
</div>
<a
href={`/findings?id=${findingId}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View full finding for ${findingName}`}
className="text-text-info dark:text-text-info h-auto shrink-0 p-0 text-xs font-medium hover:underline"
>
View Full Finding
</a>
</div>
{finding.properties?.description && (
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-2 text-xs">
{String(finding.properties.description)}
</div>
)}
</li>
);
})}
</ul>
);
};
@@ -0,0 +1,109 @@
"use client";
import { CodeSnippet } from "@/components/ui/code-snippet/code-snippet";
import { InfoField } from "@/components/ui/entities";
import { DateWithTime } from "@/components/ui/entities/date-with-time";
import type { GraphNode, GraphNodePropertyValue } from "@/types/attack-paths";
import { formatNodeLabels } from "../../_lib";
interface NodeOverviewProps {
node: GraphNode;
}
/**
* Node overview section showing basic node information
*/
export const NodeOverview = ({ node }: NodeOverviewProps) => {
const renderValue = (value: GraphNodePropertyValue) => {
if (value === null || value === undefined || value === "") {
return "-";
}
if (Array.isArray(value)) {
return value.join(", ");
}
return String(value);
};
const isFinding = node.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
return (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
<InfoField label="Type">{formatNodeLabels(node.labels)}</InfoField>
{isFinding && node.properties.check_title && (
<InfoField label="Check Title">
{String(node.properties.check_title)}
</InfoField>
)}
{isFinding && node.properties.id && (
<InfoField label="Finding ID" variant="simple">
<CodeSnippet value={String(node.properties.id)} />
</InfoField>
)}
</div>
{/* Display all properties */}
<div className="mt-4 border-t border-gray-200 pt-4 dark:border-gray-700">
<h4 className="dark:text-prowler-theme-pale/90 mb-3 text-sm font-semibold">
Properties
</h4>
<div className="grid grid-cols-1 gap-3 md:grid-cols-2">
{Object.entries(node.properties).map(([key, value]) => {
// Skip internal properties
if (key.startsWith("_")) {
return null;
}
// Skip check_title and id for findings as they're shown prominently above
if (isFinding && (key === "check_title" || key === "id")) {
return null;
}
// Format timestamp values
const isTimestamp =
key.includes("date") ||
key.includes("time") ||
key.includes("at") ||
key.includes("seen");
return (
<InfoField key={key} label={formatPropertyName(key)}>
{isTimestamp && typeof value === "number" ? (
<DateWithTime
inline
dateTime={new Date(value).toISOString()}
/>
) : isTimestamp &&
typeof value === "string" &&
value.match(/^\d+$/) ? (
<DateWithTime
inline
dateTime={new Date(parseInt(value)).toISOString()}
/>
) : typeof value === "object" ? (
<code className="text-xs">
{JSON.stringify(value).substring(0, 50)}...
</code>
) : (
renderValue(value)
)}
</InfoField>
);
})}
</div>
</div>
</div>
);
};
// Helper function to format property names
function formatPropertyName(name: string): string {
return name
.replace(/([A-Z])/g, " $1")
.replace(/_/g, " ")
.replace(/\b\w/g, (l) => l.toUpperCase())
.trim();
}
@@ -0,0 +1,105 @@
"use client";
import { cn } from "@/lib/utils";
import type { GraphEdge } from "@/types/attack-paths";
interface NodeRelationshipsProps {
incomingEdges: GraphEdge[];
outgoingEdges: GraphEdge[];
}
/**
* Format edge type to human-readable label
* e.g., "HAS_FINDING" -> "Has Finding"
*/
function formatEdgeType(edgeType: string): string {
return edgeType
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
.join(" ");
}
interface EdgeItemProps {
edge: GraphEdge;
isOutgoing: boolean;
}
/**
* Reusable edge item component
*/
function EdgeItem({ edge, isOutgoing }: EdgeItemProps) {
const targetId =
typeof edge.target === "string" ? edge.target : String(edge.target);
const sourceId =
typeof edge.source === "string" ? edge.source : String(edge.source);
const displayId = (isOutgoing ? targetId : sourceId).substring(0, 30);
return (
<div
key={edge.id}
className="border-border-neutral-tertiary dark:border-border-neutral-tertiary flex items-center justify-between rounded border p-2"
>
<code className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
{displayId}
</code>
<span
className={cn(
"rounded px-2 py-1 text-xs font-medium",
isOutgoing
? "bg-bg-data-info text-text-neutral-primary dark:text-text-neutral-primary"
: "bg-bg-pass-primary text-text-neutral-primary dark:text-text-neutral-primary",
)}
>
{formatEdgeType(edge.type)}
</span>
</div>
);
}
/**
* Node relationships section showing incoming and outgoing edges
*/
export const NodeRelationships = ({
incomingEdges,
outgoingEdges,
}: NodeRelationshipsProps) => {
return (
<div className="flex flex-col gap-6">
{/* Outgoing Relationships */}
<div>
<h4 className="dark:text-prowler-theme-pale/90 mb-3 text-sm font-semibold">
Outgoing Relationships ({outgoingEdges.length})
</h4>
{outgoingEdges.length > 0 ? (
<div className="space-y-2">
{outgoingEdges.map((edge) => (
<EdgeItem key={edge.id} edge={edge} isOutgoing />
))}
</div>
) : (
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary text-xs">
No outgoing relationships
</p>
)}
</div>
{/* Incoming Relationships */}
<div className="border-border-neutral-tertiary dark:border-border-neutral-tertiary border-t pt-6">
<h4 className="dark:text-prowler-theme-pale/90 mb-3 text-sm font-semibold">
Incoming Relationships ({incomingEdges.length})
</h4>
{incomingEdges.length > 0 ? (
<div className="space-y-2">
{incomingEdges.map((edge) => (
<EdgeItem key={edge.id} edge={edge} isOutgoing={false} />
))}
</div>
) : (
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary text-xs">
No incoming relationships
</p>
)}
</div>
</div>
);
};
@@ -0,0 +1,83 @@
"use client";
import Link from "next/link";
import { Badge } from "@/components/shadcn/badge/badge";
interface Finding {
id: string;
title: string;
severity: "critical" | "high" | "medium" | "low" | "info";
status: "PASS" | "FAIL" | "MANUAL";
}
interface NodeRemediationProps {
findings: Finding[];
}
/**
* Node remediation section showing related Prowler findings
*/
export const NodeRemediation = ({ findings }: NodeRemediationProps) => {
const getSeverityVariant = (severity: string) => {
switch (severity) {
case "critical":
return "destructive";
case "high":
return "default";
case "medium":
return "secondary";
case "low":
return "outline";
default:
return "default";
}
};
const getStatusVariant = (status: string) => {
if (status === "PASS") return "default";
if (status === "FAIL") return "destructive";
return "secondary";
};
return (
<div className="flex flex-col gap-3">
{findings.map((finding) => (
<div
key={finding.id}
className="rounded-lg border border-gray-200 p-3 dark:border-gray-700"
>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<h5 className="dark:text-prowler-theme-pale/90 text-sm font-medium">
{finding.title}
</h5>
<p className="mt-1 text-xs text-gray-500 dark:text-gray-400">
ID: {finding.id.substring(0, 12)}...
</p>
</div>
<div className="flex flex-col gap-1">
<Badge variant={getSeverityVariant(finding.severity)}>
{finding.severity}
</Badge>
<Badge variant={getStatusVariant(finding.status)}>
{finding.status}
</Badge>
</div>
</div>
<div className="mt-2">
<Link
href={`/findings?id=${finding.id}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View full finding for ${finding.title}`}
className="text-text-info dark:text-text-info text-sm transition-all hover:opacity-80 dark:hover:opacity-80"
>
View Full Finding
</Link>
</div>
</div>
))}
</div>
);
};
@@ -0,0 +1,85 @@
"use client";
import { Badge } from "@/components/shadcn/badge/badge";
import { cn } from "@/lib/utils";
import type { GraphNode } from "@/types/attack-paths";
interface NodeResourcesProps {
node: GraphNode;
allNodes?: GraphNode[];
}
/**
* Node resources section showing affected resources for the selected finding node
* Displays resources that are connected to the finding node via HAS_FINDING edges
*/
export const NodeResources = ({ node, allNodes = [] }: NodeResourcesProps) => {
// Get resource IDs from the node's resources array (populated by adapter)
const resourceIds = node.resources || [];
// Get the actual resource nodes
const resourceNodes = allNodes.filter((n) => resourceIds.includes(n.id));
if (resourceNodes.length === 0) {
return null;
}
const getResourceTypeColor = (labels: string[]): string => {
const label = (labels[0] || "").toLowerCase();
switch (label) {
case "s3bucket":
case "awsaccount":
case "ec2instance":
case "iamrole":
case "lambdafunction":
case "securitygroup":
return "bg-bg-data-aws";
default:
return "bg-bg-data-muted";
}
};
return (
<ul className="flex flex-col gap-3">
{resourceNodes.map((resource) => {
// Use properties.id for display, fallback to graph node id
const resourceId = String(resource.properties?.id || resource.id);
return (
<li
key={resource.id}
className="border-border-neutral-secondary rounded-lg border p-3"
>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<div className="flex items-center gap-2">
{resource.labels && (
<Badge
className={cn(
getResourceTypeColor(resource.labels),
"text-text-neutral-primary",
)}
>
{resource.labels[0]}
</Badge>
)}
<h5 className="dark:text-prowler-theme-pale/90 text-sm font-medium">
{String(resource.properties?.name || resourceId)}
</h5>
</div>
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary mt-1 text-xs">
ID: {resourceId}
</p>
</div>
</div>
{resource.properties?.arn && (
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-2 text-xs">
ARN: {String(resource.properties.arn)}
</div>
)}
</li>
);
})}
</ul>
);
};
@@ -0,0 +1,122 @@
"use client";
import { Controller, useFormContext } from "react-hook-form";
import type { AttackPathQuery } from "@/types/attack-paths";
interface QueryParametersFormProps {
selectedQuery: AttackPathQuery | null | undefined;
}
/**
* Dynamic form component for query parameters
* Renders form fields based on selected query's parameters
*/
export const QueryParametersForm = ({
selectedQuery,
}: QueryParametersFormProps) => {
const {
control,
formState: { errors },
} = useFormContext();
if (!selectedQuery || !selectedQuery.attributes.parameters.length) {
return (
<div className="rounded-lg bg-blue-50 p-4 dark:bg-blue-950/20">
<p className="text-sm text-blue-700 dark:text-blue-300">
This query requires no parameters. Click &quot;Execute Query&quot; to
proceed.
</p>
</div>
);
}
return (
<div className="flex flex-col gap-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Query Parameters
</h3>
{selectedQuery.attributes.parameters.map((param) => (
<Controller
key={param.name}
name={param.name}
control={control}
render={({ field }) => {
if (param.data_type === "boolean") {
return (
<div className="flex flex-col gap-2">
<label className="flex cursor-pointer items-center gap-3">
<input
type="checkbox"
id={param.name}
checked={field.value === true || field.value === "true"}
onChange={(e) => field.onChange(e.target.checked)}
aria-label={param.label}
className="border-border-neutral-secondary bg-bg-neutral-primary text-text-primary focus:ring-primary dark:border-border-neutral-secondary dark:bg-bg-neutral-primary dark:text-text-primary h-4 w-4 rounded border focus:ring-2"
/>
<div className="flex flex-col gap-1">
<span className="text-sm font-medium text-gray-900 dark:text-gray-100">
{param.label}
</span>
{param.description && (
<span className="text-xs text-gray-600 dark:text-gray-400">
{param.description}
</span>
)}
</div>
</label>
</div>
);
}
const errorMessage = (() => {
const error = errors[param.name];
if (error && typeof error.message === "string") {
return error.message;
}
return undefined;
})();
const descriptionId = `${param.name}-description`;
return (
<div className="flex flex-col gap-2">
<label
htmlFor={param.name}
className="text-sm font-medium text-gray-700 dark:text-gray-300"
>
{param.label}
{param.required && <span className="text-red-500"> *</span>}
</label>
<input
{...field}
id={param.name}
type={param.data_type === "number" ? "number" : "text"}
placeholder={
param.placeholder || `Enter ${param.label.toLowerCase()}`
}
value={field.value ?? ""}
aria-describedby={
param.description ? descriptionId : undefined
}
className="border-border-neutral-secondary bg-bg-neutral-primary text-text-neutral-primary placeholder-text-neutral-secondary focus:border-border-primary focus:ring-primary dark:border-border-neutral-secondary dark:bg-bg-neutral-primary dark:text-text-neutral-primary dark:placeholder-text-neutral-secondary dark:focus:border-border-primary rounded-md border px-3 py-2 text-sm focus:ring-1 focus:outline-none"
/>
{param.description && (
<span
id={descriptionId}
className="text-xs text-gray-600 dark:text-gray-400"
>
{param.description}
</span>
)}
{errorMessage && (
<span className="text-xs text-red-500">{errorMessage}</span>
)}
</div>
);
}}
/>
))}
</div>
);
};
@@ -0,0 +1,46 @@
"use client";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/shadcn";
import type { AttackPathQuery } from "@/types/attack-paths";
interface QuerySelectorProps {
queries: AttackPathQuery[];
selectedQueryId: string | null;
onQueryChange: (queryId: string) => void;
}
/**
* Query selector dropdown component
* Allows users to select from available Attack Paths queries
*/
export const QuerySelector = ({
queries,
selectedQueryId,
onQueryChange,
}: QuerySelectorProps) => {
return (
<Select value={selectedQueryId || ""} onValueChange={onQueryChange}>
<SelectTrigger className="w-full text-left">
<SelectValue placeholder="Choose a query..." />
</SelectTrigger>
<SelectContent>
{queries.map((query) => (
<SelectItem key={query.id} value={query.id}>
<div className="flex flex-col gap-1">
<span className="font-medium">{query.attributes.name}</span>
<span className="text-xs text-gray-500">
{query.attributes.description}
</span>
</div>
</SelectItem>
))}
</SelectContent>
</Select>
);
};
@@ -0,0 +1,350 @@
"use client";
import {
ChevronLeftIcon,
ChevronRightIcon,
DoubleArrowLeftIcon,
DoubleArrowRightIcon,
} from "@radix-ui/react-icons";
import Link from "next/link";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { useState } from "react";
import { Button } from "@/components/shadcn/button/button";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/shadcn/select/select";
import { DateWithTime } from "@/components/ui/entities/date-with-time";
import { EntityInfo } from "@/components/ui/entities/entity-info";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { cn } from "@/lib/utils";
import type { ProviderType } from "@/types";
import type { AttackPathScan } from "@/types/attack-paths";
import { SCAN_STATES } from "@/types/attack-paths";
import { ScanStatusBadge } from "./scan-status-badge";
interface ScanListTableProps {
scans: AttackPathScan[];
}
const TABLE_COLUMN_COUNT = 6;
const DEFAULT_PAGE_SIZE = 5;
const PAGE_SIZE_OPTIONS = [2, 5, 10, 15];
const baseLinkClass =
"relative block rounded border-0 bg-transparent px-3 py-1.5 text-button-primary outline-none transition-all duration-300 hover:bg-bg-neutral-tertiary hover:text-text-neutral-primary focus:shadow-none dark:hover:bg-bg-neutral-secondary dark:hover:text-text-neutral-primary";
const disabledLinkClass =
"text-border-neutral-secondary dark:text-border-neutral-secondary hover:bg-transparent hover:text-border-neutral-secondary dark:hover:text-border-neutral-secondary cursor-default pointer-events-none";
/**
* Table displaying AWS account Attack Paths scans
* Shows scan metadata and allows selection of completed scans
*/
export const ScanListTable = ({ scans }: ScanListTableProps) => {
const pathname = usePathname();
const searchParams = useSearchParams();
const router = useRouter();
const selectedScanId = searchParams.get("scanId");
const currentPage = parseInt(searchParams.get("scanPage") ?? "1");
const pageSize = parseInt(
searchParams.get("scanPageSize") ?? String(DEFAULT_PAGE_SIZE),
);
const [selectedPageSize, setSelectedPageSize] = useState(String(pageSize));
const totalPages = Math.ceil(scans.length / pageSize);
const startIndex = (currentPage - 1) * pageSize;
const endIndex = startIndex + pageSize;
const paginatedScans = scans.slice(startIndex, endIndex);
const handleSelectScan = (scanId: string) => {
const params = new URLSearchParams(searchParams);
params.set("scanId", scanId);
router.push(`${pathname}?${params.toString()}`);
};
const isSelectDisabled = (scan: AttackPathScan) => {
return (
scan.attributes.state !== SCAN_STATES.COMPLETED ||
selectedScanId === scan.id
);
};
const getSelectButtonLabel = (scan: AttackPathScan) => {
if (selectedScanId === scan.id) {
return "Selected";
}
if (scan.attributes.state === SCAN_STATES.SCHEDULED) {
return "Scheduled";
}
if (scan.attributes.state === SCAN_STATES.EXECUTING) {
return "Waiting...";
}
if (scan.attributes.state === SCAN_STATES.FAILED) {
return "Failed";
}
return "Select";
};
const createPageUrl = (pageNumber: number | string) => {
const params = new URLSearchParams(searchParams);
// Preserve scanId if it exists
const scanId = searchParams.get("scanId");
if (+pageNumber > totalPages) {
return `${pathname}?${params.toString()}`;
}
params.set("scanPage", pageNumber.toString());
// Ensure that scanId is preserved
if (scanId) params.set("scanId", scanId);
return `${pathname}?${params.toString()}`;
};
const isFirstPage = currentPage === 1;
const isLastPage = currentPage === totalPages;
return (
<>
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
<Table aria-label="Attack Paths scans table listing provider accounts, scan dates, status, progress, and duration">
<TableHeader>
<TableRow>
<TableHead>Provider / Account</TableHead>
<TableHead>Last Scan Date</TableHead>
<TableHead>Status</TableHead>
<TableHead>Progress</TableHead>
<TableHead>Duration</TableHead>
<TableHead className="text-right">Action</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{scans.length === 0 ? (
<TableRow>
<TableCell
colSpan={TABLE_COLUMN_COUNT}
className="h-24 text-center"
>
No Attack Paths scans available.
</TableCell>
</TableRow>
) : (
paginatedScans.map((scan) => {
const isDisabled = isSelectDisabled(scan);
const isSelected = selectedScanId === scan.id;
const duration = scan.attributes.duration
? `${Math.floor(scan.attributes.duration / 60)}m ${scan.attributes.duration % 60}s`
: "-";
return (
<TableRow
key={scan.id}
className={
isSelected
? "bg-button-primary/10 dark:bg-button-primary/10"
: ""
}
>
<TableCell className="font-medium">
<EntityInfo
cloudProvider={
scan.attributes.provider_type as ProviderType
}
entityAlias={scan.attributes.provider_alias}
entityId={scan.attributes.provider_uid}
/>
</TableCell>
<TableCell>
{scan.attributes.completed_at ? (
<DateWithTime
inline
dateTime={scan.attributes.completed_at}
/>
) : (
"-"
)}
</TableCell>
<TableCell>
<ScanStatusBadge
status={scan.attributes.state}
progress={scan.attributes.progress}
/>
</TableCell>
<TableCell>
<span className="text-sm">
{scan.attributes.progress}%
</span>
</TableCell>
<TableCell>
<span className="text-sm">{duration}</span>
</TableCell>
<TableCell className="text-right">
<Button
type="button"
aria-label="Select scan"
disabled={isDisabled}
variant={isDisabled ? "secondary" : "default"}
onClick={() => handleSelectScan(scan.id)}
className="w-full max-w-24"
>
{getSelectButtonLabel(scan)}
</Button>
</TableCell>
</TableRow>
);
})
)}
</TableBody>
</Table>
{/* Pagination Controls */}
{scans.length > 0 && (
<div className="flex w-full flex-col-reverse items-center justify-between gap-4 overflow-auto p-1 sm:flex-row sm:gap-8">
<div className="text-sm whitespace-nowrap">
{scans.length} scans in total
</div>
{scans.length > DEFAULT_PAGE_SIZE && (
<div className="flex flex-col-reverse items-center gap-4 sm:flex-row sm:gap-6 lg:gap-8">
{/* Rows per page selector */}
<div className="flex items-center gap-2">
<p className="text-sm font-medium whitespace-nowrap">
Rows per page
</p>
<Select
value={selectedPageSize}
onValueChange={(value) => {
setSelectedPageSize(value);
const params = new URLSearchParams(searchParams);
// Preserve scanId if it exists
const scanId = searchParams.get("scanId");
params.set("scanPageSize", value);
params.set("scanPage", "1");
// Ensure that scanId is preserved
if (scanId) params.set("scanId", scanId);
router.push(`${pathname}?${params.toString()}`);
}}
>
<SelectTrigger className="h-8 w-18">
<SelectValue />
</SelectTrigger>
<SelectContent side="top">
{PAGE_SIZE_OPTIONS.map((size) => (
<SelectItem
key={size}
value={`${size}`}
className="cursor-pointer"
>
{size}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex items-center justify-center text-sm font-medium">
Page {currentPage} of {totalPages}
</div>
<div className="flex items-center gap-2">
<Link
aria-label="Go to first page"
className={cn(
baseLinkClass,
isFirstPage && disabledLinkClass,
)}
href={
isFirstPage
? pathname + "?" + searchParams.toString()
: createPageUrl(1)
}
aria-disabled={isFirstPage}
onClick={(e) => isFirstPage && e.preventDefault()}
>
<DoubleArrowLeftIcon
className="size-4"
aria-hidden="true"
/>
</Link>
<Link
aria-label="Go to previous page"
className={cn(
baseLinkClass,
isFirstPage && disabledLinkClass,
)}
href={
isFirstPage
? pathname + "?" + searchParams.toString()
: createPageUrl(currentPage - 1)
}
aria-disabled={isFirstPage}
onClick={(e) => isFirstPage && e.preventDefault()}
>
<ChevronLeftIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to next page"
className={cn(
baseLinkClass,
isLastPage && disabledLinkClass,
)}
href={
isLastPage
? pathname + "?" + searchParams.toString()
: createPageUrl(currentPage + 1)
}
aria-disabled={isLastPage}
onClick={(e) => isLastPage && e.preventDefault()}
>
<ChevronRightIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to last page"
className={cn(
baseLinkClass,
isLastPage && disabledLinkClass,
)}
href={
isLastPage
? pathname + "?" + searchParams.toString()
: createPageUrl(totalPages)
}
aria-disabled={isLastPage}
onClick={(e) => isLastPage && e.preventDefault()}
>
<DoubleArrowRightIcon
className="size-4"
aria-hidden="true"
/>
</Link>
</div>
</div>
)}
</div>
)}
</div>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-6 text-xs">
Only Attack Paths scans with &quot;Completed&quot; status can be
selected. Scans in progress will update automatically.
</p>
</>
);
};
@@ -0,0 +1,59 @@
"use client";
import { Loader2 } from "lucide-react";
import { Badge } from "@/components/shadcn/badge/badge";
import type { ScanState } from "@/types/attack-paths";
interface ScanStatusBadgeProps {
status: ScanState;
progress?: number;
}
/**
* Status badge for attack path scan status
* Shows visual indicator and text for scan progress
*/
export const ScanStatusBadge = ({
status,
progress = 0,
}: ScanStatusBadgeProps) => {
if (status === "scheduled") {
return (
<Badge className="bg-bg-neutral-tertiary text-text-neutral-primary gap-2">
<span>Scheduled</span>
</Badge>
);
}
if (status === "available") {
return (
<Badge className="bg-bg-neutral-tertiary text-text-neutral-primary gap-2">
<span>Queued</span>
</Badge>
);
}
if (status === "executing") {
return (
<Badge className="bg-bg-warning-secondary text-text-neutral-primary gap-2">
<Loader2 size={14} className="animate-spin" />
<span>In Progress ({progress}%)</span>
</Badge>
);
}
if (status === "completed") {
return (
<Badge className="bg-bg-pass-secondary text-text-success-primary gap-2">
<span>Completed</span>
</Badge>
);
}
return (
<Badge className="bg-bg-fail-secondary text-text-error-primary gap-2">
<span>Failed</span>
</Badge>
);
};
@@ -0,0 +1,3 @@
export { useGraphState } from "./use-graph-state";
export { useQueryBuilder } from "./use-query-builder";
export { useWizardState } from "./use-wizard-state";
@@ -0,0 +1,276 @@
"use client";
import { create } from "zustand";
import type {
AttackPathGraphData,
GraphNode,
GraphState,
} from "@/types/attack-paths";
interface FilteredViewState {
isFilteredView: boolean;
filteredNodeId: string | null;
fullData: AttackPathGraphData | null; // Original data before filtering
}
interface GraphStore extends GraphState, FilteredViewState {
setGraphData: (data: AttackPathGraphData) => void;
setSelectedNodeId: (nodeId: string | null) => void;
setLoading: (loading: boolean) => void;
setError: (error: string | null) => void;
setZoom: (zoomLevel: number) => void;
setPan: (panX: number, panY: number) => void;
setFilteredView: (
isFiltered: boolean,
nodeId: string | null,
filteredData: AttackPathGraphData | null,
fullData: AttackPathGraphData | null,
) => void;
reset: () => void;
}
const initialState: GraphState & FilteredViewState = {
data: null,
selectedNodeId: null,
loading: false,
error: null,
zoomLevel: 1,
panX: 0,
panY: 0,
isFilteredView: false,
filteredNodeId: null,
fullData: null,
};
const useGraphStore = create<GraphStore>((set) => ({
...initialState,
setGraphData: (data) => set({ data, fullData: null, error: null, isFilteredView: false, filteredNodeId: null }),
setSelectedNodeId: (nodeId) => set({ selectedNodeId: nodeId }),
setLoading: (loading) => set({ loading }),
setError: (error) => set({ error }),
setZoom: (zoomLevel) => set({ zoomLevel }),
setPan: (panX, panY) => set({ panX, panY }),
setFilteredView: (isFiltered, nodeId, filteredData, fullData) =>
set({ isFilteredView: isFiltered, filteredNodeId: nodeId, data: filteredData, fullData, selectedNodeId: nodeId }),
reset: () => set(initialState),
}));
/**
* Helper to get edge source/target ID from string or object
*/
function getEdgeNodeId(nodeRef: string | object): string {
if (typeof nodeRef === "string") {
return nodeRef;
}
return (nodeRef as GraphNode).id;
}
/**
* Compute a filtered subgraph containing only the path through the target node.
* This follows the directed graph structure of attack paths:
* - Upstream: traces back to the root (AWS Account)
* - Downstream: traces forward to leaf nodes
* - Also includes findings connected to the selected node
*/
function computeFilteredSubgraph(
fullData: AttackPathGraphData,
targetNodeId: string,
): AttackPathGraphData {
const nodes = fullData.nodes;
const edges = fullData.edges || [];
// Build directed adjacency lists
const forwardEdges = new Map<string, Set<string>>(); // source -> targets
const backwardEdges = new Map<string, Set<string>>(); // target -> sources
nodes.forEach((node) => {
forwardEdges.set(node.id, new Set());
backwardEdges.set(node.id, new Set());
});
edges.forEach((edge) => {
const sourceId = getEdgeNodeId(edge.source);
const targetId = getEdgeNodeId(edge.target);
forwardEdges.get(sourceId)?.add(targetId);
backwardEdges.get(targetId)?.add(sourceId);
});
const visibleNodeIds = new Set<string>();
visibleNodeIds.add(targetNodeId);
// Traverse upstream (backward) - find all ancestors
const traverseUpstream = (nodeId: string) => {
const sources = backwardEdges.get(nodeId);
if (sources) {
sources.forEach((sourceId) => {
if (!visibleNodeIds.has(sourceId)) {
visibleNodeIds.add(sourceId);
traverseUpstream(sourceId);
}
});
}
};
// Traverse downstream (forward) - find all descendants
const traverseDownstream = (nodeId: string) => {
const targets = forwardEdges.get(nodeId);
if (targets) {
targets.forEach((targetId) => {
if (!visibleNodeIds.has(targetId)) {
visibleNodeIds.add(targetId);
traverseDownstream(targetId);
}
});
}
};
// Start traversal from the target node
traverseUpstream(targetNodeId);
traverseDownstream(targetNodeId);
// Also include findings directly connected to the selected node
edges.forEach((edge) => {
const sourceId = getEdgeNodeId(edge.source);
const targetId = getEdgeNodeId(edge.target);
const sourceNode = nodes.find((n) => n.id === sourceId);
const targetNode = nodes.find((n) => n.id === targetId);
const sourceIsFinding = sourceNode?.labels.some((l) =>
l.toLowerCase().includes("finding"),
);
const targetIsFinding = targetNode?.labels.some((l) =>
l.toLowerCase().includes("finding"),
);
// Include findings connected to the selected node
if (sourceId === targetNodeId && targetIsFinding) {
visibleNodeIds.add(targetId);
}
if (targetId === targetNodeId && sourceIsFinding) {
visibleNodeIds.add(sourceId);
}
});
// Filter nodes and edges to only include visible ones
const filteredNodes = nodes.filter((node) => visibleNodeIds.has(node.id));
const filteredEdges = edges.filter((edge) => {
const sourceId = getEdgeNodeId(edge.source);
const targetId = getEdgeNodeId(edge.target);
return visibleNodeIds.has(sourceId) && visibleNodeIds.has(targetId);
});
return {
nodes: filteredNodes,
edges: filteredEdges,
};
}
/**
* Custom hook for managing graph visualization state
* Handles graph data, node selection, zoom/pan, loading states, and filtered view
*/
export const useGraphState = () => {
const store = useGraphStore();
// Zustand store methods are stable, no need to memoize
const updateGraphData = (data: AttackPathGraphData) => {
store.setGraphData(data);
};
const selectNode = (nodeId: string | null) => {
store.setSelectedNodeId(nodeId);
};
const getSelectedNode = (): GraphNode | null => {
if (!store.data?.nodes || !store.selectedNodeId) return null;
return (
store.data.nodes.find((node) => node.id === store.selectedNodeId) || null
);
};
const startLoading = () => {
store.setLoading(true);
};
const stopLoading = () => {
store.setLoading(false);
};
const setError = (error: string | null) => {
store.setError(error);
};
const updateZoomAndPan = (zoomLevel: number, panX: number, panY: number) => {
store.setZoom(zoomLevel);
store.setPan(panX, panY);
};
const resetGraph = () => {
store.reset();
};
const clearGraph = () => {
store.setGraphData({ nodes: [], edges: [] });
store.setSelectedNodeId(null);
store.setFilteredView(false, null, null, null);
};
/**
* Enter filtered view mode - redraws graph with only the selected path
* Stores full data so we can restore it when exiting filtered view
*/
const enterFilteredView = (nodeId: string) => {
if (!store.data) return;
// Use fullData if we're already in filtered view, otherwise use current data
const sourceData = store.fullData || store.data;
const filteredData = computeFilteredSubgraph(sourceData, nodeId);
store.setFilteredView(true, nodeId, filteredData, sourceData);
};
/**
* Exit filtered view mode - restore full graph data
*/
const exitFilteredView = () => {
if (!store.isFilteredView || !store.fullData) return;
store.setFilteredView(false, null, store.fullData, null);
};
/**
* Get the node that was used to filter the view
*/
const getFilteredNode = (): GraphNode | null => {
if (!store.isFilteredView || !store.filteredNodeId) return null;
// Look in fullData since that's where the original node data is
const sourceData = store.fullData || store.data;
if (!sourceData) return null;
return (
sourceData.nodes.find((node) => node.id === store.filteredNodeId) || null
);
};
return {
data: store.data,
fullData: store.fullData,
selectedNodeId: store.selectedNodeId,
selectedNode: getSelectedNode(),
loading: store.loading,
error: store.error,
zoomLevel: store.zoomLevel,
panX: store.panX,
panY: store.panY,
isFilteredView: store.isFilteredView,
filteredNodeId: store.filteredNodeId,
filteredNode: getFilteredNode(),
updateGraphData,
selectNode,
startLoading,
stopLoading,
setError,
updateZoomAndPan,
resetGraph,
clearGraph,
enterFilteredView,
exitFilteredView,
};
};
@@ -0,0 +1,98 @@
"use client";
import { zodResolver } from "@hookform/resolvers/zod";
import { useEffect, useState } from "react";
import { useForm } from "react-hook-form";
import { z } from "zod";
import type { AttackPathQuery } from "@/types/attack-paths";
/**
* Custom hook for managing query builder form state
* Handles query selection, parameter validation, and form submission
*/
export const useQueryBuilder = (availableQueries: AttackPathQuery[]) => {
const [selectedQuery, setSelectedQuery] = useState<string | null>(null);
// Generate dynamic Zod schema based on selected query parameters
const getValidationSchema = (queryId: string | null) => {
const schemaObject: Record<string, z.ZodTypeAny> = {};
if (queryId) {
const query = availableQueries.find((q) => q.id === queryId);
if (query) {
query.attributes.parameters.forEach((param) => {
let fieldSchema: z.ZodTypeAny = z
.string()
.min(1, `${param.label} is required`);
if (param.data_type === "number") {
fieldSchema = z.coerce.number().refine((val) => val >= 0, {
message: `${param.label} must be a non-negative number`,
});
} else if (param.data_type === "boolean") {
fieldSchema = z.boolean().default(false);
}
schemaObject[param.name] = fieldSchema;
});
}
}
return z.object(schemaObject);
};
const getDefaultValues = (queryId: string | null) => {
const defaults: Record<string, unknown> = {};
const query = availableQueries.find((q) => q.id === queryId);
if (query) {
query.attributes.parameters.forEach((param) => {
defaults[param.name] = param.data_type === "boolean" ? false : "";
});
}
return defaults;
};
const form = useForm({
resolver: zodResolver(getValidationSchema(selectedQuery)),
mode: "onChange",
defaultValues: getDefaultValues(selectedQuery),
});
// Update form when selectedQuery changes
useEffect(() => {
form.reset(getDefaultValues(selectedQuery), {
keepDirtyValues: false,
});
}, [selectedQuery]); // eslint-disable-line react-hooks/exhaustive-deps
const selectedQueryData = availableQueries.find(
(q) => q.id === selectedQuery,
);
const handleQueryChange = (queryId: string) => {
setSelectedQuery(queryId);
form.reset();
};
const getQueryParameters = () => {
return form.getValues();
};
const isFormValid = () => {
return form.formState.isValid;
};
return {
selectedQuery,
selectedQueryData,
availableQueries,
form,
handleQueryChange,
getQueryParameters,
isFormValid,
};
};
@@ -0,0 +1,91 @@
"use client";
import { useRouter } from "next/navigation";
import { useCallback } from "react";
import { create } from "zustand";
import type { WizardState } from "@/types/attack-paths";
interface WizardStore extends WizardState {
setCurrentStep: (step: 1 | 2) => void;
setSelectedScanId: (scanId: string) => void;
setSelectedQuery: (queryId: string) => void;
setQueryParameters: (
parameters: Record<string, string | number | boolean>,
) => void;
reset: () => void;
}
const initialState: WizardState = {
currentStep: 1,
selectedScanId: null,
selectedQuery: null,
queryParameters: {},
};
const useWizardStore = create<WizardStore>((set) => ({
...initialState,
setCurrentStep: (step) => set({ currentStep: step }),
setSelectedScanId: (scanId) => set({ selectedScanId: scanId }),
setSelectedQuery: (queryId) => set({ selectedQuery: queryId }),
setQueryParameters: (parameters) => set({ queryParameters: parameters }),
reset: () => set(initialState),
}));
/**
* Custom hook for managing Attack Paths wizard state
* Handles step navigation, scan selection, and query configuration
*/
export const useWizardState = () => {
const router = useRouter();
const store = useWizardStore();
// Derive current step from URL path
const currentStep: 1 | 2 =
typeof window !== "undefined"
? window.location.pathname.includes("query-builder")
? 2
: 1
: 1;
const goToSelectScan = useCallback(() => {
store.setCurrentStep(1);
router.push("/attack-paths/select-scan");
}, [router, store]);
const goToQueryBuilder = useCallback(
(scanId: string) => {
store.setSelectedScanId(scanId);
store.setCurrentStep(2);
router.push(`/attack-paths/query-builder?scanId=${scanId}`);
},
[router, store],
);
const updateQueryParameters = useCallback(
(parameters: Record<string, string | number | boolean>) => {
store.setQueryParameters(parameters);
},
[store],
);
const getScanIdFromUrl = useCallback(() => {
const params = new URLSearchParams(
typeof window !== "undefined" ? window.location.search : "",
);
return params.get("scanId") || store.selectedScanId;
}, [store.selectedScanId]);
return {
currentStep,
selectedScanId: store.selectedScanId || getScanIdFromUrl(),
selectedQuery: store.selectedQuery,
queryParameters: store.queryParameters,
goToSelectScan,
goToQueryBuilder,
setSelectedQuery: store.setSelectedQuery,
updateQueryParameters,
reset: store.reset,
};
};
@@ -0,0 +1,145 @@
/**
* Export utilities for attack path graphs
* Handles exporting graph visualization to various formats
*/
/**
* Helper function to download a blob as a file
* @param blob The blob to download
* @param filename The name of the file
*/
const downloadBlob = (blob: Blob, filename: string) => {
const url = URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = url;
link.download = filename;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
};
/**
* Export graph as SVG image
* @param svgElement The SVG element to export
* @param filename The name of the file to download
*/
export const exportGraphAsSVG = (
svgElement: SVGSVGElement | null,
filename: string = "attack-path-graph.svg",
) => {
if (!svgElement) return;
try {
// Clone the SVG element to avoid modifying the original
const clonedSvg = svgElement.cloneNode(true) as SVGSVGElement;
// Find the main container group (first g element with transform)
const containerGroup = clonedSvg.querySelector("g");
if (!containerGroup) {
throw new Error("Could not find graph container");
}
// Get the bounding box of the actual graph content
// We need to get it from the original SVG since cloned elements don't have computed geometry
const originalContainer = svgElement.querySelector("g");
if (!originalContainer) {
throw new Error("Could not find original graph container");
}
const bbox = originalContainer.getBBox();
// Add padding around the content
const padding = 50;
const contentWidth = bbox.width + padding * 2;
const contentHeight = bbox.height + padding * 2;
// Set the SVG dimensions to fit the content
clonedSvg.setAttribute("width", `${contentWidth}`);
clonedSvg.setAttribute("height", `${contentHeight}`);
clonedSvg.setAttribute(
"viewBox",
`${bbox.x - padding} ${bbox.y - padding} ${contentWidth} ${contentHeight}`,
);
// Remove the zoom transform from the container - the viewBox now handles positioning
containerGroup.removeAttribute("transform");
// Add white background for better visibility
const bgRect = document.createElementNS(
"http://www.w3.org/2000/svg",
"rect",
);
bgRect.setAttribute("x", `${bbox.x - padding}`);
bgRect.setAttribute("y", `${bbox.y - padding}`);
bgRect.setAttribute("width", `${contentWidth}`);
bgRect.setAttribute("height", `${contentHeight}`);
bgRect.setAttribute("fill", "#1c1917"); // Dark background matching the app
clonedSvg.insertBefore(bgRect, clonedSvg.firstChild);
const svgData = new XMLSerializer().serializeToString(clonedSvg);
const blob = new Blob([svgData], { type: "image/svg+xml" });
downloadBlob(blob, filename);
} catch (error) {
console.error("Failed to export graph as SVG:", error);
throw new Error("Failed to export graph");
}
};
/**
* Export graph as PNG image
* @param svgElement The SVG element to export
* @param filename The name of the file to download
*/
export const exportGraphAsPNG = async (
svgElement: SVGSVGElement | null,
filename: string = "attack-path-graph.png",
) => {
if (!svgElement) return;
try {
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d") as CanvasRenderingContext2D;
if (!ctx) throw new Error("Could not get canvas context");
const svg = new Image();
svg.onload = () => {
canvas.width = svg.width;
canvas.height = svg.height;
ctx.drawImage(svg, 0, 0);
canvas.toBlob((blob) => {
if (blob) {
downloadBlob(blob, filename);
}
});
};
svg.onerror = () => {
throw new Error("Failed to load SVG for PNG conversion");
};
svg.src = `data:image/svg+xml;base64,${btoa(svgData)}`;
} catch (error) {
console.error("Failed to export graph as PNG:", error);
throw new Error("Failed to export graph");
}
};
/**
* Export graph data as JSON
* @param graphData The graph data to export
* @param filename The name of the file to download
*/
export const exportGraphAsJSON = (
graphData: Record<string, unknown>,
filename: string = "attack-path-graph.json",
) => {
try {
const jsonString = JSON.stringify(graphData, null, 2);
const blob = new Blob([jsonString], { type: "application/json" });
downloadBlob(blob, filename);
} catch (error) {
console.error("Failed to export graph as JSON:", error);
throw new Error("Failed to export graph");
}
};
@@ -0,0 +1,25 @@
/**
* Formatting utilities for attack path graph nodes
*/
/**
* Format camelCase labels to space-separated text
* e.g., "ProwlerFinding" -> "Prowler Finding", "AWSAccount" -> "Aws Account"
*/
export function formatNodeLabel(label: string): string {
return label
.replace(/([A-Z]+)([A-Z][a-z])/g, "$1 $2")
.replace(/([a-z\d])([A-Z])/g, "$1 $2")
.trim()
.split(" ")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
.join(" ");
}
/**
* Format multiple node labels into a readable string
* e.g., ["ProwlerFinding"] -> "Prowler Finding"
*/
export function formatNodeLabels(labels: string[]): string {
return labels.map(formatNodeLabel).join(", ");
}
@@ -0,0 +1,143 @@
/**
* Color constants for attack path graph visualization
* Colors chosen to work well in both light and dark themes
*/
/**
* Node fill colors - darker versions of design system severity colors
* Darkened to ensure white text has proper contrast (WCAG AA)
*/
export const GRAPH_NODE_COLORS = {
// Finding severities - darkened versions for white text readability
critical: "#cc0055", // Darker pink (from #ff006a)
high: "#c45a3a", // Darker coral (from #f77852)
medium: "#b8860b", // Dark goldenrod (from #fec94d)
low: "#8b9a3e", // Olive/dark yellow-green (from #fdfbd4)
info: "#2563eb", // Darker blue (from #3c8dff)
// Node types
prowlerFinding: "#ea580c",
awsAccount: "#f59e0b", // Amber 500 - AWS orange
attackPattern: "#16a34a",
summary: "#16a34a",
// Infrastructure
ec2Instance: "#0891b2", // Cyan 600
s3Bucket: "#0284c7", // Sky 600
iamRole: "#7c3aed", // Violet 600
iamPolicy: "#7c3aed",
lambdaFunction: "#d97706", // Amber 600
securityGroup: "#0891b2",
default: "#0891b2",
} as const;
/**
* Node border colors - using original design system colors as borders (lighter than fill)
*/
export const GRAPH_NODE_BORDER_COLORS = {
critical: "#ff006a", // Original --bg-data-critical
high: "#f77852", // Original --bg-data-high
medium: "#fec94d", // Original --bg-data-medium
low: "#c4d4a0", // Lighter olive
info: "#3c8dff", // Original --bg-data-info
prowlerFinding: "#fb923c",
awsAccount: "#fbbf24", // Amber 400
attackPattern: "#4ade80",
summary: "#4ade80",
ec2Instance: "#22d3ee", // Cyan 400
s3Bucket: "#38bdf8", // Sky 400
iamRole: "#a78bfa", // Violet 400
iamPolicy: "#a78bfa",
lambdaFunction: "#fbbf24",
securityGroup: "#22d3ee",
default: "#22d3ee",
} as const;
export const GRAPH_EDGE_COLOR = "#ffffff"; // White (default)
export const GRAPH_EDGE_HIGHLIGHT_COLOR = "#f97316"; // Orange 500 (on hover)
export const GRAPH_EDGE_GLOW_COLOR = "#fb923c";
export const GRAPH_SELECTION_COLOR = "#ffffff";
export const GRAPH_BORDER_COLOR = "#374151";
export const GRAPH_ALERT_BORDER_COLOR = "#ef4444"; // Red 500 - for resources with findings
/**
* Get node fill color based on labels and properties
*/
export const getNodeColor = (
labels: string[],
properties?: Record<string, unknown>,
): string => {
const isFinding = labels.some((l) => l.toLowerCase().includes("finding"));
const isPrivilegeEscalation = labels.some((l) => l === "PrivilegeEscalation");
if ((isFinding || isPrivilegeEscalation) && properties?.severity) {
const severity = String(properties.severity).toLowerCase();
if (severity === "critical") return GRAPH_NODE_COLORS.critical;
if (severity === "high") return GRAPH_NODE_COLORS.high;
if (severity === "medium") return GRAPH_NODE_COLORS.medium;
if (severity === "low") return GRAPH_NODE_COLORS.low;
if (severity === "informational" || severity === "info")
return GRAPH_NODE_COLORS.info;
return GRAPH_NODE_COLORS.prowlerFinding;
}
if (labels.some((l) => l.toLowerCase().includes("attackpattern")))
return GRAPH_NODE_COLORS.attackPattern;
if (labels.includes("AWSAccount")) return GRAPH_NODE_COLORS.awsAccount;
if (labels.includes("EC2Instance")) return GRAPH_NODE_COLORS.ec2Instance;
if (labels.includes("S3Bucket")) return GRAPH_NODE_COLORS.s3Bucket;
if (labels.includes("IAMRole")) return GRAPH_NODE_COLORS.iamRole;
if (labels.includes("IAMPolicy")) return GRAPH_NODE_COLORS.iamPolicy;
if (labels.includes("LambdaFunction"))
return GRAPH_NODE_COLORS.lambdaFunction;
if (labels.includes("SecurityGroup")) return GRAPH_NODE_COLORS.securityGroup;
return GRAPH_NODE_COLORS.default;
};
/**
* Get node border color based on labels and properties
*/
export const getNodeBorderColor = (
labels: string[],
properties?: Record<string, unknown>,
): string => {
const isFinding = labels.some((l) => l.toLowerCase().includes("finding"));
const isPrivilegeEscalation = labels.some((l) => l === "PrivilegeEscalation");
if ((isFinding || isPrivilegeEscalation) && properties?.severity) {
const severity = String(properties.severity).toLowerCase();
if (severity === "critical") return GRAPH_NODE_BORDER_COLORS.critical;
if (severity === "high") return GRAPH_NODE_BORDER_COLORS.high;
if (severity === "medium") return GRAPH_NODE_BORDER_COLORS.medium;
if (severity === "low") return GRAPH_NODE_BORDER_COLORS.low;
if (severity === "informational" || severity === "info")
return GRAPH_NODE_BORDER_COLORS.info;
return GRAPH_NODE_BORDER_COLORS.prowlerFinding;
}
if (labels.some((l) => l.toLowerCase().includes("attackpattern")))
return GRAPH_NODE_BORDER_COLORS.attackPattern;
if (labels.includes("AWSAccount")) return GRAPH_NODE_BORDER_COLORS.awsAccount;
if (labels.includes("EC2Instance"))
return GRAPH_NODE_BORDER_COLORS.ec2Instance;
if (labels.includes("S3Bucket")) return GRAPH_NODE_BORDER_COLORS.s3Bucket;
if (labels.includes("IAMRole")) return GRAPH_NODE_BORDER_COLORS.iamRole;
if (labels.includes("IAMPolicy")) return GRAPH_NODE_BORDER_COLORS.iamPolicy;
if (labels.includes("LambdaFunction"))
return GRAPH_NODE_BORDER_COLORS.lambdaFunction;
if (labels.includes("SecurityGroup"))
return GRAPH_NODE_BORDER_COLORS.securityGroup;
return GRAPH_NODE_BORDER_COLORS.default;
};
/**
* Check if a background color is light (for determining text color)
*/
export const isLightBackground = (backgroundColor: string): boolean => {
const hex = backgroundColor.replace("#", "");
const r = parseInt(hex.substring(0, 2), 16);
const g = parseInt(hex.substring(2, 4), 16);
const b = parseInt(hex.substring(4, 6), 16);
const luminance = (0.299 * r + 0.587 * g + 0.114 * b) / 255;
return luminance > 0.5;
};
@@ -0,0 +1,69 @@
/**
* Utility functions for attack path graph operations
*/
/**
* Find edges in the path from a given node.
* Upstream: follows only ONE parent path (first parent at each level) to avoid lighting up siblings
* Downstream: follows ALL children recursively
*
* Uses pre-built adjacency maps for O(1) lookups instead of O(n) array searches per traversal step.
*
* @param nodeId - The starting node ID
* @param edges - Array of edges with sourceId and targetId
* @returns Set of edge IDs in the format "sourceId-targetId"
*/
export const getPathEdges = (
nodeId: string,
edges: Array<{ sourceId: string; targetId: string }>,
): Set<string> => {
// Build adjacency maps once - O(n)
const parentMap = new Map<string, { sourceId: string; targetId: string }>();
const childrenMap = new Map<
string,
Array<{ sourceId: string; targetId: string }>
>();
edges.forEach((edge) => {
// First parent only (matches original behavior of find())
if (!parentMap.has(edge.targetId)) {
parentMap.set(edge.targetId, edge);
}
const children = childrenMap.get(edge.sourceId) || [];
children.push(edge);
childrenMap.set(edge.sourceId, children);
});
const pathEdgeIds = new Set<string>();
const visitedNodes = new Set<string>();
// Traverse upstream - only follow ONE parent at each level (first found)
// This creates a single path to the root, not all paths
const traverseUpstream = (currentNodeId: string) => {
if (visitedNodes.has(`up-${currentNodeId}`)) return;
visitedNodes.add(`up-${currentNodeId}`);
const parentEdge = parentMap.get(currentNodeId); // O(1) lookup
if (parentEdge) {
pathEdgeIds.add(`${parentEdge.sourceId}-${parentEdge.targetId}`);
traverseUpstream(parentEdge.sourceId);
}
};
// Traverse downstream (find ALL targets from this node)
const traverseDownstream = (currentNodeId: string) => {
if (visitedNodes.has(`down-${currentNodeId}`)) return;
visitedNodes.add(`down-${currentNodeId}`);
const children = childrenMap.get(currentNodeId) || []; // O(1) lookup
children.forEach((edge) => {
pathEdgeIds.add(`${edge.sourceId}-${edge.targetId}`);
traverseDownstream(edge.targetId);
});
};
traverseUpstream(nodeId);
traverseDownstream(nodeId);
return pathEdgeIds;
};
@@ -0,0 +1,17 @@
export {
exportGraphAsJSON,
exportGraphAsPNG,
exportGraphAsSVG,
} from "./export";
export { formatNodeLabel, formatNodeLabels } from "./format";
export { getPathEdges } from "./graph-utils";
export {
getNodeBorderColor,
getNodeColor,
GRAPH_ALERT_BORDER_COLOR,
GRAPH_EDGE_COLOR,
GRAPH_EDGE_HIGHLIGHT_COLOR,
GRAPH_NODE_BORDER_COLORS,
GRAPH_NODE_COLORS,
GRAPH_SELECTION_COLOR,
} from "./graph-colors";
@@ -0,0 +1,626 @@
"use client";
import { ArrowLeft, Maximize2, X } from "lucide-react";
import { useSearchParams } from "next/navigation";
import { Suspense, useCallback, useEffect, useRef, useState } from "react";
import { FormProvider } from "react-hook-form";
import {
executeQuery,
getAttackPathScans,
getAvailableQueries,
} from "@/actions/attack-paths";
import { adaptQueryResultToGraphData } from "@/actions/attack-paths/query-result.adapter";
import { AutoRefresh } from "@/components/scans";
import { Button, Card, CardContent } from "@/components/shadcn";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogTrigger,
useToast,
} from "@/components/ui";
import type {
AttackPathQuery,
AttackPathScan,
GraphNode,
} from "@/types/attack-paths";
import {
AttackPathGraph,
ExecuteButton,
GraphControls,
GraphLegend,
GraphLoading,
NodeDetailContent,
QueryParametersForm,
QuerySelector,
ScanListTable,
} from "./_components";
import type { AttackPathGraphRef } from "./_components/graph/attack-path-graph";
import { useGraphState } from "./_hooks/use-graph-state";
import { useQueryBuilder } from "./_hooks/use-query-builder";
import { exportGraphAsSVG, formatNodeLabel } from "./_lib";
/**
* Attack Paths Analysis
* Allows users to select a scan, build a query, and visualize the Attack Paths graph
*/
export default function AttackPathAnalysisPage() {
const searchParams = useSearchParams();
const scanId = searchParams.get("scanId");
const graphState = useGraphState();
const { toast } = useToast();
const [scansLoading, setScansLoading] = useState(true);
const [scans, setScans] = useState<AttackPathScan[]>([]);
const [queriesLoading, setQueriesLoading] = useState(true);
const [queriesError, setQueriesError] = useState<string | null>(null);
const [isFullscreenOpen, setIsFullscreenOpen] = useState(false);
const graphRef = useRef<AttackPathGraphRef>(null);
const fullscreenGraphRef = useRef<AttackPathGraphRef>(null);
const hasResetRef = useRef(false);
const nodeDetailsRef = useRef<HTMLDivElement>(null);
const graphContainerRef = useRef<HTMLDivElement>(null);
const [queries, setQueries] = useState<AttackPathQuery[]>([]);
// Use custom hook for query builder form state and validation
const queryBuilder = useQueryBuilder(queries);
// Reset graph state when component mounts
useEffect(() => {
if (!hasResetRef.current) {
hasResetRef.current = true;
graphState.resetGraph();
}
}, [graphState]);
// Load available scans on mount
useEffect(() => {
const loadScans = async () => {
setScansLoading(true);
try {
const scansData = await getAttackPathScans();
if (scansData?.data) {
setScans(scansData.data);
} else {
setScans([]);
}
} catch (error) {
console.error("Failed to load scans:", error);
setScans([]);
} finally {
setScansLoading(false);
}
};
loadScans();
}, []);
// Check if there's an executing scan for auto-refresh
const hasExecutingScan = scans.some(
(scan) =>
scan.attributes.state === "executing" ||
scan.attributes.state === "scheduled",
);
// Callback to refresh scans (used by AutoRefresh component)
const refreshScans = useCallback(async () => {
try {
const scansData = await getAttackPathScans();
if (scansData?.data) {
setScans(scansData.data);
}
} catch (error) {
console.error("Failed to refresh scans:", error);
}
}, []);
// Load available queries on mount
useEffect(() => {
const loadQueries = async () => {
if (!scanId) {
setQueriesError("No scan selected");
setQueriesLoading(false);
return;
}
setQueriesLoading(true);
try {
const queriesData = await getAvailableQueries(scanId);
if (queriesData?.data) {
setQueries(queriesData.data);
setQueriesError(null);
} else {
setQueriesError("Failed to load available queries");
toast({
title: "Error",
description: "Failed to load queries for this scan",
variant: "destructive",
});
}
} catch (error) {
const errorMsg =
error instanceof Error ? error.message : "Unknown error";
setQueriesError(errorMsg);
toast({
title: "Error",
description: "Failed to load queries",
variant: "destructive",
});
} finally {
setQueriesLoading(false);
}
};
loadQueries();
}, [scanId, toast]);
const handleQueryChange = (queryId: string) => {
queryBuilder.handleQueryChange(queryId);
};
const showErrorToast = (title: string, description: string) => {
toast({
title,
description,
variant: "destructive",
});
};
const handleExecuteQuery = async () => {
if (!scanId || !queryBuilder.selectedQuery) {
showErrorToast("Error", "Please select both a scan and a query");
return;
}
// Validate form before executing query
const isValid = await queryBuilder.form.trigger();
if (!isValid) {
showErrorToast(
"Validation Error",
"Please fill in all required parameters",
);
return;
}
graphState.startLoading();
graphState.setError(null);
try {
const parameters = queryBuilder.getQueryParameters() as Record<
string,
string | number | boolean
>;
const result = await executeQuery(
scanId,
queryBuilder.selectedQuery,
parameters,
);
if (result?.data?.attributes) {
const graphData = adaptQueryResultToGraphData(result.data.attributes);
graphState.updateGraphData(graphData);
toast({
title: "Success",
description: "Query executed successfully",
variant: "default",
});
// Scroll to graph after successful query execution
setTimeout(() => {
graphContainerRef.current?.scrollIntoView({
behavior: "smooth",
block: "start",
});
}, 100);
} else {
graphState.resetGraph();
graphState.setError("No data returned from query");
showErrorToast("Error", "Query returned no data");
}
} catch (error) {
const errorMsg =
error instanceof Error ? error.message : "Failed to execute query";
graphState.resetGraph();
graphState.setError(errorMsg);
showErrorToast("Error", errorMsg);
} finally {
graphState.stopLoading();
}
};
const handleNodeClick = (node: GraphNode) => {
// Enter filtered view showing only paths containing this node
graphState.enterFilteredView(node.id);
// For findings, also scroll to the details section
const isFinding = node.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
if (isFinding) {
setTimeout(() => {
nodeDetailsRef.current?.scrollIntoView({
behavior: "smooth",
block: "nearest",
});
}, 100);
}
};
const handleBackToFullView = () => {
graphState.exitFilteredView();
};
const handleCloseDetails = () => {
graphState.selectNode(null);
};
const handleGraphExport = (svgElement: SVGSVGElement | null) => {
try {
if (svgElement) {
exportGraphAsSVG(svgElement, "attack-path-graph.svg");
toast({
title: "Success",
description: "Graph exported as SVG",
variant: "default",
});
} else {
throw new Error("Could not find graph element");
}
} catch (error) {
toast({
title: "Error",
description:
error instanceof Error ? error.message : "Failed to export graph",
variant: "destructive",
});
}
};
return (
<div className="flex flex-col gap-6">
{/* Auto-refresh scans when there's an executing scan */}
<AutoRefresh
hasExecutingScan={hasExecutingScan}
onRefresh={refreshScans}
/>
{/* Header */}
<div>
<h2 className="dark:text-prowler-theme-pale/90 text-xl font-semibold">
Attack Paths Analysis
</h2>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-2 text-sm">
Select a scan, build a query, and visualize Attack Paths in your
infrastructure.
</p>
</div>
{/* Top Section - Scans Table and Query Builder (2 columns) */}
<div className="grid grid-cols-1 gap-8 xl:grid-cols-2">
{/* Scans Table Section - Left Column */}
<div>
{scansLoading ? (
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
<p className="text-sm">Loading scans...</p>
</div>
) : scans.length === 0 ? (
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
<p className="text-sm">No scans available</p>
</div>
) : (
<Suspense fallback={<div>Loading scans...</div>}>
<ScanListTable scans={scans} />
</Suspense>
)}
</div>
{/* Query Builder Section - Right Column */}
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
{!scanId ? (
<p className="text-text-info dark:text-text-info text-sm">
Select a scan from the table on the left to begin.
</p>
) : queriesLoading ? (
<p className="text-sm">Loading queries...</p>
) : queriesError ? (
<p className="text-text-danger dark:text-text-danger text-sm">
{queriesError}
</p>
) : (
<>
<FormProvider {...queryBuilder.form}>
<QuerySelector
queries={queries}
selectedQueryId={queryBuilder.selectedQuery}
onQueryChange={handleQueryChange}
/>
{queryBuilder.selectedQuery && (
<QueryParametersForm
selectedQuery={queryBuilder.selectedQueryData}
/>
)}
</FormProvider>
<div className="flex gap-3">
<ExecuteButton
isLoading={graphState.loading}
isDisabled={!queryBuilder.selectedQuery}
onExecute={handleExecuteQuery}
/>
</div>
{graphState.error && (
<div className="bg-bg-danger-secondary text-text-danger dark:bg-bg-danger-secondary dark:text-text-danger rounded p-3 text-sm">
{graphState.error}
</div>
)}
</>
)}
</div>
</div>
{/* Bottom Section - Graph Visualization (Full Width) */}
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
{graphState.loading ? (
<GraphLoading />
) : graphState.data &&
graphState.data.nodes &&
graphState.data.nodes.length > 0 ? (
<>
{/* Info message and controls */}
<div className="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
{graphState.isFilteredView ? (
<div className="flex items-center gap-3">
<Button
onClick={handleBackToFullView}
variant="outline"
size="sm"
className="gap-2"
aria-label="Return to full graph view"
>
<ArrowLeft size={16} />
Back to Full View
</Button>
<div
className="bg-bg-info-secondary text-text-info inline-flex cursor-default items-center gap-2 rounded-md px-3 py-2 text-xs font-medium shadow-sm sm:px-4 sm:text-sm"
role="status"
aria-label="Filtered view active"
>
<span className="flex-shrink-0" aria-hidden="true">
🔍
</span>
<span className="flex-1">
Showing paths for:{" "}
<strong>
{graphState.filteredNode?.properties?.name ||
graphState.filteredNode?.properties?.id ||
"Selected node"}
</strong>
</span>
</div>
</div>
) : (
<div
className="bg-button-primary inline-flex cursor-default items-center gap-2 rounded-md px-3 py-2 text-xs font-medium text-black shadow-sm sm:px-4 sm:text-sm"
role="status"
aria-label="Graph interaction instructions"
>
<span className="flex-shrink-0" aria-hidden="true">
💡
</span>
<span className="flex-1">
Click on any node to filter and view its connected paths
</span>
</div>
)}
{/* Graph controls and fullscreen button together */}
<div className="flex items-center gap-2">
<GraphControls
onZoomIn={() => graphRef.current?.zoomIn()}
onZoomOut={() => graphRef.current?.zoomOut()}
onFitToScreen={() => graphRef.current?.resetZoom()}
onExport={() =>
handleGraphExport(graphRef.current?.getSVGElement() || null)
}
/>
{/* Fullscreen button */}
<div className="border-border-neutral-primary bg-bg-neutral-tertiary flex gap-1 rounded-lg border p-1">
<Dialog
open={isFullscreenOpen}
onOpenChange={setIsFullscreenOpen}
>
<DialogTrigger asChild>
<Button
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
aria-label="Fullscreen"
>
<Maximize2 size={18} />
</Button>
</DialogTrigger>
<DialogContent className="flex h-full max-h-screen w-full max-w-full flex-col gap-0 p-0">
<DialogHeader className="px-4 pt-4 sm:px-6 sm:pt-6">
<DialogTitle className="text-lg">
Graph Fullscreen View
</DialogTitle>
</DialogHeader>
<div className="px-4 pt-4 pb-4 sm:px-6 sm:pt-6">
<GraphControls
onZoomIn={() => fullscreenGraphRef.current?.zoomIn()}
onZoomOut={() =>
fullscreenGraphRef.current?.zoomOut()
}
onFitToScreen={() =>
fullscreenGraphRef.current?.resetZoom()
}
onExport={() =>
handleGraphExport(
fullscreenGraphRef.current?.getSVGElement() ||
null,
)
}
/>
</div>
<div className="flex flex-1 gap-4 overflow-hidden px-4 pb-4 sm:px-6 sm:pb-6">
<div className="flex flex-1 items-center justify-center">
<AttackPathGraph
ref={fullscreenGraphRef}
data={graphState.data}
onNodeClick={handleNodeClick}
selectedNodeId={graphState.selectedNodeId}
isFilteredView={graphState.isFilteredView}
/>
</div>
{/* Node Detail Panel - Side by side */}
{graphState.selectedNode && (
<section aria-labelledby="node-details-heading">
<Card className="w-96 overflow-y-auto">
<CardContent className="p-4">
<div className="mb-4 flex items-center justify-between">
<h3
id="node-details-heading"
className="text-sm font-semibold"
>
Node Details
</h3>
<Button
onClick={handleCloseDetails}
variant="ghost"
size="sm"
className="h-6 w-6 p-0"
aria-label="Close node details"
>
<X size={16} />
</Button>
</div>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mb-4 text-xs">
{graphState.selectedNode?.labels.some(
(label) =>
label.toLowerCase().includes("finding"),
)
? graphState.selectedNode?.properties
?.check_title ||
graphState.selectedNode?.properties?.id ||
"Unknown Finding"
: graphState.selectedNode?.properties
?.name ||
graphState.selectedNode?.properties?.id ||
"Unknown Resource"}
</p>
<div className="flex flex-col gap-4">
<div>
<h4 className="mb-2 text-xs font-semibold">
Type
</h4>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
{graphState.selectedNode?.labels
.map(formatNodeLabel)
.join(", ")}
</p>
</div>
</div>
</CardContent>
</Card>
</section>
)}
</div>
</DialogContent>
</Dialog>
</div>
</div>
</div>
{/* Graph in the middle */}
<div ref={graphContainerRef} className="h-[calc(100vh-22rem)]">
<AttackPathGraph
ref={graphRef}
data={graphState.data}
onNodeClick={handleNodeClick}
selectedNodeId={graphState.selectedNodeId}
isFilteredView={graphState.isFilteredView}
/>
</div>
{/* Legend below */}
<div className="hidden justify-center lg:flex">
<GraphLegend data={graphState.data} />
</div>
</>
) : (
<div className="flex flex-1 items-center justify-center text-center">
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary text-sm">
Select a query and click &quot;Execute Query&quot; to visualize
the Attack Paths graph
</p>
</div>
)}
</div>
{/* Node Detail Panel - Below Graph */}
{graphState.selectedNode && graphState.data && (
<div
ref={nodeDetailsRef}
className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4"
>
<div className="flex items-center justify-between">
<div className="flex-1">
<h3 className="text-lg font-semibold">Node Details</h3>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-1 text-sm">
{String(
graphState.selectedNode.labels.some((label) =>
label.toLowerCase().includes("finding"),
)
? graphState.selectedNode.properties?.check_title ||
graphState.selectedNode.properties?.id ||
"Unknown Finding"
: graphState.selectedNode.properties?.name ||
graphState.selectedNode.properties?.id ||
"Unknown Resource",
)}
</p>
</div>
<div className="flex items-center gap-2">
{graphState.selectedNode.labels.some((label) =>
label.toLowerCase().includes("finding"),
) && (
<Button asChild variant="default" size="sm">
<a
href={`/findings?id=${String(graphState.selectedNode.properties?.id || graphState.selectedNode.id)}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View finding ${String(graphState.selectedNode.properties?.id || graphState.selectedNode.id)}`}
>
View Finding
</a>
</Button>
)}
<Button
onClick={handleCloseDetails}
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
aria-label="Close node details"
>
<X size={16} />
</Button>
</div>
</div>
<NodeDetailContent
node={graphState.selectedNode}
allNodes={graphState.data.nodes}
/>
</div>
)}
</div>
);
}
+9
View File
@@ -0,0 +1,9 @@
import { redirect } from "next/navigation";
/**
* Landing page for Attack Paths feature
* Redirects to the integrated attack path analysis view
*/
export default function AttackPathsPage() {
redirect("/attack-paths/query-builder");
}
+76 -9
View File
@@ -2,6 +2,7 @@ import { Spacer } from "@heroui/spacer";
import { Suspense } from "react";
import {
getFindingById,
getFindings,
getLatestFindings,
getLatestMetadataInfo,
@@ -9,6 +10,7 @@ import {
} from "@/actions/findings";
import { getProviders } from "@/actions/providers";
import { getScans } from "@/actions/scans";
import { FindingDetailsSheet } from "@/components/findings";
import { FindingsFilters } from "@/components/findings/findings-filters";
import {
ColumnFindings,
@@ -43,15 +45,79 @@ export default async function Findings({
// Check if the searchParams contain any date or scan filter
const hasDateOrScan = hasDateOrScanFilter(resolvedSearchParams);
const [metadataInfoData, providersData, scansData] = await Promise.all([
(hasDateOrScan ? getMetadataInfo : getLatestMetadataInfo)({
query,
sort: encodedSort,
filters,
}),
getProviders({ pageSize: 50 }),
getScans({ pageSize: 50 }),
]);
// Check if there's a specific finding ID to fetch
const findingId = resolvedSearchParams.id?.toString();
const [metadataInfoData, providersData, scansData, findingByIdData] =
await Promise.all([
(hasDateOrScan ? getMetadataInfo : getLatestMetadataInfo)({
query,
sort: encodedSort,
filters,
}),
getProviders({ pageSize: 50 }),
getScans({ pageSize: 50 }),
findingId
? getFindingById(findingId, "resources,scan.provider")
: Promise.resolve(null),
]);
// Process the finding data to match the expected structure
const processedFinding = findingByIdData?.data
? (() => {
const finding = findingByIdData.data;
const included = findingByIdData.included || [];
// Build dictionaries from included data
type IncludedItem = {
type: string;
id: string;
attributes: Record<string, unknown>;
relationships?: {
provider?: { data?: { id: string } };
};
};
const resourceDict: Record<string, unknown> = {};
const scanDict: Record<string, IncludedItem> = {};
const providerDict: Record<string, unknown> = {};
included.forEach((item: IncludedItem) => {
if (item.type === "resources") {
resourceDict[item.id] = {
id: item.id,
attributes: item.attributes,
};
} else if (item.type === "scans") {
scanDict[item.id] = item;
} else if (item.type === "providers") {
providerDict[item.id] = {
id: item.id,
attributes: item.attributes,
};
}
});
const scanId = finding.relationships?.scan?.data?.id;
const resourceId = finding.relationships?.resources?.data?.[0]?.id;
const scan = scanId ? scanDict[scanId] : undefined;
const providerId = scan?.relationships?.provider?.data?.id;
const resource = resourceId ? resourceDict[resourceId] : undefined;
const provider = providerId ? providerDict[providerId] : undefined;
return {
...finding,
relationships: {
scan: scan
? { data: scan, attributes: scan.attributes }
: undefined,
resource: resource,
provider: provider,
},
} as FindingProps;
})()
: null;
// Extract unique regions and services from the new endpoint
const uniqueRegions = metadataInfoData?.data?.attributes?.regions || [];
@@ -98,6 +164,7 @@ export default async function Findings({
<Suspense key={searchParamsKey} fallback={<SkeletonTableFindings />}>
<SSRDataTable searchParams={resolvedSearchParams} />
</Suspense>
{processedFinding && <FindingDetailsSheet finding={processedFinding} />}
</ContentLayout>
);
}
@@ -0,0 +1,46 @@
"use client";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import {
Sheet,
SheetContent,
SheetDescription,
SheetHeader,
SheetTitle,
} from "@/components/ui/sheet";
import { FindingProps } from "@/types/components";
import { FindingDetail } from "./table/finding-detail";
interface FindingDetailsSheetProps {
finding: FindingProps;
}
export const FindingDetailsSheet = ({ finding }: FindingDetailsSheetProps) => {
const router = useRouter();
const pathname = usePathname();
const searchParams = useSearchParams();
const handleOpenChange = (open: boolean) => {
if (!open) {
const params = new URLSearchParams(searchParams.toString());
params.delete("id");
router.push(`${pathname}?${params.toString()}`, { scroll: false });
}
};
return (
<Sheet open={true} onOpenChange={handleOpenChange}>
<SheetContent className="my-4 max-h-[calc(100vh-2rem)] max-w-[95vw] overflow-y-auto pt-10 md:my-8 md:max-h-[calc(100vh-4rem)] md:max-w-[55vw]">
<SheetHeader>
<SheetTitle className="sr-only">Finding Details</SheetTitle>
<SheetDescription className="sr-only">
View the finding details
</SheetDescription>
</SheetHeader>
<FindingDetail findingDetails={finding} />
</SheetContent>
</Sheet>
);
};
+1
View File
@@ -1 +1,2 @@
export * from "./finding-details-sheet";
export * from "./muted";
@@ -2,7 +2,7 @@
import { ColumnDef } from "@tanstack/react-table";
import { Database } from "lucide-react";
import { useSearchParams } from "next/navigation";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { DataTableRowDetails } from "@/components/findings/table";
import { DataTableRowActions } from "@/components/findings/table/data-table-row-actions";
@@ -51,13 +51,18 @@ const getProviderData = (
);
};
const FindingDetailsCell = ({ row }: { row: any }) => {
const FindingDetailsCell = ({ row }: { row: { original: FindingProps } }) => {
const router = useRouter();
const pathname = usePathname();
const searchParams = useSearchParams();
const findingId = searchParams.get("id");
const isOpen = findingId === row.original.id;
const findingIdFromUrl = searchParams.get("id");
// If there's an id in the URL, the sheet is controlled by FindingDetailsSheet component
// so we don't open a local sheet for any row
const isUrlControlled = !!findingIdFromUrl;
const handleOpenChange = (open: boolean) => {
const params = new URLSearchParams(searchParams);
const params = new URLSearchParams(searchParams.toString());
if (open) {
params.set("id", row.original.id);
@@ -65,7 +70,7 @@ const FindingDetailsCell = ({ row }: { row: any }) => {
params.delete("id");
}
window.history.pushState({}, "", `?${params.toString()}`);
router.push(`${pathname}?${params.toString()}`, { scroll: false });
};
return (
@@ -76,7 +81,7 @@ const FindingDetailsCell = ({ row }: { row: any }) => {
}
title="Finding Details"
description="View the finding details"
defaultOpen={isOpen}
open={isUrlControlled ? false : undefined}
onOpenChange={handleOpenChange}
>
<DataTableRowDetails
+40 -28
View File
@@ -1,7 +1,5 @@
"use client";
import { Snippet } from "@heroui/snippet";
import { Tooltip } from "@heroui/tooltip";
import { ExternalLink, Link } from "lucide-react";
import ReactMarkdown from "react-markdown";
@@ -11,6 +9,9 @@ import {
CardContent,
CardHeader,
CardTitle,
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/shadcn";
import { CodeSnippet } from "@/components/ui/code-snippet/code-snippet";
import { CustomLink } from "@/components/ui/custom/custom-link";
@@ -18,6 +19,7 @@ import { EntityInfo, InfoField } from "@/components/ui/entities";
import { DateWithTime } from "@/components/ui/entities/date-with-time";
import { SeverityBadge } from "@/components/ui/table/severity-badge";
import { buildGitFileUrl, extractLineRangeFromUid } from "@/lib/iac-utils";
import { cn } from "@/lib/utils";
import { FindingProps, ProviderType } from "@/types";
import { Muted } from "../muted";
@@ -83,14 +85,17 @@ export const FindingDetail = ({
<div>
<h2 className="dark:text-prowler-theme-pale/90 line-clamp-2 flex items-center gap-2 text-lg leading-tight font-medium text-gray-800">
{renderValue(attributes.check_metadata.checktitle)}
<Tooltip content="Copy finding link to clipboard" size="sm">
<button
onClick={() => navigator.clipboard.writeText(url)}
className="text-bg-data-info inline-flex cursor-pointer transition-opacity hover:opacity-80"
aria-label="Copy finding link to clipboard"
>
<Link size={16} />
</button>
<Tooltip>
<TooltipTrigger asChild>
<button
onClick={() => navigator.clipboard.writeText(url)}
className="text-text-info inline-flex cursor-pointer transition-opacity hover:opacity-80"
aria-label="Copy finding link to clipboard"
>
<Link size={16} />
</button>
</TooltipTrigger>
<TooltipContent>Copy finding link to clipboard</TooltipContent>
</Tooltip>
</h2>
</div>
@@ -164,16 +169,16 @@ export const FindingDetail = ({
{attributes.status === "FAIL" && (
<InfoField label="Risk" variant="simple">
<Snippet
className="max-w-full py-2"
color="danger"
hideCopyButton
hideSymbol
<div
className={cn(
"max-w-full rounded-md border p-2",
"border-border-error-primary bg-bg-fail-secondary",
)}
>
<MarkdownContainer>
{attributes.check_metadata.risk}
</MarkdownContainer>
</Snippet>
</div>
</InfoField>
)}
@@ -223,11 +228,13 @@ export const FindingDetail = ({
{/* CLI Command section */}
{attributes.check_metadata.remediation.code.cli && (
<InfoField label="CLI Command" variant="simple">
<Snippet>
<div
className={cn("rounded-md p-2", "bg-bg-neutral-tertiary")}
>
<span className="text-xs whitespace-pre-line">
{attributes.check_metadata.remediation.code.cli}
</span>
</Snippet>
</div>
</InfoField>
)}
@@ -276,16 +283,21 @@ export const FindingDetail = ({
<CardTitle>Resource Details</CardTitle>
{providerDetails.provider === "iac" && gitUrl && (
<CardAction>
<Tooltip content="Go to Resource in the Repository" size="sm">
<a
href={gitUrl}
target="_blank"
rel="noopener noreferrer"
className="text-bg-data-info inline-flex cursor-pointer"
aria-label="Open resource in repository"
>
<ExternalLink size={16} className="inline" />
</a>
<Tooltip>
<TooltipTrigger asChild>
<a
href={gitUrl}
target="_blank"
rel="noopener noreferrer"
className="text-text-info inline-flex cursor-pointer"
aria-label="Open resource in repository"
>
<ExternalLink size={16} className="inline" />
</a>
</TooltipTrigger>
<TooltipContent>
Go to Resource in the Repository
</TooltipContent>
</Tooltip>
</CardAction>
)}
+11 -3
View File
@@ -5,9 +5,11 @@ import { useEffect } from "react";
interface AutoRefreshProps {
hasExecutingScan: boolean;
/** Optional callback for client-side refresh (used when data is managed in local state) */
onRefresh?: () => void | Promise<void>;
}
export function AutoRefresh({ hasExecutingScan }: AutoRefreshProps) {
export function AutoRefresh({ hasExecutingScan, onRefresh }: AutoRefreshProps) {
const router = useRouter();
const searchParams = useSearchParams();
@@ -19,11 +21,17 @@ export function AutoRefresh({ hasExecutingScan }: AutoRefreshProps) {
if (scanId) return;
const interval = setInterval(() => {
router.refresh();
if (onRefresh) {
// Use custom refresh callback for client-side state management
onRefresh();
} else {
// Default: trigger server-side refresh
router.refresh();
}
}, 5000);
return () => clearInterval(interval);
}, [hasExecutingScan, router, searchParams]);
}, [hasExecutingScan, router, searchParams, onRefresh]);
return null;
}
@@ -54,6 +54,7 @@ export function BreadcrumbNavigation({
"/manage-groups": "lucide:users-2",
"/services": "lucide:server",
"/workloads": "lucide:layers",
"/attack-paths": "lucide:git-branch",
};
const pathSegments = pathname
@@ -156,6 +157,7 @@ export function BreadcrumbNavigation({
>
{breadcrumb.icon && typeof breadcrumb.icon === "string" ? (
<Icon
aria-hidden="true"
className="text-text-neutral-primary"
height={24}
icon={breadcrumb.icon}
@@ -177,6 +179,7 @@ export function BreadcrumbNavigation({
>
{breadcrumb.icon && typeof breadcrumb.icon === "string" ? (
<Icon
aria-hidden="true"
className="text-text-neutral-primary"
height={24}
icon={breadcrumb.icon}
@@ -195,6 +198,7 @@ export function BreadcrumbNavigation({
<div className="flex items-center gap-2">
{breadcrumb.icon && typeof breadcrumb.icon === "string" ? (
<Icon
aria-hidden="true"
className="text-default-500"
height={24}
icon={breadcrumb.icon}
+21 -3
View File
@@ -20,6 +20,7 @@ interface MenuItemProps {
target?: string;
tooltip?: string;
isOpen: boolean;
highlight?: boolean;
}
export const MenuItem = ({
@@ -30,6 +31,7 @@ export const MenuItem = ({
target,
tooltip,
isOpen,
highlight,
}: MenuItemProps) => {
const pathname = usePathname();
const isActive = active !== undefined ? active : pathname.startsWith(href);
@@ -44,15 +46,31 @@ export const MenuItem = ({
variant={isActive ? "menu-active" : "menu-inactive"}
className={cn(
isOpen ? "w-full justify-start" : "w-14 justify-center",
highlight &&
"relative overflow-hidden before:absolute before:inset-0 before:rounded-lg before:bg-gradient-to-r before:from-emerald-500/20 before:via-teal-400/20 before:to-emerald-300/20 before:opacity-70",
)}
asChild
>
<Link href={href} target={target}>
<div className="flex items-center">
<span className={cn(isOpen ? "mr-4" : "")}>
<div className="relative z-10 flex items-center">
<span
className={cn(
isOpen ? "mr-4" : "",
highlight && "text-button-primary",
)}
>
<Icon size={18} />
</span>
{isOpen && <p className="max-w-[200px] truncate">{label}</p>}
{isOpen && (
<p className="max-w-[200px] truncate">
{label}
{highlight && (
<span className="ml-2 rounded-sm bg-emerald-500 px-1.5 py-0.5 text-[10px] font-semibold text-white">
NEW
</span>
)}
</p>
)}
</div>
</Link>
</Button>
+1
View File
@@ -119,6 +119,7 @@ export const Menu = ({ isOpen }: { isOpen: boolean }) => {
target={menu.target}
tooltip={menu.tooltip}
isOpen={isOpen}
highlight={menu.highlight}
/>
)}
</div>
+21 -9
View File
@@ -1,16 +1,18 @@
import { Chip } from "@heroui/chip";
import clsx from "clsx";
import React from "react";
import { SpinnerIcon } from "@/components/icons";
export type Status =
| "available"
| "scheduled"
| "executing"
| "completed"
| "failed"
| "cancelled";
const STATUS = {
available: "available",
scheduled: "scheduled",
executing: "executing",
completed: "completed",
failed: "failed",
cancelled: "cancelled",
} as const;
export type Status = (typeof STATUS)[keyof typeof STATUS];
const statusColorMap: Record<
Status,
@@ -24,6 +26,15 @@ const statusColorMap: Record<
cancelled: "danger",
};
const statusDisplayMap: Record<Status, string> = {
available: "queued",
scheduled: "scheduled",
executing: "executing",
completed: "completed",
failed: "failed",
cancelled: "cancelled",
};
export const StatusBadge = ({
status,
size = "sm",
@@ -37,6 +48,7 @@ export const StatusBadge = ({
className?: string;
}) => {
const color = statusColorMap[status as keyof typeof statusColorMap];
const displayLabel = statusDisplayMap[status] || status;
return (
<Chip
@@ -59,7 +71,7 @@ export const StatusBadge = ({
<span>executing</span>
</div>
) : (
<span className="flex items-center justify-center">{status}</span>
<span className="flex items-center justify-center">{displayLabel}</span>
)}
</Chip>
);
+29 -5
View File
@@ -42,10 +42,10 @@
{
"section": "dependencies",
"name": "@langchain/core",
"from": "0.3.77",
"to": "0.3.78",
"from": "0.3.78",
"to": "0.3.77",
"strategy": "installed",
"generatedAt": "2025-11-03T07:43:34.628Z"
"generatedAt": "2026-01-07T08:46:39.109Z"
},
{
"section": "dependencies",
@@ -125,7 +125,7 @@
"from": "1.1.15",
"to": "1.1.15",
"strategy": "installed",
"generatedAt": "2025-11-20T08:20:16.313Z"
"generatedAt": "2025-11-19T12:28:39.510Z"
},
{
"section": "dependencies",
@@ -207,6 +207,14 @@
"strategy": "installed",
"generatedAt": "2025-10-22T12:36:37.962Z"
},
{
"section": "dependencies",
"name": "@types/dagre",
"from": "0.7.53",
"to": "0.7.53",
"strategy": "installed",
"generatedAt": "2025-11-27T11:47:22.908Z"
},
{
"section": "dependencies",
"name": "@types/js-yaml",
@@ -253,7 +261,7 @@
"from": "1.1.1",
"to": "1.1.1",
"strategy": "installed",
"generatedAt": "2025-11-20T08:20:16.313Z"
"generatedAt": "2025-11-19T12:28:39.510Z"
},
{
"section": "dependencies",
@@ -263,6 +271,14 @@
"strategy": "installed",
"generatedAt": "2025-10-22T12:36:37.962Z"
},
{
"section": "dependencies",
"name": "dagre",
"from": "0.8.5",
"to": "0.8.5",
"strategy": "installed",
"generatedAt": "2025-11-27T11:47:22.908Z"
},
{
"section": "dependencies",
"name": "date-fns",
@@ -399,6 +415,14 @@
"strategy": "installed",
"generatedAt": "2025-10-22T12:36:37.962Z"
},
{
"section": "dependencies",
"name": "require-in-the-middle",
"from": "8.0.1",
"to": "8.0.1",
"strategy": "installed",
"generatedAt": "2026-01-07T12:09:03.204Z"
},
{
"section": "dependencies",
"name": "rss-parser",
+14
View File
@@ -1,6 +1,7 @@
import {
CloudCog,
Cog,
GitBranch,
Group,
Mail,
MessageCircleQuestion,
@@ -75,6 +76,19 @@ export const getMenuList = ({
},
],
},
{
groupLabel: "",
menus: [
{
href: "/attack-paths",
label: "Attack Paths",
icon: GitBranch,
active: pathname.startsWith("/attack-paths"),
highlight: true,
},
],
},
{
groupLabel: "",
menus: [
+75 -7
View File
@@ -35,6 +35,7 @@
"@tailwindcss/postcss": "4.1.13",
"@tailwindcss/typography": "0.5.16",
"@tanstack/react-table": "8.21.3",
"@types/dagre": "0.7.53",
"@types/js-yaml": "4.0.9",
"ai": "5.0.59",
"alert": "6.0.2",
@@ -42,6 +43,7 @@
"clsx": "2.1.1",
"cmdk": "1.1.1",
"d3": "7.9.0",
"dagre": "0.8.5",
"date-fns": "4.1.0",
"framer-motion": "11.18.2",
"intl-messageformat": "10.7.16",
@@ -59,6 +61,7 @@
"react-hook-form": "7.62.0",
"react-markdown": "10.1.0",
"recharts": "2.15.4",
"require-in-the-middle": "8.0.1",
"rss-parser": "3.13.0",
"server-only": "0.0.1",
"sharp": "0.33.5",
@@ -5271,7 +5274,6 @@
"resolved": "https://registry.npmjs.org/@opentelemetry/instrumentation/-/instrumentation-0.203.0.tgz",
"integrity": "sha512-ke1qyM+3AK2zPuBPb6Hk/GCsc5ewbLvPNkEuELx/JmANeEp6ZjnZ+wypPAJSucTw0wvCGrUaibDSdcrGFoWxKQ==",
"license": "Apache-2.0",
"peer": true,
"dependencies": {
"@opentelemetry/api-logs": "0.203.0",
"import-in-the-middle": "^1.8.1",
@@ -5493,6 +5495,20 @@
"@opentelemetry/api": "^1.3.0"
}
},
"node_modules/@opentelemetry/instrumentation-ioredis/node_modules/require-in-the-middle": {
"version": "7.5.2",
"resolved": "https://registry.npmjs.org/require-in-the-middle/-/require-in-the-middle-7.5.2.tgz",
"integrity": "sha512-gAZ+kLqBdHarXB64XpAe2VCjB7rIRv+mU8tfRWziHRJ5umKsIHN2tLLv6EtMw7WCdP19S0ERVMldNvxYCHnhSQ==",
"license": "MIT",
"dependencies": {
"debug": "^4.3.5",
"module-details-from-path": "^1.0.3",
"resolve": "^1.22.8"
},
"engines": {
"node": ">=8.6.0"
}
},
"node_modules/@opentelemetry/instrumentation-kafkajs": {
"version": "0.13.0",
"resolved": "https://registry.npmjs.org/@opentelemetry/instrumentation-kafkajs/-/instrumentation-kafkajs-0.13.0.tgz",
@@ -5694,6 +5710,20 @@
"@opentelemetry/api": "^1.7.0"
}
},
"node_modules/@opentelemetry/instrumentation/node_modules/require-in-the-middle": {
"version": "7.5.2",
"resolved": "https://registry.npmjs.org/require-in-the-middle/-/require-in-the-middle-7.5.2.tgz",
"integrity": "sha512-gAZ+kLqBdHarXB64XpAe2VCjB7rIRv+mU8tfRWziHRJ5umKsIHN2tLLv6EtMw7WCdP19S0ERVMldNvxYCHnhSQ==",
"license": "MIT",
"dependencies": {
"debug": "^4.3.5",
"module-details-from-path": "^1.0.3",
"resolve": "^1.22.8"
},
"engines": {
"node": ">=8.6.0"
}
},
"node_modules/@opentelemetry/redis-common": {
"version": "0.38.2",
"resolved": "https://registry.npmjs.org/@opentelemetry/redis-common/-/redis-common-0.38.2.tgz",
@@ -5846,6 +5876,20 @@
"@opentelemetry/api": "^1.3.0"
}
},
"node_modules/@prisma/instrumentation/node_modules/require-in-the-middle": {
"version": "7.5.2",
"resolved": "https://registry.npmjs.org/require-in-the-middle/-/require-in-the-middle-7.5.2.tgz",
"integrity": "sha512-gAZ+kLqBdHarXB64XpAe2VCjB7rIRv+mU8tfRWziHRJ5umKsIHN2tLLv6EtMw7WCdP19S0ERVMldNvxYCHnhSQ==",
"license": "MIT",
"dependencies": {
"debug": "^4.3.5",
"module-details-from-path": "^1.0.3",
"resolve": "^1.22.8"
},
"engines": {
"node": ">=8.6.0"
}
},
"node_modules/@prisma/instrumentation/node_modules/semver": {
"version": "7.7.3",
"resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz",
@@ -11800,6 +11844,12 @@
"@types/d3-selection": "*"
}
},
"node_modules/@types/dagre": {
"version": "0.7.53",
"resolved": "https://registry.npmjs.org/@types/dagre/-/dagre-0.7.53.tgz",
"integrity": "sha512-f4gkWqzPZvYmKhOsDnhq/R8mO4UMcKdxZo+i5SCkOU1wvGeHJeUXGIHeE9pnwGyPMDof1Vx5ZQo4nxpeg2TTVQ==",
"license": "MIT"
},
"node_modules/@types/debug": {
"version": "4.1.12",
"resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz",
@@ -14534,6 +14584,16 @@
"node": ">=12"
}
},
"node_modules/dagre": {
"version": "0.8.5",
"resolved": "https://registry.npmjs.org/dagre/-/dagre-0.8.5.tgz",
"integrity": "sha512-/aTqmnRta7x7MCCpExk7HQL2O4owCT2h8NT//9I1OQ9vt29Pa0BzSAkR5lwFUcQ7491yVi/3CXU9jQ5o0Mn2Sw==",
"license": "MIT",
"dependencies": {
"graphlib": "^2.1.8",
"lodash": "^4.17.15"
}
},
"node_modules/dagre-d3-es": {
"version": "7.0.13",
"resolved": "https://registry.npmjs.org/dagre-d3-es/-/dagre-d3-es-7.0.13.tgz",
@@ -16872,6 +16932,15 @@
"dev": true,
"license": "MIT"
},
"node_modules/graphlib": {
"version": "2.1.8",
"resolved": "https://registry.npmjs.org/graphlib/-/graphlib-2.1.8.tgz",
"integrity": "sha512-jcLLfkpoVGmH7/InMC/1hIvOPSUh38oJtGhvrOFGzioE1DZ+0YW16RgmOJhHiuWTvGiJQ9Z1Ik43JvkRPRvE+A==",
"license": "MIT",
"dependencies": {
"lodash": "^4.17.15"
}
},
"node_modules/graphql": {
"version": "16.12.0",
"resolved": "https://registry.npmjs.org/graphql/-/graphql-16.12.0.tgz",
@@ -22854,17 +22923,16 @@
}
},
"node_modules/require-in-the-middle": {
"version": "7.5.2",
"resolved": "https://registry.npmjs.org/require-in-the-middle/-/require-in-the-middle-7.5.2.tgz",
"integrity": "sha512-gAZ+kLqBdHarXB64XpAe2VCjB7rIRv+mU8tfRWziHRJ5umKsIHN2tLLv6EtMw7WCdP19S0ERVMldNvxYCHnhSQ==",
"version": "8.0.1",
"resolved": "https://registry.npmjs.org/require-in-the-middle/-/require-in-the-middle-8.0.1.tgz",
"integrity": "sha512-QT7FVMXfWOYFbeRBF6nu+I6tr2Tf3u0q8RIEjNob/heKY/nh7drD/k7eeMFmSQgnTtCzLDcCu/XEnpW2wk4xCQ==",
"license": "MIT",
"dependencies": {
"debug": "^4.3.5",
"module-details-from-path": "^1.0.3",
"resolve": "^1.22.8"
"module-details-from-path": "^1.0.3"
},
"engines": {
"node": ">=8.6.0"
"node": ">=9.3.0 || >=8.10.0 <9.0.0"
}
},
"node_modules/resolve": {
+3
View File
@@ -49,6 +49,7 @@
"@tailwindcss/postcss": "4.1.13",
"@tailwindcss/typography": "0.5.16",
"@tanstack/react-table": "8.21.3",
"@types/dagre": "0.7.53",
"@types/js-yaml": "4.0.9",
"ai": "5.0.59",
"alert": "6.0.2",
@@ -56,6 +57,7 @@
"clsx": "2.1.1",
"cmdk": "1.1.1",
"d3": "7.9.0",
"dagre": "0.8.5",
"date-fns": "4.1.0",
"framer-motion": "11.18.2",
"intl-messageformat": "10.7.16",
@@ -73,6 +75,7 @@
"react-hook-form": "7.62.0",
"react-markdown": "10.1.0",
"recharts": "2.15.4",
"require-in-the-middle": "8.0.1",
"rss-parser": "3.13.0",
"server-only": "0.0.1",
"sharp": "0.33.5",
+3
View File
@@ -54,6 +54,7 @@
--bg-pass-primary: var(--color-emerald-400);
--bg-pass-secondary: var(--color-emerald-50);
--bg-warning-primary: var(--color-orange-500);
--bg-warning-secondary: var(--color-orange-50);
--bg-fail-primary: var(--color-rose-500);
--bg-fail-secondary: var(--color-rose-50);
@@ -123,6 +124,7 @@
--bg-pass-primary: var(--color-green-400);
--bg-pass-secondary: var(--color-emerald-900);
--bg-warning-primary: var(--color-orange-400);
--bg-warning-secondary: var(--color-orange-900);
--bg-fail-primary: var(--color-rose-500);
--bg-fail-secondary: #432232;
@@ -209,6 +211,7 @@
--color-bg-pass: var(--bg-pass-primary);
--color-bg-pass-secondary: var(--bg-pass-secondary);
--color-bg-warning: var(--bg-warning-primary);
--color-bg-warning-secondary: var(--bg-warning-secondary);
--color-bg-fail: var(--bg-fail-primary);
--color-bg-fail-secondary: var(--bg-fail-secondary);
}
+245
View File
@@ -0,0 +1,245 @@
/**
* Attack Paths Feature Types
* Defines all TypeScript interfaces for the Attack Paths visualization feature
*/
// Scan state constants
export const SCAN_STATES = {
AVAILABLE: "available",
SCHEDULED: "scheduled",
EXECUTING: "executing",
COMPLETED: "completed",
FAILED: "failed",
} as const;
export type ScanState = (typeof SCAN_STATES)[keyof typeof SCAN_STATES];
// Attack Path Scan - Relationship Data
export interface RelationshipData {
type: string;
id: string;
}
export interface RelationshipWrapper {
data: RelationshipData;
}
export interface ScanRelationships {
provider: RelationshipWrapper;
scan: RelationshipWrapper;
task: RelationshipWrapper;
}
// Provider type constants
export const PROVIDER_TYPES = {
AWS: "aws",
AZURE: "azure",
GCP: "gcp",
} as const;
export type ProviderType = (typeof PROVIDER_TYPES)[keyof typeof PROVIDER_TYPES];
// Attack Path Scan Response
export interface AttackPathScanAttributes {
state: ScanState;
progress: number;
provider_alias: string;
provider_type: ProviderType;
provider_uid: string;
inserted_at: string;
started_at: string;
completed_at: string | null;
duration: number | null;
}
export interface AttackPathScan {
type: "attack-paths-scans";
id: string;
attributes: AttackPathScanAttributes;
relationships: ScanRelationships;
}
export interface PaginationLinks {
first: string;
last: string;
next: string | null;
prev: string | null;
}
export interface AttackPathScansResponse {
data: AttackPathScan[];
links: PaginationLinks;
}
// Data type constants
const DATA_TYPES = {
STRING: "string",
NUMBER: "number",
BOOLEAN: "boolean",
} as const;
type DataType = (typeof DATA_TYPES)[keyof typeof DATA_TYPES];
// Query Types
export interface AttackPathQueryParameter {
name: string;
label: string;
data_type: DataType;
description: string;
placeholder?: string;
required?: boolean;
}
export interface AttackPathQueryAttributes {
name: string;
description: string;
provider: string;
parameters: AttackPathQueryParameter[];
}
export interface AttackPathQuery {
type: "attack-paths-scans";
id: string;
attributes: AttackPathQueryAttributes;
}
export interface AttackPathQueriesResponse {
data: AttackPathQuery[];
}
// Graph Data Types
// Property values from graph nodes can be any primitive type or arrays
export type GraphNodePropertyValue =
| string
| number
| boolean
| null
| undefined
| string[]
| number[];
export interface GraphNodeProperties {
[key: string]: GraphNodePropertyValue;
}
export interface GraphNode {
id: string;
labels: string[]; // e.g., ["S3Bucket"], ["EC2Instance"], ["ProwlerFinding"]
properties: GraphNodeProperties;
findings?: string[]; // IDs of finding nodes connected via HAS_FINDING edges
resources?: string[]; // IDs of resource nodes connected via HAS_FINDING edges
}
export interface GraphEdge {
id: string;
source: string | object;
target: string | object;
type: string;
properties?: GraphNodeProperties;
}
export interface GraphRelationship {
id: string;
label: string;
source: string;
target: string;
properties?: GraphNodeProperties;
}
export interface AttackPathGraphData {
nodes: GraphNode[];
edges?: GraphEdge[];
relationships?: GraphRelationship[];
}
export interface QueryResultAttributes {
nodes: GraphNode[];
relationships?: GraphRelationship[];
}
export interface QueryResultData {
type: "attack-paths-query-run-request";
id: string | null;
attributes: QueryResultAttributes;
}
export interface AttackPathQueryResult {
data: QueryResultData;
}
// Finding severity and status constants
const FINDING_SEVERITIES = {
CRITICAL: "critical",
HIGH: "high",
MEDIUM: "medium",
LOW: "low",
INFO: "info",
} as const;
type FindingSeverity =
(typeof FINDING_SEVERITIES)[keyof typeof FINDING_SEVERITIES];
const FINDING_STATUSES = {
PASS: "PASS",
FAIL: "FAIL",
MANUAL: "MANUAL",
} as const;
type FindingStatus = (typeof FINDING_STATUSES)[keyof typeof FINDING_STATUSES];
export interface RelatedFinding {
id: string;
title: string;
severity: FindingSeverity;
status: FindingStatus;
}
// Node Detail Types
export interface NodeDetailData extends GraphNode {
relatedFindings?: RelatedFinding[];
incomingEdges?: GraphEdge[];
outgoingEdges?: GraphEdge[];
}
// Wizard State Types
export interface WizardState {
currentStep: 1 | 2;
selectedScanId: string | null;
selectedQuery: string | null;
queryParameters: Record<string, string | number | boolean>;
}
// Graph State Types
export interface GraphState {
data: AttackPathGraphData | null;
selectedNodeId: string | null;
loading: boolean;
error: string | null;
zoomLevel: number;
panX: number;
panY: number;
}
// Provider Integration
export interface ProviderWithScanStatus {
id: string;
alias: string;
provider: string;
scan: AttackPathScan;
connected: boolean;
}
// API Request/Response Helpers
export interface QueryRequestAttributes {
id: string;
parameters?: Record<string, string | number | boolean>;
}
export interface ExecuteQueryRequestData {
type: "attack-paths-query-run-request";
attributes: QueryRequestAttributes;
}
export interface ExecuteQueryRequest {
data: ExecuteQueryRequestData;
}

Some files were not shown because too many files have changed in this diff Show More