mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
fix(api): handle deleted scans during progress saves (#11696)
This commit is contained in:
@@ -2,6 +2,14 @@
|
||||
|
||||
All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.32.2] (Prowler UNRELEASED)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- `scan-perform` no longer reports an error when a provider is deleted during a running scan [(#11696)](https://github.com/prowler-cloud/prowler/pull/11696)
|
||||
|
||||
---
|
||||
|
||||
## [1.32.1] (Prowler v5.31.1)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
@@ -19,7 +19,7 @@ from api.db_utils import (
|
||||
psycopg_connection,
|
||||
rls_transaction,
|
||||
)
|
||||
from api.exceptions import ProviderConnectionError
|
||||
from api.exceptions import ProviderConnectionError, ProviderDeletedException
|
||||
from api.models import (
|
||||
AttackSurfaceOverview,
|
||||
ComplianceOverviewSummary,
|
||||
@@ -48,7 +48,7 @@ from celery.utils.log import get_task_logger
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
|
||||
from config.env import env
|
||||
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
|
||||
from django.db import IntegrityError, OperationalError
|
||||
from django.db import DatabaseError, IntegrityError, OperationalError, transaction
|
||||
from django.db.models import (
|
||||
Case,
|
||||
Count,
|
||||
@@ -117,6 +117,20 @@ ATTACK_SURFACE_PROVIDER_COMPATIBILITY = {
|
||||
_ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _save_scan_instance(
|
||||
scan_instance: Scan, provider_id: str, update_fields: list[str]
|
||||
) -> None:
|
||||
try:
|
||||
with transaction.atomic(): # Savepoint for not killing the `rls_transaction`
|
||||
scan_instance.save(update_fields=update_fields)
|
||||
except DatabaseError:
|
||||
if Scan.objects.filter(pk=scan_instance.id).exists():
|
||||
raise
|
||||
raise ProviderDeletedException(
|
||||
f"Provider '{provider_id}' for scan '{scan_instance.id}' was deleted during the scan"
|
||||
) from None
|
||||
|
||||
|
||||
def aggregate_category_counts(
|
||||
categories: list[str],
|
||||
severity: str,
|
||||
@@ -1029,13 +1043,18 @@ def perform_prowler_scan(
|
||||
group_resources_cache: dict[str, set] = {}
|
||||
start_time = time.time()
|
||||
exc = None
|
||||
skip_final_scan_update = False
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
provider_instance = Provider.objects.get(pk=provider_id)
|
||||
scan_instance = Scan.objects.get(pk=scan_id)
|
||||
scan_instance.state = StateChoices.EXECUTING
|
||||
scan_instance.started_at = datetime.now(tz=UTC)
|
||||
scan_instance.save(update_fields=["state", "started_at", "updated_at"])
|
||||
_save_scan_instance(
|
||||
scan_instance,
|
||||
provider_id,
|
||||
["state", "started_at", "updated_at"],
|
||||
)
|
||||
|
||||
# Find the mutelist processor if it exists
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
@@ -1101,7 +1120,7 @@ def perform_prowler_scan(
|
||||
|
||||
# Throttle scan_instance progress writes to avoid hammering the writer:
|
||||
# only persist when progress moves by at least `PROGRESS_THROTTLE_DELTA`
|
||||
# OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (1.0)
|
||||
# OR `PROGRESS_THROTTLE_SECONDS` have elapsed. The final progress (100)
|
||||
# always persists in the `finally` block below.
|
||||
last_persisted_progress = -1.0
|
||||
last_persisted_progress_at = 0.0
|
||||
@@ -1143,7 +1162,11 @@ def perform_prowler_scan(
|
||||
):
|
||||
with rls_transaction(tenant_id):
|
||||
scan_instance.progress = progress
|
||||
scan_instance.save(update_fields=["progress", "updated_at"])
|
||||
_save_scan_instance(
|
||||
scan_instance,
|
||||
provider_id,
|
||||
["progress", "updated_at"],
|
||||
)
|
||||
last_persisted_progress = progress
|
||||
last_persisted_progress_at = now
|
||||
|
||||
@@ -1170,26 +1193,39 @@ def perform_prowler_scan(
|
||||
batch_size=SCAN_DB_BATCH_SIZE,
|
||||
)
|
||||
|
||||
except ProviderDeletedException as e:
|
||||
logger.warning(str(e))
|
||||
exception = e
|
||||
skip_final_scan_update = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing scan {scan_id}: {e}")
|
||||
exception = e
|
||||
scan_instance.state = StateChoices.FAILED
|
||||
|
||||
finally:
|
||||
with rls_transaction(tenant_id):
|
||||
scan_instance.duration = time.time() - start_time
|
||||
scan_instance.completed_at = datetime.now(tz=UTC)
|
||||
scan_instance.unique_resource_count = len(unique_resources)
|
||||
scan_instance.save(
|
||||
update_fields=[
|
||||
"state",
|
||||
"duration",
|
||||
"completed_at",
|
||||
"unique_resource_count",
|
||||
"progress",
|
||||
"updated_at",
|
||||
]
|
||||
)
|
||||
if not skip_final_scan_update:
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
scan_instance.duration = time.time() - start_time
|
||||
scan_instance.completed_at = datetime.now(tz=UTC)
|
||||
scan_instance.unique_resource_count = len(unique_resources)
|
||||
if exception is None:
|
||||
scan_instance.progress = 100
|
||||
_save_scan_instance(
|
||||
scan_instance,
|
||||
provider_id,
|
||||
[
|
||||
"state",
|
||||
"duration",
|
||||
"completed_at",
|
||||
"unique_resource_count",
|
||||
"progress",
|
||||
"updated_at",
|
||||
],
|
||||
)
|
||||
except ProviderDeletedException as e:
|
||||
logger.warning(str(e))
|
||||
exception = e
|
||||
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from api.db_router import MainRouter
|
||||
from api.exceptions import ProviderConnectionError
|
||||
from api.exceptions import ProviderConnectionError, ProviderDeletedException
|
||||
from api.models import (
|
||||
Finding,
|
||||
MuteRule,
|
||||
@@ -262,6 +262,75 @@ class TestPerformScan:
|
||||
assert provider.connected is False
|
||||
assert isinstance(provider.connection_last_checked_at, datetime)
|
||||
|
||||
def test_perform_prowler_scan_provider_deleted_during_progress_update(
|
||||
self,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
tenant_id = str(tenant.id)
|
||||
scan_id = str(scan.id)
|
||||
provider_id = str(provider.id)
|
||||
|
||||
def scan_results():
|
||||
Provider.objects.filter(pk=provider_id).delete()
|
||||
yield 50, []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
|
||||
patch("tasks.jobs.scan.logger.error") as mock_logger_error,
|
||||
):
|
||||
mock_prowler_scan_instance = MagicMock()
|
||||
mock_prowler_scan_instance.scan.return_value = scan_results()
|
||||
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
|
||||
|
||||
with pytest.raises(ProviderDeletedException):
|
||||
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
|
||||
|
||||
mock_logger_error.assert_not_called()
|
||||
assert not Scan.objects.filter(pk=scan_id).exists()
|
||||
|
||||
def test_perform_prowler_scan_sets_final_progress_when_progress_updates_are_throttled(
|
||||
self,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
|
||||
tenant_id = str(tenant.id)
|
||||
scan_id = str(scan.id)
|
||||
provider_id = str(provider.id)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
|
||||
patch("tasks.jobs.scan.PROGRESS_THROTTLE_DELTA", 200),
|
||||
patch("tasks.jobs.scan.PROGRESS_THROTTLE_SECONDS", 3600),
|
||||
):
|
||||
mock_prowler_scan_instance = MagicMock()
|
||||
mock_prowler_scan_instance.scan.return_value = [(99, []), (100, [])]
|
||||
mock_prowler_scan_class.return_value = mock_prowler_scan_instance
|
||||
|
||||
perform_prowler_scan(tenant_id, scan_id, provider_id, [])
|
||||
|
||||
scan.refresh_from_db()
|
||||
assert scan.state == StateChoices.COMPLETED
|
||||
assert scan.progress == 100
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"last_status, new_status, expected_delta",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user