mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-05-06 08:47:18 +00:00
fix(api): Attack Paths AWS region fallback and stale SCHEDULED cleanup (#10917)
This commit is contained in:
@@ -2,6 +2,14 @@
|
||||
|
||||
All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.26.1] (Prowler UNRELEASED)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- Attack Paths: AWS scans no longer fail when enabled regions cannot be retrieved, and scans stuck in `scheduled` state are now cleaned up after the stale threshold [(#10917)](https://github.com/prowler-cloud/prowler/pull/10917)
|
||||
|
||||
---
|
||||
|
||||
## [1.26.0] (Prowler v5.25.0)
|
||||
|
||||
### 🚀 Added
|
||||
|
||||
@@ -49,7 +49,7 @@ def start_aws_ingestion(
|
||||
}
|
||||
|
||||
boto3_session = get_boto3_session(prowler_api_provider, prowler_sdk_provider)
|
||||
regions: list[str] = list(prowler_sdk_provider._enabled_regions)
|
||||
regions: list[str] = resolve_aws_regions(prowler_api_provider, prowler_sdk_provider)
|
||||
requested_syncs = list(cartography_aws.RESOURCE_FUNCTIONS.keys())
|
||||
|
||||
sync_args = cartography_aws._build_aws_sync_kwargs(
|
||||
@@ -226,6 +226,48 @@ def get_boto3_session(
|
||||
return boto3_session
|
||||
|
||||
|
||||
def resolve_aws_regions(
|
||||
prowler_api_provider: ProwlerAPIProvider,
|
||||
prowler_sdk_provider: ProwlerSDKProvider,
|
||||
) -> list[str]:
|
||||
"""Resolve the regions to scan, falling back when `_enabled_regions` is `None`.
|
||||
|
||||
The SDK silently sets `_enabled_regions` to `None` when `ec2:DescribeRegions`
|
||||
fails (missing IAM permission, transient error). Without a fallback the
|
||||
Cartography ingestion crashes with a non-actionable `TypeError`. Try the
|
||||
user's `audited_regions` next, then the partition's static region list.
|
||||
Excluded regions are honored on every branch.
|
||||
"""
|
||||
if prowler_sdk_provider._enabled_regions is not None:
|
||||
regions = set(prowler_sdk_provider._enabled_regions)
|
||||
|
||||
elif prowler_sdk_provider.identity.audited_regions:
|
||||
regions = set(prowler_sdk_provider.identity.audited_regions)
|
||||
|
||||
else:
|
||||
partition = prowler_sdk_provider.identity.partition
|
||||
try:
|
||||
regions = prowler_sdk_provider.get_available_aws_service_regions(
|
||||
"ec2", partition
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
f"No region data available for partition {partition!r}; "
|
||||
f"cannot determine regions to scan for "
|
||||
f"{prowler_api_provider.uid}"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Could not enumerate enabled regions for AWS account "
|
||||
f"{prowler_api_provider.uid}; falling back to all regions in "
|
||||
f"partition {partition!r}"
|
||||
)
|
||||
|
||||
excluded = set(getattr(prowler_sdk_provider, "_excluded_regions", None) or ())
|
||||
return sorted(regions - excluded)
|
||||
|
||||
|
||||
def get_aioboto3_session(boto3_session: boto3.Session) -> aioboto3.Session:
|
||||
return aioboto3.Session(botocore_session=boto3_session._session)
|
||||
|
||||
|
||||
@@ -18,28 +18,45 @@ logger = get_task_logger(__name__)
|
||||
|
||||
def cleanup_stale_attack_paths_scans() -> dict:
|
||||
"""
|
||||
Find `EXECUTING` `AttackPathsScan` scans whose workers are dead or that have
|
||||
exceeded the stale threshold, and mark them as `FAILED`.
|
||||
Mark stale `AttackPathsScan` rows as `FAILED`.
|
||||
|
||||
Two-pass detection:
|
||||
Covers two stuck-state scenarios:
|
||||
1. `EXECUTING` scans whose workers are dead, or that have exceeded the
|
||||
stale threshold while alive.
|
||||
2. `SCHEDULED` scans that never made it to a worker — parent scan
|
||||
crashed before dispatch, broker lost the message, etc. Detected by
|
||||
age plus the parent `Scan` no longer being in flight.
|
||||
"""
|
||||
threshold = timedelta(minutes=ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES)
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
cutoff = now - threshold
|
||||
|
||||
cleaned_up: list[str] = []
|
||||
cleaned_up.extend(_cleanup_stale_executing_scans(cutoff))
|
||||
cleaned_up.extend(_cleanup_stale_scheduled_scans(cutoff))
|
||||
|
||||
logger.info(
|
||||
f"Stale `AttackPathsScan` cleanup: {len(cleaned_up)} scan(s) cleaned up"
|
||||
)
|
||||
return {"cleaned_up_count": len(cleaned_up), "scan_ids": cleaned_up}
|
||||
|
||||
|
||||
def _cleanup_stale_executing_scans(cutoff: datetime) -> list[str]:
|
||||
"""
|
||||
Two-pass detection for `EXECUTING` scans:
|
||||
1. If `TaskResult.worker` exists, ping the worker.
|
||||
- Dead worker: cleanup immediately (any age).
|
||||
- Alive + past threshold: revoke the task, then cleanup.
|
||||
- Alive + within threshold: skip.
|
||||
2. If no worker field: fall back to time-based heuristic only.
|
||||
"""
|
||||
threshold = timedelta(minutes=ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES)
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
cutoff = now - threshold
|
||||
|
||||
executing_scans = (
|
||||
executing_scans = list(
|
||||
AttackPathsScan.all_objects.using(MainRouter.admin_db)
|
||||
.filter(state=StateChoices.EXECUTING)
|
||||
.select_related("task__task_runner_task")
|
||||
)
|
||||
|
||||
# Cache worker liveness so each worker is pinged at most once
|
||||
executing_scans = list(executing_scans)
|
||||
workers = {
|
||||
tr.worker
|
||||
for scan in executing_scans
|
||||
@@ -48,7 +65,7 @@ def cleanup_stale_attack_paths_scans() -> dict:
|
||||
}
|
||||
worker_alive = {w: _is_worker_alive(w) for w in workers}
|
||||
|
||||
cleaned_up = []
|
||||
cleaned_up: list[str] = []
|
||||
|
||||
for scan in executing_scans:
|
||||
task_result = (
|
||||
@@ -65,9 +82,7 @@ def cleanup_stale_attack_paths_scans() -> dict:
|
||||
|
||||
# Alive but stale — revoke before cleanup
|
||||
_revoke_task(task_result)
|
||||
reason = (
|
||||
"Scan exceeded stale threshold — " "cleaned up by periodic task"
|
||||
)
|
||||
reason = "Scan exceeded stale threshold — cleaned up by periodic task"
|
||||
else:
|
||||
reason = "Worker dead — cleaned up by periodic task"
|
||||
else:
|
||||
@@ -82,10 +97,57 @@ def cleanup_stale_attack_paths_scans() -> dict:
|
||||
if _cleanup_scan(scan, task_result, reason):
|
||||
cleaned_up.append(str(scan.id))
|
||||
|
||||
logger.info(
|
||||
f"Stale `AttackPathsScan` cleanup: {len(cleaned_up)} scan(s) cleaned up"
|
||||
return cleaned_up
|
||||
|
||||
|
||||
def _cleanup_stale_scheduled_scans(cutoff: datetime) -> list[str]:
|
||||
"""
|
||||
Cleanup `SCHEDULED` scans that never reached a worker.
|
||||
|
||||
Detection:
|
||||
- `state == SCHEDULED`
|
||||
- `started_at < cutoff`
|
||||
- parent `Scan` is no longer in flight (terminal state or missing). This
|
||||
avoids cleaning up rows whose parent Prowler scan is legitimately still
|
||||
running.
|
||||
|
||||
For each match: revoke the queued task (best-effort; harmless if already
|
||||
consumed), atomically flip to `FAILED`, and mark the `TaskResult`. The
|
||||
temp Neo4j database is never created while `SCHEDULED`, so no drop is
|
||||
needed.
|
||||
"""
|
||||
scheduled_scans = list(
|
||||
AttackPathsScan.all_objects.using(MainRouter.admin_db)
|
||||
.filter(
|
||||
state=StateChoices.SCHEDULED,
|
||||
started_at__lt=cutoff,
|
||||
)
|
||||
.select_related("task__task_runner_task", "scan")
|
||||
)
|
||||
return {"cleaned_up_count": len(cleaned_up), "scan_ids": cleaned_up}
|
||||
|
||||
cleaned_up: list[str] = []
|
||||
parent_terminal = (
|
||||
StateChoices.COMPLETED,
|
||||
StateChoices.FAILED,
|
||||
StateChoices.CANCELLED,
|
||||
)
|
||||
|
||||
for scan in scheduled_scans:
|
||||
parent_scan = scan.scan
|
||||
if parent_scan is not None and parent_scan.state not in parent_terminal:
|
||||
continue
|
||||
|
||||
task_result = (
|
||||
getattr(scan.task, "task_runner_task", None) if scan.task else None
|
||||
)
|
||||
if task_result:
|
||||
_revoke_task(task_result, terminate=False)
|
||||
|
||||
reason = "Scan never started — cleaned up by periodic task"
|
||||
if _cleanup_scheduled_scan(scan, task_result, reason):
|
||||
cleaned_up.append(str(scan.id))
|
||||
|
||||
return cleaned_up
|
||||
|
||||
|
||||
def _is_worker_alive(worker: str) -> bool:
|
||||
@@ -98,12 +160,17 @@ def _is_worker_alive(worker: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _revoke_task(task_result) -> None:
|
||||
"""Send `SIGTERM` to a hung Celery task. Non-fatal on failure."""
|
||||
def _revoke_task(task_result, terminate: bool = True) -> None:
|
||||
"""Revoke a Celery task. Non-fatal on failure.
|
||||
|
||||
`terminate=True` SIGTERMs the worker if the task is mid-execution; use
|
||||
for EXECUTING cleanup. `terminate=False` only marks the task id revoked
|
||||
across workers, so any worker pulling the queued message discards it;
|
||||
use for SCHEDULED cleanup where the task hasn't run yet.
|
||||
"""
|
||||
try:
|
||||
current_app.control.revoke(
|
||||
task_result.task_id, terminate=True, signal="SIGTERM"
|
||||
)
|
||||
kwargs = {"terminate": True, "signal": "SIGTERM"} if terminate else {}
|
||||
current_app.control.revoke(task_result.task_id, **kwargs)
|
||||
logger.info(f"Revoked task {task_result.task_id}")
|
||||
except Exception:
|
||||
logger.exception(f"Failed to revoke task {task_result.task_id}")
|
||||
@@ -125,28 +192,64 @@ def _cleanup_scan(scan, task_result, reason: str) -> bool:
|
||||
except Exception:
|
||||
logger.exception(f"Failed to drop temp database {tmp_db_name}")
|
||||
|
||||
# 2. Lock row, verify still EXECUTING, mark FAILED — all atomic
|
||||
with rls_transaction(str(scan.tenant_id)):
|
||||
try:
|
||||
fresh_scan = AttackPathsScan.objects.select_for_update().get(id=scan.id)
|
||||
except AttackPathsScan.DoesNotExist:
|
||||
logger.warning(f"Scan {scan_id_str} no longer exists, skipping")
|
||||
return False
|
||||
fresh_scan = _finalize_failed_scan(scan, StateChoices.EXECUTING, reason)
|
||||
if fresh_scan is None:
|
||||
return False
|
||||
|
||||
if fresh_scan.state != StateChoices.EXECUTING:
|
||||
logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping")
|
||||
return False
|
||||
|
||||
_mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason})
|
||||
|
||||
# 3. Mark `TaskResult` as `FAILURE` (not RLS-protected, outside lock)
|
||||
# Mark `TaskResult` as `FAILURE` (not RLS-protected, outside lock)
|
||||
if task_result:
|
||||
task_result.status = states.FAILURE
|
||||
task_result.date_done = datetime.now(tz=timezone.utc)
|
||||
task_result.save(update_fields=["status", "date_done"])
|
||||
|
||||
# 4. Recover graph_data_ready if provider data still exists
|
||||
recover_graph_data_ready(fresh_scan)
|
||||
|
||||
logger.info(f"Cleaned up stale scan {scan_id_str}: {reason}")
|
||||
return True
|
||||
|
||||
|
||||
def _cleanup_scheduled_scan(scan, task_result, reason: str) -> bool:
|
||||
"""
|
||||
Clean up a `SCHEDULED` scan that never reached a worker.
|
||||
|
||||
Skips the temp Neo4j drop — the database is only created once the worker
|
||||
enters `EXECUTING`, so dropping it here just produces noisy log output.
|
||||
|
||||
Returns `True` if the scan was actually cleaned up, `False` if skipped.
|
||||
"""
|
||||
scan_id_str = str(scan.id)
|
||||
|
||||
fresh_scan = _finalize_failed_scan(scan, StateChoices.SCHEDULED, reason)
|
||||
if fresh_scan is None:
|
||||
return False
|
||||
|
||||
if task_result:
|
||||
task_result.status = states.FAILURE
|
||||
task_result.date_done = datetime.now(tz=timezone.utc)
|
||||
task_result.save(update_fields=["status", "date_done"])
|
||||
|
||||
logger.info(f"Cleaned up scheduled scan {scan_id_str}: {reason}")
|
||||
return True
|
||||
|
||||
|
||||
def _finalize_failed_scan(scan, expected_state: str, reason: str):
|
||||
"""
|
||||
Atomically lock the row, verify it's still in `expected_state`, and
|
||||
mark it `FAILED`. Returns the locked row on success, `None` if the
|
||||
row is gone or has already moved on.
|
||||
"""
|
||||
scan_id_str = str(scan.id)
|
||||
with rls_transaction(str(scan.tenant_id)):
|
||||
try:
|
||||
fresh_scan = AttackPathsScan.objects.select_for_update().get(id=scan.id)
|
||||
except AttackPathsScan.DoesNotExist:
|
||||
logger.warning(f"Scan {scan_id_str} no longer exists, skipping")
|
||||
return None
|
||||
|
||||
if fresh_scan.state != expected_state:
|
||||
logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping")
|
||||
return None
|
||||
|
||||
_mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason})
|
||||
|
||||
return fresh_scan
|
||||
|
||||
@@ -67,25 +67,52 @@ def retrieve_attack_paths_scan(
|
||||
return None
|
||||
|
||||
|
||||
def set_attack_paths_scan_task_id(
|
||||
tenant_id: str,
|
||||
scan_pk: str,
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Persist the Celery `task_id` on the `AttackPathsScan` row.
|
||||
|
||||
Called at dispatch time (when `apply_async` returns) so the row carries
|
||||
the task id even while still `SCHEDULED`. This lets the periodic
|
||||
cleanup revoke queued messages for scans that never reached a worker.
|
||||
"""
|
||||
with rls_transaction(tenant_id):
|
||||
ProwlerAPIAttackPathsScan.objects.filter(id=scan_pk).update(task_id=task_id)
|
||||
|
||||
|
||||
def starting_attack_paths_scan(
|
||||
attack_paths_scan: ProwlerAPIAttackPathsScan,
|
||||
task_id: str,
|
||||
cartography_config: CartographyConfig,
|
||||
) -> None:
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
attack_paths_scan.task_id = task_id
|
||||
attack_paths_scan.state = StateChoices.EXECUTING
|
||||
attack_paths_scan.started_at = datetime.now(tz=timezone.utc)
|
||||
attack_paths_scan.update_tag = cartography_config.update_tag
|
||||
) -> bool:
|
||||
"""Flip the row from `SCHEDULED` to `EXECUTING` atomically.
|
||||
|
||||
attack_paths_scan.save(
|
||||
update_fields=[
|
||||
"task_id",
|
||||
"state",
|
||||
"started_at",
|
||||
"update_tag",
|
||||
]
|
||||
)
|
||||
Returns `False` if the row is gone or has already moved past
|
||||
`SCHEDULED` (e.g., periodic cleanup raced ahead and marked it
|
||||
`FAILED` while the worker message was still in flight).
|
||||
"""
|
||||
with rls_transaction(attack_paths_scan.tenant_id):
|
||||
try:
|
||||
locked = ProwlerAPIAttackPathsScan.objects.select_for_update().get(
|
||||
id=attack_paths_scan.id
|
||||
)
|
||||
except ProwlerAPIAttackPathsScan.DoesNotExist:
|
||||
return False
|
||||
|
||||
if locked.state != StateChoices.SCHEDULED:
|
||||
return False
|
||||
|
||||
locked.state = StateChoices.EXECUTING
|
||||
locked.started_at = datetime.now(tz=timezone.utc)
|
||||
locked.update_tag = cartography_config.update_tag
|
||||
locked.save(update_fields=["state", "started_at", "update_tag"])
|
||||
|
||||
# Keep the in-memory object the caller is holding in sync.
|
||||
attack_paths_scan.state = locked.state
|
||||
attack_paths_scan.started_at = locked.started_at
|
||||
attack_paths_scan.update_tag = locked.update_tag
|
||||
return True
|
||||
|
||||
|
||||
def _mark_scan_finished(
|
||||
|
||||
@@ -97,6 +97,19 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
|
||||
)
|
||||
attack_paths_scan = db_utils.retrieve_attack_paths_scan(tenant_id, scan_id)
|
||||
|
||||
# Idempotency guard: cleanup may have flipped this row to a terminal state
|
||||
# while the message was still in flight. Bail out before touching state.
|
||||
if attack_paths_scan and attack_paths_scan.state in (
|
||||
StateChoices.FAILED,
|
||||
StateChoices.COMPLETED,
|
||||
StateChoices.CANCELLED,
|
||||
):
|
||||
logger.warning(
|
||||
f"Attack Paths scan {attack_paths_scan.id} already in terminal "
|
||||
f"state {attack_paths_scan.state}; skipping execution"
|
||||
)
|
||||
return {}
|
||||
|
||||
# Checks before starting the scan
|
||||
if not cartography_ingestion_function:
|
||||
ingestion_exceptions = {
|
||||
@@ -114,12 +127,17 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
|
||||
|
||||
else:
|
||||
if not attack_paths_scan:
|
||||
# Safety net for in-flight messages or direct task invocations; dispatcher normally pre-creates the row.
|
||||
logger.warning(
|
||||
f"No Attack Paths Scan found for scan {scan_id} and tenant {tenant_id}, let's create it then"
|
||||
)
|
||||
attack_paths_scan = db_utils.create_attack_paths_scan(
|
||||
tenant_id, scan_id, prowler_api_provider.id
|
||||
)
|
||||
if attack_paths_scan and task_id:
|
||||
db_utils.set_attack_paths_scan_task_id(
|
||||
tenant_id, attack_paths_scan.id, task_id
|
||||
)
|
||||
|
||||
tmp_database_name = graph_database.get_database_name(
|
||||
attack_paths_scan.id, temporary=True
|
||||
@@ -141,9 +159,13 @@ def run(tenant_id: str, scan_id: str, task_id: str) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
# Starting the Attack Paths scan
|
||||
db_utils.starting_attack_paths_scan(
|
||||
attack_paths_scan, task_id, tenant_cartography_config
|
||||
)
|
||||
if not db_utils.starting_attack_paths_scan(
|
||||
attack_paths_scan, tenant_cartography_config
|
||||
):
|
||||
logger.warning(
|
||||
f"Attack Paths scan {attack_paths_scan.id} no longer in SCHEDULED state; cleanup likely raced ahead"
|
||||
)
|
||||
return {}
|
||||
|
||||
scan_t0 = time.perf_counter()
|
||||
logger.info(
|
||||
|
||||
@@ -173,10 +173,25 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
|
||||
).apply_async()
|
||||
|
||||
if can_provider_run_attack_paths_scan(tenant_id, provider_id):
|
||||
perform_attack_paths_scan_task.apply_async(
|
||||
# Row is normally created upstream, so this is a safeguard so we can attach the task id below
|
||||
attack_paths_scan = attack_paths_db_utils.retrieve_attack_paths_scan(
|
||||
tenant_id, scan_id
|
||||
)
|
||||
if attack_paths_scan is None:
|
||||
attack_paths_scan = attack_paths_db_utils.create_attack_paths_scan(
|
||||
tenant_id, scan_id, provider_id
|
||||
)
|
||||
|
||||
# Persist the Celery task id so the periodic cleanup can revoke scans stuck in SCHEDULED
|
||||
result = perform_attack_paths_scan_task.apply_async(
|
||||
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
|
||||
)
|
||||
|
||||
if attack_paths_scan and result:
|
||||
attack_paths_db_utils.set_attack_paths_scan_task_id(
|
||||
tenant_id, attack_paths_scan.id, result.task_id
|
||||
)
|
||||
|
||||
|
||||
@shared_task(base=RLSTask, name="provider-connection-check")
|
||||
@set_tenant
|
||||
|
||||
@@ -135,7 +135,7 @@ class TestAttackPathsRun:
|
||||
assert result == ingestion_result
|
||||
mock_retrieve_scan.assert_called_once_with(str(tenant.id), str(scan.id))
|
||||
mock_starting.assert_called_once()
|
||||
config = mock_starting.call_args[0][2]
|
||||
config = mock_starting.call_args[0][1]
|
||||
assert config.neo4j_database == "tenant-db"
|
||||
mock_get_db_name.assert_has_calls(
|
||||
[call(attack_paths_scan.id, temporary=True), call(provider.tenant_id)]
|
||||
@@ -2732,3 +2732,143 @@ class TestCleanupStaleAttackPathsScans:
|
||||
assert result["cleaned_up_count"] == 2
|
||||
# Worker should be pinged exactly once — cache prevents second ping
|
||||
mock_alive.assert_called_once_with("shared-worker@host")
|
||||
|
||||
# `SCHEDULED` state cleanup
|
||||
def _create_scheduled_scan(
|
||||
self,
|
||||
tenant,
|
||||
provider,
|
||||
*,
|
||||
age_minutes,
|
||||
parent_state,
|
||||
with_task=True,
|
||||
):
|
||||
"""Create a SCHEDULED AttackPathsScan with a parent Scan in `parent_state`.
|
||||
|
||||
`age_minutes` controls how far in the past `started_at` is set, so
|
||||
callers can place rows safely past the cleanup cutoff.
|
||||
"""
|
||||
parent_scan = Scan.objects.create(
|
||||
name="Parent Prowler scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=parent_state,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
ap_scan = AttackPathsScan.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider,
|
||||
scan=parent_scan,
|
||||
state=StateChoices.SCHEDULED,
|
||||
started_at=datetime.now(tz=timezone.utc) - timedelta(minutes=age_minutes),
|
||||
)
|
||||
|
||||
task_result = None
|
||||
if with_task:
|
||||
task_result = TaskResult.objects.create(
|
||||
task_id=str(ap_scan.id),
|
||||
task_name="attack-paths-scan-perform",
|
||||
status="PENDING",
|
||||
)
|
||||
task = Task.objects.create(
|
||||
id=task_result.task_id,
|
||||
task_runner_task=task_result,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
ap_scan.task = task
|
||||
ap_scan.save(update_fields=["task_id"])
|
||||
|
||||
return ap_scan, task_result
|
||||
|
||||
@patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready")
|
||||
@patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database")
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.cleanup.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
)
|
||||
@patch("tasks.jobs.attack_paths.cleanup._revoke_task")
|
||||
def test_cleans_up_scheduled_scan_when_parent_is_terminal(
|
||||
self,
|
||||
mock_revoke,
|
||||
mock_drop_db,
|
||||
mock_recover,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans
|
||||
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
ap_scan, task_result = self._create_scheduled_scan(
|
||||
tenant,
|
||||
provider,
|
||||
age_minutes=24 * 60 * 3, # 3 days, safely past any threshold
|
||||
parent_state=StateChoices.FAILED,
|
||||
)
|
||||
|
||||
result = cleanup_stale_attack_paths_scans()
|
||||
|
||||
assert result["cleaned_up_count"] == 1
|
||||
assert str(ap_scan.id) in result["scan_ids"]
|
||||
|
||||
ap_scan.refresh_from_db()
|
||||
assert ap_scan.state == StateChoices.FAILED
|
||||
assert ap_scan.progress == 100
|
||||
assert ap_scan.completed_at is not None
|
||||
assert ap_scan.ingestion_exceptions == {
|
||||
"global_error": "Scan never started — cleaned up by periodic task"
|
||||
}
|
||||
|
||||
# SCHEDULED revoke must NOT terminate a running worker
|
||||
mock_revoke.assert_called_once()
|
||||
assert mock_revoke.call_args.kwargs == {"terminate": False}
|
||||
|
||||
# Temp DB never created for SCHEDULED, so no drop attempted
|
||||
mock_drop_db.assert_not_called()
|
||||
# Tenant Neo4j data is untouched in this path
|
||||
mock_recover.assert_not_called()
|
||||
|
||||
task_result.refresh_from_db()
|
||||
assert task_result.status == "FAILURE"
|
||||
assert task_result.date_done is not None
|
||||
|
||||
@patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready")
|
||||
@patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database")
|
||||
@patch(
|
||||
"tasks.jobs.attack_paths.cleanup.rls_transaction",
|
||||
new=lambda *args, **kwargs: nullcontext(),
|
||||
)
|
||||
@patch("tasks.jobs.attack_paths.cleanup._revoke_task")
|
||||
def test_skips_scheduled_scan_when_parent_still_in_flight(
|
||||
self,
|
||||
mock_revoke,
|
||||
mock_drop_db,
|
||||
mock_recover,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
):
|
||||
from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans
|
||||
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
provider.provider = Provider.ProviderChoices.AWS
|
||||
provider.save()
|
||||
|
||||
ap_scan, _ = self._create_scheduled_scan(
|
||||
tenant,
|
||||
provider,
|
||||
age_minutes=24 * 60 * 3,
|
||||
parent_state=StateChoices.EXECUTING,
|
||||
)
|
||||
|
||||
result = cleanup_stale_attack_paths_scans()
|
||||
|
||||
assert result["cleaned_up_count"] == 0
|
||||
|
||||
ap_scan.refresh_from_db()
|
||||
assert ap_scan.state == StateChoices.SCHEDULED
|
||||
mock_revoke.assert_not_called()
|
||||
|
||||
@@ -842,6 +842,72 @@ class TestScanCompleteTasks:
|
||||
# Attack Paths task should be skipped when provider cannot run it
|
||||
mock_attack_paths_task.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"row_pre_existing",
|
||||
[True, False],
|
||||
ids=["row-pre-existing", "row-missing-fallback"],
|
||||
)
|
||||
@patch("tasks.tasks.aggregate_attack_surface_task.apply_async")
|
||||
@patch("tasks.tasks.chain")
|
||||
@patch("tasks.tasks.create_compliance_requirements_task.si")
|
||||
@patch("tasks.tasks.update_provider_compliance_scores_task.si")
|
||||
@patch("tasks.tasks.perform_scan_summary_task.si")
|
||||
@patch("tasks.tasks.generate_outputs_task.si")
|
||||
@patch("tasks.tasks.generate_compliance_reports_task.si")
|
||||
@patch("tasks.tasks.check_integrations_task.si")
|
||||
@patch("tasks.tasks.attack_paths_db_utils.set_attack_paths_scan_task_id")
|
||||
@patch("tasks.tasks.attack_paths_db_utils.create_attack_paths_scan")
|
||||
@patch("tasks.tasks.attack_paths_db_utils.retrieve_attack_paths_scan")
|
||||
@patch("tasks.tasks.perform_attack_paths_scan_task.apply_async")
|
||||
@patch("tasks.tasks.can_provider_run_attack_paths_scan", return_value=True)
|
||||
def test_scan_complete_dispatches_attack_paths_scan(
|
||||
self,
|
||||
_mock_can_run_attack_paths,
|
||||
mock_attack_paths_task,
|
||||
mock_retrieve,
|
||||
mock_create,
|
||||
mock_set_task_id,
|
||||
mock_check_integrations_task,
|
||||
mock_compliance_reports_task,
|
||||
mock_outputs_task,
|
||||
mock_scan_summary_task,
|
||||
mock_update_compliance_scores_task,
|
||||
mock_compliance_requirements_task,
|
||||
mock_chain,
|
||||
mock_attack_surface_task,
|
||||
row_pre_existing,
|
||||
):
|
||||
"""When a provider can run Attack Paths, dispatch must:
|
||||
1. Reuse the existing row or create one if missing.
|
||||
2. Call apply_async on the Attack Paths task.
|
||||
3. Persist the returned Celery task id on the row.
|
||||
"""
|
||||
existing_row = MagicMock(id="ap-scan-id")
|
||||
if row_pre_existing:
|
||||
mock_retrieve.return_value = existing_row
|
||||
else:
|
||||
mock_retrieve.return_value = None
|
||||
mock_create.return_value = existing_row
|
||||
|
||||
async_result = MagicMock(task_id="celery-task-id")
|
||||
mock_attack_paths_task.return_value = async_result
|
||||
|
||||
_perform_scan_complete_tasks("tenant-id", "scan-id", "provider-id")
|
||||
|
||||
mock_retrieve.assert_called_once_with("tenant-id", "scan-id")
|
||||
if row_pre_existing:
|
||||
mock_create.assert_not_called()
|
||||
else:
|
||||
mock_create.assert_called_once_with("tenant-id", "scan-id", "provider-id")
|
||||
|
||||
mock_attack_paths_task.assert_called_once_with(
|
||||
kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"}
|
||||
)
|
||||
|
||||
mock_set_task_id.assert_called_once_with(
|
||||
"tenant-id", "ap-scan-id", "celery-task-id"
|
||||
)
|
||||
|
||||
|
||||
class TestAttackPathsTasks:
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user