fix(api): handle deleted scans during progress saves (#11697)

Co-authored-by: Josema Camacho <josema@prowler.com>
This commit is contained in:
Prowler Bot
2026-06-25 15:34:52 +02:00
committed by GitHub
parent 3d8cd467d6
commit c35ab7e91a
3 changed files with 147 additions and 33 deletions
+8
View File
@@ -2,6 +2,14 @@
All notable changes to the **Prowler API** are documented in this file.
## [1.32.2] (Prowler UNRELEASED)
### 🐞 Fixed
- `scan-perform` no longer reports an error when a provider is deleted during a running scan [(#11696)](https://github.com/prowler-cloud/prowler/pull/11696)
---
## [1.32.1] (Prowler v5.31.1)
### 🐞 Fixed
+55 -19
View File
@@ -14,7 +14,7 @@ from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
from config.env import env
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
from django.db import IntegrityError, OperationalError
from django.db import DatabaseError, IntegrityError, OperationalError, transaction
from django.db.models import (
Case,
Count,
@@ -43,7 +43,7 @@ from api.db_utils import (
psycopg_connection,
rls_transaction,
)
from api.exceptions import ProviderConnectionError
from api.exceptions import ProviderConnectionError, ProviderDeletedException
from api.models import (
AttackSurfaceOverview,
ComplianceOverviewSummary,
@@ -118,6 +118,20 @@ ATTACK_SURFACE_PROVIDER_COMPATIBILITY = {
_ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {}
def _save_scan_instance(
scan_instance: Scan, provider_id: str, update_fields: list[str]
) -> None:
try:
with transaction.atomic(): # Savepoint for not killing the `rls_transaction`
scan_instance.save(update_fields=update_fields)
except DatabaseError:
if Scan.objects.filter(pk=scan_instance.id).exists():
raise
raise ProviderDeletedException(
f"Provider '{provider_id}' for scan '{scan_instance.id}' was deleted during the scan"
) from None
def aggregate_category_counts(
categories: list[str],
severity: str,
@@ -1030,13 +1044,18 @@ def perform_prowler_scan(
group_resources_cache: dict[str, set] = {}
start_time = time.time()
exc = None
skip_final_scan_update = False
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
scan_instance = Scan.objects.get(pk=scan_id)
scan_instance.state = StateChoices.EXECUTING
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.save(update_fields=["state", "started_at", "updated_at"])
_save_scan_instance(
scan_instance,
provider_id,
["state", "started_at", "updated_at"],
)
# Find the mutelist processor if it exists
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
@@ -1104,7 +1123,7 @@ def perform_prowler_scan(
# Throttle scan_instance progress writes to avoid hammering the writer:
# only persist when progress moves by at least `PROGRESS_THROTTLE_DELTA`
# OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (1.0)
# OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (100)
# always persists in the `finally` block below.
last_persisted_progress = -1.0
last_persisted_progress_at = 0.0
@@ -1146,7 +1165,11 @@ def perform_prowler_scan(
):
with rls_transaction(tenant_id):
scan_instance.progress = progress
scan_instance.save(update_fields=["progress", "updated_at"])
_save_scan_instance(
scan_instance,
provider_id,
["progress", "updated_at"],
)
last_persisted_progress = progress
last_persisted_progress_at = now
@@ -1173,26 +1196,39 @@ def perform_prowler_scan(
batch_size=SCAN_DB_BATCH_SIZE,
)
except ProviderDeletedException as e:
logger.warning(str(e))
exception = e
skip_final_scan_update = True
except Exception as e:
logger.error(f"Error performing scan {scan_id}: {e}")
exception = e
scan_instance.state = StateChoices.FAILED
finally:
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=timezone.utc)
scan_instance.unique_resource_count = len(unique_resources)
scan_instance.save(
update_fields=[
"state",
"duration",
"completed_at",
"unique_resource_count",
"progress",
"updated_at",
]
)
if not skip_final_scan_update:
try:
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=timezone.utc)
scan_instance.unique_resource_count = len(unique_resources)
if exception is None:
scan_instance.progress = 100
_save_scan_instance(
scan_instance,
provider_id,
[
"state",
"duration",
"completed_at",
"unique_resource_count",
"progress",
"updated_at",
],
)
except ProviderDeletedException as e:
logger.warning(str(e))
exception = e
if exception is not None:
raise exception
+84 -14
View File
@@ -8,6 +8,19 @@ from io import StringIO
from unittest.mock import MagicMock, patch
import pytest
from api.db_router import MainRouter
from api.exceptions import ProviderConnectionError, ProviderDeletedException
from api.models import (
Finding,
MuteRule,
Provider,
Resource,
ResourceScanSummary,
Scan,
ScanSummary,
StateChoices,
StatusChoices,
)
from tasks.jobs.scan import (
_ATTACK_SURFACE_MAPPING_CACHE,
_aggregate_findings_by_region,
@@ -29,19 +42,6 @@ from tasks.jobs.scan import (
)
from tasks.utils import CustomEncoder
from api.db_router import MainRouter
from api.exceptions import ProviderConnectionError
from api.models import (
Finding,
MuteRule,
Provider,
Resource,
ResourceScanSummary,
Scan,
ScanSummary,
StateChoices,
StatusChoices,
)
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
@@ -263,6 +263,75 @@ class TestPerformScan:
assert provider.connected is False
assert isinstance(provider.connection_last_checked_at, datetime)
def test_perform_prowler_scan_provider_deleted_during_progress_update(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
tenant_id = str(tenant.id)
scan_id = str(scan.id)
provider_id = str(provider.id)
def scan_results():
Provider.objects.filter(pk=provider_id).delete()
yield 50, []
with (
patch(
"tasks.jobs.scan.initialize_prowler_provider",
return_value=MagicMock(),
),
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
patch("tasks.jobs.scan.logger.error") as mock_logger_error,
):
mock_prowler_scan_instance = MagicMock()
mock_prowler_scan_instance.scan.return_value = scan_results()
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
with pytest.raises(ProviderDeletedException):
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
mock_logger_error.assert_not_called()
assert not Scan.objects.filter(pk=scan_id).exists()
def test_perform_prowler_scan_sets_final_progress_when_progress_updates_are_throttled(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
tenant_id = str(tenant.id)
scan_id = str(scan.id)
provider_id = str(provider.id)
with (
patch(
"tasks.jobs.scan.initialize_prowler_provider",
return_value=MagicMock(),
),
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
patch("tasks.jobs.scan.PROGRESS_THROTTLE_DELTA", 200),
patch("tasks.jobs.scan.PROGRESS_THROTTLE_SECONDS", 3600),
):
mock_prowler_scan_instance = MagicMock()
mock_prowler_scan_instance.scan.return_value = [(99, []), (100, [])]
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
scan.refresh_from_db()
assert scan.state == StateChoices.COMPLETED
assert scan.progress == 100
@pytest.mark.parametrize(
"last_status, new_status, expected_delta",
[
@@ -4591,9 +4660,10 @@ class TestScanIsFullScope:
# If the SDK adds a new filter, this test still passes via the
# introspection-driven derivation; if it adds a non-filter kwarg
# (e.g. provider-like), keep the exclusion list in sync in models.py.
from prowler.lib.scan.scan import Scan as ProwlerScan
import inspect
from prowler.lib.scan.scan import Scan as ProwlerScan
expected = tuple(
name
for name in inspect.signature(ProwlerScan.__init__).parameters