feat(api): make orphan-task recovery configurable and drop the Jira idempotency table (#11472)

This commit is contained in:
Adrián Peña
2026-06-09 09:16:48 +02:00
committed by GitHub
parent 662e7e9e18
commit 1f7caa6394
17 changed files with 349 additions and 841 deletions
+1 -3
View File
@@ -6,15 +6,13 @@ All notable changes to the **Prowler API** are documented in this file.
### 🚀 Added
- Automatic recovery of allowlisted idempotent background tasks whose worker died during a deploy or crash: stuck scan and summary tasks are detected and re-run instead of staying pending forever, with a `reconcile_orphan_tasks` management command for on-demand recovery [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416)
- Jira integration no longer creates duplicate issues on a retried send; findings already ticketed are skipped [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416)
- Opt-in automatic recovery of allowlisted idempotent background tasks whose worker died during a deploy or crash: when enabled via `DJANGO_TASK_RECOVERY_ENABLED` (off by default), stuck summary and deletion tasks are detected and re-run instead of staying pending forever (scan and Jira tasks are excluded), with a `reconcile_orphan_tasks` management command for on-demand recovery [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416)
- DORA compliance framework support [(#11131)](https://github.com/prowler-cloud/prowler/pull/11131)
- Label Postgres connections with `application_name="<component>:<alias>"` (component injected per process via `DJANGO_APP_COMPONENT`) so connections are attributable by component in `pg_stat_activity` [(#11494)](https://github.com/prowler-cloud/prowler/pull/11494)
### 🔄 Changed
- Allowlisted idempotent background tasks are no longer lost when a worker is stopped or crashes mid-task; tasks with external side effects are marked terminal instead of blindly re-running [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416)
- A recovered scan rewrites its findings, summaries, attack surface, and compliance data instead of appending to the previous run, so recovery never leaves stale or duplicate materialized rows [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416)
### 🐞 Fixed
+37 -18
View File
@@ -1,10 +1,11 @@
# Orphan Celery task recovery
When a worker is terminated mid-task (a deploy, an OOM kill, a node eviction), the
task it was running can be left non-terminal forever: the `Scan` stays `EXECUTING`,
the `TaskResult` stays `STARTED`, and nothing re-runs it. This page describes the
mechanisms that detect and recover allowlisted idempotent orphans so users never
see a stuck scan and pending-task alerts do not fire.
task it was running can be left non-terminal forever: the `TaskResult` stays
`STARTED` and nothing re-runs it. This page describes the mechanisms that detect and
recover allowlisted idempotent orphans so pending-task alerts do not fire. Scan tasks
are not auto-recovered (re-running a scan is not safe to do automatically); the
watchdog covers the summary/aggregation and deletion tasks.
## How recovery works
@@ -13,29 +14,35 @@ see a stuck scan and pending-task alerts do not fire.
(`worker_prefetch_multiplier = 1`), and an abruptly-lost worker re-queues its task
(`task_reject_on_worker_lost`). On `SIGTERM` the worker is given a soft-shutdown
window (`worker_soft_shutdown_timeout`) to finish or re-queue in-flight work
before it is force-killed.
before it is force-killed. `scan-perform`, `scan-perform-scheduled` and
`integration-jira` opt out of redelivery with `acks_late=False`, so a crash drops
them rather than re-running and duplicating findings or Jira issues. Other
non-recovered side-effect tasks keep `acks_late=True`, so the broker can still
re-deliver them after a worker loss: the S3 upload rebuilds from worker-local files
that did not survive the crash and so no-ops, but Security Hub re-reads findings from
the DB and re-sends them to AWS.
2. **Periodic watchdog.** A Beat task, `reconcile-orphan-tasks`, runs every couple of
minutes (a `django_celery_beat` periodic task created by migration). For each
in-flight task result with an allowlisted idempotent task name, it pings the
worker recorded on the task's `TaskResult`:
- worker responds -> the task is still running, leave it alone;
- worker is gone (and the scan started before a short grace window) -> it is a
- worker is gone (and the task started before a short grace window) -> it is a
real orphan: the stale task is revoked and marked terminal (clearing the
pending/started alert), and the scan is re-enqueued from scratch.
pending/started alert), and the task is re-enqueued from its stored name and
kwargs.
The re-run is safe because only tasks with proven idempotency are allowlisted.
Scan persistence, for example, clears the scan's prior findings and materialized
summary/compliance rows before re-writing them. Jira sends are allowlisted too:
each finding is reserved in a dispatch table before the external call, so a re-run
skips already-ticketed findings (the worst case is one finding missed if a worker
is hard-killed mid-send, never a duplicate issue). Other external side effects stay
terminal: the S3 upload rebuilds from worker-local files that do not survive a
crash, and report/Security Hub recovery is out of scope.
The re-run is safe because only tasks with proven idempotency are allowlisted: the
summary/aggregation tasks clear and re-write their own rows, and deletions are
idempotent. Scan tasks and external side effects are excluded: re-running a scan is
not safe to do automatically, Jira sends would create duplicate issues, the S3
upload rebuilds from worker-local files that do not survive a crash, and
report/Security Hub recovery is out of scope.
3. **Recovery cap.** Each automatic re-enqueue increments `Scan.recovery_count`.
After `--max-attempts` recoveries (default 3) the scan is marked `FAILED` instead
of re-enqueued, so a task that repeatedly kills its worker cannot loop forever.
3. **Recovery cap.** A per-task Valkey counter limits how often the same task is
re-enqueued. After `--max-attempts` recoveries (default 3) the orphan is marked
terminal instead of re-enqueued, so a task that repeatedly kills its worker cannot
loop forever.
A Postgres advisory lock ensures that, even with multiple API/worker replicas, only
one reconciliation runs at a time; the others no-op.
@@ -63,6 +70,18 @@ All settings have safe defaults; override via environment variables.
| `DJANGO_CELERY_TASK_SOFT_TIME_LIMIT` | hard - 600 | Soft limit; raises `SoftTimeLimitExceeded` for cleanup. |
| `DJANGO_CELERY_LONG_TASK_TIME_LIMIT` | `172800` (48h) | Hard limit for scans and provider/tenant deletions, which can legitimately run for more than a day. |
| `DJANGO_CELERY_LONG_TASK_SOFT_TIME_LIMIT` | long hard - 600 | Soft limit for the long-running tasks above. |
| `DJANGO_TASK_RECOVERY_ENABLED` | `false` | Master switch for orphan-task recovery, disabled by default (opt-in); set to `true` to enable. When off, no orphan is detected, marked terminal, or re-enqueued (attack-paths stale cleanup still runs). |
| `DJANGO_TASK_RECOVERY_SUMMARIES_ENABLED` | `true` | Auto re-enqueue orphaned scan summary/aggregation tasks. |
| `DJANGO_TASK_RECOVERY_DELETIONS_ENABLED` | `true` | Auto re-enqueue orphaned provider/tenant deletion tasks. |
Recovery is opt-in: with the master flag off (the default) the sweep does nothing.
Once enabled, the per-group flags default to on, so every group recovers unless you
turn one off; a task whose group flag is off is marked terminal instead of
re-enqueued.
Turning recovery off only disables this watchdog sweep; it does not change Celery's
broker-level redelivery (`task_acks_late`/`task_reject_on_worker_lost`), which still
re-delivers tasks that keep `acks_late=True` on worker loss, independently of this flag.
`task_acks_late` and `task_reject_on_worker_lost` are enabled in `config/celery.py`.
@@ -20,7 +20,7 @@ class Command(BaseCommand):
"--max-attempts",
type=int,
default=3,
help="Give up re-running a task after this many recovery attempts (scans are marked FAILED).",
help="Give up re-running a task after this many recovery attempts; it is then left terminal instead of re-enqueued.",
)
parser.add_argument(
"--dry-run",
@@ -39,6 +39,16 @@ class Command(BaseCommand):
self.stdout.write("Reconcile skipped: another run holds the lock.")
return
if result.get("enabled") is False:
message = (
"Task recovery is disabled (DJANGO_TASK_RECOVERY_ENABLED is off); "
"no orphans were recovered."
)
if result.get("attack_paths") is not None:
message += " Attack-paths stale cleanup still ran."
self.stdout.write(message)
return
self.stdout.write(
self.style.SUCCESS(
"Orphan reconcile complete: "
@@ -1,17 +0,0 @@
# Generated by Django 5.1.15 on 2026-05-30 17:38
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0093_okta_provider"),
]
operations = [
migrations.AddField(
model_name="scan",
name="recovery_count",
field=models.IntegerField(default=0),
),
]
@@ -40,7 +40,7 @@ def delete_periodic_task(apps, schema_editor):
class Migration(migrations.Migration):
dependencies = [
("api", "0094_scan_recovery_count"),
("api", "0093_okta_provider"),
("django_celery_beat", "0019_alter_periodictasks_options"),
]
@@ -1,64 +0,0 @@
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0095_reconcile_orphan_tasks_periodic_task"),
]
operations = [
migrations.CreateModel(
name="JiraIssueDispatch",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("finding_id", models.UUIDField()),
(
"integration",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="jira_dispatches",
to="api.integration",
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "jira_issue_dispatches",
"abstract": False,
},
),
migrations.AddConstraint(
model_name="jiraissuedispatch",
constraint=models.UniqueConstraint(
fields=("tenant_id", "integration_id", "finding_id"),
name="unique_jira_issue_dispatch",
),
),
migrations.AddConstraint(
model_name="jiraissuedispatch",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_jiraissuedispatch",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]
-32
View File
@@ -666,9 +666,6 @@ class Scan(RowLevelSecurityProtectedModel):
state = StateEnumField(choices=StateChoices.choices, default=StateChoices.AVAILABLE)
unique_resource_count = models.IntegerField(default=0)
progress = models.IntegerField(default=0)
# Incremented by the scan-specific orphan-recovery path each time this scan is
# re-pointed to a fresh task; for observability (the retry cap is a Valkey counter).
recovery_count = models.IntegerField(default=0)
scanner_args = models.JSONField(default=dict)
duration = models.IntegerField(null=True, blank=True)
scheduled_at = models.DateTimeField(null=True, blank=True)
@@ -2001,35 +1998,6 @@ class IntegrationProviderRelationship(RowLevelSecurityProtectedModel):
]
class JiraIssueDispatch(RowLevelSecurityProtectedModel):
"""Tracks findings already sent to a Jira integration.
Lets the Jira task be re-run safely (e.g. by orphan recovery): findings with
an existing dispatch row are skipped, so no duplicate issues are created.
"""
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
integration = models.ForeignKey(
Integration, on_delete=models.CASCADE, related_name="jira_dispatches"
)
finding_id = models.UUIDField()
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "jira_issue_dispatches"
constraints = [
models.UniqueConstraint(
fields=["tenant_id", "integration_id", "finding_id"],
name="unique_jira_issue_dispatch",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class SAMLToken(models.Model):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
+12
View File
@@ -307,6 +307,18 @@ ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES = env.int(
"ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES", 2880
) # 48h
# Orphan task recovery feature flags. The master switch is OFF by default, so task
# recovery is opt-in; enable it with DJANGO_TASK_RECOVERY_ENABLED=true. The per-group
# toggles default to enabled, so once the master is on every group recovers unless a
# group is explicitly turned off.
TASK_RECOVERY_ENABLED = env.bool("DJANGO_TASK_RECOVERY_ENABLED", False)
TASK_RECOVERY_SUMMARIES_ENABLED = env.bool(
"DJANGO_TASK_RECOVERY_SUMMARIES_ENABLED", True
)
TASK_RECOVERY_DELETIONS_ENABLED = env.bool(
"DJANGO_TASK_RECOVERY_DELETIONS_ENABLED", True
)
def label_postgres_connections(databases):
"""Tag each Postgres connection with ``application_name="<component>:<alias>"``
-9
View File
@@ -11,7 +11,6 @@ from api.db_utils import batch_delete, rls_transaction
from api.models import (
AttackPathsScan,
Finding,
JiraIssueDispatch,
Provider,
ProviderComplianceScore,
Resource,
@@ -81,14 +80,6 @@ def delete_provider(tenant_id: str, pk: str):
deletion_steps = [
("Scan Summaries", ScanSummary.all_objects.filter(scan__provider=instance)),
(
"Jira Issue Dispatches",
JiraIssueDispatch.objects.filter(
finding_id__in=Finding.all_objects.filter(
scan__provider=instance
).values_list("id", flat=True)
),
),
("Findings", Finding.all_objects.filter(scan__provider=instance)),
("Resources", Resource.all_objects.filter(provider=instance)),
("Scans", Scan.all_objects.filter(provider=instance)),
+51 -100
View File
@@ -9,7 +9,7 @@ from tasks.utils import batched
from api.db_router import READ_REPLICA_ALIAS, MainRouter
from api.db_utils import REPLICA_MAX_ATTEMPTS, REPLICA_RETRY_BASE_DELAY, rls_transaction
from api.models import Finding, Integration, JiraIssueDispatch, Provider
from api.models import Finding, Integration, Provider
from api.utils import initialize_prowler_integration, initialize_prowler_provider
from prowler.lib.outputs.asff.asff import ASFF
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
@@ -482,115 +482,66 @@ def send_findings_to_jira(
with rls_transaction(tenant_id):
integration = Integration.objects.get(id=integration_id)
jira_integration = initialize_prowler_integration(integration)
# Idempotency: findings already ticketed for this integration must not be
# sent again on a re-run (e.g. orphan recovery), to avoid duplicate issues
already_sent = {
str(fid)
for fid in JiraIssueDispatch.objects.filter(
integration_id=integration_id, finding_id__in=finding_ids
).values_list("finding_id", flat=True)
}
num_tickets_created = 0
skipped_count = 0
for finding_id in finding_ids:
if str(finding_id) in already_sent:
skipped_count += 1
continue
# Reserve the finding BEFORE the external call. The unique constraint on
# (tenant, integration, finding) makes the dispatch row the single source of
# truth, so a concurrent run or a retry that raced past the bulk pre-check
# cannot create a duplicate issue: created=False means another run already
# claimed it. The reservation is released below if the send does not succeed.
with rls_transaction(tenant_id):
_, created = JiraIssueDispatch.objects.get_or_create(
tenant_id=tenant_id,
integration_id=integration_id,
finding_id=finding_id,
finding_instance = (
Finding.all_objects.select_related("scan__provider")
.prefetch_related("resources")
.get(id=finding_id)
)
if not created:
skipped_count += 1
continue
sent = False
try:
with rls_transaction(tenant_id):
finding_instance = (
Finding.all_objects.select_related("scan__provider")
.prefetch_related("resources")
.get(id=finding_id)
)
# Extract resource information
resource = (
finding_instance.resources.first()
if finding_instance.resources.exists()
else None
)
resource_uid = resource.uid if resource else ""
resource_name = resource.name if resource else ""
resource_tags = {}
if resource and hasattr(resource, "tags"):
resource_tags = resource.get_tags(tenant_id)
# Extract resource information
resource = (
finding_instance.resources.first()
if finding_instance.resources.exists()
else None
)
resource_uid = resource.uid if resource else ""
resource_name = resource.name if resource else ""
resource_tags = {}
if resource and hasattr(resource, "tags"):
resource_tags = resource.get_tags(tenant_id)
# Get region
region = resource.region if resource and resource.region else ""
# Get region
region = resource.region if resource and resource.region else ""
# Extract remediation information from check_metadata
check_metadata = finding_instance.check_metadata
remediation = check_metadata.get("remediation", {})
recommendation = remediation.get("recommendation", {})
remediation_code = remediation.get("code", {})
# Extract remediation information from check_metadata
check_metadata = finding_instance.check_metadata
remediation = check_metadata.get("remediation", {})
recommendation = remediation.get("recommendation", {})
remediation_code = remediation.get("code", {})
# Send the individual finding to Jira
sent = bool(
jira_integration.send_finding(
check_id=finding_instance.check_id,
check_title=check_metadata.get("checktitle", ""),
severity=finding_instance.severity,
status=finding_instance.status,
status_extended=finding_instance.status_extended or "",
provider=finding_instance.scan.provider.provider,
region=region,
resource_uid=resource_uid,
resource_name=resource_name,
risk=check_metadata.get("risk", ""),
recommendation_text=recommendation.get("text", ""),
recommendation_url=recommendation.get("url", ""),
remediation_code_native_iac=remediation_code.get(
"nativeiac", ""
),
remediation_code_terraform=remediation_code.get(
"terraform", ""
),
remediation_code_cli=remediation_code.get("cli", ""),
remediation_code_other=remediation_code.get("other", ""),
resource_tags=resource_tags,
compliance=finding_instance.compliance or {},
project_key=project_key,
issue_type=issue_type,
)
)
finally:
if not sent:
# Release the reservation so a later run can retry this finding: it
# was not ticketed (send failed or raised), so the row must not block
# a future legitimate send.
with rls_transaction(tenant_id):
JiraIssueDispatch.objects.filter(
tenant_id=tenant_id,
integration_id=integration_id,
finding_id=finding_id,
).delete()
if sent:
num_tickets_created += 1
else:
logger.error(f"Failed to send finding {finding_id} to Jira")
# Send the individual finding to Jira
result = jira_integration.send_finding(
check_id=finding_instance.check_id,
check_title=check_metadata.get("checktitle", ""),
severity=finding_instance.severity,
status=finding_instance.status,
status_extended=finding_instance.status_extended or "",
provider=finding_instance.scan.provider.provider,
region=region,
resource_uid=resource_uid,
resource_name=resource_name,
risk=check_metadata.get("risk", ""),
recommendation_text=recommendation.get("text", ""),
recommendation_url=recommendation.get("url", ""),
remediation_code_native_iac=remediation_code.get("nativeiac", ""),
remediation_code_terraform=remediation_code.get("terraform", ""),
remediation_code_cli=remediation_code.get("cli", ""),
remediation_code_other=remediation_code.get("other", ""),
resource_tags=resource_tags,
compliance=finding_instance.compliance or {},
project_key=project_key,
issue_type=issue_type,
)
if result:
num_tickets_created += 1
else:
logger.error(f"Failed to send finding {finding_id} to Jira")
return {
"created_count": num_tickets_created,
"failed_count": len(finding_ids) - num_tickets_created - skipped_count,
"skipped_count": skipped_count,
"failed_count": len(finding_ids) - num_tickets_created,
}
+75 -131
View File
@@ -37,35 +37,52 @@ ORPHAN_RECOVERY_LOCK_KEY = 0x70726F77 # "prow"
# Non-terminal states that mean "a worker had this and may have died with it".
IN_FLIGHT_STATES = (states.STARTED, states.RECEIVED)
# Scan tasks are recovered by re-running scan-perform on the EXISTING scan row,
# not by re-enqueuing the original task: re-enqueuing scan-perform-scheduled would
# hit its "a scan is already executing" guard and no-op, leaving the scan stuck.
_SCAN_TASKS = ("scan-perform", "scan-perform-scheduled")
# Tasks with proven idempotency are auto re-enqueued. Scans/summaries clear and
# rewrite their own rows. integration-jira is safe too: each finding is reserved in
# JiraIssueDispatch before the external call, so a re-run skips already-ticketed
# findings (worst case one finding missed on a mid-send crash, never a duplicate).
# Other external side effects stay terminal: integration-s3 rebuilds its upload from
# worker-local files that do not survive a crash, and report/Security Hub recovery is
# out of scope.
REENQUEUEABLE_TASKS = {
*_SCAN_TASKS,
"provider-deletion",
"tenant-deletion",
"scan-summary",
"scan-compliance-overviews",
"scan-provider-compliance-scores",
"scan-daily-severity",
"scan-finding-group-summaries",
"scan-reset-ephemeral-resources",
"integration-jira",
# Tasks with proven idempotency are eligible for auto re-enqueue, grouped so each
# group can be toggled independently by a feature flag (see config.django.base).
# Summaries clear and rewrite their own rows and deletions are idempotent. Tasks with
# external side effects are never eligible: integration-jira would create duplicate
# issues, integration-s3 rebuilds its upload from worker-local files that do not
# survive a crash, and report/Security Hub recovery is out of scope.
RECOVERY_TASK_GROUPS = {
"summaries": {
"scan-summary",
"scan-compliance-overviews",
"scan-provider-compliance-scores",
"scan-daily-severity",
"scan-finding-group-summaries",
"scan-reset-ephemeral-resources",
},
"deletions": {"provider-deletion", "tenant-deletion"},
}
# Tasks excluded from generic recovery: attack-paths scans are handled by their own
# stale-cleanup (which also drops the temp Neo4j db), and the maintenance tasks must
# not self-recover (they run again on their own schedule).
def reenqueueable_tasks() -> set[str]:
"""Task names eligible for auto re-enqueue, honoring the per-group feature flags.
A group whose flag is disabled is dropped, so its orphaned tasks are marked
terminal instead of re-enqueued.
"""
from django.conf import settings
group_enabled = {
"summaries": settings.TASK_RECOVERY_SUMMARIES_ENABLED,
"deletions": settings.TASK_RECOVERY_DELETIONS_ENABLED,
}
return {
task
for group, tasks in RECOVERY_TASK_GROUPS.items()
if group_enabled[group]
for task in tasks
}
# Tasks the watchdog ignores entirely (not even marked terminal): scan tasks are not
# auto-recovered, since re-running a scan is not safe to do automatically; attack-paths
# scans are handled by their own stale-cleanup (which also drops the temp Neo4j db);
# and the maintenance tasks must not self-recover (they run again on their own schedule).
_SKIP_RECOVERY = {
"scan-perform",
"scan-perform-scheduled",
"attack-paths-scan-perform",
"attack-paths-cleanup-stale-scans",
"reconcile-orphan-tasks",
@@ -166,15 +183,22 @@ def reconcile_orphans(
logger.info("Orphan reconcile skipped: another run holds the lock")
return {"acquired": False}
# Populate the task registry so we can re-enqueue any task by name.
import tasks.tasks # noqa: F401
from django.conf import settings
result = _reconcile_task_results(
grace_minutes=grace_minutes,
max_attempts=max_attempts,
window_hours=window_hours,
dry_run=dry_run,
)
if settings.TASK_RECOVERY_ENABLED:
# Populate the task registry so we can re-enqueue any task by name.
import tasks.tasks # noqa: F401
result = _reconcile_task_results(
grace_minutes=grace_minutes,
max_attempts=max_attempts,
window_hours=window_hours,
dry_run=dry_run,
)
result["enabled"] = True
else:
logger.info("Orphan task recovery disabled by feature flag")
result = {"recovered": [], "failed": [], "skipped": [], "enabled": False}
if not dry_run:
from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans
@@ -264,34 +288,27 @@ def _recover_task(task_result, max_attempts: int, window_hours: int) -> str:
task_result.date_done = now
task_result.save(update_fields=["status", "date_done"])
attempt = _recovery_attempt_count(name, kwargs_repr, window_hours)
if name not in REENQUEUEABLE_TASKS or attempt > max_attempts:
reason = (
f"{name} is not allowlisted for auto recovery"
if name not in REENQUEUEABLE_TASKS
else f"recovery cap reached ({attempt}/{max_attempts})"
)
_fail_domain_row(task_result.task_id, name, now)
if name not in reenqueueable_tasks():
logger.warning(
"Orphan %s (%s) not re-enqueued: %s", task_result.task_id, name, reason
"Orphan %s (%s) not re-enqueued: not allowlisted for auto recovery",
task_result.task_id,
name,
)
return "failed"
# Scan tasks: re-run the EXISTING scan row directly via scan-perform, so the
# scheduled-scan "already executing" guard cannot turn recovery into a no-op.
# Falls through to the generic path only if no scan is linked yet (e.g. a
# scheduled task that died before creating one), where re-running it creates one.
if name in _SCAN_TASKS:
scan = _scan_for_task(task_result.task_id)
if scan is not None:
if not _reenqueue_scan(task_result.task_id, scan):
return "failed"
logger.info(
"Re-enqueued orphaned scan %s (was task %s)",
scan.id,
task_result.task_id,
)
return "recovered"
# Count the attempt only once the task is allowlisted, so a task sitting in a
# disabled group does not burn its recovery budget while the flag is off (and is
# not already over the cap the moment the group is re-enabled).
attempt = _recovery_attempt_count(name, kwargs_repr, window_hours)
if attempt > max_attempts:
logger.warning(
"Orphan %s (%s) not re-enqueued: recovery cap reached (%d/%d)",
task_result.task_id,
name,
attempt,
max_attempts,
)
return "failed"
task_obj = current_app.tasks.get(name)
if task_obj is None:
@@ -311,7 +328,6 @@ def _recover_task(task_result, max_attempts: int, window_hours: int) -> str:
task_result.task_id,
name,
)
_fail_domain_row(task_result.task_id, name, now)
return "failed"
new_task_id = str(uuid4())
task_obj.apply_async(
@@ -323,75 +339,3 @@ def _recover_task(task_result, max_attempts: int, window_hours: int) -> str:
"Re-enqueued orphan %s (%s) as %s", task_result.task_id, name, new_task_id
)
return "recovered"
def _scan_for_task(task_id: str):
"""Return the Scan linked to a Celery task id, or None (read across tenants)."""
from api.db_router import MainRouter
from api.models import Scan
return Scan.all_objects.using(MainRouter.admin_db).filter(task_id=task_id).first()
def _reenqueue_scan(old_task_id: str, scan) -> bool:
"""Re-run an orphaned scan via scan-perform on the existing row.
Pre-provisions the new task linkage (TaskResult + api.Task) and relinks the
Scan before enqueuing, so the FK is valid and a worker can never outrun the DB.
The relink is conditional on the scan still pointing at the old task, so a stale
orphan can never clobber a newer linkage.
"""
from django_celery_results.models import TaskResult
from api.db_utils import rls_transaction
from api.models import Scan
from api.models import Task as APITask
from tasks.tasks import perform_scan_task
tenant_id = str(scan.tenant_id)
new_task_id = str(uuid4())
with rls_transaction(tenant_id):
locked_scan = Scan.all_objects.select_for_update().filter(id=scan.id).first()
if locked_scan is None or str(locked_scan.task_id) != old_task_id:
logger.info(
"Scan %s no longer points at task %s; skipping recovery re-enqueue",
scan.id,
old_task_id,
)
return False
task_result_new, _ = TaskResult.objects.get_or_create(
task_id=new_task_id,
defaults={"status": states.PENDING, "task_name": "scan-perform"},
)
APITask.objects.update_or_create(
id=new_task_id,
tenant_id=tenant_id,
defaults={"task_runner_task": task_result_new},
)
locked_scan.task_id = new_task_id
locked_scan.recovery_count = (locked_scan.recovery_count or 0) + 1
locked_scan.save(update_fields=["task_id", "recovery_count", "updated_at"])
perform_scan_task.apply_async(
kwargs={
"tenant_id": tenant_id,
"scan_id": str(scan.id),
"provider_id": str(scan.provider_id),
},
task_id=new_task_id,
)
return True
def _fail_domain_row(old_task_id: str, name: str, now: datetime) -> None:
"""Mark a scan terminal when its task is capped/denylisted instead of re-run."""
from api.db_utils import rls_transaction
from api.models import Scan, StateChoices
if name in _SCAN_TASKS:
scan = _scan_for_task(old_task_id)
if scan is not None:
with rls_transaction(str(scan.tenant_id)):
Scan.all_objects.filter(id=scan.id, task_id=old_task_id).update(
state=StateChoices.FAILED, completed_at=now
)
+3 -18
View File
@@ -118,19 +118,6 @@ ATTACK_SURFACE_PROVIDER_COMPATIBILITY = {
_ATTACK_SURFACE_MAPPING_CACHE: dict[str, dict] = {}
def _clear_scan_rerun_state(tenant_id: str, scan_id: str) -> None:
"""Remove rows derived from a previous execution of this scan."""
with rls_transaction(tenant_id):
Finding.all_objects.filter(scan_id=scan_id).delete()
ResourceScanSummary.objects.filter(scan_id=scan_id).delete()
ScanCategorySummary.objects.filter(scan_id=scan_id).delete()
ScanGroupSummary.objects.filter(scan_id=scan_id).delete()
ScanSummary.objects.filter(scan_id=scan_id).delete()
AttackSurfaceOverview.objects.filter(scan_id=scan_id).delete()
ComplianceRequirementOverview.objects.filter(scan_id=scan_id).delete()
ComplianceOverviewSummary.objects.filter(scan_id=scan_id).delete()
def aggregate_category_counts(
categories: list[str],
severity: str,
@@ -489,10 +476,9 @@ def _create_compliance_summaries(
)
)
# Idempotent re-run: clear this scan's prior summaries before re-inserting, so
# a recovered scan's summary always reflects its own (re-derived) requirement
# rows rather than keeping a stale row (bulk_create ignore_conflicts alone would
# keep the old one).
# Idempotent re-run: clear this scan's prior summaries before re-inserting, so a
# recovered scan-compliance-overviews run reflects its own re-derived rows instead
# of keeping a stale one (bulk_create ignore_conflicts alone would keep the old).
with rls_transaction(tenant_id):
ComplianceOverviewSummary.objects.filter(scan_id=scan_id).delete()
if summary_objects:
@@ -1039,7 +1025,6 @@ def perform_prowler_scan(
scan_instance.state = StateChoices.EXECUTING
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.save(update_fields=["state", "started_at", "updated_at"])
_clear_scan_rerun_state(tenant_id, scan_id)
# Find the mutelist processor if it exists
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
+14 -2
View File
@@ -260,7 +260,9 @@ def delete_provider_task(provider_id: str, tenant_id: str):
return delete_provider(tenant_id=tenant_id, pk=provider_id)
@shared_task(base=RLSTask, name="scan-perform", queue="scans")
# acks_late=False: a re-run would duplicate findings and the task is not auto-recovered,
# so a crashed scan is dropped rather than redelivered by the broker (as before #11416).
@shared_task(base=RLSTask, name="scan-perform", queue="scans", acks_late=False)
@handle_provider_deletion
def perform_scan_task(
tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None
@@ -304,7 +306,14 @@ def perform_scan_task(
return result
@shared_task(base=RLSTask, bind=True, name="scan-perform-scheduled", queue="scans")
# acks_late=False: like scan-perform; a dropped run is re-fired by Beat on the next tick.
@shared_task(
base=RLSTask,
bind=True,
name="scan-perform-scheduled",
queue="scans",
acks_late=False,
)
@handle_provider_deletion
def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
"""
@@ -1151,10 +1160,13 @@ def security_hub_integration_task(
return upload_security_hub_integration(tenant_id, provider_id, scan_id)
# acks_late=False: Jira sends are not deduplicated and the task is not auto-recovered,
# so a crashed send is dropped rather than redelivered (avoids duplicate Jira issues).
@shared_task(
base=RLSTask,
name="integration-jira",
queue="integrations",
acks_late=False,
)
def jira_integration_task(
tenant_id: str,
+1 -39
View File
@@ -1,12 +1,11 @@
from unittest.mock import call, patch
from uuid import uuid4
import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider, delete_tenant
from api.attack_paths import database as graph_database
from api.models import JiraIssueDispatch, Provider, Tenant, TenantComplianceSummary
from api.models import Provider, Tenant, TenantComplianceSummary
@pytest.mark.django_db
@@ -35,43 +34,6 @@ class TestDeleteProvider:
str(instance.id),
)
def test_delete_provider_removes_jira_dispatches(
self,
providers_fixture,
findings_fixture,
integrations_fixture,
):
"""Deleting a provider removes JiraIssueDispatch rows for its findings only."""
instance = providers_fixture[0]
tenant_id = str(instance.tenant_id)
finding = findings_fixture[0]
integration = integrations_fixture[0]
# Dispatch for one of the provider's findings: must be removed with it.
JiraIssueDispatch.objects.create(
tenant_id=tenant_id,
integration=integration,
finding_id=finding.id,
)
# Dispatch for an unrelated finding: must survive the provider deletion.
unrelated = JiraIssueDispatch.objects.create(
tenant_id=tenant_id,
integration=integration,
finding_id=uuid4(),
)
with (
patch(
"tasks.jobs.deletion.graph_database.get_database_name",
return_value="tenant-db",
),
patch("tasks.jobs.deletion.graph_database.drop_subgraph"),
):
delete_provider(tenant_id, instance.id)
assert not JiraIssueDispatch.objects.filter(finding_id=finding.id).exists()
assert JiraIssueDispatch.objects.filter(pk=unrelated.pk).exists()
def test_delete_provider_does_not_exist(self, tenants_fixture):
with (
patch(
@@ -1640,74 +1640,14 @@ class TestJiraIntegration:
@patch("tasks.jobs.integrations.Finding")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.initialize_prowler_integration")
@patch("tasks.jobs.integrations.JiraIssueDispatch")
def test_send_findings_to_jira_skips_already_dispatched(
self,
mock_jira_dispatch,
mock_initialize_integration,
mock_integration_model,
mock_finding_model,
mock_rls_transaction,
):
"""A re-run skips findings already ticketed (no duplicate Jira issues)."""
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
mock_integration_model.objects.get.return_value = MagicMock()
# finding-1 was already dispatched in a prior run; finding-2 is new.
mock_jira_dispatch.objects.filter.return_value.values_list.return_value = [
"finding-1"
]
mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True)
mock_jira_integration = MagicMock()
mock_jira_integration.send_finding.return_value = True
mock_initialize_integration.return_value = mock_jira_integration
finding2 = MagicMock()
finding2.id = "finding-2"
finding2.check_id = "check_002"
finding2.severity = "low"
finding2.status = "FAIL"
finding2.status_extended = ""
finding2.compliance = {}
finding2.resources.exists.return_value = False
finding2.resources.first.return_value = None
finding2.scan.provider.provider = "aws"
finding2.check_metadata = {
"checktitle": "C2",
"risk": "",
"remediation": {"recommendation": {}, "code": {}},
}
mock_finding_model.all_objects.select_related.return_value.prefetch_related.return_value.get.return_value = finding2
result = send_findings_to_jira(
"tenant-123", "integration-456", "PROJ", "Task", ["finding-1", "finding-2"]
)
# finding-1 skipped (already sent); only finding-2 sent -> no duplicate.
assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 1}
mock_jira_integration.send_finding.assert_called_once()
assert (
mock_jira_integration.send_finding.call_args.kwargs["check_id"]
== "check_002"
)
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Finding")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.initialize_prowler_integration")
@patch("tasks.jobs.integrations.JiraIssueDispatch")
def test_send_findings_to_jira_success(
self,
mock_jira_dispatch,
mock_initialize_integration,
mock_integration_model,
mock_finding_model,
mock_rls_transaction,
):
"""Test successful sending of findings to Jira using send_finding method"""
mock_jira_dispatch.objects.filter.return_value.values_list.return_value = []
mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True)
tenant_id = "tenant-123"
integration_id = "integration-456"
project_key = "PROJ"
@@ -1799,7 +1739,7 @@ class TestJiraIntegration:
)
# Assertions
assert result == {"created_count": 2, "failed_count": 0, "skipped_count": 0}
assert result == {"created_count": 2, "failed_count": 0}
# Verify Jira integration was initialized
mock_initialize_integration.assert_called_once_with(integration)
@@ -1831,10 +1771,8 @@ class TestJiraIntegration:
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.initialize_prowler_integration")
@patch("tasks.jobs.integrations.logger")
@patch("tasks.jobs.integrations.JiraIssueDispatch")
def test_send_findings_to_jira_partial_failure(
self,
mock_jira_dispatch,
mock_logger,
mock_initialize_integration,
mock_integration_model,
@@ -1842,8 +1780,6 @@ class TestJiraIntegration:
mock_rls_transaction,
):
"""Test partial failure when sending findings to Jira"""
mock_jira_dispatch.objects.filter.return_value.values_list.return_value = []
mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True)
tenant_id = "tenant-123"
integration_id = "integration-456"
project_key = "PROJ"
@@ -1897,35 +1833,23 @@ class TestJiraIntegration:
)
# Assertions
assert result == {"created_count": 2, "failed_count": 1, "skipped_count": 0}
assert result == {"created_count": 2, "failed_count": 1}
# Verify error was logged for the failed finding
mock_logger.error.assert_called_with("Failed to send finding finding-2 to Jira")
# The failed finding's reservation is released so a later run can retry it.
mock_jira_dispatch.objects.filter.assert_any_call(
tenant_id=tenant_id,
integration_id=integration_id,
finding_id="finding-2",
)
mock_jira_dispatch.objects.filter.return_value.delete.assert_called_once()
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Finding")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.initialize_prowler_integration")
@patch("tasks.jobs.integrations.JiraIssueDispatch")
def test_send_findings_to_jira_no_resources(
self,
mock_jira_dispatch,
mock_initialize_integration,
mock_integration_model,
mock_finding_model,
mock_rls_transaction,
):
"""Test sending findings to Jira when finding has no resources"""
mock_jira_dispatch.objects.filter.return_value.values_list.return_value = []
mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True)
tenant_id = "tenant-123"
integration_id = "integration-456"
project_key = "PROJ"
@@ -1983,7 +1907,7 @@ class TestJiraIntegration:
)
# Assertions
assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0}
assert result == {"created_count": 1, "failed_count": 0}
# Verify send_finding was called with empty resource fields
call_kwargs = mock_jira_integration.send_finding.call_args.kwargs
@@ -1996,18 +1920,14 @@ class TestJiraIntegration:
@patch("tasks.jobs.integrations.Finding")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.initialize_prowler_integration")
@patch("tasks.jobs.integrations.JiraIssueDispatch")
def test_send_findings_to_jira_with_empty_check_metadata(
self,
mock_jira_dispatch,
mock_initialize_integration,
mock_integration_model,
mock_finding_model,
mock_rls_transaction,
):
"""Test sending findings to Jira when check_metadata is empty or missing fields"""
mock_jira_dispatch.objects.filter.return_value.values_list.return_value = []
mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), True)
tenant_id = "tenant-123"
integration_id = "integration-456"
project_key = "PROJ"
@@ -2050,7 +1970,7 @@ class TestJiraIntegration:
)
# Assertions
assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0}
assert result == {"created_count": 1, "failed_count": 0}
# Verify send_finding was called with default/empty values
call_kwargs = mock_jira_integration.send_finding.call_args.kwargs
@@ -2063,94 +1983,3 @@ class TestJiraIntegration:
assert call_kwargs["remediation_code_cli"] == ""
assert call_kwargs["remediation_code_other"] == ""
assert call_kwargs["compliance"] == {}
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Finding")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.initialize_prowler_integration")
@patch("tasks.jobs.integrations.JiraIssueDispatch")
def test_send_findings_to_jira_reserves_before_sending(
self,
mock_jira_dispatch,
mock_initialize_integration,
mock_integration_model,
mock_finding_model,
mock_rls_transaction,
):
"""The dispatch row is reserved before the external Jira call (reserve-then-act)."""
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
mock_integration_model.objects.get.return_value = MagicMock()
mock_jira_dispatch.objects.filter.return_value.values_list.return_value = []
order = []
mock_jira_dispatch.objects.get_or_create.side_effect = lambda **kw: (
order.append(("reserve", kw)) or (MagicMock(), True)
)
mock_jira_integration = MagicMock()
mock_jira_integration.send_finding.side_effect = lambda **kw: (
order.append(("send", kw)) or True
)
mock_initialize_integration.return_value = mock_jira_integration
finding = MagicMock()
finding.id = "finding-1"
finding.check_id = "check_001"
finding.severity = "low"
finding.status = "FAIL"
finding.status_extended = ""
finding.compliance = {}
finding.resources.exists.return_value = False
finding.resources.first.return_value = None
finding.scan.provider.provider = "aws"
finding.check_metadata = {
"checktitle": "C1",
"risk": "",
"remediation": {"recommendation": {}, "code": {}},
}
mock_finding_model.all_objects.select_related.return_value.prefetch_related.return_value.get.return_value = finding
result = send_findings_to_jira(
"tenant-123", "integration-456", "PROJ", "Task", ["finding-1"]
)
assert result == {"created_count": 1, "failed_count": 0, "skipped_count": 0}
# Reservation must precede the external send.
assert [entry[0] for entry in order] == ["reserve", "send"]
# A successful send keeps the reservation (no rollback delete).
mock_jira_dispatch.objects.filter.return_value.delete.assert_not_called()
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Finding")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.initialize_prowler_integration")
@patch("tasks.jobs.integrations.JiraIssueDispatch")
def test_send_findings_to_jira_skips_when_already_reserved(
self,
mock_jira_dispatch,
mock_initialize_integration,
mock_integration_model,
mock_finding_model,
mock_rls_transaction,
):
"""A finding that races past the bulk pre-check but loses the reservation
(created=False) is skipped without a second issue, leaving the row intact."""
mock_rls_transaction.return_value.__enter__ = MagicMock()
mock_rls_transaction.return_value.__exit__ = MagicMock()
mock_integration_model.objects.get.return_value = MagicMock()
mock_jira_dispatch.objects.filter.return_value.values_list.return_value = []
# Another concurrent run already created the dispatch row.
mock_jira_dispatch.objects.get_or_create.return_value = (MagicMock(), False)
mock_jira_integration = MagicMock()
mock_initialize_integration.return_value = mock_jira_integration
result = send_findings_to_jira(
"tenant-123", "integration-456", "PROJ", "Task", ["finding-1"]
)
assert result == {"created_count": 0, "failed_count": 0, "skipped_count": 1}
mock_jira_integration.send_finding.assert_not_called()
# The reservation belongs to the run that won the race; do not delete it.
mock_jira_dispatch.objects.filter.return_value.delete.assert_not_called()
@@ -4,17 +4,17 @@ from uuid import uuid4
import pytest
from celery import states
from django.test import override_settings
from django_celery_results.models import TaskResult
from api.models import Scan, StateChoices
from api.models import Task as APITask
from tasks.jobs.orphan_recovery import (
_decode_celery_field,
_reconcile_task_results,
_recovery_attempt_count,
_reenqueue_scan,
advisory_lock,
is_worker_alive,
reconcile_orphans,
reenqueueable_tasks,
)
@@ -130,9 +130,83 @@ class TestReconcileTaskResults:
assert tr.task_id in result["failed"]
mock_task.apply_async.assert_not_called()
def test_jira_integration_task_is_reenqueued(self, tenants_fixture):
"""integration-jira is re-enqueued: its JiraIssueDispatch reservation makes a
re-run skip already-ticketed findings, so recovery cannot duplicate issues."""
@override_settings(TASK_RECOVERY_SUMMARIES_ENABLED=False)
def test_disabled_group_task_is_not_reenqueued(self, tenants_fixture):
"""A task whose group feature flag is off stays terminal, not re-enqueued."""
tr = _orphan_result(
name="scan-summary",
kwargs={
"tenant_id": str(tenants_fixture[0].id),
"scan_id": str(uuid4()),
},
worker="dead@gone",
created_minutes_ago=60,
)
p_alive, p_revoke, p_app, mock_task = self._patches(alive=False)
with (
p_alive,
p_revoke,
p_app,
patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1),
):
result = _reconcile_task_results(
grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False
)
assert tr.task_id in result["failed"]
mock_task.apply_async.assert_not_called()
@override_settings(TASK_RECOVERY_SUMMARIES_ENABLED=False)
def test_disabled_group_task_does_not_consume_recovery_attempt(
self, tenants_fixture
):
"""A disabled-group task is failed without incrementing its Valkey attempt
counter, so re-enabling the group does not start it at the cap."""
tr = _orphan_result(
name="scan-summary",
kwargs={"tenant_id": str(tenants_fixture[0].id), "scan_id": str(uuid4())},
worker="dead@gone",
created_minutes_ago=60,
)
p_alive, p_revoke, p_app, mock_task = self._patches(alive=False)
with (
p_alive,
p_revoke,
p_app,
patch("tasks.jobs.orphan_recovery._recovery_attempt_count") as mock_count,
):
result = _reconcile_task_results(
grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False
)
assert tr.task_id in result["failed"]
mock_count.assert_not_called()
def test_scan_task_is_skipped_entirely(self, tenants_fixture):
"""Scan tasks are excluded from recovery: the watchdog never touches them."""
tr = _orphan_result(
name="scan-perform",
kwargs={
"tenant_id": str(tenants_fixture[0].id),
"scan_id": str(uuid4()),
},
worker="dead@gone",
created_minutes_ago=60,
)
p_alive, p_revoke, p_app, mock_task = self._patches(alive=False)
with p_alive, p_revoke, p_app:
result = _reconcile_task_results(
grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False
)
assert tr.task_id not in result["recovered"]
assert tr.task_id not in result["failed"]
assert tr.task_id not in result["skipped"]
mock_task.apply_async.assert_not_called()
def test_jira_integration_task_is_not_reenqueued(self, tenants_fixture):
"""integration-jira stays terminal: re-running it would create duplicate Jira
issues, so an orphaned send is failed instead of re-enqueued."""
tenant = tenants_fixture[0]
kwargs = {
"tenant_id": str(tenant.id),
@@ -158,13 +232,10 @@ class TestReconcileTaskResults:
grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False
)
assert tr.task_id in result["recovered"]
assert tr.task_id in result["failed"]
tr.refresh_from_db()
assert tr.status == states.REVOKED # stale result cleared (no pending alert)
mock_task.apply_async.assert_called_once()
call = mock_task.apply_async.call_args.kwargs
assert call["kwargs"] == kwargs
assert call["task_id"] != tr.task_id # fresh task id
mock_task.apply_async.assert_not_called()
def test_skips_live_worker(self, tenants_fixture):
tr = _orphan_result(
@@ -246,98 +317,6 @@ class TestReconcileTaskResults:
mock_task.apply_async.assert_not_called()
@pytest.mark.django_db
class TestScanRecovery:
"""Scans are recovered by re-running scan-perform on the EXISTING scan row,
so even a scheduled-scan orphan (whose own task would no-op on its guard) is
actually re-executed."""
def _scan_orphan(self, tenant, provider, name):
old_id = str(uuid4())
tr = TaskResult.objects.create(
task_id=old_id,
status=states.STARTED,
task_name=name,
worker="dead@gone",
task_kwargs=repr(
{"tenant_id": str(tenant.id), "provider_id": str(provider.id)}
),
task_args=repr([]),
)
TaskResult.objects.filter(pk=tr.pk).update(
date_created=datetime.now(tz=timezone.utc) - timedelta(minutes=60)
)
APITask.objects.create(id=old_id, tenant_id=tenant.id, task_runner_task=tr)
scan = Scan.objects.create(
name="scan-orphan",
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING,
tenant_id=tenant.id,
task_id=old_id,
recovery_count=0,
)
return old_id, scan
@pytest.mark.parametrize("name", ["scan-perform", "scan-perform-scheduled"])
def test_scan_recovered_via_scan_perform(
self, tenants_fixture, providers_fixture, name
):
tenant, provider = tenants_fixture[0], providers_fixture[0]
old_id, scan = self._scan_orphan(tenant, provider, name)
with (
patch("tasks.jobs.orphan_recovery.is_worker_alive", return_value=False),
patch("tasks.jobs.orphan_recovery.revoke_task"),
patch("tasks.jobs.orphan_recovery._recovery_attempt_count", return_value=1),
patch("tasks.tasks.perform_scan_task") as mock_scan_task,
):
result = _reconcile_task_results(
grace_minutes=2, max_attempts=3, window_hours=6, dry_run=False
)
assert old_id in result["recovered"]
scan.refresh_from_db()
assert str(scan.task_id) != old_id # relinked to a fresh task
assert scan.recovery_count == 1
assert TaskResult.objects.get(task_id=old_id).status == states.REVOKED
# Recovered by re-running scan-perform on the existing scan row (so the
# scheduled guard cannot no-op it), regardless of the original task name.
mock_scan_task.apply_async.assert_called_once()
assert mock_scan_task.apply_async.call_args.kwargs["kwargs"]["scan_id"] == str(
scan.id
)
def test_reenqueue_skips_when_scan_already_repointed(
self, tenants_fixture, providers_fixture
):
# The scan already points at a newer task, so a stale orphan must not relink
# it or launch a second concurrent run against the same scan row.
tenant, provider = tenants_fixture[0], providers_fixture[0]
newer_id = str(uuid4())
tr = TaskResult.objects.create(
task_id=newer_id, status=states.STARTED, task_name="scan-perform"
)
APITask.objects.create(id=newer_id, tenant_id=tenant.id, task_runner_task=tr)
scan = Scan.objects.create(
name="scan-orphan",
provider=provider,
trigger=Scan.TriggerChoices.SCHEDULED,
state=StateChoices.EXECUTING,
tenant_id=tenant.id,
task_id=newer_id,
recovery_count=0,
)
with patch("tasks.tasks.perform_scan_task") as mock_scan_task:
recovered = _reenqueue_scan(str(uuid4()), scan)
assert recovered is False
mock_scan_task.apply_async.assert_not_called()
scan.refresh_from_db()
assert scan.recovery_count == 0
@pytest.mark.django_db
class TestOrphanRecoveryHelpers:
def test_advisory_lock_acquires_and_releases(self):
@@ -370,3 +349,60 @@ class TestOrphanRecoveryHelpers:
with patch("redis.from_url", return_value=redis_client):
assert _recovery_attempt_count("probe-task", kwargs_repr, 6) == 1
assert _recovery_attempt_count("probe-task", kwargs_repr, 6) == 2
class TestRecoveryFeatureFlags:
def test_all_groups_enabled_by_default(self):
tasks = reenqueueable_tasks()
assert "scan-summary" in tasks
assert {"provider-deletion", "tenant-deletion"} <= tasks
@override_settings(TASK_RECOVERY_SUMMARIES_ENABLED=False)
def test_summaries_group_flag_excludes_summary_tasks(self):
tasks = reenqueueable_tasks()
assert "scan-summary" not in tasks
assert "scan-compliance-overviews" not in tasks
assert "provider-deletion" in tasks
@override_settings(TASK_RECOVERY_DELETIONS_ENABLED=False)
def test_deletions_group_flag_excludes_deletion_tasks(self):
tasks = reenqueueable_tasks()
assert "provider-deletion" not in tasks
assert "tenant-deletion" not in tasks
assert "scan-summary" in tasks
@pytest.mark.django_db
class TestRecoveryMasterFlag:
@override_settings(TASK_RECOVERY_ENABLED=False)
def test_master_flag_disables_task_recovery(self):
with (
patch(
"tasks.jobs.orphan_recovery._reconcile_task_results"
) as mock_reconcile,
patch(
"tasks.jobs.attack_paths.cleanup.cleanup_stale_attack_paths_scans",
return_value={},
),
):
result = reconcile_orphans(grace_minutes=2, max_attempts=3, dry_run=False)
mock_reconcile.assert_not_called()
assert result["acquired"] is True
assert result["enabled"] is False
@override_settings(TASK_RECOVERY_ENABLED=True)
def test_master_flag_enabled_runs_task_recovery(self):
with (
patch(
"tasks.jobs.orphan_recovery._reconcile_task_results",
return_value={"recovered": [], "failed": [], "skipped": []},
) as mock_reconcile,
patch(
"tasks.jobs.attack_paths.cleanup.cleanup_stale_attack_paths_scans",
return_value={},
),
):
reconcile_orphans(grace_minutes=2, max_attempts=3, dry_run=False)
mock_reconcile.assert_called_once()
-128
View File
@@ -32,15 +32,12 @@ from tasks.utils import CustomEncoder
from api.db_router import MainRouter
from api.exceptions import ProviderConnectionError
from api.models import (
AttackSurfaceOverview,
Finding,
MuteRule,
Provider,
Resource,
ResourceScanSummary,
Scan,
ScanCategorySummary,
ScanGroupSummary,
ScanSummary,
StateChoices,
StatusChoices,
@@ -232,131 +229,6 @@ class TestPerformScan:
# Assert that failed_findings_count is 0 (finding is PASS and muted)
assert scan_resource.failed_findings_count == 0
def test_perform_prowler_scan_idempotent_on_rerun(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
"""Re-running a scan for the same scan_id must not duplicate findings."""
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch("tasks.jobs.scan.ProwlerScan") as mock_prowler_scan_class,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE",
new_callable=dict,
),
patch("api.compliance.PROWLER_CHECKS", new_callable=dict) as mock_checks,
):
mock_checks["aws"] = {"check1": {"compliance1"}}
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
provider_id = str(provider.id)
stale_resource = Resource.objects.create(
tenant_id=tenant.id,
provider=provider,
uid="stale_resource_uid",
name="stale",
region="stale-region",
service="stale-service",
type="stale-type",
)
ResourceScanSummary.objects.create(
tenant_id=tenant.id,
scan_id=scan.id,
resource_id=stale_resource.id,
service="stale-service",
region="stale-region",
resource_type="stale-type",
)
ScanCategorySummary.objects.create(
tenant_id=tenant.id,
scan=scan,
category="stale-category",
severity=Severity.medium,
total_findings=1,
)
ScanGroupSummary.objects.create(
tenant_id=tenant.id,
scan=scan,
resource_group="stale-group",
severity=Severity.medium,
total_findings=1,
)
ScanSummary.objects.create(
tenant_id=tenant.id,
scan=scan,
check_id="stale_check",
service="stale-service",
severity=Severity.medium,
region="stale-region",
total=1,
)
AttackSurfaceOverview.objects.create(
tenant_id=tenant.id,
scan=scan,
attack_surface_type=AttackSurfaceOverview.AttackSurfaceTypeChoices.SECRETS,
total_findings=1,
)
finding = MagicMock()
finding.uid = "dup_probe_finding"
finding.status = StatusChoices.PASS
finding.status_extended = "x"
finding.severity = Severity.medium
finding.check_id = "check1"
finding.get_metadata.return_value = {"key": "value"}
finding.resource_uid = "resource_uid"
finding.resource_name = "resource_name"
finding.region = "region"
finding.service_name = "service_name"
finding.resource_type = "resource_type"
finding.resource_tags = {}
finding.muted = False
finding.raw = {}
finding.resource_metadata = {}
finding.resource_details = {}
finding.partition = "partition"
finding.compliance = {}
mock_scan_instance = MagicMock()
mock_scan_instance.scan.return_value = [(100, [finding])]
mock_prowler_scan_class.return_value = mock_scan_instance
mock_provider_instance = MagicMock()
mock_provider_instance.get_regions.return_value = ["region"]
mock_initialize_prowler_provider.return_value = mock_provider_instance
# Run the same scan twice (simulating an orphan-recovery re-run).
perform_prowler_scan(tenant_id, scan_id, provider_id, ["check1"])
perform_prowler_scan(tenant_id, scan_id, provider_id, ["check1"])
# Neither findings nor resources are duplicated by the re-run: findings are
# scope-deleted before re-insert; resources are upserted by (tenant, provider, uid).
assert Finding.objects.filter(scan=scan).count() == 1
assert Resource.objects.filter(provider=provider).count() == 2
assert ResourceScanSummary.objects.filter(scan_id=scan.id).count() == 1
assert not ResourceScanSummary.objects.filter(
scan_id=scan.id, resource_id=stale_resource.id
).exists()
assert not ScanCategorySummary.objects.filter(scan=scan).exists()
assert not ScanGroupSummary.objects.filter(scan=scan).exists()
assert not ScanSummary.objects.filter(
scan=scan, check_id="stale_check"
).exists()
assert not AttackSurfaceOverview.objects.filter(scan=scan).exists()
@patch("tasks.jobs.scan.ProwlerScan")
@patch(
"tasks.jobs.scan.initialize_prowler_provider",