mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-05-06 16:58:19 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ab5c70a62a | |||
| 42d692119b | |||
| b2cc15d1cf | |||
| 2c9efdc2ca | |||
| b5f6ae5264 | |||
| bc3cdd492a | |||
| 0366325539 | |||
| 071d6476e2 | |||
| 8573efc53b |
@@ -2,6 +2,24 @@
|
||||
|
||||
All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.21.0] (Prowler UNRELEASED)
|
||||
|
||||
### 🔄 Changed
|
||||
|
||||
- `POST /api/v1/providers` returns `409 Conflict` if already exists [(#10293)](https://github.com/prowler-cloud/prowler/pull/10293)
|
||||
|
||||
---
|
||||
|
||||
## [1.20.1] (Prowler UNRELEASED)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- Attack Paths: Security hardening for custom query endpoint (Cypher blocklist, input validation, rate limiting, Helm lockdown) [(#10238)](https://github.com/prowler-cloud/prowler/pull/10238)
|
||||
- Attack Paths: Add missing logging for query execution and exception details in scan error handling [(#10269)](https://github.com/prowler-cloud/prowler/pull/10269)
|
||||
- Attack Paths: Upgrade Cartography from 0.129.0 to 0.132.0, fixing `exposed_internet` not set on ELB/ELBv2 nodes [(#10272)](https://github.com/prowler-cloud/prowler/pull/10272)
|
||||
|
||||
---
|
||||
|
||||
## [1.20.0] (Prowler v5.19.0)
|
||||
|
||||
### 🚀 Added
|
||||
|
||||
@@ -24,13 +24,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cartography depends on `dockerfile` which has no pre-built arm64 wheel and requires Go to compile
|
||||
# hadolint ignore=DL3008
|
||||
RUN if [ "$(uname -m)" = "aarch64" ]; then \
|
||||
apt-get update && apt-get install -y --no-install-recommends golang-go \
|
||||
&& rm -rf /var/lib/apt/lists/* ; \
|
||||
fi
|
||||
|
||||
# Install PowerShell
|
||||
RUN ARCH=$(uname -m) && \
|
||||
if [ "$ARCH" = "x86_64" ]; then \
|
||||
|
||||
Generated
+387
-32
File diff suppressed because it is too large
Load Diff
+3
-3
@@ -24,7 +24,7 @@ dependencies = [
|
||||
"drf-spectacular-jsonapi==0.5.1",
|
||||
"gunicorn==23.0.0",
|
||||
"lxml==5.3.2",
|
||||
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
|
||||
"prowler @ git+https://github.com/prowler-cloud/prowler.git@v5.19",
|
||||
"psycopg2-binary==2.9.9",
|
||||
"pytest-celery[redis] (>=1.0.1,<2.0.0)",
|
||||
"sentry-sdk[django] (>=2.20.0,<3.0.0)",
|
||||
@@ -37,7 +37,7 @@ dependencies = [
|
||||
"matplotlib (>=3.10.6,<4.0.0)",
|
||||
"reportlab (>=4.4.4,<5.0.0)",
|
||||
"neo4j (>=6.0.0,<7.0.0)",
|
||||
"cartography (==0.129.0)",
|
||||
"cartography (==0.132.0)",
|
||||
"gevent (>=25.9.1,<26.0.0)",
|
||||
"werkzeug (>=3.1.4)",
|
||||
"sqlparse (>=0.5.4)",
|
||||
@@ -49,7 +49,7 @@ name = "prowler-api"
|
||||
package-mode = false
|
||||
# Needed for the SDK compatibility
|
||||
requires-python = ">=3.11,<3.13"
|
||||
version = "1.20.0"
|
||||
version = "1.20.1"
|
||||
|
||||
[project.scripts]
|
||||
celery = "src.backend.config.settings.celery"
|
||||
|
||||
@@ -35,6 +35,7 @@ READ_EXCEPTION_CODES = [
|
||||
"Neo.ClientError.Statement.AccessMode",
|
||||
"Neo.ClientError.Procedure.ProcedureNotFound",
|
||||
]
|
||||
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
|
||||
|
||||
# Module-level process-wide driver singleton
|
||||
_driver: neo4j.Driver | None = None
|
||||
@@ -108,6 +109,7 @@ def get_session(
|
||||
except neo4j.exceptions.Neo4jError as exc:
|
||||
if (
|
||||
default_access_mode == neo4j.READ_ACCESS
|
||||
and exc.code
|
||||
and exc.code in READ_EXCEPTION_CODES
|
||||
):
|
||||
message = "Read query not allowed"
|
||||
@@ -115,6 +117,10 @@ def get_session(
|
||||
raise WriteQueryNotAllowedException(message=message, code=code)
|
||||
|
||||
message = exc.message if exc.message is not None else str(exc)
|
||||
|
||||
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
|
||||
raise ClientStatementException(message=message, code=exc.code)
|
||||
|
||||
raise GraphDatabaseQueryException(message=message, code=exc.code)
|
||||
|
||||
finally:
|
||||
@@ -227,3 +233,7 @@ class GraphDatabaseQueryException(Exception):
|
||||
|
||||
class WriteQueryNotAllowedException(GraphDatabaseQueryException):
|
||||
pass
|
||||
|
||||
|
||||
class ClientStatementException(GraphDatabaseQueryException):
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from typing import Any, Iterable
|
||||
|
||||
@@ -117,6 +118,38 @@ def execute_query(
|
||||
|
||||
# Custom query helpers
|
||||
|
||||
# Patterns that indicate SSRF or dangerous procedure calls
|
||||
# Defense-in-depth layer - the primary control is `neo4j.READ_ACCESS`
|
||||
_BLOCKED_PATTERNS = [
|
||||
re.compile(r"\bLOAD\s+CSV\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.load\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.import\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.export\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.cypher\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.systemdb\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.config\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.periodic\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.do\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.trigger\b", re.IGNORECASE),
|
||||
re.compile(r"\bapoc\.custom\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
# Strip string literals so patterns inside quotes don't cause false positives
|
||||
# Handles escaped quotes (\' and \") inside strings
|
||||
_STRING_LITERALS = re.compile(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"")
|
||||
|
||||
|
||||
def validate_custom_query(cypher: str) -> None:
|
||||
"""Reject queries containing known SSRF or dangerous procedure patterns.
|
||||
|
||||
Raises ValidationError if a blocked pattern is found.
|
||||
String literals are stripped before matching to avoid false positives.
|
||||
"""
|
||||
stripped = _STRING_LITERALS.sub("", cypher)
|
||||
for pattern in _BLOCKED_PATTERNS:
|
||||
if pattern.search(stripped):
|
||||
raise ValidationError({"query": "Query contains a blocked operation"})
|
||||
|
||||
|
||||
def normalize_custom_query_payload(raw_data):
|
||||
if not isinstance(raw_data, dict):
|
||||
@@ -135,6 +168,8 @@ def execute_custom_query(
|
||||
cypher: str,
|
||||
provider_id: str,
|
||||
) -> dict[str, Any]:
|
||||
validate_custom_query(cypher)
|
||||
|
||||
try:
|
||||
graph = graph_database.execute_read_query(
|
||||
database=database_name,
|
||||
@@ -143,6 +178,9 @@ def execute_custom_query(
|
||||
serialized = _serialize_graph(graph, provider_id)
|
||||
return _truncate_graph(serialized)
|
||||
|
||||
except graph_database.ClientStatementException as exc:
|
||||
raise ValidationError({"query": exc.message})
|
||||
|
||||
except graph_database.WriteQueryNotAllowedException:
|
||||
raise PermissionDenied(
|
||||
"Attack Paths query execution failed: read-only queries are enforced"
|
||||
@@ -227,6 +265,12 @@ def _serialize_graph(graph, provider_id: str) -> dict[str, Any]:
|
||||
},
|
||||
)
|
||||
|
||||
filtered_count = len(graph.nodes) - len(nodes)
|
||||
if filtered_count > 0:
|
||||
logger.debug(
|
||||
f"Filtered {filtered_count} nodes without matching provider_id={provider_id}"
|
||||
)
|
||||
|
||||
relationships = []
|
||||
for relationship in graph.relationships:
|
||||
if relationship._properties.get("provider_id") != provider_id:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: Prowler API
|
||||
version: 1.20.0
|
||||
version: 1.20.1
|
||||
description: |-
|
||||
Prowler API specification.
|
||||
|
||||
|
||||
@@ -501,6 +501,72 @@ def test_execute_custom_query_wraps_graph_errors():
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
|
||||
# -- validate_custom_query ------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cypher",
|
||||
[
|
||||
"LOAD CSV FROM 'http://169.254.169.254/' AS x RETURN x",
|
||||
"load csv from 'http://evil.com' as row return row",
|
||||
"CALL apoc.load.json('http://evil.com/') YIELD value RETURN value",
|
||||
"CALL apoc.load.csvParams('http://evil.com/', {}, null) YIELD list RETURN list",
|
||||
"CALL apoc.import.csv([{fileName: 'f'}], [], {}) YIELD node RETURN node",
|
||||
"CALL apoc.export.csv.all('file.csv', {})",
|
||||
"CALL apoc.cypher.run('CREATE (n)', {}) YIELD value RETURN value",
|
||||
"CALL apoc.systemdb.graph() YIELD nodes RETURN nodes",
|
||||
"CALL apoc.config.list() YIELD key, value RETURN key, value",
|
||||
"CALL apoc.periodic.iterate('MATCH (n) RETURN n', 'DELETE n', {batchSize: 100})",
|
||||
"CALL apoc.do.when(true, 'CREATE (n) RETURN n', '', {}) YIELD value RETURN value",
|
||||
"CALL apoc.trigger.add('t', 'RETURN 1', {phase: 'before'})",
|
||||
"CALL apoc.custom.asProcedure('myProc', 'RETURN 1')",
|
||||
],
|
||||
ids=[
|
||||
"LOAD_CSV",
|
||||
"LOAD_CSV_lowercase",
|
||||
"apoc.load.json",
|
||||
"apoc.load.csvParams",
|
||||
"apoc.import.csv",
|
||||
"apoc.export.csv",
|
||||
"apoc.cypher.run",
|
||||
"apoc.systemdb.graph",
|
||||
"apoc.config.list",
|
||||
"apoc.periodic.iterate",
|
||||
"apoc.do.when",
|
||||
"apoc.trigger.add",
|
||||
"apoc.custom.asProcedure",
|
||||
],
|
||||
)
|
||||
def test_validate_custom_query_rejects_blocked_patterns(cypher):
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
views_helpers.validate_custom_query(cypher)
|
||||
|
||||
assert "blocked operation" in str(exc.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cypher",
|
||||
[
|
||||
"MATCH (n:AWSAccount) RETURN n LIMIT 10",
|
||||
"MATCH (a)-[r]->(b) RETURN a, r, b",
|
||||
"MATCH (n) WHERE n.name CONTAINS 'load' RETURN n",
|
||||
"CALL apoc.create.vNode(['Label'], {}) YIELD node RETURN node",
|
||||
"MATCH (n) WHERE n.name = 'apoc.load.json' RETURN n",
|
||||
'MATCH (n) WHERE n.description = "LOAD CSV is cool" RETURN n',
|
||||
],
|
||||
ids=[
|
||||
"simple_match",
|
||||
"traversal",
|
||||
"contains_load_substring",
|
||||
"apoc_virtual_node",
|
||||
"apoc_load_inside_single_quotes",
|
||||
"load_csv_inside_double_quotes",
|
||||
],
|
||||
)
|
||||
def test_validate_custom_query_allows_clean_queries(cypher):
|
||||
views_helpers.validate_custom_query(cypher)
|
||||
|
||||
|
||||
# -- _truncate_graph ----------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@@ -3747,6 +3747,12 @@ class TestTaskViewSet:
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestAttackPathsScanViewSet:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_throttle_cache(self):
|
||||
from django.core.cache import cache
|
||||
|
||||
cache.clear()
|
||||
|
||||
@staticmethod
|
||||
def _run_payload(query_id="aws-rds", parameters=None):
|
||||
return {
|
||||
@@ -4348,8 +4354,6 @@ class TestAttackPathsScanViewSet:
|
||||
}
|
||||
}
|
||||
|
||||
# TODO: Remove skip once queries/custom and schema endpoints are unblocked
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_run_custom_query_returns_graph(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4407,7 +4411,6 @@ class TestAttackPathsScanViewSet:
|
||||
assert attributes["total_nodes"] == 1
|
||||
assert attributes["truncated"] is False
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_run_custom_query_returns_text_when_accept_text_plain(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4462,7 +4465,6 @@ class TestAttackPathsScanViewSet:
|
||||
assert "## Relationships (0)" in body
|
||||
assert "## Summary" in body
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_run_custom_query_returns_404_when_no_nodes(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4504,7 +4506,6 @@ class TestAttackPathsScanViewSet:
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_run_custom_query_returns_400_when_graph_not_ready(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4531,7 +4532,6 @@ class TestAttackPathsScanViewSet:
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "not available" in response.json()["errors"][0]["detail"]
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_run_custom_query_returns_403_for_write_query(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4569,9 +4569,343 @@ class TestAttackPathsScanViewSet:
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# -- SSRF blocklist (HTTP level) ----------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cypher",
|
||||
[
|
||||
"LOAD CSV FROM 'http://169.254.169.254/' AS x RETURN x",
|
||||
"CALL apoc.load.json('http://evil.com/') YIELD value RETURN value",
|
||||
"CALL apoc.import.csv([{fileName: 'f'}], [], {}) YIELD node RETURN node",
|
||||
"CALL apoc.export.csv.all('file.csv', {})",
|
||||
"CALL apoc.cypher.run('CREATE (n)', {}) YIELD value RETURN value",
|
||||
"CALL apoc.systemdb.graph() YIELD nodes RETURN nodes",
|
||||
],
|
||||
ids=[
|
||||
"LOAD_CSV",
|
||||
"apoc.load",
|
||||
"apoc.import",
|
||||
"apoc.export",
|
||||
"apoc.cypher.run",
|
||||
"apoc.systemdb",
|
||||
],
|
||||
)
|
||||
def test_run_custom_query_rejects_ssrf_patterns(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
cypher,
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.v1.views.graph_database.get_database_name",
|
||||
return_value="db-test",
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._custom_query_payload(cypher),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "blocked" in response.json()["errors"][0]["detail"].lower()
|
||||
|
||||
# -- Cross-tenant isolation ---------------------------------------------------
|
||||
|
||||
def test_run_custom_query_returns_404_for_foreign_tenant(
|
||||
self,
|
||||
authenticated_client,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
from api.models import Provider, Tenant
|
||||
|
||||
foreign_tenant = Tenant.objects.create(name="foreign-tenant")
|
||||
foreign_provider = Provider.objects.create(
|
||||
tenant=foreign_tenant,
|
||||
provider="aws",
|
||||
uid="123456789999",
|
||||
)
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
foreign_provider,
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.v1.views.graph_database.get_database_name",
|
||||
return_value="db-test",
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._custom_query_payload(),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_cartography_schema_returns_404_for_foreign_tenant(
|
||||
self,
|
||||
authenticated_client,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
from api.models import Provider, Tenant
|
||||
|
||||
foreign_tenant = Tenant.objects.create(name="foreign-tenant-schema")
|
||||
foreign_provider = Provider.objects.create(
|
||||
tenant=foreign_tenant,
|
||||
provider="aws",
|
||||
uid="123456789998",
|
||||
)
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
foreign_provider,
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse(
|
||||
"attack-paths-scans-schema",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
# -- Authentication / authorization -------------------------------------------
|
||||
|
||||
def test_run_custom_query_returns_401_unauthenticated(
|
||||
self,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
unauthenticated = APIClient()
|
||||
response = unauthenticated.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._custom_query_payload(),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_cartography_schema_returns_401_unauthenticated(
|
||||
self,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
unauthenticated = APIClient()
|
||||
response = unauthenticated.get(
|
||||
reverse(
|
||||
"attack-paths-scans-schema",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_run_custom_query_returns_403_no_manage_scans(
|
||||
self,
|
||||
authenticated_client_no_permissions_rbac,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._custom_query_payload(),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# -- Error leakage ------------------------------------------------------------
|
||||
|
||||
def test_run_custom_query_does_not_leak_internals_on_error(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
|
||||
side_effect=APIException(
|
||||
"Attack Paths query execution failed due to a database error"
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"api.v1.views.graph_database.get_database_name",
|
||||
return_value="db-test",
|
||||
),
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._custom_query_payload(),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
body = json.dumps(response.json()).lower()
|
||||
for forbidden_term in ["neo4j", "bolt://", "syntaxerror", "db-tenant-"]:
|
||||
assert forbidden_term not in body
|
||||
|
||||
# -- Rate limiting (throttle) -------------------------------------------------
|
||||
|
||||
def test_run_custom_query_throttled_after_limit(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
mock_graph = {
|
||||
"nodes": [{"id": "n1", "labels": ["Test"], "properties": {}}],
|
||||
"relationships": [],
|
||||
"total_nodes": 1,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
url = reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
)
|
||||
payload = self._custom_query_payload()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
|
||||
return_value=mock_graph,
|
||||
),
|
||||
patch(
|
||||
"api.v1.views.graph_database.get_database_name",
|
||||
return_value="db-test",
|
||||
),
|
||||
patch(
|
||||
"api.v1.views.graph_database.clear_cache",
|
||||
),
|
||||
):
|
||||
for i in range(11):
|
||||
response = authenticated_client.post(
|
||||
url,
|
||||
data=payload,
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
if i < 10:
|
||||
assert (
|
||||
response.status_code == status.HTTP_200_OK
|
||||
), f"Request {i + 1} should succeed with 200 OK, got {response.status_code}"
|
||||
else:
|
||||
assert (
|
||||
response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
), f"Request {i + 1} should be throttled"
|
||||
|
||||
# -- Timeout simulation -------------------------------------------------------
|
||||
|
||||
def test_run_custom_query_returns_500_on_database_timeout(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
|
||||
side_effect=APIException(
|
||||
"Attack Paths query execution failed due to a database error"
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"api.v1.views.graph_database.get_database_name",
|
||||
return_value="db-test",
|
||||
),
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._custom_query_payload(),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
# -- cartography_schema action ------------------------------------------------
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_cartography_schema_returns_urls(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4621,7 +4955,6 @@ class TestAttackPathsScanViewSet:
|
||||
assert "schema.md" in attributes["schema_url"]
|
||||
assert "raw.githubusercontent.com" in attributes["raw_schema_url"]
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_cartography_schema_returns_404_when_no_metadata(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4656,7 +4989,6 @@ class TestAttackPathsScanViewSet:
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "No cartography schema metadata" in str(response.json())
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_cartography_schema_returns_400_when_graph_not_ready(
|
||||
self,
|
||||
authenticated_client,
|
||||
|
||||
@@ -1220,7 +1220,7 @@ class AttackPathsQueryRunRequestSerializer(BaseSerializerV1):
|
||||
|
||||
|
||||
class AttackPathsCustomQueryRunRequestSerializer(BaseSerializerV1):
|
||||
query = serializers.CharField()
|
||||
query = serializers.CharField(max_length=10000, min_length=1, trim_whitespace=True)
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-custom-query-run-requests"
|
||||
|
||||
@@ -51,6 +51,13 @@ from api.v1.views import (
|
||||
)
|
||||
|
||||
|
||||
# This helper view is used to block any endpoints that should not be available
|
||||
# To use it, add a new entry in the `urlpatterns` list, for example (old but real one):
|
||||
# path(
|
||||
# "attack-paths-scans/<uuid:pk>/queries/custom",
|
||||
# _blocked_endpoint,
|
||||
# name="attack-paths-scans-queries-custom-blocked",
|
||||
# ),
|
||||
@csrf_exempt
|
||||
def _blocked_endpoint(request, *args, **kwargs):
|
||||
return JsonResponse(
|
||||
@@ -209,17 +216,6 @@ urlpatterns = [
|
||||
path("tokens/saml", SAMLTokenValidateView.as_view(), name="token-saml"),
|
||||
path("tokens/google", GoogleSocialLoginView.as_view(), name="token-google"),
|
||||
path("tokens/github", GithubSocialLoginView.as_view(), name="token-github"),
|
||||
# TODO: Remove these blocked endpoints once they are properly tested
|
||||
path(
|
||||
"attack-paths-scans/<uuid:pk>/queries/custom",
|
||||
_blocked_endpoint,
|
||||
name="attack-paths-scans-queries-custom-blocked",
|
||||
),
|
||||
path(
|
||||
"attack-paths-scans/<uuid:pk>/schema",
|
||||
_blocked_endpoint,
|
||||
name="attack-paths-scans-schema-blocked",
|
||||
),
|
||||
path("", include(router.urls)),
|
||||
path("", include(tenants_router.urls)),
|
||||
path("", include(users_router.urls)),
|
||||
|
||||
@@ -3,6 +3,7 @@ import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
@@ -407,7 +408,7 @@ class SchemaView(SpectacularAPIView):
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
spectacular_settings.TITLE = "Prowler API"
|
||||
spectacular_settings.VERSION = "1.20.0"
|
||||
spectacular_settings.VERSION = "1.20.1"
|
||||
spectacular_settings.DESCRIPTION = (
|
||||
"Prowler API specification.\n\nThis file is auto-generated."
|
||||
)
|
||||
@@ -2451,6 +2452,11 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_SCANS]
|
||||
|
||||
def get_throttles(self):
|
||||
if self.action == "run_custom_attack_paths_query":
|
||||
self.throttle_scope = "attack-paths-custom-query"
|
||||
return super().get_throttles()
|
||||
|
||||
def set_required_permissions(self):
|
||||
if self.request.method in SAFE_METHODS:
|
||||
self.required_permissions = []
|
||||
@@ -2570,14 +2576,35 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
|
||||
provider_id,
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
graph = attack_paths_views_helpers.execute_query(
|
||||
database_name,
|
||||
query_definition,
|
||||
parameters,
|
||||
provider_id,
|
||||
)
|
||||
query_duration = time.monotonic() - start
|
||||
graph_database.clear_cache(database_name)
|
||||
|
||||
result_nodes = len(graph.get("nodes", []))
|
||||
result_relationships = len(graph.get("relationships", []))
|
||||
logger.info(
|
||||
"attack_paths_query_run",
|
||||
extra={
|
||||
"user_id": str(request.user.id),
|
||||
"tenant_id": str(attack_paths_scan.provider.tenant_id),
|
||||
"metadata": {
|
||||
"query_id": query_definition.id,
|
||||
"provider": query_definition.provider,
|
||||
"scan_id": pk,
|
||||
"provider_id": provider_id,
|
||||
"result_nodes": result_nodes,
|
||||
"result_relationships": result_relationships,
|
||||
"query_duration": round(query_duration, 3),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
status_code = status.HTTP_200_OK
|
||||
if not graph.get("nodes"):
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
@@ -2618,13 +2645,35 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
|
||||
)
|
||||
provider_id = str(attack_paths_scan.provider_id)
|
||||
|
||||
start = time.monotonic()
|
||||
graph = attack_paths_views_helpers.execute_custom_query(
|
||||
database_name,
|
||||
serializer.validated_data["query"],
|
||||
provider_id,
|
||||
)
|
||||
query_duration = time.monotonic() - start
|
||||
graph_database.clear_cache(database_name)
|
||||
|
||||
query_length = len(serializer.validated_data["query"])
|
||||
result_nodes = len(graph.get("nodes", []))
|
||||
result_relationships = len(graph.get("relationships", []))
|
||||
logger.info(
|
||||
"attack_paths_custom_query_run",
|
||||
extra={
|
||||
"user_id": str(request.user.id),
|
||||
"tenant_id": str(attack_paths_scan.provider.tenant_id),
|
||||
"metadata": {
|
||||
"provider": attack_paths_scan.provider.provider,
|
||||
"scan_id": pk,
|
||||
"provider_id": provider_id,
|
||||
"query_length": query_length,
|
||||
"result_nodes": result_nodes,
|
||||
"result_relationships": result_relationships,
|
||||
"query_duration": round(query_duration, 3),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
status_code = status.HTTP_200_OK
|
||||
if not graph.get("nodes"):
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
from config.env import env
|
||||
from django_guid.log_filters import CorrelationId
|
||||
|
||||
@@ -62,6 +63,8 @@ class NDJSONFormatter(logging.Formatter):
|
||||
log_record["duration"] = record.duration
|
||||
if hasattr(record, "status_code"):
|
||||
log_record["status_code"] = record.status_code
|
||||
if hasattr(record, "metadata"):
|
||||
log_record["metadata"] = record.metadata
|
||||
|
||||
if record.exc_info:
|
||||
log_record["exc_info"] = self.formatException(record.exc_info)
|
||||
@@ -107,6 +110,8 @@ class HumanReadableFormatter(logging.Formatter):
|
||||
log_components.append(f"done in {record.duration}s:")
|
||||
if hasattr(record, "status_code"):
|
||||
log_components.append(f"{record.status_code}")
|
||||
if hasattr(record, "metadata"):
|
||||
log_components.append(f"metadata={record.metadata}")
|
||||
|
||||
if record.exc_info:
|
||||
log_components.append(self.formatException(record.exc_info))
|
||||
|
||||
@@ -113,8 +113,11 @@ REST_FRAMEWORK = {
|
||||
"rest_framework.throttling.ScopedRateThrottle",
|
||||
],
|
||||
"DEFAULT_THROTTLE_RATES": {
|
||||
"token-obtain": env("DJANGO_THROTTLE_TOKEN_OBTAIN", default=None),
|
||||
"dj_rest_auth": None,
|
||||
"token-obtain": env("DJANGO_THROTTLE_TOKEN_OBTAIN", default=None),
|
||||
"attack-paths-custom-query": env(
|
||||
"DJANGO_THROTTLE_ATTACK_PATHS_CUSTOM_QUERY", default="10/min"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ def start_aws_ingestion(
|
||||
"aws_guardduty_severity_threshold": cartography_config.aws_guardduty_severity_threshold,
|
||||
"aws_cloudtrail_management_events_lookback_hours": cartography_config.aws_cloudtrail_management_events_lookback_hours,
|
||||
"experimental_aws_inspector_batch": cartography_config.experimental_aws_inspector_batch,
|
||||
"aws_tagging_api_cleanup_batch": cartography_config.aws_tagging_api_cleanup_batch,
|
||||
}
|
||||
|
||||
boto3_session = get_boto3_session(prowler_api_provider, prowler_sdk_provider)
|
||||
@@ -116,6 +117,30 @@ def start_aws_ingestion(
|
||||
neo4j_session,
|
||||
common_job_parameters,
|
||||
)
|
||||
|
||||
if all(
|
||||
s in requested_syncs
|
||||
for s in ["ecs", "ec2:load_balancer_v2", "ec2:load_balancer_v2:expose"]
|
||||
):
|
||||
logger.info(
|
||||
f"Syncing lb_container_exposure scoped analysis for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_aws.run_scoped_analysis_job(
|
||||
"aws_lb_container_exposure.json",
|
||||
neo4j_session,
|
||||
common_job_parameters,
|
||||
)
|
||||
|
||||
if all(s in requested_syncs for s in ["ec2:network_acls", "ec2:load_balancer_v2"]):
|
||||
logger.info(
|
||||
f"Syncing lb_nacl_direct scoped analysis for AWS account {prowler_api_provider.uid}"
|
||||
)
|
||||
cartography_aws.run_scoped_analysis_job(
|
||||
"aws_lb_nacl_direct.json",
|
||||
neo4j_session,
|
||||
common_job_parameters,
|
||||
)
|
||||
|
||||
db_utils.update_attack_paths_scan_progress(attack_paths_scan, 91)
|
||||
|
||||
logger.info(f"Syncing metadata for AWS account {prowler_api_provider.uid}")
|
||||
@@ -239,8 +264,9 @@ def sync_aws_account(
|
||||
failed_syncs[func_name] = exception_message
|
||||
|
||||
logger.warning(
|
||||
f"Caught exception syncing function {func_name} from AWS account {prowler_api_provider.uid}. We "
|
||||
"are continuing on to the next AWS sync function.",
|
||||
f"Caught exception syncing function {func_name} from AWS account {prowler_api_provider.uid}: {e}. "
|
||||
"Continuing to the next AWS sync function.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
@@ -212,18 +212,20 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
|
||||
try:
|
||||
graph_database.drop_database(tmp_cartography_config.neo4j_database)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup"
|
||||
f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
db_utils.finish_attack_paths_scan(
|
||||
attack_paths_scan, StateChoices.FAILED, ingestion_exceptions
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Could not mark attack paths scan {attack_paths_scan.id} as FAILED (row may have been deleted)"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Could not mark attack paths scan {attack_paths_scan.id} as FAILED (row may have been deleted): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
@@ -558,9 +558,9 @@ neo4j:
|
||||
# Neo4j Configuration (yaml format)
|
||||
config:
|
||||
dbms_security_procedures_allowlist: "apoc.*"
|
||||
dbms_security_procedures_unrestricted: "apoc.*"
|
||||
dbms_security_procedures_unrestricted: ""
|
||||
|
||||
apoc_config:
|
||||
apoc.export.file.enabled: "true"
|
||||
apoc.import.file.enabled: "true"
|
||||
apoc.export.file.enabled: "false"
|
||||
apoc.import.file.enabled: "false"
|
||||
apoc.import.file.use_neo4j_config: "true"
|
||||
|
||||
@@ -121,8 +121,8 @@ To update the environment file:
|
||||
Edit the `.env` file and change version values:
|
||||
|
||||
```env
|
||||
PROWLER_UI_VERSION="5.18.0"
|
||||
PROWLER_API_VERSION="5.18.0"
|
||||
PROWLER_UI_VERSION="5.19.0"
|
||||
PROWLER_API_VERSION="5.19.0"
|
||||
```
|
||||
|
||||
<Note>
|
||||
|
||||
@@ -38,7 +38,7 @@ class _MutableTimestamp:
|
||||
|
||||
timestamp = _MutableTimestamp(datetime.today())
|
||||
timestamp_utc = _MutableTimestamp(datetime.now(timezone.utc))
|
||||
prowler_version = "5.19.0"
|
||||
prowler_version = "5.19.1"
|
||||
html_logo_url = "https://github.com/prowler-cloud/prowler/"
|
||||
square_logo_img = "https://raw.githubusercontent.com/prowler-cloud/prowler/dc7d2d5aeb92fdf12e8604f42ef6472cd3e8e889/docs/img/prowler-logo-black.png"
|
||||
aws_logo = "https://user-images.githubusercontent.com/38561120/235953920-3e3fba08-0795-41dc-b480-9bea57db9f2e.png"
|
||||
|
||||
+1
-1
@@ -94,7 +94,7 @@ maintainers = [{name = "Prowler Engineering", email = "engineering@prowler.com"}
|
||||
name = "prowler"
|
||||
readme = "README.md"
|
||||
requires-python = ">3.9.1,<3.13"
|
||||
version = "5.19.0"
|
||||
version = "5.19.1"
|
||||
|
||||
[project.scripts]
|
||||
prowler = "prowler.__main__:prowler"
|
||||
|
||||
@@ -2,6 +2,20 @@
|
||||
|
||||
All notable changes to the **Prowler UI** are documented in this file.
|
||||
|
||||
|
||||
## [1.19.1] (Prowler v5.19.1 UNRELEASED)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- Provider wizard now closes after updating credentials instead of incorrectly advancing to the Launch Scan step, which caused API errors for providers with existing scheduled scans [(#10278)](https://github.com/prowler-cloud/prowler/pull/10278)
|
||||
- Attack Paths query builder sending stale parameters from previous query selections due to validation schema and default values being recreated on every render [(#10306)](https://github.com/prowler-cloud/prowler/pull/10306)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
- npm transitive dependencies patched to resolve 11 Dependabot alerts (6 HIGH, 4 MEDIUM, 1 LOW): hono, @hono/node-server, fast-xml-parser, serialize-javascript, minimatch [(#10267)](https://github.com/prowler-cloud/prowler/pull/10267)
|
||||
|
||||
---
|
||||
|
||||
## [1.19.0] (Prowler v5.19.0)
|
||||
|
||||
### 🚀 Added
|
||||
|
||||
+9
-12
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { Skeleton } from "@/components/shadcn/skeleton/skeleton";
|
||||
import { TreeSpinner } from "@/components/shadcn/tree-view/tree-spinner";
|
||||
|
||||
/**
|
||||
* Loading skeleton for graph visualization
|
||||
@@ -8,17 +8,14 @@ import { Skeleton } from "@/components/shadcn/skeleton/skeleton";
|
||||
*/
|
||||
export const GraphLoading = () => {
|
||||
return (
|
||||
<div className="dark:bg-prowler-blue-400 flex h-96 items-center justify-center rounded-lg bg-gray-50">
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<div className="flex gap-2">
|
||||
<Skeleton className="h-3 w-3 rounded-full" />
|
||||
<Skeleton className="h-3 w-3 rounded-full" />
|
||||
<Skeleton className="h-3 w-3 rounded-full" />
|
||||
</div>
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Loading Attack Paths graph...
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
data-testid="graph-loading"
|
||||
className="flex min-h-[320px] flex-col items-center justify-center gap-4 text-center"
|
||||
>
|
||||
<TreeSpinner className="size-6" />
|
||||
<p className="text-muted-foreground text-sm">
|
||||
Loading Attack Paths graph...
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
+74
-84
@@ -2,6 +2,7 @@
|
||||
|
||||
import { Controller, useFormContext } from "react-hook-form";
|
||||
|
||||
import { Input } from "@/components/shadcn";
|
||||
import type { AttackPathQuery } from "@/types/attack-paths";
|
||||
|
||||
interface QueryParametersFormProps {
|
||||
@@ -21,14 +22,7 @@ export const QueryParametersForm = ({
|
||||
} = useFormContext();
|
||||
|
||||
if (!selectedQuery || !selectedQuery.attributes.parameters.length) {
|
||||
return (
|
||||
<div className="rounded-lg bg-blue-50 p-4 dark:bg-blue-950/20">
|
||||
<p className="text-sm text-blue-700 dark:text-blue-300">
|
||||
This query requires no parameters. Click "Execute Query" to
|
||||
proceed.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -37,86 +31,82 @@ export const QueryParametersForm = ({
|
||||
Query Parameters
|
||||
</h3>
|
||||
|
||||
{selectedQuery.attributes.parameters.map((param) => (
|
||||
<Controller
|
||||
key={param.name}
|
||||
name={param.name}
|
||||
control={control}
|
||||
render={({ field }) => {
|
||||
if (param.data_type === "boolean") {
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<label className="flex cursor-pointer items-center gap-3">
|
||||
<input
|
||||
type="checkbox"
|
||||
id={param.name}
|
||||
checked={field.value === true || field.value === "true"}
|
||||
onChange={(e) => field.onChange(e.target.checked)}
|
||||
aria-label={param.label}
|
||||
className="border-border-neutral-secondary bg-bg-neutral-primary text-text-primary focus:ring-primary dark:border-border-neutral-secondary dark:bg-bg-neutral-primary dark:text-text-primary h-4 w-4 rounded border focus:ring-2"
|
||||
/>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="text-sm font-medium text-gray-900 dark:text-gray-100">
|
||||
{param.label}
|
||||
</span>
|
||||
{param.description && (
|
||||
<span className="text-xs text-gray-600 dark:text-gray-400">
|
||||
{param.description}
|
||||
<div
|
||||
data-testid="query-parameters-grid"
|
||||
className="grid grid-cols-1 gap-4 md:grid-cols-2"
|
||||
>
|
||||
{selectedQuery.attributes.parameters.map((param) => (
|
||||
<Controller
|
||||
key={param.name}
|
||||
name={param.name}
|
||||
control={control}
|
||||
render={({ field }) => {
|
||||
if (param.data_type === "boolean") {
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<label className="flex cursor-pointer items-center gap-3">
|
||||
<input
|
||||
type="checkbox"
|
||||
id={param.name}
|
||||
checked={field.value === true || field.value === "true"}
|
||||
onChange={(e) => field.onChange(e.target.checked)}
|
||||
aria-label={param.label}
|
||||
className="border-border-neutral-secondary bg-bg-neutral-primary text-text-primary focus:ring-primary dark:border-border-neutral-secondary dark:bg-bg-neutral-primary dark:text-text-primary h-4 w-4 rounded border focus:ring-2"
|
||||
/>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="text-sm font-medium text-gray-900 dark:text-gray-100">
|
||||
{param.label}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{param.description && (
|
||||
<span className="text-xs text-gray-600 dark:text-gray-400">
|
||||
{param.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const errorMessage = (() => {
|
||||
const error = errors[param.name];
|
||||
if (error && typeof error.message === "string") {
|
||||
return error.message;
|
||||
}
|
||||
return undefined;
|
||||
})();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<label
|
||||
htmlFor={param.name}
|
||||
className="text-text-neutral-tertiary text-xs font-medium"
|
||||
>
|
||||
{param.label}
|
||||
{param.required && (
|
||||
<span className="text-text-error-primary">*</span>
|
||||
)}
|
||||
</label>
|
||||
<Input
|
||||
{...field}
|
||||
id={param.name}
|
||||
type={param.data_type === "number" ? "number" : "text"}
|
||||
placeholder={
|
||||
param.description ||
|
||||
param.placeholder ||
|
||||
`Enter ${param.label.toLowerCase()}`
|
||||
}
|
||||
value={field.value ?? ""}
|
||||
/>
|
||||
{errorMessage && (
|
||||
<span className="text-xs text-red-500">{errorMessage}</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const errorMessage = (() => {
|
||||
const error = errors[param.name];
|
||||
if (error && typeof error.message === "string") {
|
||||
return error.message;
|
||||
}
|
||||
return undefined;
|
||||
})();
|
||||
|
||||
const descriptionId = `${param.name}-description`;
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<label
|
||||
htmlFor={param.name}
|
||||
className="text-sm font-medium text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
{param.label}
|
||||
{param.required && <span className="text-red-500"> *</span>}
|
||||
</label>
|
||||
<input
|
||||
{...field}
|
||||
id={param.name}
|
||||
type={param.data_type === "number" ? "number" : "text"}
|
||||
placeholder={
|
||||
param.placeholder || `Enter ${param.label.toLowerCase()}`
|
||||
}
|
||||
value={field.value ?? ""}
|
||||
aria-describedby={
|
||||
param.description ? descriptionId : undefined
|
||||
}
|
||||
className="border-border-neutral-secondary bg-bg-neutral-primary text-text-neutral-primary placeholder-text-neutral-secondary focus:border-border-primary focus:ring-primary dark:border-border-neutral-secondary dark:bg-bg-neutral-primary dark:text-text-neutral-primary dark:placeholder-text-neutral-secondary dark:focus:border-border-primary rounded-md border px-3 py-2 text-sm focus:ring-1 focus:outline-none"
|
||||
/>
|
||||
{param.description && (
|
||||
<span
|
||||
id={descriptionId}
|
||||
className="text-xs text-gray-600 dark:text-gray-400"
|
||||
>
|
||||
{param.description}
|
||||
</span>
|
||||
)}
|
||||
{errorMessage && (
|
||||
<span className="text-xs text-red-500">{errorMessage}</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
+58
-51
@@ -7,6 +7,38 @@ import { z } from "zod";
|
||||
|
||||
import type { AttackPathQuery } from "@/types/attack-paths";
|
||||
|
||||
const getValidationSchema = (query?: AttackPathQuery) => {
|
||||
const schemaObject: Record<string, z.ZodTypeAny> = {};
|
||||
|
||||
query?.attributes.parameters.forEach((param) => {
|
||||
let fieldSchema: z.ZodTypeAny = z
|
||||
.string()
|
||||
.min(1, `${param.label} is required`);
|
||||
|
||||
if (param.data_type === "number") {
|
||||
fieldSchema = z.coerce.number().refine((val) => val >= 0, {
|
||||
message: `${param.label} must be a non-negative number`,
|
||||
});
|
||||
} else if (param.data_type === "boolean") {
|
||||
fieldSchema = z.boolean().default(false);
|
||||
}
|
||||
|
||||
schemaObject[param.name] = fieldSchema;
|
||||
});
|
||||
|
||||
return z.object(schemaObject);
|
||||
};
|
||||
|
||||
const getDefaultValues = (query?: AttackPathQuery) => {
|
||||
const defaults: Record<string, unknown> = {};
|
||||
|
||||
query?.attributes.parameters.forEach((param) => {
|
||||
defaults[param.name] = param.data_type === "boolean" ? false : "";
|
||||
});
|
||||
|
||||
return defaults;
|
||||
};
|
||||
|
||||
/**
|
||||
* Custom hook for managing query builder form state
|
||||
* Handles query selection, parameter validation, and form submission
|
||||
@@ -14,72 +46,47 @@ import type { AttackPathQuery } from "@/types/attack-paths";
|
||||
export const useQueryBuilder = (availableQueries: AttackPathQuery[]) => {
|
||||
const [selectedQuery, setSelectedQuery] = useState<string | null>(null);
|
||||
|
||||
// Generate dynamic Zod schema based on selected query parameters
|
||||
const getValidationSchema = (queryId: string | null) => {
|
||||
const schemaObject: Record<string, z.ZodTypeAny> = {};
|
||||
|
||||
if (queryId) {
|
||||
const query = availableQueries.find((q) => q.id === queryId);
|
||||
|
||||
if (query) {
|
||||
query.attributes.parameters.forEach((param) => {
|
||||
let fieldSchema: z.ZodTypeAny = z
|
||||
.string()
|
||||
.min(1, `${param.label} is required`);
|
||||
|
||||
if (param.data_type === "number") {
|
||||
fieldSchema = z.coerce.number().refine((val) => val >= 0, {
|
||||
message: `${param.label} must be a non-negative number`,
|
||||
});
|
||||
} else if (param.data_type === "boolean") {
|
||||
fieldSchema = z.boolean().default(false);
|
||||
}
|
||||
|
||||
schemaObject[param.name] = fieldSchema;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return z.object(schemaObject);
|
||||
};
|
||||
|
||||
const getDefaultValues = (queryId: string | null) => {
|
||||
const defaults: Record<string, unknown> = {};
|
||||
|
||||
const query = availableQueries.find((q) => q.id === queryId);
|
||||
if (query) {
|
||||
query.attributes.parameters.forEach((param) => {
|
||||
defaults[param.name] = param.data_type === "boolean" ? false : "";
|
||||
});
|
||||
}
|
||||
|
||||
return defaults;
|
||||
};
|
||||
const getQueryById = (queryId: string | null) =>
|
||||
availableQueries.find((query) => query.id === queryId);
|
||||
const selectedQueryData = getQueryById(selectedQuery);
|
||||
|
||||
const form = useForm({
|
||||
resolver: zodResolver(getValidationSchema(selectedQuery)),
|
||||
resolver: zodResolver(getValidationSchema(selectedQueryData)),
|
||||
mode: "onChange",
|
||||
defaultValues: getDefaultValues(selectedQuery),
|
||||
defaultValues: getDefaultValues(selectedQueryData),
|
||||
shouldUnregister: true,
|
||||
});
|
||||
|
||||
// Update form when selectedQuery changes
|
||||
useEffect(() => {
|
||||
form.reset(getDefaultValues(selectedQuery), {
|
||||
form.reset(getDefaultValues(selectedQueryData), {
|
||||
keepDirtyValues: false,
|
||||
});
|
||||
}, [selectedQuery]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const selectedQueryData = availableQueries.find(
|
||||
(q) => q.id === selectedQuery,
|
||||
);
|
||||
}, [form, selectedQueryData]);
|
||||
|
||||
const handleQueryChange = (queryId: string) => {
|
||||
setSelectedQuery(queryId);
|
||||
form.reset();
|
||||
};
|
||||
|
||||
const getQueryParameters = () => {
|
||||
return form.getValues();
|
||||
if (!selectedQueryData?.attributes.parameters.length) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const values = form.getValues() as Record<
|
||||
string,
|
||||
string | number | boolean
|
||||
>;
|
||||
|
||||
return selectedQueryData.attributes.parameters.reduce<
|
||||
Record<string, string | number | boolean>
|
||||
>((parameters, parameter) => {
|
||||
const value = values[parameter.name];
|
||||
if (value !== undefined) {
|
||||
parameters[parameter.name] = value;
|
||||
}
|
||||
return parameters;
|
||||
}, {});
|
||||
};
|
||||
|
||||
const isFormValid = () => {
|
||||
|
||||
@@ -121,7 +121,7 @@ describe("useProviderWizardController", () => {
|
||||
expect(onOpenChange).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("moves to launch step after a successful connection test in update mode", async () => {
|
||||
it("closes the modal after a successful connection test in update mode", async () => {
|
||||
// Given
|
||||
const onOpenChange = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
@@ -149,9 +149,8 @@ describe("useProviderWizardController", () => {
|
||||
result.current.handleTestSuccess();
|
||||
});
|
||||
|
||||
// Then
|
||||
expect(result.current.currentStep).toBe(PROVIDER_WIZARD_STEP.LAUNCH);
|
||||
expect(onOpenChange).not.toHaveBeenCalled();
|
||||
// Then — update mode should close the modal, not advance to launch
|
||||
expect(onOpenChange).toHaveBeenCalledWith(false);
|
||||
});
|
||||
|
||||
it("does not override launch footer config in the controller", () => {
|
||||
|
||||
@@ -173,6 +173,10 @@ export function useProviderWizardController({
|
||||
};
|
||||
|
||||
const handleTestSuccess = () => {
|
||||
if (mode === PROVIDER_WIZARD_MODE.UPDATE) {
|
||||
handleClose();
|
||||
return;
|
||||
}
|
||||
setCurrentStep(PROVIDER_WIZARD_STEP.LAUNCH);
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user