fix(api): add security hardening for Attack Paths custom query endpoint (#10238)

This commit is contained in:
Josema Camacho
2026-03-12 10:46:29 +01:00
committed by GitHub
parent e0d61ba5d1
commit 4dc3765670
10 changed files with 482 additions and 25 deletions
+1
View File
@@ -15,6 +15,7 @@ All notable changes to the **Prowler API** are documented in this file.
### 🐞 Fixed
- Attack Paths: Security hardening for custom query endpoint (Cypher blocklist, input validation, rate limiting, Helm lockdown) [(#10238)](https://github.com/prowler-cloud/prowler/pull/10238)
- Attack Paths: Add missing logging for query execution and exception details in scan error handling [(#10269)](https://github.com/prowler-cloud/prowler/pull/10269)
- Attack Paths: Upgrade Cartography from 0.129.0 to 0.132.0, fixing `exposed_internet` not set on ELB/ELBv2 nodes [(#10272)](https://github.com/prowler-cloud/prowler/pull/10272)
@@ -35,6 +35,7 @@ READ_EXCEPTION_CODES = [
"Neo.ClientError.Statement.AccessMode",
"Neo.ClientError.Procedure.ProcedureNotFound",
]
CLIENT_STATEMENT_EXCEPTION_PREFIX = "Neo.ClientError.Statement."
# Module-level process-wide driver singleton
_driver: neo4j.Driver | None = None
@@ -108,6 +109,7 @@ def get_session(
except neo4j.exceptions.Neo4jError as exc:
if (
default_access_mode == neo4j.READ_ACCESS
and exc.code
and exc.code in READ_EXCEPTION_CODES
):
message = "Read query not allowed"
@@ -115,6 +117,10 @@ def get_session(
raise WriteQueryNotAllowedException(message=message, code=code)
message = exc.message if exc.message is not None else str(exc)
if exc.code and exc.code.startswith(CLIENT_STATEMENT_EXCEPTION_PREFIX):
raise ClientStatementException(message=message, code=exc.code)
raise GraphDatabaseQueryException(message=message, code=exc.code)
finally:
@@ -227,3 +233,7 @@ class GraphDatabaseQueryException(Exception):
class WriteQueryNotAllowedException(GraphDatabaseQueryException):
pass
class ClientStatementException(GraphDatabaseQueryException):
pass
@@ -1,4 +1,5 @@
import logging
import re
from typing import Any, Iterable
@@ -117,6 +118,38 @@ def execute_query(
# Custom query helpers
# Patterns that indicate SSRF or dangerous procedure calls
# Defense-in-depth layer - the primary control is `neo4j.READ_ACCESS`
_BLOCKED_PATTERNS = [
re.compile(r"\bLOAD\s+CSV\b", re.IGNORECASE),
re.compile(r"\bapoc\.load\b", re.IGNORECASE),
re.compile(r"\bapoc\.import\b", re.IGNORECASE),
re.compile(r"\bapoc\.export\b", re.IGNORECASE),
re.compile(r"\bapoc\.cypher\b", re.IGNORECASE),
re.compile(r"\bapoc\.systemdb\b", re.IGNORECASE),
re.compile(r"\bapoc\.config\b", re.IGNORECASE),
re.compile(r"\bapoc\.periodic\b", re.IGNORECASE),
re.compile(r"\bapoc\.do\b", re.IGNORECASE),
re.compile(r"\bapoc\.trigger\b", re.IGNORECASE),
re.compile(r"\bapoc\.custom\b", re.IGNORECASE),
]
# Strip string literals so patterns inside quotes don't cause false positives
# Handles escaped quotes (\' and \") inside strings
_STRING_LITERALS = re.compile(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"")
def validate_custom_query(cypher: str) -> None:
"""Reject queries containing known SSRF or dangerous procedure patterns.
Raises ValidationError if a blocked pattern is found.
String literals are stripped before matching to avoid false positives.
"""
stripped = _STRING_LITERALS.sub("", cypher)
for pattern in _BLOCKED_PATTERNS:
if pattern.search(stripped):
raise ValidationError({"query": "Query contains a blocked operation"})
def normalize_custom_query_payload(raw_data):
if not isinstance(raw_data, dict):
@@ -135,6 +168,8 @@ def execute_custom_query(
cypher: str,
provider_id: str,
) -> dict[str, Any]:
validate_custom_query(cypher)
try:
graph = graph_database.execute_read_query(
database=database_name,
@@ -143,6 +178,9 @@ def execute_custom_query(
serialized = _serialize_graph(graph, provider_id)
return _truncate_graph(serialized)
except graph_database.ClientStatementException as exc:
raise ValidationError({"query": exc.message})
except graph_database.WriteQueryNotAllowedException:
raise PermissionDenied(
"Attack Paths query execution failed: read-only queries are enforced"
@@ -227,6 +265,12 @@ def _serialize_graph(graph, provider_id: str) -> dict[str, Any]:
},
)
filtered_count = len(graph.nodes) - len(nodes)
if filtered_count > 0:
logger.debug(
f"Filtered {filtered_count} nodes without matching provider_id={provider_id}"
)
relationships = []
for relationship in graph.relationships:
if relationship._properties.get("provider_id") != provider_id:
@@ -501,6 +501,72 @@ def test_execute_custom_query_wraps_graph_errors():
mock_logger.error.assert_called_once()
# -- validate_custom_query ------------------------------------------------
@pytest.mark.parametrize(
"cypher",
[
"LOAD CSV FROM 'http://169.254.169.254/' AS x RETURN x",
"load csv from 'http://evil.com' as row return row",
"CALL apoc.load.json('http://evil.com/') YIELD value RETURN value",
"CALL apoc.load.csvParams('http://evil.com/', {}, null) YIELD list RETURN list",
"CALL apoc.import.csv([{fileName: 'f'}], [], {}) YIELD node RETURN node",
"CALL apoc.export.csv.all('file.csv', {})",
"CALL apoc.cypher.run('CREATE (n)', {}) YIELD value RETURN value",
"CALL apoc.systemdb.graph() YIELD nodes RETURN nodes",
"CALL apoc.config.list() YIELD key, value RETURN key, value",
"CALL apoc.periodic.iterate('MATCH (n) RETURN n', 'DELETE n', {batchSize: 100})",
"CALL apoc.do.when(true, 'CREATE (n) RETURN n', '', {}) YIELD value RETURN value",
"CALL apoc.trigger.add('t', 'RETURN 1', {phase: 'before'})",
"CALL apoc.custom.asProcedure('myProc', 'RETURN 1')",
],
ids=[
"LOAD_CSV",
"LOAD_CSV_lowercase",
"apoc.load.json",
"apoc.load.csvParams",
"apoc.import.csv",
"apoc.export.csv",
"apoc.cypher.run",
"apoc.systemdb.graph",
"apoc.config.list",
"apoc.periodic.iterate",
"apoc.do.when",
"apoc.trigger.add",
"apoc.custom.asProcedure",
],
)
def test_validate_custom_query_rejects_blocked_patterns(cypher):
with pytest.raises(ValidationError) as exc:
views_helpers.validate_custom_query(cypher)
assert "blocked operation" in str(exc.value.detail)
@pytest.mark.parametrize(
"cypher",
[
"MATCH (n:AWSAccount) RETURN n LIMIT 10",
"MATCH (a)-[r]->(b) RETURN a, r, b",
"MATCH (n) WHERE n.name CONTAINS 'load' RETURN n",
"CALL apoc.create.vNode(['Label'], {}) YIELD node RETURN node",
"MATCH (n) WHERE n.name = 'apoc.load.json' RETURN n",
'MATCH (n) WHERE n.description = "LOAD CSV is cool" RETURN n',
],
ids=[
"simple_match",
"traversal",
"contains_load_substring",
"apoc_virtual_node",
"apoc_load_inside_single_quotes",
"load_csv_inside_double_quotes",
],
)
def test_validate_custom_query_allows_clean_queries(cypher):
views_helpers.validate_custom_query(cypher)
# -- _truncate_graph ----------------------------------------------------------
+341 -9
View File
@@ -3810,6 +3810,12 @@ class TestTaskViewSet:
@pytest.mark.django_db
class TestAttackPathsScanViewSet:
@pytest.fixture(autouse=True)
def _clear_throttle_cache(self):
from django.core.cache import cache
cache.clear()
@staticmethod
def _run_payload(query_id="aws-rds", parameters=None):
return {
@@ -4411,8 +4417,6 @@ class TestAttackPathsScanViewSet:
}
}
# TODO: Remove skip once queries/custom and schema endpoints are unblocked
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_run_custom_query_returns_graph(
self,
authenticated_client,
@@ -4470,7 +4474,6 @@ class TestAttackPathsScanViewSet:
assert attributes["total_nodes"] == 1
assert attributes["truncated"] is False
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_run_custom_query_returns_text_when_accept_text_plain(
self,
authenticated_client,
@@ -4525,7 +4528,6 @@ class TestAttackPathsScanViewSet:
assert "## Relationships (0)" in body
assert "## Summary" in body
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_run_custom_query_returns_404_when_no_nodes(
self,
authenticated_client,
@@ -4567,7 +4569,6 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_run_custom_query_returns_400_when_graph_not_ready(
self,
authenticated_client,
@@ -4594,7 +4595,6 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "not available" in response.json()["errors"][0]["detail"]
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_run_custom_query_returns_403_for_write_query(
self,
authenticated_client,
@@ -4632,9 +4632,343 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_403_FORBIDDEN
# -- SSRF blocklist (HTTP level) ----------------------------------------------
@pytest.mark.parametrize(
"cypher",
[
"LOAD CSV FROM 'http://169.254.169.254/' AS x RETURN x",
"CALL apoc.load.json('http://evil.com/') YIELD value RETURN value",
"CALL apoc.import.csv([{fileName: 'f'}], [], {}) YIELD node RETURN node",
"CALL apoc.export.csv.all('file.csv', {})",
"CALL apoc.cypher.run('CREATE (n)', {}) YIELD value RETURN value",
"CALL apoc.systemdb.graph() YIELD nodes RETURN nodes",
],
ids=[
"LOAD_CSV",
"apoc.load",
"apoc.import",
"apoc.export",
"apoc.cypher.run",
"apoc.systemdb",
],
)
def test_run_custom_query_rejects_ssrf_patterns(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
cypher,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
with patch(
"api.v1.views.graph_database.get_database_name",
return_value="db-test",
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
),
data=self._custom_query_payload(cypher),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "blocked" in response.json()["errors"][0]["detail"].lower()
# -- Cross-tenant isolation ---------------------------------------------------
def test_run_custom_query_returns_404_for_foreign_tenant(
self,
authenticated_client,
create_attack_paths_scan,
):
from api.models import Provider, Tenant
foreign_tenant = Tenant.objects.create(name="foreign-tenant")
foreign_provider = Provider.objects.create(
tenant=foreign_tenant,
provider="aws",
uid="123456789999",
)
attack_paths_scan = create_attack_paths_scan(
foreign_provider,
graph_data_ready=True,
)
with patch(
"api.v1.views.graph_database.get_database_name",
return_value="db-test",
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
),
data=self._custom_query_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cartography_schema_returns_404_for_foreign_tenant(
self,
authenticated_client,
create_attack_paths_scan,
):
from api.models import Provider, Tenant
foreign_tenant = Tenant.objects.create(name="foreign-tenant-schema")
foreign_provider = Provider.objects.create(
tenant=foreign_tenant,
provider="aws",
uid="123456789998",
)
attack_paths_scan = create_attack_paths_scan(
foreign_provider,
graph_data_ready=True,
)
response = authenticated_client.get(
reverse(
"attack-paths-scans-schema",
kwargs={"pk": attack_paths_scan.id},
)
)
assert response.status_code == status.HTTP_404_NOT_FOUND
# -- Authentication / authorization -------------------------------------------
def test_run_custom_query_returns_401_unauthenticated(
self,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
from rest_framework.test import APIClient
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
unauthenticated = APIClient()
response = unauthenticated.post(
reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
),
data=self._custom_query_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_cartography_schema_returns_401_unauthenticated(
self,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
from rest_framework.test import APIClient
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
unauthenticated = APIClient()
response = unauthenticated.get(
reverse(
"attack-paths-scans-schema",
kwargs={"pk": attack_paths_scan.id},
)
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_run_custom_query_returns_403_no_manage_scans(
self,
authenticated_client_no_permissions_rbac,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
response = authenticated_client_no_permissions_rbac.post(
reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
),
data=self._custom_query_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_403_FORBIDDEN
# -- Error leakage ------------------------------------------------------------
def test_run_custom_query_does_not_leak_internals_on_error(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
from rest_framework.exceptions import APIException
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
with (
patch(
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
side_effect=APIException(
"Attack Paths query execution failed due to a database error"
),
),
patch(
"api.v1.views.graph_database.get_database_name",
return_value="db-test",
),
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
),
data=self._custom_query_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
body = json.dumps(response.json()).lower()
for forbidden_term in ["neo4j", "bolt://", "syntaxerror", "db-tenant-"]:
assert forbidden_term not in body
# -- Rate limiting (throttle) -------------------------------------------------
def test_run_custom_query_throttled_after_limit(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
mock_graph = {
"nodes": [{"id": "n1", "labels": ["Test"], "properties": {}}],
"relationships": [],
"total_nodes": 1,
"truncated": False,
}
url = reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
)
payload = self._custom_query_payload()
with (
patch(
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
return_value=mock_graph,
),
patch(
"api.v1.views.graph_database.get_database_name",
return_value="db-test",
),
patch(
"api.v1.views.graph_database.clear_cache",
),
):
for i in range(11):
response = authenticated_client.post(
url,
data=payload,
content_type=API_JSON_CONTENT_TYPE,
)
if i < 10:
assert (
response.status_code == status.HTTP_200_OK
), f"Request {i + 1} should succeed with 200 OK, got {response.status_code}"
else:
assert (
response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
), f"Request {i + 1} should be throttled"
# -- Timeout simulation -------------------------------------------------------
def test_run_custom_query_returns_500_on_database_timeout(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
from rest_framework.exceptions import APIException
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
with (
patch(
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
side_effect=APIException(
"Attack Paths query execution failed due to a database error"
),
),
patch(
"api.v1.views.graph_database.get_database_name",
return_value="db-test",
),
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
),
data=self._custom_query_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# -- cartography_schema action ------------------------------------------------
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_cartography_schema_returns_urls(
self,
authenticated_client,
@@ -4684,7 +5018,6 @@ class TestAttackPathsScanViewSet:
assert "schema.md" in attributes["schema_url"]
assert "raw.githubusercontent.com" in attributes["raw_schema_url"]
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_cartography_schema_returns_404_when_no_metadata(
self,
authenticated_client,
@@ -4719,7 +5052,6 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "No cartography schema metadata" in str(response.json())
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_cartography_schema_returns_400_when_graph_not_ready(
self,
authenticated_client,
+1 -1
View File
@@ -1241,7 +1241,7 @@ class AttackPathsQueryRunRequestSerializer(BaseSerializerV1):
class AttackPathsCustomQueryRunRequestSerializer(BaseSerializerV1):
query = serializers.CharField()
query = serializers.CharField(max_length=10000, min_length=1, trim_whitespace=True)
class JSONAPIMeta:
resource_name = "attack-paths-custom-query-run-requests"
+7 -11
View File
@@ -51,6 +51,13 @@ from api.v1.views import (
)
# This helper view is used to block any endpoints that should not be available
# To use it, add a new entry in the `urlpatterns` list, for example (old but real one):
# path(
# "attack-paths-scans/<uuid:pk>/queries/custom",
# _blocked_endpoint,
# name="attack-paths-scans-queries-custom-blocked",
# ),
@csrf_exempt
def _blocked_endpoint(request, *args, **kwargs):
return JsonResponse(
@@ -209,17 +216,6 @@ urlpatterns = [
path("tokens/saml", SAMLTokenValidateView.as_view(), name="token-saml"),
path("tokens/google", GoogleSocialLoginView.as_view(), name="token-google"),
path("tokens/github", GithubSocialLoginView.as_view(), name="token-github"),
# TODO: Remove these blocked endpoints once they are properly tested
path(
"attack-paths-scans/<uuid:pk>/queries/custom",
_blocked_endpoint,
name="attack-paths-scans-queries-custom-blocked",
),
path(
"attack-paths-scans/<uuid:pk>/schema",
_blocked_endpoint,
name="attack-paths-scans-schema-blocked",
),
path("", include(router.urls)),
path("", include(tenants_router.urls)),
path("", include(users_router.urls)),
+5
View File
@@ -2452,6 +2452,11 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
# RBAC required permissions
required_permissions = [Permissions.MANAGE_SCANS]
def get_throttles(self):
if self.action == "run_custom_attack_paths_query":
self.throttle_scope = "attack-paths-custom-query"
return super().get_throttles()
def set_required_permissions(self):
if self.request.method in SAFE_METHODS:
self.required_permissions = []
+4 -1
View File
@@ -113,8 +113,11 @@ REST_FRAMEWORK = {
"rest_framework.throttling.ScopedRateThrottle",
],
"DEFAULT_THROTTLE_RATES": {
"token-obtain": env("DJANGO_THROTTLE_TOKEN_OBTAIN", default=None),
"dj_rest_auth": None,
"token-obtain": env("DJANGO_THROTTLE_TOKEN_OBTAIN", default=None),
"attack-paths-custom-query": env(
"DJANGO_THROTTLE_ATTACK_PATHS_CUSTOM_QUERY", default="10/min"
),
},
}
+3 -3
View File
@@ -558,9 +558,9 @@ neo4j:
# Neo4j Configuration (yaml format)
config:
dbms_security_procedures_allowlist: "apoc.*"
dbms_security_procedures_unrestricted: "apoc.*"
dbms_security_procedures_unrestricted: ""
apoc_config:
apoc.export.file.enabled: "true"
apoc.import.file.enabled: "true"
apoc.export.file.enabled: "false"
apoc.import.file.enabled: "false"
apoc.import.file.use_neo4j_config: "true"