From c5094829543e1311b3a24d6ed9376bb6411e325d Mon Sep 17 00:00:00 2001 From: Prowler Bot Date: Mon, 26 Jan 2026 13:36:49 +0100 Subject: [PATCH] fix(scans): scheduled scans duplicates (#9883) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Fernández Poyatos --- api/CHANGELOG.md | 1 + api/src/backend/tasks/tasks.py | 113 +++++------ api/src/backend/tasks/tests/test_tasks.py | 234 +++++++++++++++++++++- api/src/backend/tasks/utils.py | 59 ++++++ 4 files changed, 336 insertions(+), 71 deletions(-) diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 19d323bc0b..ce3228f60d 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -15,6 +15,7 @@ All notable changes to the **Prowler API** are documented in this file. - Lazy load Neo4j driver for workers only [(#9872)](https://github.com/prowler-cloud/prowler/pull/9872) - Improve Cypher query for inserting Findings into Attack Paths scan graphs [(#9874)](https://github.com/prowler-cloud/prowler/pull/9874) - Clear Neo4j database cache after Attack Paths scan and each API query [(#9877)](https://github.com/prowler-cloud/prowler/pull/9877) +- Deduplicated scheduled scans for long-running providers [(#9829)](https://github.com/prowler-cloud/prowler/pull/9829) ## [1.18.0] (Prowler v5.17.0) diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index 691b505dc6..0dd3b0905b 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -1,25 +1,13 @@ import os - from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree from celery import chain, group, shared_task from celery.utils.log import get_task_logger -from django_celery_beat.models import PeriodicTask - -from api.compliance import get_compliance_frameworks -from api.db_router import READ_REPLICA_ALIAS -from api.db_utils import rls_transaction -from api.decorators import handle_provider_deletion, set_tenant -from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices -from api.utils import initialize_prowler_provider -from api.v1.serializers import ScanTaskSerializer from config.celery import RLSTask from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY -from prowler.lib.check.compliance_models import Compliance -from prowler.lib.outputs.compliance.generic.generic import GenericCompliance -from prowler.lib.outputs.finding import Finding as FindingOutput +from django_celery_beat.models import PeriodicTask from tasks.jobs.attack_paths import ( attack_paths_scan, can_provider_run_attack_paths_scan, @@ -64,7 +52,22 @@ from tasks.jobs.scan import ( perform_prowler_scan, update_provider_compliance_scores, ) -from tasks.utils import batched, get_next_execution_datetime +from tasks.utils import ( + _get_or_create_scheduled_scan, + batched, + get_next_execution_datetime, +) + +from api.compliance import get_compliance_frameworks +from api.db_router import READ_REPLICA_ALIAS +from api.db_utils import rls_transaction +from api.decorators import handle_provider_deletion, set_tenant +from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices +from api.utils import initialize_prowler_provider +from api.v1.serializers import ScanTaskSerializer +from prowler.lib.check.compliance_models import Compliance +from prowler.lib.outputs.compliance.generic.generic import GenericCompliance +from prowler.lib.outputs.finding import Finding as FindingOutput logger = get_task_logger(__name__) @@ -275,44 +278,38 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str): periodic_task_instance = PeriodicTask.objects.get( name=f"scan-perform-scheduled-{provider_id}" ) - - executed_scan = Scan.objects.filter( - tenant_id=tenant_id, - provider_id=provider_id, - task__task_runner_task__task_id=task_id, - ).order_by("completed_at") - - if ( + executing_scan = ( Scan.objects.filter( tenant_id=tenant_id, provider_id=provider_id, trigger=Scan.TriggerChoices.SCHEDULED, state=StateChoices.EXECUTING, - scheduler_task_id=periodic_task_instance.id, - scheduled_at__date=datetime.now(timezone.utc).date(), - ).exists() - or executed_scan.exists() - ): - # Duplicated task execution due to visibility timeout or scan is already running - logger.warning(f"Duplicated scheduled scan for provider {provider_id}.") - try: - affected_scan = executed_scan.first() - if not affected_scan: - raise ValueError( - "Error retrieving affected scan details after detecting duplicated scheduled " - "scan." - ) - # Return the affected scan details to avoid losing data - serializer = ScanTaskSerializer(instance=affected_scan) - except Exception as duplicated_scan_exception: - logger.error( - f"Duplicated scheduled scan for provider {provider_id}. Error retrieving affected scan details: " - f"{str(duplicated_scan_exception)}" - ) - raise duplicated_scan_exception - return serializer.data + ) + .order_by("-started_at") + .first() + ) + if executing_scan: + logger.warning( + f"Scheduled scan already executing for provider {provider_id}. Skipping." + ) + return ScanTaskSerializer(instance=executing_scan).data + executed_scan = Scan.objects.filter( + tenant_id=tenant_id, + provider_id=provider_id, + task__task_runner_task__task_id=task_id, + ).first() + + if executed_scan: + # Duplicated task execution due to visibility timeout + logger.warning(f"Duplicated scheduled scan for provider {provider_id}.") + return ScanTaskSerializer(instance=executed_scan).data + + interval = periodic_task_instance.interval next_scan_datetime = get_next_execution_datetime(task_id, provider_id) + current_scan_datetime = next_scan_datetime - timedelta( + **{interval.period: interval.every} + ) # TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue _cleanup_orphan_scheduled_scans( @@ -321,19 +318,12 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str): scheduler_task_id=periodic_task_instance.id, ) - scan_instance, _ = Scan.objects.get_or_create( + scan_instance = _get_or_create_scheduled_scan( tenant_id=tenant_id, provider_id=provider_id, - trigger=Scan.TriggerChoices.SCHEDULED, - state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE), scheduler_task_id=periodic_task_instance.id, - defaults={ - "state": StateChoices.SCHEDULED, - "name": "Daily scheduled scan", - "scheduled_at": next_scan_datetime - timedelta(days=1), - }, + scheduled_at=current_scan_datetime, ) - scan_instance.task_id = task_id scan_instance.save() @@ -343,18 +333,19 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str): scan_id=str(scan_instance.id), provider_id=provider_id, ) - except Exception as e: - raise e finally: with rls_transaction(tenant_id): - Scan.objects.get_or_create( + now = datetime.now(timezone.utc) + if next_scan_datetime <= now: + interval_delta = timedelta(**{interval.period: interval.every}) + while next_scan_datetime <= now: + next_scan_datetime += interval_delta + _get_or_create_scheduled_scan( tenant_id=tenant_id, - name="Daily scheduled scan", provider_id=provider_id, - trigger=Scan.TriggerChoices.SCHEDULED, - state=StateChoices.SCHEDULED, - scheduled_at=next_scan_datetime, scheduler_task_id=periodic_task_instance.id, + scheduled_at=next_scan_datetime, + update_state=True, ) _perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id) diff --git a/api/src/backend/tasks/tests/test_tasks.py b/api/src/backend/tasks/tests/test_tasks.py index 3d44cd95bb..3a58118c62 100644 --- a/api/src/backend/tasks/tests/test_tasks.py +++ b/api/src/backend/tasks/tests/test_tasks.py @@ -1,21 +1,13 @@ import uuid - from contextlib import contextmanager +from datetime import datetime, timezone from unittest.mock import MagicMock, patch import openai import pytest - from botocore.exceptions import ClientError from django_celery_beat.models import IntervalSchedule, PeriodicTask - -from api.models import ( - Integration, - LighthouseProviderConfiguration, - LighthouseProviderModels, - Scan, - StateChoices, -) +from django_celery_results.models import TaskResult from tasks.jobs.lighthouse_providers import ( _create_bedrock_client, _extract_bedrock_credentials, @@ -27,11 +19,21 @@ from tasks.tasks import ( check_lighthouse_provider_connection_task, generate_outputs_task, perform_attack_paths_scan_task, + perform_scheduled_scan_task, refresh_lighthouse_provider_models_task, s3_integration_task, security_hub_integration_task, ) +from api.models import ( + Integration, + LighthouseProviderConfiguration, + LighthouseProviderModels, + Scan, + StateChoices, + Task, +) + @pytest.mark.django_db class TestExtractBedrockCredentials: @@ -2137,3 +2139,215 @@ class TestCleanupOrphanScheduledScans: assert not Scan.objects.filter(id=orphan_scan.id).exists() assert Scan.objects.filter(id=scheduled_scan.id).exists() assert Scan.objects.filter(id=available_scan_other_task.id).exists() + + +@pytest.mark.django_db +class TestPerformScheduledScanTask: + """Unit tests for perform_scheduled_scan_task.""" + + @staticmethod + @contextmanager + def _override_task_request(task, **attrs): + request = task.request + sentinel = object() + previous = {key: getattr(request, key, sentinel) for key in attrs} + for key, value in attrs.items(): + setattr(request, key, value) + + try: + yield + finally: + for key, prev in previous.items(): + if prev is sentinel: + if hasattr(request, key): + delattr(request, key) + else: + setattr(request, key, prev) + + def _create_periodic_task(self, provider_id, tenant_id, interval_hours=24): + interval, _ = IntervalSchedule.objects.get_or_create( + every=interval_hours, period="hours" + ) + return PeriodicTask.objects.create( + name=f"scan-perform-scheduled-{provider_id}", + task="scan-perform-scheduled", + interval=interval, + kwargs=f'{{"tenant_id": "{tenant_id}", "provider_id": "{provider_id}"}}', + enabled=True, + ) + + def _create_task_result(self, tenant_id, task_id): + task_result = TaskResult.objects.create( + task_id=task_id, + task_name="scan-perform-scheduled", + status="STARTED", + date_created=datetime.now(timezone.utc), + ) + Task.objects.create( + id=task_id, task_runner_task=task_result, tenant_id=tenant_id + ) + return task_result + + def test_skip_when_scheduled_scan_executing( + self, tenants_fixture, providers_fixture + ): + """Skip a scheduled run when another scheduled scan is already executing.""" + tenant = tenants_fixture[0] + provider = providers_fixture[0] + periodic_task = self._create_periodic_task(provider.id, tenant.id) + task_id = str(uuid.uuid4()) + self._create_task_result(tenant.id, task_id) + + executing_scan = Scan.objects.create( + tenant_id=tenant.id, + provider=provider, + name="Daily scheduled scan", + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.EXECUTING, + scheduler_task_id=periodic_task.id, + ) + + with ( + patch("tasks.tasks.perform_prowler_scan") as mock_scan, + patch("tasks.tasks._perform_scan_complete_tasks") as mock_complete_tasks, + self._override_task_request(perform_scheduled_scan_task, id=task_id), + ): + result = perform_scheduled_scan_task.run( + tenant_id=str(tenant.id), provider_id=str(provider.id) + ) + + mock_scan.assert_not_called() + mock_complete_tasks.assert_not_called() + assert result["id"] == str(executing_scan.id) + assert result["state"] == StateChoices.EXECUTING + assert ( + Scan.objects.filter( + tenant_id=tenant.id, + provider=provider, + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.SCHEDULED, + ).count() + == 0 + ) + + def test_creates_next_scheduled_scan_after_completion( + self, tenants_fixture, providers_fixture + ): + """Create a next scheduled scan after a successful run completes.""" + tenant = tenants_fixture[0] + provider = providers_fixture[0] + self._create_periodic_task(provider.id, tenant.id) + task_id = str(uuid.uuid4()) + self._create_task_result(tenant.id, task_id) + + def _complete_scan(tenant_id, scan_id, provider_id): + other_scheduled = Scan.objects.filter( + tenant_id=tenant_id, + provider_id=provider_id, + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.SCHEDULED, + ).exclude(id=scan_id) + assert not other_scheduled.exists() + scan_instance = Scan.objects.get(id=scan_id) + scan_instance.state = StateChoices.COMPLETED + scan_instance.save() + return {"status": "ok"} + + with ( + patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan), + patch("tasks.tasks._perform_scan_complete_tasks"), + self._override_task_request(perform_scheduled_scan_task, id=task_id), + ): + perform_scheduled_scan_task.run( + tenant_id=str(tenant.id), provider_id=str(provider.id) + ) + + scheduled_scans = Scan.objects.filter( + tenant_id=tenant.id, + provider=provider, + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.SCHEDULED, + ) + assert scheduled_scans.count() == 1 + assert scheduled_scans.first().scheduled_at > datetime.now(timezone.utc) + assert ( + Scan.objects.filter( + tenant_id=tenant.id, + provider=provider, + trigger=Scan.TriggerChoices.SCHEDULED, + state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE), + ).count() + == 1 + ) + assert ( + Scan.objects.filter( + tenant_id=tenant.id, + provider=provider, + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.COMPLETED, + ).count() + == 1 + ) + + def test_dedupes_multiple_scheduled_scans_before_run( + self, tenants_fixture, providers_fixture + ): + """Ensure duplicated scheduled scans are removed before executing.""" + tenant = tenants_fixture[0] + provider = providers_fixture[0] + periodic_task = self._create_periodic_task(provider.id, tenant.id) + task_id = str(uuid.uuid4()) + self._create_task_result(tenant.id, task_id) + + scheduled_scan = Scan.objects.create( + tenant_id=tenant.id, + provider=provider, + name="Daily scheduled scan", + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.SCHEDULED, + scheduled_at=datetime.now(timezone.utc), + scheduler_task_id=periodic_task.id, + ) + duplicate_scan = Scan.objects.create( + tenant_id=tenant.id, + provider=provider, + name="Daily scheduled scan", + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.AVAILABLE, + scheduled_at=scheduled_scan.scheduled_at, + scheduler_task_id=periodic_task.id, + ) + + def _complete_scan(tenant_id, scan_id, provider_id): + other_scheduled = Scan.objects.filter( + tenant_id=tenant_id, + provider_id=provider_id, + trigger=Scan.TriggerChoices.SCHEDULED, + state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE), + ).exclude(id=scan_id) + assert not other_scheduled.exists() + scan_instance = Scan.objects.get(id=scan_id) + scan_instance.state = StateChoices.COMPLETED + scan_instance.save() + return {"status": "ok"} + + with ( + patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan), + patch("tasks.tasks._perform_scan_complete_tasks"), + self._override_task_request(perform_scheduled_scan_task, id=task_id), + ): + perform_scheduled_scan_task.run( + tenant_id=str(tenant.id), provider_id=str(provider.id) + ) + + assert not Scan.objects.filter(id=duplicate_scan.id).exists() + assert Scan.objects.filter(id=scheduled_scan.id).exists() + assert ( + Scan.objects.filter( + tenant_id=tenant.id, + provider=provider, + trigger=Scan.TriggerChoices.SCHEDULED, + state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE), + ).count() + == 1 + ) diff --git a/api/src/backend/tasks/utils.py b/api/src/backend/tasks/utils.py index 21e30c9e29..eded5bfb9a 100644 --- a/api/src/backend/tasks/utils.py +++ b/api/src/backend/tasks/utils.py @@ -5,6 +5,10 @@ from enum import Enum from django_celery_beat.models import PeriodicTask from django_celery_results.models import TaskResult +from api.models import Scan, StateChoices + +SCHEDULED_SCAN_NAME = "Daily scheduled scan" + class CustomEncoder(json.JSONEncoder): def default(self, o): @@ -71,3 +75,58 @@ def batched(iterable, batch_size): batch = [] yield batch, True + + +def _get_or_create_scheduled_scan( + tenant_id: str, + provider_id: str, + scheduler_task_id: int, + scheduled_at: datetime, + update_state: bool = False, +) -> Scan: + """ + Get or create a scheduled scan, cleaning up duplicates if found. + + Args: + tenant_id: The tenant ID. + provider_id: The provider ID. + scheduler_task_id: The PeriodicTask ID. + scheduled_at: The scheduled datetime for the scan. + update_state: If True, also reset state to SCHEDULED when updating. + + Returns: + The scan instance to use. + """ + scheduled_scans = list( + Scan.objects.filter( + tenant_id=tenant_id, + provider_id=provider_id, + trigger=Scan.TriggerChoices.SCHEDULED, + state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE), + scheduler_task_id=scheduler_task_id, + ).order_by("scheduled_at", "inserted_at") + ) + + if scheduled_scans: + scan_instance = scheduled_scans[0] + if len(scheduled_scans) > 1: + Scan.objects.filter(id__in=[s.id for s in scheduled_scans[1:]]).delete() + needs_update = scan_instance.scheduled_at != scheduled_at + if update_state and scan_instance.state != StateChoices.SCHEDULED: + scan_instance.state = StateChoices.SCHEDULED + scan_instance.name = SCHEDULED_SCAN_NAME + needs_update = True + if needs_update: + scan_instance.scheduled_at = scheduled_at + scan_instance.save() + return scan_instance + + return Scan.objects.create( + tenant_id=tenant_id, + name=SCHEDULED_SCAN_NAME, + provider_id=provider_id, + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.SCHEDULED, + scheduled_at=scheduled_at, + scheduler_task_id=scheduler_task_id, + )