From be3be3eb6219c335e72d593d5379a95e94ea7a90 Mon Sep 17 00:00:00 2001 From: Josema Camacho Date: Wed, 18 Feb 2026 10:18:34 +0100 Subject: [PATCH] fix(api): clean up temp Neo4j databases on scan failure and provider deletion (#10101) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/CHANGELOG.md | 4 + .../tasks/jobs/attack_paths/db_utils.py | 15 +++ .../backend/tasks/jobs/attack_paths/scan.py | 1 + api/src/backend/tasks/jobs/deletion.py | 53 +++++--- api/src/backend/tasks/tasks.py | 2 +- .../tasks/tests/test_attack_paths_scan.py | 52 ++++++++ api/src/backend/tasks/tests/test_deletion.py | 115 ++++++++++++++++-- 7 files changed, 219 insertions(+), 23 deletions(-) diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 0378a15d53..96fd682b96 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -23,6 +23,10 @@ All notable changes to the **Prowler API** are documented in this file. - Attack Paths: Remove legacy per-scan `graph_database` and `is_graph_database_deleted` fields from AttackPathsScan model [(#10077)](https://github.com/prowler-cloud/prowler/pull/10077) - Attack Paths: Add `graph_data_ready` field to decouple query availability from scan state [(#10089)](https://github.com/prowler-cloud/prowler/pull/10089) +### 🐞 Fixed + +- Attack Paths: Orphaned temporary Neo4j databases are now cleaned up on scan failure and provider deletion [(#10101)](https://github.com/prowler-cloud/prowler/pull/10101) + ### 🔐 Security - Bump `Pillow` to 12.1.1 (CVE-2021-25289) [(#10027)](https://github.com/prowler-cloud/prowler/pull/10027) diff --git a/api/src/backend/tasks/jobs/attack_paths/db_utils.py b/api/src/backend/tasks/jobs/attack_paths/db_utils.py index 721b1d0483..9fb52b0ead 100644 --- a/api/src/backend/tasks/jobs/attack_paths/db_utils.py +++ b/api/src/backend/tasks/jobs/attack_paths/db_utils.py @@ -2,7 +2,9 @@ from datetime import datetime, timezone from typing import Any from cartography.config import Config as CartographyConfig +from celery.utils.log import get_task_logger +from api.attack_paths import database as graph_database from api.db_utils import rls_transaction from api.models import ( AttackPathsScan as ProwlerAPIAttackPathsScan, @@ -11,6 +13,8 @@ from api.models import ( ) from tasks.jobs.attack_paths.config import is_provider_available +logger = get_task_logger(__name__) + def can_provider_run_attack_paths_scan(tenant_id: str, provider_id: int) -> bool: with rls_transaction(tenant_id): @@ -165,6 +169,17 @@ def fail_attack_paths_scan( StateChoices.COMPLETED, StateChoices.FAILED, ): + tmp_db_name = graph_database.get_database_name( + attack_paths_scan.id, temporary=True + ) + try: + graph_database.drop_database(tmp_db_name) + + except Exception: + logger.exception( + f"Failed to drop temp database {tmp_db_name} during failure handling" + ) + finish_attack_paths_scan( attack_paths_scan, StateChoices.FAILED, diff --git a/api/src/backend/tasks/jobs/attack_paths/scan.py b/api/src/backend/tasks/jobs/attack_paths/scan.py index a5206145d4..da70b77383 100644 --- a/api/src/backend/tasks/jobs/attack_paths/scan.py +++ b/api/src/backend/tasks/jobs/attack_paths/scan.py @@ -211,6 +211,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]: # Handling databases changes try: graph_database.drop_database(tmp_cartography_config.neo4j_database) + except Exception: logger.exception( f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup" diff --git a/api/src/backend/tasks/jobs/deletion.py b/api/src/backend/tasks/jobs/deletion.py index ba59eaeb5f..21c9031bf5 100644 --- a/api/src/backend/tasks/jobs/deletion.py +++ b/api/src/backend/tasks/jobs/deletion.py @@ -27,23 +27,24 @@ def delete_provider(tenant_id: str, pk: str): Returns: dict: A dictionary with the count of deleted objects per model, - including related models. - - Raises: - Provider.DoesNotExist: If no instance with the provided primary key exists. + including related models. Returns an empty dict if the provider + was already deleted. """ - # Delete the Attack Paths' graph data related to the provider - tenant_database_name = graph_database.get_database_name(tenant_id) - try: - graph_database.drop_subgraph(tenant_database_name, str(pk)) - except graph_database.GraphDatabaseQueryException as gdb_error: - logger.error(f"Error deleting Provider graph data: {gdb_error}") - raise - - # Get all provider related data and delete them in batches + # Get all provider related data to delete them in batches with rls_transaction(tenant_id): - instance = Provider.all_objects.get(pk=pk) + try: + instance = Provider.all_objects.get(pk=pk) + except Provider.DoesNotExist: + logger.info(f"Provider `{pk}` already deleted, skipping") + return {} + + attack_paths_scan_ids = list( + AttackPathsScan.all_objects.filter(provider=instance).values_list( + "id", flat=True + ) + ) + deletion_steps = [ ("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)), ("Findings", Finding.all_objects.filter(scan__provider=instance)), @@ -52,6 +53,25 @@ def delete_provider(tenant_id: str, pk: str): ("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)), ] + # Drop orphaned temporary Neo4j databases + for aps_id in attack_paths_scan_ids: + tmp_db_name = graph_database.get_database_name(aps_id, temporary=True) + try: + graph_database.drop_database(tmp_db_name) + + except graph_database.GraphDatabaseQueryException: + logger.warning(f"Failed to drop temp database {tmp_db_name}, continuing") + + # Delete the Attack Paths' graph data related to the provider from the tenant database + tenant_database_name = graph_database.get_database_name(tenant_id) + try: + graph_database.drop_subgraph(tenant_database_name, str(pk)) + + except graph_database.GraphDatabaseQueryException as gdb_error: + logger.error(f"Error deleting Provider graph data: {gdb_error}") + raise + + # Delete related data in batches deletion_summary = {} for step_name, queryset in deletion_steps: try: @@ -61,6 +81,7 @@ def delete_provider(tenant_id: str, pk: str): logger.error(f"Error deleting {step_name}: {db_error}") raise + # Delete the provider instance itself try: with rls_transaction(tenant_id): _, provider_summary = instance.delete() @@ -85,7 +106,9 @@ def delete_tenant(pk: str): """ deletion_summary = {} - for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk): + for provider in Provider.all_objects.using(MainRouter.admin_db).filter( + tenant_id=pk + ): summary = delete_provider(pk, provider.id) deletion_summary.update(summary) diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index 0cb41d066a..30cc0b09c4 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -11,8 +11,8 @@ from django_celery_beat.models import PeriodicTask from tasks.jobs.attack_paths import ( attack_paths_scan, can_provider_run_attack_paths_scan, + db_utils as attack_paths_db_utils, ) -from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils from tasks.jobs.backfill import ( backfill_compliance_summaries, backfill_daily_severity_summaries, diff --git a/api/src/backend/tasks/tests/test_attack_paths_scan.py b/api/src/backend/tasks/tests/test_attack_paths_scan.py index aa84b8ce63..a883360f5e 100644 --- a/api/src/backend/tasks/tests/test_attack_paths_scan.py +++ b/api/src/backend/tasks/tests/test_attack_paths_scan.py @@ -441,6 +441,9 @@ class TestFailAttackPathsScan: "tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan", return_value=attack_paths_scan, ) as mock_retrieve, + patch( + "tasks.jobs.attack_paths.db_utils.graph_database.drop_database" + ) as mock_drop_db, patch( "tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan" ) as mock_finish, @@ -448,6 +451,51 @@ class TestFailAttackPathsScan: fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded") mock_retrieve.assert_called_once_with(str(tenant.id), str(scan.id)) + expected_tmp_db = f"db-tmp-scan-{str(attack_paths_scan.id).lower()}" + mock_drop_db.assert_called_once_with(expected_tmp_db) + mock_finish.assert_called_once_with( + attack_paths_scan, + StateChoices.FAILED, + {"global_error": "setup exploded"}, + ) + + def test_drops_temp_database_even_when_drop_fails( + self, tenants_fixture, providers_fixture, scans_fixture + ): + from tasks.jobs.attack_paths.db_utils import ( + fail_attack_paths_scan, + ) + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.EXECUTING, + ) + + with ( + patch( + "tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch( + "tasks.jobs.attack_paths.db_utils.graph_database.drop_database", + side_effect=Exception("Neo4j unreachable"), + ), + patch( + "tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan" + ) as mock_finish, + ): + fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded") + mock_finish.assert_called_once_with( attack_paths_scan, StateChoices.FAILED, @@ -481,12 +529,16 @@ class TestFailAttackPathsScan: "tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan", return_value=attack_paths_scan, ), + patch( + "tasks.jobs.attack_paths.db_utils.graph_database.drop_database" + ) as mock_drop_db, patch( "tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan" ) as mock_finish, ): fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded") + mock_drop_db.assert_not_called() mock_finish.assert_not_called() def test_skips_when_no_scan_found(self, tenants_fixture): diff --git a/api/src/backend/tasks/tests/test_deletion.py b/api/src/backend/tasks/tests/test_deletion.py index 843ccb5df8..831ba20641 100644 --- a/api/src/backend/tasks/tests/test_deletion.py +++ b/api/src/backend/tasks/tests/test_deletion.py @@ -4,6 +4,7 @@ import pytest from django.core.exceptions import ObjectDoesNotExist +from api.attack_paths import database as graph_database from api.models import Provider, Tenant from tasks.jobs.deletion import delete_provider, delete_tenant @@ -47,14 +48,61 @@ class TestDeleteProvider: tenant_id = str(tenants_fixture[0].id) non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645" - with pytest.raises(ObjectDoesNotExist): - delete_provider(tenant_id, non_existent_pk) + result = delete_provider(tenant_id, non_existent_pk) - mock_get_database_name.assert_called_once_with(tenant_id) - mock_drop_subgraph.assert_called_once_with( - "tenant-db", - non_existent_pk, - ) + assert result == {} + mock_get_database_name.assert_not_called() + mock_drop_subgraph.assert_not_called() + + def test_delete_provider_drops_temp_attack_paths_databases( + self, providers_fixture, create_attack_paths_scan + ): + instance = providers_fixture[0] + tenant_id = str(instance.tenant_id) + + aps1 = create_attack_paths_scan(instance) + aps2 = create_attack_paths_scan(instance) + + with ( + patch( + "tasks.jobs.deletion.graph_database.drop_subgraph", + ), + patch( + "tasks.jobs.deletion.graph_database.drop_database", + ) as mock_drop_database, + ): + result = delete_provider(tenant_id, instance.id) + + assert result + expected_tmp_calls = [ + call(f"db-tmp-scan-{str(aps1.id).lower()}"), + call(f"db-tmp-scan-{str(aps2.id).lower()}"), + ] + mock_drop_database.assert_has_calls(expected_tmp_calls, any_order=True) + + def test_delete_provider_continues_when_temp_db_drop_fails( + self, providers_fixture, create_attack_paths_scan + ): + instance = providers_fixture[0] + tenant_id = str(instance.tenant_id) + + create_attack_paths_scan(instance) + + with ( + patch( + "tasks.jobs.deletion.graph_database.drop_subgraph", + ), + patch( + "tasks.jobs.deletion.graph_database.drop_database", + side_effect=graph_database.GraphDatabaseQueryException( + "Neo4j unreachable" + ), + ), + ): + result = delete_provider(tenant_id, instance.id) + + assert result + assert not Provider.all_objects.filter(pk=instance.id).exists() @pytest.mark.django_db @@ -142,3 +190,56 @@ class TestDeleteTenant: mock_get_database_name.assert_called_once_with(tenant.id) mock_drop_subgraph.assert_not_called() mock_drop_database.assert_called_once_with("tenant-db") + + def test_delete_tenant_includes_soft_deleted_providers(self, tenants_fixture): + tenant = tenants_fixture[0] + provider = Provider.objects.create( + provider="aws", + uid="999999999999", + alias="soft_deleted_provider", + tenant_id=tenant.id, + ) + # Soft-delete the provider so ActiveProviderManager would skip it + Provider.all_objects.filter(pk=provider.id).update(is_deleted=True) + + with ( + patch( + "tasks.jobs.deletion.graph_database.get_database_name", + return_value="tenant-db", + ), + patch( + "tasks.jobs.deletion.graph_database.drop_subgraph" + ) as mock_drop_subgraph, + patch("tasks.jobs.deletion.graph_database.drop_database"), + ): + delete_tenant(tenant.id) + + mock_drop_subgraph.assert_any_call("tenant-db", str(provider.id)) + + def test_delete_tenant_handles_concurrently_deleted_provider(self, tenants_fixture): + tenant = tenants_fixture[0] + Provider.objects.create( + provider="aws", + uid="111111111111", + alias="vanishing_provider", + tenant_id=tenant.id, + ) + + def drop_subgraph_side_effect(_db_name, provider_id): + # Simulate concurrent deletion by another process + Provider.all_objects.filter(pk=provider_id).delete() + + with ( + patch( + "tasks.jobs.deletion.graph_database.get_database_name", + return_value="tenant-db", + ), + patch( + "tasks.jobs.deletion.graph_database.drop_subgraph", + side_effect=drop_subgraph_side_effect, + ), + patch("tasks.jobs.deletion.graph_database.drop_database"), + ): + deletion_summary = delete_tenant(tenant.id) + + assert deletion_summary is not None