feat(exception): Add decorator for deleted providers during scans (#9414)

This commit is contained in:
Víctor Fernández Poyatos
2025-12-03 09:46:59 +01:00
committed by GitHub
parent f5c2146d19
commit 29a1034658
6 changed files with 214 additions and 4 deletions

View File

@@ -6,6 +6,7 @@ All notable changes to the **Prowler API** are documented in this file.
### Added
- New endpoint to retrieve an overview of the attack surfaces [(#9309)](https://github.com/prowler-cloud/prowler/pull/9309)
- Exception handler for provider deletions during scans [(#9414)](https://github.com/prowler-cloud/prowler/pull/9414)
### Changed
- Restore the compliance overview endpoint's mandatory filters [(#9330)](https://github.com/prowler-cloud/prowler/pull/9330)

View File

@@ -1,10 +1,14 @@
import uuid
from functools import wraps
from django.db import connection, transaction
from django.core.exceptions import ObjectDoesNotExist
from django.db import IntegrityError, connection, transaction
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY, rls_transaction
from api.exceptions import ProviderDeletedException
from api.models import Provider, Scan
def set_tenant(func=None, *, keep_tenant=False):
@@ -66,3 +70,49 @@ def set_tenant(func=None, *, keep_tenant=False):
return decorator
else:
return decorator(func)
def handle_provider_deletion(func):
"""
Decorator that raises ProviderDeletedException if provider was deleted during execution.
Catches ObjectDoesNotExist and IntegrityError, checks if provider still exists,
and raises ProviderDeletedException if not. Otherwise, re-raises original exception.
Requires tenant_id and provider_id in kwargs.
Example:
@shared_task
@handle_provider_deletion
def scan_task(scan_id, tenant_id, provider_id):
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (ObjectDoesNotExist, IntegrityError):
tenant_id = kwargs.get("tenant_id")
provider_id = kwargs.get("provider_id")
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
if provider_id is None:
scan_id = kwargs.get("scan_id")
if scan_id is None:
raise AssertionError(
"This task does not have provider or scan in the kwargs"
)
scan = Scan.objects.filter(pk=scan_id).first()
if scan is None:
raise ProviderDeletedException(
f"Provider for scan '{scan_id}' was deleted during the scan"
) from None
provider_id = str(scan.provider_id)
if not Provider.objects.filter(pk=provider_id).exists():
raise ProviderDeletedException(
f"Provider '{provider_id}' was deleted during the scan"
) from None
raise
return wrapper

View File

@@ -66,6 +66,10 @@ class ProviderConnectionError(Exception):
"""Base exception for provider connection errors."""
class ProviderDeletedException(Exception):
"""Raised when a provider has been deleted during scan/task execution."""
def custom_exception_handler(exc, context):
if isinstance(exc, django_validation_error):
if hasattr(exc, "error_dict"):

View File

@@ -2,9 +2,12 @@ import uuid
from unittest.mock import call, patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
from django.db import IntegrityError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.decorators import set_tenant
from api.decorators import handle_provider_deletion, set_tenant
from api.exceptions import ProviderDeletedException
@pytest.mark.django_db
@@ -34,3 +37,142 @@ class TestSetTenantDecorator:
with pytest.raises(KeyError):
random_func("test_arg")
@pytest.mark.django_db
class TestHandleProviderDeletionDecorator:
def test_success_no_exception(self, tenants_fixture, providers_fixture):
"""Decorated function runs normally when no exception is raised."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
@handle_provider_deletion
def task_func(**kwargs):
return "success"
result = task_func(
tenant_id=str(tenant.id),
provider_id=str(provider.id),
)
assert result == "success"
@patch("api.decorators.rls_transaction")
@patch("api.decorators.Provider.objects.filter")
def test_provider_deleted_with_provider_id(
self, mock_filter, mock_rls, tenants_fixture
):
"""Raises ProviderDeletedException when provider_id provided and provider deleted."""
tenant = tenants_fixture[0]
deleted_provider_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_filter.return_value.exists.return_value = False
@handle_provider_deletion
def task_func(**kwargs):
raise ObjectDoesNotExist("Some object not found")
with pytest.raises(ProviderDeletedException) as exc_info:
task_func(tenant_id=str(tenant.id), provider_id=deleted_provider_id)
assert deleted_provider_id in str(exc_info.value)
@patch("api.decorators.rls_transaction")
@patch("api.decorators.Provider.objects.filter")
@patch("api.decorators.Scan.objects.filter")
def test_provider_deleted_with_scan_id(
self, mock_scan_filter, mock_provider_filter, mock_rls, tenants_fixture
):
"""Raises ProviderDeletedException when scan exists but provider deleted."""
tenant = tenants_fixture[0]
scan_id = str(uuid.uuid4())
provider_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_scan = type("MockScan", (), {"provider_id": provider_id})()
mock_scan_filter.return_value.first.return_value = mock_scan
mock_provider_filter.return_value.exists.return_value = False
@handle_provider_deletion
def task_func(**kwargs):
raise ObjectDoesNotExist("Some object not found")
with pytest.raises(ProviderDeletedException) as exc_info:
task_func(tenant_id=str(tenant.id), scan_id=scan_id)
assert provider_id in str(exc_info.value)
@patch("api.decorators.rls_transaction")
@patch("api.decorators.Scan.objects.filter")
def test_scan_deleted_cascade(self, mock_scan_filter, mock_rls, tenants_fixture):
"""Raises ProviderDeletedException when scan was deleted (CASCADE from provider)."""
tenant = tenants_fixture[0]
scan_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_scan_filter.return_value.first.return_value = None
@handle_provider_deletion
def task_func(**kwargs):
raise ObjectDoesNotExist("Some object not found")
with pytest.raises(ProviderDeletedException) as exc_info:
task_func(tenant_id=str(tenant.id), scan_id=scan_id)
assert scan_id in str(exc_info.value)
@patch("api.decorators.rls_transaction")
@patch("api.decorators.Provider.objects.filter")
def test_provider_exists_reraises_original(
self, mock_filter, mock_rls, tenants_fixture, providers_fixture
):
"""Re-raises original exception when provider still exists."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_filter.return_value.exists.return_value = True
@handle_provider_deletion
def task_func(**kwargs):
raise ObjectDoesNotExist("Actual object missing")
with pytest.raises(ObjectDoesNotExist):
task_func(tenant_id=str(tenant.id), provider_id=str(provider.id))
@patch("api.decorators.rls_transaction")
@patch("api.decorators.Provider.objects.filter")
def test_integrity_error_provider_deleted(
self, mock_filter, mock_rls, tenants_fixture
):
"""Raises ProviderDeletedException on IntegrityError when provider deleted."""
tenant = tenants_fixture[0]
deleted_provider_id = str(uuid.uuid4())
mock_rls.return_value.__enter__ = lambda s: None
mock_rls.return_value.__exit__ = lambda s, *args: None
mock_filter.return_value.exists.return_value = False
@handle_provider_deletion
def task_func(**kwargs):
raise IntegrityError("FK constraint violation")
with pytest.raises(ProviderDeletedException):
task_func(tenant_id=str(tenant.id), provider_id=deleted_provider_id)
def test_missing_provider_and_scan_raises_assertion(self, tenants_fixture):
"""Raises AssertionError when neither provider_id nor scan_id in kwargs."""
@handle_provider_deletion
def task_func(**kwargs):
raise ObjectDoesNotExist("Some object not found")
with pytest.raises(AssertionError) as exc_info:
task_func(tenant_id=str(tenants_fixture[0].id))
assert "provider or scan" in str(exc_info.value)

View File

@@ -5,6 +5,8 @@ IGNORED_EXCEPTIONS = [
# Provider is not connected due to credentials errors
"is not connected",
"ProviderConnectionError",
# Provider was deleted during a scan
"ProviderDeletedException",
# Authentication Errors from AWS
"InvalidToken",
"AccessDeniedException",

View File

@@ -47,7 +47,7 @@ from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.decorators import handle_provider_deletion, set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
@@ -144,6 +144,7 @@ def delete_provider_task(provider_id: str, tenant_id: str):
@shared_task(base=RLSTask, name="scan-perform", queue="scans")
@handle_provider_deletion
def perform_scan_task(
tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None
):
@@ -176,6 +177,7 @@ def perform_scan_task(
@shared_task(base=RLSTask, bind=True, name="scan-perform-scheduled", queue="scans")
@handle_provider_deletion
def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
"""
Task to perform a scheduled Prowler scan on a given provider.
@@ -281,6 +283,7 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
@shared_task(name="scan-summary", queue="overview")
@handle_provider_deletion
def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
@@ -296,6 +299,7 @@ def delete_tenant_task(tenant_id: str):
queue="scan-reports",
)
@set_tenant(keep_tenant=True)
@handle_provider_deletion
def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
"""
Process findings in batches and generate output files in multiple formats.
@@ -491,6 +495,7 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
@shared_task(name="backfill-scan-resource-summaries", queue="backfill")
@handle_provider_deletion
def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
"""
Tries to backfill the resource scan summaries table for a given scan.
@@ -503,6 +508,7 @@ def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
@shared_task(name="backfill-compliance-summaries", queue="backfill")
@handle_provider_deletion
def backfill_compliance_summaries_task(tenant_id: str, scan_id: str):
"""
Tries to backfill compliance overview summaries for a completed scan.
@@ -518,6 +524,7 @@ def backfill_compliance_summaries_task(tenant_id: str, scan_id: str):
@shared_task(base=RLSTask, name="scan-compliance-overviews", queue="compliance")
@handle_provider_deletion
def create_compliance_requirements_task(tenant_id: str, scan_id: str):
"""
Creates detailed compliance requirement records for a scan.
@@ -534,6 +541,7 @@ def create_compliance_requirements_task(tenant_id: str, scan_id: str):
@shared_task(name="scan-attack-surface-overviews", queue="overview")
@handle_provider_deletion
def aggregate_attack_surface_task(tenant_id: str, scan_id: str):
"""
Creates attack surface overview records for a scan.
@@ -586,6 +594,7 @@ def refresh_lighthouse_provider_models_task(
@shared_task(name="integration-check")
@handle_provider_deletion
def check_integrations_task(tenant_id: str, provider_id: str, scan_id: str = None):
"""
Check and execute all configured integrations for a provider.
@@ -650,6 +659,7 @@ def check_integrations_task(tenant_id: str, provider_id: str, scan_id: str = Non
name="integration-s3",
queue="integrations",
)
@handle_provider_deletion
def s3_integration_task(
tenant_id: str,
provider_id: str,
@@ -709,6 +719,7 @@ def jira_integration_task(
name="scan-compliance-reports",
queue="scan-reports",
)
@handle_provider_deletion
def generate_compliance_reports_task(tenant_id: str, scan_id: str, provider_id: str):
"""
Optimized task to generate ThreatScore, ENS, and NIS2 reports with shared queries.