mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-21 18:58:04 +00:00
fix(api): clean up temp Neo4j databases on scan failure and provider deletion (#10101)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,10 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- Attack Paths: Remove legacy per-scan `graph_database` and `is_graph_database_deleted` fields from AttackPathsScan model [(#10077)](https://github.com/prowler-cloud/prowler/pull/10077)
|
||||
- Attack Paths: Add `graph_data_ready` field to decouple query availability from scan state [(#10089)](https://github.com/prowler-cloud/prowler/pull/10089)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- Attack Paths: Orphaned temporary Neo4j databases are now cleaned up on scan failure and provider deletion [(#10101)](https://github.com/prowler-cloud/prowler/pull/10101)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
- Bump `Pillow` to 12.1.1 (CVE-2021-25289) [(#10027)](https://github.com/prowler-cloud/prowler/pull/10027)
|
||||
|
||||
@@ -2,7 +2,9 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from cartography.config import Config as CartographyConfig
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from api.attack_paths import database as graph_database
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
AttackPathsScan as ProwlerAPIAttackPathsScan,
|
||||
@@ -11,6 +13,8 @@ from api.models import (
|
||||
)
|
||||
from tasks.jobs.attack_paths.config import is_provider_available
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
def can_provider_run_attack_paths_scan(tenant_id: str, provider_id: int) -> bool:
|
||||
with rls_transaction(tenant_id):
|
||||
@@ -165,6 +169,17 @@ def fail_attack_paths_scan(
|
||||
StateChoices.COMPLETED,
|
||||
StateChoices.FAILED,
|
||||
):
|
||||
tmp_db_name = graph_database.get_database_name(
|
||||
attack_paths_scan.id, temporary=True
|
||||
)
|
||||
try:
|
||||
graph_database.drop_database(tmp_db_name)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to drop temp database {tmp_db_name} during failure handling"
|
||||
)
|
||||
|
||||
finish_attack_paths_scan(
|
||||
attack_paths_scan,
|
||||
StateChoices.FAILED,
|
||||
|
||||
@@ -211,6 +211,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
|
||||
# Handling databases changes
|
||||
try:
|
||||
graph_database.drop_database(tmp_cartography_config.neo4j_database)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup"
|
||||
|
||||
@@ -27,23 +27,24 @@ def delete_provider(tenant_id: str, pk: str):
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the count of deleted objects per model,
|
||||
including related models.
|
||||
|
||||
Raises:
|
||||
Provider.DoesNotExist: If no instance with the provided primary key exists.
|
||||
including related models. Returns an empty dict if the provider
|
||||
was already deleted.
|
||||
"""
|
||||
# Delete the Attack Paths' graph data related to the provider
|
||||
tenant_database_name = graph_database.get_database_name(tenant_id)
|
||||
try:
|
||||
graph_database.drop_subgraph(tenant_database_name, str(pk))
|
||||
|
||||
except graph_database.GraphDatabaseQueryException as gdb_error:
|
||||
logger.error(f"Error deleting Provider graph data: {gdb_error}")
|
||||
raise
|
||||
|
||||
# Get all provider related data and delete them in batches
|
||||
# Get all provider related data to delete them in batches
|
||||
with rls_transaction(tenant_id):
|
||||
instance = Provider.all_objects.get(pk=pk)
|
||||
try:
|
||||
instance = Provider.all_objects.get(pk=pk)
|
||||
except Provider.DoesNotExist:
|
||||
logger.info(f"Provider `{pk}` already deleted, skipping")
|
||||
return {}
|
||||
|
||||
attack_paths_scan_ids = list(
|
||||
AttackPathsScan.all_objects.filter(provider=instance).values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
)
|
||||
|
||||
deletion_steps = [
|
||||
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
|
||||
("Findings", Finding.all_objects.filter(scan__provider=instance)),
|
||||
@@ -52,6 +53,25 @@ def delete_provider(tenant_id: str, pk: str):
|
||||
("AttackPathsScans", AttackPathsScan.all_objects.filter(provider=instance)),
|
||||
]
|
||||
|
||||
# Drop orphaned temporary Neo4j databases
|
||||
for aps_id in attack_paths_scan_ids:
|
||||
tmp_db_name = graph_database.get_database_name(aps_id, temporary=True)
|
||||
try:
|
||||
graph_database.drop_database(tmp_db_name)
|
||||
|
||||
except graph_database.GraphDatabaseQueryException:
|
||||
logger.warning(f"Failed to drop temp database {tmp_db_name}, continuing")
|
||||
|
||||
# Delete the Attack Paths' graph data related to the provider from the tenant database
|
||||
tenant_database_name = graph_database.get_database_name(tenant_id)
|
||||
try:
|
||||
graph_database.drop_subgraph(tenant_database_name, str(pk))
|
||||
|
||||
except graph_database.GraphDatabaseQueryException as gdb_error:
|
||||
logger.error(f"Error deleting Provider graph data: {gdb_error}")
|
||||
raise
|
||||
|
||||
# Delete related data in batches
|
||||
deletion_summary = {}
|
||||
for step_name, queryset in deletion_steps:
|
||||
try:
|
||||
@@ -61,6 +81,7 @@ def delete_provider(tenant_id: str, pk: str):
|
||||
logger.error(f"Error deleting {step_name}: {db_error}")
|
||||
raise
|
||||
|
||||
# Delete the provider instance itself
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
_, provider_summary = instance.delete()
|
||||
@@ -85,7 +106,9 @@ def delete_tenant(pk: str):
|
||||
"""
|
||||
deletion_summary = {}
|
||||
|
||||
for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
|
||||
for provider in Provider.all_objects.using(MainRouter.admin_db).filter(
|
||||
tenant_id=pk
|
||||
):
|
||||
summary = delete_provider(pk, provider.id)
|
||||
deletion_summary.update(summary)
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from django_celery_beat.models import PeriodicTask
|
||||
from tasks.jobs.attack_paths import (
|
||||
attack_paths_scan,
|
||||
can_provider_run_attack_paths_scan,
|
||||
db_utils as attack_paths_db_utils,
|
||||
)
|
||||
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
|
||||
from tasks.jobs.backfill import (
|
||||
backfill_compliance_summaries,
|
||||
backfill_daily_severity_summaries,
|
||||
|
||||
@@ -441,6 +441,9 @@ class TestFailAttackPathsScan:
|
||||
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
|
||||
return_value=attack_paths_scan,
|
||||
) as mock_retrieve,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.db_utils.graph_database.drop_database"
|
||||
) as mock_drop_db,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"
|
||||
) as mock_finish,
|
||||
@@ -448,6 +451,51 @@ class TestFailAttackPathsScan:
|
||||
fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded")
|
||||
|
||||
mock_retrieve.assert_called_once_with(str(tenant.id), str(scan.id))
|
||||
expected_tmp_db = f"db-tmp-scan-{str(attack_paths_scan.id).lower()}"
|
||||
mock_drop_db.assert_called_once_with(expected_tmp_db)
|
||||
mock_finish.assert_called_once_with(
|
||||
attack_paths_scan,
|
||||
StateChoices.FAILED,
|
||||
{"global_error": "setup exploded"},
|
||||
)
|
||||
|
||||
def test_drops_temp_database_even_when_drop_fails(
|
||||
self, tenants_fixture, providers_fixture, scans_fixture
|
||||
):
|
||||
from tasks.jobs.attack_paths.db_utils import (
|
||||
fail_attack_paths_scan,
|
||||
)
|
||||
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
scan = scans_fixture[0]
|
||||
scan.provider = provider
|
||||
scan.save()
|
||||
|
||||
attack_paths_scan = AttackPathsScan.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
scan=scan,
|
||||
state=StateChoices.EXECUTING,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
|
||||
return_value=attack_paths_scan,
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.db_utils.graph_database.drop_database",
|
||||
side_effect=Exception("Neo4j unreachable"),
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"
|
||||
) as mock_finish,
|
||||
):
|
||||
fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded")
|
||||
|
||||
mock_finish.assert_called_once_with(
|
||||
attack_paths_scan,
|
||||
StateChoices.FAILED,
|
||||
@@ -481,12 +529,16 @@ class TestFailAttackPathsScan:
|
||||
"tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan",
|
||||
return_value=attack_paths_scan,
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.db_utils.graph_database.drop_database"
|
||||
) as mock_drop_db,
|
||||
patch(
|
||||
"tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"
|
||||
) as mock_finish,
|
||||
):
|
||||
fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded")
|
||||
|
||||
mock_drop_db.assert_not_called()
|
||||
mock_finish.assert_not_called()
|
||||
|
||||
def test_skips_when_no_scan_found(self, tenants_fixture):
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
|
||||
from api.attack_paths import database as graph_database
|
||||
from api.models import Provider, Tenant
|
||||
from tasks.jobs.deletion import delete_provider, delete_tenant
|
||||
|
||||
@@ -47,14 +48,61 @@ class TestDeleteProvider:
|
||||
tenant_id = str(tenants_fixture[0].id)
|
||||
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
|
||||
|
||||
with pytest.raises(ObjectDoesNotExist):
|
||||
delete_provider(tenant_id, non_existent_pk)
|
||||
result = delete_provider(tenant_id, non_existent_pk)
|
||||
|
||||
mock_get_database_name.assert_called_once_with(tenant_id)
|
||||
mock_drop_subgraph.assert_called_once_with(
|
||||
"tenant-db",
|
||||
non_existent_pk,
|
||||
)
|
||||
assert result == {}
|
||||
mock_get_database_name.assert_not_called()
|
||||
mock_drop_subgraph.assert_not_called()
|
||||
|
||||
def test_delete_provider_drops_temp_attack_paths_databases(
|
||||
self, providers_fixture, create_attack_paths_scan
|
||||
):
|
||||
instance = providers_fixture[0]
|
||||
tenant_id = str(instance.tenant_id)
|
||||
|
||||
aps1 = create_attack_paths_scan(instance)
|
||||
aps2 = create_attack_paths_scan(instance)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_subgraph",
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_database",
|
||||
) as mock_drop_database,
|
||||
):
|
||||
result = delete_provider(tenant_id, instance.id)
|
||||
|
||||
assert result
|
||||
expected_tmp_calls = [
|
||||
call(f"db-tmp-scan-{str(aps1.id).lower()}"),
|
||||
call(f"db-tmp-scan-{str(aps2.id).lower()}"),
|
||||
]
|
||||
mock_drop_database.assert_has_calls(expected_tmp_calls, any_order=True)
|
||||
|
||||
def test_delete_provider_continues_when_temp_db_drop_fails(
|
||||
self, providers_fixture, create_attack_paths_scan
|
||||
):
|
||||
instance = providers_fixture[0]
|
||||
tenant_id = str(instance.tenant_id)
|
||||
|
||||
create_attack_paths_scan(instance)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_subgraph",
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_database",
|
||||
side_effect=graph_database.GraphDatabaseQueryException(
|
||||
"Neo4j unreachable"
|
||||
),
|
||||
),
|
||||
):
|
||||
result = delete_provider(tenant_id, instance.id)
|
||||
|
||||
assert result
|
||||
assert not Provider.all_objects.filter(pk=instance.id).exists()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -142,3 +190,56 @@ class TestDeleteTenant:
|
||||
mock_get_database_name.assert_called_once_with(tenant.id)
|
||||
mock_drop_subgraph.assert_not_called()
|
||||
mock_drop_database.assert_called_once_with("tenant-db")
|
||||
|
||||
def test_delete_tenant_includes_soft_deleted_providers(self, tenants_fixture):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = Provider.objects.create(
|
||||
provider="aws",
|
||||
uid="999999999999",
|
||||
alias="soft_deleted_provider",
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
# Soft-delete the provider so ActiveProviderManager would skip it
|
||||
Provider.all_objects.filter(pk=provider.id).update(is_deleted=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.get_database_name",
|
||||
return_value="tenant-db",
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_subgraph"
|
||||
) as mock_drop_subgraph,
|
||||
patch("tasks.jobs.deletion.graph_database.drop_database"),
|
||||
):
|
||||
delete_tenant(tenant.id)
|
||||
|
||||
mock_drop_subgraph.assert_any_call("tenant-db", str(provider.id))
|
||||
|
||||
def test_delete_tenant_handles_concurrently_deleted_provider(self, tenants_fixture):
|
||||
tenant = tenants_fixture[0]
|
||||
Provider.objects.create(
|
||||
provider="aws",
|
||||
uid="111111111111",
|
||||
alias="vanishing_provider",
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
def drop_subgraph_side_effect(_db_name, provider_id):
|
||||
# Simulate concurrent deletion by another process
|
||||
Provider.all_objects.filter(pk=provider_id).delete()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.get_database_name",
|
||||
return_value="tenant-db",
|
||||
),
|
||||
patch(
|
||||
"tasks.jobs.deletion.graph_database.drop_subgraph",
|
||||
side_effect=drop_subgraph_side_effect,
|
||||
),
|
||||
patch("tasks.jobs.deletion.graph_database.drop_database"),
|
||||
):
|
||||
deletion_summary = delete_tenant(tenant.id)
|
||||
|
||||
assert deletion_summary is not None
|
||||
|
||||
Reference in New Issue
Block a user