feat(api): add accept header text/plain to attack paths query endpoints for support llm-friendly output (#10162)

Co-authored-by: Adrián Jesús Peña Rodríguez <adrianjpr@gmail.com>
This commit is contained in:
Josema Camacho
2026-02-26 12:53:58 +01:00
committed by GitHub
parent 902558f2d4
commit b3a67fa1a0
11 changed files with 677 additions and 70 deletions

View File

@@ -30,6 +30,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Attack Paths: Query results now filtered by provider, preventing future cross-tenant and cross-provider data leakage [(#10118)](https://github.com/prowler-cloud/prowler/pull/10118)
- Attack Paths: Add private labels and properties in Attack Paths graphs for avoiding future overlapping with Cartography's ones [(#10124)](https://github.com/prowler-cloud/prowler/pull/10124)
- Attack Paths: Query endpoint executes them in read only mode [(#10140)](https://github.com/prowler-cloud/prowler/pull/10140)
- Attack Paths: `Accept` header query endpoints also accepts `text/plain`, supporting compact plain-text format for LLM consumption [(#10162)](https://github.com/prowler-cloud/prowler/pull/10162)
### 🐞 Fixed

2
api/poetry.lock generated
View File

@@ -6745,7 +6745,7 @@ tzlocal = "5.3.1"
type = "git"
url = "https://github.com/prowler-cloud/prowler.git"
reference = "master"
resolved_reference = "ceb4691c3657e7db3d178896bfc241d14f194295"
resolved_reference = "6962622fd21401886371add25463f77228cd9c1f"
[[package]]
name = "psutil"

View File

@@ -12,7 +12,7 @@ from api.attack_paths.queries.schema import (
RAW_SCHEMA_URL,
)
from config.custom_logging import BackendLogger
from tasks.jobs.attack_paths.config import INTERNAL_LABELS
from tasks.jobs.attack_paths.config import INTERNAL_LABELS, INTERNAL_PROPERTIES
logger = logging.getLogger(BackendLogger.API)
@@ -125,7 +125,7 @@ def normalize_custom_query_payload(raw_data):
if "data" in raw_data and isinstance(raw_data.get("data"), dict):
data_section = raw_data.get("data") or {}
attributes = data_section.get("attributes") or {}
return {"cypher": attributes.get("cypher")}
return {"query": attributes.get("query")}
return raw_data
@@ -261,7 +261,11 @@ def _filter_labels(labels: Iterable[str]) -> list[str]:
def _serialize_properties(properties: dict[str, Any]) -> dict[str, Any]:
"""Convert Neo4j property values into JSON-serializable primitives."""
"""Convert Neo4j property values into JSON-serializable primitives.
Filters out internal properties (Cartography metadata and provider
isolation fields) defined in INTERNAL_PROPERTIES.
"""
def _serialize_value(value: Any) -> Any:
# Neo4j temporal and spatial values expose `to_native` returning Python primitives
@@ -276,4 +280,176 @@ def _serialize_properties(properties: dict[str, Any]) -> dict[str, Any]:
return value
return {key: _serialize_value(val) for key, val in properties.items()}
return {
key: _serialize_value(val)
for key, val in properties.items()
if key not in INTERNAL_PROPERTIES
}
# Text serialization
def serialize_graph_as_text(graph: dict[str, Any]) -> str:
"""
Convert a serialized graph dict into a compact text format for LLM consumption.
Follows the incident-encoding pattern (nodes with context + sequential edges)
which research shows is optimal for LLM path-reasoning tasks.
Example::
>>> serialize_graph_as_text({
... "nodes": [
... {"id": "n1", "labels": ["AWSAccount"], "properties": {"name": "prod"}},
... {"id": "n2", "labels": ["EC2Instance"], "properties": {}},
... ],
... "relationships": [
... {"id": "r1", "label": "RESOURCE", "source": "n1", "target": "n2", "properties": {}},
... ],
... "total_nodes": 2, "truncated": False,
... })
## Nodes (2)
- AWSAccount "n1" (name: "prod")
- EC2Instance "n2"
## Relationships (1)
- AWSAccount "n1" -[RESOURCE]-> EC2Instance "n2"
## Summary
- Total nodes: 2
- Truncated: false
"""
nodes = graph.get("nodes", [])
relationships = graph.get("relationships", [])
node_lookup = {node["id"]: node for node in nodes}
lines = [f"## Nodes ({len(nodes)})"]
for node in nodes:
lines.append(f"- {_format_node_signature(node)}")
lines.append("")
lines.append(f"## Relationships ({len(relationships)})")
for rel in relationships:
lines.append(f"- {_format_relationship(rel, node_lookup)}")
lines.append("")
lines.append("## Summary")
lines.append(f"- Total nodes: {graph.get('total_nodes', len(nodes))}")
lines.append(f"- Truncated: {str(graph.get('truncated', False)).lower()}")
return "\n".join(lines)
def _format_node_signature(node: dict[str, Any]) -> str:
"""
Format a node as its reference followed by its properties.
Example::
>>> _format_node_signature({"id": "n1", "labels": ["AWSRole"], "properties": {"name": "admin"}})
'AWSRole "n1" (name: "admin")'
>>> _format_node_signature({"id": "n2", "labels": ["AWSAccount"], "properties": {}})
'AWSAccount "n2"'
"""
reference = _format_node_reference(node)
properties = _format_properties(node.get("properties", {}))
if properties:
return f"{reference} {properties}"
return reference
def _format_node_reference(node: dict[str, Any]) -> str:
"""
Format a node as labels + quoted id (no properties).
Example::
>>> _format_node_reference({"id": "n1", "labels": ["EC2Instance", "NetworkExposed"]})
'EC2Instance, NetworkExposed "n1"'
"""
labels = ", ".join(node.get("labels", []))
return f'{labels} "{node["id"]}"'
def _format_relationship(rel: dict[str, Any], node_lookup: dict[str, dict]) -> str:
"""
Format a relationship as source -[LABEL (props)]-> target.
Example::
>>> _format_relationship(
... {"id": "r1", "label": "STS_ASSUMEROLE_ALLOW", "source": "n1", "target": "n2",
... "properties": {"weight": 1}},
... {"n1": {"id": "n1", "labels": ["AWSRole"]},
... "n2": {"id": "n2", "labels": ["AWSRole"]}},
... )
'AWSRole "n1" -[STS_ASSUMEROLE_ALLOW (weight: 1)]-> AWSRole "n2"'
"""
source = _format_node_reference(node_lookup[rel["source"]])
target = _format_node_reference(node_lookup[rel["target"]])
props = _format_properties(rel.get("properties", {}))
label = f"{rel['label']} {props}" if props else rel["label"]
return f"{source} -[{label}]-> {target}"
def _format_properties(properties: dict[str, Any]) -> str:
"""
Format properties as a parenthesized key-value list.
Returns an empty string when no properties are present.
Example::
>>> _format_properties({"name": "prod", "account_id": "123456789012"})
'(name: "prod", account_id: "123456789012")'
>>> _format_properties({})
''
"""
if not properties:
return ""
parts = [f"{k}: {_format_value(v)}" for k, v in properties.items()]
return f"({', '.join(parts)})"
def _format_value(value: Any) -> str:
"""
Format a value using Cypher-style syntax (unquoted dict keys, lowercase bools).
Example::
>>> _format_value("prod")
'"prod"'
>>> _format_value(True)
'true'
>>> _format_value([80, 443])
'[80, 443]'
>>> _format_value({"env": "prod"})
'{env: "prod"}'
>>> _format_value(None)
'null'
"""
if isinstance(value, str):
return f'"{value}"'
if isinstance(value, bool):
return str(value).lower()
if isinstance(value, (list, tuple)):
inner = ", ".join(_format_value(v) for v in value)
return f"[{inner}]"
if isinstance(value, dict):
inner = ", ".join(f"{k}: {_format_value(v)}" for k, v in value.items())
return f"{{{inner}}}"
if value is None:
return "null"
return str(value)

View File

@@ -1,15 +1,29 @@
from contextlib import nullcontext
from rest_framework.renderers import BaseRenderer
from rest_framework_json_api.renderers import JSONRenderer
from api.db_utils import rls_transaction
class PlainTextRenderer(BaseRenderer):
media_type = "text/plain"
format = "text"
def render(self, data, accepted_media_type=None, renderer_context=None):
encoding = self.charset or "utf-8"
if isinstance(data, str):
return data.encode(encoding)
if data is None:
return b""
return str(data).encode(encoding)
class APIJSONRenderer(JSONRenderer):
"""JSONRenderer override to apply tenant RLS when there are included resources in the request."""
def render(self, data, accepted_media_type=None, renderer_context=None):
request = renderer_context.get("request")
request = renderer_context.get("request") if renderer_context else None
tenant_id = getattr(request, "tenant_id", None) if request else None
db_alias = getattr(request, "db_alias", None) if request else None
include_param_present = "include" in request.query_params if request else False

File diff suppressed because it is too large Load Diff

View File

@@ -242,22 +242,181 @@ def test_serialize_graph_filters_by_provider_id(attack_paths_graph_stub_classes)
assert result["relationships"][0]["id"] == "r1"
# -- serialize_graph_as_text -------------------------------------------------------
def test_serialize_graph_as_text_renders_nodes_and_relationships():
graph = {
"nodes": [
{
"id": "n1",
"labels": ["AWSAccount"],
"properties": {"account_id": "123456789012", "name": "prod"},
},
{
"id": "n2",
"labels": ["EC2Instance", "NetworkExposed"],
"properties": {"name": "web-server-1", "exposed_internet": True},
},
],
"relationships": [
{
"id": "r1",
"label": "RESOURCE",
"source": "n1",
"target": "n2",
"properties": {},
},
],
"total_nodes": 2,
"truncated": False,
}
result = views_helpers.serialize_graph_as_text(graph)
assert result.startswith("## Nodes (2)")
assert '- AWSAccount "n1" (account_id: "123456789012", name: "prod")' in result
assert (
'- EC2Instance, NetworkExposed "n2" (name: "web-server-1", exposed_internet: true)'
in result
)
assert "## Relationships (1)" in result
assert '- AWSAccount "n1" -[RESOURCE]-> EC2Instance, NetworkExposed "n2"' in result
assert "## Summary" in result
assert "- Total nodes: 2" in result
assert "- Truncated: false" in result
def test_serialize_graph_as_text_empty_graph():
graph = {
"nodes": [],
"relationships": [],
"total_nodes": 0,
"truncated": False,
}
result = views_helpers.serialize_graph_as_text(graph)
assert "## Nodes (0)" in result
assert "## Relationships (0)" in result
assert "- Total nodes: 0" in result
assert "- Truncated: false" in result
def test_serialize_graph_as_text_truncated_flag():
graph = {
"nodes": [{"id": "n1", "labels": ["Node"], "properties": {}}],
"relationships": [],
"total_nodes": 500,
"truncated": True,
}
result = views_helpers.serialize_graph_as_text(graph)
assert "- Total nodes: 500" in result
assert "- Truncated: true" in result
def test_serialize_graph_as_text_relationship_with_properties():
graph = {
"nodes": [
{"id": "n1", "labels": ["AWSRole"], "properties": {"name": "role-a"}},
{"id": "n2", "labels": ["AWSRole"], "properties": {"name": "role-b"}},
],
"relationships": [
{
"id": "r1",
"label": "STS_ASSUMEROLE_ALLOW",
"source": "n1",
"target": "n2",
"properties": {"weight": 1, "reason": "trust-policy"},
},
],
"total_nodes": 2,
"truncated": False,
}
result = views_helpers.serialize_graph_as_text(graph)
assert '-[STS_ASSUMEROLE_ALLOW (weight: 1, reason: "trust-policy")]->' in result
def test_serialize_properties_filters_internal_fields():
properties = {
"name": "prod",
# Cartography metadata
"lastupdated": 1234567890,
"firstseen": 1234567800,
"_module_name": "cartography:aws",
"_module_version": "0.98.0",
# Provider isolation
"_provider_id": "42",
"_provider_element_id": "42:abc123",
"provider_id": "42",
"provider_element_id": "42:abc123",
}
result = views_helpers._serialize_properties(properties)
assert result == {"name": "prod"}
def test_serialize_graph_as_text_node_without_properties():
graph = {
"nodes": [{"id": "n1", "labels": ["AWSAccount"], "properties": {}}],
"relationships": [],
"total_nodes": 1,
"truncated": False,
}
result = views_helpers.serialize_graph_as_text(graph)
assert '- AWSAccount "n1"' in result
# No trailing parentheses when no properties
assert '- AWSAccount "n1" (' not in result
def test_serialize_graph_as_text_complex_property_values():
graph = {
"nodes": [
{
"id": "n1",
"labels": ["SecurityGroup"],
"properties": {
"ports": [80, 443],
"tags": {"env": "prod"},
"enabled": None,
},
},
],
"relationships": [],
"total_nodes": 1,
"truncated": False,
}
result = views_helpers.serialize_graph_as_text(graph)
assert "ports: [80, 443]" in result
assert 'tags: {env: "prod"}' in result
assert "enabled: null" in result
# -- normalize_custom_query_payload ------------------------------------------------
def test_normalize_custom_query_payload_extracts_cypher():
def test_normalize_custom_query_payload_extracts_query():
payload = {
"data": {
"type": "attack-paths-custom-query-run-requests",
"attributes": {
"cypher": "MATCH (n) RETURN n",
"query": "MATCH (n) RETURN n",
},
}
}
result = views_helpers.normalize_custom_query_payload(payload)
assert result == {"cypher": "MATCH (n) RETURN n"}
assert result == {"query": "MATCH (n) RETURN n"}
def test_normalize_custom_query_payload_passthrough_for_non_dict():
@@ -266,11 +425,11 @@ def test_normalize_custom_query_payload_passthrough_for_non_dict():
def test_normalize_custom_query_payload_passthrough_for_flat_dict():
payload = {"cypher": "MATCH (n) RETURN n"}
payload = {"query": "MATCH (n) RETURN n"}
result = views_helpers.normalize_custom_query_payload(payload)
assert result == {"cypher": "MATCH (n) RETURN n"}
assert result == {"query": "MATCH (n) RETURN n"}
# -- execute_custom_query ----------------------------------------------

View File

@@ -4049,6 +4049,74 @@ class TestAttackPathsScanViewSet:
assert attributes["nodes"] == graph_payload["nodes"]
assert attributes["relationships"] == graph_payload["relationships"]
def test_run_attack_paths_query_returns_text_when_accept_text_plain(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
query_definition = AttackPathsQueryDefinition(
id="aws-rds",
name="RDS inventory",
short_description="List account RDS assets.",
description="List account RDS assets",
provider=provider.provider,
cypher="MATCH (n) RETURN n",
parameters=[],
)
graph_payload = {
"nodes": [
{
"id": "node-1",
"labels": ["AWSAccount"],
"properties": {"name": "root"},
}
],
"relationships": [],
"total_nodes": 1,
"truncated": False,
}
with (
patch("api.v1.views.get_query_by_id", return_value=query_definition),
patch(
"api.v1.views.graph_database.get_database_name",
return_value="db-test",
),
patch(
"api.v1.views.attack_paths_views_helpers.prepare_parameters",
return_value={"provider_uid": provider.uid},
),
patch(
"api.v1.views.attack_paths_views_helpers.execute_query",
return_value=graph_payload,
),
patch("api.v1.views.graph_database.clear_cache"),
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run",
kwargs={"pk": attack_paths_scan.id},
),
data=self._run_payload("aws-rds"),
content_type=API_JSON_CONTENT_TYPE,
HTTP_ACCEPT="text/plain",
)
assert response.status_code == status.HTTP_200_OK
assert response["Content-Type"] == "text/plain"
body = response.content.decode()
assert "## Nodes (1)" in body
assert "## Relationships (0)" in body
assert "## Summary" in body
def test_run_attack_paths_query_blocks_when_graph_data_not_ready(
self,
authenticated_client,
@@ -4272,11 +4340,11 @@ class TestAttackPathsScanViewSet:
# -- run_custom_attack_paths_query action ------------------------------------
@staticmethod
def _custom_query_payload(cypher="MATCH (n) RETURN n"):
def _custom_query_payload(query="MATCH (n) RETURN n"):
return {
"data": {
"type": "attack-paths-custom-query-run-requests",
"attributes": {"cypher": cypher},
"attributes": {"query": query},
}
}
@@ -4339,6 +4407,61 @@ class TestAttackPathsScanViewSet:
assert attributes["total_nodes"] == 1
assert attributes["truncated"] is False
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_run_custom_query_returns_text_when_accept_text_plain(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_data_ready=True,
)
graph_payload = {
"nodes": [
{
"id": "node-1",
"labels": ["AWSAccount"],
"properties": {"name": "root"},
}
],
"relationships": [],
"total_nodes": 1,
"truncated": False,
}
with (
patch(
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
return_value=graph_payload,
),
patch(
"api.v1.views.graph_database.get_database_name",
return_value="db-test",
),
patch("api.v1.views.graph_database.clear_cache"),
):
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-custom",
kwargs={"pk": attack_paths_scan.id},
),
data=self._custom_query_payload(),
content_type=API_JSON_CONTENT_TYPE,
HTTP_ACCEPT="text/plain",
)
assert response.status_code == status.HTTP_200_OK
assert response["Content-Type"] == "text/plain"
body = response.content.decode()
assert "## Nodes (1)" in body
assert "## Relationships (0)" in body
assert "## Summary" in body
@pytest.mark.skip(reason="Endpoint temporarily blocked")
def test_run_custom_query_returns_404_when_no_nodes(
self,

View File

@@ -1220,7 +1220,7 @@ class AttackPathsQueryRunRequestSerializer(BaseSerializerV1):
class AttackPathsCustomQueryRunRequestSerializer(BaseSerializerV1):
cypher = serializers.CharField()
query = serializers.CharField()
class JSONAPIMeta:
resource_name = "attack-paths-custom-query-run-requests"

View File

@@ -3,6 +3,7 @@ import glob
import json
import logging
import os
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timedelta, timezone
@@ -10,6 +11,7 @@ from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
from urllib.parse import urljoin
import sentry_sdk
from allauth.socialaccount.models import SocialAccount, SocialApp
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
@@ -97,6 +99,7 @@ from api.attack_paths import database as graph_database
from api.attack_paths import get_queries_for_provider, get_query_by_id
from api.attack_paths import views_helpers as attack_paths_views_helpers
from api.base_views import BaseRLSViewSet, BaseTenantViewset, BaseUserViewset
from api.renderers import APIJSONRenderer, PlainTextRenderer
from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
get_compliance_frameworks,
@@ -2402,7 +2405,7 @@ class TaskViewSet(BaseRLSViewSet):
),
run_custom_attack_paths_query=extend_schema(
tags=["Attack Paths"],
summary="Execute a custom Cypher query",
summary="Execute a custom openCypher query",
description="Execute a raw openCypher query against the Attack Paths graph. "
"Results are filtered to the scan's provider and truncated to a maximum node count.",
request=AttackPathsCustomQueryRunRequestSerializer,
@@ -2525,11 +2528,13 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
serializer = AttackPathsQuerySerializer(queries, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@extend_schema(parameters=[OpenApiParameter("format", exclude=True)])
@action(
detail=True,
methods=["post"],
url_path="queries/run",
url_name="queries-run",
renderer_classes=[APIJSONRenderer, PlainTextRenderer],
)
def run_attack_paths_query(self, request, pk=None):
attack_paths_scan = self.get_object()
@@ -2577,14 +2582,20 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
if not graph.get("nodes"):
status_code = status.HTTP_404_NOT_FOUND
if isinstance(request.accepted_renderer, PlainTextRenderer):
text = attack_paths_views_helpers.serialize_graph_as_text(graph)
return Response(text, status=status_code, content_type="text/plain")
response_serializer = AttackPathsQueryResultSerializer(graph)
return Response(response_serializer.data, status=status_code)
@extend_schema(parameters=[OpenApiParameter("format", exclude=True)])
@action(
detail=True,
methods=["post"],
url_path="queries/custom",
url_name="queries-custom",
renderer_classes=[APIJSONRenderer, PlainTextRenderer],
)
def run_custom_attack_paths_query(self, request, pk=None):
attack_paths_scan = self.get_object()
@@ -2609,7 +2620,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
graph = attack_paths_views_helpers.execute_custom_query(
database_name,
serializer.validated_data["cypher"],
serializer.validated_data["query"],
provider_id,
)
graph_database.clear_cache(database_name)
@@ -2618,6 +2629,10 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
if not graph.get("nodes"):
status_code = status.HTTP_404_NOT_FOUND
if isinstance(request.accepted_renderer, PlainTextRenderer):
text = attack_paths_views_helpers.serialize_graph_as_text(graph)
return Response(text, status=status_code, content_type="text/plain")
response_serializer = AttackPathsQueryResultSerializer(graph)
return Response(response_serializer.data, status=status_code)

View File

@@ -62,6 +62,27 @@ INTERNAL_LABELS: list[str] = [
*[config.deprecated_resource_label for config in PROVIDER_CONFIGS.values()],
]
# Provider isolation properties
PROVIDER_ISOLATION_PROPERTIES: list[str] = [
"_provider_id",
"_provider_element_id",
"provider_id",
"provider_element_id",
]
# Cartography bookkeeping metadata
CARTOGRAPHY_METADATA_PROPERTIES: list[str] = [
"lastupdated",
"firstseen",
"_module_name",
"_module_version",
]
INTERNAL_PROPERTIES: list[str] = [
*PROVIDER_ISOLATION_PROPERTIES,
*CARTOGRAPHY_METADATA_PROPERTIES,
]
# Provider Config Accessors
# -------------------------

View File

@@ -14,6 +14,7 @@ from api.attack_paths import database as graph_database
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
DEPRECATED_PROVIDER_RESOURCE_LABEL,
PROVIDER_ISOLATION_PROPERTIES,
PROVIDER_RESOURCE_LABEL,
)
from tasks.jobs.attack_paths.indexes import IndexType, create_indexes
@@ -199,11 +200,6 @@ def sync_relationships(
def _strip_internal_properties(props: dict[str, Any]) -> None:
"""Remove internal properties that shouldn't be copied during sync."""
for key in [
"_provider_element_id",
"_provider_id",
"provider_element_id", # Deprecated
"provider_id", # Deprecated
]:
"""Remove provider isolation properties before the += spread in sync templates."""
for key in PROVIDER_ISOLATION_PROPERTIES:
props.pop(key, None)