fix(scans): scheduled scans duplicates (#9883)

Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
This commit is contained in:
Prowler Bot
2026-01-26 13:36:49 +01:00
committed by GitHub
parent 09ee0698d4
commit c509482954
4 changed files with 336 additions and 71 deletions

View File

@@ -15,6 +15,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Lazy load Neo4j driver for workers only [(#9872)](https://github.com/prowler-cloud/prowler/pull/9872)
- Improve Cypher query for inserting Findings into Attack Paths scan graphs [(#9874)](https://github.com/prowler-cloud/prowler/pull/9874)
- Clear Neo4j database cache after Attack Paths scan and each API query [(#9877)](https://github.com/prowler-cloud/prowler/pull/9877)
- Deduplicated scheduled scans for long-running providers [(#9829)](https://github.com/prowler-cloud/prowler/pull/9829)
## [1.18.0] (Prowler v5.17.0)

View File

@@ -1,25 +1,13 @@
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from shutil import rmtree
from celery import chain, group, shared_task
from celery.utils.log import get_task_logger
from django_celery_beat.models import PeriodicTask
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import handle_provider_deletion, set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
from django_celery_beat.models import PeriodicTask
from tasks.jobs.attack_paths import (
attack_paths_scan,
can_provider_run_attack_paths_scan,
@@ -64,7 +52,22 @@ from tasks.jobs.scan import (
perform_prowler_scan,
update_provider_compliance_scores,
)
from tasks.utils import batched, get_next_execution_datetime
from tasks.utils import (
_get_or_create_scheduled_scan,
batched,
get_next_execution_datetime,
)
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import handle_provider_deletion, set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
@@ -275,44 +278,38 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
periodic_task_instance = PeriodicTask.objects.get(
name=f"scan-perform-scheduled-{provider_id}"
)
executed_scan = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
task__task_runner_task__task_id=task_id,
).order_by("completed_at")
if (
executing_scan = (
Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING,
scheduler_task_id=periodic_task_instance.id,
scheduled_at__date=datetime.now(timezone.utc).date(),
).exists()
or executed_scan.exists()
):
# Duplicated task execution due to visibility timeout or scan is already running
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
try:
affected_scan = executed_scan.first()
if not affected_scan:
raise ValueError(
"Error retrieving affected scan details after detecting duplicated scheduled "
"scan."
)
# Return the affected scan details to avoid losing data
serializer = ScanTaskSerializer(instance=affected_scan)
except Exception as duplicated_scan_exception:
logger.error(
f"Duplicated scheduled scan for provider {provider_id}. Error retrieving affected scan details: "
f"{str(duplicated_scan_exception)}"
.order_by("-started_at")
.first()
)
raise duplicated_scan_exception
return serializer.data
if executing_scan:
logger.warning(
f"Scheduled scan already executing for provider {provider_id}. Skipping."
)
return ScanTaskSerializer(instance=executing_scan).data
executed_scan = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
task__task_runner_task__task_id=task_id,
).first()
if executed_scan:
# Duplicated task execution due to visibility timeout
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
return ScanTaskSerializer(instance=executed_scan).data
interval = periodic_task_instance.interval
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
current_scan_datetime = next_scan_datetime - timedelta(
**{interval.period: interval.every}
)
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
_cleanup_orphan_scheduled_scans(
@@ -321,19 +318,12 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scheduler_task_id=periodic_task_instance.id,
)
scan_instance, _ = Scan.objects.get_or_create(
scan_instance = _get_or_create_scheduled_scan(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
scheduler_task_id=periodic_task_instance.id,
defaults={
"state": StateChoices.SCHEDULED,
"name": "Daily scheduled scan",
"scheduled_at": next_scan_datetime - timedelta(days=1),
},
scheduled_at=current_scan_datetime,
)
scan_instance.task_id = task_id
scan_instance.save()
@@ -343,18 +333,19 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scan_id=str(scan_instance.id),
provider_id=provider_id,
)
except Exception as e:
raise e
finally:
with rls_transaction(tenant_id):
Scan.objects.get_or_create(
now = datetime.now(timezone.utc)
if next_scan_datetime <= now:
interval_delta = timedelta(**{interval.period: interval.every})
while next_scan_datetime <= now:
next_scan_datetime += interval_delta
_get_or_create_scheduled_scan(
tenant_id=tenant_id,
name="Daily scheduled scan",
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=next_scan_datetime,
scheduler_task_id=periodic_task_instance.id,
scheduled_at=next_scan_datetime,
update_state=True,
)
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)

View File

@@ -1,21 +1,13 @@
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import openai
import pytest
from botocore.exceptions import ClientError
from django_celery_beat.models import IntervalSchedule, PeriodicTask
from api.models import (
Integration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Scan,
StateChoices,
)
from django_celery_results.models import TaskResult
from tasks.jobs.lighthouse_providers import (
_create_bedrock_client,
_extract_bedrock_credentials,
@@ -27,11 +19,21 @@ from tasks.tasks import (
check_lighthouse_provider_connection_task,
generate_outputs_task,
perform_attack_paths_scan_task,
perform_scheduled_scan_task,
refresh_lighthouse_provider_models_task,
s3_integration_task,
security_hub_integration_task,
)
from api.models import (
Integration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Scan,
StateChoices,
Task,
)
@pytest.mark.django_db
class TestExtractBedrockCredentials:
@@ -2137,3 +2139,215 @@ class TestCleanupOrphanScheduledScans:
assert not Scan.objects.filter(id=orphan_scan.id).exists()
assert Scan.objects.filter(id=scheduled_scan.id).exists()
assert Scan.objects.filter(id=available_scan_other_task.id).exists()
@pytest.mark.django_db
class TestPerformScheduledScanTask:
"""Unit tests for perform_scheduled_scan_task."""
@staticmethod
@contextmanager
def _override_task_request(task, **attrs):
request = task.request
sentinel = object()
previous = {key: getattr(request, key, sentinel) for key in attrs}
for key, value in attrs.items():
setattr(request, key, value)
try:
yield
finally:
for key, prev in previous.items():
if prev is sentinel:
if hasattr(request, key):
delattr(request, key)
else:
setattr(request, key, prev)
def _create_periodic_task(self, provider_id, tenant_id, interval_hours=24):
interval, _ = IntervalSchedule.objects.get_or_create(
every=interval_hours, period="hours"
)
return PeriodicTask.objects.create(
name=f"scan-perform-scheduled-{provider_id}",
task="scan-perform-scheduled",
interval=interval,
kwargs=f'{{"tenant_id": "{tenant_id}", "provider_id": "{provider_id}"}}',
enabled=True,
)
def _create_task_result(self, tenant_id, task_id):
task_result = TaskResult.objects.create(
task_id=task_id,
task_name="scan-perform-scheduled",
status="STARTED",
date_created=datetime.now(timezone.utc),
)
Task.objects.create(
id=task_id, task_runner_task=task_result, tenant_id=tenant_id
)
return task_result
def test_skip_when_scheduled_scan_executing(
self, tenants_fixture, providers_fixture
):
"""Skip a scheduled run when another scheduled scan is already executing."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
periodic_task = self._create_periodic_task(provider.id, tenant.id)
task_id = str(uuid.uuid4())
self._create_task_result(tenant.id, task_id)
executing_scan = Scan.objects.create(
tenant_id=tenant.id,
provider=provider,
name="Daily scheduled scan",
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING,
scheduler_task_id=periodic_task.id,
)
with (
patch("tasks.tasks.perform_prowler_scan") as mock_scan,
patch("tasks.tasks._perform_scan_complete_tasks") as mock_complete_tasks,
self._override_task_request(perform_scheduled_scan_task, id=task_id),
):
result = perform_scheduled_scan_task.run(
tenant_id=str(tenant.id), provider_id=str(provider.id)
)
mock_scan.assert_not_called()
mock_complete_tasks.assert_not_called()
assert result["id"] == str(executing_scan.id)
assert result["state"] == StateChoices.EXECUTING
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
).count()
== 0
)
def test_creates_next_scheduled_scan_after_completion(
self, tenants_fixture, providers_fixture
):
"""Create a next scheduled scan after a successful run completes."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
self._create_periodic_task(provider.id, tenant.id)
task_id = str(uuid.uuid4())
self._create_task_result(tenant.id, task_id)
def _complete_scan(tenant_id, scan_id, provider_id):
other_scheduled = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
).exclude(id=scan_id)
assert not other_scheduled.exists()
scan_instance = Scan.objects.get(id=scan_id)
scan_instance.state = StateChoices.COMPLETED
scan_instance.save()
return {"status": "ok"}
with (
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
patch("tasks.tasks._perform_scan_complete_tasks"),
self._override_task_request(perform_scheduled_scan_task, id=task_id),
):
perform_scheduled_scan_task.run(
tenant_id=str(tenant.id), provider_id=str(provider.id)
)
scheduled_scans = Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
)
assert scheduled_scans.count() == 1
assert scheduled_scans.first().scheduled_at > datetime.now(timezone.utc)
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
).count()
== 1
)
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.COMPLETED,
).count()
== 1
)
def test_dedupes_multiple_scheduled_scans_before_run(
self, tenants_fixture, providers_fixture
):
"""Ensure duplicated scheduled scans are removed before executing."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
periodic_task = self._create_periodic_task(provider.id, tenant.id)
task_id = str(uuid.uuid4())
self._create_task_result(tenant.id, task_id)
scheduled_scan = Scan.objects.create(
tenant_id=tenant.id,
provider=provider,
name="Daily scheduled scan",
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=datetime.now(timezone.utc),
scheduler_task_id=periodic_task.id,
)
duplicate_scan = Scan.objects.create(
tenant_id=tenant.id,
provider=provider,
name="Daily scheduled scan",
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.AVAILABLE,
scheduled_at=scheduled_scan.scheduled_at,
scheduler_task_id=periodic_task.id,
)
def _complete_scan(tenant_id, scan_id, provider_id):
other_scheduled = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
).exclude(id=scan_id)
assert not other_scheduled.exists()
scan_instance = Scan.objects.get(id=scan_id)
scan_instance.state = StateChoices.COMPLETED
scan_instance.save()
return {"status": "ok"}
with (
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
patch("tasks.tasks._perform_scan_complete_tasks"),
self._override_task_request(perform_scheduled_scan_task, id=task_id),
):
perform_scheduled_scan_task.run(
tenant_id=str(tenant.id), provider_id=str(provider.id)
)
assert not Scan.objects.filter(id=duplicate_scan.id).exists()
assert Scan.objects.filter(id=scheduled_scan.id).exists()
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
).count()
== 1
)

View File

@@ -5,6 +5,10 @@ from enum import Enum
from django_celery_beat.models import PeriodicTask
from django_celery_results.models import TaskResult
from api.models import Scan, StateChoices
SCHEDULED_SCAN_NAME = "Daily scheduled scan"
class CustomEncoder(json.JSONEncoder):
def default(self, o):
@@ -71,3 +75,58 @@ def batched(iterable, batch_size):
batch = []
yield batch, True
def _get_or_create_scheduled_scan(
tenant_id: str,
provider_id: str,
scheduler_task_id: int,
scheduled_at: datetime,
update_state: bool = False,
) -> Scan:
"""
Get or create a scheduled scan, cleaning up duplicates if found.
Args:
tenant_id: The tenant ID.
provider_id: The provider ID.
scheduler_task_id: The PeriodicTask ID.
scheduled_at: The scheduled datetime for the scan.
update_state: If True, also reset state to SCHEDULED when updating.
Returns:
The scan instance to use.
"""
scheduled_scans = list(
Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
scheduler_task_id=scheduler_task_id,
).order_by("scheduled_at", "inserted_at")
)
if scheduled_scans:
scan_instance = scheduled_scans[0]
if len(scheduled_scans) > 1:
Scan.objects.filter(id__in=[s.id for s in scheduled_scans[1:]]).delete()
needs_update = scan_instance.scheduled_at != scheduled_at
if update_state and scan_instance.state != StateChoices.SCHEDULED:
scan_instance.state = StateChoices.SCHEDULED
scan_instance.name = SCHEDULED_SCAN_NAME
needs_update = True
if needs_update:
scan_instance.scheduled_at = scheduled_at
scan_instance.save()
return scan_instance
return Scan.objects.create(
tenant_id=tenant_id,
name=SCHEDULED_SCAN_NAME,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=scheduled_at,
scheduler_task_id=scheduler_task_id,
)