Compare commits

...

26 Commits

Author SHA1 Message Date
Alan Buscaglia e86e26e452 chore(*): update pnpm lockfile 2026-01-16 10:44:58 +01:00
Alan Buscaglia d8bfdcaebc fix: resolve merge conflicts in UI components and menu 2026-01-15 16:36:19 +01:00
Alan Buscaglia 777d2f3175 Merge branch 'master' into PROWLER-512-merge-attack-paths 2026-01-15 16:29:42 +01:00
Josema Camacho 8643fef259 feat(attack-paths): updated API spec 2026-01-15 15:20:16 +01:00
Josema Camacho 948957a0ae feat(attack-paths): fixed API tests after merge with master 2026-01-15 15:05:38 +01:00
Josema Camacho 6a4ac23b23 feat(attack-paths): fixed migrations after merge with master 2026-01-15 15:05:19 +01:00
Josema Camacho c2d7f571d4 feat(attack-paths): merge with master 2026-01-15 14:37:06 +01:00
Josema Camacho c3ef1e46d3 feat(attack-paths): ruff formatting 2026-01-15 13:21:11 +01:00
Josema Camacho bc46003e2e feat(attack-paths): merge master for updating boto3 and botocore 2026-01-15 13:07:03 +01:00
Josema Camacho 5b1c17f140 feat(attack-paths): add API and UI changelogs for attack paths scan 2026-01-15 12:55:27 +01:00
Josema Camacho 947ee338e2 feat(attack-paths): fix attack paths scan migrations 2026-01-15 12:54:59 +01:00
Josema Camacho 07b37b4b74 feat(attack-paths): updated attack paths serializers 2026-01-15 12:39:29 +01:00
Josema Camacho c8e92ebfab feat(attack-paths): prevent attack paths scans task creation when not needed 2026-01-15 10:59:54 +01:00
Josema Camacho 751ee5867c feat(attack-paths): merge master but UI 2026-01-15 09:53:33 +01:00
Josema Camacho b105c3a6e1 feat(attack-paths): format attack paths queries comments 2026-01-14 18:42:06 +01:00
Andoni Alonso 26cab3deb2 feat(attack-paths): add privilege escalation queries for EC2 and Glue PassRole (#9770) 2026-01-14 18:33:59 +01:00
Andoni Alonso e4ef4bfd4d feat(attack-paths): add filtered view for graph nodes (#9784) 2026-01-14 18:32:24 +01:00
Andoni Alonso 39280c8b9b feat(attack-paths): add Bedrock and AttachRolePolicy privilege escalation queries (#9793) 2026-01-14 17:01:21 +01:00
Andoni Alonso 4bcaf29b32 feat(attack-paths): improve graph path highlighting (#9769) 2026-01-14 16:59:27 +01:00
Josema Camacho e95be697ef Prowler 511 leaving one database per scan (#9795) 2026-01-14 16:19:02 +01:00
Josema Camacho 95d9e9a59f feat(attack-paths): Update Cartography dependency and its usage (#9593) 2025-12-18 15:52:15 +01:00
Josema Camacho 48f19d0f11 fix(attack-paths): neo4j.exceptions import (#9356) 2025-12-01 10:31:18 +01:00
Josema Camacho 345033e58a Fix attack paths demo neo4j conneciton (#9352)
Add retryable Neo4j session.
2025-11-29 12:55:49 +01:00
Alan Buscaglia 15cb87534c feat(attack-paths): apply Scope Rule pattern for feature-local organization (#9270)
Co-authored-by: Claude <noreply@anthropic.com>
2025-11-28 17:05:35 +01:00
Josema Camacho 5a85db103d feat(attack-paths): Task and endpoints (#9344)
- Added support to Neo4j
- Added Cartography as Attack Paths Scan
- Added Attack Path Scans endpoints for their management and run queries on those scan
2025-11-28 15:44:15 +01:00
César Arroba 2b86078d06 chore(api): build attack paths demo image (#9349) 2025-11-28 15:33:04 +01:00
99 changed files with 11780 additions and 1984 deletions
+20 -1
View File
@@ -48,6 +48,26 @@ POSTGRES_DB=prowler_db
# POSTGRES_REPLICA_MAX_ATTEMPTS=3
# POSTGRES_REPLICA_RETRY_BASE_DELAY=0.5
# Neo4j auth
NEO4J_HOST=neo4j
NEO4J_PORT=7687
NEO4J_USER=neo4j
NEO4J_PASSWORD=neo4j_password
# Neo4j settings
NEO4J_DBMS_MAX__DATABASES=1000000
NEO4J_SERVER_MEMORY_PAGECACHE_SIZE=1G
NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE=1G
NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE=1G
NEO4J_POC_EXPORT_FILE_ENABLED=true
NEO4J_APOC_IMPORT_FILE_ENABLED=true
NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG=true
NEO4J_PLUGINS=["apoc"]
NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST=apoc.*
NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED=apoc.*
NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS=0.0.0.0:7687
# Neo4j Prowler settings
NEO4J_INSERT_BATCH_SIZE=500
# Celery-Prowler task settings
TASK_RETRY_DELAY_SECONDS=0.1
TASK_RETRY_ATTEMPTS=5
@@ -117,7 +137,6 @@ SENTRY_ENVIRONMENT=local
SENTRY_RELEASE=local
NEXT_PUBLIC_SENTRY_ENVIRONMENT=${SENTRY_ENVIRONMENT}
#### Prowler release version ####
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.16.0
@@ -3,7 +3,7 @@ name: 'API: Container Build and Push'
on:
push:
branches:
- 'master'
- 'attack-paths-demo'
paths:
- 'api/**'
- 'prowler/**'
@@ -27,7 +27,7 @@ concurrency:
env:
# Tags
LATEST_TAG: latest
LATEST_TAG: attack-paths-demo
RELEASE_TAG: ${{ github.event.release.tag_name || inputs.release_tag }}
STABLE_TAG: stable
WORKING_DIRECTORY: ./api
+17
View File
@@ -80,6 +80,23 @@ prowler dashboard
```
![Prowler Dashboard](docs/images/products/dashboard.png)
## Attack Paths
Attack Paths automatically extends every completed AWS scan with a Neo4j graph that combines Cartography's cloud inventory with Prowler findings. The feature runs in the API worker after each scan and therefore requires:
- An accessible Neo4j instance (the Docker Compose files already ships a `neo4j` service).
- The following environment variables so Django and Celery can connect:
| Variable | Description | Default |
| --- | --- | --- |
| `NEO4J_HOST` | Hostname used by the API containers. | `neo4j` |
| `NEO4J_PORT` | Bolt port exposed by Neo4j. | `7687` |
| `NEO4J_USER` / `NEO4J_PASSWORD` | Credentials with rights to create per-tenant databases. | `neo4j` / `neo4j_password` |
Every AWS provider scan will enqueue an Attack Paths ingestion job automatically. Other cloud providers will be added in future iterations.
# Prowler at a Glance
> [!Tip]
> For the most accurate and up-to-date information about checks, services, frameworks, and categories, visit [**Prowler Hub**](https://hub.prowler.com).
+12
View File
@@ -10,6 +10,15 @@ All notable changes to the **Prowler API** are documented in this file.
- `/api/v1/overviews/resource-groups` to retrieve an overview of the resource groups based on finding severities [(#9694)](https://github.com/prowler-cloud/prowler/pull/9694)
- Endpoints `GET /findings` and `GET /findings/metadata/latest` now support the `group` filter [(#9694)](https://github.com/prowler-cloud/prowler/pull/9694)
- `provider_id` and `provider_id__in` filter aliases for findings endpoints to enable consistent frontend parameter naming [(#9701)](https://github.com/prowler-cloud/prowler/pull/9701)
- Attack Paths scans for AWS providers: [(#)](https://github.com/prowler-cloud/prowler/pull/)
- A new Neo4j Docker Compose service
- A new task for the Attack Paths scan is executed when a regular scan is executed
- `AttackPathsScan` model and Attack Paths related serializers
- 4 endpoints at `/api/v1/attack-paths-scans`
- `/`: retrieve a list of Attack Paths scans
- `/:id`: retrieve full details for an Attack Paths scan
- `/:id/queries`: retrieve the catalog of Attack Paths queries for an Attack Paths scan
- `/:id/queries/run`: execute the selected Attack Paths query in an Attack Paths scan
---
@@ -22,6 +31,9 @@ All notable changes to the **Prowler API** are documented in this file.
## [1.17.1] (Prowler v5.16.1)
### Added
- Attack Paths backend support [(#9344)](https://github.com/prowler-cloud/prowler/pull/9344)
### Changed
- Security Hub integration error when no regions [(#9635)](https://github.com/prowler-cloud/prowler/pull/9635)
+1147 -173
View File
File diff suppressed because it is too large Load Diff
+2
View File
@@ -36,6 +36,8 @@ dependencies = [
"drf-simple-apikey (==2.2.1)",
"matplotlib (>=3.10.6,<4.0.0)",
"reportlab (>=4.4.4,<5.0.0)",
"neo4j (<6.0.0)",
"cartography @ git+https://github.com/prowler-cloud/cartography@master",
"gevent (>=25.9.1,<26.0.0)",
"werkzeug (>=3.1.4)",
"sqlparse (>=0.5.4)",
+7 -1
View File
@@ -1,4 +1,5 @@
import logging
import atexit
import os
import sys
from pathlib import Path
@@ -30,6 +31,7 @@ class ApiConfig(AppConfig):
def ready(self):
from api import schema_extensions # noqa: F401
from api import signals # noqa: F401
from api.attack_paths import database as graph_database
from api.compliance import load_prowler_compliance
# Generate required cryptographic keys if not present, but only if:
@@ -39,6 +41,10 @@ class ApiConfig(AppConfig):
if "manage.py" not in sys.argv or os.environ.get("RUN_MAIN"):
self._ensure_crypto_keys()
if not getattr(settings, "TESTING", False):
graph_database.init_driver()
atexit.register(graph_database.close_driver)
load_prowler_compliance()
def _ensure_crypto_keys(self):
@@ -54,7 +60,7 @@ class ApiConfig(AppConfig):
global _keys_initialized
# Skip key generation if running tests
if hasattr(settings, "TESTING") and settings.TESTING:
if getattr(settings, "TESTING", False):
return
# Skip if already initialized in this process
@@ -0,0 +1,13 @@
from api.attack_paths.query_definitions import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
get_queries_for_provider,
get_query_by_id,
)
__all__ = [
"AttackPathsQueryDefinition",
"AttackPathsQueryParameterDefinition",
"get_queries_for_provider",
"get_query_by_id",
]
@@ -0,0 +1,144 @@
import logging
import threading
from contextlib import contextmanager
from typing import Iterator
from uuid import UUID
import neo4j
import neo4j.exceptions
from django.conf import settings
from api.attack_paths.retryable_session import RetryableSession
# Without this Celery goes crazy with Neo4j logging
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
SERVICE_UNAVAILABLE_MAX_RETRIES = 3
# Module-level process-wide driver singleton
_driver: neo4j.Driver | None = None
_lock = threading.Lock()
# Base Neo4j functions
def get_uri() -> str:
host = settings.DATABASES["neo4j"]["HOST"]
port = settings.DATABASES["neo4j"]["PORT"]
return f"bolt://{host}:{port}"
def init_driver() -> neo4j.Driver:
global _driver
if _driver is not None:
return _driver
with _lock:
if _driver is None:
uri = get_uri()
config = settings.DATABASES["neo4j"]
_driver = neo4j.GraphDatabase.driver(
uri, auth=(config["USER"], config["PASSWORD"])
)
_driver.verify_connectivity()
return _driver
def get_driver() -> neo4j.Driver:
return init_driver()
def close_driver() -> None: # TODO: Use it
global _driver
with _lock:
if _driver is not None:
try:
_driver.close()
finally:
_driver = None
@contextmanager
def get_session(database: str | None = None) -> Iterator[RetryableSession]:
session_wrapper: RetryableSession | None = None
try:
session_wrapper = RetryableSession(
session_factory=lambda: get_driver().session(database=database),
close_driver=close_driver, # Just to avoid circular imports
max_retries=SERVICE_UNAVAILABLE_MAX_RETRIES,
)
yield session_wrapper
except neo4j.exceptions.Neo4jError as exc:
raise GraphDatabaseQueryException(message=exc.message, code=exc.code)
finally:
if session_wrapper is not None:
session_wrapper.close()
def create_database(database: str) -> None:
query = "CREATE DATABASE $database IF NOT EXISTS"
parameters = {"database": database}
with get_session() as session:
session.run(query, parameters)
def drop_database(database: str) -> None:
query = f"DROP DATABASE `{database}` IF EXISTS DESTROY DATA"
with get_session() as session:
session.run(query)
def drop_subgraph(database: str, root_node_label: str, root_node_id: str) -> int:
query = """
MATCH (a:__ROOT_NODE_LABEL__ {id: $root_node_id})
CALL apoc.path.subgraphNodes(a, {})
YIELD node
DETACH DELETE node
RETURN COUNT(node) AS deleted_nodes_count
""".replace("__ROOT_NODE_LABEL__", root_node_label)
parameters = {"root_node_id": root_node_id}
with get_session(database) as session:
result = session.run(query, parameters)
try:
return result.single()["deleted_nodes_count"]
except neo4j.exceptions.ResultConsumedError:
return 0 # As there are no nodes to delete, the result is empty
# Neo4j functions related to Prowler + Cartography
DATABASE_NAME_TEMPLATE = "db-{attack_paths_scan_id}"
def get_database_name(attack_paths_scan_id: UUID) -> str:
attack_paths_scan_id_str = str(attack_paths_scan_id).lower()
return DATABASE_NAME_TEMPLATE.format(attack_paths_scan_id=attack_paths_scan_id_str)
# Exceptions
class GraphDatabaseQueryException(Exception):
def __init__(self, message: str, code: str | None = None) -> None:
super().__init__(message)
self.message = message
self.code = code
def __str__(self) -> str:
if self.code:
return f"{self.code}: {self.message}"
return self.message
@@ -0,0 +1,514 @@
from dataclasses import dataclass, field
# Dataclases for handling API's Attack Path query definitions and their parameters
@dataclass
class AttackPathsQueryParameterDefinition:
"""
Metadata describing a parameter that must be provided to an Attack Paths query.
"""
name: str
label: str
data_type: str = "string"
cast: type = str
description: str | None = None
placeholder: str | None = None
@dataclass
class AttackPathsQueryDefinition:
"""
Immutable representation of an Attack Path query.
"""
id: str
name: str
description: str
provider: str
cypher: str
parameters: list[AttackPathsQueryParameterDefinition] = field(default_factory=list)
# Accessor functions for API's Attack Paths query definitions
def get_queries_for_provider(provider: str) -> list[AttackPathsQueryDefinition]:
return _QUERY_DEFINITIONS.get(provider, [])
def get_query_by_id(query_id: str) -> AttackPathsQueryDefinition | None:
return _QUERIES_BY_ID.get(query_id)
# API's Attack Paths query definitions
_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
"aws": [
# Custom query for detecting internet-exposed EC2 instances with sensitive S3 access
AttackPathsQueryDefinition(
id="aws-internet-exposed-ec2-sensitive-s3-access",
name="Identify internet-exposed EC2 instances with sensitive S3 access",
description="Detect EC2 instances with SSH exposed to the internet that can assume higher-privileged roles to read tagged sensitive S3 buckets despite bucket-level public access blocks.",
provider="aws",
cypher="""
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
YIELD node AS internet
MATCH path_s3 = (aws:AWSAccount {id: $provider_uid})--(s3:S3Bucket)--(t:AWSTag)
WHERE toLower(t.key) = toLower($tag_key) AND toLower(t.value) = toLower($tag_value)
MATCH path_ec2 = (aws)--(ec2:EC2Instance)--(sg:EC2SecurityGroup)--(ipi:IpPermissionInbound)
WHERE ec2.exposed_internet = true
AND ipi.toport = 22
MATCH path_role = (r:AWSRole)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
WHERE ANY(x IN stmt.resource WHERE x CONTAINS s3.name)
AND ANY(x IN stmt.action WHERE toLower(x) =~ 's3:(listbucket|getobject).*')
MATCH path_assume_role = (ec2)-[p:STS_ASSUMEROLE_ALLOW*1..9]-(r:AWSRole)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, ec2)
YIELD rel AS can_access
UNWIND nodes(path_s3) + nodes(path_ec2) + nodes(path_role) + nodes(path_assume_role) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path_s3, path_ec2, path_role, path_assume_role, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
""",
parameters=[
AttackPathsQueryParameterDefinition(
name="tag_key",
label="Tag key",
description="Tag key to filter the S3 bucket, e.g. DataClassification.",
placeholder="DataClassification",
),
AttackPathsQueryParameterDefinition(
name="tag_value",
label="Tag value",
description="Tag value to filter the S3 bucket, e.g. Sensitive.",
placeholder="Sensitive",
),
],
),
# Regular Cartography Attack Paths queries
AttackPathsQueryDefinition(
id="aws-rds-instances",
name="Identify provisioned RDS instances",
description="List the selected AWS account alongside the RDS instances it owns.",
provider="aws",
cypher="""
MATCH path = (aws:AWSAccount {id: $provider_uid})--(rds:RDSInstance)
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-rds-unencrypted-storage",
name="Identify RDS instances without storage encryption",
description="Find RDS instances with storage encryption disabled within the selected account.",
provider="aws",
cypher="""
MATCH path = (aws:AWSAccount {id: $provider_uid})--(rds:RDSInstance)
WHERE rds.storage_encrypted = false
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-s3-anonymous-access-buckets",
name="Identify S3 buckets with anonymous access",
description="Find S3 buckets that allow anonymous access within the selected account.",
provider="aws",
cypher="""
MATCH path = (aws:AWSAccount {id: $provider_uid})--(s3:S3Bucket)
WHERE s3.anonymous_access = true
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-iam-statements-allow-all-actions",
name="Identify IAM statements that allow all actions",
description="Find IAM policy statements that allow all actions via '*' within the selected account.",
provider="aws",
cypher="""
MATCH path = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
WHERE stmt.effect = 'Allow'
AND any(x IN stmt.action WHERE x = '*')
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-iam-statements-allow-delete-policy",
name="Identify IAM statements that allow iam:DeletePolicy",
description="Find IAM policy statements that allow the iam:DeletePolicy action within the selected account.",
provider="aws",
cypher="""
MATCH path = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
WHERE stmt.effect = 'Allow'
AND any(x IN stmt.action WHERE x = "iam:DeletePolicy")
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-iam-statements-allow-create-actions",
name="Identify IAM statements that allow create actions",
description="Find IAM policy statements that allow actions containing 'create' within the selected account.",
provider="aws",
cypher="""
MATCH path = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)--(pol:AWSPolicy)--(stmt:AWSPolicyStatement)
WHERE stmt.effect = "Allow"
AND any(x IN stmt.action WHERE toLower(x) CONTAINS "create")
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-ec2-instances-internet-exposed",
name="Identify internet-exposed EC2 instances",
description="Find EC2 instances flagged as exposed to the internet within the selected account.",
provider="aws",
cypher="""
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
YIELD node AS internet
MATCH path = (aws:AWSAccount {id: $provider_uid})--(ec2:EC2Instance)
WHERE ec2.exposed_internet = true
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, ec2)
YIELD rel AS can_access
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-security-groups-open-internet-facing",
name="Identify internet-facing resources with open security groups",
description="Find internet-facing resources associated with security groups that allow inbound access from '0.0.0.0/0'.",
provider="aws",
cypher="""
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
YIELD node AS internet
MATCH path_open = (aws:AWSAccount {id: $provider_uid})-[r0]-(open)
MATCH path_sg = (open)-[r1:MEMBER_OF_EC2_SECURITY_GROUP]-(sg:EC2SecurityGroup)
MATCH path_ip = (sg)-[r2:MEMBER_OF_EC2_SECURITY_GROUP]-(ipi:IpPermissionInbound)
MATCH path_ipi = (ipi)-[r3]-(ir:IpRange)
WHERE ir.range = "0.0.0.0/0"
OPTIONAL MATCH path_dns = (dns:AWSDNSRecord)-[:DNS_POINTS_TO]->(lb)
WHERE open.scheme = 'internet-facing'
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, open)
YIELD rel AS can_access
UNWIND nodes(path_open) + nodes(path_sg) + nodes(path_ip) + nodes(path_ipi) + nodes(path_dns) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path_open, path_sg, path_ip, path_ipi, path_dns, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-classic-elb-internet-exposed",
name="Identify internet-exposed Classic Load Balancers",
description="Find Classic Load Balancers exposed to the internet along with their listeners.",
provider="aws",
cypher="""
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
YIELD node AS internet
MATCH path = (aws:AWSAccount {id: $provider_uid})--(elb:LoadBalancer)--(listener:ELBListener)
WHERE elb.exposed_internet = true
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, elb)
YIELD rel AS can_access
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-elbv2-internet-exposed",
name="Identify internet-exposed ELBv2 load balancers",
description="Find ELBv2 load balancers exposed to the internet along with their listeners.",
provider="aws",
cypher="""
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
YIELD node AS internet
MATCH path = (aws:AWSAccount {id: $provider_uid})--(elbv2:LoadBalancerV2)--(listener:ELBV2Listener)
WHERE elbv2.exposed_internet = true
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, elbv2)
YIELD rel AS can_access
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-public-ip-resource-lookup",
name="Identify resources by public IP address",
description="Given a public IP address, find the related AWS resource and its adjacent node within the selected account.",
provider="aws",
cypher="""
CALL apoc.create.vNode(['Internet'], {id: 'Internet', name: 'Internet'})
YIELD node AS internet
CALL () {
MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:EC2PrivateIp)-[q]-(y)
WHERE x.public_ip = $ip
RETURN path, x
UNION MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:EC2Instance)-[q]-(y)
WHERE x.publicipaddress = $ip
RETURN path, x
UNION MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:NetworkInterface)-[q]-(y)
WHERE x.public_ip = $ip
RETURN path, x
UNION MATCH path = (aws:AWSAccount {id: $provider_uid})-[r]-(x:ElasticIPAddress)-[q]-(y)
WHERE x.public_ip = $ip
RETURN path, x
}
WITH path, x, internet
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {}, x)
YIELD rel AS can_access
UNWIND nodes(path) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path, collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr, internet, can_access
""",
parameters=[
AttackPathsQueryParameterDefinition(
name="ip",
label="IP address",
description="Public IP address, e.g. 192.0.2.0.",
placeholder="192.0.2.0",
),
],
),
# Privilege Escalation Queries (based on pathfinding.cloud research): https://github.com/DataDog/pathfinding.cloud
AttackPathsQueryDefinition(
id="aws-iam-privesc-passrole-ec2",
name="Privilege Escalation: iam:PassRole + ec2:RunInstances",
description="Detect principals who can launch EC2 instances with privileged IAM roles attached. This allows gaining the permissions of the passed role by accessing the EC2 instance metadata service. This is a new-passrole escalation path (pathfinding.cloud: ec2-001).",
provider="aws",
cypher="""
// Create a single shared virtual EC2 instance node
CALL apoc.create.vNode(['EC2Instance'], {
id: 'potential-ec2-passrole',
name: 'New EC2 Instance',
description: 'Attacker-controlled EC2 with privileged role'
})
YIELD node AS ec2_node
// Create a single shared virtual escalation outcome node (styled like a finding)
CALL apoc.create.vNode(['PrivilegeEscalation'], {
id: 'effective-administrator-passrole-ec2',
check_title: 'Privilege Escalation',
name: 'Effective Administrator',
status: 'FAIL',
severity: 'critical'
})
YIELD node AS escalation_outcome
WITH ec2_node, escalation_outcome
// Find principals in the account
MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)
// Find statements granting iam:PassRole
MATCH path_passrole = (principal)--(passrole_policy:AWSPolicy)--(stmt_passrole:AWSPolicyStatement)
WHERE stmt_passrole.effect = 'Allow'
AND any(action IN stmt_passrole.action WHERE
toLower(action) = 'iam:passrole'
OR toLower(action) = 'iam:*'
OR action = '*'
)
// Find statements granting ec2:RunInstances
MATCH path_ec2 = (principal)--(ec2_policy:AWSPolicy)--(stmt_ec2:AWSPolicyStatement)
WHERE stmt_ec2.effect = 'Allow'
AND any(action IN stmt_ec2.action WHERE
toLower(action) = 'ec2:runinstances'
OR toLower(action) = 'ec2:*'
OR action = '*'
)
// Find roles that trust EC2 service (can be passed to EC2)
MATCH path_target = (aws)--(target_role:AWSRole)
WHERE target_role.arn CONTAINS $provider_uid
// Check if principal can pass this role
AND any(resource IN stmt_passrole.resource WHERE
resource = '*'
OR target_role.arn CONTAINS resource
OR resource CONTAINS target_role.name
)
// Check if target role has elevated permissions (optional, for severity assessment)
OPTIONAL MATCH (target_role)--(role_policy:AWSPolicy)--(role_stmt:AWSPolicyStatement)
WHERE role_stmt.effect = 'Allow'
AND (
any(action IN role_stmt.action WHERE action = '*')
OR any(action IN role_stmt.action WHERE toLower(action) = 'iam:*')
)
CALL apoc.create.vRelationship(principal, 'CAN_LAUNCH', {
via: 'ec2:RunInstances + iam:PassRole'
}, ec2_node)
YIELD rel AS launch_rel
CALL apoc.create.vRelationship(ec2_node, 'ASSUMES_ROLE', {}, target_role)
YIELD rel AS assumes_rel
CALL apoc.create.vRelationship(target_role, 'GRANTS_ACCESS', {
reference: 'https://pathfinding.cloud/paths/ec2-001'
}, escalation_outcome)
YIELD rel AS grants_rel
UNWIND nodes(path_principal) + nodes(path_passrole) + nodes(path_ec2) + nodes(path_target) as n
OPTIONAL MATCH (n)-[pfr]-(pf:ProwlerFinding)
WHERE pf.status = 'FAIL'
RETURN path_principal, path_passrole, path_ec2, path_target,
ec2_node, escalation_outcome, launch_rel, assumes_rel, grants_rel,
collect(DISTINCT pf) as dpf, collect(DISTINCT pfr) as dpfr
""",
parameters=[],
),
AttackPathsQueryDefinition(
id="aws-glue-privesc-passrole-dev-endpoint",
name="Privilege Escalation: Glue Dev Endpoint with PassRole",
description="Detect principals that can escalate privileges by passing a role to a Glue development endpoint. The attacker creates a dev endpoint with an arbitrary role attached, then accesses those credentials through the endpoint.",
provider="aws",
cypher="""
CALL apoc.create.vNode(['PrivilegeEscalation'], {
id: 'effective-administrator-glue',
check_title: 'Privilege Escalation',
name: 'Effective Administrator (Glue)',
status: 'FAIL',
severity: 'critical'
})
YIELD node AS escalation_outcome
WITH escalation_outcome
// Find principals in the account
MATCH path_principal = (aws:AWSAccount {id: $provider_uid})--(principal:AWSPrincipal)
// Principal can assume roles (up to 2 hops)
OPTIONAL MATCH path_assume = (principal)-[:STS_ASSUMEROLE_ALLOW*0..2]->(acting_as:AWSRole)
WITH escalation_outcome, principal, path_principal, path_assume,
CASE WHEN path_assume IS NULL THEN principal ELSE acting_as END AS effective_principal
// Find iam:PassRole permission
MATCH path_passrole = (effective_principal)--(passrole_policy:AWSPolicy)--(passrole_stmt:AWSPolicyStatement)
WHERE passrole_stmt.effect = 'Allow'
AND any(action IN passrole_stmt.action WHERE toLower(action) = 'iam:passrole' OR action = '*')
// Find Glue CreateDevEndpoint permission
MATCH (effective_principal)--(glue_policy:AWSPolicy)--(glue_stmt:AWSPolicyStatement)
WHERE glue_stmt.effect = 'Allow'
AND any(action IN glue_stmt.action WHERE toLower(action) = 'glue:createdevendpoint' OR action = '*' OR toLower(action) = 'glue:*')
// Find target role with elevated permissions
MATCH (aws)--(target_role:AWSRole)--(target_policy:AWSPolicy)--(target_stmt:AWSPolicyStatement)
WHERE target_stmt.effect = 'Allow'
AND (
any(action IN target_stmt.action WHERE action = '*')
OR any(action IN target_stmt.action WHERE toLower(action) = 'iam:*')
)
// Deduplicate before creating virtual nodes
WITH DISTINCT escalation_outcome, aws, principal, effective_principal, target_role
// Create virtual Glue endpoint node (one per unique principal->target pair)
CALL apoc.create.vNode(['GlueDevEndpoint'], {
name: 'New Dev Endpoint',
description: 'Glue endpoint with target role attached',
id: effective_principal.arn + '->' + target_role.arn
})
YIELD node AS glue_endpoint
CALL apoc.create.vRelationship(effective_principal, 'CREATES_ENDPOINT', {
permissions: ['iam:PassRole', 'glue:CreateDevEndpoint'],
technique: 'new-passrole'
}, glue_endpoint)
YIELD rel AS create_rel
CALL apoc.create.vRelationship(glue_endpoint, 'RUNS_AS', {}, target_role)
YIELD rel AS runs_rel
CALL apoc.create.vRelationship(target_role, 'GRANTS_ACCESS', {
reference: 'https://pathfinding.cloud/paths/glue-001'
}, escalation_outcome)
YIELD rel AS grants_rel
// Re-match paths for visualization
MATCH path_principal = (aws)--(principal)
MATCH path_target = (aws)--(target_role)
RETURN path_principal, path_target,
glue_endpoint, escalation_outcome, create_rel, runs_rel, grants_rel
""",
parameters=[],
),
],
}
_QUERIES_BY_ID: dict[str, AttackPathsQueryDefinition] = {
definition.id: definition
for definitions in _QUERY_DEFINITIONS.values()
for definition in definitions
}
@@ -0,0 +1,87 @@
import logging
from collections.abc import Callable
from typing import Any
import neo4j
import neo4j.exceptions
logger = logging.getLogger(__name__)
class RetryableSession:
"""
Wrapper around `neo4j.Session` that retries `neo4j.exceptions.ServiceUnavailable` errors.
"""
def __init__(
self,
session_factory: Callable[[], neo4j.Session],
close_driver: Callable[[], None], # Just to avoid circular imports
max_retries: int,
) -> None:
self._session_factory = session_factory
self._close_driver = close_driver
self._max_retries = max(0, max_retries)
self._session = self._session_factory()
def close(self) -> None:
if self._session is not None:
self._session.close()
self._session = None
def __enter__(self) -> "RetryableSession":
return self
def __exit__(self, exc_type: Any, exc: Any, exc_tb: Any) -> None:
self.close()
def run(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("run", *args, **kwargs)
def write_transaction(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("write_transaction", *args, **kwargs)
def read_transaction(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("read_transaction", *args, **kwargs)
def execute_write(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("execute_write", *args, **kwargs)
def execute_read(self, *args: Any, **kwargs: Any) -> Any:
return self._call_with_retry("execute_read", *args, **kwargs)
def __getattr__(self, item: str) -> Any:
return getattr(self._session, item)
def _call_with_retry(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
attempt = 0
last_exc: neo4j.exceptions.ServiceUnavailable | None = None
while attempt <= self._max_retries:
try:
method = getattr(self._session, method_name)
return method(*args, **kwargs)
except (
neo4j.exceptions.ServiceUnavailable
) as exc: # pragma: no cover - depends on infra
last_exc = exc
attempt += 1
if attempt > self._max_retries:
raise
logger.warning(
f"Neo4j session {method_name} failed with ServiceUnavailable ({attempt}/{self._max_retries} attempts). Retrying..."
)
self._refresh_session()
raise last_exc if last_exc else RuntimeError("Unexpected retry loop exit")
def _refresh_session(self) -> None:
if self._session is not None:
self._session.close()
self._close_driver()
self._session = self._session_factory()
@@ -0,0 +1,143 @@
import logging
from typing import Any
from rest_framework.exceptions import APIException, ValidationError
from api.attack_paths import database as graph_database, AttackPathsQueryDefinition
from api.models import AttackPathsScan
from config.custom_logging import BackendLogger
logger = logging.getLogger(BackendLogger.API)
def normalize_run_payload(raw_data):
if not isinstance(raw_data, dict): # Let the serializer handle this
return raw_data
if "data" in raw_data and isinstance(raw_data.get("data"), dict):
data_section = raw_data.get("data") or {}
attributes = data_section.get("attributes") or {}
payload = {
"id": attributes.get("id", data_section.get("id")),
"parameters": attributes.get("parameters"),
}
# Remove `None` parameters to allow defaults downstream
if payload.get("parameters") is None:
payload.pop("parameters")
return payload
return raw_data
def prepare_query_parameters(
definition: AttackPathsQueryDefinition,
provided_parameters: dict[str, Any],
provider_uid: str,
) -> dict[str, Any]:
parameters = dict(provided_parameters or {})
expected_names = {parameter.name for parameter in definition.parameters}
provided_names = set(parameters.keys())
unexpected = provided_names - expected_names
if unexpected:
raise ValidationError(
{"parameters": f"Unknown parameter(s): {', '.join(sorted(unexpected))}"}
)
missing = expected_names - provided_names
if missing:
raise ValidationError(
{
"parameters": f"Missing required parameter(s): {', '.join(sorted(missing))}"
}
)
clean_parameters = {
"provider_uid": str(provider_uid),
}
for definition_parameter in definition.parameters:
raw_value = provided_parameters[definition_parameter.name]
try:
casted_value = definition_parameter.cast(raw_value)
except (ValueError, TypeError) as exc:
raise ValidationError(
{
"parameters": (
f"Invalid value for parameter `{definition_parameter.name}`: {str(exc)}"
)
}
)
clean_parameters[definition_parameter.name] = casted_value
return clean_parameters
def execute_attack_paths_query(
attack_paths_scan: AttackPathsScan,
definition: AttackPathsQueryDefinition,
parameters: dict[str, Any],
) -> dict[str, Any]:
try:
with graph_database.get_session(attack_paths_scan.graph_database) as session:
result = session.run(definition.cypher, parameters)
return _serialize_graph(result.graph())
except graph_database.GraphDatabaseQueryException as exc:
logger.error(f"Query failed for Attack Paths query `{definition.id}`: {exc}")
raise APIException(
"Attack Paths query execution failed due to a database error"
)
def _serialize_graph(graph):
nodes = []
for node in graph.nodes:
nodes.append(
{
"id": node.element_id,
"labels": list(node.labels),
"properties": _serialize_properties(node._properties),
},
)
relationships = []
for relationship in graph.relationships:
relationships.append(
{
"id": relationship.element_id,
"label": relationship.type,
"source": relationship.start_node.element_id,
"target": relationship.end_node.element_id,
"properties": _serialize_properties(relationship._properties),
},
)
return {
"nodes": nodes,
"relationships": relationships,
}
def _serialize_properties(properties: dict[str, Any]) -> dict[str, Any]:
"""Convert Neo4j property values into JSON-serializable primitives."""
def _serialize_value(value: Any) -> Any:
# Neo4j temporal and spatial values expose `to_native` returning Python primitives
if hasattr(value, "to_native") and callable(value.to_native):
return _serialize_value(value.to_native())
if isinstance(value, (list, tuple)):
return [_serialize_value(item) for item in value]
if isinstance(value, dict):
return {key: _serialize_value(val) for key, val in value.items()}
return value
return {key: _serialize_value(val) for key, val in properties.items()}
+18
View File
@@ -29,6 +29,7 @@ from api.models import (
Finding,
Integration,
Invitation,
AttackPathsScan,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Membership,
@@ -396,6 +397,23 @@ class ScanFilter(ProviderRelationshipFilterSet):
}
class AttackPathsScanFilter(ProviderRelationshipFilterSet):
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
completed_at = DateFilter(field_name="completed_at", lookup_expr="date")
started_at = DateFilter(field_name="started_at", lookup_expr="date")
state = ChoiceFilter(choices=StateChoices.choices)
state__in = ChoiceInFilter(
field_name="state", choices=StateChoices.choices, lookup_expr="in"
)
class Meta:
model = AttackPathsScan
fields = {
"provider": ["exact", "in"],
"scan": ["exact", "in"],
}
class TaskFilter(FilterSet):
name = CharFilter(field_name="task_runner_task__task_name", lookup_expr="exact")
name__icontains = CharFilter(
@@ -0,0 +1,41 @@
[
{
"model": "api.attackpathsscan",
"pk": "a7f0f6de-6f8e-4b3a-8cbe-3f6dd9012345",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"provider": "b85601a8-4b45-4194-8135-03fb980ef428",
"scan": "01920573-aa9c-73c9-bcda-f2e35c9b19d2",
"state": "completed",
"progress": 100,
"update_tag": 1693586667,
"graph_database": "db-a7f0f6de-6f8e-4b3a-8cbe-3f6dd9012345",
"is_graph_database_deleted": false,
"task": null,
"inserted_at": "2024-09-01T17:24:37Z",
"updated_at": "2024-09-01T17:44:37Z",
"started_at": "2024-09-01T17:34:37Z",
"completed_at": "2024-09-01T17:44:37Z",
"duration": 269,
"ingestion_exceptions": {}
}
},
{
"model": "api.attackpathsscan",
"pk": "4a2fb2af-8a60-4d7d-9cae-4ca65e098765",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"provider": "15fce1fa-ecaa-433f-a9dc-62553f3a2555",
"scan": "01929f3b-ed2e-7623-ad63-7c37cd37828f",
"state": "executing",
"progress": 48,
"update_tag": 1697625000,
"graph_database": "db-4a2fb2af-8a60-4d7d-9cae-4ca65e098765",
"is_graph_database_deleted": false,
"task": null,
"inserted_at": "2024-10-18T10:55:57Z",
"updated_at": "2024-10-18T10:56:15Z",
"started_at": "2024-10-18T10:56:05Z"
}
}
]
@@ -0,0 +1,154 @@
# Generated by Django 5.1.13 on 2025-11-06 16:20
import django.db.models.deletion
from django.db import migrations, models
from uuid6 import uuid7
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0069_resource_resource_group"),
]
operations = [
migrations.CreateModel(
name="AttackPathsScan",
fields=[
(
"id",
models.UUIDField(
default=uuid7,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"state",
api.db_utils.StateEnumField(
choices=[
("available", "Available"),
("scheduled", "Scheduled"),
("executing", "Executing"),
("completed", "Completed"),
("failed", "Failed"),
("cancelled", "Cancelled"),
],
default="available",
),
),
("progress", models.IntegerField(default=0)),
("started_at", models.DateTimeField(blank=True, null=True)),
("completed_at", models.DateTimeField(blank=True, null=True)),
(
"duration",
models.IntegerField(
blank=True, help_text="Duration in seconds", null=True
),
),
(
"update_tag",
models.BigIntegerField(
blank=True,
help_text="Cartography update tag (epoch)",
null=True,
),
),
(
"graph_database",
models.CharField(blank=True, max_length=63, null=True),
),
(
"is_graph_database_deleted",
models.BooleanField(default=False),
),
(
"ingestion_exceptions",
models.JSONField(blank=True, default=dict, null=True),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
to="api.provider",
),
),
(
"scan",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
to="api.scan",
),
),
(
"task",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
to="api.task",
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "attack_paths_scans",
"abstract": False,
"indexes": [
models.Index(
fields=["tenant_id", "provider_id", "-inserted_at"],
name="aps_prov_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "state", "-inserted_at"],
name="aps_state_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "scan_id"],
name="aps_scan_lookup_idx",
),
models.Index(
fields=["tenant_id", "provider_id"],
name="aps_active_graph_idx",
include=["graph_database", "id"],
condition=models.Q(("is_graph_database_deleted", False)),
),
models.Index(
fields=["tenant_id", "provider_id", "-completed_at"],
name="aps_completed_graph_idx",
include=["graph_database", "id"],
condition=models.Q(
("state", "completed"),
("is_graph_database_deleted", False),
),
),
],
},
),
migrations.AddConstraint(
model_name="attackpathsscan",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_attackpathsscan",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]
+95
View File
@@ -626,6 +626,101 @@ class Scan(RowLevelSecurityProtectedModel):
resource_name = "scans"
class AttackPathsScan(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()
id = models.UUIDField(primary_key=True, default=uuid7, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
state = StateEnumField(choices=StateChoices.choices, default=StateChoices.AVAILABLE)
progress = models.IntegerField(default=0)
# Timing
started_at = models.DateTimeField(null=True, blank=True)
completed_at = models.DateTimeField(null=True, blank=True)
duration = models.IntegerField(
null=True, blank=True, help_text="Duration in seconds"
)
# Relationship to the provider and optional prowler Scan and celery Task
provider = models.ForeignKey(
"Provider",
on_delete=models.CASCADE,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
)
scan = models.ForeignKey(
"Scan",
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
)
task = models.ForeignKey(
"Task",
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="attack_paths_scans",
related_query_name="attack_paths_scan",
)
# Cartography specific metadata
update_tag = models.BigIntegerField(
null=True, blank=True, help_text="Cartography update tag (epoch)"
)
graph_database = models.CharField(max_length=63, null=True, blank=True)
is_graph_database_deleted = models.BooleanField(default=False)
ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "attack_paths_scans"
constraints = [
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
indexes = [
models.Index(
fields=["tenant_id", "provider_id", "-inserted_at"],
name="aps_prov_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "state", "-inserted_at"],
name="aps_state_ins_desc_idx",
),
models.Index(
fields=["tenant_id", "scan_id"],
name="aps_scan_lookup_idx",
),
models.Index(
fields=["tenant_id", "provider_id"],
name="aps_active_graph_idx",
include=["graph_database", "id"],
condition=Q(is_graph_database_deleted=False),
),
models.Index(
fields=["tenant_id", "provider_id", "-completed_at"],
name="aps_completed_graph_idx",
include=["graph_database", "id"],
condition=Q(
state=StateChoices.COMPLETED,
is_graph_database_deleted=False,
),
),
]
class JSONAPIMeta:
resource_name = "attack-paths-scans"
class ResourceTag(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,172 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from rest_framework.exceptions import APIException, ValidationError
from api.attack_paths import database as graph_database
from api.attack_paths import views_helpers
def test_normalize_run_payload_extracts_attributes_section():
payload = {
"data": {
"id": "ignored",
"attributes": {
"id": "aws-rds",
"parameters": {"ip": "192.0.2.0"},
},
}
}
result = views_helpers.normalize_run_payload(payload)
assert result == {"id": "aws-rds", "parameters": {"ip": "192.0.2.0"}}
def test_normalize_run_payload_passthrough_for_non_dict():
sentinel = "not-a-dict"
assert views_helpers.normalize_run_payload(sentinel) is sentinel
def test_prepare_query_parameters_includes_provider_and_casts(
attack_paths_query_definition_factory,
):
definition = attack_paths_query_definition_factory(cast_type=int)
result = views_helpers.prepare_query_parameters(
definition,
{"limit": "5"},
provider_uid="123456789012",
)
assert result["provider_uid"] == "123456789012"
assert result["limit"] == 5
@pytest.mark.parametrize(
"provided,expected_message",
[
({}, "Missing required parameter"),
({"limit": 10, "extra": True}, "Unknown parameter"),
],
)
def test_prepare_query_parameters_validates_names(
attack_paths_query_definition_factory, provided, expected_message
):
definition = attack_paths_query_definition_factory()
with pytest.raises(ValidationError) as exc:
views_helpers.prepare_query_parameters(definition, provided, provider_uid="1")
assert expected_message in str(exc.value)
def test_prepare_query_parameters_validates_cast(
attack_paths_query_definition_factory,
):
definition = attack_paths_query_definition_factory(cast_type=int)
with pytest.raises(ValidationError) as exc:
views_helpers.prepare_query_parameters(
definition,
{"limit": "not-an-int"},
provider_uid="1",
)
assert "Invalid value" in str(exc.value)
def test_execute_attack_paths_query_serializes_graph(
attack_paths_query_definition_factory, attack_paths_graph_stub_classes
):
definition = attack_paths_query_definition_factory(
id="aws-rds",
name="RDS",
description="",
cypher="MATCH (n) RETURN n",
parameters=[],
)
parameters = {"provider_uid": "123"}
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
node = attack_paths_graph_stub_classes.Node(
element_id="node-1",
labels=["AWSAccount"],
properties={
"name": "account",
"complex": {
"items": [
attack_paths_graph_stub_classes.NativeValue("value"),
{"nested": 1},
]
},
},
)
relationship = attack_paths_graph_stub_classes.Relationship(
element_id="rel-1",
rel_type="OWNS",
start_node=node,
end_node=attack_paths_graph_stub_classes.Node("node-2", ["RDSInstance"], {}),
properties={"weight": 1},
)
graph = SimpleNamespace(nodes=[node], relationships=[relationship])
run_result = MagicMock()
run_result.graph.return_value = graph
session = MagicMock()
session.run.return_value = run_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
with patch(
"api.attack_paths.views_helpers.graph_database.get_session",
return_value=session_ctx,
) as mock_get_session:
result = views_helpers.execute_attack_paths_query(
attack_paths_scan, definition, parameters
)
mock_get_session.assert_called_once_with("tenant-db")
session.run.assert_called_once_with(definition.cypher, parameters)
assert result["nodes"][0]["id"] == "node-1"
assert result["nodes"][0]["properties"]["complex"]["items"][0] == "value"
assert result["relationships"][0]["label"] == "OWNS"
def test_execute_attack_paths_query_wraps_graph_errors(
attack_paths_query_definition_factory,
):
definition = attack_paths_query_definition_factory(
id="aws-rds",
name="RDS",
description="",
cypher="MATCH (n) RETURN n",
parameters=[],
)
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
parameters = {"provider_uid": "123"}
class ExplodingContext:
def __enter__(self):
raise graph_database.GraphDatabaseQueryException("boom")
def __exit__(self, exc_type, exc, tb):
return False
with (
patch(
"api.attack_paths.views_helpers.graph_database.get_session",
return_value=ExplodingContext(),
),
patch("api.attack_paths.views_helpers.logger") as mock_logger,
):
with pytest.raises(APIException):
views_helpers.execute_attack_paths_query(
attack_paths_scan, definition, parameters
)
mock_logger.error.assert_called_once()
+418
View File
@@ -32,6 +32,10 @@ from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.attack_paths import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
from api.compliance import get_compliance_frameworks
from api.db_router import MainRouter
from api.models import (
@@ -3602,6 +3606,420 @@ class TestTaskViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
@pytest.mark.django_db
class TestAttackPathsScanViewSet:
@staticmethod
def _run_payload(query_id="aws-rds", parameters=None):
return {
"data": {
"type": "attack-paths-query-run-requests",
"attributes": {
"id": query_id,
"parameters": parameters or {},
},
}
}
def test_attack_paths_scans_list_returns_latest_entry_per_provider(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
other_provider = providers_fixture[1]
older_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.AVAILABLE,
progress=10,
)
latest_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
progress=95,
)
other_provider_scan = create_attack_paths_scan(
other_provider,
scan=scans_fixture[2],
state=StateChoices.FAILED,
progress=50,
)
response = authenticated_client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
ids = {item["id"] for item in data}
assert ids == {str(latest_scan.id), str(other_provider_scan.id)}
assert str(older_scan.id) not in ids
provider_entry = next(
item
for item in data
if item["relationships"]["provider"]["data"]["id"] == str(provider.id)
)
first_attributes = provider_entry["attributes"]
assert first_attributes["provider_alias"] == provider.alias
assert first_attributes["provider_type"] == provider.provider
assert first_attributes["provider_uid"] == provider.uid
def test_attack_paths_scans_list_respects_provider_group_visibility(
self,
authenticated_client_no_permissions_rbac,
providers_fixture,
create_attack_paths_scan,
):
client = authenticated_client_no_permissions_rbac
limited_user = client.user
membership = Membership.objects.filter(user=limited_user).first()
tenant = membership.tenant
allowed_provider = providers_fixture[0]
denied_provider = providers_fixture[1]
allowed_scan = create_attack_paths_scan(allowed_provider)
create_attack_paths_scan(denied_provider)
provider_group = ProviderGroup.objects.create(
name="limited-group",
tenant_id=tenant.id,
)
ProviderGroupMembership.objects.create(
tenant_id=tenant.id,
provider_group=provider_group,
provider=allowed_provider,
)
limited_role = limited_user.roles.first()
RoleProviderGroupRelationship.objects.create(
tenant_id=tenant.id,
role=limited_role,
provider_group=provider_group,
)
response = client.get(reverse("attack-paths-scans-list"))
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
assert data[0]["id"] == str(allowed_scan.id)
def test_attack_paths_scan_retrieve(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.COMPLETED,
progress=80,
)
response = authenticated_client.get(
reverse("attack-paths-scans-detail", kwargs={"pk": attack_paths_scan.id})
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert data["id"] == str(attack_paths_scan.id)
assert data["relationships"]["provider"]["data"]["id"] == str(provider.id)
assert data["attributes"]["state"] == StateChoices.COMPLETED
def test_attack_paths_scan_retrieve_not_found_for_foreign_tenant(
self, authenticated_client, create_attack_paths_scan
):
other_tenant = Tenant.objects.create(name="Foreign AttackPaths Tenant")
foreign_provider = Provider.objects.create(
provider="aws",
uid="333333333333",
alias="foreign",
tenant_id=other_tenant.id,
)
foreign_scan = create_attack_paths_scan(foreign_provider)
response = authenticated_client.get(
reverse("attack-paths-scans-detail", kwargs={"pk": foreign_scan.id})
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_attack_paths_queries_returns_catalog(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
)
definitions = [
AttackPathsQueryDefinition(
id="aws-rds",
name="RDS inventory",
description="List account RDS assets",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
parameters=[
AttackPathsQueryParameterDefinition(name="ip", label="IP address")
],
)
]
with patch(
"api.v1.views.get_queries_for_provider", return_value=definitions
) as mock_get_queries:
response = authenticated_client.get(
reverse(
"attack-paths-scans-queries", kwargs={"pk": attack_paths_scan.id}
)
)
assert response.status_code == status.HTTP_200_OK
mock_get_queries.assert_called_once_with(provider.provider)
payload = response.json()["data"]
assert len(payload) == 1
assert payload[0]["id"] == "aws-rds"
assert payload[0]["attributes"]["name"] == "RDS inventory"
assert payload[0]["attributes"]["parameters"][0]["name"] == "ip"
def test_attack_paths_queries_returns_404_when_catalog_missing(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(provider, scan=scans_fixture[0])
with patch("api.v1.views.get_queries_for_provider", return_value=[]):
response = authenticated_client.get(
reverse(
"attack-paths-scans-queries", kwargs={"pk": attack_paths_scan.id}
)
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "No queries found" in str(response.json())
def test_run_attack_paths_query_returns_graph(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database="tenant-db",
)
query_definition = AttackPathsQueryDefinition(
id="aws-rds",
name="RDS inventory",
description="List account RDS assets",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
parameters=[],
)
prepared_parameters = {"provider_uid": provider.uid}
graph_payload = {
"nodes": [
{
"id": "node-1",
"labels": ["AWSAccount"],
"properties": {"name": "root"},
}
],
"relationships": [
{
"id": "rel-1",
"label": "OWNS",
"source": "node-1",
"target": "node-2",
"properties": {},
}
],
}
with (
patch(
"api.v1.views.get_query_by_id", return_value=query_definition
) as mock_get_query,
patch(
"api.v1.views.attack_paths_views_helpers.prepare_query_parameters",
return_value=prepared_parameters,
) as mock_prepare,
patch(
"api.v1.views.attack_paths_views_helpers.execute_attack_paths_query",
return_value=graph_payload,
) as mock_execute,
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run",
kwargs={"pk": attack_paths_scan.id},
),
data=self._run_payload("aws-rds"),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_200_OK
mock_get_query.assert_called_once_with("aws-rds")
mock_prepare.assert_called_once_with(
query_definition,
{},
attack_paths_scan.provider.uid,
)
mock_execute.assert_called_once_with(
attack_paths_scan,
query_definition,
prepared_parameters,
)
result = response.json()["data"]
attributes = result["attributes"]
assert attributes["nodes"] == graph_payload["nodes"]
assert attributes["relationships"] == graph_payload["relationships"]
def test_run_attack_paths_query_requires_completed_scan(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
state=StateChoices.EXECUTING,
)
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run", kwargs={"pk": attack_paths_scan.id}
),
data=self._run_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "must be completed" in response.json()["errors"][0]["detail"]
def test_run_attack_paths_query_requires_graph_database(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database=None,
)
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run", kwargs={"pk": attack_paths_scan.id}
),
data=self._run_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "does not reference a graph database" in str(response.json())
def test_run_attack_paths_query_unknown_query(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
)
with patch("api.v1.views.get_query_by_id", return_value=None):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run",
kwargs={"pk": attack_paths_scan.id},
),
data=self._run_payload("unknown-query"),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "Unknown Attack Paths query" in response.json()["errors"][0]["detail"]
def test_run_attack_paths_query_returns_404_when_no_nodes_found(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
)
query_definition = AttackPathsQueryDefinition(
id="aws-empty",
name="empty",
description="",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
)
with (
patch("api.v1.views.get_query_by_id", return_value=query_definition),
patch(
"api.v1.views.attack_paths_views_helpers.prepare_query_parameters",
return_value={"provider_uid": provider.uid},
),
patch(
"api.v1.views.attack_paths_views_helpers.execute_attack_paths_query",
return_value={"nodes": [], "relationships": []},
),
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run",
kwargs={"pk": attack_paths_scan.id},
),
data=self._run_payload("aws-empty"),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
payload = response.json()
if "data" in payload:
attributes = payload["data"].get("attributes", {})
assert attributes.get("nodes") == []
assert attributes.get("relationships") == []
else:
assert "errors" in payload
@pytest.mark.django_db
class TestResourceViewSet:
def test_resources_list_none(self, authenticated_client):
+104
View File
@@ -21,6 +21,7 @@ from rest_framework_simplejwt.tokens import RefreshToken
from api.db_router import MainRouter
from api.exceptions import ConflictException
from api.models import (
AttackPathsScan,
Finding,
Integration,
IntegrationProviderRelationship,
@@ -1132,6 +1133,109 @@ class ScanComplianceReportSerializer(BaseSerializerV1):
fields = ["id", "name"]
class AttackPathsScanSerializer(RLSSerializer):
state = StateEnumSerializerField(read_only=True)
provider_alias = serializers.SerializerMethodField(read_only=True)
provider_type = serializers.SerializerMethodField(read_only=True)
provider_uid = serializers.SerializerMethodField(read_only=True)
class Meta:
model = AttackPathsScan
fields = [
"id",
"state",
"progress",
"provider",
"provider_alias",
"provider_type",
"provider_uid",
"scan",
"task",
"inserted_at",
"started_at",
"completed_at",
"duration",
]
included_serializers = {
"provider": "api.v1.serializers.ProviderIncludeSerializer",
"scan": "api.v1.serializers.ScanIncludeSerializer",
"task": "api.v1.serializers.TaskSerializer",
}
def get_provider_alias(self, obj):
provider = getattr(obj, "provider", None)
return provider.alias if provider else None
def get_provider_type(self, obj):
provider = getattr(obj, "provider", None)
return provider.provider if provider else None
def get_provider_uid(self, obj):
provider = getattr(obj, "provider", None)
return provider.uid if provider else None
class AttackPathsQueryParameterSerializer(BaseSerializerV1):
name = serializers.CharField()
label = serializers.CharField()
data_type = serializers.CharField(default="string")
description = serializers.CharField(allow_null=True, required=False)
placeholder = serializers.CharField(allow_null=True, required=False)
class JSONAPIMeta:
resource_name = "attack-paths-query-parameters"
class AttackPathsQuerySerializer(BaseSerializerV1):
id = serializers.CharField()
name = serializers.CharField()
description = serializers.CharField()
provider = serializers.CharField()
parameters = AttackPathsQueryParameterSerializer(many=True)
class JSONAPIMeta:
resource_name = "attack-paths-queries"
class AttackPathsQueryRunRequestSerializer(BaseSerializerV1):
id = serializers.CharField()
parameters = serializers.DictField(
child=serializers.JSONField(), allow_empty=True, required=False
)
class JSONAPIMeta:
resource_name = "attack-paths-query-run-requests"
class AttackPathsNodeSerializer(BaseSerializerV1):
id = serializers.CharField()
labels = serializers.ListField(child=serializers.CharField())
properties = serializers.DictField(child=serializers.JSONField())
class JSONAPIMeta:
resource_name = "attack-paths-query-result-nodes"
class AttackPathsRelationshipSerializer(BaseSerializerV1):
id = serializers.CharField()
label = serializers.CharField()
source = serializers.CharField()
target = serializers.CharField()
properties = serializers.DictField(child=serializers.JSONField())
class JSONAPIMeta:
resource_name = "attack-paths-query-result-relationships"
class AttackPathsQueryResultSerializer(BaseSerializerV1):
nodes = AttackPathsNodeSerializer(many=True)
relationships = AttackPathsRelationshipSerializer(many=True)
class JSONAPIMeta:
resource_name = "attack-paths-query-results"
class ResourceTagSerializer(RLSSerializer):
"""
Serializer for the ResourceTag model
+4
View File
@@ -4,6 +4,7 @@ from drf_spectacular.views import SpectacularRedocView
from rest_framework_nested import routers
from api.v1.views import (
AttackPathsScanViewSet,
ComplianceOverviewViewSet,
CustomSAMLLoginView,
CustomTokenObtainView,
@@ -53,6 +54,9 @@ router.register(r"tenants", TenantViewSet, basename="tenant")
router.register(r"providers", ProviderViewSet, basename="provider")
router.register(r"provider-groups", ProviderGroupViewSet, basename="providergroup")
router.register(r"scans", ScanViewSet, basename="scan")
router.register(
r"attack-paths-scans", AttackPathsScanViewSet, basename="attack-paths-scans"
)
router.register(r"tasks", TaskViewSet, basename="task")
router.register(r"resources", ResourceViewSet, basename="resource")
router.register(r"findings", FindingViewSet, basename="finding")
+224 -18
View File
@@ -3,6 +3,7 @@ import glob
import json
import logging
import os
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timedelta, timezone
@@ -10,6 +11,7 @@ from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
from urllib.parse import urljoin
import sentry_sdk
from allauth.socialaccount.models import SocialAccount, SocialApp
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
@@ -41,8 +43,9 @@ from django.db.models import (
Sum,
Value,
When,
Window,
)
from django.db.models.functions import Coalesce
from django.db.models.functions import Coalesce, RowNumber
from django.http import HttpResponse, QueryDict
from django.shortcuts import redirect
from django.urls import reverse
@@ -72,23 +75,12 @@ from rest_framework.generics import GenericAPIView, get_object_or_404
from rest_framework.permissions import SAFE_METHODS
from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
backfill_compliance_summaries_task,
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
check_lighthouse_provider_connection_task,
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
jira_integration_task,
mute_historical_findings_task,
perform_scan_task,
refresh_lighthouse_provider_models_task,
)
from api.attack_paths import (
get_queries_for_provider,
get_query_by_id,
views_helpers as attack_paths_views_helpers,
)
from api.base_views import BaseRLSViewSet, BaseTenantViewset, BaseUserViewset
from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
@@ -110,6 +102,7 @@ from api.filters import (
InvitationFilter,
LatestFindingFilter,
LatestResourceFilter,
AttackPathsScanFilter,
LighthouseProviderConfigFilter,
LighthouseProviderModelsFilter,
MembershipFilter,
@@ -138,6 +131,7 @@ from api.models import (
Finding,
Integration,
Invitation,
AttackPathsScan,
LighthouseConfiguration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
@@ -183,6 +177,10 @@ from api.utils import (
from api.uuid_utils import datetime_to_uuid7, uuid7_start
from api.v1.mixins import DisablePaginationMixin, PaginateByPkMixin, TaskManagementMixin
from api.v1.serializers import (
AttackPathsQueryRunRequestSerializer,
AttackPathsQuerySerializer,
AttackPathsQueryResultSerializer,
AttackPathsScanSerializer,
AttackSurfaceOverviewSerializer,
CategoryOverviewSerializer,
ComplianceOverviewAttributesSerializer,
@@ -265,6 +263,23 @@ from api.v1.serializers import (
UserSerializer,
UserUpdateSerializer,
)
from tasks.beat import schedule_provider_scan
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
backfill_compliance_summaries_task,
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
check_lighthouse_provider_connection_task,
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
jira_integration_task,
mute_historical_findings_task,
perform_scan_task,
refresh_lighthouse_provider_models_task,
)
logger = logging.getLogger(BackendLogger.API)
@@ -408,6 +423,10 @@ class SchemaView(SpectacularAPIView):
"name": "Scan",
"description": "Endpoints for triggering manual scans and viewing scan results.",
},
{
"name": "Attack Paths",
"description": "Endpoints for Attack Paths scan status and executing Attack Paths queries.",
},
{
"name": "Schedule",
"description": "Endpoints for managing scan schedules, allowing configuration of automated "
@@ -2158,6 +2177,12 @@ class ScanViewSet(BaseRLSViewSet):
},
)
attack_paths_db_utils.create_attack_paths_scan(
tenant_id=self.request.tenant_id,
scan_id=str(scan.id),
provider_id=str(scan.provider_id),
)
prowler_task = Task.objects.get(id=task.id)
scan.task_id = task.id
scan.save(update_fields=["task_id"])
@@ -2238,6 +2263,187 @@ class TaskViewSet(BaseRLSViewSet):
)
@extend_schema_view(
list=extend_schema(
tags=["Attack Paths"],
summary="List Attack Paths scans",
description="Retrieve Attack Paths scans for the tenant with support for filtering, ordering, and pagination.",
),
retrieve=extend_schema(
tags=["Attack Paths"],
summary="Retrieve Attack Paths scan details",
description="Fetch full details for a specific Attack Paths scan.",
),
attack_paths_queries=extend_schema(
tags=["Attack Paths"],
summary="List attack paths queries",
description="Retrieve the catalog of Attack Paths queries available for this Attack Paths scan.",
responses={
200: OpenApiResponse(AttackPathsQuerySerializer(many=True)),
404: OpenApiResponse(
description="No queries found for the selected provider"
),
},
),
run_attack_paths_query=extend_schema(
tags=["Attack Paths"],
summary="Execute an Attack Paths query",
description="Execute the selected Attack Paths query against the Attack Paths graph and return the resulting subgraph.",
request=AttackPathsQueryRunRequestSerializer,
responses={
200: OpenApiResponse(AttackPathsQueryResultSerializer),
400: OpenApiResponse(
description="Bad request (e.g., Unknown Attack Paths query for the selected provider)"
),
404: OpenApiResponse(
description="No attack paths found for the given query and parameters"
),
500: OpenApiResponse(
description="Attack Paths query execution failed due to a database error"
),
},
),
)
class AttackPathsScanViewSet(BaseRLSViewSet):
queryset = AttackPathsScan.objects.all()
serializer_class = AttackPathsScanSerializer
http_method_names = ["get", "post"]
filterset_class = AttackPathsScanFilter
ordering = ["-inserted_at"]
ordering_fields = [
"inserted_at",
"started_at",
]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_SCANS]
def set_required_permissions(self):
if self.request.method in SAFE_METHODS:
self.required_permissions = []
else:
self.required_permissions = [Permissions.MANAGE_SCANS]
def get_serializer_class(self):
if self.action == "run_attack_paths_query":
return AttackPathsQueryRunRequestSerializer
return super().get_serializer_class()
def get_queryset(self):
user_roles = get_role(self.request.user)
base_queryset = AttackPathsScan.objects.filter(tenant_id=self.request.tenant_id)
if user_roles.unlimited_visibility:
queryset = base_queryset
else:
queryset = base_queryset.filter(provider__in=get_providers(user_roles))
return queryset.select_related("provider", "scan", "task")
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
latest_per_provider = queryset.annotate(
latest_scan_rank=Window(
expression=RowNumber(),
partition_by=[F("provider_id")],
order_by=[F("inserted_at").desc()],
)
).filter(latest_scan_rank=1)
page = self.paginate_queryset(latest_per_provider)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(latest_per_provider, many=True)
return Response(serializer.data)
@extend_schema(exclude=True)
def create(self, request, *args, **kwargs):
raise MethodNotAllowed(method="POST")
@extend_schema(exclude=True)
def destroy(self, request, *args, **kwargs):
raise MethodNotAllowed(method="DELETE")
@action(
detail=True,
methods=["get"],
url_path="queries",
url_name="queries",
)
def attack_paths_queries(self, request, pk=None):
attack_paths_scan = self.get_object()
queries = get_queries_for_provider(attack_paths_scan.provider.provider)
if not queries:
return Response(
{"detail": "No queries found for the selected provider"},
status=status.HTTP_404_NOT_FOUND,
)
serializer = AttackPathsQuerySerializer(queries, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(
detail=True,
methods=["post"],
url_path="queries/run",
url_name="queries-run",
)
def run_attack_paths_query(self, request, pk=None):
attack_paths_scan = self.get_object()
if attack_paths_scan.state != StateChoices.COMPLETED:
raise ValidationError(
{
"detail": "The Attack Paths scan must be completed before running Attack Paths queries"
}
)
if not attack_paths_scan.graph_database:
logger.error(
f"The Attack Paths Scan {attack_paths_scan.id} does not reference a graph database"
)
return Response(
{"detail": "The Attack Paths scan does not reference a graph database"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
payload = attack_paths_views_helpers.normalize_run_payload(request.data)
serializer = AttackPathsQueryRunRequestSerializer(data=payload)
serializer.is_valid(raise_exception=True)
query_definition = get_query_by_id(serializer.validated_data["id"])
if (
query_definition is None
or query_definition.provider != attack_paths_scan.provider.provider
):
raise ValidationError(
{"id": "Unknown Attack Paths query for the selected provider"}
)
parameters = attack_paths_views_helpers.prepare_query_parameters(
query_definition,
serializer.validated_data.get("parameters", {}),
attack_paths_scan.provider.uid,
)
graph = attack_paths_views_helpers.execute_attack_paths_query(
attack_paths_scan, query_definition, parameters
)
status_code = status.HTTP_200_OK
if not graph.get("nodes"):
status_code = status.HTTP_404_NOT_FOUND
response_serializer = AttackPathsQueryResultSerializer(graph)
return Response(response_serializer.data, status=status_code)
@extend_schema_view(
list=extend_schema(
tags=["Resource"],
@@ -5912,7 +6118,7 @@ class TenantApiKeyViewSet(BaseRLSViewSet):
@extend_schema(exclude=True)
def destroy(self, request, *args, **kwargs):
raise MethodNotAllowed(method="DESTROY")
raise MethodNotAllowed(method="DELETE")
@action(detail=True, methods=["delete"])
def revoke(self, request, *args, **kwargs):
+1
View File
@@ -1,6 +1,7 @@
import warnings
from celery import Celery, Task
from config.env import env
# Suppress specific warnings from django-rest-auth: https://github.com/iMerica/dj-rest-auth/issues/684
+6
View File
@@ -44,6 +44,12 @@ DATABASES = {
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
},
"neo4j": {
"HOST": env.str("NEO4J_HOST", "neo4j"),
"PORT": env.str("NEO4J_PORT", "7687"),
"USER": env.str("NEO4J_USER", "neo4j"),
"PASSWORD": env.str("NEO4J_PASSWORD", "neo4j_password"),
},
}
DATABASES["default"] = DATABASES["prowler_user"]
@@ -45,6 +45,12 @@ DATABASES = {
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
},
"neo4j": {
"HOST": env.str("NEO4J_HOST"),
"PORT": env.str("NEO4J_PORT"),
"USER": env.str("NEO4J_USER"),
"PASSWORD": env.str("NEO4J_PASSWORD"),
},
}
DATABASES["default"] = DATABASES["prowler_user"]
+115 -11
View File
@@ -1,8 +1,11 @@
import logging
from types import SimpleNamespace
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
import pytest
from allauth.socialaccount.models import SocialLogin
from django.conf import settings
from django.db import connection as django_connection
@@ -11,14 +14,14 @@ from django.urls import reverse
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.test import APIClient
from tasks.jobs.backfill import (
backfill_resource_scan_summaries,
backfill_scan_category_summaries,
backfill_scan_resource_group_summaries,
)
from api.attack_paths import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
from api.db_utils import rls_transaction
from api.models import (
AttackPathsScan,
AttackSurfaceOverview,
ComplianceOverview,
ComplianceRequirementOverview,
@@ -56,6 +59,11 @@ from api.rls import Tenant
from api.v1.serializers import TokenSerializer
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
from tasks.jobs.backfill import (
backfill_resource_scan_summaries,
backfill_scan_category_summaries,
backfill_scan_resource_group_summaries,
)
TODAY = str(datetime.today().date())
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
@@ -168,22 +176,20 @@ def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker, tenants_f
@pytest.fixture(scope="function")
def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
def create_test_user_rbac_limited(django_db_setup, django_db_blocker, tenants_fixture):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing_limited",
email="rbac_limited@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
tenant = tenants_fixture[0]
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
Role.objects.create(
role = Role.objects.create(
name="limited",
tenant_id=tenant.id,
manage_users=False,
@@ -196,7 +202,7 @@ def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
)
UserRoleRelationship.objects.create(
user=user,
role=Role.objects.get(name="limited"),
role=role,
tenant_id=tenant.id,
)
return user
@@ -1597,6 +1603,104 @@ def mute_rules_fixture(tenants_fixture, create_test_user, findings_fixture):
return mute_rule1, mute_rule2
@pytest.fixture
def create_attack_paths_scan():
"""Factory fixture to create Attack Paths scans for tests."""
def _create(
provider,
*,
scan=None,
state=StateChoices.COMPLETED,
progress=0,
graph_database="tenant-db",
**extra_fields,
):
scan_instance = scan or Scan.objects.create(
name=extra_fields.pop("scan_name", "Attack Paths Supporting Scan"),
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=extra_fields.pop("scan_state", StateChoices.COMPLETED),
tenant_id=provider.tenant_id,
)
payload = {
"tenant_id": provider.tenant_id,
"provider": provider,
"scan": scan_instance,
"state": state,
"progress": progress,
"graph_database": graph_database,
}
payload.update(extra_fields)
return AttackPathsScan.objects.create(**payload)
return _create
@pytest.fixture
def attack_paths_query_definition_factory():
"""Factory fixture for building Attack Paths query definitions."""
def _create(**overrides):
cast_type = overrides.pop("cast_type", str)
parameters = overrides.pop(
"parameters",
[
AttackPathsQueryParameterDefinition(
name="limit",
label="Limit",
cast=cast_type,
)
],
)
definition_payload = {
"id": "aws-test",
"name": "Attack Paths Test Query",
"description": "Synthetic Attack Paths definition for tests.",
"provider": "aws",
"cypher": "RETURN 1",
"parameters": parameters,
}
definition_payload.update(overrides)
return AttackPathsQueryDefinition(**definition_payload)
return _create
@pytest.fixture
def attack_paths_graph_stub_classes():
"""Provide lightweight graph element stubs for Attack Paths serialization tests."""
class AttackPathsNativeValue:
def __init__(self, value):
self._value = value
def to_native(self):
return self._value
class AttackPathsNode:
def __init__(self, element_id, labels, properties):
self.element_id = element_id
self.labels = labels
self._properties = properties
class AttackPathsRelationship:
def __init__(self, element_id, rel_type, start_node, end_node, properties):
self.element_id = element_id
self.type = rel_type
self.start_node = start_node
self.end_node = end_node
self._properties = properties
return SimpleNamespace(
NativeValue=AttackPathsNativeValue,
Node=AttackPathsNode,
Relationship=AttackPathsRelationship,
)
@pytest.fixture
def create_attack_surface_overview():
def _create(tenant, scan, attack_surface_type, total=10, failed=5, muted_failed=2):
+7
View File
@@ -7,6 +7,7 @@ from tasks.tasks import perform_scheduled_scan_task
from api.db_utils import rls_transaction
from api.exceptions import ConflictException
from api.models import Provider, Scan, StateChoices
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
def schedule_provider_scan(provider_instance: Provider):
@@ -39,6 +40,12 @@ def schedule_provider_scan(provider_instance: Provider):
scheduled_at=datetime.now(timezone.utc),
)
attack_paths_db_utils.create_attack_paths_scan(
tenant_id=tenant_id,
scan_id=str(scheduled_scan.id),
provider_id=provider_id,
)
# Schedule the task
periodic_task_instance = PeriodicTask.objects.create(
interval=schedule,
@@ -0,0 +1,7 @@
from tasks.jobs.attack_paths.db_utils import can_provider_run_attack_paths_scan
from tasks.jobs.attack_paths.scan import run as attack_paths_scan
__all__ = [
"attack_paths_scan",
"can_provider_run_attack_paths_scan",
]
@@ -0,0 +1,237 @@
# Portions of this file are based on code from the Cartography project
# (https://github.com/cartography-cncf/cartography), which is licensed under the Apache 2.0 License.
from typing import Any
import aioboto3
import boto3
import neo4j
from cartography.config import Config as CartographyConfig
from cartography.intel import aws as cartography_aws
from celery.utils.log import get_task_logger
from api.models import (
AttackPathsScan as ProwlerAPIAttackPathsScan,
Provider as ProwlerAPIProvider,
)
from prowler.providers.common.provider import Provider as ProwlerSDKProvider
from tasks.jobs.attack_paths import db_utils, utils
logger = get_task_logger(__name__)
def start_aws_ingestion(
neo4j_session: neo4j.Session,
cartography_config: CartographyConfig,
prowler_api_provider: ProwlerAPIProvider,
prowler_sdk_provider: ProwlerSDKProvider,
attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> dict[str, dict[str, str]]:
"""
Code based on Cartography version 0.122.0, specifically on `cartography.intel.aws.__init__.py`.
For the scan progress updates:
- The caller of this function (`tasks.jobs.attack_paths.scan.run`) has set it to 2.
- When the control returns to the caller, it will be set to 95.
"""
# Initialize variables common to all jobs
common_job_parameters = {
"UPDATE_TAG": cartography_config.update_tag,
"permission_relationships_file": cartography_config.permission_relationships_file,
"aws_guardduty_severity_threshold": cartography_config.aws_guardduty_severity_threshold,
"aws_cloudtrail_management_events_lookback_hours": cartography_config.aws_cloudtrail_management_events_lookback_hours,
"experimental_aws_inspector_batch": cartography_config.experimental_aws_inspector_batch,
}
boto3_session = get_boto3_session(prowler_api_provider, prowler_sdk_provider)
regions: list[str] = list(prowler_sdk_provider._enabled_regions)
requested_syncs = list(cartography_aws.RESOURCE_FUNCTIONS.keys())
sync_args = cartography_aws._build_aws_sync_kwargs(
neo4j_session,
boto3_session,
regions,
prowler_api_provider.uid,
cartography_config.update_tag,
common_job_parameters,
)
# Starting with sync functions
cartography_aws.organizations.sync(
neo4j_session,
{prowler_api_provider.alias: prowler_api_provider.uid},
cartography_config.update_tag,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 3)
# Adding an extra field
common_job_parameters["AWS_ID"] = prowler_api_provider.uid
cartography_aws._autodiscover_accounts(
neo4j_session,
boto3_session,
prowler_api_provider.uid,
cartography_config.update_tag,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 4)
failed_syncs = sync_aws_account(
prowler_api_provider, requested_syncs, sync_args, attack_paths_scan
)
if "permission_relationships" in requested_syncs:
cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"](**sync_args)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 88)
if "resourcegroupstaggingapi" in requested_syncs:
cartography_aws.RESOURCE_FUNCTIONS["resourcegroupstaggingapi"](**sync_args)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 89)
cartography_aws.run_scoped_analysis_job(
"aws_ec2_iaminstanceprofile.json",
neo4j_session,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 90)
cartography_aws.run_analysis_job(
"aws_lambda_ecr.json",
neo4j_session,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 91)
cartography_aws.merge_module_sync_metadata(
neo4j_session,
group_type="AWSAccount",
group_id=prowler_api_provider.uid,
synced_type="AWSAccount",
update_tag=cartography_config.update_tag,
stat_handler=cartography_aws.stat_handler,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 92)
# Removing the added extra field
del common_job_parameters["AWS_ID"]
cartography_aws.run_cleanup_job(
"aws_post_ingestion_principals_cleanup.json",
neo4j_session,
common_job_parameters,
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 93)
cartography_aws._perform_aws_analysis(
requested_syncs, neo4j_session, common_job_parameters
)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 94)
return failed_syncs
def get_boto3_session(
prowler_api_provider: ProwlerAPIProvider, prowler_sdk_provider: ProwlerSDKProvider
) -> boto3.Session:
boto3_session = prowler_sdk_provider.session.current_session
aws_accounts_from_session = cartography_aws.organizations.get_aws_account_default(
boto3_session
)
if not aws_accounts_from_session:
raise Exception(
"No valid AWS credentials could be found. No AWS accounts can be synced."
)
aws_account_id_from_session = list(aws_accounts_from_session.values())[0]
if prowler_api_provider.uid != aws_account_id_from_session:
raise Exception(
f"Provider {prowler_api_provider.uid} doesn't match AWS account {aws_account_id_from_session}."
)
if boto3_session.region_name is None:
global_region = prowler_sdk_provider.get_global_region()
boto3_session._session.set_config_variable("region", global_region)
return boto3_session
def get_aioboto3_session(boto3_session: boto3.Session) -> aioboto3.Session:
return aioboto3.Session(botocore_session=boto3_session._session)
def sync_aws_account(
prowler_api_provider: ProwlerAPIProvider,
requested_syncs: list[str],
sync_args: dict[str, Any],
attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> dict[str, str]:
current_progress = 4 # `cartography_aws._autodiscover_accounts`
max_progress = (
87 # `cartography_aws.RESOURCE_FUNCTIONS["permission_relationships"]` - 1
)
n_steps = (
len(requested_syncs) - 2
) # Excluding `permission_relationships` and `resourcegroupstaggingapi`
progress_step = (max_progress - current_progress) / n_steps
failed_syncs = {}
for func_name in requested_syncs:
if func_name in cartography_aws.RESOURCE_FUNCTIONS:
logger.info(
f"Syncing function {func_name} for AWS account {prowler_api_provider.uid}"
)
# Updating progress, not really the right place but good enough
current_progress += progress_step
db_utils.update_attack_paths_scan_progress(
attack_paths_scan, int(current_progress)
)
try:
# `ecr:image_layers` uses `aioboto3_session` instead of `boto3_session`
if func_name == "ecr:image_layers":
cartography_aws.RESOURCE_FUNCTIONS[func_name](
neo4j_session=sync_args.get("neo4j_session"),
aioboto3_session=get_aioboto3_session(
sync_args.get("boto3_session")
),
regions=sync_args.get("regions"),
current_aws_account_id=sync_args.get("current_aws_account_id"),
update_tag=sync_args.get("update_tag"),
common_job_parameters=sync_args.get("common_job_parameters"),
)
# Skip permission relationships and tags for now because they rely on data already being in the graph
elif func_name in [
"permission_relationships",
"resourcegroupstaggingapi",
]:
continue
else:
cartography_aws.RESOURCE_FUNCTIONS[func_name](**sync_args)
except Exception as e:
exception_message = utils.stringify_exception(
e, f"Exception for AWS sync function: {func_name}"
)
failed_syncs[func_name] = exception_message
logger.warning(
f"Caught exception syncing function {func_name} from AWS account {prowler_api_provider.uid}. We "
"are continuing on to the next AWS sync function.",
)
continue
else:
raise ValueError(
f'AWS sync function "{func_name}" was specified but does not exist. Did you misspell it?'
)
return failed_syncs
@@ -0,0 +1,161 @@
from datetime import datetime, timezone
from typing import Any
from cartography.config import Config as CartographyConfig
from api.db_utils import rls_transaction
from api.models import (
AttackPathsScan as ProwlerAPIAttackPathsScan,
Provider as ProwlerAPIProvider,
StateChoices,
)
from tasks.jobs.attack_paths.providers import is_provider_available
def can_provider_run_attack_paths_scan(tenant_id: str, provider_id: int) -> bool:
with rls_transaction(tenant_id):
prowler_api_provider = ProwlerAPIProvider.objects.get(id=provider_id)
return is_provider_available(prowler_api_provider.provider)
def create_attack_paths_scan(
tenant_id: str,
scan_id: str,
provider_id: int,
) -> ProwlerAPIAttackPathsScan | None:
if not can_provider_run_attack_paths_scan(tenant_id, provider_id):
return None
with rls_transaction(tenant_id):
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.create(
tenant_id=tenant_id,
provider_id=provider_id,
scan_id=scan_id,
state=StateChoices.SCHEDULED,
started_at=datetime.now(tz=timezone.utc),
)
attack_paths_scan.save()
return attack_paths_scan
def retrieve_attack_paths_scan(
tenant_id: str,
scan_id: str,
) -> ProwlerAPIAttackPathsScan | None:
try:
with rls_transaction(tenant_id):
attack_paths_scan = ProwlerAPIAttackPathsScan.objects.get(
scan_id=scan_id,
)
return attack_paths_scan
except ProwlerAPIAttackPathsScan.DoesNotExist:
return None
def starting_attack_paths_scan(
attack_paths_scan: ProwlerAPIAttackPathsScan,
task_id: str,
cartography_config: CartographyConfig,
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.task_id = task_id
attack_paths_scan.state = StateChoices.EXECUTING
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
attack_paths_scan.update_tag = cartography_config.update_tag
attack_paths_scan.graph_database = cartography_config.neo4j_database
attack_paths_scan.save(
update_fields=[
"task_id",
"state",
"started_at",
"update_tag",
"graph_database",
]
)
def finish_attack_paths_scan(
attack_paths_scan: ProwlerAPIAttackPathsScan,
state: StateChoices,
ingestion_exceptions: dict[str, Any],
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
now = datetime.now(tz=timezone.utc)
duration = int((now - attack_paths_scan.started_at).total_seconds())
attack_paths_scan.state = state
attack_paths_scan.progress = 100
attack_paths_scan.completed_at = now
attack_paths_scan.duration = duration
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
attack_paths_scan.save(
update_fields=[
"state",
"progress",
"completed_at",
"duration",
"ingestion_exceptions",
]
)
def update_attack_paths_scan_progress(
attack_paths_scan: ProwlerAPIAttackPathsScan,
progress: int,
) -> None:
with rls_transaction(attack_paths_scan.tenant_id):
attack_paths_scan.progress = progress
attack_paths_scan.save(update_fields=["progress"])
def get_old_attack_paths_scans(
tenant_id: str,
provider_id: str,
attack_paths_scan_id: str,
) -> list[ProwlerAPIAttackPathsScan]:
"""
An `old_attack_paths_scan` is any `completed` Attack Paths scan for the same provider,
with its graph database not deleted, excluding the current Attack Paths scan.
"""
with rls_transaction(tenant_id):
completed_scans_qs = (
ProwlerAPIAttackPathsScan.objects.filter(
provider_id=provider_id,
state=StateChoices.COMPLETED,
is_graph_database_deleted=False,
)
.exclude(id=attack_paths_scan_id)
.all()
)
return list(completed_scans_qs)
def update_old_attack_paths_scan(
old_attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> None:
with rls_transaction(old_attack_paths_scan.tenant_id):
old_attack_paths_scan.is_graph_database_deleted = True
old_attack_paths_scan.save(update_fields=["is_graph_database_deleted"])
def get_provider_graph_database_names(tenant_id: str, provider_id: str) -> list[str]:
"""
Return existing graph database names for a tenant/provider.
Note: For accesing the `AttackPathsScan` we need to use `all_objects` manager because the provider is soft-deleted.
"""
with rls_transaction(tenant_id):
graph_databases_names_qs = ProwlerAPIAttackPathsScan.all_objects.filter(
provider_id=provider_id,
is_graph_database_deleted=False,
).values_list("graph_database", flat=True)
return list(graph_databases_names_qs)
@@ -0,0 +1,23 @@
AVAILABLE_PROVIDERS: list[str] = [
"aws",
]
ROOT_NODE_LABELS: dict[str, str] = {
"aws": "AWSAccount",
}
NODE_UID_FIELDS: dict[str, str] = {
"aws": "arn",
}
def is_provider_available(provider_type: str) -> bool:
return provider_type in AVAILABLE_PROVIDERS
def get_root_node_label(provider_type: str) -> str:
return ROOT_NODE_LABELS.get(provider_type, "UnknownProviderAccount")
def get_node_uid_field(provider_type: str) -> str:
return NODE_UID_FIELDS.get(provider_type, "UnknownProviderUID")
@@ -0,0 +1,205 @@
import neo4j
from cartography.client.core.tx import run_write_query
from cartography.config import Config as CartographyConfig
from celery.utils.log import get_task_logger
from api.db_utils import rls_transaction
from api.models import Provider, ResourceFindingMapping
from config.env import env
from prowler.config import config as ProwlerConfig
from tasks.jobs.attack_paths.providers import get_node_uid_field, get_root_node_label
logger = get_task_logger(__name__)
BATCH_SIZE = env.int("NEO4J_INSERT_BATCH_SIZE", 500)
INDEX_STATEMENTS = [
"CREATE INDEX prowler_finding_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.id);",
"CREATE INDEX prowler_finding_provider_uid IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.provider_uid);",
"CREATE INDEX prowler_finding_lastupdated IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.lastupdated);",
"CREATE INDEX prowler_finding_check_id IF NOT EXISTS FOR (n:ProwlerFinding) ON (n.status);",
]
INSERT_STATEMENT_TEMPLATE = """
UNWIND $findings_data AS finding_data
MATCH (account:__ROOT_NODE_LABEL__ {id: $provider_uid})
MATCH (account)-->(resource)
WHERE resource.__NODE_UID_FIELD__ = finding_data.resource_uid
OR resource.id = finding_data.resource_uid
MERGE (finding:ProwlerFinding {id: finding_data.id})
ON CREATE SET
finding.id = finding_data.id,
finding.uid = finding_data.uid,
finding.inserted_at = finding_data.inserted_at,
finding.updated_at = finding_data.updated_at,
finding.first_seen_at = finding_data.first_seen_at,
finding.scan_id = finding_data.scan_id,
finding.delta = finding_data.delta,
finding.status = finding_data.status,
finding.status_extended = finding_data.status_extended,
finding.severity = finding_data.severity,
finding.check_id = finding_data.check_id,
finding.check_title = finding_data.check_title,
finding.muted = finding_data.muted,
finding.muted_reason = finding_data.muted_reason,
finding.provider_uid = $provider_uid,
finding.firstseen = timestamp(),
finding.lastupdated = $last_updated,
finding._module_name = 'cartography:prowler',
finding._module_version = $prowler_version
ON MATCH SET
finding.status = finding_data.status,
finding.status_extended = finding_data.status_extended,
finding.lastupdated = $last_updated
MERGE (resource)-[rel:HAS_FINDING]->(finding)
ON CREATE SET
rel.provider_uid = $provider_uid,
rel.firstseen = timestamp(),
rel.lastupdated = $last_updated,
rel._module_name = 'cartography:prowler',
rel._module_version = $prowler_version
ON MATCH SET
rel.lastupdated = $last_updated
"""
CLEANUP_STATEMENT = """
MATCH (finding:ProwlerFinding {provider_uid: $provider_uid})
WHERE finding.lastupdated < $last_updated
WITH finding LIMIT $batch_size
DETACH DELETE finding
RETURN COUNT(finding) AS deleted_findings_count
"""
def create_indexes(neo4j_session: neo4j.Session) -> None:
"""
Code based on Cartography version 0.122.0, specifically on `cartography.intel.create_indexes.run`.
"""
logger.info("Creating indexes for Prowler node types.")
for statement in INDEX_STATEMENTS:
logger.debug("Executing statement: %s", statement)
run_write_query(neo4j_session, statement)
def analysis(
neo4j_session: neo4j.Session,
prowler_api_provider: Provider,
scan_id: str,
config: CartographyConfig,
) -> None:
findings_data = get_provider_last_scan_findings(prowler_api_provider, scan_id)
load_findings(neo4j_session, findings_data, prowler_api_provider, config)
cleanup_findings(neo4j_session, prowler_api_provider, config)
def get_provider_last_scan_findings(
prowler_api_provider: Provider,
scan_id: str,
) -> list[dict[str, str]]:
with rls_transaction(prowler_api_provider.tenant_id):
resource_finding_qs = ResourceFindingMapping.objects.filter(
finding__scan_id=scan_id,
).values(
"resource__uid",
"finding__id",
"finding__uid",
"finding__inserted_at",
"finding__updated_at",
"finding__first_seen_at",
"finding__scan_id",
"finding__delta",
"finding__status",
"finding__status_extended",
"finding__severity",
"finding__check_id",
"finding__check_metadata__checktitle",
"finding__muted",
"finding__muted_reason",
)
findings = []
for resource_finding in resource_finding_qs:
findings.append(
{
"resource_uid": str(resource_finding["resource__uid"]),
"id": str(resource_finding["finding__id"]),
"uid": resource_finding["finding__uid"],
"inserted_at": resource_finding["finding__inserted_at"],
"updated_at": resource_finding["finding__updated_at"],
"first_seen_at": resource_finding["finding__first_seen_at"],
"scan_id": str(resource_finding["finding__scan_id"]),
"delta": resource_finding["finding__delta"],
"status": resource_finding["finding__status"],
"status_extended": resource_finding["finding__status_extended"],
"severity": resource_finding["finding__severity"],
"check_id": str(resource_finding["finding__check_id"]),
"check_title": resource_finding[
"finding__check_metadata__checktitle"
],
"muted": resource_finding["finding__muted"],
"muted_reason": resource_finding["finding__muted_reason"],
}
)
return findings
def load_findings(
neo4j_session: neo4j.Session,
findings_data: list[dict[str, str]],
prowler_api_provider: Provider,
config: CartographyConfig,
) -> None:
replacements = {
"__ROOT_NODE_LABEL__": get_root_node_label(prowler_api_provider.provider),
"__NODE_UID_FIELD__": get_node_uid_field(prowler_api_provider.provider),
}
query = INSERT_STATEMENT_TEMPLATE
for replace_key, replace_value in replacements.items():
query = query.replace(replace_key, replace_value)
parameters = {
"provider_uid": str(prowler_api_provider.uid),
"last_updated": config.update_tag,
"prowler_version": ProwlerConfig.prowler_version,
}
total_length = len(findings_data)
for i in range(0, total_length, BATCH_SIZE):
parameters["findings_data"] = findings_data[i : i + BATCH_SIZE]
logger.info(
f"Loading findings batch {i // BATCH_SIZE + 1} / {(total_length + BATCH_SIZE - 1) // BATCH_SIZE}"
)
neo4j_session.run(query, parameters)
def cleanup_findings(
neo4j_session: neo4j.Session,
prowler_api_provider: Provider,
config: CartographyConfig,
) -> None:
parameters = {
"provider_uid": str(prowler_api_provider.uid),
"last_updated": config.update_tag,
"batch_size": BATCH_SIZE,
}
batch = 1
deleted_count = 1
while deleted_count > 0:
logger.info(f"Cleaning findings batch {batch}")
result = neo4j_session.run(CLEANUP_STATEMENT, parameters)
deleted_count = result.single().get("deleted_findings_count", 0)
batch += 1
@@ -0,0 +1,183 @@
import logging
import time
import asyncio
from typing import Any, Callable
from cartography.config import Config as CartographyConfig
from cartography.intel import analysis as cartography_analysis
from cartography.intel import create_indexes as cartography_create_indexes
from cartography.intel import ontology as cartography_ontology
from celery.utils.log import get_task_logger
from api.attack_paths import database as graph_database
from api.db_utils import rls_transaction
from api.models import (
Provider as ProwlerAPIProvider,
StateChoices,
)
from api.utils import initialize_prowler_provider
from tasks.jobs.attack_paths import aws, db_utils, prowler, utils
# Without this Celery goes crazy with Cartography logging
logging.getLogger("cartography").setLevel(logging.ERROR)
logging.getLogger("neo4j").propagate = False
logger = get_task_logger(__name__)
CARTOGRAPHY_INGESTION_FUNCTIONS: dict[str, Callable] = {
"aws": aws.start_aws_ingestion,
}
def get_cartography_ingestion_function(provider_type: str) -> Callable | None:
return CARTOGRAPHY_INGESTION_FUNCTIONS.get(provider_type)
def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
"""
Code based on Cartography version 0.122.0, specifically on `cartography.cli.main`, `cartography.cli.CLI.main`,
`cartography.sync.run_with_config` and `cartography.sync.Sync.run`.
"""
ingestion_exceptions = {} # This will hold any exceptions raised during ingestion
# Prowler necessary objects
with rls_transaction(tenant_id):
prowler_api_provider = ProwlerAPIProvider.objects.get(scan__pk=scan_id)
prowler_sdk_provider = initialize_prowler_provider(prowler_api_provider)
# Attack Paths Scan necessary objects
cartography_ingestion_function = get_cartography_ingestion_function(
prowler_api_provider.provider
)
attack_paths_scan = db_utils.retrieve_attack_paths_scan(tenant_id, scan_id)
# Checks before starting the scan
if not cartography_ingestion_function:
ingestion_exceptions = {
"global_error": f"Provider {prowler_api_provider.provider} is not supported for Attack Paths scans"
}
if attack_paths_scan:
db_utils.finish_attack_paths_scan(
attack_paths_scan, StateChoices.COMPLETED, ingestion_exceptions
)
logger.warning(
f"Provider {prowler_api_provider.provider} is not supported for Attack Paths scans"
)
return ingestion_exceptions
else:
if not attack_paths_scan:
logger.warning(
f"No Attack Paths Scan found for scan {scan_id} and tenant {tenant_id}, let's create it then"
)
attack_paths_scan = db_utils.create_attack_paths_scan(
tenant_id, scan_id, prowler_api_provider.id
)
# While creating the Cartography configuration, attributes `neo4j_user` and `neo4j_password` are not really needed in this config object
cartography_config = CartographyConfig(
neo4j_uri=graph_database.get_uri(),
neo4j_database=graph_database.get_database_name(attack_paths_scan.id),
update_tag=int(time.time()),
)
# Starting the Attack Paths scan
db_utils.starting_attack_paths_scan(attack_paths_scan, task_id, cartography_config)
try:
logger.info(
f"Creating Neo4j database {cartography_config.neo4j_database} for tenant {prowler_api_provider.tenant_id}"
)
graph_database.create_database(cartography_config.neo4j_database)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 1)
logger.info(
f"Starting Cartography ({attack_paths_scan.id}) for "
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
)
with graph_database.get_session(
cartography_config.neo4j_database
) as neo4j_session:
# Indexes creation
cartography_create_indexes.run(neo4j_session, cartography_config)
prowler.create_indexes(neo4j_session)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 2)
# The real scan, where iterates over cloud services
ingestion_exceptions = _call_within_event_loop(
cartography_ingestion_function,
neo4j_session,
cartography_config,
prowler_api_provider,
prowler_sdk_provider,
attack_paths_scan,
)
# Post-processing: Just keeping it to be more Cartography compliant
cartography_ontology.run(neo4j_session, cartography_config)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 95)
cartography_analysis.run(neo4j_session, cartography_config)
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 96)
# Adding Prowler nodes and relationships
prowler.analysis(
neo4j_session, prowler_api_provider, scan_id, cartography_config
)
logger.info(
f"Completed Cartography ({attack_paths_scan.id}) for "
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
)
# Handling databases changes
old_attack_paths_scans = db_utils.get_old_attack_paths_scans(
prowler_api_provider.tenant_id,
prowler_api_provider.id,
attack_paths_scan.id,
)
for old_attack_paths_scan in old_attack_paths_scans:
graph_database.drop_database(old_attack_paths_scan.graph_database)
db_utils.update_old_attack_paths_scan(old_attack_paths_scan)
db_utils.finish_attack_paths_scan(
attack_paths_scan, StateChoices.COMPLETED, ingestion_exceptions
)
return ingestion_exceptions
except Exception as e:
exception_message = utils.stringify_exception(e, "Cartography failed")
logger.error(exception_message)
ingestion_exceptions["global_cartography_error"] = exception_message
# Handling databases changes
graph_database.drop_database(cartography_config.neo4j_database)
db_utils.finish_attack_paths_scan(
attack_paths_scan, StateChoices.FAILED, ingestion_exceptions
)
raise
def _call_within_event_loop(fn, *args, **kwargs):
"""
Cartography needs a running event loop, so assuming there is none (Celery task or even regular DRF endpoint),
let's create a new one and set it as the current event loop for this thread.
"""
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return fn(*args, **kwargs)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
except Exception:
pass
loop.close()
asyncio.set_event_loop(None)
@@ -0,0 +1,10 @@
import traceback
from datetime import datetime, timezone
def stringify_exception(exception: Exception, context: str) -> str:
timestamp = datetime.now(tz=timezone.utc)
exception_traceback = traceback.TracebackException.from_exception(exception)
traceback_string = "".join(exception_traceback.format())
return f"{timestamp} - {context}\n{traceback_string}"
+24 -2
View File
@@ -1,9 +1,19 @@
from celery.utils.log import get_task_logger
from django.db import DatabaseError
from api.attack_paths import database as graph_database
from api.db_router import MainRouter
from api.db_utils import batch_delete, rls_transaction
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant
from api.models import (
AttackPathsScan,
Finding,
Provider,
Resource,
Scan,
ScanSummary,
Tenant,
)
from tasks.jobs.attack_paths.db_utils import get_provider_graph_database_names
logger = get_task_logger(__name__)
@@ -23,16 +33,27 @@ def delete_provider(tenant_id: str, pk: str):
Raises:
Provider.DoesNotExist: If no instance with the provided primary key exists.
"""
# Delete the Attack Paths' graph databases related to the provider
graph_database_names = get_provider_graph_database_names(tenant_id, pk)
try:
for graph_database_name in graph_database_names:
graph_database.drop_database(graph_database_name)
except graph_database.GraphDatabaseQueryException as gdb_error:
logger.error(f"Error deleting Provider databases: {gdb_error}")
raise
# Get all provider related data and delete them in batches
with rls_transaction(tenant_id):
instance = Provider.all_objects.get(pk=pk)
deletion_summary = {}
deletion_steps = [
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
("Findings", Finding.all_objects.filter(scan__provider=instance)),
("Resources", Resource.all_objects.filter(provider=instance)),
("Scans", Scan.all_objects.filter(provider=instance)),
("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)),
]
deletion_summary = {}
for step_name, queryset in deletion_steps:
try:
_, step_summary = batch_delete(tenant_id, queryset)
@@ -48,6 +69,7 @@ def delete_provider(tenant_id: str, pk: str):
except DatabaseError as db_error:
logger.error(f"Error deleting Provider: {db_error}")
raise
return deletion_summary
+41 -12
View File
@@ -1,13 +1,29 @@
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from shutil import rmtree
from celery import chain, group, shared_task
from celery.utils.log import get_task_logger
from django_celery_beat.models import PeriodicTask
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import handle_provider_deletion, set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
from tasks.jobs.attack_paths import (
attack_paths_scan,
can_provider_run_attack_paths_scan,
)
from tasks.jobs.backfill import (
backfill_compliance_summaries,
backfill_daily_severity_summaries,
@@ -50,17 +66,6 @@ from tasks.jobs.scan import (
)
from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import handle_provider_deletion, set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
@@ -153,6 +158,11 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
),
).apply_async()
if can_provider_run_attack_paths_scan(tenant_id, provider_id):
perform_attack_paths_scan_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
@shared_task(base=RLSTask, name="provider-connection-check")
@set_tenant
@@ -358,6 +368,25 @@ def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
# TODO: This task must be queued at the `attack-paths` queue, don't forget to add it to the `docker-entrypoint.sh` file
@shared_task(base=RLSTask, bind=True, name="attack-paths-scan-perform", queue="scans")
def perform_attack_paths_scan_task(self, tenant_id: str, scan_id: str):
"""
Execute an Attack Paths scan for the given provider within the current tenant RLS context.
Args:
self: The task instance (automatically passed when bind=True).
tenant_id (str): The tenant identifier for RLS context.
scan_id (str): The Prowler scan identifier for obtaining the tenant and provider context.
Returns:
Any: The result from `attack_paths_scan`, including any per-scan failure details.
"""
return attack_paths_scan(
tenant_id=tenant_id, scan_id=scan_id, task_id=self.request.id
)
@shared_task(name="tenant-deletion", queue="deletion", autoretry_for=(Exception,))
def delete_tenant_task(tenant_id: str):
return delete_tenant(pk=tenant_id)
@@ -0,0 +1,416 @@
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock, call, patch
import pytest
from api.models import (
AttackPathsScan,
Finding,
Provider,
Resource,
ResourceFindingMapping,
Scan,
StateChoices,
StatusChoices,
)
from prowler.lib.check.models import Severity
from tasks.jobs.attack_paths import prowler as prowler_module
from tasks.jobs.attack_paths.scan import run as attack_paths_run
@pytest.mark.django_db
class TestAttackPathsRun:
def test_run_success_flow(self, tenants_fixture, providers_fixture, scans_fixture):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
attack_paths_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.SCHEDULED,
)
mock_session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
ingestion_result = {"organizations": "warning"}
ingestion_fn = MagicMock(return_value=ingestion_result)
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_uri",
return_value="bolt://neo4j",
),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id",
) as mock_get_db_name,
patch(
"tasks.jobs.attack_paths.scan.graph_database.create_database"
) as mock_create_db,
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_session",
return_value=session_ctx,
) as mock_get_session,
patch(
"tasks.jobs.attack_paths.scan.cartography_create_indexes.run"
) as mock_cartography_indexes,
patch(
"tasks.jobs.attack_paths.scan.cartography_analysis.run"
) as mock_cartography_analysis,
patch(
"tasks.jobs.attack_paths.scan.cartography_ontology.run"
) as mock_cartography_ontology,
patch(
"tasks.jobs.attack_paths.scan.prowler.create_indexes"
) as mock_prowler_indexes,
patch(
"tasks.jobs.attack_paths.scan.prowler.analysis"
) as mock_prowler_analysis,
patch(
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
) as mock_retrieve_scan,
patch(
"tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan"
) as mock_starting,
patch(
"tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress"
) as mock_update_progress,
patch(
"tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan"
) as mock_finish,
patch(
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
return_value=ingestion_fn,
) as mock_get_ingestion,
patch(
"tasks.jobs.attack_paths.scan._call_within_event_loop",
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
) as mock_event_loop,
):
result = attack_paths_run(str(tenant.id), str(scan.id), "task-123")
assert result == ingestion_result
mock_retrieve_scan.assert_called_once_with(str(tenant.id), str(scan.id))
mock_starting.assert_called_once()
config = mock_starting.call_args[0][2]
assert config.neo4j_database == "db-scan-id"
mock_create_db.assert_called_once_with("db-scan-id")
mock_get_session.assert_called_once_with("db-scan-id")
mock_cartography_indexes.assert_called_once_with(mock_session, config)
mock_prowler_indexes.assert_called_once_with(mock_session)
mock_cartography_analysis.assert_called_once_with(mock_session, config)
mock_cartography_ontology.assert_called_once_with(mock_session, config)
mock_prowler_analysis.assert_called_once_with(
mock_session,
provider,
str(scan.id),
config,
)
mock_get_ingestion.assert_called_once_with(provider.provider)
mock_event_loop.assert_called_once()
mock_update_progress.assert_any_call(attack_paths_scan, 1)
mock_update_progress.assert_any_call(attack_paths_scan, 2)
mock_update_progress.assert_any_call(attack_paths_scan, 95)
mock_finish.assert_called_once_with(
attack_paths_scan, StateChoices.COMPLETED, ingestion_result
)
mock_get_db_name.assert_called_once_with(attack_paths_scan.id)
def test_run_failure_marks_scan_failed(
self, tenants_fixture, providers_fixture, scans_fixture
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
attack_paths_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.SCHEDULED,
)
mock_session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = mock_session
session_ctx.__exit__.return_value = False
ingestion_fn = MagicMock(side_effect=RuntimeError("ingestion boom"))
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(_enabled_regions=["us-east-1"]),
),
patch("tasks.jobs.attack_paths.scan.graph_database.get_uri"),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_database_name",
return_value="db-scan-id",
),
patch("tasks.jobs.attack_paths.scan.graph_database.create_database"),
patch(
"tasks.jobs.attack_paths.scan.graph_database.get_session",
return_value=session_ctx,
),
patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run"),
patch("tasks.jobs.attack_paths.scan.cartography_analysis.run"),
patch("tasks.jobs.attack_paths.scan.prowler.create_indexes"),
patch("tasks.jobs.attack_paths.scan.prowler.analysis"),
patch(
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
),
patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan"),
patch(
"tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress"
),
patch(
"tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan"
) as mock_finish,
patch(
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
return_value=ingestion_fn,
),
patch(
"tasks.jobs.attack_paths.scan._call_within_event_loop",
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
),
patch(
"tasks.jobs.attack_paths.scan.utils.stringify_exception",
return_value="Cartography failed: ingestion boom",
),
):
with pytest.raises(RuntimeError, match="ingestion boom"):
attack_paths_run(str(tenant.id), str(scan.id), "task-456")
failure_args = mock_finish.call_args[0]
assert failure_args[0] is attack_paths_scan
assert failure_args[1] == StateChoices.FAILED
assert failure_args[2] == {
"global_cartography_error": "Cartography failed: ingestion boom"
}
def test_run_returns_early_for_unsupported_provider(self, tenants_fixture):
tenant = tenants_fixture[0]
provider = Provider.objects.create(
provider=Provider.ProviderChoices.GCP,
uid="gcp-account",
alias="gcp",
tenant_id=tenant.id,
)
scan = Scan.objects.create(
name="GCP Scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
tenant_id=tenant.id,
)
with (
patch(
"tasks.jobs.attack_paths.scan.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
),
patch(
"tasks.jobs.attack_paths.scan.initialize_prowler_provider",
return_value=MagicMock(),
),
patch(
"tasks.jobs.attack_paths.scan.get_cartography_ingestion_function",
return_value=None,
) as mock_get_ingestion,
patch(
"tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan"
) as mock_retrieve,
):
mock_retrieve.return_value = None
result = attack_paths_run(str(tenant.id), str(scan.id), "task-789")
assert result == {
"global_error": "Provider gcp is not supported for Attack Paths scans"
}
mock_get_ingestion.assert_called_once_with(provider.provider)
mock_retrieve.assert_called_once_with(str(tenant.id), str(scan.id))
@pytest.mark.django_db
class TestAttackPathsProwlerHelpers:
def test_create_indexes_executes_all_statements(self):
mock_session = MagicMock()
with patch("tasks.jobs.attack_paths.prowler.run_write_query") as mock_run_write:
prowler_module.create_indexes(mock_session)
assert mock_run_write.call_count == len(prowler_module.INDEX_STATEMENTS)
mock_run_write.assert_has_calls(
[call(mock_session, stmt) for stmt in prowler_module.INDEX_STATEMENTS]
)
def test_load_findings_batches_requests(self, providers_fixture):
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
findings = [
{"id": "1", "resource_uid": "r-1"},
{"id": "2", "resource_uid": "r-2"},
]
config = SimpleNamespace(update_tag=12345)
mock_session = MagicMock()
with (
patch.object(prowler_module, "BATCH_SIZE", 1),
patch(
"tasks.jobs.attack_paths.prowler.get_root_node_label",
return_value="AWSAccount",
),
patch(
"tasks.jobs.attack_paths.prowler.get_node_uid_field",
return_value="arn",
),
):
prowler_module.load_findings(mock_session, findings, provider, config)
assert mock_session.run.call_count == 2
for call_args in mock_session.run.call_args_list:
params = call_args.args[1]
assert params["provider_uid"] == str(provider.uid)
assert params["last_updated"] == config.update_tag
assert "findings_data" in params
def test_cleanup_findings_runs_batches(self, providers_fixture):
provider = providers_fixture[0]
config = SimpleNamespace(update_tag=1024)
mock_session = MagicMock()
first_batch = MagicMock()
first_batch.single.return_value = {"deleted_findings_count": 3}
second_batch = MagicMock()
second_batch.single.return_value = {"deleted_findings_count": 0}
mock_session.run.side_effect = [first_batch, second_batch]
prowler_module.cleanup_findings(mock_session, provider, config)
assert mock_session.run.call_count == 2
params = mock_session.run.call_args.args[1]
assert params["provider_uid"] == str(provider.uid)
assert params["last_updated"] == config.update_tag
def test_get_provider_last_scan_findings_returns_latest_scan_data(
self,
tenants_fixture,
providers_fixture,
):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
resource = Resource.objects.create(
tenant_id=tenant.id,
provider=provider,
uid="resource-uid",
name="Resource",
region="us-east-1",
service="ec2",
type="instance",
)
older_scan = Scan.objects.create(
name="Older",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
)
old_finding = Finding.objects.create(
tenant_id=tenant.id,
uid="older-finding",
scan=older_scan,
delta=Finding.DeltaChoices.NEW,
status=StatusChoices.PASS,
status_extended="ok",
severity=Severity.low,
impact=Severity.low,
impact_extended="",
raw_result={},
check_id="check-old",
check_metadata={"checktitle": "Old"},
first_seen_at=older_scan.inserted_at,
)
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
resource=resource,
finding=old_finding,
)
latest_scan = Scan.objects.create(
name="Latest",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
)
finding = Finding.objects.create(
tenant_id=tenant.id,
uid="finding-uid",
scan=latest_scan,
delta=Finding.DeltaChoices.NEW,
status=StatusChoices.FAIL,
status_extended="failed",
severity=Severity.high,
impact=Severity.high,
impact_extended="",
raw_result={},
check_id="check-1",
check_metadata={"checktitle": "Check title"},
first_seen_at=latest_scan.inserted_at,
)
ResourceFindingMapping.objects.create(
tenant_id=tenant.id,
resource=resource,
finding=finding,
)
latest_scan.refresh_from_db()
with patch(
"tasks.jobs.attack_paths.prowler.rls_transaction",
new=lambda *args, **kwargs: nullcontext(),
):
findings_data = prowler_module.get_provider_last_scan_findings(
provider,
str(latest_scan.id),
)
assert len(findings_data) == 1
finding_dict = findings_data[0]
assert finding_dict["id"] == str(finding.id)
assert finding_dict["resource_uid"] == resource.uid
assert finding_dict["check_title"] == "Check title"
assert finding_dict["scan_id"] == str(latest_scan.id)
+98 -30
View File
@@ -1,27 +1,60 @@
from unittest.mock import call, patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider, delete_tenant
from api.models import Provider, Tenant
from tasks.jobs.deletion import delete_provider, delete_tenant
@pytest.mark.django_db
class TestDeleteProvider:
def test_delete_provider_success(self, providers_fixture):
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
result = delete_provider(tenant_id, instance.id)
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
graph_db_names = ["graph-db-1", "graph-db-2"]
mock_get_provider_graph_database_names.return_value = graph_db_names
assert result
with pytest.raises(ObjectDoesNotExist):
Provider.objects.get(pk=instance.id)
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
result = delete_provider(tenant_id, instance.id)
assert result
with pytest.raises(ObjectDoesNotExist):
Provider.objects.get(pk=instance.id)
mock_get_provider_graph_database_names.assert_called_once_with(
tenant_id, instance.id
)
mock_drop_database.assert_has_calls(
[call(graph_db_name) for graph_db_name in graph_db_names]
)
def test_delete_provider_does_not_exist(self, tenants_fixture):
tenant_id = str(tenants_fixture[0].id)
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
graph_db_names = ["graph-db-1"]
mock_get_provider_graph_database_names.return_value = graph_db_names
with pytest.raises(ObjectDoesNotExist):
delete_provider(tenant_id, non_existent_pk)
tenant_id = str(tenants_fixture[0].id)
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
with pytest.raises(ObjectDoesNotExist):
delete_provider(tenant_id, non_existent_pk)
mock_get_provider_graph_database_names.assert_called_once_with(
tenant_id, non_existent_pk
)
mock_drop_database.assert_has_calls(
[call(graph_db_name) for graph_db_name in graph_db_names]
)
@pytest.mark.django_db
@@ -30,33 +63,68 @@ class TestDeleteTenant:
"""
Test successful deletion of a tenant and its related data.
"""
tenant = tenants_fixture[0]
providers = Provider.objects.filter(tenant_id=tenant.id)
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
tenant = tenants_fixture[0]
providers = list(Provider.objects.filter(tenant_id=tenant.id))
# Ensure the tenant and related providers exist before deletion
assert Tenant.objects.filter(id=tenant.id).exists()
assert providers.exists()
graph_db_names_per_provider = [
[f"graph-db-{provider.id}"] for provider in providers
]
mock_get_provider_graph_database_names.side_effect = (
graph_db_names_per_provider
)
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
# Ensure the tenant and related providers exist before deletion
assert Tenant.objects.filter(id=tenant.id).exists()
assert providers
assert deletion_summary is not None
assert not Tenant.objects.filter(id=tenant.id).exists()
assert not Provider.objects.filter(tenant_id=tenant.id).exists()
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
assert deletion_summary is not None
assert not Tenant.objects.filter(id=tenant.id).exists()
assert not Provider.objects.filter(tenant_id=tenant.id).exists()
expected_calls = [
call(provider.tenant_id, provider.id) for provider in providers
]
mock_get_provider_graph_database_names.assert_has_calls(
expected_calls, any_order=True
)
assert mock_get_provider_graph_database_names.call_count == len(
expected_calls
)
expected_drop_calls = [
call(graph_db_name[0]) for graph_db_name in graph_db_names_per_provider
]
mock_drop_database.assert_has_calls(expected_drop_calls, any_order=True)
assert mock_drop_database.call_count == len(expected_drop_calls)
def test_delete_tenant_with_no_providers(self, tenants_fixture):
"""
Test deletion of a tenant with no related providers.
"""
tenant = tenants_fixture[1] # Assume this tenant has no providers
providers = Provider.objects.filter(tenant_id=tenant.id)
with patch(
"tasks.jobs.deletion.get_provider_graph_database_names"
) as mock_get_provider_graph_database_names, patch(
"tasks.jobs.deletion.graph_database.drop_database"
) as mock_drop_database:
tenant = tenants_fixture[1] # Assume this tenant has no providers
providers = Provider.objects.filter(tenant_id=tenant.id)
# Ensure the tenant exists but has no related providers
assert Tenant.objects.filter(id=tenant.id).exists()
assert not providers.exists()
# Ensure the tenant exists but has no related providers
assert Tenant.objects.filter(id=tenant.id).exists()
assert not providers.exists()
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
assert deletion_summary == {} # No providers, so empty summary
assert not Tenant.objects.filter(id=tenant.id).exists()
assert deletion_summary == {} # No providers, so empty summary
assert not Tenant.objects.filter(id=tenant.id).exists()
mock_get_provider_graph_database_names.assert_not_called()
mock_drop_database.assert_not_called()
+77 -8
View File
@@ -1,10 +1,21 @@
import uuid
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import openai
import pytest
from botocore.exceptions import ClientError
from django_celery_beat.models import IntervalSchedule, PeriodicTask
from api.models import (
Integration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Scan,
StateChoices,
)
from tasks.jobs.lighthouse_providers import (
_create_bedrock_client,
_extract_bedrock_credentials,
@@ -15,19 +26,12 @@ from tasks.tasks import (
check_integrations_task,
check_lighthouse_provider_connection_task,
generate_outputs_task,
perform_attack_paths_scan_task,
refresh_lighthouse_provider_models_task,
s3_integration_task,
security_hub_integration_task,
)
from api.models import (
Integration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Scan,
StateChoices,
)
@pytest.mark.django_db
class TestExtractBedrockCredentials:
@@ -737,8 +741,12 @@ class TestScanCompleteTasks:
@patch("tasks.tasks.generate_outputs_task.si")
@patch("tasks.tasks.generate_compliance_reports_task.si")
@patch("tasks.tasks.check_integrations_task.si")
@patch("tasks.tasks.perform_attack_paths_scan_task.apply_async")
@patch("tasks.tasks.can_provider_run_attack_paths_scan", return_value=False)
def test_scan_complete_tasks(
self,
mock_can_run_attack_paths,
mock_attack_paths_task,
mock_check_integrations_task,
mock_compliance_reports_task,
mock_outputs_task,
@@ -793,6 +801,67 @@ class TestScanCompleteTasks:
scan_id="scan-id",
)
# Attack Paths task should be skipped when provider cannot run it
mock_attack_paths_task.assert_not_called()
class TestAttackPathsTasks:
@staticmethod
@contextmanager
def _override_task_request(task, **attrs):
request = task.request
sentinel = object()
previous = {key: getattr(request, key, sentinel) for key in attrs}
for key, value in attrs.items():
setattr(request, key, value)
try:
yield
finally:
for key, prev in previous.items():
if prev is sentinel:
if hasattr(request, key):
delattr(request, key)
else:
setattr(request, key, prev)
def test_perform_attack_paths_scan_task_calls_runner(self):
with (
patch("tasks.tasks.attack_paths_scan") as mock_attack_paths_scan,
self._override_task_request(
perform_attack_paths_scan_task, id="celery-task-id"
),
):
mock_attack_paths_scan.return_value = {"status": "ok"}
result = perform_attack_paths_scan_task.run(
tenant_id="tenant-id", scan_id="scan-id"
)
mock_attack_paths_scan.assert_called_once_with(
tenant_id="tenant-id", scan_id="scan-id", task_id="celery-task-id"
)
assert result == {"status": "ok"}
def test_perform_attack_paths_scan_task_propagates_exception(self):
with (
patch(
"tasks.tasks.attack_paths_scan",
side_effect=RuntimeError("Exception to propagate"),
) as mock_attack_paths_scan,
self._override_task_request(
perform_attack_paths_scan_task, id="celery-task-error"
),
):
with pytest.raises(RuntimeError, match="Exception to propagate"):
perform_attack_paths_scan_task.run(
tenant_id="tenant-id", scan_id="scan-id"
)
mock_attack_paths_scan.assert_called_once_with(
tenant_id="tenant-id", scan_id="scan-id", task_id="celery-task-error"
)
@pytest.mark.django_db
class TestCheckIntegrationsTask:
+46 -1
View File
@@ -1,6 +1,7 @@
services:
api-dev:
hostname: "prowler-api"
image: prowler-api-dev
build:
context: ./api
dockerfile: Dockerfile
@@ -24,6 +25,8 @@ services:
condition: service_healthy
valkey:
condition: service_healthy
neo4j:
condition: service_healthy
entrypoint:
- "/home/prowler/docker-entrypoint.sh"
- "dev"
@@ -85,7 +88,41 @@ services:
timeout: 5s
retries: 3
neo4j:
image: graphstack/dozerdb:5.26.3.0
hostname: "neo4j"
volumes:
- ./_data/neo4j:/data
environment:
# We can't add our .env file because some of our current variables are not compatible with Neo4j env vars
# Auth
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
# Memory limits
- NEO4J_dbms_max__databases=${NEO4J_DBMS_MAX__DATABASES:-1000000}
- NEO4J_server_memory_pagecache_size=${NEO4J_SERVER_MEMORY_PAGECACHE_SIZE:-1G}
- NEO4J_server_memory_heap_initial__size=${NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE:-1G}
- NEO4J_server_memory_heap_max__size=${NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE:-1G}
# APOC
- apoc.export.file.enabled=${NEO4J_POC_EXPORT_FILE_ENABLED:-true}
- apoc.import.file.enabled=${NEO4J_APOC_IMPORT_FILE_ENABLED:-true}
- apoc.import.file.use_neo4j_config=${NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG:-true}
- "NEO4J_PLUGINS=${NEO4J_PLUGINS:-[\"apoc\"]}"
- "NEO4J_dbms_security_procedures_allowlist=${NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST:-apoc.*}"
- "NEO4J_dbms_security_procedures_unrestricted=${NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED:-apoc.*}"
# Networking
- "dbms.connector.bolt.listen_address=${NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS:-0.0.0.0:7687}"
# 7474 is the UI port
ports:
- 7474:7474
- ${NEO4J_PORT:-7687}:7687
healthcheck:
test: ["CMD", "wget", "--no-verbose", "http://localhost:7474"]
interval: 10s
timeout: 10s
retries: 10
worker-dev:
image: prowler-api-dev
build:
context: ./api
dockerfile: Dockerfile
@@ -96,17 +133,23 @@ services:
- path: .env
required: false
volumes:
- "outputs:/tmp/prowler_api_output"
- ./api/src/backend:/home/prowler/backend
- ./api/pyproject.toml:/home/prowler/pyproject.toml
- ./api/docker-entrypoint.sh:/home/prowler/docker-entrypoint.sh
- outputs:/tmp/prowler_api_output
depends_on:
valkey:
condition: service_healthy
postgres:
condition: service_healthy
neo4j:
condition: service_healthy
entrypoint:
- "/home/prowler/docker-entrypoint.sh"
- "worker"
worker-beat:
image: prowler-api-dev
build:
context: ./api
dockerfile: Dockerfile
@@ -121,6 +164,8 @@ services:
condition: service_healthy
postgres:
condition: service_healthy
neo4j:
condition: service_healthy
entrypoint:
- "../docker-entrypoint.sh"
- "beat"
+31
View File
@@ -72,6 +72,37 @@ services:
timeout: 5s
retries: 3
neo4j:
image: graphstack/dozerdb:5.26.3.0
hostname: "neo4j"
volumes:
- ./_data/neo4j:/data
environment:
# We can't add our .env file because some of our current variables are not compatible with Neo4j env vars
# Auth
- NEO4J_AUTH=${NEO4J_USER}/${NEO4J_PASSWORD}
# Memory limits
- NEO4J_dbms_max__databases=${NEO4J_DBMS_MAX__DATABASES:-1000000}
- NEO4J_server_memory_pagecache_size=${NEO4J_SERVER_MEMORY_PAGECACHE_SIZE:-1G}
- NEO4J_server_memory_heap_initial__size=${NEO4J_SERVER_MEMORY_HEAP_INITIAL__SIZE:-1G}
- NEO4J_server_memory_heap_max__size=${NEO4J_SERVER_MEMORY_HEAP_MAX__SIZE:-1G}
# APOC
- apoc.export.file.enabled=${NEO4J_POC_EXPORT_FILE_ENABLED:-true}
- apoc.import.file.enabled=${NEO4J_APOC_IMPORT_FILE_ENABLED:-true}
- apoc.import.file.use_neo4j_config=${NEO4J_APOC_IMPORT_FILE_USE_NEO4J_CONFIG:-true}
- "NEO4J_PLUGINS=${NEO4J_PLUGINS:-[\"apoc\"]}"
- "NEO4J_dbms_security_procedures_allowlist=${NEO4J_DBMS_SECURITY_PROCEDURES_ALLOWLIST:-apoc.*}"
- "NEO4J_dbms_security_procedures_unrestricted=${NEO4J_DBMS_SECURITY_PROCEDURES_UNRESTRICTED:-apoc.*}"
# Networking
- "dbms.connector.bolt.listen_address=${NEO4J_DBMS_CONNECTOR_BOLT_LISTEN_ADDRESS:-0.0.0.0:7687}"
ports:
- ${NEO4J_PORT:-7687}:7687
healthcheck:
test: ["CMD", "wget", "--no-verbose", "http://localhost:7474"]
interval: 10s
timeout: 10s
retries: 10
worker:
image: prowlercloud/prowler-api:${PROWLER_API_VERSION:-stable}
env_file:
+10 -3
View File
@@ -37,8 +37,8 @@ CODE_REVIEW_ENABLED=$(echo "$CODE_REVIEW_ENABLED" | tr '[:upper:]' '[:lower:]')
echo -e "${BLUE}️ Code Review Status: ${CODE_REVIEW_ENABLED}${NC}"
echo ""
# Get staged files (what will be committed)
STAGED_FILES=$(git diff --cached --name-only --diff-filter=ACM | grep -E '\.(tsx?|jsx?)$' || true)
# Get staged files in the UI folder only (what will be committed)
STAGED_FILES=$(git diff --cached --name-only --diff-filter=ACM -- 'ui/**' | grep -E '\.(tsx?|jsx?)$' || true)
if [ "$CODE_REVIEW_ENABLED" = "true" ]; then
if [ -z "$STAGED_FILES" ]; then
@@ -135,7 +135,14 @@ else
echo ""
fi
# Run healthcheck (typecheck and lint check)
# Check if there are any UI files to validate
if [ -z "$STAGED_FILES" ] && [ "$CODE_REVIEW_ENABLED" = "true" ]; then
echo -e "${YELLOW}⏭️ No UI files to validate, skipping healthcheck${NC}"
echo ""
exit 0
fi
# Run healthcheck (typecheck and lint check) only if there are UI changes
echo -e "${BLUE}🏥 Running healthcheck...${NC}"
echo ""
+2
View File
@@ -12,6 +12,7 @@ All notable changes to the **Prowler UI** are documented in this file.
- Add ThreatScore pillar breakdown to Compliance Summary page and detail view [(#9773)](https://github.com/prowler-cloud/prowler/pull/9773)
- Add Provider and Group filters to Resources page [(#9492)](https://github.com/prowler-cloud/prowler/pull/9492)
- Compliance Watchlist component in Overview page [(#9786)](https://github.com/prowler-cloud/prowler/pull/9786)
- - Add a new main section for list Attack Paths scans, execute queries on them and view their result as a graph [(#)](https://github.com/prowler-cloud/prowler/pull/)
### 🔄 Changed
@@ -132,6 +133,7 @@ All notable changes to the **Prowler UI** are documented in this file.
- PDF reporting for NIS2 compliance framework [(#9170)](https://github.com/prowler-cloud/prowler/pull/9170)
- External resource link to IaC findings for direct navigation to source code in Git repositories [(#9151)](https://github.com/prowler-cloud/prowler/pull/9151)
- New Overview page and new app styles [(#9234)](https://github.com/prowler-cloud/prowler/pull/9234)
- Attack Paths feature with query execution and graph visualization [(#PROWLER-383)](https://github.com/prowler-cloud/prowler/pull/9270)
- Use branch name as region for IaC findings [(#9296)](https://github.com/prowler-cloud/prowler/pull/9296)
### 🔄 Changed
+4
View File
@@ -0,0 +1,4 @@
export * from "./queries";
export * from "./queries.adapter";
export * from "./scans";
export * from "./scans.adapter";
@@ -0,0 +1,55 @@
import { MetaDataProps } from "@/types";
import {
AttackPathQueriesResponse,
AttackPathQuery,
} from "@/types/attack-paths";
/**
* Adapts raw query API responses to enriched domain models
* - Enriches queries with metadata and computed properties
* - Co-locates related data for better performance
* - Preserves pagination metadata for list operations
*
* Uses plugin architecture for extensibility:
* - Handles query-specific response transformation
* - Can be composed with backend service plugins
* - Maintains separation of concerns between API layer and business logic
*/
/**
* Adapt attack path queries response with enriched data
*
* @param response - Raw API response from attack-paths-scans/{id}/queries endpoint
* @returns Enriched queries data with metadata
*/
export function adaptAttackPathQueriesResponse(
response: AttackPathQueriesResponse | undefined,
): {
data: AttackPathQuery[];
metadata?: MetaDataProps;
} {
if (!response?.data) {
return { data: [] };
}
// Enrich query data with computed properties
const enrichedData = response.data.map((query) => ({
...query,
// Can add computed properties here, e.g.:
// parameterCount: query.attributes.parameters.length,
// requiredParameters: query.attributes.parameters.filter(p => p.required),
// hasParameters: query.attributes.parameters.length > 0,
}));
const metadata: MetaDataProps | undefined = {
pagination: {
page: 1,
pages: 1,
count: enrichedData.length,
itemsPerPage: [10, 25, 50, 100],
},
version: "1.0",
};
return { data: enrichedData, metadata };
}
+97
View File
@@ -0,0 +1,97 @@
"use server";
import { z } from "zod";
import { apiBaseUrl, getAuthHeaders } from "@/lib";
import { handleApiResponse } from "@/lib/server-actions-helper";
import {
AttackPathQueriesResponse,
AttackPathQuery,
AttackPathQueryResult,
ExecuteQueryRequest,
} from "@/types/attack-paths";
import { adaptAttackPathQueriesResponse } from "./queries.adapter";
// Validation schema for UUID - RFC 9562/4122 compliant
const UUIDSchema = z.uuid();
/**
* Fetch available queries for a specific attack path scan
*/
export const getAvailableQueries = async (
scanId: string,
): Promise<{ data: AttackPathQuery[] } | undefined> => {
// Validate scanId is a valid UUID format to prevent request forgery
const validatedScanId = UUIDSchema.safeParse(scanId);
if (!validatedScanId.success) {
console.error("Invalid scan ID format");
return undefined;
}
const headers = await getAuthHeaders({ contentType: false });
try {
const response = await fetch(
`${apiBaseUrl}/attack-paths-scans/${validatedScanId.data}/queries`,
{
headers,
method: "GET",
},
);
const apiResponse = (await handleApiResponse(
response,
)) as AttackPathQueriesResponse;
const adaptedData = adaptAttackPathQueriesResponse(apiResponse);
return { data: adaptedData.data };
} catch (error) {
console.error("Error fetching available queries for scan:", error);
return undefined;
}
};
/**
* Execute a query on an attack path scan
*/
export const executeQuery = async (
scanId: string,
queryId: string,
parameters?: Record<string, string | number | boolean>,
): Promise<AttackPathQueryResult | undefined> => {
// Validate scanId is a valid UUID format to prevent request forgery
const validatedScanId = UUIDSchema.safeParse(scanId);
if (!validatedScanId.success) {
console.error("Invalid scan ID format");
return undefined;
}
const headers = await getAuthHeaders({ contentType: true });
const requestBody: ExecuteQueryRequest = {
data: {
type: "attack-paths-query-run-request",
attributes: {
id: queryId,
...(parameters && { parameters }),
},
},
};
try {
const response = await fetch(
`${apiBaseUrl}/attack-paths-scans/${validatedScanId.data}/queries/run`,
{
headers,
method: "POST",
body: JSON.stringify(requestBody),
},
);
return handleApiResponse(response);
} catch (error) {
console.error("Error executing query on scan:", error);
return undefined;
}
};
@@ -0,0 +1,164 @@
import {
AttackPathGraphData,
GraphEdge,
GraphNodeProperties,
GraphNodePropertyValue,
GraphRelationship,
} from "@/types/attack-paths";
/**
* Normalizes property values to ensure they are primitives
* Arrays are converted to comma-separated strings
*
* @param value - The property value to normalize
* @returns Normalized primitive value
*/
function normalizePropertyValue(
value:
| GraphNodePropertyValue
| GraphNodePropertyValue[]
| Record<string, unknown>,
): string | number | boolean | null | undefined {
if (value === null || value === undefined) {
return value;
}
if (Array.isArray(value)) {
// Convert arrays to comma-separated strings
return value.join(", ");
}
if (
typeof value === "string" ||
typeof value === "number" ||
typeof value === "boolean"
) {
return value;
}
// For any other type, convert to string
return String(value);
}
/**
* Normalizes all properties in an object to ensure they are primitives
*
* @param properties - The properties object to normalize
* @returns Normalized properties object
*/
function normalizeProperties(
properties: Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[] | Record<string, unknown>
>,
): GraphNodeProperties {
const normalized: GraphNodeProperties = {};
for (const [key, value] of Object.entries(properties)) {
normalized[key] = normalizePropertyValue(value);
}
return normalized;
}
/**
* Adapts graph query result data for D3 visualization
* Transforms relationships array into edges array for D3 force-directed graph
*
* The adapter handles:
* - Converting relationship objects to edge objects compatible with D3
* - Mapping relationship labels to edge types for graph styling
* - Normalizing array properties to strings (e.g., anonymous_actions: ["s3:GetObject"] -> "s3:GetObject")
* - Preserving node and relationship data structure
* - Adding findings array to each node based on HAS_FINDING edges
* - Adding resources array to finding nodes based on HAS_FINDING edges (reverse relationship)
*
* @param graphData - Raw graph data with nodes and relationships from API
* @returns Graph data with edges array formatted for D3 visualization and findings/resources on nodes
*/
export function adaptQueryResultToGraphData(
graphData: AttackPathGraphData,
): AttackPathGraphData {
// Normalize node properties to ensure all values are primitives
const normalizedNodes = graphData.nodes.map((node) => ({
...node,
properties: normalizeProperties(
node.properties as Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[]
>,
),
findings: [] as string[], // Will be populated below
resources: [] as string[], // Will be populated below for finding nodes
}));
// Transform relationships into D3-compatible edges if relationships exist
// Also handle case where edges are already provided (e.g., from mock data)
let edges: GraphEdge[] = [];
if (graphData.relationships) {
edges = (graphData.relationships as GraphRelationship[]).map(
(relationship) => ({
id: relationship.id,
source: relationship.source,
target: relationship.target,
type: relationship.label, // D3 uses 'type' for styling edge appearance
properties: relationship.properties
? normalizeProperties(
relationship.properties as Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[]
>,
)
: undefined,
}),
);
} else if (graphData.edges) {
// If edges are already provided, just normalize their properties
edges = (graphData.edges as GraphEdge[]).map((edge) => ({
...edge,
properties: edge.properties
? normalizeProperties(
edge.properties as Record<
string,
GraphNodePropertyValue | GraphNodePropertyValue[]
>,
)
: undefined,
}));
}
// Populate findings and resources based on HAS_FINDING edges
edges.forEach((edge) => {
if (edge.type === "HAS_FINDING") {
const sourceId =
typeof edge.source === "string"
? edge.source
: (edge.source as { id?: string })?.id;
const targetId =
typeof edge.target === "string"
? edge.target
: (edge.target as { id?: string })?.id;
if (sourceId && targetId) {
// Add finding to source node (resource -> finding)
const sourceNode = normalizedNodes.find((n) => n.id === sourceId);
if (sourceNode) {
sourceNode.findings.push(targetId);
}
// Add resource to target node (finding <- resource)
const targetNode = normalizedNodes.find((n) => n.id === targetId);
if (targetNode) {
targetNode.resources.push(sourceId);
}
}
}
});
return {
nodes: normalizedNodes,
edges,
relationships: graphData.relationships, // Preserve original relationships data
};
}
+89
View File
@@ -0,0 +1,89 @@
import { MetaDataProps } from "@/types";
import { AttackPathScan, AttackPathScansResponse } from "@/types/attack-paths";
/**
* Adapts raw scan API responses to enriched domain models
* - Transforms raw scan data with computed properties
* - Co-locates related data for better performance
* - Preserves pagination metadata for list operations
*
* Uses plugin architecture for extensibility:
* - Handles scan-specific response transformation
* - Can be composed with backend service plugins
* - Maintains separation of concerns between API layer and business logic
*/
/**
* Adapt attack path scans response with enriched data
*
* @param response - Raw API response from attack-paths-scans endpoint
* @returns Enriched scans data with metadata and computed properties
*/
export function adaptAttackPathScansResponse(
response: AttackPathScansResponse | undefined,
): {
data: AttackPathScan[];
metadata?: MetaDataProps;
} {
if (!response?.data) {
return { data: [] };
}
// Enrich scan data with computed properties
const enrichedData = response.data.map((scan) => ({
...scan,
attributes: {
...scan.attributes,
// Format duration for display
durationLabel: scan.attributes.duration
? formatDuration(scan.attributes.duration)
: null,
// Check if scan is recent (completed within last 24 hours)
isRecent: isRecentScan(scan.attributes.completed_at),
},
}));
// Transform links to MetaDataProps format if pagination exists
const metadata: MetaDataProps | undefined = response.links
? {
pagination: {
// Links-based pagination doesn't have traditional page numbers
// but we preserve the structure for consistency
page: 1,
pages: 1,
count: enrichedData.length,
itemsPerPage: [10, 25, 50, 100],
},
version: "1.0",
}
: undefined;
return { data: enrichedData, metadata };
}
/**
* Format duration in seconds to human-readable format
*
* @param seconds - Duration in seconds
* @returns Formatted duration string (e.g., "2m 30s")
*/
function formatDuration(seconds: number): string {
const minutes = Math.floor(seconds / 60);
const remainingSeconds = seconds % 60;
return `${minutes}m ${remainingSeconds}s`;
}
/**
* Check if a scan is recent (completed within last 24 hours)
*
* @param completedAt - Completion timestamp
* @returns true if scan completed within last 24 hours
*/
function isRecentScan(completedAt: string | null): boolean {
if (!completedAt) return false;
const completionTime = new Date(completedAt).getTime();
const oneDayAgo = Date.now() - 24 * 60 * 60 * 1000;
return completionTime > oneDayAgo;
}
+69
View File
@@ -0,0 +1,69 @@
"use server";
import { z } from "zod";
import { apiBaseUrl, getAuthHeaders } from "@/lib";
import { handleApiResponse } from "@/lib/server-actions-helper";
import { AttackPathScan, AttackPathScansResponse } from "@/types/attack-paths";
import { adaptAttackPathScansResponse } from "./scans.adapter";
// Validation schema for UUID - RFC 9562/4122 compliant
const UUIDSchema = z.uuid();
/**
* Fetch list of attack path scans (latest scan for each provider)
*/
export const getAttackPathScans = async (): Promise<
{ data: AttackPathScan[] } | undefined
> => {
const headers = await getAuthHeaders({ contentType: false });
try {
const response = await fetch(`${apiBaseUrl}/attack-paths-scans`, {
headers,
method: "GET",
});
const apiResponse = (await handleApiResponse(
response,
)) as AttackPathScansResponse;
const adaptedData = adaptAttackPathScansResponse(apiResponse);
return { data: adaptedData.data };
} catch (error) {
console.error("Error fetching attack path scans:", error);
return undefined;
}
};
/**
* Fetch detail of a specific attack path scan
*/
export const getAttackPathScanDetail = async (
scanId: string,
): Promise<{ data: AttackPathScan } | undefined> => {
// Validate scanId is a valid UUID format to prevent request forgery
const validatedScanId = UUIDSchema.safeParse(scanId);
if (!validatedScanId.success) {
console.error("Invalid scan ID format");
return undefined;
}
const headers = await getAuthHeaders({ contentType: false });
try {
const response = await fetch(
`${apiBaseUrl}/attack-paths-scans/${validatedScanId.data}`,
{
headers,
method: "GET",
},
);
return handleApiResponse(response);
} catch (error) {
console.error("Error fetching attack path scan detail:", error);
return undefined;
}
};
@@ -0,0 +1,2 @@
export { VerticalSteps } from "./vertical-steps";
export { WorkflowAttackPaths } from "./workflow-attack-paths";
@@ -0,0 +1,299 @@
"use client";
import { useControlledState } from "@react-stately/utils";
import { domAnimation, LazyMotion, m } from "framer-motion";
import type {
ComponentProps,
CSSProperties,
HTMLAttributes,
ReactNode,
} from "react";
import { forwardRef } from "react";
import { cn } from "@/lib/utils";
export type VerticalStepProps = {
className?: string;
description?: ReactNode;
title?: ReactNode;
};
const STEP_COLORS = {
primary: "primary",
secondary: "secondary",
success: "success",
warning: "warning",
danger: "danger",
default: "default",
} as const;
type StepColor = (typeof STEP_COLORS)[keyof typeof STEP_COLORS];
export interface VerticalStepsProps extends HTMLAttributes<HTMLButtonElement> {
/**
* An array of steps.
*
* @default []
*/
steps?: VerticalStepProps[];
/**
* The color of the steps.
*
* @default "primary"
*/
color?: StepColor;
/**
* The current step index.
*/
currentStep?: number;
/**
* The default step index.
*
* @default 0
*/
defaultStep?: number;
/**
* Whether to hide the progress bars.
*
* @default false
*/
hideProgressBars?: boolean;
/**
* The custom class for the steps wrapper.
*/
className?: string;
/**
* The custom class for the step.
*/
stepClassName?: string;
/**
* Callback function when the step index changes.
*/
onStepChange?: (stepIndex: number) => void;
}
function CheckIcon(props: ComponentProps<"svg">) {
return (
<svg
{...props}
fill="none"
stroke="currentColor"
strokeWidth={2}
viewBox="0 0 24 24"
>
<m.path
animate={{ pathLength: 1 }}
d="M5 13l4 4L19 7"
initial={{ pathLength: 0 }}
strokeLinecap="round"
strokeLinejoin="round"
transition={{
delay: 0.2,
type: "tween",
ease: "easeOut",
duration: 0.3,
}}
/>
</svg>
);
}
export const VerticalSteps = forwardRef<HTMLButtonElement, VerticalStepsProps>(
(
{
color = "primary",
steps = [],
defaultStep = 0,
onStepChange,
currentStep: currentStepProp,
hideProgressBars = false,
stepClassName,
className,
...props
},
ref,
) => {
const [currentStep, setCurrentStep] = useControlledState(
currentStepProp,
defaultStep,
onStepChange,
);
let userColor;
let fgColor;
const colorsVars = [
"[--active-fg-color:var(--step-fg-color)]",
"[--active-border-color:var(--step-color)]",
"[--active-color:var(--step-color)]",
"[--complete-background-color:var(--step-color)]",
"[--complete-border-color:var(--step-color)]",
"[--inactive-border-color:hsl(var(--heroui-default-300))]",
"[--inactive-color:hsl(var(--heroui-default-300))]",
];
switch (color) {
case "primary":
userColor = "[--step-color:hsl(var(--heroui-primary))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-primary-foreground))]";
break;
case "secondary":
userColor = "[--step-color:hsl(var(--heroui-secondary))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-secondary-foreground))]";
break;
case "success":
userColor = "[--step-color:hsl(var(--heroui-success))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-success-foreground))]";
break;
case "warning":
userColor = "[--step-color:hsl(var(--heroui-warning))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-warning-foreground))]";
break;
case "danger":
userColor = "[--step-color:hsl(var(--heroui-error))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-error-foreground))]";
break;
case "default":
userColor = "[--step-color:hsl(var(--heroui-default))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-default-foreground))]";
break;
default:
userColor = "[--step-color:hsl(var(--heroui-primary))]";
fgColor = "[--step-fg-color:hsl(var(--heroui-primary-foreground))]";
break;
}
if (!className?.includes("--step-fg-color")) colorsVars.unshift(fgColor);
if (!className?.includes("--step-color")) colorsVars.unshift(userColor);
if (!className?.includes("--inactive-bar-color"))
colorsVars.push("[--inactive-bar-color:hsl(var(--heroui-default-300))]");
const colors = colorsVars;
return (
<nav aria-label="Progress" className="max-w-fit">
<ol className={cn("flex flex-col gap-y-3", colors, className)}>
{steps?.map((step, stepIdx) => {
const status =
currentStep === stepIdx
? "active"
: currentStep < stepIdx
? "inactive"
: "complete";
return (
<li key={stepIdx} className="relative">
<div className="flex w-full max-w-full items-center">
<button
key={stepIdx}
ref={ref}
aria-current={status === "active" ? "step" : undefined}
className={cn(
"group rounded-large flex w-full cursor-pointer items-center justify-center gap-4 px-3 py-2.5",
stepClassName,
)}
onClick={() => setCurrentStep(stepIdx)}
{...props}
>
<div className="flex h-full items-center">
<LazyMotion features={domAnimation}>
<div className="relative">
<m.div
animate={status}
className={cn(
"border-medium text-large text-default-foreground relative flex h-[34px] w-[34px] items-center justify-center rounded-full font-semibold",
{
"shadow-lg": status === "complete",
},
)}
data-status={status}
initial={false}
transition={{ duration: 0.25 }}
variants={{
inactive: {
backgroundColor: "transparent",
borderColor: "var(--inactive-border-color)",
color: "var(--inactive-color)",
},
active: {
backgroundColor: "transparent",
borderColor: "var(--active-border-color)",
color: "var(--active-color)",
},
complete: {
backgroundColor:
"var(--complete-background-color)",
borderColor: "var(--complete-border-color)",
},
}}
>
<div className="flex items-center justify-center">
{status === "complete" ? (
<CheckIcon className="h-6 w-6 text-(--active-fg-color)" />
) : (
<span>{stepIdx + 1}</span>
)}
</div>
</m.div>
</div>
</LazyMotion>
</div>
<div className="flex-1 text-left">
<div>
<div
className={cn(
"text-medium text-default-foreground font-medium transition-[color,opacity] duration-300 group-active:opacity-70",
{
"text-default-500": status === "inactive",
},
)}
>
{step.title}
</div>
<div
className={cn(
"text-tiny text-default-600 lg:text-small transition-[color,opacity] duration-300 group-active:opacity-70",
{
"text-default-500": status === "inactive",
},
)}
>
{step.description}
</div>
</div>
</div>
</button>
</div>
{stepIdx < steps.length - 1 && !hideProgressBars && (
<div
aria-hidden="true"
className={cn(
"pointer-events-none absolute top-[calc(64px*var(--idx)+1)] left-3 flex h-1/2 -translate-y-1/3 items-center px-4",
)}
style={
{
"--idx": stepIdx,
} as CSSProperties
}
>
<div
className={cn(
"relative h-full w-0.5 bg-(--inactive-bar-color) transition-colors duration-300",
"after:absolute after:block after:h-0 after:w-full after:bg-(--active-border-color) after:transition-[height] after:duration-300 after:content-['']",
{
"after:h-full": stepIdx < currentStep,
},
)}
/>
</div>
)}
</li>
);
})}
</ol>
</nav>
);
},
);
VerticalSteps.displayName = "VerticalSteps";
@@ -0,0 +1,49 @@
"use client";
import { usePathname } from "next/navigation";
import { VerticalSteps } from "./vertical-steps";
/**
* Workflow steps component for Attack Paths wizard
* Shows progress and navigation steps for the two-step process
*/
export const WorkflowAttackPaths = () => {
const pathname = usePathname();
// Determine current step based on pathname
const isQueryBuilderStep = pathname.includes("query-builder");
const currentStep = isQueryBuilderStep ? 1 : 0; // 0-indexed
const steps = [
{
title: "Select Attack Paths Scan",
description: "Choose an AWS account and its latest Attack Paths scan",
},
{
title: "Build Query & Visualize",
description: "Create a query and view the Attack Paths graph",
},
];
const progressPercentage = (currentStep / (steps.length - 1)) * 100;
return (
<section className="flex flex-col gap-6">
<div>
<div className="bg-bg-neutral-tertiary mb-4 h-2 w-full overflow-hidden rounded-full">
<div
className="bg-success-primary h-full transition-all duration-300"
style={{ width: `${progressPercentage}%` }}
/>
</div>
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Step {currentStep + 1} of {steps.length}
</h3>
</div>
<VerticalSteps currentStep={currentStep} steps={steps} color="success" />
</section>
);
};
@@ -0,0 +1,21 @@
import { Navbar } from "@/components/ui/nav-bar/navbar";
/**
* Workflow layout for Attack Paths
* Displays content with navbar
*/
export default function AttackPathsWorkflowLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<>
<Navbar title="Attack Paths Analysis" icon="" />
<div className="px-6 py-4 sm:px-8 xl:px-10">
{/* Content */}
<div>{children}</div>
</div>
</>
);
}
@@ -0,0 +1,34 @@
"use client";
import { Play } from "lucide-react";
import { Button } from "@/components/shadcn";
interface ExecuteButtonProps {
isLoading: boolean;
isDisabled: boolean;
onExecute: () => void;
}
/**
* Execute query button component
* Triggers query execution with loading state
*/
export const ExecuteButton = ({
isLoading,
isDisabled,
onExecute,
}: ExecuteButtonProps) => {
return (
<Button
variant="default"
size="lg"
disabled={isDisabled || isLoading}
onClick={onExecute}
className="w-full gap-2 font-semibold sm:w-auto"
>
{!isLoading && <Play size={18} />}
{isLoading ? "Executing Query..." : "Execute Query"}
</Button>
);
};
@@ -0,0 +1,93 @@
"use client";
import { Download, Minimize2, ZoomIn, ZoomOut } from "lucide-react";
import { Button } from "@/components/shadcn";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/shadcn/tooltip";
interface GraphControlsProps {
onZoomIn: () => void;
onZoomOut: () => void;
onFitToScreen: () => void;
onExport: () => void;
}
/**
* Controls for graph visualization (zoom, pan, export)
* Positioned as floating toolbar above graph
*/
export const GraphControls = ({
onZoomIn,
onZoomOut,
onFitToScreen,
onExport,
}: GraphControlsProps) => {
return (
<div className="flex items-center">
<div className="border-border-neutral-primary bg-bg-neutral-tertiary flex gap-1 rounded-lg border p-1">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onZoomIn}
className="h-8 w-8 p-0"
>
<ZoomIn size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Zoom in</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onZoomOut}
className="h-8 w-8 p-0"
>
<ZoomOut size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Zoom out</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onFitToScreen}
className="h-8 w-8 p-0"
>
<Minimize2 size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Fit graph to view</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="sm"
onClick={onExport}
className="h-8 w-8 p-0"
>
<Download size={18} />
</Button>
</TooltipTrigger>
<TooltipContent>Export graph</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
</div>
);
};
@@ -0,0 +1,508 @@
"use client";
import { Card, CardContent } from "@/components/shadcn";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/shadcn/tooltip";
import type { AttackPathGraphData } from "@/types/attack-paths";
import {
getNodeBorderColor,
getNodeColor,
GRAPH_EDGE_COLOR,
GRAPH_NODE_BORDER_COLORS,
GRAPH_NODE_COLORS,
} from "../../_lib/graph-colors";
interface LegendItem {
label: string;
color: string;
borderColor: string;
description: string;
shape: "rectangle" | "hexagon" | "cloud";
}
// Map node labels to human-readable names and descriptions
const nodeTypeDescriptions: Record<
string,
{ name: string; description: string }
> = {
// Findings
ProwlerFinding: {
name: "Finding",
description: "Security findings from Prowler scans",
},
// AWS Account
AWSAccount: {
name: "AWS Account",
description: "AWS account root node",
},
// Compute
EC2Instance: {
name: "EC2 Instance",
description: "Elastic Compute Cloud instance",
},
LambdaFunction: {
name: "Lambda Function",
description: "AWS Lambda serverless function",
},
// Storage
S3Bucket: {
name: "S3 Bucket",
description: "Simple Storage Service bucket",
},
// IAM
IAMRole: {
name: "IAM Role",
description: "Identity and Access Management role",
},
IAMPolicy: {
name: "IAM Policy",
description: "Identity and Access Management policy",
},
AWSRole: {
name: "AWS Role",
description: "AWS IAM role",
},
AWSPolicy: {
name: "AWS Policy",
description: "AWS IAM policy",
},
AWSInlinePolicy: {
name: "AWS Inline Policy",
description: "AWS IAM inline policy",
},
AWSPolicyStatement: {
name: "AWS Policy Statement",
description: "AWS IAM policy statement",
},
AWSPrincipal: {
name: "AWS Principal",
description: "AWS IAM principal entity",
},
// Networking
SecurityGroup: {
name: "Security Group",
description: "AWS security group for network access control",
},
EC2SecurityGroup: {
name: "EC2 Security Group",
description: "EC2 security group for network access control",
},
IpPermissionInbound: {
name: "IP Permission Inbound",
description: "Inbound IP permission rule",
},
IpRule: {
name: "IP Rule",
description: "IP address rule",
},
Internet: {
name: "Internet",
description: "Internet gateway or public access",
},
// Tags
AWSTag: {
name: "AWS Tag",
description: "AWS resource tag",
},
Tag: {
name: "Tag",
description: "Resource tag",
},
};
/**
* Extract unique node types from graph data
*/
function extractNodeTypes(
nodes: AttackPathGraphData["nodes"] | undefined,
): string[] {
if (!nodes) return [];
const nodeTypes = new Set<string>();
nodes.forEach((node) => {
node.labels.forEach((label) => {
nodeTypes.add(label);
});
});
return Array.from(nodeTypes).sort();
}
/**
* Severity legend items - colors work in both light and dark themes
*/
const severityLegendItems: LegendItem[] = [
{
label: "Critical",
color: GRAPH_NODE_COLORS.critical,
borderColor: GRAPH_NODE_BORDER_COLORS.critical,
description: "Critical severity finding",
shape: "hexagon",
},
{
label: "High",
color: GRAPH_NODE_COLORS.high,
borderColor: GRAPH_NODE_BORDER_COLORS.high,
description: "High severity finding",
shape: "hexagon",
},
{
label: "Medium",
color: GRAPH_NODE_COLORS.medium,
borderColor: GRAPH_NODE_BORDER_COLORS.medium,
description: "Medium severity finding",
shape: "hexagon",
},
{
label: "Low",
color: GRAPH_NODE_COLORS.low,
borderColor: GRAPH_NODE_BORDER_COLORS.low,
description: "Low severity finding",
shape: "hexagon",
},
];
/**
* Generate legend items from graph data
*/
function generateLegendItems(
nodeTypes: string[],
hasFindings: boolean,
): LegendItem[] {
const items: LegendItem[] = [];
const seenTypes = new Set<string>();
// Add severity items if there are findings
if (hasFindings) {
items.push(...severityLegendItems);
}
// Helper to format unknown node types (e.g., "AWSPolicyStatement" -> "AWS Policy Statement")
const formatNodeTypeName = (nodeType: string): string => {
return nodeType
.replace(/([A-Z])/g, " $1") // Add space before capitals
.replace(/^ /, "") // Remove leading space
.replace(/AWS /g, "AWS ") // Keep AWS together
.replace(/EC2 /g, "EC2 ") // Keep EC2 together
.replace(/S3 /g, "S3 ") // Keep S3 together
.replace(/IAM /g, "IAM ") // Keep IAM together
.replace(/IP /g, "IP ") // Keep IP together
.trim();
};
nodeTypes.forEach((nodeType) => {
if (seenTypes.has(nodeType)) return;
seenTypes.add(nodeType);
// Skip findings - we show severity colors instead
const isFinding = nodeType.toLowerCase().includes("finding");
if (isFinding) return;
const description = nodeTypeDescriptions[nodeType];
// Determine shape based on node type
const isInternet = nodeType.toLowerCase() === "internet";
const shape: "rectangle" | "hexagon" | "cloud" = isInternet
? "cloud"
: "rectangle";
if (description) {
items.push({
label: description.name,
color: getNodeColor([nodeType]),
borderColor: getNodeBorderColor([nodeType]),
description: description.description,
shape,
});
} else {
// Format unknown node types nicely
const formattedName = formatNodeTypeName(nodeType);
items.push({
label: formattedName,
color: getNodeColor([nodeType]),
borderColor: getNodeBorderColor([nodeType]),
description: `${formattedName} node`,
shape,
});
}
});
return items;
}
/**
* Hexagon shape component for legend
*/
const HexagonShape = ({
color,
borderColor,
}: {
color: string;
borderColor: string;
}) => (
<svg width="32" height="22" viewBox="0 0 32 22" aria-hidden="true">
<defs>
<filter id="legendGlow" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="1" result="coloredBlur" />
<feMerge>
<feMergeNode in="coloredBlur" />
<feMergeNode in="SourceGraphic" />
</feMerge>
</filter>
</defs>
<path
d="M5 1 L27 1 L31 11 L27 21 L5 21 L1 11 Z"
fill={color}
fillOpacity={0.85}
stroke={borderColor}
strokeWidth={1.5}
filter="url(#legendGlow)"
/>
</svg>
);
/**
* Pill shape component for legend
*/
const PillShape = ({
color,
borderColor,
}: {
color: string;
borderColor: string;
}) => (
<svg width="36" height="20" viewBox="0 0 36 20" aria-hidden="true">
<defs>
<filter id="legendGlow2" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="1" result="coloredBlur" />
<feMerge>
<feMergeNode in="coloredBlur" />
<feMergeNode in="SourceGraphic" />
</feMerge>
</filter>
</defs>
<rect
x="1"
y="1"
width="34"
height="18"
rx="9"
ry="9"
fill={color}
fillOpacity={0.85}
stroke={borderColor}
strokeWidth={1.5}
filter="url(#legendGlow2)"
/>
</svg>
);
/**
* Globe shape component for legend (used for Internet nodes)
*/
const GlobeShape = ({
color,
borderColor,
}: {
color: string;
borderColor: string;
}) => (
<svg width="24" height="24" viewBox="0 0 24 24" aria-hidden="true">
<defs>
<filter id="legendGlow3" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="1" result="coloredBlur" />
<feMerge>
<feMergeNode in="coloredBlur" />
<feMergeNode in="SourceGraphic" />
</feMerge>
</filter>
</defs>
{/* Globe circle */}
<circle
cx="12"
cy="12"
r="10"
fill={color}
fillOpacity={0.85}
stroke={borderColor}
strokeWidth={1.5}
filter="url(#legendGlow3)"
/>
{/* Horizontal line */}
<ellipse
cx="12"
cy="12"
rx="10"
ry="4"
fill="none"
stroke={borderColor}
strokeWidth={1}
strokeOpacity={0.6}
/>
{/* Vertical ellipse */}
<ellipse
cx="12"
cy="12"
rx="4"
ry="10"
fill="none"
stroke={borderColor}
strokeWidth={1}
strokeOpacity={0.6}
/>
</svg>
);
/**
* Edge line component for legend
*/
const EdgeLine = ({ dashed }: { dashed: boolean }) => (
<svg
width="60"
height="20"
viewBox="0 0 60 20"
aria-hidden="true"
style={{ overflow: "visible" }}
>
{/* Line */}
<line
x1="4"
y1="10"
x2="44"
y2="10"
stroke={GRAPH_EDGE_COLOR}
strokeWidth={3}
strokeLinecap="round"
strokeDasharray={dashed ? "8,6" : undefined}
/>
{/* Arrow head */}
<polygon points="44,5 56,10 44,15" fill={GRAPH_EDGE_COLOR} />
</svg>
);
interface GraphLegendProps {
data?: AttackPathGraphData;
}
/**
* Legend for attack path graph node types and edge styles
*/
export const GraphLegend = ({ data }: GraphLegendProps) => {
const nodeTypes = extractNodeTypes(data?.nodes);
// Check if there are any findings in the data
const hasFindings = nodeTypes.some((type) =>
type.toLowerCase().includes("finding"),
);
const legendItems = generateLegendItems(nodeTypes, hasFindings);
if (legendItems.length === 0) {
return null;
}
return (
<Card className="w-fit border-0">
<CardContent className="gap-3 p-4">
<div className="flex flex-col gap-4">
{/* Node types section */}
<div className="flex flex-col items-start gap-3 lg:flex-row lg:flex-wrap lg:items-center">
<TooltipProvider>
{legendItems.map((item) => (
<Tooltip key={item.label}>
<TooltipTrigger asChild>
<div
className="flex cursor-help items-center gap-2"
role="img"
aria-label={`${item.label}: ${item.description}`}
>
{item.shape === "hexagon" ? (
<HexagonShape
color={item.color}
borderColor={item.borderColor}
/>
) : item.shape === "cloud" ? (
<GlobeShape
color={item.color}
borderColor={item.borderColor}
/>
) : (
<PillShape
color={item.color}
borderColor={item.borderColor}
/>
)}
<span className="text-text-neutral-secondary text-xs">
{item.label}
</span>
</div>
</TooltipTrigger>
<TooltipContent>{item.description}</TooltipContent>
</Tooltip>
))}
</TooltipProvider>
</div>
{/* Edge types section */}
<div className="border-border-neutral-primary flex flex-col items-start gap-3 border-t pt-3 lg:flex-row lg:flex-wrap lg:items-center">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div
className="flex cursor-help items-center gap-2"
role="img"
aria-label="Solid line: Resource connection"
>
<EdgeLine dashed={false} />
<span className="text-text-neutral-secondary text-xs">
Resource Connection
</span>
</div>
</TooltipTrigger>
<TooltipContent>
Connection between infrastructure resources
</TooltipContent>
</Tooltip>
{hasFindings && (
<Tooltip>
<TooltipTrigger asChild>
<div
className="flex cursor-help items-center gap-2"
role="img"
aria-label="Dashed line: Finding connection"
>
<EdgeLine dashed={true} />
<span className="text-text-neutral-secondary text-xs">
Finding Connection
</span>
</div>
</TooltipTrigger>
<TooltipContent>
Connection to a security finding
</TooltipContent>
</Tooltip>
)}
</TooltipProvider>
</div>
{/* Zoom control hint */}
<div className="border-border-neutral-primary flex items-center gap-2 border-t pt-3">
<kbd className="bg-bg-neutral-tertiary text-text-neutral-secondary rounded px-1.5 py-0.5 text-xs font-medium">
Ctrl
</kbd>
<span className="text-text-neutral-secondary text-xs">+</span>
<span className="text-text-neutral-secondary text-xs">
Scroll to zoom
</span>
</div>
</div>
</CardContent>
</Card>
);
};
@@ -0,0 +1,24 @@
"use client";
import { Skeleton } from "@/components/shadcn/skeleton/skeleton";
/**
* Loading skeleton for graph visualization
* Shows while graph data is being fetched and processed
*/
export const GraphLoading = () => {
return (
<div className="dark:bg-prowler-blue-400 flex h-96 items-center justify-center rounded-lg bg-gray-50">
<div className="flex flex-col items-center gap-3">
<div className="flex gap-2">
<Skeleton className="h-3 w-3 rounded-full" />
<Skeleton className="h-3 w-3 rounded-full" />
<Skeleton className="h-3 w-3 rounded-full" />
</div>
<p className="text-sm text-gray-600 dark:text-gray-400">
Loading Attack Paths graph...
</p>
</div>
</div>
);
};
@@ -0,0 +1,5 @@
export type { AttackPathGraphRef } from "./attack-path-graph";
export { AttackPathGraph } from "./attack-path-graph";
export { GraphControls } from "./graph-controls";
export { GraphLegend } from "./graph-legend";
export { GraphLoading } from "./graph-loading";
@@ -0,0 +1,7 @@
export { ExecuteButton } from "./execute-button";
export * from "./graph";
export * from "./node-detail";
export { QueryParametersForm } from "./query-parameters-form";
export { QuerySelector } from "./query-selector";
export { ScanListTable } from "./scan-list-table";
export { ScanStatusBadge } from "./scan-status-badge";
@@ -0,0 +1,4 @@
export { NodeDetailContent, NodeDetailPanel } from "./node-detail-panel";
export { NodeOverview } from "./node-overview";
export { NodeRelationships } from "./node-relationships";
export { NodeRemediation } from "./node-remediation";
@@ -0,0 +1,132 @@
"use client";
import { Button, Card, CardContent } from "@/components/shadcn";
import {
Sheet,
SheetContent,
SheetDescription,
SheetHeader,
SheetTitle,
} from "@/components/ui/sheet/sheet";
import type { GraphNode } from "@/types/attack-paths";
import { NodeFindings } from "./node-findings";
import { NodeOverview } from "./node-overview";
import { NodeResources } from "./node-resources";
interface NodeDetailPanelProps {
node: GraphNode | null;
allNodes?: GraphNode[];
onClose?: () => void;
}
/**
* Node details content component (reusable)
*/
export const NodeDetailContent = ({
node,
allNodes = [],
}: {
node: GraphNode;
allNodes?: GraphNode[];
}) => {
const isProwlerFinding = node?.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
return (
<div className="flex flex-col gap-6">
{/* Node Overview Section */}
<Card className="border-border-neutral-secondary">
<CardContent className="flex flex-col gap-3 p-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Node Overview
</h3>
<NodeOverview node={node} />
</CardContent>
</Card>
{/* Related Findings Section - Only show for non-Finding nodes */}
{!isProwlerFinding && (
<Card className="border-border-neutral-secondary">
<CardContent className="flex flex-col gap-3 p-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Related Findings
</h3>
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
Findings connected to this node
</div>
<NodeFindings node={node} allNodes={allNodes} />
</CardContent>
</Card>
)}
{/* Affected Resources Section - Only show for Finding nodes */}
{isProwlerFinding && (
<Card className="border-border-neutral-secondary">
<CardContent className="flex flex-col gap-3 p-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Affected Resources
</h3>
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
Resources affected by this finding
</div>
<NodeResources node={node} allNodes={allNodes} />
</CardContent>
</Card>
)}
</div>
);
};
/**
* Right-side sheet panel for node details
* Shows comprehensive information about selected graph node
* Uses shadcn Sheet component for sliding panel from right
*/
export const NodeDetailPanel = ({
node,
allNodes = [],
onClose,
}: NodeDetailPanelProps) => {
const isOpen = node !== null;
const isProwlerFinding = node?.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
return (
<Sheet open={isOpen} onOpenChange={(open) => !open && onClose?.()}>
<SheetContent className="dark:bg-prowler-theme-midnight my-4 max-h-[calc(100vh-2rem)] max-w-[95vw] overflow-y-auto rounded-l-xl pt-10 md:my-8 md:max-h-[calc(100vh-4rem)] md:max-w-[55vw]">
<SheetHeader>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<SheetTitle>Node Details</SheetTitle>
<SheetDescription>
{String(node?.properties?.name || node?.id.substring(0, 20))}
</SheetDescription>
</div>
{node && isProwlerFinding && (
<Button asChild variant="default" size="sm" className="mt-1">
<a
href={`/findings?id=${String(node.properties?.id || node.id)}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View finding ${String(node.properties?.id || node.id)}`}
>
View Finding
</a>
</Button>
)}
</div>
</SheetHeader>
{node && (
<div className="pt-6">
<NodeDetailContent node={node} allNodes={allNodes} />
</div>
)}
</SheetContent>
</Sheet>
);
};
@@ -0,0 +1,102 @@
"use client";
import { SeverityBadge } from "@/components/ui/table/severity-badge";
import type { GraphNode } from "@/types/attack-paths";
const SEVERITY_LEVELS = {
informational: "informational",
low: "low",
medium: "medium",
high: "high",
critical: "critical",
} as const;
type Severity = (typeof SEVERITY_LEVELS)[keyof typeof SEVERITY_LEVELS];
interface NodeFindingsProps {
node: GraphNode;
allNodes?: GraphNode[];
}
/**
* Node findings section showing related findings for the selected node
* Displays findings that are connected to the node via HAS_FINDING edges
*/
export const NodeFindings = ({ node, allNodes = [] }: NodeFindingsProps) => {
// Get finding IDs from the node's findings array (populated by adapter)
const findingIds = node.findings || [];
// Get the actual finding nodes
const findingNodes = allNodes.filter((n) => findingIds.includes(n.id));
if (findingNodes.length === 0) {
return null;
}
const normalizeSeverity = (
severity?: string | number | boolean | string[] | number[] | null,
): Severity => {
const sev = String(
Array.isArray(severity) ? severity[0] : severity || "",
).toLowerCase();
if (sev in SEVERITY_LEVELS) {
return sev as Severity;
}
return "informational";
};
return (
<ul className="flex flex-col gap-3">
{findingNodes.map((finding) => {
// Get the finding name (check_title preferred, then name)
const findingName = String(
finding.properties?.check_title ||
finding.properties?.name ||
finding.properties?.finding_id ||
"Unknown Finding",
);
// Use properties.id for display, fallback to graph node id
const findingId = String(finding.properties?.id || finding.id);
return (
<li
key={finding.id}
className="border-border-neutral-secondary rounded-lg border p-3"
>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<div className="flex items-center gap-2">
{finding.properties?.severity && (
<SeverityBadge
severity={normalizeSeverity(finding.properties.severity)}
/>
)}
<h5 className="dark:text-prowler-theme-pale/90 text-sm font-medium">
{findingName}
</h5>
</div>
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary mt-1 text-xs">
ID: {findingId}
</p>
</div>
<a
href={`/findings?id=${findingId}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View full finding for ${findingName}`}
className="text-text-info dark:text-text-info h-auto shrink-0 p-0 text-xs font-medium hover:underline"
>
View Full Finding
</a>
</div>
{finding.properties?.description && (
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-2 text-xs">
{String(finding.properties.description)}
</div>
)}
</li>
);
})}
</ul>
);
};
@@ -0,0 +1,109 @@
"use client";
import { CodeSnippet } from "@/components/ui/code-snippet/code-snippet";
import { InfoField } from "@/components/ui/entities";
import { DateWithTime } from "@/components/ui/entities/date-with-time";
import type { GraphNode, GraphNodePropertyValue } from "@/types/attack-paths";
import { formatNodeLabels } from "../../_lib";
interface NodeOverviewProps {
node: GraphNode;
}
/**
* Node overview section showing basic node information
*/
export const NodeOverview = ({ node }: NodeOverviewProps) => {
const renderValue = (value: GraphNodePropertyValue) => {
if (value === null || value === undefined || value === "") {
return "-";
}
if (Array.isArray(value)) {
return value.join(", ");
}
return String(value);
};
const isFinding = node.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
return (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
<InfoField label="Type">{formatNodeLabels(node.labels)}</InfoField>
{isFinding && node.properties.check_title && (
<InfoField label="Check Title">
{String(node.properties.check_title)}
</InfoField>
)}
{isFinding && node.properties.id && (
<InfoField label="Finding ID" variant="simple">
<CodeSnippet value={String(node.properties.id)} />
</InfoField>
)}
</div>
{/* Display all properties */}
<div className="mt-4 border-t border-gray-200 pt-4 dark:border-gray-700">
<h4 className="dark:text-prowler-theme-pale/90 mb-3 text-sm font-semibold">
Properties
</h4>
<div className="grid grid-cols-1 gap-3 md:grid-cols-2">
{Object.entries(node.properties).map(([key, value]) => {
// Skip internal properties
if (key.startsWith("_")) {
return null;
}
// Skip check_title and id for findings as they're shown prominently above
if (isFinding && (key === "check_title" || key === "id")) {
return null;
}
// Format timestamp values
const isTimestamp =
key.includes("date") ||
key.includes("time") ||
key.includes("at") ||
key.includes("seen");
return (
<InfoField key={key} label={formatPropertyName(key)}>
{isTimestamp && typeof value === "number" ? (
<DateWithTime
inline
dateTime={new Date(value).toISOString()}
/>
) : isTimestamp &&
typeof value === "string" &&
value.match(/^\d+$/) ? (
<DateWithTime
inline
dateTime={new Date(parseInt(value)).toISOString()}
/>
) : typeof value === "object" ? (
<code className="text-xs">
{JSON.stringify(value).substring(0, 50)}...
</code>
) : (
renderValue(value)
)}
</InfoField>
);
})}
</div>
</div>
</div>
);
};
// Helper function to format property names
function formatPropertyName(name: string): string {
return name
.replace(/([A-Z])/g, " $1")
.replace(/_/g, " ")
.replace(/\b\w/g, (l) => l.toUpperCase())
.trim();
}
@@ -0,0 +1,105 @@
"use client";
import { cn } from "@/lib/utils";
import type { GraphEdge } from "@/types/attack-paths";
interface NodeRelationshipsProps {
incomingEdges: GraphEdge[];
outgoingEdges: GraphEdge[];
}
/**
* Format edge type to human-readable label
* e.g., "HAS_FINDING" -> "Has Finding"
*/
function formatEdgeType(edgeType: string): string {
return edgeType
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
.join(" ");
}
interface EdgeItemProps {
edge: GraphEdge;
isOutgoing: boolean;
}
/**
* Reusable edge item component
*/
function EdgeItem({ edge, isOutgoing }: EdgeItemProps) {
const targetId =
typeof edge.target === "string" ? edge.target : String(edge.target);
const sourceId =
typeof edge.source === "string" ? edge.source : String(edge.source);
const displayId = (isOutgoing ? targetId : sourceId).substring(0, 30);
return (
<div
key={edge.id}
className="border-border-neutral-tertiary dark:border-border-neutral-tertiary flex items-center justify-between rounded border p-2"
>
<code className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
{displayId}
</code>
<span
className={cn(
"rounded px-2 py-1 text-xs font-medium",
isOutgoing
? "bg-bg-data-info text-text-neutral-primary dark:text-text-neutral-primary"
: "bg-bg-pass-primary text-text-neutral-primary dark:text-text-neutral-primary",
)}
>
{formatEdgeType(edge.type)}
</span>
</div>
);
}
/**
* Node relationships section showing incoming and outgoing edges
*/
export const NodeRelationships = ({
incomingEdges,
outgoingEdges,
}: NodeRelationshipsProps) => {
return (
<div className="flex flex-col gap-6">
{/* Outgoing Relationships */}
<div>
<h4 className="dark:text-prowler-theme-pale/90 mb-3 text-sm font-semibold">
Outgoing Relationships ({outgoingEdges.length})
</h4>
{outgoingEdges.length > 0 ? (
<div className="space-y-2">
{outgoingEdges.map((edge) => (
<EdgeItem key={edge.id} edge={edge} isOutgoing />
))}
</div>
) : (
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary text-xs">
No outgoing relationships
</p>
)}
</div>
{/* Incoming Relationships */}
<div className="border-border-neutral-tertiary dark:border-border-neutral-tertiary border-t pt-6">
<h4 className="dark:text-prowler-theme-pale/90 mb-3 text-sm font-semibold">
Incoming Relationships ({incomingEdges.length})
</h4>
{incomingEdges.length > 0 ? (
<div className="space-y-2">
{incomingEdges.map((edge) => (
<EdgeItem key={edge.id} edge={edge} isOutgoing={false} />
))}
</div>
) : (
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary text-xs">
No incoming relationships
</p>
)}
</div>
</div>
);
};
@@ -0,0 +1,83 @@
"use client";
import Link from "next/link";
import { Badge } from "@/components/shadcn/badge/badge";
interface Finding {
id: string;
title: string;
severity: "critical" | "high" | "medium" | "low" | "info";
status: "PASS" | "FAIL" | "MANUAL";
}
interface NodeRemediationProps {
findings: Finding[];
}
/**
* Node remediation section showing related Prowler findings
*/
export const NodeRemediation = ({ findings }: NodeRemediationProps) => {
const getSeverityVariant = (severity: string) => {
switch (severity) {
case "critical":
return "destructive";
case "high":
return "default";
case "medium":
return "secondary";
case "low":
return "outline";
default:
return "default";
}
};
const getStatusVariant = (status: string) => {
if (status === "PASS") return "default";
if (status === "FAIL") return "destructive";
return "secondary";
};
return (
<div className="flex flex-col gap-3">
{findings.map((finding) => (
<div
key={finding.id}
className="rounded-lg border border-gray-200 p-3 dark:border-gray-700"
>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<h5 className="dark:text-prowler-theme-pale/90 text-sm font-medium">
{finding.title}
</h5>
<p className="mt-1 text-xs text-gray-500 dark:text-gray-400">
ID: {finding.id.substring(0, 12)}...
</p>
</div>
<div className="flex flex-col gap-1">
<Badge variant={getSeverityVariant(finding.severity)}>
{finding.severity}
</Badge>
<Badge variant={getStatusVariant(finding.status)}>
{finding.status}
</Badge>
</div>
</div>
<div className="mt-2">
<Link
href={`/findings?id=${finding.id}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View full finding for ${finding.title}`}
className="text-text-info dark:text-text-info text-sm transition-all hover:opacity-80 dark:hover:opacity-80"
>
View Full Finding
</Link>
</div>
</div>
))}
</div>
);
};
@@ -0,0 +1,85 @@
"use client";
import { Badge } from "@/components/shadcn/badge/badge";
import { cn } from "@/lib/utils";
import type { GraphNode } from "@/types/attack-paths";
interface NodeResourcesProps {
node: GraphNode;
allNodes?: GraphNode[];
}
/**
* Node resources section showing affected resources for the selected finding node
* Displays resources that are connected to the finding node via HAS_FINDING edges
*/
export const NodeResources = ({ node, allNodes = [] }: NodeResourcesProps) => {
// Get resource IDs from the node's resources array (populated by adapter)
const resourceIds = node.resources || [];
// Get the actual resource nodes
const resourceNodes = allNodes.filter((n) => resourceIds.includes(n.id));
if (resourceNodes.length === 0) {
return null;
}
const getResourceTypeColor = (labels: string[]): string => {
const label = (labels[0] || "").toLowerCase();
switch (label) {
case "s3bucket":
case "awsaccount":
case "ec2instance":
case "iamrole":
case "lambdafunction":
case "securitygroup":
return "bg-bg-data-aws";
default:
return "bg-bg-data-muted";
}
};
return (
<ul className="flex flex-col gap-3">
{resourceNodes.map((resource) => {
// Use properties.id for display, fallback to graph node id
const resourceId = String(resource.properties?.id || resource.id);
return (
<li
key={resource.id}
className="border-border-neutral-secondary rounded-lg border p-3"
>
<div className="flex items-start justify-between gap-2">
<div className="flex-1">
<div className="flex items-center gap-2">
{resource.labels && (
<Badge
className={cn(
getResourceTypeColor(resource.labels),
"text-text-neutral-primary",
)}
>
{resource.labels[0]}
</Badge>
)}
<h5 className="dark:text-prowler-theme-pale/90 text-sm font-medium">
{String(resource.properties?.name || resourceId)}
</h5>
</div>
<p className="text-text-neutral-tertiary dark:text-text-neutral-tertiary mt-1 text-xs">
ID: {resourceId}
</p>
</div>
</div>
{resource.properties?.arn && (
<div className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-2 text-xs">
ARN: {String(resource.properties.arn)}
</div>
)}
</li>
);
})}
</ul>
);
};
@@ -0,0 +1,122 @@
"use client";
import { Controller, useFormContext } from "react-hook-form";
import type { AttackPathQuery } from "@/types/attack-paths";
interface QueryParametersFormProps {
selectedQuery: AttackPathQuery | null | undefined;
}
/**
* Dynamic form component for query parameters
* Renders form fields based on selected query's parameters
*/
export const QueryParametersForm = ({
selectedQuery,
}: QueryParametersFormProps) => {
const {
control,
formState: { errors },
} = useFormContext();
if (!selectedQuery || !selectedQuery.attributes.parameters.length) {
return (
<div className="rounded-lg bg-blue-50 p-4 dark:bg-blue-950/20">
<p className="text-sm text-blue-700 dark:text-blue-300">
This query requires no parameters. Click &quot;Execute Query&quot; to
proceed.
</p>
</div>
);
}
return (
<div className="flex flex-col gap-4">
<h3 className="dark:text-prowler-theme-pale/90 text-sm font-semibold">
Query Parameters
</h3>
{selectedQuery.attributes.parameters.map((param) => (
<Controller
key={param.name}
name={param.name}
control={control}
render={({ field }) => {
if (param.data_type === "boolean") {
return (
<div className="flex flex-col gap-2">
<label className="flex cursor-pointer items-center gap-3">
<input
type="checkbox"
id={param.name}
checked={field.value === true || field.value === "true"}
onChange={(e) => field.onChange(e.target.checked)}
aria-label={param.label}
className="border-border-neutral-secondary bg-bg-neutral-primary text-text-primary focus:ring-primary dark:border-border-neutral-secondary dark:bg-bg-neutral-primary dark:text-text-primary h-4 w-4 rounded border focus:ring-2"
/>
<div className="flex flex-col gap-1">
<span className="text-sm font-medium text-gray-900 dark:text-gray-100">
{param.label}
</span>
{param.description && (
<span className="text-xs text-gray-600 dark:text-gray-400">
{param.description}
</span>
)}
</div>
</label>
</div>
);
}
const errorMessage = (() => {
const error = errors[param.name];
if (error && typeof error.message === "string") {
return error.message;
}
return undefined;
})();
const descriptionId = `${param.name}-description`;
return (
<div className="flex flex-col gap-2">
<label
htmlFor={param.name}
className="text-sm font-medium text-gray-700 dark:text-gray-300"
>
{param.label}
{param.required && <span className="text-red-500"> *</span>}
</label>
<input
{...field}
id={param.name}
type={param.data_type === "number" ? "number" : "text"}
placeholder={
param.placeholder || `Enter ${param.label.toLowerCase()}`
}
value={field.value ?? ""}
aria-describedby={
param.description ? descriptionId : undefined
}
className="border-border-neutral-secondary bg-bg-neutral-primary text-text-neutral-primary placeholder-text-neutral-secondary focus:border-border-primary focus:ring-primary dark:border-border-neutral-secondary dark:bg-bg-neutral-primary dark:text-text-neutral-primary dark:placeholder-text-neutral-secondary dark:focus:border-border-primary rounded-md border px-3 py-2 text-sm focus:ring-1 focus:outline-none"
/>
{param.description && (
<span
id={descriptionId}
className="text-xs text-gray-600 dark:text-gray-400"
>
{param.description}
</span>
)}
{errorMessage && (
<span className="text-xs text-red-500">{errorMessage}</span>
)}
</div>
);
}}
/>
))}
</div>
);
};
@@ -0,0 +1,46 @@
"use client";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/shadcn";
import type { AttackPathQuery } from "@/types/attack-paths";
interface QuerySelectorProps {
queries: AttackPathQuery[];
selectedQueryId: string | null;
onQueryChange: (queryId: string) => void;
}
/**
* Query selector dropdown component
* Allows users to select from available Attack Paths queries
*/
export const QuerySelector = ({
queries,
selectedQueryId,
onQueryChange,
}: QuerySelectorProps) => {
return (
<Select value={selectedQueryId || ""} onValueChange={onQueryChange}>
<SelectTrigger className="w-full text-left">
<SelectValue placeholder="Choose a query..." />
</SelectTrigger>
<SelectContent>
{queries.map((query) => (
<SelectItem key={query.id} value={query.id}>
<div className="flex flex-col gap-1">
<span className="font-medium">{query.attributes.name}</span>
<span className="text-xs text-gray-500">
{query.attributes.description}
</span>
</div>
</SelectItem>
))}
</SelectContent>
</Select>
);
};
@@ -0,0 +1,350 @@
"use client";
import {
ChevronLeftIcon,
ChevronRightIcon,
DoubleArrowLeftIcon,
DoubleArrowRightIcon,
} from "@radix-ui/react-icons";
import Link from "next/link";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { useState } from "react";
import { Button } from "@/components/shadcn/button/button";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/shadcn/select/select";
import { DateWithTime } from "@/components/ui/entities/date-with-time";
import { EntityInfo } from "@/components/ui/entities/entity-info";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { cn } from "@/lib/utils";
import type { ProviderType } from "@/types";
import type { AttackPathScan } from "@/types/attack-paths";
import { SCAN_STATES } from "@/types/attack-paths";
import { ScanStatusBadge } from "./scan-status-badge";
interface ScanListTableProps {
scans: AttackPathScan[];
}
const TABLE_COLUMN_COUNT = 6;
const DEFAULT_PAGE_SIZE = 5;
const PAGE_SIZE_OPTIONS = [2, 5, 10, 15];
const baseLinkClass =
"relative block rounded border-0 bg-transparent px-3 py-1.5 text-button-primary outline-none transition-all duration-300 hover:bg-bg-neutral-tertiary hover:text-text-neutral-primary focus:shadow-none dark:hover:bg-bg-neutral-secondary dark:hover:text-text-neutral-primary";
const disabledLinkClass =
"text-border-neutral-secondary dark:text-border-neutral-secondary hover:bg-transparent hover:text-border-neutral-secondary dark:hover:text-border-neutral-secondary cursor-default pointer-events-none";
/**
* Table displaying AWS account Attack Paths scans
* Shows scan metadata and allows selection of completed scans
*/
export const ScanListTable = ({ scans }: ScanListTableProps) => {
const pathname = usePathname();
const searchParams = useSearchParams();
const router = useRouter();
const selectedScanId = searchParams.get("scanId");
const currentPage = parseInt(searchParams.get("scanPage") ?? "1");
const pageSize = parseInt(
searchParams.get("scanPageSize") ?? String(DEFAULT_PAGE_SIZE),
);
const [selectedPageSize, setSelectedPageSize] = useState(String(pageSize));
const totalPages = Math.ceil(scans.length / pageSize);
const startIndex = (currentPage - 1) * pageSize;
const endIndex = startIndex + pageSize;
const paginatedScans = scans.slice(startIndex, endIndex);
const handleSelectScan = (scanId: string) => {
const params = new URLSearchParams(searchParams);
params.set("scanId", scanId);
router.push(`${pathname}?${params.toString()}`);
};
const isSelectDisabled = (scan: AttackPathScan) => {
return (
scan.attributes.state !== SCAN_STATES.COMPLETED ||
selectedScanId === scan.id
);
};
const getSelectButtonLabel = (scan: AttackPathScan) => {
if (selectedScanId === scan.id) {
return "Selected";
}
if (scan.attributes.state === SCAN_STATES.SCHEDULED) {
return "Scheduled";
}
if (scan.attributes.state === SCAN_STATES.EXECUTING) {
return "Waiting...";
}
if (scan.attributes.state === SCAN_STATES.FAILED) {
return "Failed";
}
return "Select";
};
const createPageUrl = (pageNumber: number | string) => {
const params = new URLSearchParams(searchParams);
// Preserve scanId if it exists
const scanId = searchParams.get("scanId");
if (+pageNumber > totalPages) {
return `${pathname}?${params.toString()}`;
}
params.set("scanPage", pageNumber.toString());
// Ensure that scanId is preserved
if (scanId) params.set("scanId", scanId);
return `${pathname}?${params.toString()}`;
};
const isFirstPage = currentPage === 1;
const isLastPage = currentPage === totalPages;
return (
<>
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
<Table aria-label="Attack Paths scans table listing provider accounts, scan dates, status, progress, and duration">
<TableHeader>
<TableRow>
<TableHead>Provider / Account</TableHead>
<TableHead>Last Scan Date</TableHead>
<TableHead>Status</TableHead>
<TableHead>Progress</TableHead>
<TableHead>Duration</TableHead>
<TableHead className="text-right">Action</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{scans.length === 0 ? (
<TableRow>
<TableCell
colSpan={TABLE_COLUMN_COUNT}
className="h-24 text-center"
>
No Attack Paths scans available.
</TableCell>
</TableRow>
) : (
paginatedScans.map((scan) => {
const isDisabled = isSelectDisabled(scan);
const isSelected = selectedScanId === scan.id;
const duration = scan.attributes.duration
? `${Math.floor(scan.attributes.duration / 60)}m ${scan.attributes.duration % 60}s`
: "-";
return (
<TableRow
key={scan.id}
className={
isSelected
? "bg-button-primary/10 dark:bg-button-primary/10"
: ""
}
>
<TableCell className="font-medium">
<EntityInfo
cloudProvider={
scan.attributes.provider_type as ProviderType
}
entityAlias={scan.attributes.provider_alias}
entityId={scan.attributes.provider_uid}
/>
</TableCell>
<TableCell>
{scan.attributes.completed_at ? (
<DateWithTime
inline
dateTime={scan.attributes.completed_at}
/>
) : (
"-"
)}
</TableCell>
<TableCell>
<ScanStatusBadge
status={scan.attributes.state}
progress={scan.attributes.progress}
/>
</TableCell>
<TableCell>
<span className="text-sm">
{scan.attributes.progress}%
</span>
</TableCell>
<TableCell>
<span className="text-sm">{duration}</span>
</TableCell>
<TableCell className="text-right">
<Button
type="button"
aria-label="Select scan"
disabled={isDisabled}
variant={isDisabled ? "secondary" : "default"}
onClick={() => handleSelectScan(scan.id)}
className="w-full max-w-24"
>
{getSelectButtonLabel(scan)}
</Button>
</TableCell>
</TableRow>
);
})
)}
</TableBody>
</Table>
{/* Pagination Controls */}
{scans.length > 0 && (
<div className="flex w-full flex-col-reverse items-center justify-between gap-4 overflow-auto p-1 sm:flex-row sm:gap-8">
<div className="text-sm whitespace-nowrap">
{scans.length} scans in total
</div>
{scans.length > DEFAULT_PAGE_SIZE && (
<div className="flex flex-col-reverse items-center gap-4 sm:flex-row sm:gap-6 lg:gap-8">
{/* Rows per page selector */}
<div className="flex items-center gap-2">
<p className="text-sm font-medium whitespace-nowrap">
Rows per page
</p>
<Select
value={selectedPageSize}
onValueChange={(value) => {
setSelectedPageSize(value);
const params = new URLSearchParams(searchParams);
// Preserve scanId if it exists
const scanId = searchParams.get("scanId");
params.set("scanPageSize", value);
params.set("scanPage", "1");
// Ensure that scanId is preserved
if (scanId) params.set("scanId", scanId);
router.push(`${pathname}?${params.toString()}`);
}}
>
<SelectTrigger className="h-8 w-18">
<SelectValue />
</SelectTrigger>
<SelectContent side="top">
{PAGE_SIZE_OPTIONS.map((size) => (
<SelectItem
key={size}
value={`${size}`}
className="cursor-pointer"
>
{size}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex items-center justify-center text-sm font-medium">
Page {currentPage} of {totalPages}
</div>
<div className="flex items-center gap-2">
<Link
aria-label="Go to first page"
className={cn(
baseLinkClass,
isFirstPage && disabledLinkClass,
)}
href={
isFirstPage
? pathname + "?" + searchParams.toString()
: createPageUrl(1)
}
aria-disabled={isFirstPage}
onClick={(e) => isFirstPage && e.preventDefault()}
>
<DoubleArrowLeftIcon
className="size-4"
aria-hidden="true"
/>
</Link>
<Link
aria-label="Go to previous page"
className={cn(
baseLinkClass,
isFirstPage && disabledLinkClass,
)}
href={
isFirstPage
? pathname + "?" + searchParams.toString()
: createPageUrl(currentPage - 1)
}
aria-disabled={isFirstPage}
onClick={(e) => isFirstPage && e.preventDefault()}
>
<ChevronLeftIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to next page"
className={cn(
baseLinkClass,
isLastPage && disabledLinkClass,
)}
href={
isLastPage
? pathname + "?" + searchParams.toString()
: createPageUrl(currentPage + 1)
}
aria-disabled={isLastPage}
onClick={(e) => isLastPage && e.preventDefault()}
>
<ChevronRightIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to last page"
className={cn(
baseLinkClass,
isLastPage && disabledLinkClass,
)}
href={
isLastPage
? pathname + "?" + searchParams.toString()
: createPageUrl(totalPages)
}
aria-disabled={isLastPage}
onClick={(e) => isLastPage && e.preventDefault()}
>
<DoubleArrowRightIcon
className="size-4"
aria-hidden="true"
/>
</Link>
</div>
</div>
)}
</div>
)}
</div>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-6 text-xs">
Only Attack Paths scans with &quot;Completed&quot; status can be
selected. Scans in progress will update automatically.
</p>
</>
);
};
@@ -0,0 +1,59 @@
"use client";
import { Loader2 } from "lucide-react";
import { Badge } from "@/components/shadcn/badge/badge";
import type { ScanState } from "@/types/attack-paths";
interface ScanStatusBadgeProps {
status: ScanState;
progress?: number;
}
/**
* Status badge for attack path scan status
* Shows visual indicator and text for scan progress
*/
export const ScanStatusBadge = ({
status,
progress = 0,
}: ScanStatusBadgeProps) => {
if (status === "scheduled") {
return (
<Badge className="bg-bg-neutral-tertiary text-text-neutral-primary gap-2">
<span>Scheduled</span>
</Badge>
);
}
if (status === "available") {
return (
<Badge className="bg-bg-neutral-tertiary text-text-neutral-primary gap-2">
<span>Queued</span>
</Badge>
);
}
if (status === "executing") {
return (
<Badge className="bg-bg-warning-secondary text-text-neutral-primary gap-2">
<Loader2 size={14} className="animate-spin" />
<span>In Progress ({progress}%)</span>
</Badge>
);
}
if (status === "completed") {
return (
<Badge className="bg-bg-pass-secondary text-text-success-primary gap-2">
<span>Completed</span>
</Badge>
);
}
return (
<Badge className="bg-bg-fail-secondary text-text-error-primary gap-2">
<span>Failed</span>
</Badge>
);
};
@@ -0,0 +1,3 @@
export { useGraphState } from "./use-graph-state";
export { useQueryBuilder } from "./use-query-builder";
export { useWizardState } from "./use-wizard-state";
@@ -0,0 +1,169 @@
"use client";
import { create } from "zustand";
import type {
AttackPathGraphData,
GraphNode,
GraphState,
} from "@/types/attack-paths";
import { computeFilteredSubgraph } from "../_lib";
interface FilteredViewState {
isFilteredView: boolean;
filteredNodeId: string | null;
fullData: AttackPathGraphData | null; // Original data before filtering
}
interface GraphStore extends GraphState, FilteredViewState {
setGraphData: (data: AttackPathGraphData) => void;
setSelectedNodeId: (nodeId: string | null) => void;
setLoading: (loading: boolean) => void;
setError: (error: string | null) => void;
setZoom: (zoomLevel: number) => void;
setPan: (panX: number, panY: number) => void;
setFilteredView: (
isFiltered: boolean,
nodeId: string | null,
filteredData: AttackPathGraphData | null,
fullData: AttackPathGraphData | null,
) => void;
reset: () => void;
}
const initialState: GraphState & FilteredViewState = {
data: null,
selectedNodeId: null,
loading: false,
error: null,
zoomLevel: 1,
panX: 0,
panY: 0,
isFilteredView: false,
filteredNodeId: null,
fullData: null,
};
const useGraphStore = create<GraphStore>((set) => ({
...initialState,
setGraphData: (data) => set({ data, fullData: null, error: null, isFilteredView: false, filteredNodeId: null }),
setSelectedNodeId: (nodeId) => set({ selectedNodeId: nodeId }),
setLoading: (loading) => set({ loading }),
setError: (error) => set({ error }),
setZoom: (zoomLevel) => set({ zoomLevel }),
setPan: (panX, panY) => set({ panX, panY }),
setFilteredView: (isFiltered, nodeId, filteredData, fullData) =>
set({ isFilteredView: isFiltered, filteredNodeId: nodeId, data: filteredData, fullData, selectedNodeId: nodeId }),
reset: () => set(initialState),
}));
/**
* Custom hook for managing graph visualization state
* Handles graph data, node selection, zoom/pan, loading states, and filtered view
*/
export const useGraphState = () => {
const store = useGraphStore();
// Zustand store methods are stable, no need to memoize
const updateGraphData = (data: AttackPathGraphData) => {
store.setGraphData(data);
};
const selectNode = (nodeId: string | null) => {
store.setSelectedNodeId(nodeId);
};
const getSelectedNode = (): GraphNode | null => {
if (!store.data?.nodes || !store.selectedNodeId) return null;
return (
store.data.nodes.find((node) => node.id === store.selectedNodeId) || null
);
};
const startLoading = () => {
store.setLoading(true);
};
const stopLoading = () => {
store.setLoading(false);
};
const setError = (error: string | null) => {
store.setError(error);
};
const updateZoomAndPan = (zoomLevel: number, panX: number, panY: number) => {
store.setZoom(zoomLevel);
store.setPan(panX, panY);
};
const resetGraph = () => {
store.reset();
};
const clearGraph = () => {
store.setGraphData({ nodes: [], edges: [] });
store.setSelectedNodeId(null);
store.setFilteredView(false, null, null, null);
};
/**
* Enter filtered view mode - redraws graph with only the selected path
* Stores full data so we can restore it when exiting filtered view
*/
const enterFilteredView = (nodeId: string) => {
if (!store.data) return;
// Use fullData if we're already in filtered view, otherwise use current data
const sourceData = store.fullData || store.data;
const filteredData = computeFilteredSubgraph(sourceData, nodeId);
store.setFilteredView(true, nodeId, filteredData, sourceData);
};
/**
* Exit filtered view mode - restore full graph data
*/
const exitFilteredView = () => {
if (!store.isFilteredView || !store.fullData) return;
store.setFilteredView(false, null, store.fullData, null);
};
/**
* Get the node that was used to filter the view
*/
const getFilteredNode = (): GraphNode | null => {
if (!store.isFilteredView || !store.filteredNodeId) return null;
// Look in fullData since that's where the original node data is
const sourceData = store.fullData || store.data;
if (!sourceData) return null;
return (
sourceData.nodes.find((node) => node.id === store.filteredNodeId) || null
);
};
return {
data: store.data,
fullData: store.fullData,
selectedNodeId: store.selectedNodeId,
selectedNode: getSelectedNode(),
loading: store.loading,
error: store.error,
zoomLevel: store.zoomLevel,
panX: store.panX,
panY: store.panY,
isFilteredView: store.isFilteredView,
filteredNodeId: store.filteredNodeId,
filteredNode: getFilteredNode(),
updateGraphData,
selectNode,
startLoading,
stopLoading,
setError,
updateZoomAndPan,
resetGraph,
clearGraph,
enterFilteredView,
exitFilteredView,
};
};
@@ -0,0 +1,98 @@
"use client";
import { zodResolver } from "@hookform/resolvers/zod";
import { useEffect, useState } from "react";
import { useForm } from "react-hook-form";
import { z } from "zod";
import type { AttackPathQuery } from "@/types/attack-paths";
/**
* Custom hook for managing query builder form state
* Handles query selection, parameter validation, and form submission
*/
export const useQueryBuilder = (availableQueries: AttackPathQuery[]) => {
const [selectedQuery, setSelectedQuery] = useState<string | null>(null);
// Generate dynamic Zod schema based on selected query parameters
const getValidationSchema = (queryId: string | null) => {
const schemaObject: Record<string, z.ZodTypeAny> = {};
if (queryId) {
const query = availableQueries.find((q) => q.id === queryId);
if (query) {
query.attributes.parameters.forEach((param) => {
let fieldSchema: z.ZodTypeAny = z
.string()
.min(1, `${param.label} is required`);
if (param.data_type === "number") {
fieldSchema = z.coerce.number().refine((val) => val >= 0, {
message: `${param.label} must be a non-negative number`,
});
} else if (param.data_type === "boolean") {
fieldSchema = z.boolean().default(false);
}
schemaObject[param.name] = fieldSchema;
});
}
}
return z.object(schemaObject);
};
const getDefaultValues = (queryId: string | null) => {
const defaults: Record<string, unknown> = {};
const query = availableQueries.find((q) => q.id === queryId);
if (query) {
query.attributes.parameters.forEach((param) => {
defaults[param.name] = param.data_type === "boolean" ? false : "";
});
}
return defaults;
};
const form = useForm({
resolver: zodResolver(getValidationSchema(selectedQuery)),
mode: "onChange",
defaultValues: getDefaultValues(selectedQuery),
});
// Update form when selectedQuery changes
useEffect(() => {
form.reset(getDefaultValues(selectedQuery), {
keepDirtyValues: false,
});
}, [selectedQuery]); // eslint-disable-line react-hooks/exhaustive-deps
const selectedQueryData = availableQueries.find(
(q) => q.id === selectedQuery,
);
const handleQueryChange = (queryId: string) => {
setSelectedQuery(queryId);
form.reset();
};
const getQueryParameters = () => {
return form.getValues();
};
const isFormValid = () => {
return form.formState.isValid;
};
return {
selectedQuery,
selectedQueryData,
availableQueries,
form,
handleQueryChange,
getQueryParameters,
isFormValid,
};
};
@@ -0,0 +1,91 @@
"use client";
import { useRouter } from "next/navigation";
import { useCallback } from "react";
import { create } from "zustand";
import type { WizardState } from "@/types/attack-paths";
interface WizardStore extends WizardState {
setCurrentStep: (step: 1 | 2) => void;
setSelectedScanId: (scanId: string) => void;
setSelectedQuery: (queryId: string) => void;
setQueryParameters: (
parameters: Record<string, string | number | boolean>,
) => void;
reset: () => void;
}
const initialState: WizardState = {
currentStep: 1,
selectedScanId: null,
selectedQuery: null,
queryParameters: {},
};
const useWizardStore = create<WizardStore>((set) => ({
...initialState,
setCurrentStep: (step) => set({ currentStep: step }),
setSelectedScanId: (scanId) => set({ selectedScanId: scanId }),
setSelectedQuery: (queryId) => set({ selectedQuery: queryId }),
setQueryParameters: (parameters) => set({ queryParameters: parameters }),
reset: () => set(initialState),
}));
/**
* Custom hook for managing Attack Paths wizard state
* Handles step navigation, scan selection, and query configuration
*/
export const useWizardState = () => {
const router = useRouter();
const store = useWizardStore();
// Derive current step from URL path
const currentStep: 1 | 2 =
typeof window !== "undefined"
? window.location.pathname.includes("query-builder")
? 2
: 1
: 1;
const goToSelectScan = useCallback(() => {
store.setCurrentStep(1);
router.push("/attack-paths/select-scan");
}, [router, store]);
const goToQueryBuilder = useCallback(
(scanId: string) => {
store.setSelectedScanId(scanId);
store.setCurrentStep(2);
router.push(`/attack-paths/query-builder?scanId=${scanId}`);
},
[router, store],
);
const updateQueryParameters = useCallback(
(parameters: Record<string, string | number | boolean>) => {
store.setQueryParameters(parameters);
},
[store],
);
const getScanIdFromUrl = useCallback(() => {
const params = new URLSearchParams(
typeof window !== "undefined" ? window.location.search : "",
);
return params.get("scanId") || store.selectedScanId;
}, [store.selectedScanId]);
return {
currentStep,
selectedScanId: store.selectedScanId || getScanIdFromUrl(),
selectedQuery: store.selectedQuery,
queryParameters: store.queryParameters,
goToSelectScan,
goToQueryBuilder,
setSelectedQuery: store.setSelectedQuery,
updateQueryParameters,
reset: store.reset,
};
};
@@ -0,0 +1,145 @@
/**
* Export utilities for attack path graphs
* Handles exporting graph visualization to various formats
*/
/**
* Helper function to download a blob as a file
* @param blob The blob to download
* @param filename The name of the file
*/
const downloadBlob = (blob: Blob, filename: string) => {
const url = URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = url;
link.download = filename;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
};
/**
* Export graph as SVG image
* @param svgElement The SVG element to export
* @param filename The name of the file to download
*/
export const exportGraphAsSVG = (
svgElement: SVGSVGElement | null,
filename: string = "attack-path-graph.svg",
) => {
if (!svgElement) return;
try {
// Clone the SVG element to avoid modifying the original
const clonedSvg = svgElement.cloneNode(true) as SVGSVGElement;
// Find the main container group (first g element with transform)
const containerGroup = clonedSvg.querySelector("g");
if (!containerGroup) {
throw new Error("Could not find graph container");
}
// Get the bounding box of the actual graph content
// We need to get it from the original SVG since cloned elements don't have computed geometry
const originalContainer = svgElement.querySelector("g");
if (!originalContainer) {
throw new Error("Could not find original graph container");
}
const bbox = originalContainer.getBBox();
// Add padding around the content
const padding = 50;
const contentWidth = bbox.width + padding * 2;
const contentHeight = bbox.height + padding * 2;
// Set the SVG dimensions to fit the content
clonedSvg.setAttribute("width", `${contentWidth}`);
clonedSvg.setAttribute("height", `${contentHeight}`);
clonedSvg.setAttribute(
"viewBox",
`${bbox.x - padding} ${bbox.y - padding} ${contentWidth} ${contentHeight}`,
);
// Remove the zoom transform from the container - the viewBox now handles positioning
containerGroup.removeAttribute("transform");
// Add white background for better visibility
const bgRect = document.createElementNS(
"http://www.w3.org/2000/svg",
"rect",
);
bgRect.setAttribute("x", `${bbox.x - padding}`);
bgRect.setAttribute("y", `${bbox.y - padding}`);
bgRect.setAttribute("width", `${contentWidth}`);
bgRect.setAttribute("height", `${contentHeight}`);
bgRect.setAttribute("fill", "#1c1917"); // Dark background matching the app
clonedSvg.insertBefore(bgRect, clonedSvg.firstChild);
const svgData = new XMLSerializer().serializeToString(clonedSvg);
const blob = new Blob([svgData], { type: "image/svg+xml" });
downloadBlob(blob, filename);
} catch (error) {
console.error("Failed to export graph as SVG:", error);
throw new Error("Failed to export graph");
}
};
/**
* Export graph as PNG image
* @param svgElement The SVG element to export
* @param filename The name of the file to download
*/
export const exportGraphAsPNG = async (
svgElement: SVGSVGElement | null,
filename: string = "attack-path-graph.png",
) => {
if (!svgElement) return;
try {
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d") as CanvasRenderingContext2D;
if (!ctx) throw new Error("Could not get canvas context");
const svg = new Image();
svg.onload = () => {
canvas.width = svg.width;
canvas.height = svg.height;
ctx.drawImage(svg, 0, 0);
canvas.toBlob((blob) => {
if (blob) {
downloadBlob(blob, filename);
}
});
};
svg.onerror = () => {
throw new Error("Failed to load SVG for PNG conversion");
};
svg.src = `data:image/svg+xml;base64,${btoa(svgData)}`;
} catch (error) {
console.error("Failed to export graph as PNG:", error);
throw new Error("Failed to export graph");
}
};
/**
* Export graph data as JSON
* @param graphData The graph data to export
* @param filename The name of the file to download
*/
export const exportGraphAsJSON = (
graphData: Record<string, unknown>,
filename: string = "attack-path-graph.json",
) => {
try {
const jsonString = JSON.stringify(graphData, null, 2);
const blob = new Blob([jsonString], { type: "application/json" });
downloadBlob(blob, filename);
} catch (error) {
console.error("Failed to export graph as JSON:", error);
throw new Error("Failed to export graph");
}
};
@@ -0,0 +1,25 @@
/**
* Formatting utilities for attack path graph nodes
*/
/**
* Format camelCase labels to space-separated text
* e.g., "ProwlerFinding" -> "Prowler Finding", "AWSAccount" -> "Aws Account"
*/
export function formatNodeLabel(label: string): string {
return label
.replace(/([A-Z]+)([A-Z][a-z])/g, "$1 $2")
.replace(/([a-z\d])([A-Z])/g, "$1 $2")
.trim()
.split(" ")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
.join(" ");
}
/**
* Format multiple node labels into a readable string
* e.g., ["ProwlerFinding"] -> "Prowler Finding"
*/
export function formatNodeLabels(labels: string[]): string {
return labels.map(formatNodeLabel).join(", ");
}
@@ -0,0 +1,139 @@
/**
* Color constants for attack path graph visualization
* Colors chosen to work well in both light and dark themes
*/
/**
* Node fill colors - darker versions of design system severity colors
* Darkened to ensure white text has proper contrast (WCAG AA)
*/
export const GRAPH_NODE_COLORS = {
// Finding severities - darkened versions for white text readability
critical: "#cc0055", // Darker pink (from #ff006a)
high: "#c45a3a", // Darker coral (from #f77852)
medium: "#b8860b", // Dark goldenrod (from #fec94d)
low: "#8b9a3e", // Olive/dark yellow-green (from #fdfbd4)
info: "#2563eb", // Darker blue (from #3c8dff)
// Node types
prowlerFinding: "#ea580c",
awsAccount: "#f59e0b", // Amber 500 - AWS orange
attackPattern: "#16a34a",
summary: "#16a34a",
// Infrastructure
ec2Instance: "#0891b2", // Cyan 600
s3Bucket: "#0284c7", // Sky 600
iamRole: "#7c3aed", // Violet 600
iamPolicy: "#7c3aed",
lambdaFunction: "#d97706", // Amber 600
securityGroup: "#0891b2",
default: "#0891b2",
} as const;
/**
* Node border colors - using original design system colors as borders (lighter than fill)
*/
export const GRAPH_NODE_BORDER_COLORS = {
critical: "#ff006a", // Original --bg-data-critical
high: "#f77852", // Original --bg-data-high
medium: "#fec94d", // Original --bg-data-medium
low: "#c4d4a0", // Lighter olive
info: "#3c8dff", // Original --bg-data-info
prowlerFinding: "#fb923c",
awsAccount: "#fbbf24", // Amber 400
attackPattern: "#4ade80",
summary: "#4ade80",
ec2Instance: "#22d3ee", // Cyan 400
s3Bucket: "#38bdf8", // Sky 400
iamRole: "#a78bfa", // Violet 400
iamPolicy: "#a78bfa",
lambdaFunction: "#fbbf24",
securityGroup: "#22d3ee",
default: "#22d3ee",
} as const;
export const GRAPH_EDGE_COLOR = "#ffffff"; // White (default)
export const GRAPH_EDGE_HIGHLIGHT_COLOR = "#f97316"; // Orange 500 (on hover)
export const GRAPH_EDGE_GLOW_COLOR = "#fb923c";
export const GRAPH_SELECTION_COLOR = "#ffffff";
export const GRAPH_BORDER_COLOR = "#374151";
export const GRAPH_ALERT_BORDER_COLOR = "#ef4444"; // Red 500 - for resources with findings
/**
* Get node fill color based on labels and properties
*/
export const getNodeColor = (
labels: string[],
properties?: Record<string, unknown>,
): string => {
const isFinding = labels.some((l) => l.toLowerCase().includes("finding"));
if (isFinding && properties?.severity) {
const severity = String(properties.severity).toLowerCase();
if (severity === "critical") return GRAPH_NODE_COLORS.critical;
if (severity === "high") return GRAPH_NODE_COLORS.high;
if (severity === "medium") return GRAPH_NODE_COLORS.medium;
if (severity === "low") return GRAPH_NODE_COLORS.low;
if (severity === "informational" || severity === "info")
return GRAPH_NODE_COLORS.info;
return GRAPH_NODE_COLORS.prowlerFinding;
}
if (labels.some((l) => l.toLowerCase().includes("attackpattern")))
return GRAPH_NODE_COLORS.attackPattern;
if (labels.includes("AWSAccount")) return GRAPH_NODE_COLORS.awsAccount;
if (labels.includes("EC2Instance")) return GRAPH_NODE_COLORS.ec2Instance;
if (labels.includes("S3Bucket")) return GRAPH_NODE_COLORS.s3Bucket;
if (labels.includes("IAMRole")) return GRAPH_NODE_COLORS.iamRole;
if (labels.includes("IAMPolicy")) return GRAPH_NODE_COLORS.iamPolicy;
if (labels.includes("LambdaFunction"))
return GRAPH_NODE_COLORS.lambdaFunction;
if (labels.includes("SecurityGroup")) return GRAPH_NODE_COLORS.securityGroup;
return GRAPH_NODE_COLORS.default;
};
/**
* Get node border color based on labels and properties
*/
export const getNodeBorderColor = (
labels: string[],
properties?: Record<string, unknown>,
): string => {
const isFinding = labels.some((l) => l.toLowerCase().includes("finding"));
if (isFinding && properties?.severity) {
const severity = String(properties.severity).toLowerCase();
if (severity === "critical") return GRAPH_NODE_BORDER_COLORS.critical;
if (severity === "high") return GRAPH_NODE_BORDER_COLORS.high;
if (severity === "medium") return GRAPH_NODE_BORDER_COLORS.medium;
if (severity === "low") return GRAPH_NODE_BORDER_COLORS.low;
if (severity === "informational" || severity === "info")
return GRAPH_NODE_BORDER_COLORS.info;
return GRAPH_NODE_BORDER_COLORS.prowlerFinding;
}
if (labels.some((l) => l.toLowerCase().includes("attackpattern")))
return GRAPH_NODE_BORDER_COLORS.attackPattern;
if (labels.includes("AWSAccount")) return GRAPH_NODE_BORDER_COLORS.awsAccount;
if (labels.includes("EC2Instance"))
return GRAPH_NODE_BORDER_COLORS.ec2Instance;
if (labels.includes("S3Bucket")) return GRAPH_NODE_BORDER_COLORS.s3Bucket;
if (labels.includes("IAMRole")) return GRAPH_NODE_BORDER_COLORS.iamRole;
if (labels.includes("IAMPolicy")) return GRAPH_NODE_BORDER_COLORS.iamPolicy;
if (labels.includes("LambdaFunction"))
return GRAPH_NODE_BORDER_COLORS.lambdaFunction;
if (labels.includes("SecurityGroup"))
return GRAPH_NODE_BORDER_COLORS.securityGroup;
return GRAPH_NODE_BORDER_COLORS.default;
};
/**
* Check if a background color is light (for determining text color)
*/
export const isLightBackground = (backgroundColor: string): boolean => {
const hex = backgroundColor.replace("#", "");
const r = parseInt(hex.substring(0, 2), 16);
const g = parseInt(hex.substring(2, 4), 16);
const b = parseInt(hex.substring(4, 6), 16);
const luminance = (0.299 * r + 0.587 * g + 0.114 * b) / 255;
return luminance > 0.5;
};
@@ -0,0 +1,185 @@
/**
* Utility functions for attack path graph operations
*/
import type { AttackPathGraphData, GraphNode } from "@/types/attack-paths";
/**
* Type for edge node reference - can be a string ID or an object with id property
*/
export type EdgeNodeRef = string | { id: string };
/**
* Helper to get edge source/target ID from string or object
*/
export const getEdgeNodeId = (nodeRef: EdgeNodeRef): string => {
if (typeof nodeRef === "string") {
return nodeRef;
}
return nodeRef.id;
};
/**
* Compute a filtered subgraph containing only the path through the target node.
* This follows the directed graph structure of attack paths:
* - Upstream: traces back to the root (AWS Account)
* - Downstream: traces forward to leaf nodes
* - Also includes findings connected to the selected node
*/
export const computeFilteredSubgraph = (
fullData: AttackPathGraphData,
targetNodeId: string,
): AttackPathGraphData => {
const nodes = fullData.nodes;
const edges = fullData.edges || [];
// Build directed adjacency lists
const forwardEdges = new Map<string, Set<string>>(); // source -> targets
const backwardEdges = new Map<string, Set<string>>(); // target -> sources
nodes.forEach((node) => {
forwardEdges.set(node.id, new Set());
backwardEdges.set(node.id, new Set());
});
edges.forEach((edge) => {
const sourceId = getEdgeNodeId(edge.source);
const targetId = getEdgeNodeId(edge.target);
forwardEdges.get(sourceId)?.add(targetId);
backwardEdges.get(targetId)?.add(sourceId);
});
const visibleNodeIds = new Set<string>();
visibleNodeIds.add(targetNodeId);
// Traverse upstream (backward) - find all ancestors
const traverseUpstream = (nodeId: string) => {
const sources = backwardEdges.get(nodeId);
if (sources) {
sources.forEach((sourceId) => {
if (!visibleNodeIds.has(sourceId)) {
visibleNodeIds.add(sourceId);
traverseUpstream(sourceId);
}
});
}
};
// Traverse downstream (forward) - find all descendants
const traverseDownstream = (nodeId: string) => {
const targets = forwardEdges.get(nodeId);
if (targets) {
targets.forEach((targetId) => {
if (!visibleNodeIds.has(targetId)) {
visibleNodeIds.add(targetId);
traverseDownstream(targetId);
}
});
}
};
// Start traversal from the target node
traverseUpstream(targetNodeId);
traverseDownstream(targetNodeId);
// Also include findings directly connected to the selected node
edges.forEach((edge) => {
const sourceId = getEdgeNodeId(edge.source);
const targetId = getEdgeNodeId(edge.target);
const sourceNode = nodes.find((n) => n.id === sourceId);
const targetNode = nodes.find((n) => n.id === targetId);
const sourceIsFinding = sourceNode?.labels.some((l) =>
l.toLowerCase().includes("finding"),
);
const targetIsFinding = targetNode?.labels.some((l) =>
l.toLowerCase().includes("finding"),
);
// Include findings connected to the selected node
if (sourceId === targetNodeId && targetIsFinding) {
visibleNodeIds.add(targetId);
}
if (targetId === targetNodeId && sourceIsFinding) {
visibleNodeIds.add(sourceId);
}
});
// Filter nodes and edges to only include visible ones
const filteredNodes = nodes.filter((node) => visibleNodeIds.has(node.id));
const filteredEdges = edges.filter((edge) => {
const sourceId = getEdgeNodeId(edge.source);
const targetId = getEdgeNodeId(edge.target);
return visibleNodeIds.has(sourceId) && visibleNodeIds.has(targetId);
});
return {
nodes: filteredNodes,
edges: filteredEdges,
};
};
/**
* Find edges in the path from a given node.
* Upstream: follows only ONE parent path (first parent at each level) to avoid lighting up siblings
* Downstream: follows ALL children recursively
*
* Uses pre-built adjacency maps for O(1) lookups instead of O(n) array searches per traversal step.
*
* @param nodeId - The starting node ID
* @param edges - Array of edges with sourceId and targetId
* @returns Set of edge IDs in the format "sourceId-targetId"
*/
export const getPathEdges = (
nodeId: string,
edges: Array<{ sourceId: string; targetId: string }>,
): Set<string> => {
// Build adjacency maps once - O(n)
const parentMap = new Map<string, { sourceId: string; targetId: string }>();
const childrenMap = new Map<
string,
Array<{ sourceId: string; targetId: string }>
>();
edges.forEach((edge) => {
// First parent only (matches original behavior of find())
if (!parentMap.has(edge.targetId)) {
parentMap.set(edge.targetId, edge);
}
const children = childrenMap.get(edge.sourceId) || [];
children.push(edge);
childrenMap.set(edge.sourceId, children);
});
const pathEdgeIds = new Set<string>();
const visitedNodes = new Set<string>();
// Traverse upstream - only follow ONE parent at each level (first found)
// This creates a single path to the root, not all paths
const traverseUpstream = (currentNodeId: string) => {
if (visitedNodes.has(`up-${currentNodeId}`)) return;
visitedNodes.add(`up-${currentNodeId}`);
const parentEdge = parentMap.get(currentNodeId); // O(1) lookup
if (parentEdge) {
pathEdgeIds.add(`${parentEdge.sourceId}-${parentEdge.targetId}`);
traverseUpstream(parentEdge.sourceId);
}
};
// Traverse downstream (find ALL targets from this node)
const traverseDownstream = (currentNodeId: string) => {
if (visitedNodes.has(`down-${currentNodeId}`)) return;
visitedNodes.add(`down-${currentNodeId}`);
const children = childrenMap.get(currentNodeId) || []; // O(1) lookup
children.forEach((edge) => {
pathEdgeIds.add(`${edge.sourceId}-${edge.targetId}`);
traverseDownstream(edge.targetId);
});
};
traverseUpstream(nodeId);
traverseDownstream(nodeId);
return pathEdgeIds;
};
@@ -0,0 +1,22 @@
export {
exportGraphAsJSON,
exportGraphAsPNG,
exportGraphAsSVG,
} from "./export";
export { formatNodeLabel, formatNodeLabels } from "./format";
export {
computeFilteredSubgraph,
getEdgeNodeId,
getPathEdges,
type EdgeNodeRef,
} from "./graph-utils";
export {
getNodeBorderColor,
getNodeColor,
GRAPH_ALERT_BORDER_COLOR,
GRAPH_EDGE_COLOR,
GRAPH_EDGE_HIGHLIGHT_COLOR,
GRAPH_NODE_BORDER_COLORS,
GRAPH_NODE_COLORS,
GRAPH_SELECTION_COLOR,
} from "./graph-colors";
@@ -0,0 +1,626 @@
"use client";
import { ArrowLeft, Maximize2, X } from "lucide-react";
import { useSearchParams } from "next/navigation";
import { Suspense, useCallback, useEffect, useRef, useState } from "react";
import { FormProvider } from "react-hook-form";
import {
executeQuery,
getAttackPathScans,
getAvailableQueries,
} from "@/actions/attack-paths";
import { adaptQueryResultToGraphData } from "@/actions/attack-paths/query-result.adapter";
import { AutoRefresh } from "@/components/scans";
import { Button, Card, CardContent } from "@/components/shadcn";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogTrigger,
useToast,
} from "@/components/ui";
import type {
AttackPathQuery,
AttackPathScan,
GraphNode,
} from "@/types/attack-paths";
import {
AttackPathGraph,
ExecuteButton,
GraphControls,
GraphLegend,
GraphLoading,
NodeDetailContent,
QueryParametersForm,
QuerySelector,
ScanListTable,
} from "./_components";
import type { AttackPathGraphRef } from "./_components/graph/attack-path-graph";
import { useGraphState } from "./_hooks/use-graph-state";
import { useQueryBuilder } from "./_hooks/use-query-builder";
import { exportGraphAsSVG, formatNodeLabel } from "./_lib";
/**
* Attack Paths Analysis
* Allows users to select a scan, build a query, and visualize the Attack Paths graph
*/
export default function AttackPathAnalysisPage() {
const searchParams = useSearchParams();
const scanId = searchParams.get("scanId");
const graphState = useGraphState();
const { toast } = useToast();
const [scansLoading, setScansLoading] = useState(true);
const [scans, setScans] = useState<AttackPathScan[]>([]);
const [queriesLoading, setQueriesLoading] = useState(true);
const [queriesError, setQueriesError] = useState<string | null>(null);
const [isFullscreenOpen, setIsFullscreenOpen] = useState(false);
const graphRef = useRef<AttackPathGraphRef>(null);
const fullscreenGraphRef = useRef<AttackPathGraphRef>(null);
const hasResetRef = useRef(false);
const nodeDetailsRef = useRef<HTMLDivElement>(null);
const graphContainerRef = useRef<HTMLDivElement>(null);
const [queries, setQueries] = useState<AttackPathQuery[]>([]);
// Use custom hook for query builder form state and validation
const queryBuilder = useQueryBuilder(queries);
// Reset graph state when component mounts
useEffect(() => {
if (!hasResetRef.current) {
hasResetRef.current = true;
graphState.resetGraph();
}
}, [graphState]);
// Load available scans on mount
useEffect(() => {
const loadScans = async () => {
setScansLoading(true);
try {
const scansData = await getAttackPathScans();
if (scansData?.data) {
setScans(scansData.data);
} else {
setScans([]);
}
} catch (error) {
console.error("Failed to load scans:", error);
setScans([]);
} finally {
setScansLoading(false);
}
};
loadScans();
}, []);
// Check if there's an executing scan for auto-refresh
const hasExecutingScan = scans.some(
(scan) =>
scan.attributes.state === "executing" ||
scan.attributes.state === "scheduled",
);
// Callback to refresh scans (used by AutoRefresh component)
const refreshScans = useCallback(async () => {
try {
const scansData = await getAttackPathScans();
if (scansData?.data) {
setScans(scansData.data);
}
} catch (error) {
console.error("Failed to refresh scans:", error);
}
}, []);
// Load available queries on mount
useEffect(() => {
const loadQueries = async () => {
if (!scanId) {
setQueriesError("No scan selected");
setQueriesLoading(false);
return;
}
setQueriesLoading(true);
try {
const queriesData = await getAvailableQueries(scanId);
if (queriesData?.data) {
setQueries(queriesData.data);
setQueriesError(null);
} else {
setQueriesError("Failed to load available queries");
toast({
title: "Error",
description: "Failed to load queries for this scan",
variant: "destructive",
});
}
} catch (error) {
const errorMsg =
error instanceof Error ? error.message : "Unknown error";
setQueriesError(errorMsg);
toast({
title: "Error",
description: "Failed to load queries",
variant: "destructive",
});
} finally {
setQueriesLoading(false);
}
};
loadQueries();
}, [scanId, toast]);
const handleQueryChange = (queryId: string) => {
queryBuilder.handleQueryChange(queryId);
};
const showErrorToast = (title: string, description: string) => {
toast({
title,
description,
variant: "destructive",
});
};
const handleExecuteQuery = async () => {
if (!scanId || !queryBuilder.selectedQuery) {
showErrorToast("Error", "Please select both a scan and a query");
return;
}
// Validate form before executing query
const isValid = await queryBuilder.form.trigger();
if (!isValid) {
showErrorToast(
"Validation Error",
"Please fill in all required parameters",
);
return;
}
graphState.startLoading();
graphState.setError(null);
try {
const parameters = queryBuilder.getQueryParameters() as Record<
string,
string | number | boolean
>;
const result = await executeQuery(
scanId,
queryBuilder.selectedQuery,
parameters,
);
if (result?.data?.attributes) {
const graphData = adaptQueryResultToGraphData(result.data.attributes);
graphState.updateGraphData(graphData);
toast({
title: "Success",
description: "Query executed successfully",
variant: "default",
});
// Scroll to graph after successful query execution
setTimeout(() => {
graphContainerRef.current?.scrollIntoView({
behavior: "smooth",
block: "start",
});
}, 100);
} else {
graphState.resetGraph();
graphState.setError("No data returned from query");
showErrorToast("Error", "Query returned no data");
}
} catch (error) {
const errorMsg =
error instanceof Error ? error.message : "Failed to execute query";
graphState.resetGraph();
graphState.setError(errorMsg);
showErrorToast("Error", errorMsg);
} finally {
graphState.stopLoading();
}
};
const handleNodeClick = (node: GraphNode) => {
// Enter filtered view showing only paths containing this node
graphState.enterFilteredView(node.id);
// For findings, also scroll to the details section
const isFinding = node.labels.some((label) =>
label.toLowerCase().includes("finding"),
);
if (isFinding) {
setTimeout(() => {
nodeDetailsRef.current?.scrollIntoView({
behavior: "smooth",
block: "nearest",
});
}, 100);
}
};
const handleBackToFullView = () => {
graphState.exitFilteredView();
};
const handleCloseDetails = () => {
graphState.selectNode(null);
};
const handleGraphExport = (svgElement: SVGSVGElement | null) => {
try {
if (svgElement) {
exportGraphAsSVG(svgElement, "attack-path-graph.svg");
toast({
title: "Success",
description: "Graph exported as SVG",
variant: "default",
});
} else {
throw new Error("Could not find graph element");
}
} catch (error) {
toast({
title: "Error",
description:
error instanceof Error ? error.message : "Failed to export graph",
variant: "destructive",
});
}
};
return (
<div className="flex flex-col gap-6">
{/* Auto-refresh scans when there's an executing scan */}
<AutoRefresh
hasExecutingScan={hasExecutingScan}
onRefresh={refreshScans}
/>
{/* Header */}
<div>
<h2 className="dark:text-prowler-theme-pale/90 text-xl font-semibold">
Attack Paths Analysis
</h2>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-2 text-sm">
Select a scan, build a query, and visualize Attack Paths in your
infrastructure.
</p>
</div>
{/* Top Section - Scans Table and Query Builder (2 columns) */}
<div className="grid grid-cols-1 gap-8 xl:grid-cols-2">
{/* Scans Table Section - Left Column */}
<div>
{scansLoading ? (
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
<p className="text-sm">Loading scans...</p>
</div>
) : scans.length === 0 ? (
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
<p className="text-sm">No scans available</p>
</div>
) : (
<Suspense fallback={<div>Loading scans...</div>}>
<ScanListTable scans={scans} />
</Suspense>
)}
</div>
{/* Query Builder Section - Right Column */}
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
{!scanId ? (
<p className="text-text-info dark:text-text-info text-sm">
Select a scan from the table on the left to begin.
</p>
) : queriesLoading ? (
<p className="text-sm">Loading queries...</p>
) : queriesError ? (
<p className="text-text-danger dark:text-text-danger text-sm">
{queriesError}
</p>
) : (
<>
<FormProvider {...queryBuilder.form}>
<QuerySelector
queries={queries}
selectedQueryId={queryBuilder.selectedQuery}
onQueryChange={handleQueryChange}
/>
{queryBuilder.selectedQuery && (
<QueryParametersForm
selectedQuery={queryBuilder.selectedQueryData}
/>
)}
</FormProvider>
<div className="flex gap-3">
<ExecuteButton
isLoading={graphState.loading}
isDisabled={!queryBuilder.selectedQuery}
onExecute={handleExecuteQuery}
/>
</div>
{graphState.error && (
<div className="bg-bg-danger-secondary text-text-danger dark:bg-bg-danger-secondary dark:text-text-danger rounded p-3 text-sm">
{graphState.error}
</div>
)}
</>
)}
</div>
</div>
{/* Bottom Section - Graph Visualization (Full Width) */}
<div className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4">
{graphState.loading ? (
<GraphLoading />
) : graphState.data &&
graphState.data.nodes &&
graphState.data.nodes.length > 0 ? (
<>
{/* Info message and controls */}
<div className="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
{graphState.isFilteredView ? (
<div className="flex items-center gap-3">
<Button
onClick={handleBackToFullView}
variant="outline"
size="sm"
className="gap-2"
aria-label="Return to full graph view"
>
<ArrowLeft size={16} />
Back to Full View
</Button>
<div
className="bg-bg-info-secondary text-text-info inline-flex cursor-default items-center gap-2 rounded-md px-3 py-2 text-xs font-medium shadow-sm sm:px-4 sm:text-sm"
role="status"
aria-label="Filtered view active"
>
<span className="flex-shrink-0" aria-hidden="true">
🔍
</span>
<span className="flex-1">
Showing paths for:{" "}
<strong>
{graphState.filteredNode?.properties?.name ||
graphState.filteredNode?.properties?.id ||
"Selected node"}
</strong>
</span>
</div>
</div>
) : (
<div
className="bg-button-primary inline-flex cursor-default items-center gap-2 rounded-md px-3 py-2 text-xs font-medium text-black shadow-sm sm:px-4 sm:text-sm"
role="status"
aria-label="Graph interaction instructions"
>
<span className="flex-shrink-0" aria-hidden="true">
💡
</span>
<span className="flex-1">
Click on any node to filter and view its connected paths
</span>
</div>
)}
{/* Graph controls and fullscreen button together */}
<div className="flex items-center gap-2">
<GraphControls
onZoomIn={() => graphRef.current?.zoomIn()}
onZoomOut={() => graphRef.current?.zoomOut()}
onFitToScreen={() => graphRef.current?.resetZoom()}
onExport={() =>
handleGraphExport(graphRef.current?.getSVGElement() || null)
}
/>
{/* Fullscreen button */}
<div className="border-border-neutral-primary bg-bg-neutral-tertiary flex gap-1 rounded-lg border p-1">
<Dialog
open={isFullscreenOpen}
onOpenChange={setIsFullscreenOpen}
>
<DialogTrigger asChild>
<Button
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
aria-label="Fullscreen"
>
<Maximize2 size={18} />
</Button>
</DialogTrigger>
<DialogContent className="flex h-full max-h-screen w-full max-w-full flex-col gap-0 p-0">
<DialogHeader className="px-4 pt-4 sm:px-6 sm:pt-6">
<DialogTitle className="text-lg">
Graph Fullscreen View
</DialogTitle>
</DialogHeader>
<div className="px-4 pt-4 pb-4 sm:px-6 sm:pt-6">
<GraphControls
onZoomIn={() => fullscreenGraphRef.current?.zoomIn()}
onZoomOut={() =>
fullscreenGraphRef.current?.zoomOut()
}
onFitToScreen={() =>
fullscreenGraphRef.current?.resetZoom()
}
onExport={() =>
handleGraphExport(
fullscreenGraphRef.current?.getSVGElement() ||
null,
)
}
/>
</div>
<div className="flex flex-1 gap-4 overflow-hidden px-4 pb-4 sm:px-6 sm:pb-6">
<div className="flex flex-1 items-center justify-center">
<AttackPathGraph
ref={fullscreenGraphRef}
data={graphState.data}
onNodeClick={handleNodeClick}
selectedNodeId={graphState.selectedNodeId}
isFilteredView={graphState.isFilteredView}
/>
</div>
{/* Node Detail Panel - Side by side */}
{graphState.selectedNode && (
<section aria-labelledby="node-details-heading">
<Card className="w-96 overflow-y-auto">
<CardContent className="p-4">
<div className="mb-4 flex items-center justify-between">
<h3
id="node-details-heading"
className="text-sm font-semibold"
>
Node Details
</h3>
<Button
onClick={handleCloseDetails}
variant="ghost"
size="sm"
className="h-6 w-6 p-0"
aria-label="Close node details"
>
<X size={16} />
</Button>
</div>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mb-4 text-xs">
{graphState.selectedNode?.labels.some(
(label) =>
label.toLowerCase().includes("finding"),
)
? graphState.selectedNode?.properties
?.check_title ||
graphState.selectedNode?.properties?.id ||
"Unknown Finding"
: graphState.selectedNode?.properties
?.name ||
graphState.selectedNode?.properties?.id ||
"Unknown Resource"}
</p>
<div className="flex flex-col gap-4">
<div>
<h4 className="mb-2 text-xs font-semibold">
Type
</h4>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary text-xs">
{graphState.selectedNode?.labels
.map(formatNodeLabel)
.join(", ")}
</p>
</div>
</div>
</CardContent>
</Card>
</section>
)}
</div>
</DialogContent>
</Dialog>
</div>
</div>
</div>
{/* Graph in the middle */}
<div ref={graphContainerRef} className="h-[calc(100vh-22rem)]">
<AttackPathGraph
ref={graphRef}
data={graphState.data}
onNodeClick={handleNodeClick}
selectedNodeId={graphState.selectedNodeId}
isFilteredView={graphState.isFilteredView}
/>
</div>
{/* Legend below */}
<div className="hidden justify-center lg:flex">
<GraphLegend data={graphState.data} />
</div>
</>
) : (
<div className="flex flex-1 items-center justify-center text-center">
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary text-sm">
Select a query and click &quot;Execute Query&quot; to visualize
the Attack Paths graph
</p>
</div>
)}
</div>
{/* Node Detail Panel - Below Graph */}
{graphState.selectedNode && graphState.data && (
<div
ref={nodeDetailsRef}
className="minimal-scrollbar rounded-large shadow-small border-border-neutral-secondary bg-bg-neutral-secondary relative z-0 flex w-full flex-col gap-4 overflow-auto border p-4"
>
<div className="flex items-center justify-between">
<div className="flex-1">
<h3 className="text-lg font-semibold">Node Details</h3>
<p className="text-text-neutral-secondary dark:text-text-neutral-secondary mt-1 text-sm">
{String(
graphState.selectedNode.labels.some((label) =>
label.toLowerCase().includes("finding"),
)
? graphState.selectedNode.properties?.check_title ||
graphState.selectedNode.properties?.id ||
"Unknown Finding"
: graphState.selectedNode.properties?.name ||
graphState.selectedNode.properties?.id ||
"Unknown Resource",
)}
</p>
</div>
<div className="flex items-center gap-2">
{graphState.selectedNode.labels.some((label) =>
label.toLowerCase().includes("finding"),
) && (
<Button asChild variant="default" size="sm">
<a
href={`/findings?id=${String(graphState.selectedNode.properties?.id || graphState.selectedNode.id)}`}
target="_blank"
rel="noopener noreferrer"
aria-label={`View finding ${String(graphState.selectedNode.properties?.id || graphState.selectedNode.id)}`}
>
View Finding
</a>
</Button>
)}
<Button
onClick={handleCloseDetails}
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
aria-label="Close node details"
>
<X size={16} />
</Button>
</div>
</div>
<NodeDetailContent
node={graphState.selectedNode}
allNodes={graphState.data.nodes}
/>
</div>
)}
</div>
);
}
+9
View File
@@ -0,0 +1,9 @@
import { redirect } from "next/navigation";
/**
* Landing page for Attack Paths feature
* Redirects to the integrated attack path analysis view
*/
export default function AttackPathsPage() {
redirect("/attack-paths/query-builder");
}
+76 -9
View File
@@ -1,6 +1,7 @@
import { Suspense } from "react";
import {
getFindingById,
getFindings,
getLatestFindings,
getLatestMetadataInfo,
@@ -8,6 +9,7 @@ import {
} from "@/actions/findings";
import { getProviders } from "@/actions/providers";
import { getScans } from "@/actions/scans";
import { FindingDetailsSheet } from "@/components/findings";
import { FindingsFilters } from "@/components/findings/findings-filters";
import {
FindingsTableWithSelection,
@@ -41,15 +43,79 @@ export default async function Findings({
// Check if the searchParams contain any date or scan filter
const hasDateOrScan = hasDateOrScanFilter(resolvedSearchParams);
const [metadataInfoData, providersData, scansData] = await Promise.all([
(hasDateOrScan ? getMetadataInfo : getLatestMetadataInfo)({
query,
sort: encodedSort,
filters,
}),
getProviders({ pageSize: 50 }),
getScans({ pageSize: 50 }),
]);
// Check if there's a specific finding ID to fetch
const findingId = resolvedSearchParams.id?.toString();
const [metadataInfoData, providersData, scansData, findingByIdData] =
await Promise.all([
(hasDateOrScan ? getMetadataInfo : getLatestMetadataInfo)({
query,
sort: encodedSort,
filters,
}),
getProviders({ pageSize: 50 }),
getScans({ pageSize: 50 }),
findingId
? getFindingById(findingId, "resources,scan.provider")
: Promise.resolve(null),
]);
// Process the finding data to match the expected structure
const processedFinding = findingByIdData?.data
? (() => {
const finding = findingByIdData.data;
const included = findingByIdData.included || [];
// Build dictionaries from included data
type IncludedItem = {
type: string;
id: string;
attributes: Record<string, unknown>;
relationships?: {
provider?: { data?: { id: string } };
};
};
const resourceDict: Record<string, unknown> = {};
const scanDict: Record<string, IncludedItem> = {};
const providerDict: Record<string, unknown> = {};
included.forEach((item: IncludedItem) => {
if (item.type === "resources") {
resourceDict[item.id] = {
id: item.id,
attributes: item.attributes,
};
} else if (item.type === "scans") {
scanDict[item.id] = item;
} else if (item.type === "providers") {
providerDict[item.id] = {
id: item.id,
attributes: item.attributes,
};
}
});
const scanId = finding.relationships?.scan?.data?.id;
const resourceId = finding.relationships?.resources?.data?.[0]?.id;
const scan = scanId ? scanDict[scanId] : undefined;
const providerId = scan?.relationships?.provider?.data?.id;
const resource = resourceId ? resourceDict[resourceId] : undefined;
const provider = providerId ? providerDict[providerId] : undefined;
return {
...finding,
relationships: {
scan: scan
? { data: scan, attributes: scan.attributes }
: undefined,
resource: resource,
provider: provider,
},
} as FindingProps;
})()
: null;
// Extract unique regions, services, categories from the new endpoint
const uniqueRegions = metadataInfoData?.data?.attributes?.regions || [];
@@ -98,6 +164,7 @@ export default async function Findings({
<Suspense key={searchParamsKey} fallback={<SkeletonTableFindings />}>
<SSRDataTable searchParams={resolvedSearchParams} />
</Suspense>
{processedFinding && <FindingDetailsSheet finding={processedFinding} />}
</ContentLayout>
);
}
@@ -0,0 +1,46 @@
"use client";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import {
Sheet,
SheetContent,
SheetDescription,
SheetHeader,
SheetTitle,
} from "@/components/ui/sheet";
import { FindingProps } from "@/types/components";
import { FindingDetail } from "./table/finding-detail";
interface FindingDetailsSheetProps {
finding: FindingProps;
}
export const FindingDetailsSheet = ({ finding }: FindingDetailsSheetProps) => {
const router = useRouter();
const pathname = usePathname();
const searchParams = useSearchParams();
const handleOpenChange = (open: boolean) => {
if (!open) {
const params = new URLSearchParams(searchParams.toString());
params.delete("id");
router.push(`${pathname}?${params.toString()}`, { scroll: false });
}
};
return (
<Sheet open={true} onOpenChange={handleOpenChange}>
<SheetContent className="my-4 max-h-[calc(100vh-2rem)] max-w-[95vw] overflow-y-auto pt-10 md:my-8 md:max-h-[calc(100vh-4rem)] md:max-w-[55vw]">
<SheetHeader>
<SheetTitle className="sr-only">Finding Details</SheetTitle>
<SheetDescription className="sr-only">
View the finding details
</SheetDescription>
</SheetHeader>
<FindingDetail findingDetails={finding} />
</SheetContent>
</Sheet>
);
};
+1
View File
@@ -1 +1,2 @@
export * from "./finding-details-sheet";
export * from "./muted";
@@ -4,10 +4,8 @@ import { ColumnDef, RowSelectionState } from "@tanstack/react-table";
import { Database } from "lucide-react";
import { useSearchParams } from "next/navigation";
import {
DataTableRowActions,
FindingDetail,
} from "@/components/findings/table";
import { FindingDetail } from "@/components/findings/table";
import { DataTableRowActions } from "@/components/findings/table";
import { Checkbox } from "@/components/shadcn";
import { DateWithTime, SnippetChip } from "@/components/ui/entities";
import {
@@ -57,23 +55,10 @@ const FindingTitleCell = ({ row }: { row: { original: FindingProps } }) => {
const isOpen = findingId === row.original.id;
const { checktitle } = row.original.attributes.check_metadata;
const handleOpenChange = (open: boolean) => {
const params = new URLSearchParams(searchParams);
if (open) {
params.set("id", row.original.id);
} else {
params.delete("id");
}
window.history.pushState({}, "", `?${params.toString()}`);
};
return (
<FindingDetail
findingDetails={row.original}
defaultOpen={isOpen}
onOpenChange={handleOpenChange}
trigger={
<div className="max-w-[500px]">
<p className="text-text-neutral-primary hover:text-button-tertiary cursor-pointer text-left text-sm break-words whitespace-normal hover:underline">
@@ -1,6 +1,5 @@
"use client";
import { Snippet } from "@heroui/snippet";
import { ExternalLink, Link, X } from "lucide-react";
import { usePathname, useSearchParams } from "next/navigation";
import type { ReactNode } from "react";
@@ -33,6 +32,7 @@ import {
StatusFindingBadge,
} from "@/components/ui/table/status-finding-badge";
import { buildGitFileUrl, extractLineRangeFromUid } from "@/lib/iac-utils";
import { cn } from "@/lib/utils";
import { FindingProps, ProviderType } from "@/types";
import { Muted } from "../muted";
@@ -196,16 +196,16 @@ export const FindingDetail = ({
{attributes.status === "FAIL" && (
<InfoField label="Risk" variant="simple">
<Snippet
className="max-w-full py-2"
color="danger"
hideCopyButton
hideSymbol
<div
className={cn(
"max-w-full rounded-md border p-2",
"border-border-error-primary bg-bg-fail-secondary",
)}
>
<MarkdownContainer>
{attributes.check_metadata.risk}
</MarkdownContainer>
</Snippet>
</div>
</InfoField>
)}
@@ -255,11 +255,13 @@ export const FindingDetail = ({
{/* CLI Command section */}
{attributes.check_metadata.remediation.code.cli && (
<InfoField label="CLI Command" variant="simple">
<Snippet>
<div
className={cn("rounded-md p-2", "bg-bg-neutral-tertiary")}
>
<span className="text-xs whitespace-pre-line">
{attributes.check_metadata.remediation.code.cli}
</span>
</Snippet>
</div>
</InfoField>
)}
+11 -3
View File
@@ -5,9 +5,11 @@ import { useEffect } from "react";
interface AutoRefreshProps {
hasExecutingScan: boolean;
/** Optional callback for client-side refresh (used when data is managed in local state) */
onRefresh?: () => void | Promise<void>;
}
export function AutoRefresh({ hasExecutingScan }: AutoRefreshProps) {
export function AutoRefresh({ hasExecutingScan, onRefresh }: AutoRefreshProps) {
const router = useRouter();
const searchParams = useSearchParams();
@@ -19,11 +21,17 @@ export function AutoRefresh({ hasExecutingScan }: AutoRefreshProps) {
if (scanId) return;
const interval = setInterval(() => {
router.refresh();
if (onRefresh) {
// Use custom refresh callback for client-side state management
onRefresh();
} else {
// Default: trigger server-side refresh
router.refresh();
}
}, 5000);
return () => clearInterval(interval);
}, [hasExecutingScan, router, searchParams]);
}, [hasExecutingScan, router, searchParams, onRefresh]);
return null;
}
@@ -54,6 +54,7 @@ export function BreadcrumbNavigation({
"/manage-groups": "lucide:users-2",
"/services": "lucide:server",
"/workloads": "lucide:layers",
"/attack-paths": "lucide:git-branch",
};
const pathSegments = pathname
@@ -156,6 +157,7 @@ export function BreadcrumbNavigation({
>
{breadcrumb.icon && typeof breadcrumb.icon === "string" ? (
<Icon
aria-hidden="true"
className="text-text-neutral-primary"
height={24}
icon={breadcrumb.icon}
@@ -177,6 +179,7 @@ export function BreadcrumbNavigation({
>
{breadcrumb.icon && typeof breadcrumb.icon === "string" ? (
<Icon
aria-hidden="true"
className="text-text-neutral-primary"
height={24}
icon={breadcrumb.icon}
@@ -195,6 +198,7 @@ export function BreadcrumbNavigation({
<div className="flex items-center gap-2">
{breadcrumb.icon && typeof breadcrumb.icon === "string" ? (
<Icon
aria-hidden="true"
className="text-default-500"
height={24}
icon={breadcrumb.icon}
+21 -3
View File
@@ -20,6 +20,7 @@ interface MenuItemProps {
target?: string;
tooltip?: string;
isOpen: boolean;
highlight?: boolean;
}
export const MenuItem = ({
@@ -30,6 +31,7 @@ export const MenuItem = ({
target,
tooltip,
isOpen,
highlight,
}: MenuItemProps) => {
const pathname = usePathname();
const isActive = active !== undefined ? active : pathname.startsWith(href);
@@ -44,15 +46,31 @@ export const MenuItem = ({
variant={isActive ? "menu-active" : "menu-inactive"}
className={cn(
isOpen ? "w-full justify-start" : "w-14 justify-center",
highlight &&
"relative overflow-hidden before:absolute before:inset-0 before:rounded-lg before:bg-gradient-to-r before:from-emerald-500/20 before:via-teal-400/20 before:to-emerald-300/20 before:opacity-70",
)}
asChild
>
<Link href={href} target={target}>
<div className="flex items-center">
<span className={cn(isOpen ? "mr-4" : "")}>
<div className="relative z-10 flex items-center">
<span
className={cn(
isOpen ? "mr-4" : "",
highlight && "text-button-primary",
)}
>
<Icon size={18} />
</span>
{isOpen && <p className="max-w-[200px] truncate">{label}</p>}
{isOpen && (
<p className="max-w-[200px] truncate">
{label}
{highlight && (
<span className="ml-2 rounded-sm bg-emerald-500 px-1.5 py-0.5 text-[10px] font-semibold text-white">
NEW
</span>
)}
</p>
)}
</div>
</Link>
</Button>
+1
View File
@@ -113,6 +113,7 @@ export const Menu = ({ isOpen }: { isOpen: boolean }) => {
target={menu.target}
tooltip={menu.tooltip}
isOpen={isOpen}
highlight={menu.highlight}
/>
)}
</div>
+2 -3
View File
@@ -48,6 +48,7 @@ export const StatusBadge = ({
className?: string;
}) => {
const color = statusColorMap[status as keyof typeof statusColorMap];
const displayLabel = statusDisplayMap[status] || status;
return (
<Chip
@@ -70,9 +71,7 @@ export const StatusBadge = ({
<span>executing</span>
</div>
) : (
<span className="flex items-center justify-center">
{statusDisplayMap[status as keyof typeof statusDisplayMap] || status}
</span>
<span className="flex items-center justify-center">{displayLabel}</span>
)}
</Chip>
);
+18 -2
View File
@@ -149,7 +149,7 @@
"from": "1.1.15",
"to": "1.1.15",
"strategy": "installed",
"generatedAt": "2025-11-20T08:20:16.313Z"
"generatedAt": "2025-11-19T12:28:39.510Z"
},
{
"section": "dependencies",
@@ -295,6 +295,14 @@
"strategy": "installed",
"generatedAt": "2025-10-22T12:36:37.962Z"
},
{
"section": "dependencies",
"name": "@types/dagre",
"from": "0.7.53",
"to": "0.7.53",
"strategy": "installed",
"generatedAt": "2025-11-27T11:47:22.908Z"
},
{
"section": "dependencies",
"name": "@types/js-yaml",
@@ -341,7 +349,7 @@
"from": "1.1.1",
"to": "1.1.1",
"strategy": "installed",
"generatedAt": "2025-11-20T08:20:16.313Z"
"generatedAt": "2025-11-19T12:28:39.510Z"
},
{
"section": "dependencies",
@@ -351,6 +359,14 @@
"strategy": "installed",
"generatedAt": "2025-10-22T12:36:37.962Z"
},
{
"section": "dependencies",
"name": "dagre",
"from": "0.8.5",
"to": "0.8.5",
"strategy": "installed",
"generatedAt": "2025-11-27T11:47:22.908Z"
},
{
"section": "dependencies",
"name": "date-fns",
+14
View File
@@ -1,6 +1,7 @@
import {
CloudCog,
Cog,
GitBranch,
Group,
Mail,
MessageCircleQuestion,
@@ -66,6 +67,19 @@ export const getMenuList = ({ pathname }: MenuListOptions): GroupProps[] => {
},
],
},
{
groupLabel: "",
menus: [
{
href: "/attack-paths",
label: "Attack Paths",
icon: GitBranch,
active: pathname.startsWith("/attack-paths"),
highlight: true,
},
],
},
{
groupLabel: "",
menus: [
+2
View File
@@ -61,6 +61,7 @@
"@tailwindcss/postcss": "4.1.13",
"@tailwindcss/typography": "0.5.16",
"@tanstack/react-table": "8.21.3",
"@types/dagre": "0.7.53",
"@types/js-yaml": "4.0.9",
"ai": "5.0.109",
"alert": "6.0.2",
@@ -68,6 +69,7 @@
"clsx": "2.1.1",
"cmdk": "1.1.1",
"d3": "7.9.0",
"dagre": "0.8.5",
"date-fns": "4.1.0",
"framer-motion": "11.18.2",
"import-in-the-middle": "2.0.0",
+26
View File
@@ -129,6 +129,9 @@ importers:
'@tanstack/react-table':
specifier: 8.21.3
version: 8.21.3(react-dom@19.2.2(react@19.2.2))(react@19.2.2)
'@types/dagre':
specifier: 0.7.53
version: 0.7.53
'@types/js-yaml':
specifier: 4.0.9
version: 4.0.9
@@ -150,6 +153,9 @@ importers:
d3:
specifier: 7.9.0
version: 7.9.0
dagre:
specifier: 0.8.5
version: 0.8.5
date-fns:
specifier: 4.1.0
version: 4.1.0
@@ -4295,6 +4301,9 @@ packages:
'@types/d3@7.4.3':
resolution: {integrity: sha512-lZXZ9ckh5R8uiFVt8ogUNf+pIrK4EsWrx2Np75WvF/eTpJ0FMHNhjXk8CKEx/+gpHbNQyJWehbFaTvqmHWB3ww==}
'@types/dagre@0.7.53':
resolution: {integrity: sha512-f4gkWqzPZvYmKhOsDnhq/R8mO4UMcKdxZo+i5SCkOU1wvGeHJeUXGIHeE9pnwGyPMDof1Vx5ZQo4nxpeg2TTVQ==}
'@types/debug@4.1.12':
resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==}
@@ -5205,6 +5214,9 @@ packages:
dagre-d3-es@7.0.13:
resolution: {integrity: sha512-efEhnxpSuwpYOKRm/L5KbqoZmNNukHa/Flty4Wp62JRvgH2ojwVgPgdYyr4twpieZnyRDdIH7PY2mopX26+j2Q==}
dagre@0.8.5:
resolution: {integrity: sha512-/aTqmnRta7x7MCCpExk7HQL2O4owCT2h8NT//9I1OQ9vt29Pa0BzSAkR5lwFUcQ7491yVi/3CXU9jQ5o0Mn2Sw==}
damerau-levenshtein@1.0.8:
resolution: {integrity: sha512-sdQSFB7+llfUcQHUQO3+B8ERRj0Oa4w9POWMI/puGtuf7gFywGmkaLCElnudfTiKZV+NvHqL0ifzdrI8Ro7ESA==}
@@ -5932,6 +5944,9 @@ packages:
graphemer@1.4.0:
resolution: {integrity: sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==}
graphlib@2.1.8:
resolution: {integrity: sha512-jcLLfkpoVGmH7/InMC/1hIvOPSUh38oJtGhvrOFGzioE1DZ+0YW16RgmOJhHiuWTvGiJQ9Z1Ik43JvkRPRvE+A==}
graphql@16.12.0:
resolution: {integrity: sha512-DKKrynuQRne0PNpEbzuEdHlYOMksHSUI8Zc9Unei5gTsMNA2/vMpoMz/yKba50pejK56qj98qM0SjYxAKi13gQ==}
engines: {node: ^12.22.0 || ^14.16.0 || ^16.0.0 || >=17.0.0}
@@ -14010,6 +14025,8 @@ snapshots:
'@types/d3-transition': 3.0.9
'@types/d3-zoom': 3.0.8
'@types/dagre@0.7.53': {}
'@types/debug@4.1.12':
dependencies:
'@types/ms': 2.1.0
@@ -14963,6 +14980,11 @@ snapshots:
d3: 7.9.0
lodash-es: 4.17.21
dagre@0.8.5:
dependencies:
graphlib: 2.1.8
lodash: 4.17.21
damerau-levenshtein@1.0.8: {}
data-uri-to-buffer@4.0.1: {}
@@ -15842,6 +15864,10 @@ snapshots:
graphemer@1.4.0: {}
graphlib@2.1.8:
dependencies:
lodash: 4.17.21
graphql@16.12.0: {}
hachure-fill@0.5.2: {}
+3
View File
@@ -55,6 +55,7 @@
--bg-pass-primary: var(--color-emerald-400);
--bg-pass-secondary: var(--color-emerald-50);
--bg-warning-primary: var(--color-orange-500);
--bg-warning-secondary: var(--color-orange-50);
--bg-fail-primary: var(--color-rose-500);
--bg-fail-secondary: var(--color-rose-50);
@@ -129,6 +130,7 @@
--bg-pass-primary: var(--color-green-400);
--bg-pass-secondary: var(--color-emerald-900);
--bg-warning-primary: var(--color-orange-400);
--bg-warning-secondary: var(--color-orange-900);
--bg-fail-primary: var(--color-rose-500);
--bg-fail-secondary: #432232;
@@ -220,6 +222,7 @@
--color-bg-pass: var(--bg-pass-primary);
--color-bg-pass-secondary: var(--bg-pass-secondary);
--color-bg-warning: var(--bg-warning-primary);
--color-bg-warning-secondary: var(--bg-warning-secondary);
--color-bg-fail: var(--bg-fail-primary);
--color-bg-fail-secondary: var(--bg-fail-secondary);
+245
View File
@@ -0,0 +1,245 @@
/**
* Attack Paths Feature Types
* Defines all TypeScript interfaces for the Attack Paths visualization feature
*/
// Scan state constants
export const SCAN_STATES = {
AVAILABLE: "available",
SCHEDULED: "scheduled",
EXECUTING: "executing",
COMPLETED: "completed",
FAILED: "failed",
} as const;
export type ScanState = (typeof SCAN_STATES)[keyof typeof SCAN_STATES];
// Attack Path Scan - Relationship Data
export interface RelationshipData {
type: string;
id: string;
}
export interface RelationshipWrapper {
data: RelationshipData;
}
export interface ScanRelationships {
provider: RelationshipWrapper;
scan: RelationshipWrapper;
task: RelationshipWrapper;
}
// Provider type constants
export const PROVIDER_TYPES = {
AWS: "aws",
AZURE: "azure",
GCP: "gcp",
} as const;
export type ProviderType = (typeof PROVIDER_TYPES)[keyof typeof PROVIDER_TYPES];
// Attack Path Scan Response
export interface AttackPathScanAttributes {
state: ScanState;
progress: number;
provider_alias: string;
provider_type: ProviderType;
provider_uid: string;
inserted_at: string;
started_at: string;
completed_at: string | null;
duration: number | null;
}
export interface AttackPathScan {
type: "attack-paths-scans";
id: string;
attributes: AttackPathScanAttributes;
relationships: ScanRelationships;
}
export interface PaginationLinks {
first: string;
last: string;
next: string | null;
prev: string | null;
}
export interface AttackPathScansResponse {
data: AttackPathScan[];
links: PaginationLinks;
}
// Data type constants
const DATA_TYPES = {
STRING: "string",
NUMBER: "number",
BOOLEAN: "boolean",
} as const;
type DataType = (typeof DATA_TYPES)[keyof typeof DATA_TYPES];
// Query Types
export interface AttackPathQueryParameter {
name: string;
label: string;
data_type: DataType;
description: string;
placeholder?: string;
required?: boolean;
}
export interface AttackPathQueryAttributes {
name: string;
description: string;
provider: string;
parameters: AttackPathQueryParameter[];
}
export interface AttackPathQuery {
type: "attack-paths-scans";
id: string;
attributes: AttackPathQueryAttributes;
}
export interface AttackPathQueriesResponse {
data: AttackPathQuery[];
}
// Graph Data Types
// Property values from graph nodes can be any primitive type or arrays
export type GraphNodePropertyValue =
| string
| number
| boolean
| null
| undefined
| string[]
| number[];
export interface GraphNodeProperties {
[key: string]: GraphNodePropertyValue;
}
export interface GraphNode {
id: string;
labels: string[]; // e.g., ["S3Bucket"], ["EC2Instance"], ["ProwlerFinding"]
properties: GraphNodeProperties;
findings?: string[]; // IDs of finding nodes connected via HAS_FINDING edges
resources?: string[]; // IDs of resource nodes connected via HAS_FINDING edges
}
export interface GraphEdge {
id: string;
source: string | object;
target: string | object;
type: string;
properties?: GraphNodeProperties;
}
export interface GraphRelationship {
id: string;
label: string;
source: string;
target: string;
properties?: GraphNodeProperties;
}
export interface AttackPathGraphData {
nodes: GraphNode[];
edges?: GraphEdge[];
relationships?: GraphRelationship[];
}
export interface QueryResultAttributes {
nodes: GraphNode[];
relationships?: GraphRelationship[];
}
export interface QueryResultData {
type: "attack-paths-query-run-request";
id: string | null;
attributes: QueryResultAttributes;
}
export interface AttackPathQueryResult {
data: QueryResultData;
}
// Finding severity and status constants
const FINDING_SEVERITIES = {
CRITICAL: "critical",
HIGH: "high",
MEDIUM: "medium",
LOW: "low",
INFO: "info",
} as const;
type FindingSeverity =
(typeof FINDING_SEVERITIES)[keyof typeof FINDING_SEVERITIES];
const FINDING_STATUSES = {
PASS: "PASS",
FAIL: "FAIL",
MANUAL: "MANUAL",
} as const;
type FindingStatus = (typeof FINDING_STATUSES)[keyof typeof FINDING_STATUSES];
export interface RelatedFinding {
id: string;
title: string;
severity: FindingSeverity;
status: FindingStatus;
}
// Node Detail Types
export interface NodeDetailData extends GraphNode {
relatedFindings?: RelatedFinding[];
incomingEdges?: GraphEdge[];
outgoingEdges?: GraphEdge[];
}
// Wizard State Types
export interface WizardState {
currentStep: 1 | 2;
selectedScanId: string | null;
selectedQuery: string | null;
queryParameters: Record<string, string | number | boolean>;
}
// Graph State Types
export interface GraphState {
data: AttackPathGraphData | null;
selectedNodeId: string | null;
loading: boolean;
error: string | null;
zoomLevel: number;
panX: number;
panY: number;
}
// Provider Integration
export interface ProviderWithScanStatus {
id: string;
alias: string;
provider: string;
scan: AttackPathScan;
connected: boolean;
}
// API Request/Response Helpers
export interface QueryRequestAttributes {
id: string;
parameters?: Record<string, string | number | boolean>;
}
export interface ExecuteQueryRequestData {
type: "attack-paths-query-run-request";
attributes: QueryRequestAttributes;
}
export interface ExecuteQueryRequest {
data: ExecuteQueryRequestData;
}
+1
View File
@@ -33,6 +33,7 @@ export type MenuProps = {
defaultOpen?: boolean;
target?: string;
tooltip?: string;
highlight?: boolean;
};
export type GroupProps = {