From cf9beb82342887e46364bb524fba6a9f03272d9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Pe=C3=B1a?= Date: Tue, 2 Jun 2026 14:00:17 +0200 Subject: [PATCH] feat(api): recover orphaned background tasks and make task re-runs idempotent (#11416) --- api/CHANGELOG.md | 18 + api/docker-entrypoint.sh | 8 +- api/docs/orphan-task-recovery.md | 86 ++++ .../commands/reconcile_orphan_tasks.py | 49 +++ .../migrations/0094_scan_recovery_count.py | 17 + ...95_reconcile_orphan_tasks_periodic_task.py | 49 +++ .../api/migrations/0096_jiraissuedispatch.py | 64 +++ api/src/backend/api/models.py | 32 ++ api/src/backend/config/celery.py | 55 +++ .../tasks/jobs/attack_paths/cleanup.py | 30 +- api/src/backend/tasks/jobs/deletion.py | 9 + api/src/backend/tasks/jobs/integrations.py | 151 ++++--- api/src/backend/tasks/jobs/orphan_recovery.py | 397 ++++++++++++++++++ api/src/backend/tasks/jobs/scan.py | 28 +- api/src/backend/tasks/tasks.py | 7 + api/src/backend/tasks/tests/test_deletion.py | 40 +- .../backend/tasks/tests/test_integrations.py | 179 +++++++- .../tasks/tests/test_orphan_recovery.py | 372 ++++++++++++++++ api/src/backend/tasks/tests/test_scan.py | 184 ++++++++ api/src/backend/tasks/tests/test_tasks.py | 33 ++ docker-compose-dev.yml | 2 + docker-compose.yml | 2 + 22 files changed, 1722 insertions(+), 90 deletions(-) create mode 100644 api/docs/orphan-task-recovery.md create mode 100644 api/src/backend/api/management/commands/reconcile_orphan_tasks.py create mode 100644 api/src/backend/api/migrations/0094_scan_recovery_count.py create mode 100644 api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py create mode 100644 api/src/backend/api/migrations/0096_jiraissuedispatch.py create mode 100644 api/src/backend/tasks/jobs/orphan_recovery.py create mode 100644 api/src/backend/tasks/tests/test_orphan_recovery.py diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 8f400fd012..bf95d3c569 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -2,6 +2,24 @@ All notable changes to the **Prowler API** are documented in this file. +## [1.31.0] (Prowler v5.30.0) + +### 🚀 Added + +- Automatic recovery of allowlisted idempotent background tasks whose worker died during a deploy or crash: stuck scan and summary tasks are detected and re-run instead of staying pending forever, with a `reconcile_orphan_tasks` management command for on-demand recovery [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) +- Jira integration no longer creates duplicate issues on a retried send; findings already ticketed are skipped [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) + +### 🔄 Changed + +- Allowlisted idempotent background tasks are no longer lost when a worker is stopped or crashes mid-task; tasks with external side effects are marked terminal instead of blindly re-running [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) +- A recovered scan rewrites its findings, summaries, attack surface, and compliance data instead of appending to the previous run, so recovery never leaves stale or duplicate materialized rows [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) + +### 🐞 Fixed + +- Workers now shut down gracefully on deploy or restart, finishing or re-queueing in-flight tasks instead of being force-killed and leaving them stuck [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) + +--- + ## [1.30.1] (Prowler v5.29.1) ### 🐞 Fixed diff --git a/api/docker-entrypoint.sh b/api/docker-entrypoint.sh index 4535fb34e2..e6313f459a 100755 --- a/api/docker-entrypoint.sh +++ b/api/docker-entrypoint.sh @@ -22,12 +22,12 @@ apply_fixtures() { start_dev_server() { echo "Starting the development server..." - uv run python manage.py runserver 0.0.0.0:"${DJANGO_PORT:-8080}" + exec uv run python manage.py runserver 0.0.0.0:"${DJANGO_PORT:-8080}" } start_prod_server() { echo "Starting the Gunicorn server..." - uv run gunicorn -c config/guniconf.py config.wsgi:application + exec uv run gunicorn -c config/guniconf.py config.wsgi:application } resolve_worker_hostname() { @@ -47,7 +47,7 @@ resolve_worker_hostname() { start_worker() { echo "Starting the worker..." - uv run python -m celery -A config.celery worker \ + exec uv run python -m celery -A config.celery worker \ -n "$(resolve_worker_hostname)" \ -l "${DJANGO_LOGGING_LEVEL:-info}" \ -Q celery,scans,scan-reports,deletion,backfill,overview,integrations,compliance,attack-paths-scans \ @@ -56,7 +56,7 @@ start_worker() { start_worker_beat() { echo "Starting the worker-beat..." - uv run python -m celery -A config.celery beat -l "${DJANGO_LOGGING_LEVEL:-info}" --scheduler django_celery_beat.schedulers:DatabaseScheduler + exec uv run python -m celery -A config.celery beat -l "${DJANGO_LOGGING_LEVEL:-info}" --scheduler django_celery_beat.schedulers:DatabaseScheduler } manage_db_partitions() { diff --git a/api/docs/orphan-task-recovery.md b/api/docs/orphan-task-recovery.md new file mode 100644 index 0000000000..38b1546bae --- /dev/null +++ b/api/docs/orphan-task-recovery.md @@ -0,0 +1,86 @@ +# Orphan Celery task recovery + +When a worker is terminated mid-task (a deploy, an OOM kill, a node eviction), the +task it was running can be left non-terminal forever: the `Scan` stays `EXECUTING`, +the `TaskResult` stays `STARTED`, and nothing re-runs it. This page describes the +mechanisms that detect and recover allowlisted idempotent orphans so users never +see a stuck scan and pending-task alerts do not fire. + +## How recovery works + +1. **Durable delivery.** The broker is configured so a task message is acknowledged + only after the task finishes (`task_acks_late`), one task is reserved at a time + (`worker_prefetch_multiplier = 1`), and an abruptly-lost worker re-queues its task + (`task_reject_on_worker_lost`). On `SIGTERM` the worker is given a soft-shutdown + window (`worker_soft_shutdown_timeout`) to finish or re-queue in-flight work + before it is force-killed. + +2. **Periodic watchdog.** A Beat task, `reconcile-orphan-tasks`, runs every couple of + minutes (a `django_celery_beat` periodic task created by migration). For each + in-flight task result with an allowlisted idempotent task name, it pings the + worker recorded on the task's `TaskResult`: + - worker responds -> the task is still running, leave it alone; + - worker is gone (and the scan started before a short grace window) -> it is a + real orphan: the stale task is revoked and marked terminal (clearing the + pending/started alert), and the scan is re-enqueued from scratch. + + The re-run is safe because only tasks with proven idempotency are allowlisted. + Scan persistence, for example, clears the scan's prior findings and materialized + summary/compliance rows before re-writing them. Jira sends are allowlisted too: + each finding is reserved in a dispatch table before the external call, so a re-run + skips already-ticketed findings (the worst case is one finding missed if a worker + is hard-killed mid-send, never a duplicate issue). Other external side effects stay + terminal: the S3 upload rebuilds from worker-local files that do not survive a + crash, and report/Security Hub recovery is out of scope. + +3. **Recovery cap.** Each automatic re-enqueue increments `Scan.recovery_count`. + After `--max-attempts` recoveries (default 3) the scan is marked `FAILED` instead + of re-enqueued, so a task that repeatedly kills its worker cannot loop forever. + +A Postgres advisory lock ensures that, even with multiple API/worker replicas, only +one reconciliation runs at a time; the others no-op. + +## On-demand command + +The same logic is available as a management command, useful right after a deploy or +for manual intervention: + +```bash +python manage.py reconcile_orphan_tasks # recover now +python manage.py reconcile_orphan_tasks --dry-run # report orphans, change nothing +python manage.py reconcile_orphan_tasks --grace-minutes 5 --max-attempts 3 +``` + +## Configuration + +All settings have safe defaults; override via environment variables. + +| Env var | Default | Purpose | +| --- | --- | --- | +| `DJANGO_CELERY_WORKER_PREFETCH_MULTIPLIER` | `1` | Tasks reserved per worker process. | +| `DJANGO_CELERY_WORKER_SOFT_SHUTDOWN_TIMEOUT` | `60` | Seconds the worker drains/re-queues on `SIGTERM` before force-kill. | +| `DJANGO_CELERY_TASK_TIME_LIMIT` | `21600` (6h) | Hard limit for most tasks; connection checks are capped at 120s. | +| `DJANGO_CELERY_TASK_SOFT_TIME_LIMIT` | hard - 600 | Soft limit; raises `SoftTimeLimitExceeded` for cleanup. | +| `DJANGO_CELERY_LONG_TASK_TIME_LIMIT` | `172800` (48h) | Hard limit for scans and provider/tenant deletions, which can legitimately run for more than a day. | +| `DJANGO_CELERY_LONG_TASK_SOFT_TIME_LIMIT` | long hard - 600 | Soft limit for the long-running tasks above. | + +`task_acks_late` and `task_reject_on_worker_lost` are enabled in `config/celery.py`. + +## Deployment requirement + +Two conditions must both hold for the soft shutdown to actually drain work: + +1. **The worker must receive `SIGTERM`.** The container entrypoint `exec`s the + Celery process so it runs as PID 1; otherwise `SIGTERM` from `docker stop`/ECS + hits the entrypoint shell, never reaches Celery, and the worker is hard-killed + (SIGKILL) at the grace deadline without draining. Custom entrypoints must + preserve the `exec`. +2. **The orchestrator must give the worker enough time** before force-killing it. + Set the stop grace period to exceed `DJANGO_CELERY_WORKER_SOFT_SHUTDOWN_TIMEOUT` + plus a margin: + - **docker-compose:** `stop_grace_period` on the worker services (set to `120s`). + - **AWS ECS:** the worker container `stopTimeout` (configured in the deployment + repository). + +If either condition is missing, long tasks are still recovered by the watchdog, +but they are cut mid-run on every deploy instead of draining. diff --git a/api/src/backend/api/management/commands/reconcile_orphan_tasks.py b/api/src/backend/api/management/commands/reconcile_orphan_tasks.py new file mode 100644 index 0000000000..cdfe6b3fda --- /dev/null +++ b/api/src/backend/api/management/commands/reconcile_orphan_tasks.py @@ -0,0 +1,49 @@ +from django.core.management.base import BaseCommand + +from tasks.jobs.orphan_recovery import reconcile_orphans + + +class Command(BaseCommand): + help = ( + "Recover orphaned allowlisted Celery tasks whose worker is gone and mark " + "other stale task results terminal. Single-flight via a Postgres advisory lock." + ) + + def add_arguments(self, parser): + parser.add_argument( + "--grace-minutes", + type=int, + default=2, + help="Skip tasks started within this window (worker may still register).", + ) + parser.add_argument( + "--max-attempts", + type=int, + default=3, + help="Give up re-running a task after this many recovery attempts (scans are marked FAILED).", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Detect and report orphans without revoking or re-enqueuing.", + ) + + def handle(self, *args, **options): + result = reconcile_orphans( + grace_minutes=options["grace_minutes"], + max_attempts=options["max_attempts"], + dry_run=options["dry_run"], + ) + + if not result.get("acquired"): + self.stdout.write("Reconcile skipped: another run holds the lock.") + return + + self.stdout.write( + self.style.SUCCESS( + "Orphan reconcile complete: " + f"recovered={len(result.get('recovered', []))} " + f"failed={len(result.get('failed', []))} " + f"skipped(in-flight)={len(result.get('skipped', []))}" + ) + ) diff --git a/api/src/backend/api/migrations/0094_scan_recovery_count.py b/api/src/backend/api/migrations/0094_scan_recovery_count.py new file mode 100644 index 0000000000..01f7af42df --- /dev/null +++ b/api/src/backend/api/migrations/0094_scan_recovery_count.py @@ -0,0 +1,17 @@ +# Generated by Django 5.1.15 on 2026-05-30 17:38 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("api", "0093_okta_provider"), + ] + + operations = [ + migrations.AddField( + model_name="scan", + name="recovery_count", + field=models.IntegerField(default=0), + ), + ] diff --git a/api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py b/api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py new file mode 100644 index 0000000000..ab511a11b1 --- /dev/null +++ b/api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py @@ -0,0 +1,49 @@ +from django.db import migrations + + +TASK_NAME = "reconcile-orphan-tasks" +INTERVAL_MINUTES = 2 + + +def create_periodic_task(apps, schema_editor): + IntervalSchedule = apps.get_model("django_celery_beat", "IntervalSchedule") + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") + + schedule, _ = IntervalSchedule.objects.get_or_create( + every=INTERVAL_MINUTES, + period="minutes", + ) + + PeriodicTask.objects.update_or_create( + name=TASK_NAME, + defaults={ + "task": TASK_NAME, + "interval": schedule, + "enabled": True, + }, + ) + + +def delete_periodic_task(apps, schema_editor): + IntervalSchedule = apps.get_model("django_celery_beat", "IntervalSchedule") + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") + + PeriodicTask.objects.filter(name=TASK_NAME).delete() + + # Clean up the schedule if no other task references it + IntervalSchedule.objects.filter( + every=INTERVAL_MINUTES, + period="minutes", + periodictask__isnull=True, + ).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ("api", "0094_scan_recovery_count"), + ("django_celery_beat", "0019_alter_periodictasks_options"), + ] + + operations = [ + migrations.RunPython(create_periodic_task, delete_periodic_task), + ] diff --git a/api/src/backend/api/migrations/0096_jiraissuedispatch.py b/api/src/backend/api/migrations/0096_jiraissuedispatch.py new file mode 100644 index 0000000000..f5a1a9d9a0 --- /dev/null +++ b/api/src/backend/api/migrations/0096_jiraissuedispatch.py @@ -0,0 +1,64 @@ +import uuid + +import django.db.models.deletion +from django.db import migrations, models + +import api.rls + + +class Migration(migrations.Migration): + dependencies = [ + ("api", "0095_reconcile_orphan_tasks_periodic_task"), + ] + + operations = [ + migrations.CreateModel( + name="JiraIssueDispatch", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("inserted_at", models.DateTimeField(auto_now_add=True)), + ("finding_id", models.UUIDField()), + ( + "integration", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="jira_dispatches", + to="api.integration", + ), + ), + ( + "tenant", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="api.tenant" + ), + ), + ], + options={ + "db_table": "jira_issue_dispatches", + "abstract": False, + }, + ), + migrations.AddConstraint( + model_name="jiraissuedispatch", + constraint=models.UniqueConstraint( + fields=("tenant_id", "integration_id", "finding_id"), + name="unique_jira_issue_dispatch", + ), + ), + migrations.AddConstraint( + model_name="jiraissuedispatch", + constraint=api.rls.RowLevelSecurityConstraint( + "tenant_id", + name="rls_on_jiraissuedispatch", + statements=["SELECT", "INSERT", "UPDATE", "DELETE"], + ), + ), + ] diff --git a/api/src/backend/api/models.py b/api/src/backend/api/models.py index 3d9a26698e..8adbc2cf9e 100644 --- a/api/src/backend/api/models.py +++ b/api/src/backend/api/models.py @@ -666,6 +666,9 @@ class Scan(RowLevelSecurityProtectedModel): state = StateEnumField(choices=StateChoices.choices, default=StateChoices.AVAILABLE) unique_resource_count = models.IntegerField(default=0) progress = models.IntegerField(default=0) + # Incremented by the scan-specific orphan-recovery path each time this scan is + # re-pointed to a fresh task; for observability (the retry cap is a Valkey counter). + recovery_count = models.IntegerField(default=0) scanner_args = models.JSONField(default=dict) duration = models.IntegerField(null=True, blank=True) scheduled_at = models.DateTimeField(null=True, blank=True) @@ -1998,6 +2001,35 @@ class IntegrationProviderRelationship(RowLevelSecurityProtectedModel): ] +class JiraIssueDispatch(RowLevelSecurityProtectedModel): + """Tracks findings already sent to a Jira integration. + + Lets the Jira task be re-run safely (e.g. by orphan recovery): findings with + an existing dispatch row are skipped, so no duplicate issues are created. + """ + + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) + inserted_at = models.DateTimeField(auto_now_add=True, editable=False) + integration = models.ForeignKey( + Integration, on_delete=models.CASCADE, related_name="jira_dispatches" + ) + finding_id = models.UUIDField() + + class Meta(RowLevelSecurityProtectedModel.Meta): + db_table = "jira_issue_dispatches" + constraints = [ + models.UniqueConstraint( + fields=["tenant_id", "integration_id", "finding_id"], + name="unique_jira_issue_dispatch", + ), + RowLevelSecurityConstraint( + field="tenant_id", + name="rls_on_%(class)s", + statements=["SELECT", "INSERT", "UPDATE", "DELETE"], + ), + ] + + class SAMLToken(models.Model): id = models.UUIDField(primary_key=True, default=uuid4, editable=False) inserted_at = models.DateTimeField(auto_now_add=True, editable=False) diff --git a/api/src/backend/config/celery.py b/api/src/backend/config/celery.py index c46c8a426c..5d246395a5 100644 --- a/api/src/backend/config/celery.py +++ b/api/src/backend/config/celery.py @@ -26,6 +26,61 @@ celery_app.conf.result_backend_transport_options = { } celery_app.conf.visibility_timeout = BROKER_VISIBILITY_TIMEOUT +# Durable delivery: keep the message until the task finishes, so a worker killed +# mid-task (deploy/OOM/eviction) does not silently drop it. Reserve one task at a +# time so a crash exposes at most one extra reserved message. +celery_app.conf.task_acks_late = True +celery_app.conf.task_reject_on_worker_lost = True +celery_app.conf.worker_prefetch_multiplier = env.int( + "DJANGO_CELERY_WORKER_PREFETCH_MULTIPLIER", default=1 +) +# On SIGTERM, give the worker time to finish or re-queue in-flight tasks before +# it is forcefully killed (Celery 5.5+ soft shutdown). +celery_app.conf.worker_soft_shutdown_timeout = env.int( + "DJANGO_CELERY_WORKER_SOFT_SHUTDOWN_TIMEOUT", default=60 +) +# Bound execution so a blocked task cannot pin a worker forever. Connection +# checks get a tight limit; scans and provider/tenant deletions can legitimately +# run for more than a day on large tenants, so they get a much higher cap. +# The default for every other task is set as the global limit, not as a "*" +# annotation: Celery applies the "*" entry AFTER the per-task one, so a "*" in +# task_annotations would silently overwrite every specific limit defined below. +_TASK_HARD_LIMIT = env.int("DJANGO_CELERY_TASK_TIME_LIMIT", default=6 * 60 * 60) +_TASK_SOFT_LIMIT = env.int( + "DJANGO_CELERY_TASK_SOFT_TIME_LIMIT", default=_TASK_HARD_LIMIT - 600 +) +_LONG_TASK_HARD_LIMIT = env.int( + "DJANGO_CELERY_LONG_TASK_TIME_LIMIT", default=48 * 60 * 60 +) +_LONG_TASK_SOFT_LIMIT = env.int( + "DJANGO_CELERY_LONG_TASK_SOFT_TIME_LIMIT", default=_LONG_TASK_HARD_LIMIT - 600 +) +celery_app.conf.task_time_limit = _TASK_HARD_LIMIT +celery_app.conf.task_soft_time_limit = _TASK_SOFT_LIMIT +celery_app.conf.task_annotations = { + **{ + name: {"soft_time_limit": 60, "time_limit": 120} + for name in ( + "provider-connection-check", + "integration-connection-check", + "lighthouse-connection-check", + "lighthouse-provider-connection-check", + ) + }, + **{ + name: { + "soft_time_limit": _LONG_TASK_SOFT_LIMIT, + "time_limit": _LONG_TASK_HARD_LIMIT, + } + for name in ( + "scan-perform", + "scan-perform-scheduled", + "provider-deletion", + "tenant-deletion", + ) + }, +} + celery_app.autodiscover_tasks(["api"]) diff --git a/api/src/backend/tasks/jobs/attack_paths/cleanup.py b/api/src/backend/tasks/jobs/attack_paths/cleanup.py index 65ba583a3e..fa7670afaa 100644 --- a/api/src/backend/tasks/jobs/attack_paths/cleanup.py +++ b/api/src/backend/tasks/jobs/attack_paths/cleanup.py @@ -1,12 +1,14 @@ from datetime import datetime, timedelta, timezone -from celery import current_app, states +from celery import states from celery.utils.log import get_task_logger from config.django.base import ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES from tasks.jobs.attack_paths.db_utils import ( _mark_scan_finished, recover_graph_data_ready, ) +from tasks.jobs.orphan_recovery import is_worker_alive as _is_worker_alive +from tasks.jobs.orphan_recovery import revoke_task as _revoke_task from api.attack_paths import database as graph_database from api.db_router import MainRouter @@ -150,32 +152,6 @@ def _cleanup_stale_scheduled_scans(cutoff: datetime) -> list[str]: return cleaned_up -def _is_worker_alive(worker: str) -> bool: - """Ping a specific Celery worker. Returns `True` if it responds or on error.""" - try: - response = current_app.control.inspect(destination=[worker], timeout=1.0).ping() - return response is not None and worker in response - except Exception: - logger.exception(f"Failed to ping worker {worker}, treating as alive") - return True - - -def _revoke_task(task_result, terminate: bool = True) -> None: - """Revoke a Celery task. Non-fatal on failure. - - `terminate=True` SIGTERMs the worker if the task is mid-execution; use - for EXECUTING cleanup. `terminate=False` only marks the task id revoked - across workers, so any worker pulling the queued message discards it; - use for SCHEDULED cleanup where the task hasn't run yet. - """ - try: - kwargs = {"terminate": True, "signal": "SIGTERM"} if terminate else {} - current_app.control.revoke(task_result.task_id, **kwargs) - logger.info(f"Revoked task {task_result.task_id}") - except Exception: - logger.exception(f"Failed to revoke task {task_result.task_id}") - - def _cleanup_scan(scan, task_result, reason: str) -> bool: """ Clean up a single stale `AttackPathsScan`: diff --git a/api/src/backend/tasks/jobs/deletion.py b/api/src/backend/tasks/jobs/deletion.py index f9ead01897..7540f72c7e 100644 --- a/api/src/backend/tasks/jobs/deletion.py +++ b/api/src/backend/tasks/jobs/deletion.py @@ -11,6 +11,7 @@ from api.db_utils import batch_delete, rls_transaction from api.models import ( AttackPathsScan, Finding, + JiraIssueDispatch, Provider, ProviderComplianceScore, Resource, @@ -80,6 +81,14 @@ def delete_provider(tenant_id: str, pk: str): deletion_steps = [ ("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)), + ( + "Jira Issue Dispatches", + JiraIssueDispatch.objects.filter( + finding_id__in=Finding.all_objects.filter( + scan__provider=instance + ).values_list("id", flat=True) + ), + ), ("Findings", Finding.all_objects.filter(scan__provider=instance)), ("Resources", Resource.all_objects.filter(provider=instance)), ("Scans", Scan.all_objects.filter(provider=instance)), diff --git a/api/src/backend/tasks/jobs/integrations.py b/api/src/backend/tasks/jobs/integrations.py index 5ca94057da..55c1205169 100644 --- a/api/src/backend/tasks/jobs/integrations.py +++ b/api/src/backend/tasks/jobs/integrations.py @@ -9,7 +9,7 @@ from tasks.utils import batched from api.db_router import READ_REPLICA_ALIAS, MainRouter from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction -from api.models import Finding, Integration, Provider +from api.models import Finding, Integration, JiraIssueDispatch, Provider from api.utils import initialize_prowler_integration, initialize_prowler_provider from prowler.lib.outputs.asff.asff import ASFF from prowler.lib.outputs.compliance.generic.generic import GenericCompliance @@ -482,66 +482,115 @@ def send_findings_to_jira( with rls_transaction(tenant_id): integration = Integration.objects.get(id=integration_id) jira_integration = initialize_prowler_integration(integration) + # Idempotency: findings already ticketed for this integration must not be + # sent again on a re-run (e.g. orphan recovery), to avoid duplicate issues + already_sent = { + str(fid) + for fid in JiraIssueDispatch.objects.filter( + integration_id=integration_id, finding_id__in=finding_ids + ).values_list("finding_id", flat=True) + } num_tickets_created = 0 + skipped_count = 0 for finding_id in finding_ids: + if str(finding_id) in already_sent: + skipped_count += 1 + continue + + # Reserve the finding BEFORE the external call. The unique constraint on + # (tenant, integration, finding) makes the dispatch row the single source of + # truth, so a concurrent run or a retry that raced past the bulk pre-check + # cannot create a duplicate issue: created=False means another run already + # claimed it. The reservation is released below if the send does not succeed. with rls_transaction(tenant_id): - finding_instance = ( - Finding.all_objects.select_related("scan__provider") - .prefetch_related("resources") - .get(id=finding_id) + _, created = JiraIssueDispatch.objects.get_or_create( + tenant_id=tenant_id, + integration_id=integration_id, + finding_id=finding_id, ) + if not created: + skipped_count += 1 + continue - # Extract resource information - resource = ( - finding_instance.resources.first() - if finding_instance.resources.exists() - else None - ) - resource_uid = resource.uid if resource else "" - resource_name = resource.name if resource else "" - resource_tags = {} - if resource and hasattr(resource, "tags"): - resource_tags = resource.get_tags(tenant_id) + sent = False + try: + with rls_transaction(tenant_id): + finding_instance = ( + Finding.all_objects.select_related("scan__provider") + .prefetch_related("resources") + .get(id=finding_id) + ) - # Get region - region = resource.region if resource and resource.region else "" + # Extract resource information + resource = ( + finding_instance.resources.first() + if finding_instance.resources.exists() + else None + ) + resource_uid = resource.uid if resource else "" + resource_name = resource.name if resource else "" + resource_tags = {} + if resource and hasattr(resource, "tags"): + resource_tags = resource.get_tags(tenant_id) - # Extract remediation information from check_metadata - check_metadata = finding_instance.check_metadata - remediation = check_metadata.get("remediation", {}) - recommendation = remediation.get("recommendation", {}) - remediation_code = remediation.get("code", {}) + # Get region + region = resource.region if resource and resource.region else "" - # Send the individual finding to Jira - result = jira_integration.send_finding( - check_id=finding_instance.check_id, - check_title=check_metadata.get("checktitle", ""), - severity=finding_instance.severity, - status=finding_instance.status, - status_extended=finding_instance.status_extended or "", - provider=finding_instance.scan.provider.provider, - region=region, - resource_uid=resource_uid, - resource_name=resource_name, - risk=check_metadata.get("risk", ""), - recommendation_text=recommendation.get("text", ""), - recommendation_url=recommendation.get("url", ""), - remediation_code_native_iac=remediation_code.get("nativeiac", ""), - remediation_code_terraform=remediation_code.get("terraform", ""), - remediation_code_cli=remediation_code.get("cli", ""), - remediation_code_other=remediation_code.get("other", ""), - resource_tags=resource_tags, - compliance=finding_instance.compliance or {}, - project_key=project_key, - issue_type=issue_type, - ) - if result: - num_tickets_created += 1 - else: - logger.error(f"Failed to send finding {finding_id} to Jira") + # Extract remediation information from check_metadata + check_metadata = finding_instance.check_metadata + remediation = check_metadata.get("remediation", {}) + recommendation = remediation.get("recommendation", {}) + remediation_code = remediation.get("code", {}) + + # Send the individual finding to Jira + sent = bool( + jira_integration.send_finding( + check_id=finding_instance.check_id, + check_title=check_metadata.get("checktitle", ""), + severity=finding_instance.severity, + status=finding_instance.status, + status_extended=finding_instance.status_extended or "", + provider=finding_instance.scan.provider.provider, + region=region, + resource_uid=resource_uid, + resource_name=resource_name, + risk=check_metadata.get("risk", ""), + recommendation_text=recommendation.get("text", ""), + recommendation_url=recommendation.get("url", ""), + remediation_code_native_iac=remediation_code.get( + "nativeiac", "" + ), + remediation_code_terraform=remediation_code.get( + "terraform", "" + ), + remediation_code_cli=remediation_code.get("cli", ""), + remediation_code_other=remediation_code.get("other", ""), + resource_tags=resource_tags, + compliance=finding_instance.compliance or {}, + project_key=project_key, + issue_type=issue_type, + ) + ) + finally: + if not sent: + # Release the reservation so a later run can retry this finding: it + # was not ticketed (send failed or raised), so the row must not block + # a future legitimate send. + with rls_transaction(tenant_id): + JiraIssueDispatch.objects.filter( + tenant_id=tenant_id, + integration_id=integration_id, + finding_id=finding_id, + ).delete() + + if sent: + num_tickets_created += 1 + else: + logger.error(f"Failed to send finding {finding_id} to Jira") return { "created_count": num_tickets_created, - "failed_count": len(finding_ids) - num_tickets_created, + "failed_count": len(finding_ids) - num_tickets_created - skipped_count, + "skipped_count": skipped_count, } diff --git a/api/src/backend/tasks/jobs/orphan_recovery.py b/api/src/backend/tasks/jobs/orphan_recovery.py new file mode 100644 index 0000000000..d884c3fc8b --- /dev/null +++ b/api/src/backend/tasks/jobs/orphan_recovery.py @@ -0,0 +1,397 @@ +"""Detect and recover orphaned Celery tasks. + +A task is "orphaned" when its result row is non-terminal (STARTED/RECEIVED) but the +worker that was running it is gone (deploy, OOM, eviction). We tell a real orphan +from a still-running task by pinging the worker recorded on its `TaskResult`: + +- worker responds -> the task is in flight, leave it alone (never double-run); +- worker is gone -> real orphan: mark the stale result terminal (so pending/started + alerts clear), then re-enqueue the task from its stored name + kwargs. + +This recovers only allowlisted tasks with local, proven idempotency. Celery's +`result_extended=True` gives us the stored `task_name`/`task_kwargs`/`worker` once +the task starts, but external side-effect tasks are failed instead of blindly +re-run. A small recovery cap stops a task that repeatedly kills its worker from +looping forever. + +This is the shared engine behind both the periodic Beat watchdog and the +`reconcile_orphan_tasks` management command. +""" + +import ast +import json +from contextlib import contextmanager +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +from celery import current_app, states +from celery.utils.log import get_task_logger +from django.db import connections + +logger = get_task_logger(__name__) + +# Arbitrary constant key for pg_try_advisory_lock so only one reconciliation +# runs at a time across replicas / the watchdog / the command. +ORPHAN_RECOVERY_LOCK_KEY = 0x70726F77 # "prow" + +# Non-terminal states that mean "a worker had this and may have died with it". +IN_FLIGHT_STATES = (states.STARTED, states.RECEIVED) + +# Scan tasks are recovered by re-running scan-perform on the EXISTING scan row, +# not by re-enqueuing the original task: re-enqueuing scan-perform-scheduled would +# hit its "a scan is already executing" guard and no-op, leaving the scan stuck. +_SCAN_TASKS = ("scan-perform", "scan-perform-scheduled") + +# Tasks with proven idempotency are auto re-enqueued. Scans/summaries clear and +# rewrite their own rows. integration-jira is safe too: each finding is reserved in +# JiraIssueDispatch before the external call, so a re-run skips already-ticketed +# findings (worst case one finding missed on a mid-send crash, never a duplicate). +# Other external side effects stay terminal: integration-s3 rebuilds its upload from +# worker-local files that do not survive a crash, and report/Security Hub recovery is +# out of scope. +REENQUEUEABLE_TASKS = { + *_SCAN_TASKS, + "provider-deletion", + "tenant-deletion", + "scan-summary", + "scan-compliance-overviews", + "scan-provider-compliance-scores", + "scan-daily-severity", + "scan-finding-group-summaries", + "scan-reset-ephemeral-resources", + "integration-jira", +} + +# Tasks excluded from generic recovery: attack-paths scans are handled by their own +# stale-cleanup (which also drops the temp Neo4j db), and the maintenance tasks must +# not self-recover (they run again on their own schedule). +_SKIP_RECOVERY = { + "attack-paths-scan-perform", + "attack-paths-cleanup-stale-scans", + "reconcile-orphan-tasks", +} + + +@contextmanager +def advisory_lock(key: int = ORPHAN_RECOVERY_LOCK_KEY, using: str = "default"): + """Yield True if this session won a Postgres advisory lock, else False. + + Non-blocking: losers get False and should no-op. The lock is released on + exit (and implicitly if the session dies). + """ + with connections[using].cursor() as cursor: + cursor.execute("SELECT pg_try_advisory_lock(%s)", [key]) + acquired = bool(cursor.fetchone()[0]) + try: + yield acquired + finally: + if acquired: + cursor.execute("SELECT pg_advisory_unlock(%s)", [key]) + + +def is_worker_alive(worker: str, timeout: float = 1.0) -> bool: + """Ping a specific Celery worker. Returns True if it responds, or on error. + + Erring on the side of "alive" means an unreachable control bus never causes + a still-running task to be re-enqueued. + """ + try: + response = current_app.control.inspect( + destination=[worker], timeout=timeout + ).ping() + return response is not None and worker in response + except Exception: + logger.exception(f"Failed to ping worker {worker}, treating as alive") + return True + + +def revoke_task(task_result, terminate: bool = True) -> None: + """Revoke a Celery task by its TaskResult. Non-fatal on failure. + + terminate=True SIGTERMs the worker if the task is mid-execution; terminate=False + only marks the id revoked so any worker pulling the queued message discards it + (use before re-enqueuing, so a later broker redelivery of the stale message is + dropped). + """ + try: + kwargs = {"terminate": True, "signal": "SIGTERM"} if terminate else {} + current_app.control.revoke(task_result.task_id, **kwargs) + logger.info(f"Revoked task {task_result.task_id}") + except Exception: + logger.exception(f"Failed to revoke task {task_result.task_id}") + + +def _decode_celery_field(value, default): + """Decode django-celery-results' stored task_args/task_kwargs to a Python object. + + The backend stores them as a (sometimes double-encoded) repr/JSON string. An + empty or missing field returns ``default``; a non-empty value that cannot be + decoded raises ``ValueError`` so the caller can avoid re-enqueuing a task with + the wrong arguments. + """ + obj = value + for _ in range(2): # values can be double-encoded (a string holding a repr) + if not isinstance(obj, str): + break + text = obj.strip() + if not text: + return default + parsed = None + for parser in (ast.literal_eval, json.loads): + try: + parsed = parser(text) + break + except (ValueError, SyntaxError, TypeError): + continue + if parsed is None: + raise ValueError(f"undecodable celery field: {text[:120]!r}") + obj = parsed + return default if obj is None else obj + + +def reconcile_orphans( + grace_minutes: int = 2, + max_attempts: int = 3, + window_hours: int = 6, + dry_run: bool = False, +) -> dict: + """Run the full orphan sweep under a single-flight advisory lock. + + Recovers any orphaned in-flight task and delegates attack-paths scans that + never reached a worker to their existing stale-cleanup. Returns a summary; + a no-op (lock not won) is reported too. + """ + with advisory_lock() as acquired: + if not acquired: + logger.info("Orphan reconcile skipped: another run holds the lock") + return {"acquired": False} + + # Populate the task registry so we can re-enqueue any task by name. + import tasks.tasks # noqa: F401 + + result = _reconcile_task_results( + grace_minutes=grace_minutes, + max_attempts=max_attempts, + window_hours=window_hours, + dry_run=dry_run, + ) + + if not dry_run: + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + result["attack_paths"] = cleanup_stale_attack_paths_scans() + + return {"acquired": True, **result} + + +def _reconcile_task_results( + grace_minutes: int, max_attempts: int, window_hours: int, dry_run: bool +) -> dict: + from django_celery_results.models import TaskResult + + cutoff = datetime.now(tz=timezone.utc) - timedelta(minutes=grace_minutes) + candidates = list( + TaskResult.objects.filter(status__in=IN_FLIGHT_STATES, date_created__lt=cutoff) + .exclude(worker__isnull=True) + .exclude(worker="") + .exclude(task_name__in=_SKIP_RECOVERY) + ) + + # Ping each distinct worker at most once. + worker_alive = {w: is_worker_alive(w) for w in {tr.worker for tr in candidates}} + + recovered, failed, skipped = [], [], [] + for task_result in candidates: + if worker_alive.get(task_result.worker, True): + skipped.append(task_result.task_id) # in flight, do not double-run + continue + if dry_run: + recovered.append(task_result.task_id) + continue + outcome = _recover_task(task_result, max_attempts, window_hours) + (recovered if outcome == "recovered" else failed).append(task_result.task_id) + + logger.info( + "Orphan reconcile: recovered=%d failed=%d skipped(in-flight)=%d", + len(recovered), + len(failed), + len(skipped), + ) + return {"recovered": recovered, "failed": failed, "skipped": skipped} + + +def _recovery_attempt_count(name: str, kwargs_repr, window_hours: int) -> int: + """Increment and return the recovery count for this (task, kwargs) within the + window. Backed by Valkey so it survives result-row churn (a worker processing + the revoke can blank the TaskResult fields). Fail-open if Valkey is down (the + broker being unreachable means nothing is running anyway). + """ + import hashlib + + from django.conf import settings + + try: + import redis + + client = redis.from_url(settings.CELERY_BROKER_URL) + signature = f"{name}|{kwargs_repr}".encode() + key = ( + "orphan-recovery:" + + hashlib.sha1(signature, usedforsecurity=False).hexdigest() + ) + count = client.incr(key) + if count == 1: + client.expire(key, max(1, window_hours) * 3600) + return int(count) + except Exception: + logger.exception("Recovery-attempt counter unavailable; allowing recovery") + return 1 + + +def _recover_task(task_result, max_attempts: int, window_hours: int) -> str: + """Recover one orphaned task. Returns 'recovered' or 'failed'.""" + # Capture name/args/kwargs now: revoking can let a worker blank the row. + name = task_result.task_name + args_repr = task_result.task_args + kwargs_repr = task_result.task_kwargs + now = datetime.now(tz=timezone.utc) + + # Drop any future broker redelivery of the stale message. + revoke_task(task_result, terminate=False) + + # Mark the stale result terminal so "pending/started forever" alerts clear. + task_result.status = states.REVOKED + task_result.date_done = now + task_result.save(update_fields=["status", "date_done"]) + + attempt = _recovery_attempt_count(name, kwargs_repr, window_hours) + if name not in REENQUEUEABLE_TASKS or attempt > max_attempts: + reason = ( + f"{name} is not allowlisted for auto recovery" + if name not in REENQUEUEABLE_TASKS + else f"recovery cap reached ({attempt}/{max_attempts})" + ) + _fail_domain_row(task_result.task_id, name, now) + logger.warning( + "Orphan %s (%s) not re-enqueued: %s", task_result.task_id, name, reason + ) + return "failed" + + # Scan tasks: re-run the EXISTING scan row directly via scan-perform, so the + # scheduled-scan "already executing" guard cannot turn recovery into a no-op. + # Falls through to the generic path only if no scan is linked yet (e.g. a + # scheduled task that died before creating one), where re-running it creates one. + if name in _SCAN_TASKS: + scan = _scan_for_task(task_result.task_id) + if scan is not None: + if not _reenqueue_scan(task_result.task_id, scan): + return "failed" + logger.info( + "Re-enqueued orphaned scan %s (was task %s)", + scan.id, + task_result.task_id, + ) + return "recovered" + + task_obj = current_app.tasks.get(name) + if task_obj is None: + logger.error( + "Orphan %s: task %s not registered, cannot re-enqueue", + task_result.task_id, + name, + ) + return "failed" + + try: + args = _decode_celery_field(args_repr, []) + kwargs = _decode_celery_field(kwargs_repr, {}) + except ValueError: + logger.error( + "Orphan %s (%s): could not decode stored args/kwargs, not re-enqueuing", + task_result.task_id, + name, + ) + _fail_domain_row(task_result.task_id, name, now) + return "failed" + new_task_id = str(uuid4()) + task_obj.apply_async( + args=list(args) if isinstance(args, (list, tuple)) else [], + kwargs=kwargs if isinstance(kwargs, dict) else {}, + task_id=new_task_id, + ) + logger.info( + "Re-enqueued orphan %s (%s) as %s", task_result.task_id, name, new_task_id + ) + return "recovered" + + +def _scan_for_task(task_id: str): + """Return the Scan linked to a Celery task id, or None (read across tenants).""" + from api.db_router import MainRouter + from api.models import Scan + + return Scan.all_objects.using(MainRouter.admin_db).filter(task_id=task_id).first() + + +def _reenqueue_scan(old_task_id: str, scan) -> bool: + """Re-run an orphaned scan via scan-perform on the existing row. + + Pre-provisions the new task linkage (TaskResult + api.Task) and relinks the + Scan before enqueuing, so the FK is valid and a worker can never outrun the DB. + The relink is conditional on the scan still pointing at the old task, so a stale + orphan can never clobber a newer linkage. + """ + from django_celery_results.models import TaskResult + + from api.db_utils import rls_transaction + from api.models import Scan + from api.models import Task as APITask + from tasks.tasks import perform_scan_task + + tenant_id = str(scan.tenant_id) + new_task_id = str(uuid4()) + with rls_transaction(tenant_id): + locked_scan = Scan.all_objects.select_for_update().filter(id=scan.id).first() + if locked_scan is None or str(locked_scan.task_id) != old_task_id: + logger.info( + "Scan %s no longer points at task %s; skipping recovery re-enqueue", + scan.id, + old_task_id, + ) + return False + task_result_new, _ = TaskResult.objects.get_or_create( + task_id=new_task_id, + defaults={"status": states.PENDING, "task_name": "scan-perform"}, + ) + APITask.objects.update_or_create( + id=new_task_id, + tenant_id=tenant_id, + defaults={"task_runner_task": task_result_new}, + ) + locked_scan.task_id = new_task_id + locked_scan.recovery_count = (locked_scan.recovery_count or 0) + 1 + locked_scan.save(update_fields=["task_id", "recovery_count", "updated_at"]) + + perform_scan_task.apply_async( + kwargs={ + "tenant_id": tenant_id, + "scan_id": str(scan.id), + "provider_id": str(scan.provider_id), + }, + task_id=new_task_id, + ) + return True + + +def _fail_domain_row(old_task_id: str, name: str, now: datetime) -> None: + """Mark a scan terminal when its task is capped/denylisted instead of re-run.""" + from api.db_utils import rls_transaction + from api.models import Scan, StateChoices + + if name in _SCAN_TASKS: + scan = _scan_for_task(old_task_id) + if scan is not None: + with rls_transaction(str(scan.tenant_id)): + Scan.all_objects.filter(id=scan.id, task_id=old_task_id).update( + state=StateChoices.FAILED, completed_at=now + ) diff --git a/api/src/backend/tasks/jobs/scan.py b/api/src/backend/tasks/jobs/scan.py index a772173a7f..298a2b225a 100644 --- a/api/src/backend/tasks/jobs/scan.py +++ b/api/src/backend/tasks/jobs/scan.py @@ -118,6 +118,19 @@ ATTACK_SURFACE_PROVIDER_COMPATIBILITY = { _ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {} +def _clear_scan_rerun_state(tenant_id: str, scan_id: str) -> None: + """Remove rows derived from a previous execution of this scan.""" + with rls_transaction(tenant_id): + Finding.all_objects.filter(scan_id=scan_id).delete() + ResourceScanSummary.objects.filter(scan_id=scan_id).delete() + ScanCategorySummary.objects.filter(scan_id=scan_id).delete() + ScanGroupSummary.objects.filter(scan_id=scan_id).delete() + ScanSummary.objects.filter(scan_id=scan_id).delete() + AttackSurfaceOverview.objects.filter(scan_id=scan_id).delete() + ComplianceRequirementOverview.objects.filter(scan_id=scan_id).delete() + ComplianceOverviewSummary.objects.filter(scan_id=scan_id).delete() + + def aggregate_category_counts( categories: list[str], severity: str, @@ -476,9 +489,13 @@ def _create_compliance_summaries( ) ) - # Bulk insert summaries - if summary_objects: - with rls_transaction(tenant_id): + # Idempotent re-run: clear this scan's prior summaries before re-inserting, so + # a recovered scan's summary always reflects its own (re-derived) requirement + # rows rather than keeping a stale row (bulk_create ignore_conflicts alone would + # keep the old one). + with rls_transaction(tenant_id): + ComplianceOverviewSummary.objects.filter(scan_id=scan_id).delete() + if summary_objects: ComplianceOverviewSummary.objects.bulk_create( summary_objects, batch_size=500, ignore_conflicts=True ) @@ -1022,6 +1039,7 @@ def perform_prowler_scan( scan_instance.state = StateChoices.EXECUTING scan_instance.started_at = datetime.now(tz=timezone.utc) scan_instance.save(update_fields=["state", "started_at", "updated_at"]) + _clear_scan_rerun_state(tenant_id, scan_id) # Find the mutelist processor if it exists with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): @@ -1651,6 +1669,10 @@ def create_compliance_requirements(tenant_id: str, scan_id: str): elif requirement_status == "PASS": requirement_statuses[key]["pass_count"] += 1 + # Idempotent re-run: COPY can't ON CONFLICT, so clear this scan's rows first. + with rls_transaction(tenant_id): + ComplianceRequirementOverview.objects.filter(scan_id=scan_id).delete() + # Bulk create requirement records using PostgreSQL COPY _persist_compliance_requirement_rows(tenant_id, compliance_requirement_rows) diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index b35b68893d..92c2604942 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -46,6 +46,7 @@ from tasks.jobs.lighthouse_providers import ( refresh_lighthouse_provider_models, ) from tasks.jobs.muting import mute_historical_findings +from tasks.jobs.orphan_recovery import reconcile_orphans from tasks.jobs.report import ( STALE_TMP_OUTPUT_MAX_AGE_HOURS, _cleanup_stale_tmp_output_directories, @@ -462,6 +463,12 @@ def cleanup_stale_attack_paths_scans_task(): return cleanup_stale_attack_paths_scans() +@shared_task(name="reconcile-orphan-tasks", queue="celery") +def reconcile_orphan_tasks_task(): + """Periodic watchdog: recover tasks whose worker is gone (deploys, crashes).""" + return reconcile_orphans() + + @shared_task(name="tenant-deletion", queue="deletion", autoretry_for=(Exception,)) def delete_tenant_task(tenant_id: str): return delete_tenant(pk=tenant_id) diff --git a/api/src/backend/tasks/tests/test_deletion.py b/api/src/backend/tasks/tests/test_deletion.py index 0ed8c5ddb2..e6cd51aca8 100644 --- a/api/src/backend/tasks/tests/test_deletion.py +++ b/api/src/backend/tasks/tests/test_deletion.py @@ -1,11 +1,12 @@ from unittest.mock import call, patch +from uuid import uuid4 import pytest from django.core.exceptions import ObjectDoesNotExist from tasks.jobs.deletion import delete_provider, delete_tenant from api.attack_paths import database as graph_database -from api.models import Provider, Tenant, TenantComplianceSummary +from api.models import JiraIssueDispatch, Provider, Tenant, TenantComplianceSummary @pytest.mark.django_db @@ -34,6 +35,43 @@ class TestDeleteProvider: str(instance.id), ) + def test_delete_provider_removes_jira_dispatches( + self, + providers_fixture, + findings_fixture, + integrations_fixture, + ): + """Deleting a provider removes JiraIssueDispatch rows for its findings only.""" + instance = providers_fixture[0] + tenant_id = str(instance.tenant_id) + finding = findings_fixture[0] + integration = integrations_fixture[0] + + # Dispatch for one of the provider's findings: must be removed with it. + JiraIssueDispatch.objects.create( + tenant_id=tenant_id, + integration=integration, + finding_id=finding.id, + ) + # Dispatch for an unrelated finding: must survive the provider deletion. + unrelated = JiraIssueDispatch.objects.create( + tenant_id=tenant_id, + integration=integration, + finding_id=uuid4(), + ) + + with ( + patch( + "tasks.jobs.deletion.graph_database.get_database_name", + return_value="tenant-db", + ), + patch("tasks.jobs.deletion.graph_database.drop_subgraph"), + ): + delete_provider(tenant_id, instance.id) + + assert not JiraIssueDispatch.objects.filter(finding_id=finding.id).exists() + assert JiraIssueDispatch.objects.filter(pk=unrelated.pk).exists() + def test_delete_provider_does_not_exist(self, tenants_fixture): with ( patch( diff --git a/api/src/backend/tasks/tests/test_integrations.py b/api/src/backend/tasks/tests/test_integrations.py index e246405cdd..ba8d52c193 100644 --- a/api/src/backend/tasks/tests/test_integrations.py +++ b/api/src/backend/tasks/tests/test_integrations.py @@ -1640,14 +1640,74 @@ class TestJiraIntegration: @patch("tasks.jobs.integrations.Finding") @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") + @patch("tasks.jobs.integrations.JiraIssueDispatch") + def test_send_findings_to_jira_skips_already_dispatched( + self, + mock_jira_dispatch, + mock_initialize_integration, + mock_integration_model, + mock_finding_model, + mock_rls_transaction, + ): + """A re-run skips findings already ticketed (no duplicate Jira issues).""" + mock_rls_transaction.return_value.__enter__ = MagicMock() + mock_rls_transaction.return_value.__exit__ = MagicMock() + mock_integration_model.objects.get.return_value = MagicMock() + # finding-1 was already dispatched in a prior run; finding-2 is new. + mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [ + "finding-1" + ] + mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) + + mock_jira_integration = MagicMock() + mock_jira_integration.send_finding.return_value = True + mock_initialize_integration.return_value = mock_jira_integration + + finding2 = MagicMock() + finding2.id = "finding-2" + finding2.check_id = "check_002" + finding2.severity = "low" + finding2.status = "FAIL" + finding2.status_extended = "" + finding2.compliance = {} + finding2.resources.exists.return_value = False + finding2.resources.first.return_value = None + finding2.scan.provider.provider = "aws" + finding2.check_metadata = { + "checktitle": "C2", + "risk": "", + "remediation": {"recommendation": {}, "code": {}}, + } + mock_finding_model.all_objects.select_related.return_value.prefetch_related.return_value.get.return_value = finding2 + + result = send_findings_to_jira( + "tenant-123", "integration-456", "PROJ", "Task", ["finding-1", "finding-2"] + ) + + # finding-1 skipped (already sent); only finding-2 sent -> no duplicate. + assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 1} + mock_jira_integration.send_finding.assert_called_once() + assert ( + mock_jira_integration.send_finding.call_args.kwargs["check_id"] + == "check_002" + ) + + @patch("tasks.jobs.integrations.rls_transaction") + @patch("tasks.jobs.integrations.Finding") + @patch("tasks.jobs.integrations.Integration") + @patch("tasks.jobs.integrations.initialize_prowler_integration") + @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_success( self, + mock_jira_dispatch, mock_initialize_integration, mock_integration_model, mock_finding_model, mock_rls_transaction, ): """Test successful sending of findings to Jira using send_finding method""" + mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] + mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -1739,7 +1799,7 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 2, "failed_count": 0} + assert result == {"created_count": 2, "failed_count": 0, "skipped_count": 0} # Verify Jira integration was initialized mock_initialize_integration.assert_called_once_with(integration) @@ -1771,8 +1831,10 @@ class TestJiraIntegration: @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") @patch("tasks.jobs.integrations.logger") + @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_partial_failure( self, + mock_jira_dispatch, mock_logger, mock_initialize_integration, mock_integration_model, @@ -1780,6 +1842,8 @@ class TestJiraIntegration: mock_rls_transaction, ): """Test partial failure when sending findings to Jira""" + mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] + mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -1833,23 +1897,35 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 2, "failed_count": 1} + assert result == {"created_count": 2, "failed_count": 1, "skipped_count": 0} # Verify error was logged for the failed finding mock_logger.error.assert_called_with("Failed to send finding finding-2 to Jira") + # The failed finding's reservation is released so a later run can retry it. + mock_jira_dispatch.objects.filter.assert_any_call( + tenant_id=tenant_id, + integration_id=integration_id, + finding_id="finding-2", + ) + mock_jira_dispatch.objects.filter.return_value.delete.assert_called_once() + @patch("tasks.jobs.integrations.rls_transaction") @patch("tasks.jobs.integrations.Finding") @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") + @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_no_resources( self, + mock_jira_dispatch, mock_initialize_integration, mock_integration_model, mock_finding_model, mock_rls_transaction, ): """Test sending findings to Jira when finding has no resources""" + mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] + mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -1907,7 +1983,7 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 1, "failed_count": 0} + assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0} # Verify send_finding was called with empty resource fields call_kwargs = mock_jira_integration.send_finding.call_args.kwargs @@ -1920,14 +1996,18 @@ class TestJiraIntegration: @patch("tasks.jobs.integrations.Finding") @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") + @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_with_empty_check_metadata( self, + mock_jira_dispatch, mock_initialize_integration, mock_integration_model, mock_finding_model, mock_rls_transaction, ): """Test sending findings to Jira when check_metadata is empty or missing fields""" + mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] + mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -1970,7 +2050,7 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 1, "failed_count": 0} + assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0} # Verify send_finding was called with default/empty values call_kwargs = mock_jira_integration.send_finding.call_args.kwargs @@ -1983,3 +2063,94 @@ class TestJiraIntegration: assert call_kwargs["remediation_code_cli"] == "" assert call_kwargs["remediation_code_other"] == "" assert call_kwargs["compliance"] == {} + + @patch("tasks.jobs.integrations.rls_transaction") + @patch("tasks.jobs.integrations.Finding") + @patch("tasks.jobs.integrations.Integration") + @patch("tasks.jobs.integrations.initialize_prowler_integration") + @patch("tasks.jobs.integrations.JiraIssueDispatch") + def test_send_findings_to_jira_reserves_before_sending( + self, + mock_jira_dispatch, + mock_initialize_integration, + mock_integration_model, + mock_finding_model, + mock_rls_transaction, + ): + """The dispatch row is reserved before the external Jira call (reserve-then-act).""" + mock_rls_transaction.return_value.__enter__ = MagicMock() + mock_rls_transaction.return_value.__exit__ = MagicMock() + mock_integration_model.objects.get.return_value = MagicMock() + mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] + + order = [] + mock_jira_dispatch.objects.get_or_create.side_effect = lambda **kw: ( + order.append(("reserve", kw)) or (MagicMock(), True) + ) + + mock_jira_integration = MagicMock() + mock_jira_integration.send_finding.side_effect = lambda **kw: ( + order.append(("send", kw)) or True + ) + mock_initialize_integration.return_value = mock_jira_integration + + finding = MagicMock() + finding.id = "finding-1" + finding.check_id = "check_001" + finding.severity = "low" + finding.status = "FAIL" + finding.status_extended = "" + finding.compliance = {} + finding.resources.exists.return_value = False + finding.resources.first.return_value = None + finding.scan.provider.provider = "aws" + finding.check_metadata = { + "checktitle": "C1", + "risk": "", + "remediation": {"recommendation": {}, "code": {}}, + } + mock_finding_model.all_objects.select_related.return_value.prefetch_related.return_value.get.return_value = finding + + result = send_findings_to_jira( + "tenant-123", "integration-456", "PROJ", "Task", ["finding-1"] + ) + + assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0} + # Reservation must precede the external send. + assert [entry[0] for entry in order] == ["reserve", "send"] + # A successful send keeps the reservation (no rollback delete). + mock_jira_dispatch.objects.filter.return_value.delete.assert_not_called() + + @patch("tasks.jobs.integrations.rls_transaction") + @patch("tasks.jobs.integrations.Finding") + @patch("tasks.jobs.integrations.Integration") + @patch("tasks.jobs.integrations.initialize_prowler_integration") + @patch("tasks.jobs.integrations.JiraIssueDispatch") + def test_send_findings_to_jira_skips_when_already_reserved( + self, + mock_jira_dispatch, + mock_initialize_integration, + mock_integration_model, + mock_finding_model, + mock_rls_transaction, + ): + """A finding that races past the bulk pre-check but loses the reservation + (created=False) is skipped without a second issue, leaving the row intact.""" + mock_rls_transaction.return_value.__enter__ = MagicMock() + mock_rls_transaction.return_value.__exit__ = MagicMock() + mock_integration_model.objects.get.return_value = MagicMock() + mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] + # Another concurrent run already created the dispatch row. + mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), False) + + mock_jira_integration = MagicMock() + mock_initialize_integration.return_value = mock_jira_integration + + result = send_findings_to_jira( + "tenant-123", "integration-456", "PROJ", "Task", ["finding-1"] + ) + + assert result == {"created_count": 0, "failed_count": 0, "skipped_count": 1} + mock_jira_integration.send_finding.assert_not_called() + # The reservation belongs to the run that won the race; do not delete it. + mock_jira_dispatch.objects.filter.return_value.delete.assert_not_called() diff --git a/api/src/backend/tasks/tests/test_orphan_recovery.py b/api/src/backend/tasks/tests/test_orphan_recovery.py new file mode 100644 index 0000000000..b426297bbd --- /dev/null +++ b/api/src/backend/tasks/tests/test_orphan_recovery.py @@ -0,0 +1,372 @@ +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from celery import states +from django_celery_results.models import TaskResult + +from api.models import Scan, StateChoices +from api.models import Task as APITask +from tasks.jobs.orphan_recovery import ( + _decode_celery_field, + _reconcile_task_results, + _recovery_attempt_count, + _reenqueue_scan, + advisory_lock, + is_worker_alive, +) + + +def _orphan_result(*, name, kwargs, worker, created_minutes_ago, status=states.STARTED): + """Create a TaskResult mimicking an in-flight task, backdated past the grace.""" + tr = TaskResult.objects.create( + task_id=str(uuid4()), + status=status, + task_name=name, + worker=worker, + task_kwargs=repr(kwargs), + task_args=repr([]), + ) + TaskResult.objects.filter(pk=tr.pk).update( + date_created=datetime.now(tz=timezone.utc) + - timedelta(minutes=created_minutes_ago) + ) + tr.refresh_from_db() + return tr + + +@pytest.mark.django_db +class TestDecodeCeleryField: + def test_decodes_single_encoded_repr(self): + assert _decode_celery_field("{'tenant_id': 'abc'}", {}) == {"tenant_id": "abc"} + + def test_decodes_double_encoded(self): + import json + + stored = json.dumps(repr({"tenant_id": "abc", "scan_id": "s1"})) + assert _decode_celery_field(stored, {}) == {"tenant_id": "abc", "scan_id": "s1"} + + def test_empty_returns_default(self): + assert _decode_celery_field(None, {}) == {} + assert _decode_celery_field("", []) == [] + + def test_unparseable_raises(self): + with pytest.raises(ValueError): + _decode_celery_field("<>", {}) + + +@pytest.mark.django_db +class TestReconcileTaskResults: + def _patches(self, alive): + """Patch worker liveness, revoke, and the task registry for re-enqueue.""" + mock_app = MagicMock() + mock_task = MagicMock() + mock_app.tasks.get.return_value = mock_task + return ( + patch("tasks.jobs.orphan_recovery.is_worker_alive", return_value=alive), + patch("tasks.jobs.orphan_recovery.revoke_task"), + patch("tasks.jobs.orphan_recovery.current_app", mock_app), + mock_task, + ) + + def test_recovers_non_scan_task(self, tenants_fixture): + """A NON-scan task (tenant-deletion) left orphaned is re-enqueued too.""" + tenant = tenants_fixture[0] + tr = _orphan_result( + name="tenant-deletion", + kwargs={"tenant_id": str(tenant.id)}, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with ( + p_alive, + p_revoke, + p_app, + patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1), + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["recovered"] + tr.refresh_from_db() + assert tr.status == states.REVOKED # stale result cleared (no pending alert) + mock_task.apply_async.assert_called_once() + call = mock_task.apply_async.call_args.kwargs + assert call["kwargs"] == {"tenant_id": str(tenant.id)} + assert call["task_id"] != tr.task_id # fresh task id + + def test_external_integration_task_is_not_reenqueued_by_default( + self, tenants_fixture + ): + """External side-effect tasks without proven idempotency stay terminal. + + integration-s3 rebuilds its upload from worker-local files that do not + survive the crash, so re-enqueuing it would upload nothing. + """ + tr = _orphan_result( + name="integration-s3", + kwargs={ + "tenant_id": str(tenants_fixture[0].id), + "provider_id": str(uuid4()), + "output_directory": "/tmp/gone", + }, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with ( + p_alive, + p_revoke, + p_app, + patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1), + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["failed"] + mock_task.apply_async.assert_not_called() + + def test_jira_integration_task_is_reenqueued(self, tenants_fixture): + """integration-jira is re-enqueued: its JiraIssueDispatch reservation makes a + re-run skip already-ticketed findings, so recovery cannot duplicate issues.""" + tenant = tenants_fixture[0] + kwargs = { + "tenant_id": str(tenant.id), + "integration_id": str(uuid4()), + "project_key": "PROWLER", + "issue_type": "Task", + "finding_ids": [str(uuid4()), str(uuid4())], + } + tr = _orphan_result( + name="integration-jira", + kwargs=kwargs, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with ( + p_alive, + p_revoke, + p_app, + patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1), + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["recovered"] + tr.refresh_from_db() + assert tr.status == states.REVOKED # stale result cleared (no pending alert) + mock_task.apply_async.assert_called_once() + call = mock_task.apply_async.call_args.kwargs + assert call["kwargs"] == kwargs + assert call["task_id"] != tr.task_id # fresh task id + + def test_skips_live_worker(self, tenants_fixture): + tr = _orphan_result( + name="tenant-deletion", + kwargs={"tenant_id": str(tenants_fixture[0].id)}, + worker="alive@host", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=True) + with p_alive, p_revoke, p_app: + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["skipped"] + mock_task.apply_async.assert_not_called() + + def test_skips_recently_created(self, tenants_fixture): + tr = _orphan_result( + name="tenant-deletion", + kwargs={"tenant_id": str(tenants_fixture[0].id)}, + worker="dead@gone", + created_minutes_ago=0, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with p_alive, p_revoke, p_app: + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + # too recent: excluded by the grace window (not even a candidate) + assert tr.task_id not in result["recovered"] + mock_task.apply_async.assert_not_called() + + def test_denylisted_task_failed_not_reenqueued(self, tenants_fixture): + """A non-allowlisted task is failed, never blind re-run.""" + tr = _orphan_result( + name="some-non-idempotent-task", + kwargs={"tenant_id": str(tenants_fixture[0].id)}, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with ( + p_alive, + p_revoke, + p_app, + patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1), + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["failed"] + tr.refresh_from_db() + assert tr.status == states.REVOKED + mock_task.apply_async.assert_not_called() + + def test_recovery_cap_marks_failed(self, tenants_fixture): + """When the recovery counter exceeds the cap, the task is failed not re-run.""" + tr = _orphan_result( + name="tenant-deletion", + kwargs={"tenant_id": str(tenants_fixture[0].id)}, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with ( + p_alive, + p_revoke, + p_app, + patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=4), + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["failed"] + mock_task.apply_async.assert_not_called() + + +@pytest.mark.django_db +class TestScanRecovery: + """Scans are recovered by re-running scan-perform on the EXISTING scan row, + so even a scheduled-scan orphan (whose own task would no-op on its guard) is + actually re-executed.""" + + def _scan_orphan(self, tenant, provider, name): + old_id = str(uuid4()) + tr = TaskResult.objects.create( + task_id=old_id, + status=states.STARTED, + task_name=name, + worker="dead@gone", + task_kwargs=repr( + {"tenant_id": str(tenant.id), "provider_id": str(provider.id)} + ), + task_args=repr([]), + ) + TaskResult.objects.filter(pk=tr.pk).update( + date_created=datetime.now(tz=timezone.utc) - timedelta(minutes=60) + ) + APITask.objects.create(id=old_id, tenant_id=tenant.id, task_runner_task=tr) + scan = Scan.objects.create( + name="scan-orphan", + provider=provider, + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.EXECUTING, + tenant_id=tenant.id, + task_id=old_id, + recovery_count=0, + ) + return old_id, scan + + @pytest.mark.parametrize("name", ["scan-perform", "scan-perform-scheduled"]) + def test_scan_recovered_via_scan_perform( + self, tenants_fixture, providers_fixture, name + ): + tenant, provider = tenants_fixture[0], providers_fixture[0] + old_id, scan = self._scan_orphan(tenant, provider, name) + + with ( + patch("tasks.jobs.orphan_recovery.is_worker_alive", return_value=False), + patch("tasks.jobs.orphan_recovery.revoke_task"), + patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1), + patch("tasks.tasks.perform_scan_task") as mock_scan_task, + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert old_id in result["recovered"] + scan.refresh_from_db() + assert str(scan.task_id) != old_id # relinked to a fresh task + assert scan.recovery_count == 1 + assert TaskResult.objects.get(task_id=old_id).status == states.REVOKED + # Recovered by re-running scan-perform on the existing scan row (so the + # scheduled guard cannot no-op it), regardless of the original task name. + mock_scan_task.apply_async.assert_called_once() + assert mock_scan_task.apply_async.call_args.kwargs["kwargs"]["scan_id"] == str( + scan.id + ) + + def test_reenqueue_skips_when_scan_already_repointed( + self, tenants_fixture, providers_fixture + ): + # The scan already points at a newer task, so a stale orphan must not relink + # it or launch a second concurrent run against the same scan row. + tenant, provider = tenants_fixture[0], providers_fixture[0] + newer_id = str(uuid4()) + tr = TaskResult.objects.create( + task_id=newer_id, status=states.STARTED, task_name="scan-perform" + ) + APITask.objects.create(id=newer_id, tenant_id=tenant.id, task_runner_task=tr) + scan = Scan.objects.create( + name="scan-orphan", + provider=provider, + trigger=Scan.TriggerChoices.SCHEDULED, + state=StateChoices.EXECUTING, + tenant_id=tenant.id, + task_id=newer_id, + recovery_count=0, + ) + + with patch("tasks.tasks.perform_scan_task") as mock_scan_task: + recovered = _reenqueue_scan(str(uuid4()), scan) + + assert recovered is False + mock_scan_task.apply_async.assert_not_called() + scan.refresh_from_db() + assert scan.recovery_count == 0 + + +@pytest.mark.django_db +class TestOrphanRecoveryHelpers: + def test_advisory_lock_acquires_and_releases(self): + with advisory_lock() as acquired: + assert acquired is True + + def test_is_worker_alive_true_when_responds(self): + inspect = MagicMock() + inspect.ping.return_value = {"w@h": {"ok": "pong"}} + with patch( + "tasks.jobs.orphan_recovery.current_app.control.inspect", + return_value=inspect, + ): + assert is_worker_alive("w@h") is True + + def test_is_worker_alive_false_when_silent(self): + inspect = MagicMock() + inspect.ping.return_value = None + with patch( + "tasks.jobs.orphan_recovery.current_app.control.inspect", + return_value=inspect, + ): + assert is_worker_alive("w@h") is False + + def test_recovery_attempt_count_increments(self): + # Unique signature so the Valkey counter starts fresh for this test. + kwargs_repr = repr({"probe": str(uuid4())}) + redis_client = MagicMock() + redis_client.incr.side_effect = [1, 2] + with patch("redis.from_url", return_value=redis_client): + assert _recovery_attempt_count("probe-task", kwargs_repr, 6) == 1 + assert _recovery_attempt_count("probe-task", kwargs_repr, 6) == 2 diff --git a/api/src/backend/tasks/tests/test_scan.py b/api/src/backend/tasks/tests/test_scan.py index 0a7193cd4d..2ff7e44e4d 100644 --- a/api/src/backend/tasks/tests/test_scan.py +++ b/api/src/backend/tasks/tests/test_scan.py @@ -32,12 +32,15 @@ from tasks.utils import CustomEncoder from api.db_router import MainRouter from api.exceptions import ProviderConnectionError from api.models import ( + AttackSurfaceOverview, Finding, MuteRule, Provider, Resource, ResourceScanSummary, Scan, + ScanCategorySummary, + ScanGroupSummary, ScanSummary, StateChoices, StatusChoices, @@ -229,6 +232,131 @@ class TestPerformScan: # Assert that failed_findings_count is 0 (finding is PASS and muted) assert scan_resource.failed_findings_count == 0 + def test_perform_prowler_scan_idempotent_on_rerun( + self, + tenants_fixture, + scans_fixture, + providers_fixture, + ): + """Re-running a scan for the same scan_id must not duplicate findings.""" + with ( + patch("api.db_utils.rls_transaction"), + patch( + "tasks.jobs.scan.initialize_prowler_provider" + ) as mock_initialize_prowler_provider, + patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class, + patch( + "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE", + new_callable=dict, + ), + patch("api.compliance.PROWLER_CHECKS", new_callable=dict) as mock_checks, + ): + mock_checks["aws"] = {"check1": {"compliance1"}} + + tenant = tenants_fixture[0] + scan = scans_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + tenant_id = str(tenant.id) + scan_id = str(scan.id) + provider_id = str(provider.id) + + stale_resource = Resource.objects.create( + tenant_id=tenant.id, + provider=provider, + uid="stale_resource_uid", + name="stale", + region="stale-region", + service="stale-service", + type="stale-type", + ) + ResourceScanSummary.objects.create( + tenant_id=tenant.id, + scan_id=scan.id, + resource_id=stale_resource.id, + service="stale-service", + region="stale-region", + resource_type="stale-type", + ) + ScanCategorySummary.objects.create( + tenant_id=tenant.id, + scan=scan, + category="stale-category", + severity=Severity.medium, + total_findings=1, + ) + ScanGroupSummary.objects.create( + tenant_id=tenant.id, + scan=scan, + resource_group="stale-group", + severity=Severity.medium, + total_findings=1, + ) + ScanSummary.objects.create( + tenant_id=tenant.id, + scan=scan, + check_id="stale_check", + service="stale-service", + severity=Severity.medium, + region="stale-region", + total=1, + ) + AttackSurfaceOverview.objects.create( + tenant_id=tenant.id, + scan=scan, + attack_surface_type=AttackSurfaceOverview.AttackSurfaceTypeChoices.SECRETS, + total_findings=1, + ) + + finding = MagicMock() + finding.uid = "dup_probe_finding" + finding.status = StatusChoices.PASS + finding.status_extended = "x" + finding.severity = Severity.medium + finding.check_id = "check1" + finding.get_metadata.return_value = {"key": "value"} + finding.resource_uid = "resource_uid" + finding.resource_name = "resource_name" + finding.region = "region" + finding.service_name = "service_name" + finding.resource_type = "resource_type" + finding.resource_tags = {} + finding.muted = False + finding.raw = {} + finding.resource_metadata = {} + finding.resource_details = {} + finding.partition = "partition" + finding.compliance = {} + + mock_scan_instance = MagicMock() + mock_scan_instance.scan.return_value = [(100, [finding])] + mock_prowler_scan_class.return_value = mock_scan_instance + + mock_provider_instance = MagicMock() + mock_provider_instance.get_regions.return_value = ["region"] + mock_initialize_prowler_provider.return_value = mock_provider_instance + + # Run the same scan twice (simulating an orphan-recovery re-run). + perform_prowler_scan(tenant_id, scan_id, provider_id, ["check1"]) + perform_prowler_scan(tenant_id, scan_id, provider_id, ["check1"]) + + # Neither findings nor resources are duplicated by the re-run: findings are + # scope-deleted before re-insert; resources are upserted by (tenant, provider, uid). + assert Finding.objects.filter(scan=scan).count() == 1 + assert Resource.objects.filter(provider=provider).count() == 2 + assert ResourceScanSummary.objects.filter(scan_id=scan.id).count() == 1 + assert not ResourceScanSummary.objects.filter( + scan_id=scan.id, resource_id=stale_resource.id + ).exists() + assert not ScanCategorySummary.objects.filter(scan=scan).exists() + assert not ScanGroupSummary.objects.filter(scan=scan).exists() + assert not ScanSummary.objects.filter( + scan=scan, check_id="stale_check" + ).exists() + assert not AttackSurfaceOverview.objects.filter(scan=scan).exists() + @patch("tasks.jobs.scan.ProwlerScan") @patch( "tasks.jobs.scan.initialize_prowler_provider", @@ -1880,6 +2008,62 @@ class TestCreateComplianceRequirements: assert "requirements_created" in result + @pytest.mark.django_db(transaction=True) + def test_create_compliance_requirements_idempotent_on_rerun( + self, + tenants_fixture, + scans_fixture, + providers_fixture, + findings_fixture, + ): + """Re-running compliance materialization must not raise nor duplicate rows. + + Uses transaction=True because the COPY path commits on its own connection, + so the test must use real commits (mirroring production) rather than the + default rollback wrapper. + """ + from api.models import ComplianceRequirementOverview + + with patch( + "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE" + ) as mock_compliance_template: + tenant_id = str(tenants_fixture[0].id) + scan_id = str(scans_fixture[0].id) + + mock_compliance_template.__getitem__.return_value = { + "test_compliance": { + "framework": "Test Framework", + "version": "1.0", + "requirements": { + "req_1": { + "description": "Test Requirement 1", + "checks": {"test_check_id": None}, + "checks_status": { + "pass": 2, + "fail": 1, + "manual": 0, + "total": 3, + }, + "status": "FAIL", + }, + }, + } + } + + create_compliance_requirements(tenant_id, scan_id) + count_after_first = ComplianceRequirementOverview.objects.filter( + scan_id=scan_id + ).count() + + # Second run must not raise and must not duplicate rows. + create_compliance_requirements(tenant_id, scan_id) + count_after_second = ComplianceRequirementOverview.objects.filter( + scan_id=scan_id + ).count() + + assert count_after_first > 0 + assert count_after_second == count_after_first + def test_create_compliance_requirements_kubernetes_provider( self, tenants_fixture, diff --git a/api/src/backend/tasks/tests/test_tasks.py b/api/src/backend/tasks/tests/test_tasks.py index 7a52c27323..f62f5684cc 100644 --- a/api/src/backend/tasks/tests/test_tasks.py +++ b/api/src/backend/tasks/tests/test_tasks.py @@ -2706,3 +2706,36 @@ class TestReaggregateAllFindingGroupSummaries: assert result == {"scans_reaggregated": 0} mock_group.assert_not_called() mock_chain.assert_not_called() + + +class TestTaskTimeLimits: + """The per-task limits in task_annotations must actually take effect. + + Celery applies a "*" annotation after the per-task one, so a "*" entry would + silently overwrite every specific limit and cap long scans at the default. The + default is set as the global limit instead, and these per-task limits must win. + """ + + def test_long_running_tasks_exceed_the_default_limit(self): + from config.celery import celery_app + + default = celery_app.conf.task_time_limit + for name in ( + "scan-perform", + "scan-perform-scheduled", + "provider-deletion", + "tenant-deletion", + ): + assert celery_app.tasks[name].time_limit > default + + def test_connection_checks_stay_below_the_default_limit(self): + from config.celery import celery_app + + default = celery_app.conf.task_time_limit + for name in ( + "provider-connection-check", + "integration-connection-check", + "lighthouse-connection-check", + "lighthouse-provider-connection-check", + ): + assert celery_app.tasks[name].time_limit < default diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 2020c7d21a..522aaf5413 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -139,6 +139,8 @@ services: worker-dev: image: prowler-api-dev + # Give Celery soft shutdown time to drain/re-queue in-flight tasks on stop. + stop_grace_period: 120s build: context: ./api dockerfile: Dockerfile diff --git a/docker-compose.yml b/docker-compose.yml index a9d2e03c3d..1519946afd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -129,6 +129,8 @@ services: worker: image: prowlercloud/prowler-api:${PROWLER_API_VERSION:-stable} + # Give Celery soft shutdown time to drain/re-queue in-flight tasks on stop. + stop_grace_period: 120s env_file: - path: .env required: false