From 1f7caa63945a44f6f03340136c61e9b09bce4ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Pe=C3=B1a?= Date: Tue, 9 Jun 2026 09:16:48 +0200 Subject: [PATCH] feat(api): make orphan-task recovery configurable and drop the Jira idempotency table (#11472) --- api/CHANGELOG.md | 4 +- api/docs/orphan-task-recovery.md | 55 ++-- .../commands/reconcile_orphan_tasks.py | 12 +- .../migrations/0094_scan_recovery_count.py | 17 -- ...95_reconcile_orphan_tasks_periodic_task.py | 2 +- .../api/migrations/0096_jiraissuedispatch.py | 64 ----- api/src/backend/api/models.py | 32 --- api/src/backend/config/django/base.py | 12 + api/src/backend/tasks/jobs/deletion.py | 9 - api/src/backend/tasks/jobs/integrations.py | 151 ++++------- api/src/backend/tasks/jobs/orphan_recovery.py | 206 ++++++--------- api/src/backend/tasks/jobs/scan.py | 21 +- api/src/backend/tasks/tasks.py | 16 +- api/src/backend/tasks/tests/test_deletion.py | 40 +-- .../backend/tasks/tests/test_integrations.py | 179 +------------ .../tasks/tests/test_orphan_recovery.py | 242 ++++++++++-------- api/src/backend/tasks/tests/test_scan.py | 128 --------- 17 files changed, 349 insertions(+), 841 deletions(-) delete mode 100644 api/src/backend/api/migrations/0094_scan_recovery_count.py delete mode 100644 api/src/backend/api/migrations/0096_jiraissuedispatch.py diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index af4b9e69bf..15466fd5fe 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -6,15 +6,13 @@ All notable changes to the **Prowler API** are documented in this file. ### 🚀 Added -- Automatic recovery of allowlisted idempotent background tasks whose worker died during a deploy or crash: stuck scan and summary tasks are detected and re-run instead of staying pending forever, with a `reconcile_orphan_tasks` management command for on-demand recovery [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) -- Jira integration no longer creates duplicate issues on a retried send; findings already ticketed are skipped [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) +- Opt-in automatic recovery of allowlisted idempotent background tasks whose worker died during a deploy or crash: when enabled via `DJANGO_TASK_RECOVERY_ENABLED` (off by default), stuck summary and deletion tasks are detected and re-run instead of staying pending forever (scan and Jira tasks are excluded), with a `reconcile_orphan_tasks` management command for on-demand recovery [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) - DORA compliance framework support [(#11131)](https://github.com/prowler-cloud/prowler/pull/11131) - Label Postgres connections with `application_name=":"` (component injected per process via `DJANGO_APP_COMPONENT`) so connections are attributable by component in `pg_stat_activity` [(#11494)](https://github.com/prowler-cloud/prowler/pull/11494) ### 🔄 Changed - Allowlisted idempotent background tasks are no longer lost when a worker is stopped or crashes mid-task; tasks with external side effects are marked terminal instead of blindly re-running [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) -- A recovered scan rewrites its findings, summaries, attack surface, and compliance data instead of appending to the previous run, so recovery never leaves stale or duplicate materialized rows [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416) ### 🐞 Fixed diff --git a/api/docs/orphan-task-recovery.md b/api/docs/orphan-task-recovery.md index 38b1546bae..a47b4f36a9 100644 --- a/api/docs/orphan-task-recovery.md +++ b/api/docs/orphan-task-recovery.md @@ -1,10 +1,11 @@ # Orphan Celery task recovery When a worker is terminated mid-task (a deploy, an OOM kill, a node eviction), the -task it was running can be left non-terminal forever: the `Scan` stays `EXECUTING`, -the `TaskResult` stays `STARTED`, and nothing re-runs it. This page describes the -mechanisms that detect and recover allowlisted idempotent orphans so users never -see a stuck scan and pending-task alerts do not fire. +task it was running can be left non-terminal forever: the `TaskResult` stays +`STARTED` and nothing re-runs it. This page describes the mechanisms that detect and +recover allowlisted idempotent orphans so pending-task alerts do not fire. Scan tasks +are not auto-recovered (re-running a scan is not safe to do automatically); the +watchdog covers the summary/aggregation and deletion tasks. ## How recovery works @@ -13,29 +14,35 @@ see a stuck scan and pending-task alerts do not fire. (`worker_prefetch_multiplier = 1`), and an abruptly-lost worker re-queues its task (`task_reject_on_worker_lost`). On `SIGTERM` the worker is given a soft-shutdown window (`worker_soft_shutdown_timeout`) to finish or re-queue in-flight work - before it is force-killed. + before it is force-killed. `scan-perform`, `scan-perform-scheduled` and + `integration-jira` opt out of redelivery with `acks_late=False`, so a crash drops + them rather than re-running and duplicating findings or Jira issues. Other + non-recovered side-effect tasks keep `acks_late=True`, so the broker can still + re-deliver them after a worker loss: the S3 upload rebuilds from worker-local files + that did not survive the crash and so no-ops, but Security Hub re-reads findings from + the DB and re-sends them to AWS. 2. **Periodic watchdog.** A Beat task, `reconcile-orphan-tasks`, runs every couple of minutes (a `django_celery_beat` periodic task created by migration). For each in-flight task result with an allowlisted idempotent task name, it pings the worker recorded on the task's `TaskResult`: - worker responds -> the task is still running, leave it alone; - - worker is gone (and the scan started before a short grace window) -> it is a + - worker is gone (and the task started before a short grace window) -> it is a real orphan: the stale task is revoked and marked terminal (clearing the - pending/started alert), and the scan is re-enqueued from scratch. + pending/started alert), and the task is re-enqueued from its stored name and + kwargs. - The re-run is safe because only tasks with proven idempotency are allowlisted. - Scan persistence, for example, clears the scan's prior findings and materialized - summary/compliance rows before re-writing them. Jira sends are allowlisted too: - each finding is reserved in a dispatch table before the external call, so a re-run - skips already-ticketed findings (the worst case is one finding missed if a worker - is hard-killed mid-send, never a duplicate issue). Other external side effects stay - terminal: the S3 upload rebuilds from worker-local files that do not survive a - crash, and report/Security Hub recovery is out of scope. + The re-run is safe because only tasks with proven idempotency are allowlisted: the + summary/aggregation tasks clear and re-write their own rows, and deletions are + idempotent. Scan tasks and external side effects are excluded: re-running a scan is + not safe to do automatically, Jira sends would create duplicate issues, the S3 + upload rebuilds from worker-local files that do not survive a crash, and + report/Security Hub recovery is out of scope. -3. **Recovery cap.** Each automatic re-enqueue increments `Scan.recovery_count`. - After `--max-attempts` recoveries (default 3) the scan is marked `FAILED` instead - of re-enqueued, so a task that repeatedly kills its worker cannot loop forever. +3. **Recovery cap.** A per-task Valkey counter limits how often the same task is + re-enqueued. After `--max-attempts` recoveries (default 3) the orphan is marked + terminal instead of re-enqueued, so a task that repeatedly kills its worker cannot + loop forever. A Postgres advisory lock ensures that, even with multiple API/worker replicas, only one reconciliation runs at a time; the others no-op. @@ -63,6 +70,18 @@ All settings have safe defaults; override via environment variables. | `DJANGO_CELERY_TASK_SOFT_TIME_LIMIT` | hard - 600 | Soft limit; raises `SoftTimeLimitExceeded` for cleanup. | | `DJANGO_CELERY_LONG_TASK_TIME_LIMIT` | `172800` (48h) | Hard limit for scans and provider/tenant deletions, which can legitimately run for more than a day. | | `DJANGO_CELERY_LONG_TASK_SOFT_TIME_LIMIT` | long hard - 600 | Soft limit for the long-running tasks above. | +| `DJANGO_TASK_RECOVERY_ENABLED` | `false` | Master switch for orphan-task recovery, disabled by default (opt-in); set to `true` to enable. When off, no orphan is detected, marked terminal, or re-enqueued (attack-paths stale cleanup still runs). | +| `DJANGO_TASK_RECOVERY_SUMMARIES_ENABLED` | `true` | Auto re-enqueue orphaned scan summary/aggregation tasks. | +| `DJANGO_TASK_RECOVERY_DELETIONS_ENABLED` | `true` | Auto re-enqueue orphaned provider/tenant deletion tasks. | + +Recovery is opt-in: with the master flag off (the default) the sweep does nothing. +Once enabled, the per-group flags default to on, so every group recovers unless you +turn one off; a task whose group flag is off is marked terminal instead of +re-enqueued. + +Turning recovery off only disables this watchdog sweep; it does not change Celery's +broker-level redelivery (`task_acks_late`/`task_reject_on_worker_lost`), which still +re-delivers tasks that keep `acks_late=True` on worker loss, independently of this flag. `task_acks_late` and `task_reject_on_worker_lost` are enabled in `config/celery.py`. diff --git a/api/src/backend/api/management/commands/reconcile_orphan_tasks.py b/api/src/backend/api/management/commands/reconcile_orphan_tasks.py index cdfe6b3fda..8ba8f5b342 100644 --- a/api/src/backend/api/management/commands/reconcile_orphan_tasks.py +++ b/api/src/backend/api/management/commands/reconcile_orphan_tasks.py @@ -20,7 +20,7 @@ class Command(BaseCommand): "--max-attempts", type=int, default=3, - help="Give up re-running a task after this many recovery attempts (scans are marked FAILED).", + help="Give up re-running a task after this many recovery attempts; it is then left terminal instead of re-enqueued.", ) parser.add_argument( "--dry-run", @@ -39,6 +39,16 @@ class Command(BaseCommand): self.stdout.write("Reconcile skipped: another run holds the lock.") return + if result.get("enabled") is False: + message = ( + "Task recovery is disabled (DJANGO_TASK_RECOVERY_ENABLED is off); " + "no orphans were recovered." + ) + if result.get("attack_paths") is not None: + message += " Attack-paths stale cleanup still ran." + self.stdout.write(message) + return + self.stdout.write( self.style.SUCCESS( "Orphan reconcile complete: " diff --git a/api/src/backend/api/migrations/0094_scan_recovery_count.py b/api/src/backend/api/migrations/0094_scan_recovery_count.py deleted file mode 100644 index 01f7af42df..0000000000 --- a/api/src/backend/api/migrations/0094_scan_recovery_count.py +++ /dev/null @@ -1,17 +0,0 @@ -# Generated by Django 5.1.15 on 2026-05-30 17:38 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("api", "0093_okta_provider"), - ] - - operations = [ - migrations.AddField( - model_name="scan", - name="recovery_count", - field=models.IntegerField(default=0), - ), - ] diff --git a/api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py b/api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py index ab511a11b1..9d67404258 100644 --- a/api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py +++ b/api/src/backend/api/migrations/0095_reconcile_orphan_tasks_periodic_task.py @@ -40,7 +40,7 @@ def delete_periodic_task(apps, schema_editor): class Migration(migrations.Migration): dependencies = [ - ("api", "0094_scan_recovery_count"), + ("api", "0093_okta_provider"), ("django_celery_beat", "0019_alter_periodictasks_options"), ] diff --git a/api/src/backend/api/migrations/0096_jiraissuedispatch.py b/api/src/backend/api/migrations/0096_jiraissuedispatch.py deleted file mode 100644 index f5a1a9d9a0..0000000000 --- a/api/src/backend/api/migrations/0096_jiraissuedispatch.py +++ /dev/null @@ -1,64 +0,0 @@ -import uuid - -import django.db.models.deletion -from django.db import migrations, models - -import api.rls - - -class Migration(migrations.Migration): - dependencies = [ - ("api", "0095_reconcile_orphan_tasks_periodic_task"), - ] - - operations = [ - migrations.CreateModel( - name="JiraIssueDispatch", - fields=[ - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ("inserted_at", models.DateTimeField(auto_now_add=True)), - ("finding_id", models.UUIDField()), - ( - "integration", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="jira_dispatches", - to="api.integration", - ), - ), - ( - "tenant", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to="api.tenant" - ), - ), - ], - options={ - "db_table": "jira_issue_dispatches", - "abstract": False, - }, - ), - migrations.AddConstraint( - model_name="jiraissuedispatch", - constraint=models.UniqueConstraint( - fields=("tenant_id", "integration_id", "finding_id"), - name="unique_jira_issue_dispatch", - ), - ), - migrations.AddConstraint( - model_name="jiraissuedispatch", - constraint=api.rls.RowLevelSecurityConstraint( - "tenant_id", - name="rls_on_jiraissuedispatch", - statements=["SELECT", "INSERT", "UPDATE", "DELETE"], - ), - ), - ] diff --git a/api/src/backend/api/models.py b/api/src/backend/api/models.py index 8adbc2cf9e..3d9a26698e 100644 --- a/api/src/backend/api/models.py +++ b/api/src/backend/api/models.py @@ -666,9 +666,6 @@ class Scan(RowLevelSecurityProtectedModel): state = StateEnumField(choices=StateChoices.choices, default=StateChoices.AVAILABLE) unique_resource_count = models.IntegerField(default=0) progress = models.IntegerField(default=0) - # Incremented by the scan-specific orphan-recovery path each time this scan is - # re-pointed to a fresh task; for observability (the retry cap is a Valkey counter). - recovery_count = models.IntegerField(default=0) scanner_args = models.JSONField(default=dict) duration = models.IntegerField(null=True, blank=True) scheduled_at = models.DateTimeField(null=True, blank=True) @@ -2001,35 +1998,6 @@ class IntegrationProviderRelationship(RowLevelSecurityProtectedModel): ] -class JiraIssueDispatch(RowLevelSecurityProtectedModel): - """Tracks findings already sent to a Jira integration. - - Lets the Jira task be re-run safely (e.g. by orphan recovery): findings with - an existing dispatch row are skipped, so no duplicate issues are created. - """ - - id = models.UUIDField(primary_key=True, default=uuid4, editable=False) - inserted_at = models.DateTimeField(auto_now_add=True, editable=False) - integration = models.ForeignKey( - Integration, on_delete=models.CASCADE, related_name="jira_dispatches" - ) - finding_id = models.UUIDField() - - class Meta(RowLevelSecurityProtectedModel.Meta): - db_table = "jira_issue_dispatches" - constraints = [ - models.UniqueConstraint( - fields=["tenant_id", "integration_id", "finding_id"], - name="unique_jira_issue_dispatch", - ), - RowLevelSecurityConstraint( - field="tenant_id", - name="rls_on_%(class)s", - statements=["SELECT", "INSERT", "UPDATE", "DELETE"], - ), - ] - - class SAMLToken(models.Model): id = models.UUIDField(primary_key=True, default=uuid4, editable=False) inserted_at = models.DateTimeField(auto_now_add=True, editable=False) diff --git a/api/src/backend/config/django/base.py b/api/src/backend/config/django/base.py index 402b71eb51..38cf047ac2 100644 --- a/api/src/backend/config/django/base.py +++ b/api/src/backend/config/django/base.py @@ -307,6 +307,18 @@ ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES = env.int( "ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES", 2880 ) # 48h +# Orphan task recovery feature flags. The master switch is OFF by default, so task +# recovery is opt-in; enable it with DJANGO_TASK_RECOVERY_ENABLED=true. The per-group +# toggles default to enabled, so once the master is on every group recovers unless a +# group is explicitly turned off. +TASK_RECOVERY_ENABLED = env.bool("DJANGO_TASK_RECOVERY_ENABLED", False) +TASK_RECOVERY_SUMMARIES_ENABLED = env.bool( + "DJANGO_TASK_RECOVERY_SUMMARIES_ENABLED", True +) +TASK_RECOVERY_DELETIONS_ENABLED = env.bool( + "DJANGO_TASK_RECOVERY_DELETIONS_ENABLED", True +) + def label_postgres_connections(databases): """Tag each Postgres connection with ``application_name=":"`` diff --git a/api/src/backend/tasks/jobs/deletion.py b/api/src/backend/tasks/jobs/deletion.py index 7540f72c7e..f9ead01897 100644 --- a/api/src/backend/tasks/jobs/deletion.py +++ b/api/src/backend/tasks/jobs/deletion.py @@ -11,7 +11,6 @@ from api.db_utils import batch_delete, rls_transaction from api.models import ( AttackPathsScan, Finding, - JiraIssueDispatch, Provider, ProviderComplianceScore, Resource, @@ -81,14 +80,6 @@ def delete_provider(tenant_id: str, pk: str): deletion_steps = [ ("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)), - ( - "Jira Issue Dispatches", - JiraIssueDispatch.objects.filter( - finding_id__in=Finding.all_objects.filter( - scan__provider=instance - ).values_list("id", flat=True) - ), - ), ("Findings", Finding.all_objects.filter(scan__provider=instance)), ("Resources", Resource.all_objects.filter(provider=instance)), ("Scans", Scan.all_objects.filter(provider=instance)), diff --git a/api/src/backend/tasks/jobs/integrations.py b/api/src/backend/tasks/jobs/integrations.py index 55c1205169..5ca94057da 100644 --- a/api/src/backend/tasks/jobs/integrations.py +++ b/api/src/backend/tasks/jobs/integrations.py @@ -9,7 +9,7 @@ from tasks.utils import batched from api.db_router import READ_REPLICA_ALIAS, MainRouter from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction -from api.models import Finding, Integration, JiraIssueDispatch, Provider +from api.models import Finding, Integration, Provider from api.utils import initialize_prowler_integration, initialize_prowler_provider from prowler.lib.outputs.asff.asff import ASFF from prowler.lib.outputs.compliance.generic.generic import GenericCompliance @@ -482,115 +482,66 @@ def send_findings_to_jira( with rls_transaction(tenant_id): integration = Integration.objects.get(id=integration_id) jira_integration = initialize_prowler_integration(integration) - # Idempotency: findings already ticketed for this integration must not be - # sent again on a re-run (e.g. orphan recovery), to avoid duplicate issues - already_sent = { - str(fid) - for fid in JiraIssueDispatch.objects.filter( - integration_id=integration_id, finding_id__in=finding_ids - ).values_list("finding_id", flat=True) - } num_tickets_created = 0 - skipped_count = 0 for finding_id in finding_ids: - if str(finding_id) in already_sent: - skipped_count += 1 - continue - - # Reserve the finding BEFORE the external call. The unique constraint on - # (tenant, integration, finding) makes the dispatch row the single source of - # truth, so a concurrent run or a retry that raced past the bulk pre-check - # cannot create a duplicate issue: created=False means another run already - # claimed it. The reservation is released below if the send does not succeed. with rls_transaction(tenant_id): - _, created = JiraIssueDispatch.objects.get_or_create( - tenant_id=tenant_id, - integration_id=integration_id, - finding_id=finding_id, + finding_instance = ( + Finding.all_objects.select_related("scan__provider") + .prefetch_related("resources") + .get(id=finding_id) ) - if not created: - skipped_count += 1 - continue - sent = False - try: - with rls_transaction(tenant_id): - finding_instance = ( - Finding.all_objects.select_related("scan__provider") - .prefetch_related("resources") - .get(id=finding_id) - ) + # Extract resource information + resource = ( + finding_instance.resources.first() + if finding_instance.resources.exists() + else None + ) + resource_uid = resource.uid if resource else "" + resource_name = resource.name if resource else "" + resource_tags = {} + if resource and hasattr(resource, "tags"): + resource_tags = resource.get_tags(tenant_id) - # Extract resource information - resource = ( - finding_instance.resources.first() - if finding_instance.resources.exists() - else None - ) - resource_uid = resource.uid if resource else "" - resource_name = resource.name if resource else "" - resource_tags = {} - if resource and hasattr(resource, "tags"): - resource_tags = resource.get_tags(tenant_id) + # Get region + region = resource.region if resource and resource.region else "" - # Get region - region = resource.region if resource and resource.region else "" + # Extract remediation information from check_metadata + check_metadata = finding_instance.check_metadata + remediation = check_metadata.get("remediation", {}) + recommendation = remediation.get("recommendation", {}) + remediation_code = remediation.get("code", {}) - # Extract remediation information from check_metadata - check_metadata = finding_instance.check_metadata - remediation = check_metadata.get("remediation", {}) - recommendation = remediation.get("recommendation", {}) - remediation_code = remediation.get("code", {}) - - # Send the individual finding to Jira - sent = bool( - jira_integration.send_finding( - check_id=finding_instance.check_id, - check_title=check_metadata.get("checktitle", ""), - severity=finding_instance.severity, - status=finding_instance.status, - status_extended=finding_instance.status_extended or "", - provider=finding_instance.scan.provider.provider, - region=region, - resource_uid=resource_uid, - resource_name=resource_name, - risk=check_metadata.get("risk", ""), - recommendation_text=recommendation.get("text", ""), - recommendation_url=recommendation.get("url", ""), - remediation_code_native_iac=remediation_code.get( - "nativeiac", "" - ), - remediation_code_terraform=remediation_code.get( - "terraform", "" - ), - remediation_code_cli=remediation_code.get("cli", ""), - remediation_code_other=remediation_code.get("other", ""), - resource_tags=resource_tags, - compliance=finding_instance.compliance or {}, - project_key=project_key, - issue_type=issue_type, - ) - ) - finally: - if not sent: - # Release the reservation so a later run can retry this finding: it - # was not ticketed (send failed or raised), so the row must not block - # a future legitimate send. - with rls_transaction(tenant_id): - JiraIssueDispatch.objects.filter( - tenant_id=tenant_id, - integration_id=integration_id, - finding_id=finding_id, - ).delete() - - if sent: - num_tickets_created += 1 - else: - logger.error(f"Failed to send finding {finding_id} to Jira") + # Send the individual finding to Jira + result = jira_integration.send_finding( + check_id=finding_instance.check_id, + check_title=check_metadata.get("checktitle", ""), + severity=finding_instance.severity, + status=finding_instance.status, + status_extended=finding_instance.status_extended or "", + provider=finding_instance.scan.provider.provider, + region=region, + resource_uid=resource_uid, + resource_name=resource_name, + risk=check_metadata.get("risk", ""), + recommendation_text=recommendation.get("text", ""), + recommendation_url=recommendation.get("url", ""), + remediation_code_native_iac=remediation_code.get("nativeiac", ""), + remediation_code_terraform=remediation_code.get("terraform", ""), + remediation_code_cli=remediation_code.get("cli", ""), + remediation_code_other=remediation_code.get("other", ""), + resource_tags=resource_tags, + compliance=finding_instance.compliance or {}, + project_key=project_key, + issue_type=issue_type, + ) + if result: + num_tickets_created += 1 + else: + logger.error(f"Failed to send finding {finding_id} to Jira") return { "created_count": num_tickets_created, - "failed_count": len(finding_ids) - num_tickets_created - skipped_count, - "skipped_count": skipped_count, + "failed_count": len(finding_ids) - num_tickets_created, } diff --git a/api/src/backend/tasks/jobs/orphan_recovery.py b/api/src/backend/tasks/jobs/orphan_recovery.py index d884c3fc8b..1bf5c95df2 100644 --- a/api/src/backend/tasks/jobs/orphan_recovery.py +++ b/api/src/backend/tasks/jobs/orphan_recovery.py @@ -37,35 +37,52 @@ ORPHAN_RECOVERY_LOCK_KEY = 0x70726F77 # "prow" # Non-terminal states that mean "a worker had this and may have died with it". IN_FLIGHT_STATES = (states.STARTED, states.RECEIVED) -# Scan tasks are recovered by re-running scan-perform on the EXISTING scan row, -# not by re-enqueuing the original task: re-enqueuing scan-perform-scheduled would -# hit its "a scan is already executing" guard and no-op, leaving the scan stuck. -_SCAN_TASKS = ("scan-perform", "scan-perform-scheduled") - -# Tasks with proven idempotency are auto re-enqueued. Scans/summaries clear and -# rewrite their own rows. integration-jira is safe too: each finding is reserved in -# JiraIssueDispatch before the external call, so a re-run skips already-ticketed -# findings (worst case one finding missed on a mid-send crash, never a duplicate). -# Other external side effects stay terminal: integration-s3 rebuilds its upload from -# worker-local files that do not survive a crash, and report/Security Hub recovery is -# out of scope. -REENQUEUEABLE_TASKS = { - *_SCAN_TASKS, - "provider-deletion", - "tenant-deletion", - "scan-summary", - "scan-compliance-overviews", - "scan-provider-compliance-scores", - "scan-daily-severity", - "scan-finding-group-summaries", - "scan-reset-ephemeral-resources", - "integration-jira", +# Tasks with proven idempotency are eligible for auto re-enqueue, grouped so each +# group can be toggled independently by a feature flag (see config.django.base). +# Summaries clear and rewrite their own rows and deletions are idempotent. Tasks with +# external side effects are never eligible: integration-jira would create duplicate +# issues, integration-s3 rebuilds its upload from worker-local files that do not +# survive a crash, and report/Security Hub recovery is out of scope. +RECOVERY_TASK_GROUPS = { + "summaries": { + "scan-summary", + "scan-compliance-overviews", + "scan-provider-compliance-scores", + "scan-daily-severity", + "scan-finding-group-summaries", + "scan-reset-ephemeral-resources", + }, + "deletions": {"provider-deletion", "tenant-deletion"}, } -# Tasks excluded from generic recovery: attack-paths scans are handled by their own -# stale-cleanup (which also drops the temp Neo4j db), and the maintenance tasks must -# not self-recover (they run again on their own schedule). + +def reenqueueable_tasks() -> set[str]: + """Task names eligible for auto re-enqueue, honoring the per-group feature flags. + + A group whose flag is disabled is dropped, so its orphaned tasks are marked + terminal instead of re-enqueued. + """ + from django.conf import settings + + group_enabled = { + "summaries": settings.TASK_RECOVERY_SUMMARIES_ENABLED, + "deletions": settings.TASK_RECOVERY_DELETIONS_ENABLED, + } + return { + task + for group, tasks in RECOVERY_TASK_GROUPS.items() + if group_enabled[group] + for task in tasks + } + + +# Tasks the watchdog ignores entirely (not even marked terminal): scan tasks are not +# auto-recovered, since re-running a scan is not safe to do automatically; attack-paths +# scans are handled by their own stale-cleanup (which also drops the temp Neo4j db); +# and the maintenance tasks must not self-recover (they run again on their own schedule). _SKIP_RECOVERY = { + "scan-perform", + "scan-perform-scheduled", "attack-paths-scan-perform", "attack-paths-cleanup-stale-scans", "reconcile-orphan-tasks", @@ -166,15 +183,22 @@ def reconcile_orphans( logger.info("Orphan reconcile skipped: another run holds the lock") return {"acquired": False} - # Populate the task registry so we can re-enqueue any task by name. - import tasks.tasks # noqa: F401 + from django.conf import settings - result = _reconcile_task_results( - grace_minutes=grace_minutes, - max_attempts=max_attempts, - window_hours=window_hours, - dry_run=dry_run, - ) + if settings.TASK_RECOVERY_ENABLED: + # Populate the task registry so we can re-enqueue any task by name. + import tasks.tasks # noqa: F401 + + result = _reconcile_task_results( + grace_minutes=grace_minutes, + max_attempts=max_attempts, + window_hours=window_hours, + dry_run=dry_run, + ) + result["enabled"] = True + else: + logger.info("Orphan task recovery disabled by feature flag") + result = {"recovered": [], "failed": [], "skipped": [], "enabled": False} if not dry_run: from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans @@ -264,34 +288,27 @@ def _recover_task(task_result, max_attempts: int, window_hours: int) -> str: task_result.date_done = now task_result.save(update_fields=["status", "date_done"]) - attempt = _recovery_attempt_count(name, kwargs_repr, window_hours) - if name not in REENQUEUEABLE_TASKS or attempt > max_attempts: - reason = ( - f"{name} is not allowlisted for auto recovery" - if name not in REENQUEUEABLE_TASKS - else f"recovery cap reached ({attempt}/{max_attempts})" - ) - _fail_domain_row(task_result.task_id, name, now) + if name not in reenqueueable_tasks(): logger.warning( - "Orphan %s (%s) not re-enqueued: %s", task_result.task_id, name, reason + "Orphan %s (%s) not re-enqueued: not allowlisted for auto recovery", + task_result.task_id, + name, ) return "failed" - # Scan tasks: re-run the EXISTING scan row directly via scan-perform, so the - # scheduled-scan "already executing" guard cannot turn recovery into a no-op. - # Falls through to the generic path only if no scan is linked yet (e.g. a - # scheduled task that died before creating one), where re-running it creates one. - if name in _SCAN_TASKS: - scan = _scan_for_task(task_result.task_id) - if scan is not None: - if not _reenqueue_scan(task_result.task_id, scan): - return "failed" - logger.info( - "Re-enqueued orphaned scan %s (was task %s)", - scan.id, - task_result.task_id, - ) - return "recovered" + # Count the attempt only once the task is allowlisted, so a task sitting in a + # disabled group does not burn its recovery budget while the flag is off (and is + # not already over the cap the moment the group is re-enabled). + attempt = _recovery_attempt_count(name, kwargs_repr, window_hours) + if attempt > max_attempts: + logger.warning( + "Orphan %s (%s) not re-enqueued: recovery cap reached (%d/%d)", + task_result.task_id, + name, + attempt, + max_attempts, + ) + return "failed" task_obj = current_app.tasks.get(name) if task_obj is None: @@ -311,7 +328,6 @@ def _recover_task(task_result, max_attempts: int, window_hours: int) -> str: task_result.task_id, name, ) - _fail_domain_row(task_result.task_id, name, now) return "failed" new_task_id = str(uuid4()) task_obj.apply_async( @@ -323,75 +339,3 @@ def _recover_task(task_result, max_attempts: int, window_hours: int) -> str: "Re-enqueued orphan %s (%s) as %s", task_result.task_id, name, new_task_id ) return "recovered" - - -def _scan_for_task(task_id: str): - """Return the Scan linked to a Celery task id, or None (read across tenants).""" - from api.db_router import MainRouter - from api.models import Scan - - return Scan.all_objects.using(MainRouter.admin_db).filter(task_id=task_id).first() - - -def _reenqueue_scan(old_task_id: str, scan) -> bool: - """Re-run an orphaned scan via scan-perform on the existing row. - - Pre-provisions the new task linkage (TaskResult + api.Task) and relinks the - Scan before enqueuing, so the FK is valid and a worker can never outrun the DB. - The relink is conditional on the scan still pointing at the old task, so a stale - orphan can never clobber a newer linkage. - """ - from django_celery_results.models import TaskResult - - from api.db_utils import rls_transaction - from api.models import Scan - from api.models import Task as APITask - from tasks.tasks import perform_scan_task - - tenant_id = str(scan.tenant_id) - new_task_id = str(uuid4()) - with rls_transaction(tenant_id): - locked_scan = Scan.all_objects.select_for_update().filter(id=scan.id).first() - if locked_scan is None or str(locked_scan.task_id) != old_task_id: - logger.info( - "Scan %s no longer points at task %s; skipping recovery re-enqueue", - scan.id, - old_task_id, - ) - return False - task_result_new, _ = TaskResult.objects.get_or_create( - task_id=new_task_id, - defaults={"status": states.PENDING, "task_name": "scan-perform"}, - ) - APITask.objects.update_or_create( - id=new_task_id, - tenant_id=tenant_id, - defaults={"task_runner_task": task_result_new}, - ) - locked_scan.task_id = new_task_id - locked_scan.recovery_count = (locked_scan.recovery_count or 0) + 1 - locked_scan.save(update_fields=["task_id", "recovery_count", "updated_at"]) - - perform_scan_task.apply_async( - kwargs={ - "tenant_id": tenant_id, - "scan_id": str(scan.id), - "provider_id": str(scan.provider_id), - }, - task_id=new_task_id, - ) - return True - - -def _fail_domain_row(old_task_id: str, name: str, now: datetime) -> None: - """Mark a scan terminal when its task is capped/denylisted instead of re-run.""" - from api.db_utils import rls_transaction - from api.models import Scan, StateChoices - - if name in _SCAN_TASKS: - scan = _scan_for_task(old_task_id) - if scan is not None: - with rls_transaction(str(scan.tenant_id)): - Scan.all_objects.filter(id=scan.id, task_id=old_task_id).update( - state=StateChoices.FAILED, completed_at=now - ) diff --git a/api/src/backend/tasks/jobs/scan.py b/api/src/backend/tasks/jobs/scan.py index 298a2b225a..db5018db90 100644 --- a/api/src/backend/tasks/jobs/scan.py +++ b/api/src/backend/tasks/jobs/scan.py @@ -118,19 +118,6 @@ ATTACK_SURFACE_PROVIDER_COMPATIBILITY = { _ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {} -def _clear_scan_rerun_state(tenant_id: str, scan_id: str) -> None: - """Remove rows derived from a previous execution of this scan.""" - with rls_transaction(tenant_id): - Finding.all_objects.filter(scan_id=scan_id).delete() - ResourceScanSummary.objects.filter(scan_id=scan_id).delete() - ScanCategorySummary.objects.filter(scan_id=scan_id).delete() - ScanGroupSummary.objects.filter(scan_id=scan_id).delete() - ScanSummary.objects.filter(scan_id=scan_id).delete() - AttackSurfaceOverview.objects.filter(scan_id=scan_id).delete() - ComplianceRequirementOverview.objects.filter(scan_id=scan_id).delete() - ComplianceOverviewSummary.objects.filter(scan_id=scan_id).delete() - - def aggregate_category_counts( categories: list[str], severity: str, @@ -489,10 +476,9 @@ def _create_compliance_summaries( ) ) - # Idempotent re-run: clear this scan's prior summaries before re-inserting, so - # a recovered scan's summary always reflects its own (re-derived) requirement - # rows rather than keeping a stale row (bulk_create ignore_conflicts alone would - # keep the old one). + # Idempotent re-run: clear this scan's prior summaries before re-inserting, so a + # recovered scan-compliance-overviews run reflects its own re-derived rows instead + # of keeping a stale one (bulk_create ignore_conflicts alone would keep the old). with rls_transaction(tenant_id): ComplianceOverviewSummary.objects.filter(scan_id=scan_id).delete() if summary_objects: @@ -1039,7 +1025,6 @@ def perform_prowler_scan( scan_instance.state = StateChoices.EXECUTING scan_instance.started_at = datetime.now(tz=timezone.utc) scan_instance.save(update_fields=["state", "started_at", "updated_at"]) - _clear_scan_rerun_state(tenant_id, scan_id) # Find the mutelist processor if it exists with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS): diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index 8f6b9bda0e..e617339973 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -260,7 +260,9 @@ def delete_provider_task(provider_id: str, tenant_id: str): return delete_provider(tenant_id=tenant_id, pk=provider_id) -@shared_task(base=RLSTask, name="scan-perform", queue="scans") +# acks_late=False: a re-run would duplicate findings and the task is not auto-recovered, +# so a crashed scan is dropped rather than redelivered by the broker (as before #11416). +@shared_task(base=RLSTask, name="scan-perform", queue="scans", acks_late=False) @handle_provider_deletion def perform_scan_task( tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None @@ -304,7 +306,14 @@ def perform_scan_task( return result -@shared_task(base=RLSTask, bind=True, name="scan-perform-scheduled", queue="scans") +# acks_late=False: like scan-perform; a dropped run is re-fired by Beat on the next tick. +@shared_task( + base=RLSTask, + bind=True, + name="scan-perform-scheduled", + queue="scans", + acks_late=False, +) @handle_provider_deletion def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str): """ @@ -1151,10 +1160,13 @@ def security_hub_integration_task( return upload_security_hub_integration(tenant_id, provider_id, scan_id) +# acks_late=False: Jira sends are not deduplicated and the task is not auto-recovered, +# so a crashed send is dropped rather than redelivered (avoids duplicate Jira issues). @shared_task( base=RLSTask, name="integration-jira", queue="integrations", + acks_late=False, ) def jira_integration_task( tenant_id: str, diff --git a/api/src/backend/tasks/tests/test_deletion.py b/api/src/backend/tasks/tests/test_deletion.py index e6cd51aca8..0ed8c5ddb2 100644 --- a/api/src/backend/tasks/tests/test_deletion.py +++ b/api/src/backend/tasks/tests/test_deletion.py @@ -1,12 +1,11 @@ from unittest.mock import call, patch -from uuid import uuid4 import pytest from django.core.exceptions import ObjectDoesNotExist from tasks.jobs.deletion import delete_provider, delete_tenant from api.attack_paths import database as graph_database -from api.models import JiraIssueDispatch, Provider, Tenant, TenantComplianceSummary +from api.models import Provider, Tenant, TenantComplianceSummary @pytest.mark.django_db @@ -35,43 +34,6 @@ class TestDeleteProvider: str(instance.id), ) - def test_delete_provider_removes_jira_dispatches( - self, - providers_fixture, - findings_fixture, - integrations_fixture, - ): - """Deleting a provider removes JiraIssueDispatch rows for its findings only.""" - instance = providers_fixture[0] - tenant_id = str(instance.tenant_id) - finding = findings_fixture[0] - integration = integrations_fixture[0] - - # Dispatch for one of the provider's findings: must be removed with it. - JiraIssueDispatch.objects.create( - tenant_id=tenant_id, - integration=integration, - finding_id=finding.id, - ) - # Dispatch for an unrelated finding: must survive the provider deletion. - unrelated = JiraIssueDispatch.objects.create( - tenant_id=tenant_id, - integration=integration, - finding_id=uuid4(), - ) - - with ( - patch( - "tasks.jobs.deletion.graph_database.get_database_name", - return_value="tenant-db", - ), - patch("tasks.jobs.deletion.graph_database.drop_subgraph"), - ): - delete_provider(tenant_id, instance.id) - - assert not JiraIssueDispatch.objects.filter(finding_id=finding.id).exists() - assert JiraIssueDispatch.objects.filter(pk=unrelated.pk).exists() - def test_delete_provider_does_not_exist(self, tenants_fixture): with ( patch( diff --git a/api/src/backend/tasks/tests/test_integrations.py b/api/src/backend/tasks/tests/test_integrations.py index ba8d52c193..e246405cdd 100644 --- a/api/src/backend/tasks/tests/test_integrations.py +++ b/api/src/backend/tasks/tests/test_integrations.py @@ -1640,74 +1640,14 @@ class TestJiraIntegration: @patch("tasks.jobs.integrations.Finding") @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") - @patch("tasks.jobs.integrations.JiraIssueDispatch") - def test_send_findings_to_jira_skips_already_dispatched( - self, - mock_jira_dispatch, - mock_initialize_integration, - mock_integration_model, - mock_finding_model, - mock_rls_transaction, - ): - """A re-run skips findings already ticketed (no duplicate Jira issues).""" - mock_rls_transaction.return_value.__enter__ = MagicMock() - mock_rls_transaction.return_value.__exit__ = MagicMock() - mock_integration_model.objects.get.return_value = MagicMock() - # finding-1 was already dispatched in a prior run; finding-2 is new. - mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [ - "finding-1" - ] - mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) - - mock_jira_integration = MagicMock() - mock_jira_integration.send_finding.return_value = True - mock_initialize_integration.return_value = mock_jira_integration - - finding2 = MagicMock() - finding2.id = "finding-2" - finding2.check_id = "check_002" - finding2.severity = "low" - finding2.status = "FAIL" - finding2.status_extended = "" - finding2.compliance = {} - finding2.resources.exists.return_value = False - finding2.resources.first.return_value = None - finding2.scan.provider.provider = "aws" - finding2.check_metadata = { - "checktitle": "C2", - "risk": "", - "remediation": {"recommendation": {}, "code": {}}, - } - mock_finding_model.all_objects.select_related.return_value.prefetch_related.return_value.get.return_value = finding2 - - result = send_findings_to_jira( - "tenant-123", "integration-456", "PROJ", "Task", ["finding-1", "finding-2"] - ) - - # finding-1 skipped (already sent); only finding-2 sent -> no duplicate. - assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 1} - mock_jira_integration.send_finding.assert_called_once() - assert ( - mock_jira_integration.send_finding.call_args.kwargs["check_id"] - == "check_002" - ) - - @patch("tasks.jobs.integrations.rls_transaction") - @patch("tasks.jobs.integrations.Finding") - @patch("tasks.jobs.integrations.Integration") - @patch("tasks.jobs.integrations.initialize_prowler_integration") - @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_success( self, - mock_jira_dispatch, mock_initialize_integration, mock_integration_model, mock_finding_model, mock_rls_transaction, ): """Test successful sending of findings to Jira using send_finding method""" - mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] - mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -1799,7 +1739,7 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 2, "failed_count": 0, "skipped_count": 0} + assert result == {"created_count": 2, "failed_count": 0} # Verify Jira integration was initialized mock_initialize_integration.assert_called_once_with(integration) @@ -1831,10 +1771,8 @@ class TestJiraIntegration: @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") @patch("tasks.jobs.integrations.logger") - @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_partial_failure( self, - mock_jira_dispatch, mock_logger, mock_initialize_integration, mock_integration_model, @@ -1842,8 +1780,6 @@ class TestJiraIntegration: mock_rls_transaction, ): """Test partial failure when sending findings to Jira""" - mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] - mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -1897,35 +1833,23 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 2, "failed_count": 1, "skipped_count": 0} + assert result == {"created_count": 2, "failed_count": 1} # Verify error was logged for the failed finding mock_logger.error.assert_called_with("Failed to send finding finding-2 to Jira") - # The failed finding's reservation is released so a later run can retry it. - mock_jira_dispatch.objects.filter.assert_any_call( - tenant_id=tenant_id, - integration_id=integration_id, - finding_id="finding-2", - ) - mock_jira_dispatch.objects.filter.return_value.delete.assert_called_once() - @patch("tasks.jobs.integrations.rls_transaction") @patch("tasks.jobs.integrations.Finding") @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") - @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_no_resources( self, - mock_jira_dispatch, mock_initialize_integration, mock_integration_model, mock_finding_model, mock_rls_transaction, ): """Test sending findings to Jira when finding has no resources""" - mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] - mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -1983,7 +1907,7 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0} + assert result == {"created_count": 1, "failed_count": 0} # Verify send_finding was called with empty resource fields call_kwargs = mock_jira_integration.send_finding.call_args.kwargs @@ -1996,18 +1920,14 @@ class TestJiraIntegration: @patch("tasks.jobs.integrations.Finding") @patch("tasks.jobs.integrations.Integration") @patch("tasks.jobs.integrations.initialize_prowler_integration") - @patch("tasks.jobs.integrations.JiraIssueDispatch") def test_send_findings_to_jira_with_empty_check_metadata( self, - mock_jira_dispatch, mock_initialize_integration, mock_integration_model, mock_finding_model, mock_rls_transaction, ): """Test sending findings to Jira when check_metadata is empty or missing fields""" - mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] - mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True) tenant_id = "tenant-123" integration_id = "integration-456" project_key = "PROJ" @@ -2050,7 +1970,7 @@ class TestJiraIntegration: ) # Assertions - assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0} + assert result == {"created_count": 1, "failed_count": 0} # Verify send_finding was called with default/empty values call_kwargs = mock_jira_integration.send_finding.call_args.kwargs @@ -2063,94 +1983,3 @@ class TestJiraIntegration: assert call_kwargs["remediation_code_cli"] == "" assert call_kwargs["remediation_code_other"] == "" assert call_kwargs["compliance"] == {} - - @patch("tasks.jobs.integrations.rls_transaction") - @patch("tasks.jobs.integrations.Finding") - @patch("tasks.jobs.integrations.Integration") - @patch("tasks.jobs.integrations.initialize_prowler_integration") - @patch("tasks.jobs.integrations.JiraIssueDispatch") - def test_send_findings_to_jira_reserves_before_sending( - self, - mock_jira_dispatch, - mock_initialize_integration, - mock_integration_model, - mock_finding_model, - mock_rls_transaction, - ): - """The dispatch row is reserved before the external Jira call (reserve-then-act).""" - mock_rls_transaction.return_value.__enter__ = MagicMock() - mock_rls_transaction.return_value.__exit__ = MagicMock() - mock_integration_model.objects.get.return_value = MagicMock() - mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] - - order = [] - mock_jira_dispatch.objects.get_or_create.side_effect = lambda **kw: ( - order.append(("reserve", kw)) or (MagicMock(), True) - ) - - mock_jira_integration = MagicMock() - mock_jira_integration.send_finding.side_effect = lambda **kw: ( - order.append(("send", kw)) or True - ) - mock_initialize_integration.return_value = mock_jira_integration - - finding = MagicMock() - finding.id = "finding-1" - finding.check_id = "check_001" - finding.severity = "low" - finding.status = "FAIL" - finding.status_extended = "" - finding.compliance = {} - finding.resources.exists.return_value = False - finding.resources.first.return_value = None - finding.scan.provider.provider = "aws" - finding.check_metadata = { - "checktitle": "C1", - "risk": "", - "remediation": {"recommendation": {}, "code": {}}, - } - mock_finding_model.all_objects.select_related.return_value.prefetch_related.return_value.get.return_value = finding - - result = send_findings_to_jira( - "tenant-123", "integration-456", "PROJ", "Task", ["finding-1"] - ) - - assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0} - # Reservation must precede the external send. - assert [entry[0] for entry in order] == ["reserve", "send"] - # A successful send keeps the reservation (no rollback delete). - mock_jira_dispatch.objects.filter.return_value.delete.assert_not_called() - - @patch("tasks.jobs.integrations.rls_transaction") - @patch("tasks.jobs.integrations.Finding") - @patch("tasks.jobs.integrations.Integration") - @patch("tasks.jobs.integrations.initialize_prowler_integration") - @patch("tasks.jobs.integrations.JiraIssueDispatch") - def test_send_findings_to_jira_skips_when_already_reserved( - self, - mock_jira_dispatch, - mock_initialize_integration, - mock_integration_model, - mock_finding_model, - mock_rls_transaction, - ): - """A finding that races past the bulk pre-check but loses the reservation - (created=False) is skipped without a second issue, leaving the row intact.""" - mock_rls_transaction.return_value.__enter__ = MagicMock() - mock_rls_transaction.return_value.__exit__ = MagicMock() - mock_integration_model.objects.get.return_value = MagicMock() - mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [] - # Another concurrent run already created the dispatch row. - mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), False) - - mock_jira_integration = MagicMock() - mock_initialize_integration.return_value = mock_jira_integration - - result = send_findings_to_jira( - "tenant-123", "integration-456", "PROJ", "Task", ["finding-1"] - ) - - assert result == {"created_count": 0, "failed_count": 0, "skipped_count": 1} - mock_jira_integration.send_finding.assert_not_called() - # The reservation belongs to the run that won the race; do not delete it. - mock_jira_dispatch.objects.filter.return_value.delete.assert_not_called() diff --git a/api/src/backend/tasks/tests/test_orphan_recovery.py b/api/src/backend/tasks/tests/test_orphan_recovery.py index b426297bbd..abfa920b8e 100644 --- a/api/src/backend/tasks/tests/test_orphan_recovery.py +++ b/api/src/backend/tasks/tests/test_orphan_recovery.py @@ -4,17 +4,17 @@ from uuid import uuid4 import pytest from celery import states +from django.test import override_settings from django_celery_results.models import TaskResult -from api.models import Scan, StateChoices -from api.models import Task as APITask from tasks.jobs.orphan_recovery import ( _decode_celery_field, _reconcile_task_results, _recovery_attempt_count, - _reenqueue_scan, advisory_lock, is_worker_alive, + reconcile_orphans, + reenqueueable_tasks, ) @@ -130,9 +130,83 @@ class TestReconcileTaskResults: assert tr.task_id in result["failed"] mock_task.apply_async.assert_not_called() - def test_jira_integration_task_is_reenqueued(self, tenants_fixture): - """integration-jira is re-enqueued: its JiraIssueDispatch reservation makes a - re-run skip already-ticketed findings, so recovery cannot duplicate issues.""" + @override_settings(TASK_RECOVERY_SUMMARIES_ENABLED=False) + def test_disabled_group_task_is_not_reenqueued(self, tenants_fixture): + """A task whose group feature flag is off stays terminal, not re-enqueued.""" + tr = _orphan_result( + name="scan-summary", + kwargs={ + "tenant_id": str(tenants_fixture[0].id), + "scan_id": str(uuid4()), + }, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with ( + p_alive, + p_revoke, + p_app, + patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1), + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["failed"] + mock_task.apply_async.assert_not_called() + + @override_settings(TASK_RECOVERY_SUMMARIES_ENABLED=False) + def test_disabled_group_task_does_not_consume_recovery_attempt( + self, tenants_fixture + ): + """A disabled-group task is failed without incrementing its Valkey attempt + counter, so re-enabling the group does not start it at the cap.""" + tr = _orphan_result( + name="scan-summary", + kwargs={"tenant_id": str(tenants_fixture[0].id), "scan_id": str(uuid4())}, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with ( + p_alive, + p_revoke, + p_app, + patch("tasks.jobs.orphan_recovery._recovery_attempt_count") as mock_count, + ): + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id in result["failed"] + mock_count.assert_not_called() + + def test_scan_task_is_skipped_entirely(self, tenants_fixture): + """Scan tasks are excluded from recovery: the watchdog never touches them.""" + tr = _orphan_result( + name="scan-perform", + kwargs={ + "tenant_id": str(tenants_fixture[0].id), + "scan_id": str(uuid4()), + }, + worker="dead@gone", + created_minutes_ago=60, + ) + p_alive, p_revoke, p_app, mock_task = self._patches(alive=False) + with p_alive, p_revoke, p_app: + result = _reconcile_task_results( + grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False + ) + + assert tr.task_id not in result["recovered"] + assert tr.task_id not in result["failed"] + assert tr.task_id not in result["skipped"] + mock_task.apply_async.assert_not_called() + + def test_jira_integration_task_is_not_reenqueued(self, tenants_fixture): + """integration-jira stays terminal: re-running it would create duplicate Jira + issues, so an orphaned send is failed instead of re-enqueued.""" tenant = tenants_fixture[0] kwargs = { "tenant_id": str(tenant.id), @@ -158,13 +232,10 @@ class TestReconcileTaskResults: grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False ) - assert tr.task_id in result["recovered"] + assert tr.task_id in result["failed"] tr.refresh_from_db() assert tr.status == states.REVOKED # stale result cleared (no pending alert) - mock_task.apply_async.assert_called_once() - call = mock_task.apply_async.call_args.kwargs - assert call["kwargs"] == kwargs - assert call["task_id"] != tr.task_id # fresh task id + mock_task.apply_async.assert_not_called() def test_skips_live_worker(self, tenants_fixture): tr = _orphan_result( @@ -246,98 +317,6 @@ class TestReconcileTaskResults: mock_task.apply_async.assert_not_called() -@pytest.mark.django_db -class TestScanRecovery: - """Scans are recovered by re-running scan-perform on the EXISTING scan row, - so even a scheduled-scan orphan (whose own task would no-op on its guard) is - actually re-executed.""" - - def _scan_orphan(self, tenant, provider, name): - old_id = str(uuid4()) - tr = TaskResult.objects.create( - task_id=old_id, - status=states.STARTED, - task_name=name, - worker="dead@gone", - task_kwargs=repr( - {"tenant_id": str(tenant.id), "provider_id": str(provider.id)} - ), - task_args=repr([]), - ) - TaskResult.objects.filter(pk=tr.pk).update( - date_created=datetime.now(tz=timezone.utc) - timedelta(minutes=60) - ) - APITask.objects.create(id=old_id, tenant_id=tenant.id, task_runner_task=tr) - scan = Scan.objects.create( - name="scan-orphan", - provider=provider, - trigger=Scan.TriggerChoices.SCHEDULED, - state=StateChoices.EXECUTING, - tenant_id=tenant.id, - task_id=old_id, - recovery_count=0, - ) - return old_id, scan - - @pytest.mark.parametrize("name", ["scan-perform", "scan-perform-scheduled"]) - def test_scan_recovered_via_scan_perform( - self, tenants_fixture, providers_fixture, name - ): - tenant, provider = tenants_fixture[0], providers_fixture[0] - old_id, scan = self._scan_orphan(tenant, provider, name) - - with ( - patch("tasks.jobs.orphan_recovery.is_worker_alive", return_value=False), - patch("tasks.jobs.orphan_recovery.revoke_task"), - patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1), - patch("tasks.tasks.perform_scan_task") as mock_scan_task, - ): - result = _reconcile_task_results( - grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False - ) - - assert old_id in result["recovered"] - scan.refresh_from_db() - assert str(scan.task_id) != old_id # relinked to a fresh task - assert scan.recovery_count == 1 - assert TaskResult.objects.get(task_id=old_id).status == states.REVOKED - # Recovered by re-running scan-perform on the existing scan row (so the - # scheduled guard cannot no-op it), regardless of the original task name. - mock_scan_task.apply_async.assert_called_once() - assert mock_scan_task.apply_async.call_args.kwargs["kwargs"]["scan_id"] == str( - scan.id - ) - - def test_reenqueue_skips_when_scan_already_repointed( - self, tenants_fixture, providers_fixture - ): - # The scan already points at a newer task, so a stale orphan must not relink - # it or launch a second concurrent run against the same scan row. - tenant, provider = tenants_fixture[0], providers_fixture[0] - newer_id = str(uuid4()) - tr = TaskResult.objects.create( - task_id=newer_id, status=states.STARTED, task_name="scan-perform" - ) - APITask.objects.create(id=newer_id, tenant_id=tenant.id, task_runner_task=tr) - scan = Scan.objects.create( - name="scan-orphan", - provider=provider, - trigger=Scan.TriggerChoices.SCHEDULED, - state=StateChoices.EXECUTING, - tenant_id=tenant.id, - task_id=newer_id, - recovery_count=0, - ) - - with patch("tasks.tasks.perform_scan_task") as mock_scan_task: - recovered = _reenqueue_scan(str(uuid4()), scan) - - assert recovered is False - mock_scan_task.apply_async.assert_not_called() - scan.refresh_from_db() - assert scan.recovery_count == 0 - - @pytest.mark.django_db class TestOrphanRecoveryHelpers: def test_advisory_lock_acquires_and_releases(self): @@ -370,3 +349,60 @@ class TestOrphanRecoveryHelpers: with patch("redis.from_url", return_value=redis_client): assert _recovery_attempt_count("probe-task", kwargs_repr, 6) == 1 assert _recovery_attempt_count("probe-task", kwargs_repr, 6) == 2 + + +class TestRecoveryFeatureFlags: + def test_all_groups_enabled_by_default(self): + tasks = reenqueueable_tasks() + assert "scan-summary" in tasks + assert {"provider-deletion", "tenant-deletion"} <= tasks + + @override_settings(TASK_RECOVERY_SUMMARIES_ENABLED=False) + def test_summaries_group_flag_excludes_summary_tasks(self): + tasks = reenqueueable_tasks() + assert "scan-summary" not in tasks + assert "scan-compliance-overviews" not in tasks + assert "provider-deletion" in tasks + + @override_settings(TASK_RECOVERY_DELETIONS_ENABLED=False) + def test_deletions_group_flag_excludes_deletion_tasks(self): + tasks = reenqueueable_tasks() + assert "provider-deletion" not in tasks + assert "tenant-deletion" not in tasks + assert "scan-summary" in tasks + + +@pytest.mark.django_db +class TestRecoveryMasterFlag: + @override_settings(TASK_RECOVERY_ENABLED=False) + def test_master_flag_disables_task_recovery(self): + with ( + patch( + "tasks.jobs.orphan_recovery._reconcile_task_results" + ) as mock_reconcile, + patch( + "tasks.jobs.attack_paths.cleanup.cleanup_stale_attack_paths_scans", + return_value={}, + ), + ): + result = reconcile_orphans(grace_minutes=2, max_attempts=3, dry_run=False) + + mock_reconcile.assert_not_called() + assert result["acquired"] is True + assert result["enabled"] is False + + @override_settings(TASK_RECOVERY_ENABLED=True) + def test_master_flag_enabled_runs_task_recovery(self): + with ( + patch( + "tasks.jobs.orphan_recovery._reconcile_task_results", + return_value={"recovered": [], "failed": [], "skipped": []}, + ) as mock_reconcile, + patch( + "tasks.jobs.attack_paths.cleanup.cleanup_stale_attack_paths_scans", + return_value={}, + ), + ): + reconcile_orphans(grace_minutes=2, max_attempts=3, dry_run=False) + + mock_reconcile.assert_called_once() diff --git a/api/src/backend/tasks/tests/test_scan.py b/api/src/backend/tasks/tests/test_scan.py index 2ff7e44e4d..8d3d0be93a 100644 --- a/api/src/backend/tasks/tests/test_scan.py +++ b/api/src/backend/tasks/tests/test_scan.py @@ -32,15 +32,12 @@ from tasks.utils import CustomEncoder from api.db_router import MainRouter from api.exceptions import ProviderConnectionError from api.models import ( - AttackSurfaceOverview, Finding, MuteRule, Provider, Resource, ResourceScanSummary, Scan, - ScanCategorySummary, - ScanGroupSummary, ScanSummary, StateChoices, StatusChoices, @@ -232,131 +229,6 @@ class TestPerformScan: # Assert that failed_findings_count is 0 (finding is PASS and muted) assert scan_resource.failed_findings_count == 0 - def test_perform_prowler_scan_idempotent_on_rerun( - self, - tenants_fixture, - scans_fixture, - providers_fixture, - ): - """Re-running a scan for the same scan_id must not duplicate findings.""" - with ( - patch("api.db_utils.rls_transaction"), - patch( - "tasks.jobs.scan.initialize_prowler_provider" - ) as mock_initialize_prowler_provider, - patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class, - patch( - "tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE", - new_callable=dict, - ), - patch("api.compliance.PROWLER_CHECKS", new_callable=dict) as mock_checks, - ): - mock_checks["aws"] = {"check1": {"compliance1"}} - - tenant = tenants_fixture[0] - scan = scans_fixture[0] - provider = providers_fixture[0] - provider.provider = Provider.ProviderChoices.AWS - provider.save() - - tenant_id = str(tenant.id) - scan_id = str(scan.id) - provider_id = str(provider.id) - - stale_resource = Resource.objects.create( - tenant_id=tenant.id, - provider=provider, - uid="stale_resource_uid", - name="stale", - region="stale-region", - service="stale-service", - type="stale-type", - ) - ResourceScanSummary.objects.create( - tenant_id=tenant.id, - scan_id=scan.id, - resource_id=stale_resource.id, - service="stale-service", - region="stale-region", - resource_type="stale-type", - ) - ScanCategorySummary.objects.create( - tenant_id=tenant.id, - scan=scan, - category="stale-category", - severity=Severity.medium, - total_findings=1, - ) - ScanGroupSummary.objects.create( - tenant_id=tenant.id, - scan=scan, - resource_group="stale-group", - severity=Severity.medium, - total_findings=1, - ) - ScanSummary.objects.create( - tenant_id=tenant.id, - scan=scan, - check_id="stale_check", - service="stale-service", - severity=Severity.medium, - region="stale-region", - total=1, - ) - AttackSurfaceOverview.objects.create( - tenant_id=tenant.id, - scan=scan, - attack_surface_type=AttackSurfaceOverview.AttackSurfaceTypeChoices.SECRETS, - total_findings=1, - ) - - finding = MagicMock() - finding.uid = "dup_probe_finding" - finding.status = StatusChoices.PASS - finding.status_extended = "x" - finding.severity = Severity.medium - finding.check_id = "check1" - finding.get_metadata.return_value = {"key": "value"} - finding.resource_uid = "resource_uid" - finding.resource_name = "resource_name" - finding.region = "region" - finding.service_name = "service_name" - finding.resource_type = "resource_type" - finding.resource_tags = {} - finding.muted = False - finding.raw = {} - finding.resource_metadata = {} - finding.resource_details = {} - finding.partition = "partition" - finding.compliance = {} - - mock_scan_instance = MagicMock() - mock_scan_instance.scan.return_value = [(100, [finding])] - mock_prowler_scan_class.return_value = mock_scan_instance - - mock_provider_instance = MagicMock() - mock_provider_instance.get_regions.return_value = ["region"] - mock_initialize_prowler_provider.return_value = mock_provider_instance - - # Run the same scan twice (simulating an orphan-recovery re-run). - perform_prowler_scan(tenant_id, scan_id, provider_id, ["check1"]) - perform_prowler_scan(tenant_id, scan_id, provider_id, ["check1"]) - - # Neither findings nor resources are duplicated by the re-run: findings are - # scope-deleted before re-insert; resources are upserted by (tenant, provider, uid). - assert Finding.objects.filter(scan=scan).count() == 1 - assert Resource.objects.filter(provider=provider).count() == 2 - assert ResourceScanSummary.objects.filter(scan_id=scan.id).count() == 1 - assert not ResourceScanSummary.objects.filter( - scan_id=scan.id, resource_id=stale_resource.id - ).exists() - assert not ScanCategorySummary.objects.filter(scan=scan).exists() - assert not ScanGroupSummary.objects.filter(scan=scan).exists() - assert not ScanSummary.objects.filter( - scan=scan, check_id="stale_check" - ).exists() - assert not AttackSurfaceOverview.objects.filter(scan=scan).exists() - @patch("tasks.jobs.scan.ProwlerScan") @patch( "tasks.jobs.scan.initialize_prowler_provider",