From f5f1f1ab2d83851c4cc939f4758df00c6e6cef8c Mon Sep 17 00:00:00 2001 From: Josema Camacho Date: Wed, 18 Mar 2026 09:49:45 +0100 Subject: [PATCH] fix(attack-paths): recover graph_data_ready when scan fails during graph swap (#10354) --- api/CHANGELOG.md | 4 + api/src/backend/api/attack_paths/database.py | 35 +- .../api/tests/test_attack_paths_database.py | 75 ++ .../tasks/jobs/attack_paths/db_utils.py | 43 +- .../backend/tasks/jobs/attack_paths/scan.py | 40 +- .../tasks/tests/test_attack_paths_scan.py | 682 +++++++++++++++++- 6 files changed, 843 insertions(+), 36 deletions(-) diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 8c5a47ad6b..c7ff2bf99e 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -13,6 +13,10 @@ All notable changes to the **Prowler API** are documented in this file. - Attack Paths: Complete migration to private graph labels and properties, removing deprecated dual-write support [(#10268)](https://github.com/prowler-cloud/prowler/pull/10268) - Attack Paths: Added tenant and provider related labels to the nodes so they can be easily filtered on custom queries [(#10308)](https://github.com/prowler-cloud/prowler/pull/10308) +### 🐞 Fixed + +- Attack Paths: Recover `graph_data_ready` flag when scan fails during graph swap, preventing query endpoints from staying blocked until the next successful scan [(#10354)](https://github.com/prowler-cloud/prowler/pull/10354) + ### 🔐 Security - Use `psycopg2.sql` to safely compose DDL in `PostgresEnumMigration`, preventing SQL injection via f-string interpolation [(#10166)](https://github.com/prowler-cloud/prowler/pull/10166) diff --git a/api/src/backend/api/attack_paths/database.py b/api/src/backend/api/attack_paths/database.py index 25172152e9..02083991ac 100644 --- a/api/src/backend/api/attack_paths/database.py +++ b/api/src/backend/api/attack_paths/database.py @@ -1,26 +1,22 @@ import atexit import logging import threading - -from typing import Any - from contextlib import contextmanager -from typing import Iterator +from typing import Any, Iterator from uuid import UUID import neo4j import neo4j.exceptions - -from django.conf import settings - -from api.attack_paths.retryable_session import RetryableSession from config.env import env +from django.conf import settings from tasks.jobs.attack_paths.config import ( BATCH_SIZE, PROVIDER_ID_PROPERTY, PROVIDER_RESOURCE_LABEL, ) +from api.attack_paths.retryable_session import RetryableSession + # Without this Celery goes crazy with Neo4j logging logging.getLogger("neo4j").setLevel(logging.ERROR) logging.getLogger("neo4j").propagate = False @@ -197,6 +193,29 @@ def drop_subgraph(database: str, provider_id: str) -> int: return deleted_nodes +def has_provider_data(database: str, provider_id: str) -> bool: + """ + Check if any ProviderResource node exists for this provider. + + Returns `False` if the database doesn't exist. + """ + query = ( + f"MATCH (n:{PROVIDER_RESOURCE_LABEL} " + f"{{{PROVIDER_ID_PROPERTY}: $provider_id}}) " + "RETURN 1 LIMIT 1" + ) + + try: + with get_session(database, default_access_mode=neo4j.READ_ACCESS) as session: + result = session.run(query, {"provider_id": provider_id}) + return result.single() is not None + + except GraphDatabaseQueryException as exc: + if exc.code == "Neo.ClientError.Database.DatabaseNotFound": + return False + raise + + def clear_cache(database: str) -> None: query = "CALL db.clearQueryCaches()" diff --git a/api/src/backend/api/tests/test_attack_paths_database.py b/api/src/backend/api/tests/test_attack_paths_database.py index 8b458cb7b7..7e07792a69 100644 --- a/api/src/backend/api/tests/test_attack_paths_database.py +++ b/api/src/backend/api/tests/test_attack_paths_database.py @@ -442,3 +442,78 @@ class TestThreadSafety: # All threads got the same driver instance assert all(r is mock_driver for r in results) assert len(results) == 10 + + +class TestHasProviderData: + """Test has_provider_data helper for checking provider nodes in Neo4j.""" + + def test_returns_true_when_nodes_exist(self): + import api.attack_paths.database as db_module + + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.single.return_value = MagicMock() # non-None record + mock_session.run.return_value = mock_result + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = mock_session + session_ctx.__exit__.return_value = False + + with patch( + "api.attack_paths.database.get_session", + return_value=session_ctx, + ): + assert db_module.has_provider_data("db-tenant-abc", "provider-123") is True + + mock_session.run.assert_called_once() + + def test_returns_false_when_no_nodes(self): + import api.attack_paths.database as db_module + + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.single.return_value = None + mock_session.run.return_value = mock_result + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = mock_session + session_ctx.__exit__.return_value = False + + with patch( + "api.attack_paths.database.get_session", + return_value=session_ctx, + ): + assert db_module.has_provider_data("db-tenant-abc", "provider-123") is False + + def test_returns_false_when_database_not_found(self): + import api.attack_paths.database as db_module + + session_ctx = MagicMock() + session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException( + message="Database does not exist", + code="Neo.ClientError.Database.DatabaseNotFound", + ) + + with patch( + "api.attack_paths.database.get_session", + return_value=session_ctx, + ): + assert ( + db_module.has_provider_data("db-tenant-gone", "provider-123") is False + ) + + def test_raises_on_other_errors(self): + import api.attack_paths.database as db_module + + session_ctx = MagicMock() + session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException( + message="Connection refused", + code="Neo.TransientError.General.UnknownError", + ) + + with patch( + "api.attack_paths.database.get_session", + return_value=session_ctx, + ): + with pytest.raises(db_module.GraphDatabaseQueryException): + db_module.has_provider_data("db-tenant-abc", "provider-123") diff --git a/api/src/backend/tasks/jobs/attack_paths/db_utils.py b/api/src/backend/tasks/jobs/attack_paths/db_utils.py index 9fb52b0ead..7d17ec07eb 100644 --- a/api/src/backend/tasks/jobs/attack_paths/db_utils.py +++ b/api/src/backend/tasks/jobs/attack_paths/db_utils.py @@ -3,15 +3,13 @@ from typing import Any from cartography.config import Config as CartographyConfig from celery.utils.log import get_task_logger +from tasks.jobs.attack_paths.config import is_provider_available from api.attack_paths import database as graph_database from api.db_utils import rls_transaction -from api.models import ( - AttackPathsScan as ProwlerAPIAttackPathsScan, - Provider as ProwlerAPIProvider, - StateChoices, -) -from tasks.jobs.attack_paths.config import is_provider_available +from api.models import AttackPathsScan as ProwlerAPIAttackPathsScan +from api.models import Provider as ProwlerAPIProvider +from api.models import StateChoices logger = get_task_logger(__name__) @@ -155,6 +153,37 @@ def set_provider_graph_data_ready( attack_paths_scan.refresh_from_db(fields=["graph_data_ready"]) +def recover_graph_data_ready( + attack_paths_scan: ProwlerAPIAttackPathsScan, +) -> None: + """ + Best-effort recovery of `graph_data_ready` after a scan failure. + + Queries Neo4j to check if the provider still has data in the tenant + database. If data exists, restores `graph_data_ready=True` for all scans + of this provider. Never raises. + + Trade-off: if the worker crashed mid-sync, partial data may exist and + this will re-enable queries against it. We accept that because leaving + `graph_data_ready=False` permanently (blocking all queries until the + next successful scan) is a worse outcome for the user. + """ + try: + tenant_db = graph_database.get_database_name(attack_paths_scan.tenant_id) + if graph_database.has_provider_data( + tenant_db, str(attack_paths_scan.provider_id) + ): + set_provider_graph_data_ready(attack_paths_scan, True) + logger.info( + f"Recovered `graph_data_ready` for provider {attack_paths_scan.provider_id}" + ) + + except Exception: + logger.exception( + f"Failed to recover `graph_data_ready` for provider {attack_paths_scan.provider_id}" + ) + + def fail_attack_paths_scan( tenant_id: str, scan_id: str, @@ -185,3 +214,5 @@ def fail_attack_paths_scan( StateChoices.FAILED, {"global_error": error}, ) + + recover_graph_data_ready(attack_paths_scan) diff --git a/api/src/backend/tasks/jobs/attack_paths/scan.py b/api/src/backend/tasks/jobs/attack_paths/scan.py index 6624680a5e..20dba74da8 100644 --- a/api/src/backend/tasks/jobs/attack_paths/scan.py +++ b/api/src/backend/tasks/jobs/attack_paths/scan.py @@ -55,7 +55,6 @@ exception propagates to Celery. import logging import time - from typing import Any from cartography.config import Config as CartographyConfig @@ -63,16 +62,14 @@ from cartography.intel import analysis as cartography_analysis from cartography.intel import create_indexes as cartography_create_indexes from cartography.intel import ontology as cartography_ontology from celery.utils.log import get_task_logger +from tasks.jobs.attack_paths import db_utils, findings, internet, sync, utils +from tasks.jobs.attack_paths.config import get_cartography_ingestion_function from api.attack_paths import database as graph_database from api.db_utils import rls_transaction -from api.models import ( - Provider as ProwlerAPIProvider, - StateChoices, -) +from api.models import Provider as ProwlerAPIProvider +from api.models import StateChoices from api.utils import initialize_prowler_provider -from tasks.jobs.attack_paths import db_utils, findings, internet, sync, utils -from tasks.jobs.attack_paths.config import get_cartography_ingestion_function # Without this Celery goes crazy with Cartography logging logging.getLogger("cartography").setLevel(logging.ERROR) @@ -147,6 +144,10 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]: attack_paths_scan, task_id, tenant_cartography_config ) + subgraph_dropped = False + sync_completed = False + provider_gated = False + try: logger.info( f"Creating Neo4j database {tmp_cartography_config.neo4j_database} for tenant {prowler_api_provider.tenant_id}" @@ -225,10 +226,12 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]: logger.info(f"Deleting existing provider graph in {tenant_database_name}") db_utils.set_provider_graph_data_ready(attack_paths_scan, False) + provider_gated = True graph_database.drop_subgraph( database=tenant_database_name, provider_id=str(prowler_api_provider.id), ) + subgraph_dropped = True db_utils.update_attack_paths_scan_progress(attack_paths_scan, 98) logger.info( @@ -240,6 +243,7 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]: tenant_id=str(prowler_api_provider.tenant_id), provider_id=str(prowler_api_provider.id), ) + sync_completed = True db_utils.set_graph_data_ready(attack_paths_scan, True) db_utils.update_attack_paths_scan_progress(attack_paths_scan, 99) @@ -264,23 +268,39 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]: logger.exception(exception_message) ingestion_exceptions["global_error"] = exception_message - # Handling databases changes + # Recover graph_data_ready based on how far the swap got. + # Partial drop (mid-batch failure) may leave `subgraph_dropped=False` + # with data partially deleted, so we prefer that over permanently blocked queries. + try: + if sync_completed: + db_utils.set_graph_data_ready(attack_paths_scan, True) + elif provider_gated and not subgraph_dropped: + db_utils.set_provider_graph_data_ready(attack_paths_scan, True) + + except Exception: + logger.error( + f"Failed to recover `graph_data_ready` for provider {attack_paths_scan.provider_id}", + exc_info=True, + ) + + # Dropping the temporary database if it still exists try: graph_database.drop_database(tmp_cartography_config.neo4j_database) except Exception as e: logger.error( - f"Failed to drop temporary Neo4j database {tmp_cartography_config.neo4j_database} during cleanup: {e}", + f"Failed to drop temporary Neo4j database `{tmp_cartography_config.neo4j_database}` during cleanup: {e}", exc_info=True, ) + # Set Attack Paths scan state to FAILED try: db_utils.finish_attack_paths_scan( attack_paths_scan, StateChoices.FAILED, ingestion_exceptions ) except Exception as e: logger.error( - f"Could not mark attack paths scan {attack_paths_scan.id} as FAILED (row may have been deleted): {e}", + f"Could not mark Attack Paths scan {attack_paths_scan.id} as `FAILED` (row may have been deleted): {e}", exc_info=True, ) diff --git a/api/src/backend/tasks/tests/test_attack_paths_scan.py b/api/src/backend/tasks/tests/test_attack_paths_scan.py index cd4dfcffd9..43a2420321 100644 --- a/api/src/backend/tasks/tests/test_attack_paths_scan.py +++ b/api/src/backend/tasks/tests/test_attack_paths_scan.py @@ -269,6 +269,106 @@ class TestAttackPathsRun: assert failure_args[1] == StateChoices.FAILED assert failure_args[2] == {"global_error": "Cartography failed: ingestion boom"} + @patch( + "tasks.jobs.attack_paths.scan.utils.stringify_exception", + return_value="Cartography failed: ingestion boom", + ) + @patch( + "tasks.jobs.attack_paths.scan.utils.call_within_event_loop", + side_effect=lambda fn, *a, **kw: fn(*a, **kw), + ) + @patch("tasks.jobs.attack_paths.scan.graph_database.drop_database") + @patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress") + @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.findings.analysis") + @patch("tasks.jobs.attack_paths.scan.internet.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.create_findings_indexes") + @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") + @patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") + @patch("tasks.jobs.attack_paths.scan.graph_database.create_database") + @patch( + "tasks.jobs.attack_paths.scan.graph_database.get_database_name", + return_value="db-scan-id", + ) + @patch("tasks.jobs.attack_paths.scan.graph_database.get_uri") + @patch( + "tasks.jobs.attack_paths.scan.initialize_prowler_provider", + return_value=MagicMock(_enabled_regions=["us-east-1"]), + ) + @patch( + "tasks.jobs.attack_paths.scan.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + def test_failure_before_gate_does_not_flip_graph_data_ready_true( + self, + mock_init_provider, + mock_get_uri, + mock_get_db_name, + mock_create_db, + mock_cartography_indexes, + mock_cartography_analysis, + mock_findings_indexes, + mock_internet_analysis, + mock_findings_analysis, + mock_starting, + mock_update_progress, + mock_set_provider_graph_data_ready, + mock_set_graph_data_ready, + mock_finish, + mock_drop_db, + mock_event_loop, + mock_stringify, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + """Failure during ingestion (before set_provider_graph_data_ready(False)) + must NOT flip graph_data_ready to True for providers that never had data.""" + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.SCHEDULED, + ) + + mock_session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = mock_session + session_ctx.__exit__.return_value = False + ingestion_fn = MagicMock(side_effect=RuntimeError("ingestion boom")) + + with ( + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_session", + return_value=session_ctx, + ), + patch( + "tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch( + "tasks.jobs.attack_paths.scan.get_cartography_ingestion_function", + return_value=ingestion_fn, + ), + ): + with pytest.raises(RuntimeError, match="ingestion boom"): + attack_paths_run(str(tenant.id), str(scan.id), "task-456") + + # Gate was never applied, so recovery must not flip anything to True + mock_set_provider_graph_data_ready.assert_not_called() + mock_set_graph_data_ready.assert_not_called() + @patch( "tasks.jobs.attack_paths.scan.utils.stringify_exception", return_value="Cartography failed: ingestion boom", @@ -371,6 +471,465 @@ class TestAttackPathsRun: assert failure_args[1] == StateChoices.FAILED assert failure_args[2] == {"global_error": "Cartography failed: ingestion boom"} + @patch( + "tasks.jobs.attack_paths.scan.utils.stringify_exception", + return_value="Attack Paths scan failed: drop failed", + ) + @patch( + "tasks.jobs.attack_paths.scan.utils.call_within_event_loop", + side_effect=lambda fn, *a, **kw: fn(*a, **kw), + ) + @patch("tasks.jobs.attack_paths.scan.graph_database.drop_database") + @patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress") + @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.sync.sync_graph") + @patch( + "tasks.jobs.attack_paths.scan.graph_database.drop_subgraph", + side_effect=RuntimeError("drop failed"), + ) + @patch("tasks.jobs.attack_paths.scan.sync.create_sync_indexes") + @patch("tasks.jobs.attack_paths.scan.internet.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.create_findings_indexes") + @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") + @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") + @patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") + @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") + @patch("tasks.jobs.attack_paths.scan.graph_database.create_database") + @patch( + "tasks.jobs.attack_paths.scan.graph_database.get_uri", + return_value="bolt://neo4j", + ) + @patch( + "tasks.jobs.attack_paths.scan.initialize_prowler_provider", + return_value=MagicMock(_enabled_regions=["us-east-1"]), + ) + @patch( + "tasks.jobs.attack_paths.scan.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + def test_failure_after_gate_before_drop_restores_graph_data_ready( + self, + mock_init_provider, + mock_get_uri, + mock_create_db, + mock_clear_cache, + mock_cartography_indexes, + mock_cartography_analysis, + mock_cartography_ontology, + mock_findings_indexes, + mock_findings_analysis, + mock_internet_analysis, + mock_sync_indexes, + mock_drop_subgraph, + mock_sync, + mock_starting, + mock_update_progress, + mock_set_provider_graph_data_ready, + mock_set_graph_data_ready, + mock_finish, + mock_drop_db, + mock_event_loop, + mock_stringify, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.SCHEDULED, + ) + + mock_session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = mock_session + session_ctx.__exit__.return_value = False + + with ( + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_database_name", + side_effect=["db-scan-id", "tenant-db"], + ), + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_session", + return_value=session_ctx, + ), + patch( + "tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch( + "tasks.jobs.attack_paths.scan.get_cartography_ingestion_function", + return_value=MagicMock(return_value={}), + ), + ): + with pytest.raises(RuntimeError, match="drop failed"): + attack_paths_run(str(tenant.id), str(scan.id), "task-456") + + assert mock_set_provider_graph_data_ready.call_args_list == [ + call(attack_paths_scan, False), + call(attack_paths_scan, True), + ] + + @patch( + "tasks.jobs.attack_paths.scan.utils.stringify_exception", + return_value="Attack Paths scan failed: sync failed", + ) + @patch( + "tasks.jobs.attack_paths.scan.utils.call_within_event_loop", + side_effect=lambda fn, *a, **kw: fn(*a, **kw), + ) + @patch("tasks.jobs.attack_paths.scan.graph_database.drop_database") + @patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress") + @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") + @patch( + "tasks.jobs.attack_paths.scan.sync.sync_graph", + side_effect=RuntimeError("sync failed"), + ) + @patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph") + @patch("tasks.jobs.attack_paths.scan.sync.create_sync_indexes") + @patch("tasks.jobs.attack_paths.scan.internet.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.create_findings_indexes") + @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") + @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") + @patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") + @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") + @patch("tasks.jobs.attack_paths.scan.graph_database.create_database") + @patch( + "tasks.jobs.attack_paths.scan.graph_database.get_uri", + return_value="bolt://neo4j", + ) + @patch( + "tasks.jobs.attack_paths.scan.initialize_prowler_provider", + return_value=MagicMock(_enabled_regions=["us-east-1"]), + ) + @patch( + "tasks.jobs.attack_paths.scan.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + def test_failure_after_drop_before_sync_leaves_graph_data_ready_false( + self, + mock_init_provider, + mock_get_uri, + mock_create_db, + mock_clear_cache, + mock_cartography_indexes, + mock_cartography_analysis, + mock_cartography_ontology, + mock_findings_indexes, + mock_findings_analysis, + mock_internet_analysis, + mock_sync_indexes, + mock_drop_subgraph, + mock_sync, + mock_starting, + mock_update_progress, + mock_set_provider_graph_data_ready, + mock_set_graph_data_ready, + mock_finish, + mock_drop_db, + mock_event_loop, + mock_stringify, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.SCHEDULED, + ) + + mock_session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = mock_session + session_ctx.__exit__.return_value = False + + with ( + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_database_name", + side_effect=["db-scan-id", "tenant-db"], + ), + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_session", + return_value=session_ctx, + ), + patch( + "tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch( + "tasks.jobs.attack_paths.scan.get_cartography_ingestion_function", + return_value=MagicMock(return_value={}), + ), + ): + with pytest.raises(RuntimeError, match="sync failed"): + attack_paths_run(str(tenant.id), str(scan.id), "task-456") + + # Only called with False (gate), never with True (no recovery for partial data) + mock_set_provider_graph_data_ready.assert_called_once_with( + attack_paths_scan, False + ) + + @patch( + "tasks.jobs.attack_paths.scan.utils.stringify_exception", + return_value="Attack Paths scan failed: flag failed", + ) + @patch( + "tasks.jobs.attack_paths.scan.utils.call_within_event_loop", + side_effect=lambda fn, *a, **kw: fn(*a, **kw), + ) + @patch("tasks.jobs.attack_paths.scan.graph_database.drop_database") + @patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan") + @patch( + "tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready", + side_effect=[RuntimeError("flag failed"), None], + ) + @patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress") + @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.sync.sync_graph") + @patch("tasks.jobs.attack_paths.scan.graph_database.drop_subgraph") + @patch("tasks.jobs.attack_paths.scan.sync.create_sync_indexes") + @patch("tasks.jobs.attack_paths.scan.internet.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.create_findings_indexes") + @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") + @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") + @patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") + @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") + @patch("tasks.jobs.attack_paths.scan.graph_database.create_database") + @patch( + "tasks.jobs.attack_paths.scan.graph_database.get_uri", + return_value="bolt://neo4j", + ) + @patch( + "tasks.jobs.attack_paths.scan.initialize_prowler_provider", + return_value=MagicMock(_enabled_regions=["us-east-1"]), + ) + @patch( + "tasks.jobs.attack_paths.scan.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + def test_failure_after_sync_restores_graph_data_ready( + self, + mock_init_provider, + mock_get_uri, + mock_create_db, + mock_clear_cache, + mock_cartography_indexes, + mock_cartography_analysis, + mock_cartography_ontology, + mock_findings_indexes, + mock_findings_analysis, + mock_internet_analysis, + mock_sync_indexes, + mock_drop_subgraph, + mock_sync, + mock_starting, + mock_update_progress, + mock_set_provider_graph_data_ready, + mock_set_graph_data_ready, + mock_finish, + mock_drop_db, + mock_event_loop, + mock_stringify, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.SCHEDULED, + ) + + mock_session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = mock_session + session_ctx.__exit__.return_value = False + + with ( + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_database_name", + side_effect=["db-scan-id", "tenant-db"], + ), + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_session", + return_value=session_ctx, + ), + patch( + "tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch( + "tasks.jobs.attack_paths.scan.get_cartography_ingestion_function", + return_value=MagicMock(return_value={}), + ), + ): + with pytest.raises(RuntimeError, match="flag failed"): + attack_paths_run(str(tenant.id), str(scan.id), "task-456") + + # sync completed: first call (normal path) raised, recovery retried and succeeded + assert mock_set_graph_data_ready.call_args_list == [ + call(attack_paths_scan, True), + call(attack_paths_scan, True), + ] + # set_provider_graph_data_ready only called once with False (the gate) + mock_set_provider_graph_data_ready.assert_called_once_with( + attack_paths_scan, False + ) + + @patch( + "tasks.jobs.attack_paths.scan.utils.stringify_exception", + return_value="Attack Paths scan failed: drop failed", + ) + @patch( + "tasks.jobs.attack_paths.scan.utils.call_within_event_loop", + side_effect=lambda fn, *a, **kw: fn(*a, **kw), + ) + @patch("tasks.jobs.attack_paths.scan.graph_database.drop_database") + @patch("tasks.jobs.attack_paths.scan.db_utils.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.set_provider_graph_data_ready") + @patch("tasks.jobs.attack_paths.scan.db_utils.update_attack_paths_scan_progress") + @patch("tasks.jobs.attack_paths.scan.db_utils.starting_attack_paths_scan") + @patch("tasks.jobs.attack_paths.scan.sync.sync_graph") + @patch( + "tasks.jobs.attack_paths.scan.graph_database.drop_subgraph", + side_effect=RuntimeError("drop failed"), + ) + @patch("tasks.jobs.attack_paths.scan.sync.create_sync_indexes") + @patch("tasks.jobs.attack_paths.scan.internet.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.analysis") + @patch("tasks.jobs.attack_paths.scan.findings.create_findings_indexes") + @patch("tasks.jobs.attack_paths.scan.cartography_ontology.run") + @patch("tasks.jobs.attack_paths.scan.cartography_analysis.run") + @patch("tasks.jobs.attack_paths.scan.cartography_create_indexes.run") + @patch("tasks.jobs.attack_paths.scan.graph_database.clear_cache") + @patch("tasks.jobs.attack_paths.scan.graph_database.create_database") + @patch( + "tasks.jobs.attack_paths.scan.graph_database.get_uri", + return_value="bolt://neo4j", + ) + @patch( + "tasks.jobs.attack_paths.scan.initialize_prowler_provider", + return_value=MagicMock(_enabled_regions=["us-east-1"]), + ) + @patch( + "tasks.jobs.attack_paths.scan.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + def test_recovery_failure_does_not_suppress_original_exception( + self, + mock_init_provider, + mock_get_uri, + mock_create_db, + mock_clear_cache, + mock_cartography_indexes, + mock_cartography_analysis, + mock_cartography_ontology, + mock_findings_indexes, + mock_findings_analysis, + mock_internet_analysis, + mock_sync_indexes, + mock_drop_subgraph, + mock_sync, + mock_starting, + mock_update_progress, + mock_set_provider_graph_data_ready, + mock_set_graph_data_ready, + mock_finish, + mock_drop_db, + mock_event_loop, + mock_stringify, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.SCHEDULED, + ) + + # Recovery itself fails on the second call (True) + mock_set_provider_graph_data_ready.side_effect = [ + None, + RuntimeError("recovery boom"), + ] + + mock_session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = mock_session + session_ctx.__exit__.return_value = False + + with ( + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_database_name", + side_effect=["db-scan-id", "tenant-db"], + ), + patch( + "tasks.jobs.attack_paths.scan.graph_database.get_session", + return_value=session_ctx, + ), + patch( + "tasks.jobs.attack_paths.scan.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch( + "tasks.jobs.attack_paths.scan.get_cartography_ingestion_function", + return_value=MagicMock(return_value={}), + ), + ): + # Original exception propagates despite recovery failure + with pytest.raises(RuntimeError, match="drop failed"): + attack_paths_run(str(tenant.id), str(scan.id), "task-456") + def test_run_returns_early_for_unsupported_provider(self, tenants_fixture): tenant = tenants_fixture[0] provider = Provider.objects.create( @@ -419,9 +978,7 @@ class TestFailAttackPathsScan: def test_marks_executing_scan_as_failed( self, tenants_fixture, providers_fixture, scans_fixture ): - from tasks.jobs.attack_paths.db_utils import ( - fail_attack_paths_scan, - ) + from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan tenant = tenants_fixture[0] provider = providers_fixture[0] @@ -449,6 +1006,7 @@ class TestFailAttackPathsScan: patch( "tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan" ) as mock_finish, + patch("tasks.jobs.attack_paths.db_utils.recover_graph_data_ready"), ): fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded") @@ -464,9 +1022,7 @@ class TestFailAttackPathsScan: def test_drops_temp_database_even_when_drop_fails( self, tenants_fixture, providers_fixture, scans_fixture ): - from tasks.jobs.attack_paths.db_utils import ( - fail_attack_paths_scan, - ) + from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan tenant = tenants_fixture[0] provider = providers_fixture[0] @@ -495,6 +1051,7 @@ class TestFailAttackPathsScan: patch( "tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan" ) as mock_finish, + patch("tasks.jobs.attack_paths.db_utils.recover_graph_data_ready"), ): fail_attack_paths_scan(str(tenant.id), str(scan.id), "setup exploded") @@ -507,9 +1064,7 @@ class TestFailAttackPathsScan: def test_skips_already_failed_scan( self, tenants_fixture, providers_fixture, scans_fixture ): - from tasks.jobs.attack_paths.db_utils import ( - fail_attack_paths_scan, - ) + from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan tenant = tenants_fixture[0] provider = providers_fixture[0] @@ -544,9 +1099,7 @@ class TestFailAttackPathsScan: mock_finish.assert_not_called() def test_skips_when_no_scan_found(self, tenants_fixture): - from tasks.jobs.attack_paths.db_utils import ( - fail_attack_paths_scan, - ) + from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan tenant = tenants_fixture[0] @@ -563,6 +1116,111 @@ class TestFailAttackPathsScan: mock_finish.assert_not_called() + def test_fail_recovers_graph_data_ready_when_data_exists( + self, tenants_fixture, providers_fixture, scans_fixture + ): + from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.EXECUTING, + ) + + with ( + patch( + "tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"), + patch("tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"), + patch( + "tasks.jobs.attack_paths.db_utils.graph_database.has_provider_data", + return_value=True, + ), + patch( + "tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready" + ) as mock_set_ready, + ): + fail_attack_paths_scan(str(tenant.id), str(scan.id), "worker died") + + mock_set_ready.assert_called_once_with(attack_paths_scan, True) + + def test_fail_leaves_graph_data_ready_false_when_no_data( + self, tenants_fixture, providers_fixture, scans_fixture + ): + from tasks.jobs.attack_paths.db_utils import fail_attack_paths_scan + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.EXECUTING, + ) + + with ( + patch( + "tasks.jobs.attack_paths.db_utils.retrieve_attack_paths_scan", + return_value=attack_paths_scan, + ), + patch("tasks.jobs.attack_paths.db_utils.graph_database.drop_database"), + patch("tasks.jobs.attack_paths.db_utils.finish_attack_paths_scan"), + patch( + "tasks.jobs.attack_paths.db_utils.graph_database.has_provider_data", + return_value=False, + ), + patch( + "tasks.jobs.attack_paths.db_utils.set_provider_graph_data_ready" + ) as mock_set_ready, + ): + fail_attack_paths_scan(str(tenant.id), str(scan.id), "worker died") + + mock_set_ready.assert_not_called() + + def test_recover_graph_data_ready_never_raises( + self, tenants_fixture, providers_fixture, scans_fixture + ): + from tasks.jobs.attack_paths.db_utils import recover_graph_data_ready + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + scan = scans_fixture[0] + scan.provider = provider + scan.save() + + attack_paths_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.EXECUTING, + ) + + with patch( + "tasks.jobs.attack_paths.db_utils.graph_database.has_provider_data", + side_effect=Exception("Neo4j unreachable"), + ): + # Should not raise + recover_graph_data_ready(attack_paths_scan) + class TestAttackPathsScanRLSTaskOnFailure: def test_on_failure_delegates_to_fail_attack_paths_scan(self):