diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 3d7abfab2e..120b04dbb3 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -2,6 +2,14 @@ All notable changes to the **Prowler API** are documented in this file. +## [1.32.2] (Prowler UNRELEASED) + +### 🐞 Fixed + +- `scan-perform` no longer reports an error when a provider is deleted during a running scan [(#11696)](https://github.com/prowler-cloud/prowler/pull/11696) + +--- + ## [1.32.1] (Prowler v5.31.1) ### 🐞 Fixed diff --git a/api/src/backend/tasks/jobs/scan.py b/api/src/backend/tasks/jobs/scan.py index dcaacf6642..d69e0c8941 100644 --- a/api/src/backend/tasks/jobs/scan.py +++ b/api/src/backend/tasks/jobs/scan.py @@ -19,7 +19,7 @@ from api.db_utils import ( psycopg_connection, rls_transaction, ) -from api.exceptions import ProviderConnectionError +from api.exceptions import ProviderConnectionError, ProviderDeletedException from api.models import ( AttackSurfaceOverview, ComplianceOverviewSummary, @@ -48,7 +48,7 @@ from celery.utils.log import get_task_logger from config.django.base import DJANGO_FINDINGS_BATCH_SIZE from config.env import env from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS -from django.db import IntegrityError, OperationalError +from django.db import DatabaseError, IntegrityError, OperationalError, transaction from django.db.models import ( Case, Count, @@ -117,6 +117,20 @@ ATTACK_SURFACE_PROVIDER_COMPATIBILITY = { _ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {} +def _save_scan_instance( + scan_instance: Scan, provider_id: str, update_fields: list[str] +) -> None: + try: + with transaction.atomic(): # Savepoint for not killing the `rls_transaction` + scan_instance.save(update_fields=update_fields) + except DatabaseError: + if Scan.objects.filter(pk=scan_instance.id).exists(): + raise + raise ProviderDeletedException( + f"Provider '{provider_id}' for scan '{scan_instance.id}' was deleted during the scan" + ) from None + + def aggregate_category_counts( categories: list[str], severity: str, @@ -1029,13 +1043,18 @@ def perform_prowler_scan( group_resources_cache: dict[str, set] = {} start_time = time.time() exc = None + skip_final_scan_update = False with rls_transaction(tenant_id): provider_instance = Provider.objects.get(pk=provider_id) scan_instance = Scan.objects.get(pk=scan_id) scan_instance.state = StateChoices.EXECUTING scan_instance.started_at = datetime.now(tz=UTC) - scan_instance.save(update_fields=["state", "started_at", "updated_at"]) + _save_scan_instance( + scan_instance, + provider_id, + ["state", "started_at", "updated_at"], + ) # Find the mutelist processor if it exists with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): @@ -1101,7 +1120,7 @@ def perform_prowler_scan( # Throttle scan_instance progress writes to avoid hammering the writer: # only persist when progress moves by at least `PROGRESS_THROTTLE_DELTA` - # OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (1.0) + # OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (100) # always persists in the `finally` block below. last_persisted_progress = -1.0 last_persisted_progress_at = 0.0 @@ -1143,7 +1162,11 @@ def perform_prowler_scan( ): with rls_transaction(tenant_id): scan_instance.progress = progress - scan_instance.save(update_fields=["progress", "updated_at"]) + _save_scan_instance( + scan_instance, + provider_id, + ["progress", "updated_at"], + ) last_persisted_progress = progress last_persisted_progress_at = now @@ -1170,26 +1193,39 @@ def perform_prowler_scan( batch_size=SCAN_DB_BATCH_SIZE, ) + except ProviderDeletedException as e: + logger.warning(str(e)) + exception = e + skip_final_scan_update = True except Exception as e: logger.error(f"Error performing scan {scan_id}: {e}") exception = e scan_instance.state = StateChoices.FAILED finally: - with rls_transaction(tenant_id): - scan_instance.duration = time.time() - start_time - scan_instance.completed_at = datetime.now(tz=UTC) - scan_instance.unique_resource_count = len(unique_resources) - scan_instance.save( - update_fields=[ - "state", - "duration", - "completed_at", - "unique_resource_count", - "progress", - "updated_at", - ] - ) + if not skip_final_scan_update: + try: + with rls_transaction(tenant_id): + scan_instance.duration = time.time() - start_time + scan_instance.completed_at = datetime.now(tz=UTC) + scan_instance.unique_resource_count = len(unique_resources) + if exception is None: + scan_instance.progress = 100 + _save_scan_instance( + scan_instance, + provider_id, + [ + "state", + "duration", + "completed_at", + "unique_resource_count", + "progress", + "updated_at", + ], + ) + except ProviderDeletedException as e: + logger.warning(str(e)) + exception = e if exception is not None: raise exception diff --git a/api/src/backend/tasks/tests/test_scan.py b/api/src/backend/tasks/tests/test_scan.py index 17b3b65fc4..2d251b247f 100644 --- a/api/src/backend/tasks/tests/test_scan.py +++ b/api/src/backend/tasks/tests/test_scan.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch import pytest from api.db_router import MainRouter -from api.exceptions import ProviderConnectionError +from api.exceptions import ProviderConnectionError, ProviderDeletedException from api.models import ( Finding, MuteRule, @@ -262,6 +262,75 @@ class TestPerformScan: assert provider.connected is False assert isinstance(provider.connection_last_checked_at, datetime) + def test_perform_prowler_scan_provider_deleted_during_progress_update( + self, + tenants_fixture, + scans_fixture, + providers_fixture, + ): + tenant = tenants_fixture[0] + scan = scans_fixture[0] + provider = providers_fixture[0] + + tenant_id = str(tenant.id) + scan_id = str(scan.id) + provider_id = str(provider.id) + + def scan_results(): + Provider.objects.filter(pk=provider_id).delete() + yield 50, [] + + with ( + patch( + "tasks.jobs.scan.initialize_prowler_provider", + return_value=MagicMock(), + ), + patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class, + patch("tasks.jobs.scan.logger.error") as mock_logger_error, + ): + mock_prowler_scan_instance = MagicMock() + mock_prowler_scan_instance.scan.return_value = scan_results() + mock_prowler_scan_class.return_value = mock_prowler_scan_instance + + with pytest.raises(ProviderDeletedException): + perform_prowler_scan(tenant_id, scan_id, provider_id, []) + + mock_logger_error.assert_not_called() + assert not Scan.objects.filter(pk=scan_id).exists() + + def test_perform_prowler_scan_sets_final_progress_when_progress_updates_are_throttled( + self, + tenants_fixture, + scans_fixture, + providers_fixture, + ): + tenant = tenants_fixture[0] + scan = scans_fixture[0] + provider = providers_fixture[0] + + tenant_id = str(tenant.id) + scan_id = str(scan.id) + provider_id = str(provider.id) + + with ( + patch( + "tasks.jobs.scan.initialize_prowler_provider", + return_value=MagicMock(), + ), + patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class, + patch("tasks.jobs.scan.PROGRESS_THROTTLE_DELTA", 200), + patch("tasks.jobs.scan.PROGRESS_THROTTLE_SECONDS", 3600), + ): + mock_prowler_scan_instance = MagicMock() + mock_prowler_scan_instance.scan.return_value = [(99, []), (100, [])] + mock_prowler_scan_class.return_value = mock_prowler_scan_instance + + perform_prowler_scan(tenant_id, scan_id, provider_id, []) + + scan.refresh_from_db() + assert scan.state == StateChoices.COMPLETED + assert scan.progress == 100 + @pytest.mark.parametrize( "last_status, new_status, expected_delta", [