mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-22 03:08:23 +00:00
feat(api): add accept header text/plain to attack paths query endpoints for support llm-friendly output (#10162)
Co-authored-by: Adrián Jesús Peña Rodríguez <adrianjpr@gmail.com>
This commit is contained in:
@@ -30,6 +30,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- Attack Paths: Query results now filtered by provider, preventing future cross-tenant and cross-provider data leakage [(#10118)](https://github.com/prowler-cloud/prowler/pull/10118)
|
||||
- Attack Paths: Add private labels and properties in Attack Paths graphs for avoiding future overlapping with Cartography's ones [(#10124)](https://github.com/prowler-cloud/prowler/pull/10124)
|
||||
- Attack Paths: Query endpoint executes them in read only mode [(#10140)](https://github.com/prowler-cloud/prowler/pull/10140)
|
||||
- Attack Paths: `Accept` header query endpoints also accepts `text/plain`, supporting compact plain-text format for LLM consumption [(#10162)](https://github.com/prowler-cloud/prowler/pull/10162)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
|
||||
2
api/poetry.lock
generated
2
api/poetry.lock
generated
@@ -6745,7 +6745,7 @@ tzlocal = "5.3.1"
|
||||
type = "git"
|
||||
url = "https://github.com/prowler-cloud/prowler.git"
|
||||
reference = "master"
|
||||
resolved_reference = "ceb4691c3657e7db3d178896bfc241d14f194295"
|
||||
resolved_reference = "6962622fd21401886371add25463f77228cd9c1f"
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
|
||||
@@ -12,7 +12,7 @@ from api.attack_paths.queries.schema import (
|
||||
RAW_SCHEMA_URL,
|
||||
)
|
||||
from config.custom_logging import BackendLogger
|
||||
from tasks.jobs.attack_paths.config import INTERNAL_LABELS
|
||||
from tasks.jobs.attack_paths.config import INTERNAL_LABELS, INTERNAL_PROPERTIES
|
||||
|
||||
logger = logging.getLogger(BackendLogger.API)
|
||||
|
||||
@@ -125,7 +125,7 @@ def normalize_custom_query_payload(raw_data):
|
||||
if "data" in raw_data and isinstance(raw_data.get("data"), dict):
|
||||
data_section = raw_data.get("data") or {}
|
||||
attributes = data_section.get("attributes") or {}
|
||||
return {"cypher": attributes.get("cypher")}
|
||||
return {"query": attributes.get("query")}
|
||||
|
||||
return raw_data
|
||||
|
||||
@@ -261,7 +261,11 @@ def _filter_labels(labels: Iterable[str]) -> list[str]:
|
||||
|
||||
|
||||
def _serialize_properties(properties: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert Neo4j property values into JSON-serializable primitives."""
|
||||
"""Convert Neo4j property values into JSON-serializable primitives.
|
||||
|
||||
Filters out internal properties (Cartography metadata and provider
|
||||
isolation fields) defined in INTERNAL_PROPERTIES.
|
||||
"""
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
# Neo4j temporal and spatial values expose `to_native` returning Python primitives
|
||||
@@ -276,4 +280,176 @@ def _serialize_properties(properties: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
return value
|
||||
|
||||
return {key: _serialize_value(val) for key, val in properties.items()}
|
||||
return {
|
||||
key: _serialize_value(val)
|
||||
for key, val in properties.items()
|
||||
if key not in INTERNAL_PROPERTIES
|
||||
}
|
||||
|
||||
|
||||
# Text serialization
|
||||
|
||||
|
||||
def serialize_graph_as_text(graph: dict[str, Any]) -> str:
|
||||
"""
|
||||
Convert a serialized graph dict into a compact text format for LLM consumption.
|
||||
|
||||
Follows the incident-encoding pattern (nodes with context + sequential edges)
|
||||
which research shows is optimal for LLM path-reasoning tasks.
|
||||
|
||||
Example::
|
||||
|
||||
>>> serialize_graph_as_text({
|
||||
... "nodes": [
|
||||
... {"id": "n1", "labels": ["AWSAccount"], "properties": {"name": "prod"}},
|
||||
... {"id": "n2", "labels": ["EC2Instance"], "properties": {}},
|
||||
... ],
|
||||
... "relationships": [
|
||||
... {"id": "r1", "label": "RESOURCE", "source": "n1", "target": "n2", "properties": {}},
|
||||
... ],
|
||||
... "total_nodes": 2, "truncated": False,
|
||||
... })
|
||||
## Nodes (2)
|
||||
- AWSAccount "n1" (name: "prod")
|
||||
- EC2Instance "n2"
|
||||
|
||||
## Relationships (1)
|
||||
- AWSAccount "n1" -[RESOURCE]-> EC2Instance "n2"
|
||||
|
||||
## Summary
|
||||
- Total nodes: 2
|
||||
- Truncated: false
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
relationships = graph.get("relationships", [])
|
||||
|
||||
node_lookup = {node["id"]: node for node in nodes}
|
||||
|
||||
lines = [f"## Nodes ({len(nodes)})"]
|
||||
for node in nodes:
|
||||
lines.append(f"- {_format_node_signature(node)}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"## Relationships ({len(relationships)})")
|
||||
for rel in relationships:
|
||||
lines.append(f"- {_format_relationship(rel, node_lookup)}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("## Summary")
|
||||
lines.append(f"- Total nodes: {graph.get('total_nodes', len(nodes))}")
|
||||
lines.append(f"- Truncated: {str(graph.get('truncated', False)).lower()}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_node_signature(node: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a node as its reference followed by its properties.
|
||||
|
||||
Example::
|
||||
|
||||
>>> _format_node_signature({"id": "n1", "labels": ["AWSRole"], "properties": {"name": "admin"}})
|
||||
'AWSRole "n1" (name: "admin")'
|
||||
>>> _format_node_signature({"id": "n2", "labels": ["AWSAccount"], "properties": {}})
|
||||
'AWSAccount "n2"'
|
||||
"""
|
||||
reference = _format_node_reference(node)
|
||||
properties = _format_properties(node.get("properties", {}))
|
||||
|
||||
if properties:
|
||||
return f"{reference} {properties}"
|
||||
|
||||
return reference
|
||||
|
||||
|
||||
def _format_node_reference(node: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a node as labels + quoted id (no properties).
|
||||
|
||||
Example::
|
||||
|
||||
>>> _format_node_reference({"id": "n1", "labels": ["EC2Instance", "NetworkExposed"]})
|
||||
'EC2Instance, NetworkExposed "n1"'
|
||||
"""
|
||||
labels = ", ".join(node.get("labels", []))
|
||||
return f'{labels} "{node["id"]}"'
|
||||
|
||||
|
||||
def _format_relationship(rel: dict[str, Any], node_lookup: dict[str, dict]) -> str:
|
||||
"""
|
||||
Format a relationship as source -[LABEL (props)]-> target.
|
||||
|
||||
Example::
|
||||
|
||||
>>> _format_relationship(
|
||||
... {"id": "r1", "label": "STS_ASSUMEROLE_ALLOW", "source": "n1", "target": "n2",
|
||||
... "properties": {"weight": 1}},
|
||||
... {"n1": {"id": "n1", "labels": ["AWSRole"]},
|
||||
... "n2": {"id": "n2", "labels": ["AWSRole"]}},
|
||||
... )
|
||||
'AWSRole "n1" -[STS_ASSUMEROLE_ALLOW (weight: 1)]-> AWSRole "n2"'
|
||||
"""
|
||||
source = _format_node_reference(node_lookup[rel["source"]])
|
||||
target = _format_node_reference(node_lookup[rel["target"]])
|
||||
|
||||
props = _format_properties(rel.get("properties", {}))
|
||||
label = f"{rel['label']} {props}" if props else rel["label"]
|
||||
|
||||
return f"{source} -[{label}]-> {target}"
|
||||
|
||||
|
||||
def _format_properties(properties: dict[str, Any]) -> str:
|
||||
"""
|
||||
Format properties as a parenthesized key-value list.
|
||||
|
||||
Returns an empty string when no properties are present.
|
||||
|
||||
Example::
|
||||
|
||||
>>> _format_properties({"name": "prod", "account_id": "123456789012"})
|
||||
'(name: "prod", account_id: "123456789012")'
|
||||
>>> _format_properties({})
|
||||
''
|
||||
"""
|
||||
if not properties:
|
||||
return ""
|
||||
|
||||
parts = [f"{k}: {_format_value(v)}" for k, v in properties.items()]
|
||||
return f"({', '.join(parts)})"
|
||||
|
||||
|
||||
def _format_value(value: Any) -> str:
|
||||
"""
|
||||
Format a value using Cypher-style syntax (unquoted dict keys, lowercase bools).
|
||||
|
||||
Example::
|
||||
|
||||
>>> _format_value("prod")
|
||||
'"prod"'
|
||||
>>> _format_value(True)
|
||||
'true'
|
||||
>>> _format_value([80, 443])
|
||||
'[80, 443]'
|
||||
>>> _format_value({"env": "prod"})
|
||||
'{env: "prod"}'
|
||||
>>> _format_value(None)
|
||||
'null'
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return f'"{value}"'
|
||||
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
inner = ", ".join(_format_value(v) for v in value)
|
||||
return f"[{inner}]"
|
||||
|
||||
if isinstance(value, dict):
|
||||
inner = ", ".join(f"{k}: {_format_value(v)}" for k, v in value.items())
|
||||
return f"{{{inner}}}"
|
||||
|
||||
if value is None:
|
||||
return "null"
|
||||
|
||||
return str(value)
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
from rest_framework_json_api.renderers import JSONRenderer
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
|
||||
|
||||
class PlainTextRenderer(BaseRenderer):
|
||||
media_type = "text/plain"
|
||||
format = "text"
|
||||
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||
encoding = self.charset or "utf-8"
|
||||
if isinstance(data, str):
|
||||
return data.encode(encoding)
|
||||
if data is None:
|
||||
return b""
|
||||
return str(data).encode(encoding)
|
||||
|
||||
|
||||
class APIJSONRenderer(JSONRenderer):
|
||||
"""JSONRenderer override to apply tenant RLS when there are included resources in the request."""
|
||||
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||
request = renderer_context.get("request")
|
||||
request = renderer_context.get("request") if renderer_context else None
|
||||
tenant_id = getattr(request, "tenant_id", None) if request else None
|
||||
db_alias = getattr(request, "db_alias", None) if request else None
|
||||
include_param_present = "include" in request.query_params if request else False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -242,22 +242,181 @@ def test_serialize_graph_filters_by_provider_id(attack_paths_graph_stub_classes)
|
||||
assert result["relationships"][0]["id"] == "r1"
|
||||
|
||||
|
||||
# -- serialize_graph_as_text -------------------------------------------------------
|
||||
|
||||
|
||||
def test_serialize_graph_as_text_renders_nodes_and_relationships():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "n1",
|
||||
"labels": ["AWSAccount"],
|
||||
"properties": {"account_id": "123456789012", "name": "prod"},
|
||||
},
|
||||
{
|
||||
"id": "n2",
|
||||
"labels": ["EC2Instance", "NetworkExposed"],
|
||||
"properties": {"name": "web-server-1", "exposed_internet": True},
|
||||
},
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"id": "r1",
|
||||
"label": "RESOURCE",
|
||||
"source": "n1",
|
||||
"target": "n2",
|
||||
"properties": {},
|
||||
},
|
||||
],
|
||||
"total_nodes": 2,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
result = views_helpers.serialize_graph_as_text(graph)
|
||||
|
||||
assert result.startswith("## Nodes (2)")
|
||||
assert '- AWSAccount "n1" (account_id: "123456789012", name: "prod")' in result
|
||||
assert (
|
||||
'- EC2Instance, NetworkExposed "n2" (name: "web-server-1", exposed_internet: true)'
|
||||
in result
|
||||
)
|
||||
assert "## Relationships (1)" in result
|
||||
assert '- AWSAccount "n1" -[RESOURCE]-> EC2Instance, NetworkExposed "n2"' in result
|
||||
assert "## Summary" in result
|
||||
assert "- Total nodes: 2" in result
|
||||
assert "- Truncated: false" in result
|
||||
|
||||
|
||||
def test_serialize_graph_as_text_empty_graph():
|
||||
graph = {
|
||||
"nodes": [],
|
||||
"relationships": [],
|
||||
"total_nodes": 0,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
result = views_helpers.serialize_graph_as_text(graph)
|
||||
|
||||
assert "## Nodes (0)" in result
|
||||
assert "## Relationships (0)" in result
|
||||
assert "- Total nodes: 0" in result
|
||||
assert "- Truncated: false" in result
|
||||
|
||||
|
||||
def test_serialize_graph_as_text_truncated_flag():
|
||||
graph = {
|
||||
"nodes": [{"id": "n1", "labels": ["Node"], "properties": {}}],
|
||||
"relationships": [],
|
||||
"total_nodes": 500,
|
||||
"truncated": True,
|
||||
}
|
||||
|
||||
result = views_helpers.serialize_graph_as_text(graph)
|
||||
|
||||
assert "- Total nodes: 500" in result
|
||||
assert "- Truncated: true" in result
|
||||
|
||||
|
||||
def test_serialize_graph_as_text_relationship_with_properties():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{"id": "n1", "labels": ["AWSRole"], "properties": {"name": "role-a"}},
|
||||
{"id": "n2", "labels": ["AWSRole"], "properties": {"name": "role-b"}},
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"id": "r1",
|
||||
"label": "STS_ASSUMEROLE_ALLOW",
|
||||
"source": "n1",
|
||||
"target": "n2",
|
||||
"properties": {"weight": 1, "reason": "trust-policy"},
|
||||
},
|
||||
],
|
||||
"total_nodes": 2,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
result = views_helpers.serialize_graph_as_text(graph)
|
||||
|
||||
assert '-[STS_ASSUMEROLE_ALLOW (weight: 1, reason: "trust-policy")]->' in result
|
||||
|
||||
|
||||
def test_serialize_properties_filters_internal_fields():
|
||||
properties = {
|
||||
"name": "prod",
|
||||
# Cartography metadata
|
||||
"lastupdated": 1234567890,
|
||||
"firstseen": 1234567800,
|
||||
"_module_name": "cartography:aws",
|
||||
"_module_version": "0.98.0",
|
||||
# Provider isolation
|
||||
"_provider_id": "42",
|
||||
"_provider_element_id": "42:abc123",
|
||||
"provider_id": "42",
|
||||
"provider_element_id": "42:abc123",
|
||||
}
|
||||
|
||||
result = views_helpers._serialize_properties(properties)
|
||||
|
||||
assert result == {"name": "prod"}
|
||||
|
||||
|
||||
def test_serialize_graph_as_text_node_without_properties():
|
||||
graph = {
|
||||
"nodes": [{"id": "n1", "labels": ["AWSAccount"], "properties": {}}],
|
||||
"relationships": [],
|
||||
"total_nodes": 1,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
result = views_helpers.serialize_graph_as_text(graph)
|
||||
|
||||
assert '- AWSAccount "n1"' in result
|
||||
# No trailing parentheses when no properties
|
||||
assert '- AWSAccount "n1" (' not in result
|
||||
|
||||
|
||||
def test_serialize_graph_as_text_complex_property_values():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "n1",
|
||||
"labels": ["SecurityGroup"],
|
||||
"properties": {
|
||||
"ports": [80, 443],
|
||||
"tags": {"env": "prod"},
|
||||
"enabled": None,
|
||||
},
|
||||
},
|
||||
],
|
||||
"relationships": [],
|
||||
"total_nodes": 1,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
result = views_helpers.serialize_graph_as_text(graph)
|
||||
|
||||
assert "ports: [80, 443]" in result
|
||||
assert 'tags: {env: "prod"}' in result
|
||||
assert "enabled: null" in result
|
||||
|
||||
|
||||
# -- normalize_custom_query_payload ------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_custom_query_payload_extracts_cypher():
|
||||
def test_normalize_custom_query_payload_extracts_query():
|
||||
payload = {
|
||||
"data": {
|
||||
"type": "attack-paths-custom-query-run-requests",
|
||||
"attributes": {
|
||||
"cypher": "MATCH (n) RETURN n",
|
||||
"query": "MATCH (n) RETURN n",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result = views_helpers.normalize_custom_query_payload(payload)
|
||||
|
||||
assert result == {"cypher": "MATCH (n) RETURN n"}
|
||||
assert result == {"query": "MATCH (n) RETURN n"}
|
||||
|
||||
|
||||
def test_normalize_custom_query_payload_passthrough_for_non_dict():
|
||||
@@ -266,11 +425,11 @@ def test_normalize_custom_query_payload_passthrough_for_non_dict():
|
||||
|
||||
|
||||
def test_normalize_custom_query_payload_passthrough_for_flat_dict():
|
||||
payload = {"cypher": "MATCH (n) RETURN n"}
|
||||
payload = {"query": "MATCH (n) RETURN n"}
|
||||
|
||||
result = views_helpers.normalize_custom_query_payload(payload)
|
||||
|
||||
assert result == {"cypher": "MATCH (n) RETURN n"}
|
||||
assert result == {"query": "MATCH (n) RETURN n"}
|
||||
|
||||
|
||||
# -- execute_custom_query ----------------------------------------------
|
||||
|
||||
@@ -4049,6 +4049,74 @@ class TestAttackPathsScanViewSet:
|
||||
assert attributes["nodes"] == graph_payload["nodes"]
|
||||
assert attributes["relationships"] == graph_payload["relationships"]
|
||||
|
||||
def test_run_attack_paths_query_returns_text_when_accept_text_plain(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
query_definition = AttackPathsQueryDefinition(
|
||||
id="aws-rds",
|
||||
name="RDS inventory",
|
||||
short_description="List account RDS assets.",
|
||||
description="List account RDS assets",
|
||||
provider=provider.provider,
|
||||
cypher="MATCH (n) RETURN n",
|
||||
parameters=[],
|
||||
)
|
||||
graph_payload = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"labels": ["AWSAccount"],
|
||||
"properties": {"name": "root"},
|
||||
}
|
||||
],
|
||||
"relationships": [],
|
||||
"total_nodes": 1,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
with (
|
||||
patch("api.v1.views.get_query_by_id", return_value=query_definition),
|
||||
patch(
|
||||
"api.v1.views.graph_database.get_database_name",
|
||||
return_value="db-test",
|
||||
),
|
||||
patch(
|
||||
"api.v1.views.attack_paths_views_helpers.prepare_parameters",
|
||||
return_value={"provider_uid": provider.uid},
|
||||
),
|
||||
patch(
|
||||
"api.v1.views.attack_paths_views_helpers.execute_query",
|
||||
return_value=graph_payload,
|
||||
),
|
||||
patch("api.v1.views.graph_database.clear_cache"),
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-run",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._run_payload("aws-rds"),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
HTTP_ACCEPT="text/plain",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response["Content-Type"] == "text/plain"
|
||||
body = response.content.decode()
|
||||
assert "## Nodes (1)" in body
|
||||
assert "## Relationships (0)" in body
|
||||
assert "## Summary" in body
|
||||
|
||||
def test_run_attack_paths_query_blocks_when_graph_data_not_ready(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -4272,11 +4340,11 @@ class TestAttackPathsScanViewSet:
|
||||
# -- run_custom_attack_paths_query action ------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _custom_query_payload(cypher="MATCH (n) RETURN n"):
|
||||
def _custom_query_payload(query="MATCH (n) RETURN n"):
|
||||
return {
|
||||
"data": {
|
||||
"type": "attack-paths-custom-query-run-requests",
|
||||
"attributes": {"cypher": cypher},
|
||||
"attributes": {"query": query},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4339,6 +4407,61 @@ class TestAttackPathsScanViewSet:
|
||||
assert attributes["total_nodes"] == 1
|
||||
assert attributes["truncated"] is False
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_run_custom_query_returns_text_when_accept_text_plain(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
scans_fixture,
|
||||
create_attack_paths_scan,
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
attack_paths_scan = create_attack_paths_scan(
|
||||
provider,
|
||||
scan=scans_fixture[0],
|
||||
graph_data_ready=True,
|
||||
)
|
||||
graph_payload = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"labels": ["AWSAccount"],
|
||||
"properties": {"name": "root"},
|
||||
}
|
||||
],
|
||||
"relationships": [],
|
||||
"total_nodes": 1,
|
||||
"truncated": False,
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.v1.views.attack_paths_views_helpers.execute_custom_query",
|
||||
return_value=graph_payload,
|
||||
),
|
||||
patch(
|
||||
"api.v1.views.graph_database.get_database_name",
|
||||
return_value="db-test",
|
||||
),
|
||||
patch("api.v1.views.graph_database.clear_cache"),
|
||||
):
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"attack-paths-scans-queries-custom",
|
||||
kwargs={"pk": attack_paths_scan.id},
|
||||
),
|
||||
data=self._custom_query_payload(),
|
||||
content_type=API_JSON_CONTENT_TYPE,
|
||||
HTTP_ACCEPT="text/plain",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response["Content-Type"] == "text/plain"
|
||||
body = response.content.decode()
|
||||
assert "## Nodes (1)" in body
|
||||
assert "## Relationships (0)" in body
|
||||
assert "## Summary" in body
|
||||
|
||||
@pytest.mark.skip(reason="Endpoint temporarily blocked")
|
||||
def test_run_custom_query_returns_404_when_no_nodes(
|
||||
self,
|
||||
|
||||
@@ -1220,7 +1220,7 @@ class AttackPathsQueryRunRequestSerializer(BaseSerializerV1):
|
||||
|
||||
|
||||
class AttackPathsCustomQueryRunRequestSerializer(BaseSerializerV1):
|
||||
cypher = serializers.CharField()
|
||||
query = serializers.CharField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "attack-paths-custom-query-run-requests"
|
||||
|
||||
@@ -3,6 +3,7 @@ import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -10,6 +11,7 @@ from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from allauth.socialaccount.models import SocialAccount, SocialApp
|
||||
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
|
||||
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
|
||||
@@ -97,6 +99,7 @@ from api.attack_paths import database as graph_database
|
||||
from api.attack_paths import get_queries_for_provider, get_query_by_id
|
||||
from api.attack_paths import views_helpers as attack_paths_views_helpers
|
||||
from api.base_views import BaseRLSViewSet, BaseTenantViewset, BaseUserViewset
|
||||
from api.renderers import APIJSONRenderer, PlainTextRenderer
|
||||
from api.compliance import (
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
|
||||
get_compliance_frameworks,
|
||||
@@ -2402,7 +2405,7 @@ class TaskViewSet(BaseRLSViewSet):
|
||||
),
|
||||
run_custom_attack_paths_query=extend_schema(
|
||||
tags=["Attack Paths"],
|
||||
summary="Execute a custom Cypher query",
|
||||
summary="Execute a custom openCypher query",
|
||||
description="Execute a raw openCypher query against the Attack Paths graph. "
|
||||
"Results are filtered to the scan's provider and truncated to a maximum node count.",
|
||||
request=AttackPathsCustomQueryRunRequestSerializer,
|
||||
@@ -2525,11 +2528,13 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
|
||||
serializer = AttackPathsQuerySerializer(queries, many=True)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@extend_schema(parameters=[OpenApiParameter("format", exclude=True)])
|
||||
@action(
|
||||
detail=True,
|
||||
methods=["post"],
|
||||
url_path="queries/run",
|
||||
url_name="queries-run",
|
||||
renderer_classes=[APIJSONRenderer, PlainTextRenderer],
|
||||
)
|
||||
def run_attack_paths_query(self, request, pk=None):
|
||||
attack_paths_scan = self.get_object()
|
||||
@@ -2577,14 +2582,20 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
|
||||
if not graph.get("nodes"):
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
|
||||
if isinstance(request.accepted_renderer, PlainTextRenderer):
|
||||
text = attack_paths_views_helpers.serialize_graph_as_text(graph)
|
||||
return Response(text, status=status_code, content_type="text/plain")
|
||||
|
||||
response_serializer = AttackPathsQueryResultSerializer(graph)
|
||||
return Response(response_serializer.data, status=status_code)
|
||||
|
||||
@extend_schema(parameters=[OpenApiParameter("format", exclude=True)])
|
||||
@action(
|
||||
detail=True,
|
||||
methods=["post"],
|
||||
url_path="queries/custom",
|
||||
url_name="queries-custom",
|
||||
renderer_classes=[APIJSONRenderer, PlainTextRenderer],
|
||||
)
|
||||
def run_custom_attack_paths_query(self, request, pk=None):
|
||||
attack_paths_scan = self.get_object()
|
||||
@@ -2609,7 +2620,7 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
|
||||
|
||||
graph = attack_paths_views_helpers.execute_custom_query(
|
||||
database_name,
|
||||
serializer.validated_data["cypher"],
|
||||
serializer.validated_data["query"],
|
||||
provider_id,
|
||||
)
|
||||
graph_database.clear_cache(database_name)
|
||||
@@ -2618,6 +2629,10 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
|
||||
if not graph.get("nodes"):
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
|
||||
if isinstance(request.accepted_renderer, PlainTextRenderer):
|
||||
text = attack_paths_views_helpers.serialize_graph_as_text(graph)
|
||||
return Response(text, status=status_code, content_type="text/plain")
|
||||
|
||||
response_serializer = AttackPathsQueryResultSerializer(graph)
|
||||
return Response(response_serializer.data, status=status_code)
|
||||
|
||||
|
||||
@@ -62,6 +62,27 @@ INTERNAL_LABELS: list[str] = [
|
||||
*[config.deprecated_resource_label for config in PROVIDER_CONFIGS.values()],
|
||||
]
|
||||
|
||||
# Provider isolation properties
|
||||
PROVIDER_ISOLATION_PROPERTIES: list[str] = [
|
||||
"_provider_id",
|
||||
"_provider_element_id",
|
||||
"provider_id",
|
||||
"provider_element_id",
|
||||
]
|
||||
|
||||
# Cartography bookkeeping metadata
|
||||
CARTOGRAPHY_METADATA_PROPERTIES: list[str] = [
|
||||
"lastupdated",
|
||||
"firstseen",
|
||||
"_module_name",
|
||||
"_module_version",
|
||||
]
|
||||
|
||||
INTERNAL_PROPERTIES: list[str] = [
|
||||
*PROVIDER_ISOLATION_PROPERTIES,
|
||||
*CARTOGRAPHY_METADATA_PROPERTIES,
|
||||
]
|
||||
|
||||
|
||||
# Provider Config Accessors
|
||||
# -------------------------
|
||||
|
||||
@@ -14,6 +14,7 @@ from api.attack_paths import database as graph_database
|
||||
from tasks.jobs.attack_paths.config import (
|
||||
BATCH_SIZE,
|
||||
DEPRECATED_PROVIDER_RESOURCE_LABEL,
|
||||
PROVIDER_ISOLATION_PROPERTIES,
|
||||
PROVIDER_RESOURCE_LABEL,
|
||||
)
|
||||
from tasks.jobs.attack_paths.indexes import IndexType, create_indexes
|
||||
@@ -199,11 +200,6 @@ def sync_relationships(
|
||||
|
||||
|
||||
def _strip_internal_properties(props: dict[str, Any]) -> None:
|
||||
"""Remove internal properties that shouldn't be copied during sync."""
|
||||
for key in [
|
||||
"_provider_element_id",
|
||||
"_provider_id",
|
||||
"provider_element_id", # Deprecated
|
||||
"provider_id", # Deprecated
|
||||
]:
|
||||
"""Remove provider isolation properties before the += spread in sync templates."""
|
||||
for key in PROVIDER_ISOLATION_PROPERTIES:
|
||||
props.pop(key, None)
|
||||
|
||||
Reference in New Issue
Block a user