diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 9fd7a5f887..7eb41e9036 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -4,6 +4,7 @@ import json import logging import os import time +import uuid from collections import defaultdict from copy import deepcopy from datetime import datetime, timedelta, timezone @@ -16,7 +17,7 @@ from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter from allauth.socialaccount.providers.saml.views import FinishACSView, LoginView from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError -from celery import chain +from celery import chain, states from celery.result import AsyncResult from config.custom_logging import BackendLogger from config.env import env @@ -60,6 +61,7 @@ from django.utils.dateparse import parse_date from django.utils.decorators import method_decorator from django.views.decorators.cache import cache_control from django_celery_beat.models import PeriodicTask +from django_celery_results.models import TaskResult from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( @@ -2534,28 +2536,45 @@ class ScanViewSet(BaseRLSViewSet): def create(self, request, *args, **kwargs): input_serializer = self.get_serializer(data=request.data) input_serializer.is_valid(raise_exception=True) + + # Broker publish is deferred to on_commit so the worker cannot read + # Scan before BaseRLSViewSet's dispatch-wide atomic commits. + pre_task_id = str(uuid.uuid4()) + with transaction.atomic(): scan = input_serializer.save() - with transaction.atomic(): - task = perform_scan_task.apply_async( - kwargs={ - "tenant_id": self.request.tenant_id, - "scan_id": str(scan.id), - "provider_id": str(scan.provider_id), - # Disabled for now - # checks_to_execute=scan.scanner_args.get("checks_to_execute") - }, + scan.task_id = pre_task_id + scan.save(update_fields=["task_id"]) + + attack_paths_db_utils.create_attack_paths_scan( + tenant_id=self.request.tenant_id, + scan_id=str(scan.id), + provider_id=str(scan.provider_id), ) - attack_paths_db_utils.create_attack_paths_scan( - tenant_id=self.request.tenant_id, - scan_id=str(scan.id), - provider_id=str(scan.provider_id), - ) + task_result, _ = TaskResult.objects.get_or_create( + task_id=pre_task_id, + defaults={"status": states.PENDING, "task_name": "scan-perform"}, + ) + prowler_task, _ = Task.objects.update_or_create( + id=pre_task_id, + tenant_id=self.request.tenant_id, + defaults={"task_runner_task": task_result}, + ) - prowler_task = Task.objects.get(id=task.id) - scan.task_id = task.id - scan.save(update_fields=["task_id"]) + scan_kwargs = { + "tenant_id": self.request.tenant_id, + "scan_id": str(scan.id), + "provider_id": str(scan.provider_id), + # Disabled for now + # checks_to_execute=scan.scanner_args.get("checks_to_execute") + } + + transaction.on_commit( + lambda: perform_scan_task.apply_async( + kwargs=scan_kwargs, task_id=pre_task_id + ) + ) self.response_serializer_class = TaskSerializer output_serializer = self.get_serializer(prowler_task)