mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-22 03:08:23 +00:00
fix(scans): scheduled scans duplicates (#9883)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
This commit is contained in:
@@ -15,6 +15,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
|||||||
- Lazy load Neo4j driver for workers only [(#9872)](https://github.com/prowler-cloud/prowler/pull/9872)
|
- Lazy load Neo4j driver for workers only [(#9872)](https://github.com/prowler-cloud/prowler/pull/9872)
|
||||||
- Improve Cypher query for inserting Findings into Attack Paths scan graphs [(#9874)](https://github.com/prowler-cloud/prowler/pull/9874)
|
- Improve Cypher query for inserting Findings into Attack Paths scan graphs [(#9874)](https://github.com/prowler-cloud/prowler/pull/9874)
|
||||||
- Clear Neo4j database cache after Attack Paths scan and each API query [(#9877)](https://github.com/prowler-cloud/prowler/pull/9877)
|
- Clear Neo4j database cache after Attack Paths scan and each API query [(#9877)](https://github.com/prowler-cloud/prowler/pull/9877)
|
||||||
|
- Deduplicated scheduled scans for long-running providers [(#9829)](https://github.com/prowler-cloud/prowler/pull/9829)
|
||||||
|
|
||||||
## [1.18.0] (Prowler v5.17.0)
|
## [1.18.0] (Prowler v5.17.0)
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
|
||||||
from celery import chain, group, shared_task
|
from celery import chain, group, shared_task
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
from django_celery_beat.models import PeriodicTask
|
|
||||||
|
|
||||||
from api.compliance import get_compliance_frameworks
|
|
||||||
from api.db_router import READ_REPLICA_ALIAS
|
|
||||||
from api.db_utils import rls_transaction
|
|
||||||
from api.decorators import handle_provider_deletion, set_tenant
|
|
||||||
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
|
|
||||||
from api.utils import initialize_prowler_provider
|
|
||||||
from api.v1.serializers import ScanTaskSerializer
|
|
||||||
from config.celery import RLSTask
|
from config.celery import RLSTask
|
||||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
|
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
|
||||||
from prowler.lib.check.compliance_models import Compliance
|
from django_celery_beat.models import PeriodicTask
|
||||||
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
|
|
||||||
from prowler.lib.outputs.finding import Finding as FindingOutput
|
|
||||||
from tasks.jobs.attack_paths import (
|
from tasks.jobs.attack_paths import (
|
||||||
attack_paths_scan,
|
attack_paths_scan,
|
||||||
can_provider_run_attack_paths_scan,
|
can_provider_run_attack_paths_scan,
|
||||||
@@ -64,7 +52,22 @@ from tasks.jobs.scan import (
|
|||||||
perform_prowler_scan,
|
perform_prowler_scan,
|
||||||
update_provider_compliance_scores,
|
update_provider_compliance_scores,
|
||||||
)
|
)
|
||||||
from tasks.utils import batched, get_next_execution_datetime
|
from tasks.utils import (
|
||||||
|
_get_or_create_scheduled_scan,
|
||||||
|
batched,
|
||||||
|
get_next_execution_datetime,
|
||||||
|
)
|
||||||
|
|
||||||
|
from api.compliance import get_compliance_frameworks
|
||||||
|
from api.db_router import READ_REPLICA_ALIAS
|
||||||
|
from api.db_utils import rls_transaction
|
||||||
|
from api.decorators import handle_provider_deletion, set_tenant
|
||||||
|
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
|
||||||
|
from api.utils import initialize_prowler_provider
|
||||||
|
from api.v1.serializers import ScanTaskSerializer
|
||||||
|
from prowler.lib.check.compliance_models import Compliance
|
||||||
|
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
|
||||||
|
from prowler.lib.outputs.finding import Finding as FindingOutput
|
||||||
|
|
||||||
logger = get_task_logger(__name__)
|
logger = get_task_logger(__name__)
|
||||||
|
|
||||||
@@ -275,44 +278,38 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
|||||||
periodic_task_instance = PeriodicTask.objects.get(
|
periodic_task_instance = PeriodicTask.objects.get(
|
||||||
name=f"scan-perform-scheduled-{provider_id}"
|
name=f"scan-perform-scheduled-{provider_id}"
|
||||||
)
|
)
|
||||||
|
executing_scan = (
|
||||||
executed_scan = Scan.objects.filter(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
task__task_runner_task__task_id=task_id,
|
|
||||||
).order_by("completed_at")
|
|
||||||
|
|
||||||
if (
|
|
||||||
Scan.objects.filter(
|
Scan.objects.filter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
state=StateChoices.EXECUTING,
|
state=StateChoices.EXECUTING,
|
||||||
scheduler_task_id=periodic_task_instance.id,
|
)
|
||||||
scheduled_at__date=datetime.now(timezone.utc).date(),
|
.order_by("-started_at")
|
||||||
).exists()
|
.first()
|
||||||
or executed_scan.exists()
|
)
|
||||||
):
|
if executing_scan:
|
||||||
# Duplicated task execution due to visibility timeout or scan is already running
|
logger.warning(
|
||||||
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
|
f"Scheduled scan already executing for provider {provider_id}. Skipping."
|
||||||
try:
|
)
|
||||||
affected_scan = executed_scan.first()
|
return ScanTaskSerializer(instance=executing_scan).data
|
||||||
if not affected_scan:
|
|
||||||
raise ValueError(
|
|
||||||
"Error retrieving affected scan details after detecting duplicated scheduled "
|
|
||||||
"scan."
|
|
||||||
)
|
|
||||||
# Return the affected scan details to avoid losing data
|
|
||||||
serializer = ScanTaskSerializer(instance=affected_scan)
|
|
||||||
except Exception as duplicated_scan_exception:
|
|
||||||
logger.error(
|
|
||||||
f"Duplicated scheduled scan for provider {provider_id}. Error retrieving affected scan details: "
|
|
||||||
f"{str(duplicated_scan_exception)}"
|
|
||||||
)
|
|
||||||
raise duplicated_scan_exception
|
|
||||||
return serializer.data
|
|
||||||
|
|
||||||
|
executed_scan = Scan.objects.filter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
task__task_runner_task__task_id=task_id,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if executed_scan:
|
||||||
|
# Duplicated task execution due to visibility timeout
|
||||||
|
logger.warning(f"Duplicated scheduled scan for provider {provider_id}.")
|
||||||
|
return ScanTaskSerializer(instance=executed_scan).data
|
||||||
|
|
||||||
|
interval = periodic_task_instance.interval
|
||||||
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
|
next_scan_datetime = get_next_execution_datetime(task_id, provider_id)
|
||||||
|
current_scan_datetime = next_scan_datetime - timedelta(
|
||||||
|
**{interval.period: interval.every}
|
||||||
|
)
|
||||||
|
|
||||||
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
|
# TEMPORARY WORKAROUND: Clean up orphan scans from transaction isolation issue
|
||||||
_cleanup_orphan_scheduled_scans(
|
_cleanup_orphan_scheduled_scans(
|
||||||
@@ -321,19 +318,12 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
|||||||
scheduler_task_id=periodic_task_instance.id,
|
scheduler_task_id=periodic_task_instance.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
scan_instance, _ = Scan.objects.get_or_create(
|
scan_instance = _get_or_create_scheduled_scan(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
|
||||||
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
|
||||||
scheduler_task_id=periodic_task_instance.id,
|
scheduler_task_id=periodic_task_instance.id,
|
||||||
defaults={
|
scheduled_at=current_scan_datetime,
|
||||||
"state": StateChoices.SCHEDULED,
|
|
||||||
"name": "Daily scheduled scan",
|
|
||||||
"scheduled_at": next_scan_datetime - timedelta(days=1),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scan_instance.task_id = task_id
|
scan_instance.task_id = task_id
|
||||||
scan_instance.save()
|
scan_instance.save()
|
||||||
|
|
||||||
@@ -343,18 +333,19 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
|||||||
scan_id=str(scan_instance.id),
|
scan_id=str(scan_instance.id),
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
finally:
|
finally:
|
||||||
with rls_transaction(tenant_id):
|
with rls_transaction(tenant_id):
|
||||||
Scan.objects.get_or_create(
|
now = datetime.now(timezone.utc)
|
||||||
|
if next_scan_datetime <= now:
|
||||||
|
interval_delta = timedelta(**{interval.period: interval.every})
|
||||||
|
while next_scan_datetime <= now:
|
||||||
|
next_scan_datetime += interval_delta
|
||||||
|
_get_or_create_scheduled_scan(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name="Daily scheduled scan",
|
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
trigger=Scan.TriggerChoices.SCHEDULED,
|
|
||||||
state=StateChoices.SCHEDULED,
|
|
||||||
scheduled_at=next_scan_datetime,
|
|
||||||
scheduler_task_id=periodic_task_instance.id,
|
scheduler_task_id=periodic_task_instance.id,
|
||||||
|
scheduled_at=next_scan_datetime,
|
||||||
|
update_state=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
|
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
|
||||||
|
|||||||
@@ -1,21 +1,13 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime, timezone
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
from django_celery_beat.models import IntervalSchedule, PeriodicTask
|
from django_celery_beat.models import IntervalSchedule, PeriodicTask
|
||||||
|
from django_celery_results.models import TaskResult
|
||||||
from api.models import (
|
|
||||||
Integration,
|
|
||||||
LighthouseProviderConfiguration,
|
|
||||||
LighthouseProviderModels,
|
|
||||||
Scan,
|
|
||||||
StateChoices,
|
|
||||||
)
|
|
||||||
from tasks.jobs.lighthouse_providers import (
|
from tasks.jobs.lighthouse_providers import (
|
||||||
_create_bedrock_client,
|
_create_bedrock_client,
|
||||||
_extract_bedrock_credentials,
|
_extract_bedrock_credentials,
|
||||||
@@ -27,11 +19,21 @@ from tasks.tasks import (
|
|||||||
check_lighthouse_provider_connection_task,
|
check_lighthouse_provider_connection_task,
|
||||||
generate_outputs_task,
|
generate_outputs_task,
|
||||||
perform_attack_paths_scan_task,
|
perform_attack_paths_scan_task,
|
||||||
|
perform_scheduled_scan_task,
|
||||||
refresh_lighthouse_provider_models_task,
|
refresh_lighthouse_provider_models_task,
|
||||||
s3_integration_task,
|
s3_integration_task,
|
||||||
security_hub_integration_task,
|
security_hub_integration_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from api.models import (
|
||||||
|
Integration,
|
||||||
|
LighthouseProviderConfiguration,
|
||||||
|
LighthouseProviderModels,
|
||||||
|
Scan,
|
||||||
|
StateChoices,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
class TestExtractBedrockCredentials:
|
class TestExtractBedrockCredentials:
|
||||||
@@ -2137,3 +2139,215 @@ class TestCleanupOrphanScheduledScans:
|
|||||||
assert not Scan.objects.filter(id=orphan_scan.id).exists()
|
assert not Scan.objects.filter(id=orphan_scan.id).exists()
|
||||||
assert Scan.objects.filter(id=scheduled_scan.id).exists()
|
assert Scan.objects.filter(id=scheduled_scan.id).exists()
|
||||||
assert Scan.objects.filter(id=available_scan_other_task.id).exists()
|
assert Scan.objects.filter(id=available_scan_other_task.id).exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
class TestPerformScheduledScanTask:
|
||||||
|
"""Unit tests for perform_scheduled_scan_task."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@contextmanager
|
||||||
|
def _override_task_request(task, **attrs):
|
||||||
|
request = task.request
|
||||||
|
sentinel = object()
|
||||||
|
previous = {key: getattr(request, key, sentinel) for key in attrs}
|
||||||
|
for key, value in attrs.items():
|
||||||
|
setattr(request, key, value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for key, prev in previous.items():
|
||||||
|
if prev is sentinel:
|
||||||
|
if hasattr(request, key):
|
||||||
|
delattr(request, key)
|
||||||
|
else:
|
||||||
|
setattr(request, key, prev)
|
||||||
|
|
||||||
|
def _create_periodic_task(self, provider_id, tenant_id, interval_hours=24):
|
||||||
|
interval, _ = IntervalSchedule.objects.get_or_create(
|
||||||
|
every=interval_hours, period="hours"
|
||||||
|
)
|
||||||
|
return PeriodicTask.objects.create(
|
||||||
|
name=f"scan-perform-scheduled-{provider_id}",
|
||||||
|
task="scan-perform-scheduled",
|
||||||
|
interval=interval,
|
||||||
|
kwargs=f'{{"tenant_id": "{tenant_id}", "provider_id": "{provider_id}"}}',
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_task_result(self, tenant_id, task_id):
|
||||||
|
task_result = TaskResult.objects.create(
|
||||||
|
task_id=task_id,
|
||||||
|
task_name="scan-perform-scheduled",
|
||||||
|
status="STARTED",
|
||||||
|
date_created=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
Task.objects.create(
|
||||||
|
id=task_id, task_runner_task=task_result, tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
return task_result
|
||||||
|
|
||||||
|
def test_skip_when_scheduled_scan_executing(
|
||||||
|
self, tenants_fixture, providers_fixture
|
||||||
|
):
|
||||||
|
"""Skip a scheduled run when another scheduled scan is already executing."""
|
||||||
|
tenant = tenants_fixture[0]
|
||||||
|
provider = providers_fixture[0]
|
||||||
|
periodic_task = self._create_periodic_task(provider.id, tenant.id)
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
self._create_task_result(tenant.id, task_id)
|
||||||
|
|
||||||
|
executing_scan = Scan.objects.create(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
name="Daily scheduled scan",
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.EXECUTING,
|
||||||
|
scheduler_task_id=periodic_task.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tasks.tasks.perform_prowler_scan") as mock_scan,
|
||||||
|
patch("tasks.tasks._perform_scan_complete_tasks") as mock_complete_tasks,
|
||||||
|
self._override_task_request(perform_scheduled_scan_task, id=task_id),
|
||||||
|
):
|
||||||
|
result = perform_scheduled_scan_task.run(
|
||||||
|
tenant_id=str(tenant.id), provider_id=str(provider.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_scan.assert_not_called()
|
||||||
|
mock_complete_tasks.assert_not_called()
|
||||||
|
assert result["id"] == str(executing_scan.id)
|
||||||
|
assert result["state"] == StateChoices.EXECUTING
|
||||||
|
assert (
|
||||||
|
Scan.objects.filter(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.SCHEDULED,
|
||||||
|
).count()
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_creates_next_scheduled_scan_after_completion(
|
||||||
|
self, tenants_fixture, providers_fixture
|
||||||
|
):
|
||||||
|
"""Create a next scheduled scan after a successful run completes."""
|
||||||
|
tenant = tenants_fixture[0]
|
||||||
|
provider = providers_fixture[0]
|
||||||
|
self._create_periodic_task(provider.id, tenant.id)
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
self._create_task_result(tenant.id, task_id)
|
||||||
|
|
||||||
|
def _complete_scan(tenant_id, scan_id, provider_id):
|
||||||
|
other_scheduled = Scan.objects.filter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.SCHEDULED,
|
||||||
|
).exclude(id=scan_id)
|
||||||
|
assert not other_scheduled.exists()
|
||||||
|
scan_instance = Scan.objects.get(id=scan_id)
|
||||||
|
scan_instance.state = StateChoices.COMPLETED
|
||||||
|
scan_instance.save()
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
|
||||||
|
patch("tasks.tasks._perform_scan_complete_tasks"),
|
||||||
|
self._override_task_request(perform_scheduled_scan_task, id=task_id),
|
||||||
|
):
|
||||||
|
perform_scheduled_scan_task.run(
|
||||||
|
tenant_id=str(tenant.id), provider_id=str(provider.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduled_scans = Scan.objects.filter(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.SCHEDULED,
|
||||||
|
)
|
||||||
|
assert scheduled_scans.count() == 1
|
||||||
|
assert scheduled_scans.first().scheduled_at > datetime.now(timezone.utc)
|
||||||
|
assert (
|
||||||
|
Scan.objects.filter(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||||
|
).count()
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
Scan.objects.filter(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.COMPLETED,
|
||||||
|
).count()
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dedupes_multiple_scheduled_scans_before_run(
|
||||||
|
self, tenants_fixture, providers_fixture
|
||||||
|
):
|
||||||
|
"""Ensure duplicated scheduled scans are removed before executing."""
|
||||||
|
tenant = tenants_fixture[0]
|
||||||
|
provider = providers_fixture[0]
|
||||||
|
periodic_task = self._create_periodic_task(provider.id, tenant.id)
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
self._create_task_result(tenant.id, task_id)
|
||||||
|
|
||||||
|
scheduled_scan = Scan.objects.create(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
name="Daily scheduled scan",
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.SCHEDULED,
|
||||||
|
scheduled_at=datetime.now(timezone.utc),
|
||||||
|
scheduler_task_id=periodic_task.id,
|
||||||
|
)
|
||||||
|
duplicate_scan = Scan.objects.create(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
name="Daily scheduled scan",
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.AVAILABLE,
|
||||||
|
scheduled_at=scheduled_scan.scheduled_at,
|
||||||
|
scheduler_task_id=periodic_task.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _complete_scan(tenant_id, scan_id, provider_id):
|
||||||
|
other_scheduled = Scan.objects.filter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||||
|
).exclude(id=scan_id)
|
||||||
|
assert not other_scheduled.exists()
|
||||||
|
scan_instance = Scan.objects.get(id=scan_id)
|
||||||
|
scan_instance.state = StateChoices.COMPLETED
|
||||||
|
scan_instance.save()
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("tasks.tasks.perform_prowler_scan", side_effect=_complete_scan),
|
||||||
|
patch("tasks.tasks._perform_scan_complete_tasks"),
|
||||||
|
self._override_task_request(perform_scheduled_scan_task, id=task_id),
|
||||||
|
):
|
||||||
|
perform_scheduled_scan_task.run(
|
||||||
|
tenant_id=str(tenant.id), provider_id=str(provider.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not Scan.objects.filter(id=duplicate_scan.id).exists()
|
||||||
|
assert Scan.objects.filter(id=scheduled_scan.id).exists()
|
||||||
|
assert (
|
||||||
|
Scan.objects.filter(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
provider=provider,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||||
|
).count()
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ from enum import Enum
|
|||||||
from django_celery_beat.models import PeriodicTask
|
from django_celery_beat.models import PeriodicTask
|
||||||
from django_celery_results.models import TaskResult
|
from django_celery_results.models import TaskResult
|
||||||
|
|
||||||
|
from api.models import Scan, StateChoices
|
||||||
|
|
||||||
|
SCHEDULED_SCAN_NAME = "Daily scheduled scan"
|
||||||
|
|
||||||
|
|
||||||
class CustomEncoder(json.JSONEncoder):
|
class CustomEncoder(json.JSONEncoder):
|
||||||
def default(self, o):
|
def default(self, o):
|
||||||
@@ -71,3 +75,58 @@ def batched(iterable, batch_size):
|
|||||||
batch = []
|
batch = []
|
||||||
|
|
||||||
yield batch, True
|
yield batch, True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_scheduled_scan(
|
||||||
|
tenant_id: str,
|
||||||
|
provider_id: str,
|
||||||
|
scheduler_task_id: int,
|
||||||
|
scheduled_at: datetime,
|
||||||
|
update_state: bool = False,
|
||||||
|
) -> Scan:
|
||||||
|
"""
|
||||||
|
Get or create a scheduled scan, cleaning up duplicates if found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant ID.
|
||||||
|
provider_id: The provider ID.
|
||||||
|
scheduler_task_id: The PeriodicTask ID.
|
||||||
|
scheduled_at: The scheduled datetime for the scan.
|
||||||
|
update_state: If True, also reset state to SCHEDULED when updating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The scan instance to use.
|
||||||
|
"""
|
||||||
|
scheduled_scans = list(
|
||||||
|
Scan.objects.filter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state__in=(StateChoices.SCHEDULED, StateChoices.AVAILABLE),
|
||||||
|
scheduler_task_id=scheduler_task_id,
|
||||||
|
).order_by("scheduled_at", "inserted_at")
|
||||||
|
)
|
||||||
|
|
||||||
|
if scheduled_scans:
|
||||||
|
scan_instance = scheduled_scans[0]
|
||||||
|
if len(scheduled_scans) > 1:
|
||||||
|
Scan.objects.filter(id__in=[s.id for s in scheduled_scans[1:]]).delete()
|
||||||
|
needs_update = scan_instance.scheduled_at != scheduled_at
|
||||||
|
if update_state and scan_instance.state != StateChoices.SCHEDULED:
|
||||||
|
scan_instance.state = StateChoices.SCHEDULED
|
||||||
|
scan_instance.name = SCHEDULED_SCAN_NAME
|
||||||
|
needs_update = True
|
||||||
|
if needs_update:
|
||||||
|
scan_instance.scheduled_at = scheduled_at
|
||||||
|
scan_instance.save()
|
||||||
|
return scan_instance
|
||||||
|
|
||||||
|
return Scan.objects.create(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=SCHEDULED_SCAN_NAME,
|
||||||
|
provider_id=provider_id,
|
||||||
|
trigger=Scan.TriggerChoices.SCHEDULED,
|
||||||
|
state=StateChoices.SCHEDULED,
|
||||||
|
scheduled_at=scheduled_at,
|
||||||
|
scheduler_task_id=scheduler_task_id,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user