fix(scans): scheduled scans duplicates (#9883)

Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
This commit is contained in:
Prowler Bot
2026-01-26 13:36:49 +01:00
committed by GitHub
parent 09ee0698d4
commit c509482954
4 changed files with 336 additions and 71 deletions

View File

@@ -15,6 +15,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Lazy load Neo4j driver for workers only [(#9872)](https://github.com/prowler-cloud/prowler/pull/9872) - Lazy load Neo4j driver for workers only [(#9872)](https://github.com/prowler-cloud/prowler/pull/9872)
- Improve Cypher query for inserting Findings into Attack Paths scan graphs [(#9874)](https://github.com/prowler-cloud/prowler/pull/9874) - Improve Cypher query for inserting Findings into Attack Paths scan graphs [(#9874)](https://github.com/prowler-cloud/prowler/pull/9874)
- Clear Neo4j database cache after Attack Paths scan and each API query [(#9877)](https://github.com/prowler-cloud/prowler/pull/9877) - Clear Neo4j database cache after Attack Paths scan and each API query [(#9877)](https://github.com/prowler-cloud/prowler/pull/9877)
- Deduplicated scheduled scans for long-running providers [(#9829)](https://github.com/prowler-cloud/prowler/pull/9829)
## [1.18.0] (Prowler v5.17.0) ## [1.18.0] (Prowler v5.17.0)

View File

@@ -1,25 +1,13 @@
import os import os
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from celery import chain, group, shared_task from celery import chain, group, shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from django_celery_beat.models import PeriodicTask
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import handle_provider_deletion, set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from config.celery import RLSTask from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from prowler.lib.check.compliance_models import Compliance from django_celery_beat.models import PeriodicTask
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
from tasks.jobs.attack_paths import ( from tasks.jobs.attack_paths import (
attack_paths_scan, attack_paths_scan,
can_provider_run_attack_paths_scan, can_provider_run_attack_paths_scan,
@@ -64,7 +52,22 @@ from tasks.jobs.scan import (
perform_prowler_scan, perform_prowler_scan,
update_provider_compliance_scores, update_provider_compliance_scores,
) )
from tasks.utils import batched, get_next_execution_datetime from tasks.utils import (
_get_or_create_scheduled_scan,
batched,
get_next_execution_datetime,
)
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import handle_provider_deletion, set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__) logger = get_task_logger(__name__)
@@ -275,44 +278,38 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
periodic_task_instance = PeriodicTask.objects.get( periodic_task_instance = PeriodicTask.objects.get(
name=f"scan-perform-scheduled-{provider_id}" name=f"scan-perform-scheduled-{provider_id}"
) )
executing_scan = (
executed_scan = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
task__task_runner_task__task_id=task_id,
).order_by("completed_at")
if (
Scan.objects.filter( Scan.objects.filter(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED, trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING, state=StateChoices.EXECUTING,
scheduler_task_id=periodic_task_instance.id, )
scheduled_at__date=datetime.now(timezone.utc).date(), .order_by("-started_at")
).exists() .first()
or executed_scan.exists() )
): if executing_scan:
# Duplicated task execution due to visibility timeout or scan is already running logger.warning(
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.") f"Scheduled scan already executing for provider {provider_id}. Skipping."
try: )
affected_scan = executed_scan.first() return ScanTaskSerializer(instance=executing_scan).data
if not affected_scan:
raise ValueError(
"Error retrieving affected scan details after detecting duplicated scheduled "
"scan."
)
# Return the affected scan details to avoid losing data
serializer = ScanTaskSerializer(instance=affected_scan)
except Exception as duplicated_scan_exception:
logger.error(
f"Duplicated scheduled scan for provider {provider_id}. Error retrieving affected scan details: "
f"{str(duplicated_scan_exception)}"
)
raise duplicated_scan_exception
return serializer.data
executed_scan = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
task__task_runner_task__task_id=task_id,
).first()
if executed_scan:
# Duplicated task execution due to visibility timeout
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
return ScanTaskSerializer(instance=executed_scan).data
interval = periodic_task_instance.interval
next_scan_datetime = get_next_execution_datetime(task_id, provider_id) next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
current_scan_datetime = next_scan_datetime - timedelta(
**{interval.period: interval.every}
)
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue # TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
_cleanup_orphan_scheduled_scans( _cleanup_orphan_scheduled_scans(
@@ -321,19 +318,12 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scheduler_task_id=periodic_task_instance.id, scheduler_task_id=periodic_task_instance.id,
) )
scan_instance, _ = Scan.objects.get_or_create( scan_instance = _get_or_create_scheduled_scan(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
scheduler_task_id=periodic_task_instance.id, scheduler_task_id=periodic_task_instance.id,
defaults={ scheduled_at=current_scan_datetime,
"state": StateChoices.SCHEDULED,
"name": "Daily scheduled scan",
"scheduled_at": next_scan_datetime - timedelta(days=1),
},
) )
scan_instance.task_id = task_id scan_instance.task_id = task_id
scan_instance.save() scan_instance.save()
@@ -343,18 +333,19 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scan_id=str(scan_instance.id), scan_id=str(scan_instance.id),
provider_id=provider_id, provider_id=provider_id,
) )
except Exception as e:
raise e
finally: finally:
with rls_transaction(tenant_id): with rls_transaction(tenant_id):
Scan.objects.get_or_create( now = datetime.now(timezone.utc)
if next_scan_datetime <= now:
interval_delta = timedelta(**{interval.period: interval.every})
while next_scan_datetime <= now:
next_scan_datetime += interval_delta
_get_or_create_scheduled_scan(
tenant_id=tenant_id, tenant_id=tenant_id,
name="Daily scheduled scan",
provider_id=provider_id, provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=next_scan_datetime,
scheduler_task_id=periodic_task_instance.id, scheduler_task_id=periodic_task_instance.id,
scheduled_at=next_scan_datetime,
update_state=True,
) )
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id) _perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)

View File

@@ -1,21 +1,13 @@
import uuid import uuid
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import openai import openai
import pytest import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from django_celery_beat.models import IntervalSchedule, PeriodicTask from django_celery_beat.models import IntervalSchedule, PeriodicTask
from django_celery_results.models import TaskResult
from api.models import (
Integration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Scan,
StateChoices,
)
from tasks.jobs.lighthouse_providers import ( from tasks.jobs.lighthouse_providers import (
_create_bedrock_client, _create_bedrock_client,
_extract_bedrock_credentials, _extract_bedrock_credentials,
@@ -27,11 +19,21 @@ from tasks.tasks import (
check_lighthouse_provider_connection_task, check_lighthouse_provider_connection_task,
generate_outputs_task, generate_outputs_task,
perform_attack_paths_scan_task, perform_attack_paths_scan_task,
perform_scheduled_scan_task,
refresh_lighthouse_provider_models_task, refresh_lighthouse_provider_models_task,
s3_integration_task, s3_integration_task,
security_hub_integration_task, security_hub_integration_task,
) )
from api.models import (
Integration,
LighthouseProviderConfiguration,
LighthouseProviderModels,
Scan,
StateChoices,
Task,
)
@pytest.mark.django_db @pytest.mark.django_db
class TestExtractBedrockCredentials: class TestExtractBedrockCredentials:
@@ -2137,3 +2139,215 @@ class TestCleanupOrphanScheduledScans:
assert not Scan.objects.filter(id=orphan_scan.id).exists() assert not Scan.objects.filter(id=orphan_scan.id).exists()
assert Scan.objects.filter(id=scheduled_scan.id).exists() assert Scan.objects.filter(id=scheduled_scan.id).exists()
assert Scan.objects.filter(id=available_scan_other_task.id).exists() assert Scan.objects.filter(id=available_scan_other_task.id).exists()
@pytest.mark.django_db
class TestPerformScheduledScanTask:
"""Unit tests for perform_scheduled_scan_task."""
@staticmethod
@contextmanager
def _override_task_request(task, **attrs):
request = task.request
sentinel = object()
previous = {key: getattr(request, key, sentinel) for key in attrs}
for key, value in attrs.items():
setattr(request, key, value)
try:
yield
finally:
for key, prev in previous.items():
if prev is sentinel:
if hasattr(request, key):
delattr(request, key)
else:
setattr(request, key, prev)
def _create_periodic_task(self, provider_id, tenant_id, interval_hours=24):
interval, _ = IntervalSchedule.objects.get_or_create(
every=interval_hours, period="hours"
)
return PeriodicTask.objects.create(
name=f"scan-perform-scheduled-{provider_id}",
task="scan-perform-scheduled",
interval=interval,
kwargs=f'{{"tenant_id": "{tenant_id}", "provider_id": "{provider_id}"}}',
enabled=True,
)
def _create_task_result(self, tenant_id, task_id):
task_result = TaskResult.objects.create(
task_id=task_id,
task_name="scan-perform-scheduled",
status="STARTED",
date_created=datetime.now(timezone.utc),
)
Task.objects.create(
id=task_id, task_runner_task=task_result, tenant_id=tenant_id
)
return task_result
def test_skip_when_scheduled_scan_executing(
self, tenants_fixture, providers_fixture
):
"""Skip a scheduled run when another scheduled scan is already executing."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
periodic_task = self._create_periodic_task(provider.id, tenant.id)
task_id = str(uuid.uuid4())
self._create_task_result(tenant.id, task_id)
executing_scan = Scan.objects.create(
tenant_id=tenant.id,
provider=provider,
name="Daily scheduled scan",
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING,
scheduler_task_id=periodic_task.id,
)
with (
patch("tasks.tasks.perform_prowler_scan") as mock_scan,
patch("tasks.tasks._perform_scan_complete_tasks") as mock_complete_tasks,
self._override_task_request(perform_scheduled_scan_task, id=task_id),
):
result = perform_scheduled_scan_task.run(
tenant_id=str(tenant.id), provider_id=str(provider.id)
)
mock_scan.assert_not_called()
mock_complete_tasks.assert_not_called()
assert result["id"] == str(executing_scan.id)
assert result["state"] == StateChoices.EXECUTING
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
).count()
== 0
)
def test_creates_next_scheduled_scan_after_completion(
self, tenants_fixture, providers_fixture
):
"""Create a next scheduled scan after a successful run completes."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
self._create_periodic_task(provider.id, tenant.id)
task_id = str(uuid.uuid4())
self._create_task_result(tenant.id, task_id)
def _complete_scan(tenant_id, scan_id, provider_id):
other_scheduled = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
).exclude(id=scan_id)
assert not other_scheduled.exists()
scan_instance = Scan.objects.get(id=scan_id)
scan_instance.state = StateChoices.COMPLETED
scan_instance.save()
return {"status": "ok"}
with (
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
patch("tasks.tasks._perform_scan_complete_tasks"),
self._override_task_request(perform_scheduled_scan_task, id=task_id),
):
perform_scheduled_scan_task.run(
tenant_id=str(tenant.id), provider_id=str(provider.id)
)
scheduled_scans = Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
)
assert scheduled_scans.count() == 1
assert scheduled_scans.first().scheduled_at > datetime.now(timezone.utc)
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
).count()
== 1
)
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.COMPLETED,
).count()
== 1
)
def test_dedupes_multiple_scheduled_scans_before_run(
self, tenants_fixture, providers_fixture
):
"""Ensure duplicated scheduled scans are removed before executing."""
tenant = tenants_fixture[0]
provider = providers_fixture[0]
periodic_task = self._create_periodic_task(provider.id, tenant.id)
task_id = str(uuid.uuid4())
self._create_task_result(tenant.id, task_id)
scheduled_scan = Scan.objects.create(
tenant_id=tenant.id,
provider=provider,
name="Daily scheduled scan",
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=datetime.now(timezone.utc),
scheduler_task_id=periodic_task.id,
)
duplicate_scan = Scan.objects.create(
tenant_id=tenant.id,
provider=provider,
name="Daily scheduled scan",
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.AVAILABLE,
scheduled_at=scheduled_scan.scheduled_at,
scheduler_task_id=periodic_task.id,
)
def _complete_scan(tenant_id, scan_id, provider_id):
other_scheduled = Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
).exclude(id=scan_id)
assert not other_scheduled.exists()
scan_instance = Scan.objects.get(id=scan_id)
scan_instance.state = StateChoices.COMPLETED
scan_instance.save()
return {"status": "ok"}
with (
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
patch("tasks.tasks._perform_scan_complete_tasks"),
self._override_task_request(perform_scheduled_scan_task, id=task_id),
):
perform_scheduled_scan_task.run(
tenant_id=str(tenant.id), provider_id=str(provider.id)
)
assert not Scan.objects.filter(id=duplicate_scan.id).exists()
assert Scan.objects.filter(id=scheduled_scan.id).exists()
assert (
Scan.objects.filter(
tenant_id=tenant.id,
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
).count()
== 1
)

