refactor(api): remove graph_database and is_graph_database_deleted from AttackPathsScan (#10077)

This commit is contained in:
Josema Camacho
2026-02-16 12:46:49 +01:00
committed by GitHub
parent be516f1dfc
commit bb34f6cc3d
15 changed files with 76 additions and 142 deletions
+8 -7
View File
@@ -1,14 +1,14 @@
name: 'API: Security'
name: "API: Security"
on:
push:
branches:
- 'master'
- 'v5.*'
- "master"
- "v5.*"
pull_request:
branches:
- 'master'
- 'v5.*'
- "master"
- "v5.*"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@@ -26,7 +26,7 @@ jobs:
strategy:
matrix:
python-version:
- '3.12'
- "3.12"
defaults:
run:
working-directory: ./api
@@ -61,8 +61,9 @@ jobs:
- name: Safety
if: steps.check-changes.outputs.any_changed == 'true'
run: poetry run safety check --ignore 79023,79027
run: poetry run safety check --ignore 79023,79027,84420
# TODO: 79023 & 79027 knack ReDoS until `azure-cli-core` (via `cartography`) allows `knack` >=0.13.0
# TODO: 84420 from `azure-core`, that we need fix alltogether with `azure-cli-core` and `knack`
- name: Vulture
if: steps.check-changes.outputs.any_changed == 'true'
+2 -2
View File
@@ -85,7 +85,6 @@ repos:
args: ["--directory=./"]
pass_filenames: false
- repo: https://github.com/hadolint/hadolint
rev: v2.13.0-beta
hooks:
@@ -121,7 +120,8 @@ repos:
description: "Safety is a tool that checks your installed dependencies for known security vulnerabilities"
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
# TODO: 79023 & 79027 knack ReDoS until `azure-cli-core` (via `cartography`) allows `knack` >=0.13.0
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353,77744,77745,79023,79027'
# TODO: 84420 from `azure-core`, that we need fix alltogether with `azure-cli-core` and `knack`
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353,77744,77745,79023,79027,84420'
language: system
- id: vulture
+1
View File
@@ -19,6 +19,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Support CSA CCM 4.0 for the Oracle Cloud provider [(#10057)](https://github.com/prowler-cloud/prowler/pull/10057)
- Support CSA CCM 4.0 for the Alibaba Cloud provider [(#10061)](https://github.com/prowler-cloud/prowler/pull/10061)
- Attack Paths: Mark attack Paths scan as failed when Celery task fails outside job error handling [(#10065)](https://github.com/prowler-cloud/prowler/pull/10065)
- Attack Paths: Remove legacy per-scan `graph_database` and `is_graph_database_deleted` fields from AttackPathsScan model [(#10077)](https://github.com/prowler-cloud/prowler/pull/10077)
### 🔐 Security
@@ -5,7 +5,6 @@ from typing import Any, Iterable
from rest_framework.exceptions import APIException, ValidationError
from api.attack_paths import database as graph_database, AttackPathsQueryDefinition
from api.models import AttackPathsScan
from config.custom_logging import BackendLogger
from tasks.jobs.attack_paths.config import INTERNAL_LABELS
@@ -80,12 +79,12 @@ def prepare_query_parameters(
def execute_attack_paths_query(
attack_paths_scan: AttackPathsScan,
database_name: str,
definition: AttackPathsQueryDefinition,
parameters: dict[str, Any],
) -> dict[str, Any]:
try:
with graph_database.get_session(attack_paths_scan.graph_database) as session:
with graph_database.get_session(database_name) as session:
result = session.run(definition.cypher, parameters)
return _serialize_graph(result.graph())
@@ -9,8 +9,6 @@
"state": "completed",
"progress": 100,
"update_tag": 1693586667,
"graph_database": "db-a7f0f6de-6f8e-4b3a-8cbe-3f6dd9012345",
"is_graph_database_deleted": false,
"task": null,
"inserted_at": "2024-09-01T17:24:37Z",
"updated_at": "2024-09-01T17:44:37Z",
@@ -30,8 +28,6 @@
"state": "executing",
"progress": 48,
"update_tag": 1697625000,
"graph_database": "db-4a2fb2af-8a60-4d7d-9cae-4ca65e098765",
"is_graph_database_deleted": false,
"task": null,
"inserted_at": "2024-10-18T10:55:57Z",
"updated_at": "2024-10-18T10:56:15Z",
@@ -0,0 +1,23 @@
# Generated by Django 5.1.15 on 2026-02-16 09:24
from django.contrib.postgres.operations import RemoveIndexConcurrently
from django.db import migrations
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0076_openstack_provider"),
]
operations = [
RemoveIndexConcurrently(
model_name="attackpathsscan",
name="aps_active_graph_idx",
),
RemoveIndexConcurrently(
model_name="attackpathsscan",
name="aps_completed_graph_idx",
),
]
@@ -0,0 +1,20 @@
# Generated by Django 5.1.15 on 2026-02-16 09:24
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("api", "0077_remove_attackpathsscan_graph_database_indexes"),
]
operations = [
migrations.RemoveField(
model_name="attackpathsscan",
name="graph_database",
),
migrations.RemoveField(
model_name="attackpathsscan",
name="is_graph_database_deleted",
),
]
-17
View File
@@ -691,8 +691,6 @@ class AttackPathsScan(RowLevelSecurityProtectedModel):
update_tag = models.BigIntegerField(
null=True, blank=True, help_text="Cartography update tag (epoch)"
)
graph_database = models.CharField(max_length=63, null=True, blank=True)
is_graph_database_deleted = models.BooleanField(default=False)
ingestion_exceptions = models.JSONField(default=dict, null=True, blank=True)
class Meta(RowLevelSecurityProtectedModel.Meta):
@@ -719,21 +717,6 @@ class AttackPathsScan(RowLevelSecurityProtectedModel):
fields=["tenant_id", "scan_id"],
name="aps_scan_lookup_idx",
),
models.Index(
fields=["tenant_id", "provider_id"],
name="aps_active_graph_idx",
include=["graph_database", "id"],
condition=Q(is_graph_database_deleted=False),
),
models.Index(
fields=["tenant_id", "provider_id", "-completed_at"],
name="aps_completed_graph_idx",
include=["graph_database", "id"],
condition=Q(
state=StateChoices.COMPLETED,
is_graph_database_deleted=False,
),
),
]
class JSONAPIMeta:
@@ -89,7 +89,6 @@ def test_execute_attack_paths_query_serializes_graph(
parameters=[],
)
parameters = {"provider_uid": "123"}
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
node = attack_paths_graph_stub_classes.Node(
element_id="node-1",
@@ -123,15 +122,17 @@ def test_execute_attack_paths_query_serializes_graph(
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
database_name = "db-tenant-test-tenant-id"
with patch(
"api.attack_paths.views_helpers.graph_database.get_session",
return_value=session_ctx,
) as mock_get_session:
result = views_helpers.execute_attack_paths_query(
attack_paths_scan, definition, parameters
database_name, definition, parameters
)
mock_get_session.assert_called_once_with("tenant-db")
mock_get_session.assert_called_once_with(database_name)
session.run.assert_called_once_with(definition.cypher, parameters)
assert result["nodes"][0]["id"] == "node-1"
assert result["nodes"][0]["properties"]["complex"]["items"][0] == "value"
@@ -149,7 +150,7 @@ def test_execute_attack_paths_query_wraps_graph_errors(
cypher="MATCH (n) RETURN n",
parameters=[],
)
attack_paths_scan = SimpleNamespace(graph_database="tenant-db")
database_name = "db-tenant-test-tenant-id"
parameters = {"provider_uid": "123"}
class ExplodingContext:
@@ -168,7 +169,7 @@ def test_execute_attack_paths_query_wraps_graph_errors(
):
with pytest.raises(APIException):
views_helpers.execute_attack_paths_query(
attack_paths_scan, definition, parameters
database_name, definition, parameters
)
mock_logger.error.assert_called_once()
+9 -28
View File
@@ -3922,7 +3922,6 @@ class TestAttackPathsScanViewSet:
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database="tenant-db",
)
query_definition = AttackPathsQueryDefinition(
id="aws-rds",
@@ -3953,10 +3952,16 @@ class TestAttackPathsScanViewSet:
],
}
expected_db_name = f"db-tenant-{attack_paths_scan.provider.tenant_id}"
with (
patch(
"api.v1.views.get_query_by_id", return_value=query_definition
) as mock_get_query,
patch(
"api.v1.views.graph_database.get_database_name",
return_value=expected_db_name,
) as mock_get_db_name,
patch(
"api.v1.views.attack_paths_views_helpers.prepare_query_parameters",
return_value=prepared_parameters,
@@ -3978,17 +3983,18 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_200_OK
mock_get_query.assert_called_once_with("aws-rds")
mock_get_db_name.assert_called_once_with(attack_paths_scan.provider.tenant_id)
mock_prepare.assert_called_once_with(
query_definition,
{},
attack_paths_scan.provider.uid,
)
mock_execute.assert_called_once_with(
attack_paths_scan,
expected_db_name,
query_definition,
prepared_parameters,
)
mock_clear_cache.assert_called_once_with(attack_paths_scan.graph_database)
mock_clear_cache.assert_called_once_with(expected_db_name)
result = response.json()["data"]
attributes = result["attributes"]
assert attributes["nodes"] == graph_payload["nodes"]
@@ -4019,31 +4025,6 @@ class TestAttackPathsScanViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "must be completed" in response.json()["errors"][0]["detail"]
def test_run_attack_paths_query_requires_graph_database(
self,
authenticated_client,
providers_fixture,
scans_fixture,
create_attack_paths_scan,
):
provider = providers_fixture[0]
attack_paths_scan = create_attack_paths_scan(
provider,
scan=scans_fixture[0],
graph_database=None,
)
response = authenticated_client.post(
reverse(
"attack-paths-scans-queries-run", kwargs={"pk": attack_paths_scan.id}
),
data=self._run_payload(),
content_type=API_JSON_CONTENT_TYPE,
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "does not reference a graph database" in str(response.json())
def test_run_attack_paths_query_unknown_query(
self,
authenticated_client,
+5 -11
View File
@@ -2428,15 +2428,6 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
}
)
if not attack_paths_scan.graph_database:
logger.error(
f"The Attack Paths Scan {attack_paths_scan.id} does not reference a graph database"
)
return Response(
{"detail": "The Attack Paths scan does not reference a graph database"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
payload = attack_paths_views_helpers.normalize_run_payload(request.data)
serializer = AttackPathsQueryRunRequestSerializer(data=payload)
serializer.is_valid(raise_exception=True)
@@ -2450,6 +2441,9 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
{"id": "Unknown Attack Paths query for the selected provider"}
)
database_name = graph_database.get_database_name(
attack_paths_scan.provider.tenant_id
)
parameters = attack_paths_views_helpers.prepare_query_parameters(
query_definition,
serializer.validated_data.get("parameters", {}),
@@ -2457,9 +2451,9 @@ class AttackPathsScanViewSet(BaseRLSViewSet):
)
graph = attack_paths_views_helpers.execute_attack_paths_query(
attack_paths_scan, query_definition, parameters
database_name, query_definition, parameters
)
graph_database.clear_cache(attack_paths_scan.graph_database)
graph_database.clear_cache(database_name)
status_code = status.HTTP_200_OK
if not graph.get("nodes"):
-2
View File
@@ -1625,7 +1625,6 @@ def create_attack_paths_scan():
scan=None,
state=StateChoices.COMPLETED,
progress=0,
graph_database="tenant-db",
**extra_fields,
):
scan_instance = scan or Scan.objects.create(
@@ -1642,7 +1641,6 @@ def create_attack_paths_scan():
"scan": scan_instance,
"state": state,
"progress": progress,
"graph_database": graph_database,
}
payload.update(extra_fields)
@@ -66,7 +66,6 @@ def starting_attack_paths_scan(
attack_paths_scan.state = StateChoices.EXECUTING
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
attack_paths_scan.update_tag = cartography_config.update_tag
attack_paths_scan.graph_database = cartography_config.neo4j_database
attack_paths_scan.save(
update_fields=[
@@ -74,7 +73,6 @@ def starting_attack_paths_scan(
"state",
"started_at",
"update_tag",
"graph_database",
]
)
@@ -118,38 +116,6 @@ def update_attack_paths_scan_progress(
attack_paths_scan.save(update_fields=["progress"])
def get_old_attack_paths_scans(
tenant_id: str,
provider_id: str,
attack_paths_scan_id: str,
) -> list[ProwlerAPIAttackPathsScan]:
"""
An `old_attack_paths_scan` is any `completed` Attack Paths scan for the same provider,
with its graph database not deleted, excluding the current Attack Paths scan.
"""
with rls_transaction(tenant_id):
completed_scans_qs = (
ProwlerAPIAttackPathsScan.objects.filter(
provider_id=provider_id,
state=StateChoices.COMPLETED,
is_graph_database_deleted=False,
)
.exclude(id=attack_paths_scan_id)
.all()
)
return list(completed_scans_qs)
def update_old_attack_paths_scan(
old_attack_paths_scan: ProwlerAPIAttackPathsScan,
) -> None:
with rls_transaction(old_attack_paths_scan.tenant_id):
old_attack_paths_scan.is_graph_database_deleted = True
old_attack_paths_scan.save(update_fields=["is_graph_database_deleted"])
def fail_attack_paths_scan(
tenant_id: str,
scan_id: str,
@@ -193,30 +193,6 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
f"{prowler_api_provider.provider.upper()} provider {prowler_api_provider.id}"
)
# TODO
# This piece of code delete old Neo4j databases for this tenant's provider
# When we clean all of these databases we need to:
# - Delete this block
# - Delete function from `db_utils` the functions get_old_attack_paths_scans` & `update_old_attack_paths_scan`
# - Remove `graph_database` & `is_graph_database_deleted` from the AttackPathsScan model:
# - Check indexes
# - Create migration
# - The use of `attack_paths_scan.graph_database` on `views` and `views_helpers`
# - Tests
old_attack_paths_scans = db_utils.get_old_attack_paths_scans(
prowler_api_provider.tenant_id,
prowler_api_provider.id,
attack_paths_scan.id,
)
for old_attack_paths_scan in old_attack_paths_scans:
old_graph_database = old_attack_paths_scan.graph_database
if old_graph_database and old_graph_database != tenant_database_name:
logger.info(
f"Dropping old Neo4j database {old_graph_database} for provider {prowler_api_provider.id}"
)
graph_database.drop_database(old_graph_database)
db_utils.update_old_attack_paths_scan(old_attack_paths_scan)
logger.info(f"Dropping temporary Neo4j database {tmp_database_name}")
graph_database.drop_database(tmp_database_name)
@@ -28,10 +28,6 @@ class TestAttackPathsRun:
"tasks.jobs.attack_paths.scan.utils.call_within_event_loop",
side_effect=lambda fn, *a, **kw: fn(*a, **kw),
)
@patch(
"tasks.jobs.attack_paths.scan.db_utils.get_old_attack_paths_scans",
return_value=[],
)
@patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan")
@patch("tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress")
@patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan")
@@ -76,7 +72,6 @@ class TestAttackPathsRun:
mock_starting,
mock_update_progress,
mock_finish,
mock_get_old_scans,
mock_event_loop,
mock_drop_db,
tenants_fixture,