fix(api): clean up temp Neo4j databases on scan failure and provider deletion (#10101)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Josema Camacho
2026-02-18 10:18:34 +01:00
committed by GitHub
parent 338d514197
commit be3be3eb62
7 changed files with 219 additions and 23 deletions

View File

@@ -23,6 +23,10 @@ All notable changes to the **Prowler API** are documented in this file.
- Attack Paths: Remove legacy per-scan `graph_database` and `is_graph_database_deleted` fields from AttackPathsScan model [(#10077)](https://github.com/prowler-cloud/prowler/pull/10077)
- Attack Paths: Add `graph_data_ready` field to decouple query availability from scan state [(#10089)](https://github.com/prowler-cloud/prowler/pull/10089)
### 🐞 Fixed
- Attack Paths: Orphaned temporary Neo4j databases are now cleaned up on scan failure and provider deletion [(#10101)](https://github.com/prowler-cloud/prowler/pull/10101)
### 🔐 Security
- Bump `Pillow` to 12.1.1 (CVE-2021-25289) [(#10027)](https://github.com/prowler-cloud/prowler/pull/10027)

View File

@@ -2,7 +2,9 @@ from datetime import datetime, timezone
from typing import Any
from cartography.config import Config as CartographyConfig
from celery.utils.log import get_task_logger
from api.attack_paths import database as graph_database
from api.db_utils import rls_transaction
from api.models import (
AttackPathsScan as ProwlerAPIAttackPathsScan,
@@ -11,6 +13,8 @@ from api.models import (
)
from tasks.jobs.attack_paths.config import is_provider_available
logger = get_task_logger(__name__)
def can_provider_run_attack_paths_scan(tenant_id: str, provider_id: int) -> bool:
with rls_transaction(tenant_id):
@@ -165,6 +169,17 @@ def fail_attack_paths_scan(
StateChoices.COMPLETED,
StateChoices.FAILED,
):
tmp_db_name = graph_database.get_database_name(
attack_paths_scan.id, temporary=True
)
try:
graph_database.drop_database(tmp_db_name)
except Exception:
logger.exception(
f"Failed to drop temp database {tmp_db_name} during failure handling"
)
finish_attack_paths_scan(
attack_paths_scan,
StateChoices.FAILED,

View File

@@ -211,6 +211,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
# Handling databases changes
try:
graph_database.drop_database(tmp_cartography_config.neo4j_database)
except Exception:
logger.exception(
f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup"

View File

@@ -27,23 +27,24 @@ def delete_provider(tenant_id: str, pk: str):
Returns:
dict: A dictionary with the count of deleted objects per model,
including related models.
Raises:
Provider.DoesNotExist: If no instance with the provided primary key exists.
including related models. Returns an empty dict if the provider
was already deleted.
"""
# Delete the Attack Paths' graph data related to the provider
tenant_database_name = graph_database.get_database_name(tenant_id)
try:
graph_database.drop_subgraph(tenant_database_name, str(pk))
except graph_database.GraphDatabaseQueryException as gdb_error:
logger.error(f"Error deleting Provider graph data: {gdb_error}")
raise
# Get all provider related data and delete them in batches
# Get all provider related data to delete them in batches
with rls_transaction(tenant_id):
instance = Provider.all_objects.get(pk=pk)
try:
instance = Provider.all_objects.get(pk=pk)
except Provider.DoesNotExist:
logger.info(f"Provider `{pk}` already deleted, skipping")
return {}
attack_paths_scan_ids = list(
AttackPathsScan.all_objects.filter(provider=instance).values_list(
"id", flat=True
)
)
deletion_steps = [
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
("Findings", Finding.all_objects.filter(scan__provider=instance)),
@@ -52,6 +53,25 @@ def delete_provider(tenant_id: str, pk: str):
("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)),
]
# Drop orphaned temporary Neo4j databases
for aps_id in attack_paths_scan_ids:
tmp_db_name = graph_database.get_database_name(aps_id, temporary=True)
try:
graph_database.drop_database(tmp_db_name)
except graph_database.GraphDatabaseQueryException:
logger.warning(f"Failed to drop temp database {tmp_db_name}, continuing")
# Delete the Attack Paths' graph data related to the provider from the tenant database
tenant_database_name = graph_database.get_database_name(tenant_id)
try:
graph_database.drop_subgraph(tenant_database_name, str(pk))
except graph_database.GraphDatabaseQueryException as gdb_error:
logger.error(f"Error deleting Provider graph data: {gdb_error}")
raise
# Delete related data in batches
deletion_summary = {}
for step_name, queryset in deletion_steps:
try:
@@ -61,6 +81,7 @@ def delete_provider(tenant_id: str, pk: str):
logger.error(f"Error deleting {step_name}: {db_error}")
raise
# Delete the provider instance itself
try:
with rls_transaction(tenant_id):
_, provider_summary = instance.delete()
@@ -85,7 +106,9 @@ def delete_tenant(pk: str):
"""
deletion_summary = {}
for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
for provider in Provider.all_objects.using(MainRouter.admin_db).filter(
tenant_id=pk
):
summary = delete_provider(pk, provider.id)
deletion_summary.update(summary)

View File

@@ -11,8 +11,8 @@ from django_celery_beat.models import PeriodicTask
from tasks.jobs.attack_paths import (
attack_paths_scan,
can_provider_run_attack_paths_scan,
db_utils as attack_paths_db_utils,
)
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
from tasks.jobs.backfill import (
backfill_compliance_summaries,
backfill_daily_severity_summaries,

View File

@@ -441,6 +441,9 @@ class TestFailAttackPathsScan:
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
) as mock_retrieve,
patch(
"tasks.jobs.attack_paths.db_utils.graph_database.drop_database"
) as mock_drop_db,
patch(
"tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"
) as mock_finish,
@@ -448,6 +451,51 @@ class TestFailAttackPathsScan:
fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded")
mock_retrieve.assert_called_once_with(str(tenant.id), str(scan.id))
expected_tmp_db = f"db-tmp-scan-{str(attack_paths_scan.id).lower()}"
mock_drop_db.assert_called_once_with(expected_tmp_db)
mock_finish.assert_called_once_with(
attack_paths_scan,
StateChoices.FAILED,
{"global_error": "setup exploded"},
)
def test_drops_temp_database_even_when_drop_fails(
self, tenants_fixture, providers_fixture, scans_fixture
):
from tasks.jobs.attack_paths.db_utils import (
fail_attack_paths_scan,
)
tenant = tenants_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan = scans_fixture[0]
scan.provider = provider
scan.save()
attack_paths_scan = AttackPathsScan.objects.create(
tenant_id=tenant.id,
provider=provider,
scan=scan,
state=StateChoices.EXECUTING,
)
with (
patch(
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
),
patch(
"tasks.jobs.attack_paths.db_utils.graph_database.drop_database",
side_effect=Exception("Neo4j unreachable"),
),
patch(
"tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"
) as mock_finish,
):
fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded")
mock_finish.assert_called_once_with(
attack_paths_scan,
StateChoices.FAILED,
@@ -481,12 +529,16 @@ class TestFailAttackPathsScan:
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
return_value=attack_paths_scan,
),
patch(
"tasks.jobs.attack_paths.db_utils.graph_database.drop_database"
) as mock_drop_db,
patch(
"tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"
) as mock_finish,
):
fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded")
mock_drop_db.assert_not_called()
mock_finish.assert_not_called()
def test_skips_when_no_scan_found(self, tenants_fixture):

View File

@@ -4,6 +4,7 @@ import pytest
from django.core.exceptions import ObjectDoesNotExist
from api.attack_paths import database as graph_database
from api.models import Provider, Tenant
from tasks.jobs.deletion import delete_provider, delete_tenant
@@ -47,14 +48,61 @@ class TestDeleteProvider:
tenant_id = str(tenants_fixture[0].id)
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
with pytest.raises(ObjectDoesNotExist):
delete_provider(tenant_id, non_existent_pk)
result = delete_provider(tenant_id, non_existent_pk)
mock_get_database_name.assert_called_once_with(tenant_id)
mock_drop_subgraph.assert_called_once_with(
"tenant-db",
non_existent_pk,
)
assert result == {}
mock_get_database_name.assert_not_called()
mock_drop_subgraph.assert_not_called()
def test_delete_provider_drops_temp_attack_paths_databases(
self, providers_fixture, create_attack_paths_scan
):
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
aps1 = create_attack_paths_scan(instance)
aps2 = create_attack_paths_scan(instance)
with (
patch(
"tasks.jobs.deletion.graph_database.drop_subgraph",
),
patch(
"tasks.jobs.deletion.graph_database.drop_database",
) as mock_drop_database,
):
result = delete_provider(tenant_id, instance.id)
assert result
expected_tmp_calls = [
call(f"db-tmp-scan-{str(aps1.id).lower()}"),
call(f"db-tmp-scan-{str(aps2.id).lower()}"),
]
mock_drop_database.assert_has_calls(expected_tmp_calls, any_order=True)
def test_delete_provider_continues_when_temp_db_drop_fails(
self, providers_fixture, create_attack_paths_scan
):
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
create_attack_paths_scan(instance)
with (
patch(
"tasks.jobs.deletion.graph_database.drop_subgraph",
),
patch(
"tasks.jobs.deletion.graph_database.drop_database",
side_effect=graph_database.GraphDatabaseQueryException(
"Neo4j unreachable"
),
),
):
result = delete_provider(tenant_id, instance.id)
assert result
assert not Provider.all_objects.filter(pk=instance.id).exists()
@pytest.mark.django_db
@@ -142,3 +190,56 @@ class TestDeleteTenant:
mock_get_database_name.assert_called_once_with(tenant.id)
mock_drop_subgraph.assert_not_called()
mock_drop_database.assert_called_once_with("tenant-db")
def test_delete_tenant_includes_soft_deleted_providers(self, tenants_fixture):
tenant = tenants_fixture[0]
provider = Provider.objects.create(
provider="aws",
uid="999999999999",
alias="soft_deleted_provider",
tenant_id=tenant.id,
)
# Soft-delete the provider so ActiveProviderManager would skip it
Provider.all_objects.filter(pk=provider.id).update(is_deleted=True)
with (
patch(
"tasks.jobs.deletion.graph_database.get_database_name",
return_value="tenant-db",
),
patch(
"tasks.jobs.deletion.graph_database.drop_subgraph"
) as mock_drop_subgraph,
patch("tasks.jobs.deletion.graph_database.drop_database"),
):
delete_tenant(tenant.id)
mock_drop_subgraph.assert_any_call("tenant-db", str(provider.id))
def test_delete_tenant_handles_concurrently_deleted_provider(self, tenants_fixture):
tenant = tenants_fixture[0]
Provider.objects.create(
provider="aws",
uid="111111111111",
alias="vanishing_provider",
tenant_id=tenant.id,
)
def drop_subgraph_side_effect(_db_name, provider_id):
# Simulate concurrent deletion by another process
Provider.all_objects.filter(pk=provider_id).delete()
with (
patch(
"tasks.jobs.deletion.graph_database.get_database_name",
return_value="tenant-db",
),
patch(
"tasks.jobs.deletion.graph_database.drop_subgraph",
side_effect=drop_subgraph_side_effect,
),
patch("tasks.jobs.deletion.graph_database.drop_database"),
):
deletion_summary = delete_tenant(tenant.id)
assert deletion_summary is not None