View File

@@ -5,6 +5,10 @@ from enum import Enum
from django_celery_beat.models import PeriodicTask from django_celery_beat.models import PeriodicTask
from django_celery_results.models import TaskResult from django_celery_results.models import TaskResult
from api.models import Scan, StateChoices
SCHEDULED_SCAN_NAME = "Daily scheduled scan"
class CustomEncoder(json.JSONEncoder): class CustomEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):
@@ -71,3 +75,58 @@ def batched(iterable, batch_size):
batch = [] batch = []
yield batch, True yield batch, True
def _get_or_create_scheduled_scan(
tenant_id: str,
provider_id: str,
scheduler_task_id: int,
scheduled_at: datetime,
update_state: bool = False,
) -> Scan:
"""
Get or create a scheduled scan, cleaning up duplicates if found.
Args:
tenant_id: The tenant ID.
provider_id: The provider ID.
scheduler_task_id: The PeriodicTask ID.
scheduled_at: The scheduled datetime for the scan.
update_state: If True, also reset state to SCHEDULED when updating.
Returns:
The scan instance to use.
"""
scheduled_scans = list(
Scan.objects.filter(
tenant_id=tenant_id,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
scheduler_task_id=scheduler_task_id,
).order_by("scheduled_at", "inserted_at")
)
if scheduled_scans:
scan_instance = scheduled_scans[0]
if len(scheduled_scans) > 1:
Scan.objects.filter(id__in=[s.id for s in scheduled_scans[1:]]).delete()
needs_update = scan_instance.scheduled_at != scheduled_at
if update_state and scan_instance.state != StateChoices.SCHEDULED:
scan_instance.state = StateChoices.SCHEDULED
scan_instance.name = SCHEDULED_SCAN_NAME
needs_update = True
if needs_update:
scan_instance.scheduled_at = scheduled_at
scan_instance.save()
return scan_instance
return Scan.objects.create(
tenant_id=tenant_id,
name=SCHEDULED_SCAN_NAME,
provider_id=provider_id,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.SCHEDULED,
scheduled_at=scheduled_at,
scheduler_task_id=scheduler_task_id,
)