feat(api): filter Attack Paths query results by provider_id (#10118)

This commit is contained in:
Josema Camacho
2026-02-23 11:06:30 +01:00
committed by GitHub
parent c12f27413d
commit b5d2a75151
6 changed files with 90 additions and 21 deletions

View File

@@ -24,6 +24,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Attack Paths: Add `graph_data_ready` field to decouple query availability from scan state [(#10089)](https://github.com/prowler-cloud/prowler/pull/10089)
- AI agent guidelines with TDD and testing skills references [(#9925)](https://github.com/prowler-cloud/prowler/pull/9925)
- Attack Paths: Upgrade Cartography from fork 0.126.1 to upstream 0.129.0 and Neo4j driver from 5.x to 6.x [(#10110)](https://github.com/prowler-cloud/prowler/pull/10110)
- Attack Paths: Query results now filtered by provider, preventing future cross-tenant and cross-provider data leakage [(#10118)](https://github.com/prowler-cloud/prowler/pull/10118)
### 🐞 Fixed

View File

@@ -16,7 +16,7 @@ AWS_INTERNET_EXPOSED_EC2_SENSITIVE_S3_ACCESS = AttackPathsQueryDefinition(
description="Detect EC2 instances with SSH exposed to the internet that can assume higher-privileged roles to read tagged sensitive S3 buckets despite bucket-level public access blocks.",
provider="aws",
cypher=f"""
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
YIELD node AS internet
MATCH path_s3 = (aws:AWSAccount {{id: $provider_uid}})--(s3:S3Bucket)--(t:AWSTag)
@@ -32,7 +32,7 @@ AWS_INTERNET_EXPOSED_EC2_SENSITIVE_S3_ACCESS = AttackPathsQueryDefinition(
MATCH path_assume_role = (ec2)-[p:STS_ASSUMEROLE_ALLOW*1..9]-(r:AWSRole)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, ec2)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, ec2)
YIELD rel AS can_access
UNWIND nodes(path_s3) + nodes(path_ec2) + nodes(path_role) + nodes(path_assume_role) as n
@@ -181,13 +181,13 @@ AWS_EC2_INSTANCES_INTERNET_EXPOSED = AttackPathsQueryDefinition(
description="Find EC2 instances flagged as exposed to the internet within the selected account.",
provider="aws",
cypher=f"""
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
YIELD node AS internet
MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(ec2:EC2Instance)
WHERE ec2.exposed_internet = true
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, ec2)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, ec2)
YIELD rel AS can_access
UNWIND nodes(path) as n
@@ -205,7 +205,7 @@ AWS_SECURITY_GROUPS_OPEN_INTERNET_FACING = AttackPathsQueryDefinition(
description="Find internet-facing resources associated with security groups that allow inbound access from '0.0.0.0/0'.",
provider="aws",
cypher=f"""
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
YIELD node AS internet
// Match EC2 instances that are internet-exposed with open security groups (0.0.0.0/0)
@@ -213,7 +213,7 @@ AWS_SECURITY_GROUPS_OPEN_INTERNET_FACING = AttackPathsQueryDefinition(
WHERE ec2.exposed_internet = true
AND ir.range = "0.0.0.0/0"
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, ec2)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, ec2)
YIELD rel AS can_access
UNWIND nodes(path_ec2) as n
@@ -231,13 +231,13 @@ AWS_CLASSIC_ELB_INTERNET_EXPOSED = AttackPathsQueryDefinition(
description="Find Classic Load Balancers exposed to the internet along with their listeners.",
provider="aws",
cypher=f"""
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
YIELD node AS internet
MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(elb:LoadBalancer)--(listener:ELBListener)
WHERE elb.exposed_internet = true
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, elb)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, elb)
YIELD rel AS can_access
UNWIND nodes(path) as n
@@ -255,13 +255,13 @@ AWS_ELBV2_INTERNET_EXPOSED = AttackPathsQueryDefinition(
description="Find ELBv2 load balancers exposed to the internet along with their listeners.",
provider="aws",
cypher=f"""
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
YIELD node AS internet
MATCH path = (aws:AWSAccount {{id: $provider_uid}})--(elbv2:LoadBalancerV2)--(listener:ELBV2Listener)
WHERE elbv2.exposed_internet = true
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, elbv2)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, elbv2)
YIELD rel AS can_access
UNWIND nodes(path) as n
@@ -279,7 +279,7 @@ AWS_PUBLIC_IP_RESOURCE_LOOKUP = AttackPathsQueryDefinition(
description="Given a public IP address, find the related AWS resource and its adjacent node within the selected account.",
provider="aws",
cypher=f"""
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet'}})
CALL apoc.create.vNode(['Internet'], {{id: 'Internet', name: 'Internet', provider_id: $provider_id}})
YIELD node AS internet
CALL () {{
@@ -302,7 +302,7 @@ AWS_PUBLIC_IP_RESOURCE_LOOKUP = AttackPathsQueryDefinition(
WITH path, x, internet
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{}}, x)
CALL apoc.create.vRelationship(internet, 'CAN_ACCESS', {{provider_id: $provider_id}}, x)
YIELD rel AS can_access
UNWIND nodes(path) as n

View File

@@ -35,6 +35,7 @@ def prepare_query_parameters(
definition: AttackPathsQueryDefinition,
provided_parameters: dict[str, Any],
provider_uid: str,
provider_id: str,
) -> dict[str, Any]:
parameters = dict(provided_parameters or {})
expected_names = {parameter.name for parameter in definition.parameters}
@@ -56,6 +57,7 @@ def prepare_query_parameters(
clean_parameters = {
"provider_uid": str(provider_uid),
"provider_id": str(provider_id),
}
for definition_parameter in definition.parameters:
@@ -82,11 +84,12 @@ def execute_attack_paths_query(
database_name: str,
definition: AttackPathsQueryDefinition,
parameters: dict[str, Any],
provider_id: str,
) -> dict[str, Any]:
try:
with graph_database.get_session(database_name) as session:
result = session.run(definition.cypher, parameters)
return _serialize_graph(result.graph())
return _serialize_graph(result.graph(), provider_id)
except graph_database.GraphDatabaseQueryException as exc:
logger.error(f"Query failed for Attack Paths query `{definition.id}`: {exc}")
@@ -95,9 +98,14 @@ def execute_attack_paths_query(
)
def _serialize_graph(graph):
def _serialize_graph(graph, provider_id: str):
nodes = []
kept_node_ids = set()
for node in graph.nodes:
if node._properties.get("provider_id") != provider_id:
continue
kept_node_ids.add(node.element_id)
nodes.append(
{
"id": node.element_id,
@@ -108,6 +116,15 @@ def _serialize_graph(graph):
relationships = []
for relationship in graph.relationships:
if relationship._properties.get("provider_id") != provider_id:
continue
if (
relationship.start_node.element_id not in kept_node_ids
or relationship.end_node.element_id not in kept_node_ids
):
continue
relationships.append(
{
"id": relationship.element_id,

View File

@@ -38,9 +38,11 @@ def test_prepare_query_parameters_includes_provider_and_casts(
definition,
{"limit": "5"},
provider_uid="123456789012",
provider_id="test-provider-id",
)
assert result["provider_uid"] == "123456789012"
assert result["provider_id"] == "test-provider-id"
assert result["limit"] == 5
@@ -57,7 +59,9 @@ def test_prepare_query_parameters_validates_names(
definition = attack_paths_query_definition_factory()
with pytest.raises(ValidationError) as exc:
views_helpers.prepare_query_parameters(definition, provided, provider_uid="1")
views_helpers.prepare_query_parameters(
definition, provided, provider_uid="1", provider_id="p1"
)
assert expected_message in str(exc.value)
@@ -72,6 +76,7 @@ def test_prepare_query_parameters_validates_cast(
definition,
{"limit": "not-an-int"},
provider_uid="1",
provider_id="p1",
)
assert "Invalid value" in str(exc.value)
@@ -90,11 +95,13 @@ def test_execute_attack_paths_query_serializes_graph(
)
parameters = {"provider_uid": "123"}
provider_id = "test-provider-123"
node = attack_paths_graph_stub_classes.Node(
element_id="node-1",
labels=["AWSAccount"],
properties={
"name": "account",
"provider_id": provider_id,
"complex": {
"items": [
attack_paths_graph_stub_classes.NativeValue("value"),
@@ -103,14 +110,17 @@ def test_execute_attack_paths_query_serializes_graph(
},
},
)
node_2 = attack_paths_graph_stub_classes.Node(
"node-2", ["RDSInstance"], {"provider_id": provider_id}
)
relationship = attack_paths_graph_stub_classes.Relationship(
element_id="rel-1",
rel_type="OWNS",
start_node=node,
end_node=attack_paths_graph_stub_classes.Node("node-2", ["RDSInstance"], {}),
properties={"weight": 1},
end_node=node_2,
properties={"weight": 1, "provider_id": provider_id},
)
graph = SimpleNamespace(nodes=[node], relationships=[relationship])
graph = SimpleNamespace(nodes=[node, node_2], relationships=[relationship])
run_result = MagicMock()
run_result.graph.return_value = graph
@@ -129,7 +139,7 @@ def test_execute_attack_paths_query_serializes_graph(
return_value=session_ctx,
) as mock_get_session:
result = views_helpers.execute_attack_paths_query(
database_name, definition, parameters
database_name, definition, parameters, provider_id=provider_id
)
mock_get_session.assert_called_once_with(database_name)
@@ -169,7 +179,40 @@ def test_execute_attack_paths_query_wraps_graph_errors(
):
with pytest.raises(APIException):
views_helpers.execute_attack_paths_query(
database_name, definition, parameters
database_name, definition, parameters, provider_id="test-provider-123"
)
mock_logger.error.assert_called_once()
def test_serialize_graph_filters_by_provider_id(attack_paths_graph_stub_classes):
provider_id = "provider-keep"
node_keep = attack_paths_graph_stub_classes.Node(
"n1", ["AWSAccount"], {"provider_id": provider_id}
)
node_drop = attack_paths_graph_stub_classes.Node(
"n2", ["AWSAccount"], {"provider_id": "provider-other"}
)
rel_keep = attack_paths_graph_stub_classes.Relationship(
"r1", "OWNS", node_keep, node_keep, {"provider_id": provider_id}
)
rel_drop_by_provider = attack_paths_graph_stub_classes.Relationship(
"r2", "OWNS", node_keep, node_drop, {"provider_id": "provider-other"}
)
rel_drop_orphaned = attack_paths_graph_stub_classes.Relationship(
"r3", "OWNS", node_keep, node_drop, {"provider_id": provider_id}
)
graph = SimpleNamespace(
nodes=[node_keep, node_drop],
relationships=[rel_keep, rel_drop_by_provider, rel_drop_orphaned],
)
result = views_helpers._serialize_graph(graph, provider_id)
assert len(result["nodes"]) == 1
assert result["nodes"][0]["id"] == "n1"
assert len(result["relationships"]) == 1
assert result["relationships"][0]["id"] == "r1"

View File

@@ -3995,15 +3995,18 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_200_OK
mock_get_query.assert_called_once_with("aws-rds")
mock_get_db_name.assert_called_once_with(attack_paths_scan.provider.tenant_id)
provider_id = str(attack_paths_scan.provider_id)
mock_prepare.assert_called_once_with(
query_definition,
{},
attack_paths_scan.provider.uid,
provider_id,
)
mock_execute.assert_called_once_with(
expected_db_name,
query_definition,
prepared_parameters,
provider_id,
)
mock_clear_cache.assert_called_once_with(expected_db_name)
result = response.json()["data"]

View File

@@ -2505,14 +2505,19 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
database_name = graph_database.get_database_name(
attack_paths_scan.provider.tenant_id
)
provider_id = str(attack_paths_scan.provider_id)
parameters = attack_paths_views_helpers.prepare_query_parameters(
query_definition,
serializer.validated_data.get("parameters", {}),
attack_paths_scan.provider.uid,
provider_id,
)
graph = attack_paths_views_helpers.execute_attack_paths_query(
database_name, query_definition, parameters
database_name,
query_definition,
parameters,
provider_id,
)
graph_database.clear_cache(database_name)