Compare commits

...

25 Commits

Author SHA1 Message Date
Adrián Jesús Peña Rodríguez
772667c8cc feat: add auth_method and partition information 2025-02-19 11:29:42 +01:00
Pablo Lara
eb0f6addb4 feat: allow users to download exports 2025-02-19 10:40:12 +01:00
Pablo Lara
526b20ca0a feat: render a button to download the exports 2025-02-19 10:40:12 +01:00
Pablo Lara
b2911f4161 chore: update env to make the API work 2025-02-19 10:40:11 +01:00
Adrián Jesús Peña Rodríguez
afce5dcf0d chore: api format 2025-02-18 16:22:56 +01:00
Adrián Jesús Peña Rodríguez
38fb72e819 fix: s3 unittests tests 2025-02-18 16:17:10 +01:00
Adrián Jesús Peña Rodríguez
97d55d7aa7 fix: fix the batch writing when launching the CLI 2025-02-18 14:27:43 +01:00
Adrián Jesús Peña Rodríguez
79aded5aa3 Merge branch 'master' into PRWLR-5956-Export-Artifacts-only 2025-02-18 13:37:22 +01:00
Adrián Jesús Peña Rodríguez
7139683809 chore: rename variables 2025-02-18 10:11:54 +01:00
Adrián Jesús Peña Rodríguez
d9889776ad test: add export unittests 2025-02-13 21:28:44 +01:00
Adrián Jesús Peña Rodríguez
621e71cfbe ref: improve code 2025-02-13 20:31:39 +01:00
Adrián Jesús Peña Rodríguez
41aec46578 Merge branch 'master' into PRWLR-5956-Export-Artifacts-only 2025-02-12 18:49:12 +01:00
Adrián Jesús Peña Rodríguez
820a8809b5 chore: ruff format 2025-02-12 18:41:55 +01:00
Adrián Jesús Peña Rodríguez
cbf8cf73cf fix: html close file 2025-02-12 17:53:44 +01:00
Adrián Jesús Peña Rodríguez
7e7da99628 ref: move the api output folder 2025-02-12 17:23:23 +01:00
Adrián Jesús Peña Rodríguez
d90b4fa324 chore: remove comment 2025-02-12 17:14:29 +01:00
Adrián Jesús Peña Rodríguez
32e880e9c4 chore: restore rls.py 2025-02-12 16:53:06 +01:00
Adrián Jesús Peña Rodríguez
492e9f24a2 fix: solve duplicated findings 2025-02-12 15:52:53 +01:00
Adrián Jesús Peña Rodríguez
f7e27402aa fix: add condition before close the csv file 2025-02-12 14:37:01 +01:00
Adrián Jesús Peña Rodríguez
747b97fe87 ref: improve export code 2025-02-12 10:47:31 +01:00
Adrián Jesús Peña Rodríguez
d5e2d75c9b chore: update api changelog 2025-02-10 17:26:01 +01:00
Adrián Jesús Peña Rodríguez
82d53c5158 chore: update api schema 2025-02-10 16:39:41 +01:00
Adrián Jesús Peña Rodríguez
326fddd206 Merge branch 'master' into PRWLR-5956-Export-Artifacts-only 2025-02-10 16:28:20 +01:00
Adrián Jesús Peña Rodríguez
63b59e4d42 chore: apply ruff 2025-02-10 16:23:37 +01:00
Adrián Jesús Peña Rodríguez
a790a5060e feat(export): add api export system 2025-02-10 16:09:18 +01:00
34 changed files with 1267 additions and 89 deletions

24
.env
View File

@@ -30,6 +30,30 @@ VALKEY_HOST=valkey
VALKEY_PORT=6379
VALKEY_DB=0
# API scan settings
# The path to the directory where scan artifacts should be stored
DJANGO_TMP_OUTPUT_DIRECTORY = "/tmp/prowler_api_output"
# The maximum number of findings to process in a single batch
DJANGO_FINDINGS_BATCH_SIZE = 1000
# The AWS access key to be used when uploading scan artifacts to an S3 bucket
# If left empty, default AWS credentials resolution behavior will be used
DJANGO_ARTIFACTS_AWS_ACCESS_KEY_ID=""
# The AWS secret key to be used when uploading scan artifacts to an S3 bucket
DJANGO_ARTIFACTS_AWS_SECRET_ACCESS_KEY=""
# An optional AWS session token
DJANGO_ARTIFACTS_AWS_SESSION_TOKEN=""
# The AWS region where your S3 bucket is located (e.g., "us-east-1")
DJANGO_ARTIFACTS_AWS_DEFAULT_REGION=""
# The name of the S3 bucket where scan artifacts should be stored
DJANGO_ARTIFACTS_AWS_S3_OUTPUT_BUCKET=""
# Django settings
DJANGO_ALLOWED_HOSTS=localhost,127.0.0.1,prowler-api
DJANGO_BIND_ADDRESS=0.0.0.0

View File

@@ -11,6 +11,13 @@ All notable changes to the **Prowler API** are documented in this file.
---
## [v1.5.0] (Prowler v5.4.0) - 2025-XX-XX
### Added
- Add API scan report system, now all scans launched from the API will generate a compressed file with the report in OCSF, CSV and HTML formats [(#6878)](https://github.com/prowler-cloud/prowler/pull/6878).
---
## [v1.4.0] (Prowler v5.3.0) - 2025-02-10
### Changed

4
api/poetry.lock generated
View File

@@ -3466,8 +3466,8 @@ tzlocal = "5.2"
[package.source]
type = "git"
url = "https://github.com/prowler-cloud/prowler.git"
reference = "master"
resolved_reference = "7469377079bb88487a741625dcd431aa5375e793"
reference = "PRWLR-5956-Export-Artifacts-only"
resolved_reference = "492e9f24a2666d203950cfd85959d7d3f621b957"
[[package]]
name = "psutil"

View File

@@ -28,7 +28,7 @@ drf-nested-routers = "^0.94.1"
drf-spectacular = "0.27.2"
drf-spectacular-jsonapi = "0.5.1"
gunicorn = "23.0.0"
prowler = {git = "https://github.com/prowler-cloud/prowler.git", branch = "master"}
prowler = {git = "https://github.com/prowler-cloud/prowler.git", branch = "PRWLR-5956-Export-Artifacts-only"}
psycopg2-binary = "2.9.9"
pytest-celery = {extras = ["redis"], version = "^1.0.1"}
# Needed for prowler compatibility

View File

@@ -7,7 +7,7 @@ from rest_framework_json_api.serializers import ValidationError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
def set_tenant(func):
def set_tenant(func=None, *, keep_tenant=False):
"""
Decorator to set the tenant context for a Celery task based on the provided tenant_id.
@@ -40,20 +40,29 @@ def set_tenant(func):
# The tenant context will be set before the task logic executes.
"""
@wraps(func)
@transaction.atomic
def wrapper(*args, **kwargs):
try:
tenant_id = kwargs.pop("tenant_id")
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
def decorator(func):
@wraps(func)
@transaction.atomic
def wrapper(*args, **kwargs):
try:
if not keep_tenant:
tenant_id = kwargs.pop("tenant_id")
else:
tenant_id = kwargs["tenant_id"]
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
return func(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
return wrapper
if func is None:
return decorator
else:
return decorator(func)

View File

@@ -0,0 +1,17 @@
# Generated by Django 5.1.5 on 2025-02-07 10:59
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0009_increase_provider_uid_maximum_length"),
]
operations = [
migrations.AddField(
model_name="scan",
name="output_path",
field=models.CharField(blank=True, max_length=200, null=True),
),
]

View File

@@ -414,6 +414,7 @@ class Scan(RowLevelSecurityProtectedModel):
scheduler_task = models.ForeignKey(
PeriodicTask, on_delete=models.CASCADE, null=True, blank=True
)
output_path = models.CharField(blank=True, null=True, max_length=200)
# TODO: mutelist foreign key
class Meta(RowLevelSecurityProtectedModel.Meta):

View File

@@ -4105,6 +4105,39 @@ paths:
schema:
$ref: '#/components/schemas/ScanUpdateResponse'
description: ''
/api/v1/scans/{id}/report:
get:
operationId: scans_report_retrieve
description: Returns a ZIP file containing the requested report
summary: Download ZIP report
parameters:
- in: query
name: fields[scan-reports]
schema:
type: array
items:
type: string
enum:
- id
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this scan.
required: true
tags:
- Scan
security:
- jwtAuth: []
responses:
'200':
description: Report obtained successfully
'423':
description: There is a problem with the AWS credentials
/api/v1/schedules/daily:
post:
operationId: schedules_daily_create

View File

@@ -274,9 +274,10 @@ class TestValidateInvitation:
expired_time = datetime.now(timezone.utc) - timedelta(days=1)
invitation.expires_at = expired_time
with patch("api.utils.Invitation.objects.using") as mock_using, patch(
"api.utils.datetime"
) as mock_datetime:
with (
patch("api.utils.Invitation.objects.using") as mock_using,
patch("api.utils.datetime") as mock_datetime,
):
mock_db = mock_using.return_value
mock_db.get.return_value = invitation
mock_datetime.now.return_value = datetime.now(timezone.utc)

View File

@@ -1,11 +1,14 @@
import json
from datetime import datetime, timedelta, timezone
from unittest.mock import ANY, Mock, patch
from io import BytesIO
from unittest.mock import ANY, Mock, mock_open, patch
import jwt
import pytest
from botocore.exceptions import ClientError
from conftest import API_JSON_CONTENT_TYPE, TEST_PASSWORD, TEST_USER
from django.conf import settings
from django.http import HttpResponse
from django.urls import reverse
from rest_framework import status
@@ -2156,6 +2159,134 @@ class TestScanViewSet:
response = authenticated_client.get(reverse("scan-list"), {"sort": "invalid"})
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_report_no_output(self, authenticated_client, scans_fixture):
"""
If the scan's output_path is empty, the view should return a 404 response.
"""
scan = scans_fixture[0]
url = reverse("scan-report", kwargs={"pk": scan.pk})
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data["detail"] == "No files found"
def test_report_s3_get_client_fail(self, authenticated_client, scans_fixture):
"""
If output_path starts with "s3://", but get_s3_client() fails, the view should return a 403.
"""
scan = scans_fixture[0]
scan.output_path = "s3://bucket/path/to/file.zip"
scan.save()
url = reverse("scan-report", kwargs={"pk": scan.pk})
# Patch get_s3_client to raise one of the expected exceptions
client_err = ClientError({"Error": {"Code": "AccessDenied"}}, "get_s3_client")
with patch("api.v1.views.get_s3_client", side_effect=client_err):
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert (
response.data["detail"]
== "There is a problem with the AWS credentials."
)
def test_report_s3_get_object_fail(self, authenticated_client, scans_fixture):
"""
If output_path starts with "s3://", and get_s3_client() succeeds but get_object() fails,
the view should return a 500 error.
"""
scan = scans_fixture[0]
scan.output_path = "s3://bucket/path/to/file.zip"
scan.save()
url = reverse("scan-report", kwargs={"pk": scan.pk})
client_err = ClientError({"Error": {"Code": "NoSuchKey"}}, "get_object")
with patch("api.v1.views.get_s3_client") as mock_get_s3_client, patch(
"api.v1.views.env.str", return_value="bucket"
):
s3_client = mock_get_s3_client.return_value
s3_client.get_object.side_effect = client_err
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert response.data["detail"] == "Error accessing cloud storage"
def test_report_s3_success(self, authenticated_client, scans_fixture):
"""
If output_path starts with "s3://", and S3 functions succeed, the view should return an
HttpResponse with the ZIP content.
"""
fake_file_content = b"fake s3 zip content"
scan = scans_fixture[0]
scan.output_path = "s3://bucket/path/to/file.zip"
scan.save()
url = reverse("scan-report", kwargs={"pk": scan.pk})
# Create a dummy S3 object with a 'Body' that returns our fake file content
dummy_body = BytesIO(fake_file_content)
dummy_s3_object = {"Body": dummy_body}
with patch("api.v1.views.get_s3_client") as mock_get_s3_client, patch(
"api.v1.views.env.str", return_value="bucket"
):
s3_client = mock_get_s3_client.return_value
s3_client.get_object.return_value = dummy_s3_object
response = authenticated_client.get(url)
# The view returns an HttpResponse (not a DRF Response) for file downloads
assert isinstance(response, HttpResponse)
assert response.status_code == 200
assert response.content == fake_file_content
# Check that the Content-Disposition header includes the filename "file.zip"
content_disp = response.get("Content-Disposition", "")
assert content_disp.startswith('attachment; filename="')
assert "file.zip" in content_disp
def test_report_local_no_files(self, authenticated_client, scans_fixture):
"""
If output_path does not start with "s3://" and glob.glob finds no matching files,
the view should return a 404 response.
"""
scan = scans_fixture[0]
scan.output_path = "/non/existent/path/*.zip"
scan.save()
url = reverse("scan-report", kwargs={"pk": scan.pk})
with patch("api.v1.views.glob.glob", return_value=[]):
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data["detail"] == "No local files found"
def test_report_local_ioerror(self, authenticated_client, scans_fixture):
"""
If output_path does not start with "s3://", glob.glob finds a file but reading it raises an IOError,
the view should return a 500 response.
"""
scan = scans_fixture[0]
scan.output_path = "/path/to/file.zip"
scan.save()
url = reverse("scan-report", kwargs={"pk": scan.pk})
with patch("api.v1.views.glob.glob", return_value=["/path/to/file.zip"]), patch(
"api.v1.views.open", side_effect=IOError
):
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert response.data["detail"] == "Error reading local file"
def test_report_local_success(self, authenticated_client, scans_fixture):
"""
If output_path does not start with "s3://", and a local file is found and read successfully,
the view should return an HttpResponse with the file contents.
"""
fake_file_content = b"local zip file content"
scan = scans_fixture[0]
scan.output_path = "/path/to/file.zip"
scan.save()
url = reverse("scan-report", kwargs={"pk": scan.pk})
m_open = mock_open(read_data=fake_file_content)
with patch("api.v1.views.glob.glob", return_value=["/path/to/file.zip"]), patch(
"api.v1.views.open", m_open
):
response = authenticated_client.get(url)
assert isinstance(response, HttpResponse)
assert response.status_code == 200
assert response.content == fake_file_content
content_disp = response.get("Content-Disposition", "")
assert content_disp.startswith('attachment; filename="')
# The filename should be the basename of the file path
assert "file.zip" in content_disp
@pytest.mark.django_db
class TestTaskViewSet:

View File

@@ -873,6 +873,14 @@ class ScanTaskSerializer(RLSSerializer):
]
class ScanReportSerializer(serializers.Serializer):
id = serializers.CharField(source="scan")
class Meta:
resource_name = "scan-reports"
fields = ["id"]
class ResourceTagSerializer(RLSSerializer):
"""
Serializer for the ResourceTag model

View File

@@ -1,6 +1,11 @@
import glob
import os
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
from celery.result import AsyncResult
from config.env import env
from config.settings.social_login import (
GITHUB_OAUTH_CALLBACK_URL,
GOOGLE_OAUTH_CALLBACK_URL,
@@ -12,6 +17,7 @@ from django.contrib.postgres.search import SearchQuery
from django.db import transaction
from django.db.models import Count, F, OuterRef, Prefetch, Q, Subquery, Sum
from django.db.models.functions import Coalesce
from django.http import HttpResponse
from django.urls import reverse
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_control
@@ -38,11 +44,11 @@ from rest_framework.permissions import SAFE_METHODS
from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
perform_scan_summary_task,
perform_scan_task,
)
@@ -121,6 +127,7 @@ from api.v1.serializers import (
RoleSerializer,
RoleUpdateSerializer,
ScanCreateSerializer,
ScanReportSerializer,
ScanSerializer,
ScanUpdateSerializer,
ScheduleDailyCreateSerializer,
@@ -1164,6 +1171,8 @@ class ScanViewSet(BaseRLSViewSet):
return ScanCreateSerializer
elif self.action == "partial_update":
return ScanUpdateSerializer
elif self.action == "report":
return ScanReportSerializer
return super().get_serializer_class()
def partial_update(self, request, *args, **kwargs):
@@ -1181,6 +1190,72 @@ class ScanViewSet(BaseRLSViewSet):
)
return Response(data=read_serializer.data, status=status.HTTP_200_OK)
@extend_schema(
tags=["Scan"],
summary="Download ZIP report",
description="Returns a ZIP file containing the requested report",
request=ScanReportSerializer,
responses={
200: OpenApiResponse(description="Report obtained successfully"),
403: OpenApiResponse(
description="There is a problem with the AWS credentials"
),
},
)
@action(detail=True, methods=["get"], url_name="report")
def report(self, request, pk=None):
scan_instance = Scan.objects.get(pk=pk)
output_path = scan_instance.output_path
if not output_path:
return Response(
{"detail": "No files found"}, status=status.HTTP_404_NOT_FOUND
)
if scan_instance.output_path.startswith("s3://"):
try:
s3_client = get_s3_client()
except (ClientError, NoCredentialsError, ParamValidationError):
return Response(
{"detail": "There is a problem with the AWS credentials."},
status=status.HTTP_403_FORBIDDEN,
)
bucket_name = env.str("DJANGO_ARTIFACTS_AWS_S3_OUTPUT_BUCKET")
try:
key = output_path[len(f"s3://{bucket_name}/") :]
s3_object = s3_client.get_object(Bucket=bucket_name, Key=key)
file_content = s3_object["Body"].read()
filename = os.path.basename(output_path.split("/")[-1])
except ClientError:
return Response(
{"detail": "Error accessing cloud storage"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
else:
zip_files = glob.glob(output_path)
if not zip_files:
return Response(
{"detail": "No local files found"}, status=status.HTTP_404_NOT_FOUND
)
try:
file_path = zip_files[0]
with open(file_path, "rb") as f:
file_content = f.read()
filename = os.path.basename(file_path)
except IOError:
return Response(
{"detail": "Error reading local file"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
response = HttpResponse(
file_content, content_type="application/x-zip-compressed"
)
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
def create(self, request, *args, **kwargs):
input_serializer = self.get_serializer(data=request.data)
input_serializer.is_valid(raise_exception=True)
@@ -1195,10 +1270,6 @@ class ScanViewSet(BaseRLSViewSet):
# Disabled for now
# checks_to_execute=scan.scanner_args.get("checks_to_execute"),
},
link=perform_scan_summary_task.si(
tenant_id=self.request.tenant_id,
scan_id=str(scan.id),
),
)
scan.task_id = task.id

View File

@@ -219,3 +219,20 @@ CACHE_STALE_WHILE_REVALIDATE = env.int("DJANGO_STALE_WHILE_REVALIDATE", 60)
TESTING = False
FINDINGS_MAX_DAYS_IN_RANGE = env.int("DJANGO_FINDINGS_MAX_DAYS_IN_RANGE", 7)
# API export settings
DJANGO_TMP_OUTPUT_DIRECTORY = env.str(
"DJANGO_TMP_OUTPUT_DIRECTORY", "/tmp/prowler_api_output"
)
DJANGO_FINDINGS_BATCH_SIZE = env.str("DJANGO_FINDINGS_BATCH_SIZE", 1000)
DJANGO_ARTIFACTS_AWS_S3_OUTPUT_BUCKET = env.str(
"DJANGO_ARTIFACTS_AWS_S3_OUTPUT_BUCKET", ""
)
DJANGO_ARTIFACTS_AWS_ACCESS_KEY_ID = env.str("DJANGO_ARTIFACTS_AWS_ACCESS_KEY_ID", "")
DJANGO_ARTIFACTS_AWS_SECRET_ACCESS_KEY = env.str(
"DJANGO_ARTIFACTS_AWS_SECRET_ACCESS_KEY", ""
)
DJANGO_ARTIFACTS_AWS_SESSION_TOKEN = env.str("DJANGO_ARTIFACTS_AWS_SESSION_TOKEN", "")
DJANGO_ARTIFACTS_AWS_DEFAULT_REGION = env.str("DJANGO_ARTIFACTS_AWS_DEFAULT_REGION", "")

View File

@@ -0,0 +1,157 @@
import os
import zipfile
import boto3
import config.django.base as base
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
from celery.utils.log import get_task_logger
from config.env import env
from prowler.config.config import (
csv_file_suffix,
html_file_suffix,
json_ocsf_file_suffix,
output_file_timestamp,
)
from prowler.lib.outputs.csv.csv import CSV
from prowler.lib.outputs.html.html import HTML
from prowler.lib.outputs.ocsf.ocsf import OCSF
logger = get_task_logger(__name__)
# Predefined mapping for output formats and their configurations
OUTPUT_FORMATS_MAPPING = {
"csv": {
"class": CSV,
"suffix": csv_file_suffix,
"kwargs": {},
},
"json-ocsf": {"class": OCSF, "suffix": json_ocsf_file_suffix, "kwargs": {}},
"html": {"class": HTML, "suffix": html_file_suffix, "kwargs": {"stats": {}}},
}
def _compress_output_files(output_directory: str) -> str:
"""
Compress output files from all configured output formats into a ZIP archive.
Args:
output_directory (str): The directory where the output files are located.
The function looks up all known suffixes in OUTPUT_FORMATS_MAPPING
and compresses those files into a single ZIP.
Returns:
str: The full path to the newly created ZIP archive.
"""
zip_path = f"{output_directory}.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for suffix in [config["suffix"] for config in OUTPUT_FORMATS_MAPPING.values()]:
zipf.write(
f"{output_directory}{suffix}",
f"artifacts/{output_directory.split('/')[-1]}{suffix}",
)
return zip_path
def get_s3_client():
"""
Create and return a boto3 S3 client using AWS credentials from environment variables.
This function attempts to initialize an S3 client by reading the AWS access key, secret key,
session token, and region from environment variables. It then validates the client by listing
available S3 buckets. If an error occurs during this process (for example, due to missing or
invalid credentials), it falls back to creating an S3 client without explicitly provided credentials,
which may rely on other configuration sources (e.g., IAM roles).
Returns:
boto3.client: A configured S3 client instance.
Raises:
ClientError, NoCredentialsError, or ParamValidationError if both attempts to create a client fail.
"""
s3_client = None
try:
s3_client = boto3.client(
"s3",
aws_access_key_id=env.str("DJANGO_ARTIFACTS_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=env.str("DJANGO_ARTIFACTS_AWS_SECRET_ACCESS_KEY"),
aws_session_token=env.str("DJANGO_ARTIFACTS_AWS_SESSION_TOKEN"),
region_name=env.str("DJANGO_ARTIFACTS_AWS_DEFAULT_REGION"),
)
s3_client.list_buckets()
except (ClientError, NoCredentialsError, ParamValidationError):
s3_client = boto3.client("s3")
s3_client.list_buckets()
return s3_client
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
"""
Upload the specified ZIP file to an S3 bucket.
If the S3 bucket environment variables are not configured,
the function returns None without performing an upload.
Args:
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
zip_path (str): The local file system path to the ZIP file to be uploaded.
scan_id (str): The scan identifier, used as part of the S3 key prefix.
Returns:
str: The S3 URI of the uploaded file (e.g., "s3://<bucket>/<key>") if successful.
None: If the required environment variables for the S3 bucket are not set.
Raises:
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
"""
if not base.DJANGO_ARTIFACTS_AWS_S3_OUTPUT_BUCKET:
return
s3 = get_s3_client()
s3_key = f"{tenant_id}/{scan_id}/{os.path.basename(zip_path)}"
try:
s3.upload_file(
Filename=zip_path,
Bucket=base.DJANGO_ARTIFACTS_AWS_S3_OUTPUT_BUCKET,
Key=s3_key,
)
return f"s3://{base.DJANGO_ARTIFACTS_AWS_S3_OUTPUT_BUCKET}/{s3_key}"
except ClientError as e:
logger.error(f"S3 upload failed: {str(e)}")
raise e
def _generate_output_directory(
output_directory, prowler_provider: object, tenant_id: str, scan_id: str
) -> str:
"""
Generate a file system path for the output directory of a prowler scan.
This function constructs the output directory path by combining a base
temporary output directory, the tenant ID, the scan ID, and details about
the prowler provider along with a timestamp. The resulting path is used to
store the output files of a prowler scan.
Note:
This function depends on one external variable:
- `output_file_timestamp`: A timestamp (as a string) used to uniquely identify the output.
Args:
output_directory (str): The base output directory.
prowler_provider (object): An identifier or descriptor for the prowler provider.
Typically, this is a string indicating the provider (e.g., "aws").
tenant_id (str): The unique identifier for the tenant.
scan_id (str): The unique identifier for the scan.
Returns:
str: The constructed file system path for the prowler scan output directory.
Example:
>>> _generate_output_directory("/tmp", "aws", "tenant-1234", "scan-5678")
'/tmp/tenant-1234/aws/scan-5678/prowler-output-2023-02-15T12:34:56'
"""
path = (
f"{output_directory}/{tenant_id}/{scan_id}/prowler-output-"
f"{prowler_provider}-{output_file_timestamp}"
)
os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
return path

View File

@@ -1,14 +1,26 @@
from celery import shared_task
from celery import chain, shared_task
from celery.utils.log import get_task_logger
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from tasks.jobs.connection import check_provider_connection
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.export import (
OUTPUT_FORMATS_MAPPING,
_compress_output_files,
_generate_output_directory,
_upload_to_s3,
)
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan
from tasks.utils import get_next_execution_datetime
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Scan, StateChoices
from api.models import Finding, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
@shared_task(base=RLSTask, name="provider-connection-check")
@@ -68,13 +80,20 @@ def perform_scan_task(
Returns:
dict: The result of the scan execution, typically including the status and results of the performed checks.
"""
return perform_prowler_scan(
result = perform_prowler_scan(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
checks_to_execute=checks_to_execute,
)
chain(
perform_scan_summary_task.s(tenant_id, scan_id),
generate_outputs.si(scan_id, provider_id, tenant_id=tenant_id),
).apply_async()
return result
@shared_task(base=RLSTask, bind=True, name="scan-perform-scheduled", queue="scans")
def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
@@ -135,12 +154,11 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scheduler_task_id=periodic_task_instance.id,
)
perform_scan_summary_task.apply_async(
kwargs={
"tenant_id": tenant_id,
"scan_id": str(scan_instance.id),
}
)
chain(
perform_scan_summary_task.s(tenant_id, scan_instance.id),
generate_outputs.si(str(scan_instance.id), provider_id, tenant_id=tenant_id),
).apply_async()
return result
@@ -152,3 +170,125 @@ def perform_scan_summary_task(tenant_id: str, scan_id: str):
@shared_task(name="tenant-deletion")
def delete_tenant_task(tenant_id: str):
return delete_tenant(pk=tenant_id)
def batched(iterable, batch_size):
"""
Yield successive batches from an iterable.
Args:
iterable: An iterable source of items.
batch_size (int): The number of items per batch.
Yields:
tuple: A pair (batch, is_last_batch) where:
- batch (list): A list of items (with length equal to batch_size,
except possibly for the last batch).
- is_last_batch (bool): True if this is the final batch, False otherwise.
"""
batch = []
for item in iterable:
batch.append(item)
if len(batch) == batch_size:
yield batch, False
batch = []
yield batch, True
@shared_task(base=RLSTask, name="scan-output", queue="scans")
@set_tenant(keep_tenant=True)
def generate_outputs(scan_id: str, provider_id: str, tenant_id: str):
"""
Process findings in batches and generate output files in multiple formats.
This function retrieves findings associated with a scan, processes them
in batches of 50, and writes each batch to the corresponding output files.
It reuses output writer instances across batches, updates them with each
batch of transformed findings, and uses a flag to indicate when the final
batch is being processed. Finally, the output files are compressed and
uploaded to S3.
Args:
tenant_id (str): The tenant identifier.
scan_id (str): The scan identifier.
provider_id (str): The provider_id id to be used in generating outputs.
"""
# Initialize the prowler provider
prowler_provider = initialize_prowler_provider(Provider.objects.get(id=provider_id))
# Get the provider UID
provider_uid = Provider.objects.get(id=provider_id).uid
# Generate and ensure the output directory exists
output_directory = _generate_output_directory(
DJANGO_TMP_OUTPUT_DIRECTORY, provider_uid, tenant_id, scan_id
)
# Define auxiliary variables
output_writers = {}
scan_summary = FindingOutput._transform_findings_stats(
ScanSummary.objects.filter(scan_id=scan_id)
)
# Retrieve findings queryset
findings_qs = Finding.objects.filter(scan_id=scan_id).order_by("uid")
# Process findings in batches
for batch, is_last_batch in batched(
findings_qs.iterator(), DJANGO_FINDINGS_BATCH_SIZE
):
finding_outputs = [
FindingOutput.transform_api_finding(finding) for finding in batch
]
# Generate output files
for mode, config in OUTPUT_FORMATS_MAPPING.items():
kwargs = dict(config.get("kwargs", {}))
if mode == "html":
kwargs["provider"] = prowler_provider
kwargs["stats"] = scan_summary
writer_class = config["class"]
if writer_class in output_writers:
writer = output_writers[writer_class]
writer.transform(finding_outputs)
writer.close_file = is_last_batch
else:
writer = writer_class(
findings=finding_outputs,
file_path=output_directory,
file_extension=config["suffix"],
from_cli=False,
)
writer.close_file = is_last_batch
output_writers[writer_class] = writer
# Write the current batch using the writer
writer.batch_write_data_to_file(**kwargs)
# TODO: Refactor the output classes to avoid this manual reset
writer._data = []
# Compress output files
output_directory = _compress_output_files(output_directory)
# Save to configured storage
uploaded = _upload_to_s3(tenant_id, output_directory, scan_id)
if uploaded:
output_directory = uploaded
uploaded = True
else:
uploaded = False
# Update the scan instance with the output path
Scan.objects.filter(id=scan_id).update(output_path=output_directory)
logger.info(f"Scan output files generated, output location: {output_directory}")
return {
"upload": uploaded,
"scan_id": scan_id,
"provider_id": provider_id,
}

View File

@@ -16,6 +16,7 @@ services:
volumes:
- "./api/src/backend:/home/prowler/backend"
- "./api/pyproject.toml:/home/prowler/pyproject.toml"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -85,6 +86,8 @@ services:
env_file:
- path: .env
required: false
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy

View File

@@ -7,6 +7,8 @@ services:
required: false
ports:
- "${DJANGO_PORT:-8080}:${DJANGO_PORT:-8080}"
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -65,6 +67,8 @@ services:
env_file:
- path: .env
required: false
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy

View File

@@ -351,7 +351,6 @@ def prowler():
if mode == "csv":
csv_output = CSV(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{csv_file_suffix}",
)
generated_outputs["regular"].append(csv_output)
@@ -361,7 +360,6 @@ def prowler():
if mode == "json-asff":
asff_output = ASFF(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{json_asff_file_suffix}",
)
generated_outputs["regular"].append(asff_output)
@@ -371,7 +369,6 @@ def prowler():
if mode == "json-ocsf":
json_output = OCSF(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{json_ocsf_file_suffix}",
)
generated_outputs["regular"].append(json_output)
@@ -379,7 +376,6 @@ def prowler():
if mode == "html":
html_output = HTML(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{html_file_suffix}",
)
generated_outputs["regular"].append(html_output)
@@ -402,7 +398,6 @@ def prowler():
cis = AWSCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -416,7 +411,6 @@ def prowler():
mitre_attack = AWSMitreAttack(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(mitre_attack)
@@ -430,7 +424,6 @@ def prowler():
ens = AWSENS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(ens)
@@ -444,7 +437,6 @@ def prowler():
aws_well_architected = AWSWellArchitected(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(aws_well_architected)
@@ -458,7 +450,6 @@ def prowler():
iso27001 = AWSISO27001(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(iso27001)
@@ -472,7 +463,6 @@ def prowler():
kisa_ismsp = AWSKISAISMSP(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(kisa_ismsp)
@@ -485,7 +475,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -502,7 +491,6 @@ def prowler():
cis = AzureCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -516,7 +504,6 @@ def prowler():
mitre_attack = AzureMitreAttack(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(mitre_attack)
@@ -530,7 +517,6 @@ def prowler():
ens = AzureENS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(ens)
@@ -543,7 +529,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -560,7 +545,6 @@ def prowler():
cis = GCPCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -574,7 +558,6 @@ def prowler():
mitre_attack = GCPMitreAttack(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(mitre_attack)
@@ -588,7 +571,6 @@ def prowler():
ens = GCPENS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(ens)
@@ -601,7 +583,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -618,7 +599,6 @@ def prowler():
cis = KubernetesCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -631,7 +611,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -648,7 +627,6 @@ def prowler():
cis = Microsoft365CIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -661,7 +639,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)

View File

@@ -29,11 +29,11 @@ class ComplianceOutput(Output):
self,
findings: List[Finding],
compliance: Compliance,
create_file_descriptor: bool = False,
file_path: str = None,
file_extension: str = "",
) -> None:
self._data = []
self.file_descriptor = None
if not file_extension and file_path:
self._file_extension = "".join(Path(file_path).suffixes)
@@ -48,7 +48,7 @@ class ComplianceOutput(Output):
else compliance.Framework
)
self.transform(findings, compliance, compliance_name)
if create_file_descriptor:
if not self._file_descriptor and file_path:
self.create_file_descriptor(file_path)
def batch_write_data_to_file(self) -> None:

View File

@@ -98,10 +98,12 @@ class CSV(Output):
fieldnames=self._data[0].keys(),
delimiter=";",
)
csv_writer.writeheader()
if self._file_descriptor.tell() == 0:
csv_writer.writeheader()
for finding in self._data:
csv_writer.writerow(finding)
self._file_descriptor.close()
if self.close_file or self._from_cli:
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"

View File

@@ -1,13 +1,20 @@
from datetime import datetime
from typing import Optional, Union
from typing import Optional, Tuple, Union
from pydantic import BaseModel, Field, ValidationError
from prowler.config.config import prowler_version
from prowler.lib.check.models import Check_Report, CheckMetadata
from prowler.lib.check.models import (
Check_Report,
CheckMetadata,
Remediation,
Code,
Recommendation,
)
from prowler.lib.logger import logger
from prowler.lib.outputs.common import Status, fill_common_finding_data
from prowler.lib.outputs.compliance.compliance import get_check_compliance
from prowler.lib.outputs.utils import unroll_tags
from prowler.lib.utils.utils import dict_to_lowercase, get_nested_attribute
from prowler.providers.common.provider import Provider
@@ -267,3 +274,207 @@ class Finding(BaseModel):
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
raise error
@staticmethod
def _get_auth_method_and_partition(provider: Provider) -> Tuple[str, str]:
"""
Extract the authentication method and partition information based on the provider type.
"""
auth_method = ""
partition = ""
if provider.type == "aws":
auth_method = f"profile: {get_nested_attribute(provider, 'identity.profile')}"
partition = get_nested_attribute(provider, "identity.partition")
elif provider.type == "azure":
auth_method = f"{provider.identity.identity_type}: {provider.identity.identity_id}"
partition = get_nested_attribute(provider, "region_config.name")
elif provider.type == "gcp":
auth_method = f"Principal: {get_nested_attribute(provider, 'identity.profile')}"
elif provider.type == "kubernetes":
auth_method = "in-cluster" if provider.identity.context == "In-Cluster" else "kubeconfig"
elif provider.type == "microsoft365":
auth_method = f"{provider.identity.identity_type}: {provider.identity.identity_id}"
return auth_method, partition
@classmethod
def transform_api_finding(cls, finding, provider) -> "Finding":
"""
Transform a FindingModel instance into an API-friendly Finding object.
This class method extracts data from a FindingModel instance and maps its
properties to a new Finding object. The transformation populates various
fields including authentication details, timestamp, account information,
check metadata (such as provider, check ID, title, type, service, severity,
and remediation details), as well as resource-specific data. The resulting
Finding object is structured for use in API responses or further processing.
Args:
finding (API Finding): An API Finding instance containing data from the database.
provider (Provider): the provider object.
Returns:
Finding: A new Finding instance populated with data from the provided model.
"""
output_data = {}
output_data["auth_method"], output_data["partition"] = cls._get_auth_method_and_partition(provider)
output_data["timestamp"] = finding.inserted_at
output_data["account_uid"] = finding.scan.provider.uid
output_data["account_name"] = ""
output_data["metadata"] = CheckMetadata(
Provider=finding.check_metadata["provider"],
CheckID=finding.check_metadata["checkid"],
CheckTitle=finding.check_metadata["checktitle"],
CheckType=finding.check_metadata["checktype"],
ServiceName=finding.check_metadata["servicename"],
SubServiceName=finding.check_metadata["subservicename"],
Severity=finding.check_metadata["severity"],
ResourceType=finding.check_metadata["resourcetype"],
Description=finding.check_metadata["description"],
Risk=finding.check_metadata["risk"],
RelatedUrl=finding.check_metadata["relatedurl"],
Remediation=Remediation(
Recommendation=Recommendation(
Text=finding.check_metadata["remediation"]["recommendation"][
"text"
],
Url=finding.check_metadata["remediation"]["recommendation"]["url"],
),
Code=Code(
NativeIaC=finding.check_metadata["remediation"]["code"][
"nativeiac"
],
Terraform=finding.check_metadata["remediation"]["code"][
"terraform"
],
CLI=finding.check_metadata["remediation"]["code"]["cli"],
Other=finding.check_metadata["remediation"]["code"]["other"],
),
),
ResourceIdTemplate=finding.check_metadata["resourceidtemplate"],
Categories=finding.check_metadata["categories"],
DependsOn=finding.check_metadata["dependson"],
RelatedTo=finding.check_metadata["relatedto"],
Notes=finding.check_metadata["notes"],
)
output_data["uid"] = finding.uid
output_data["status"] = Status(finding.status)
output_data["status_extended"] = finding.status_extended
resource = finding.resources.first()
output_data["resource_uid"] = resource.uid
output_data["resource_name"] = resource.name
output_data["resource_details"] = ""
resource_tags = resource.tags.all()
output_data["resource_tags"] = unroll_tags(
[{"key": tag.key, "value": tag.value} for tag in resource_tags]
)
output_data["region"] = resource.region
output_data["compliance"] = {}
return cls(**output_data)
def _transform_findings_stats(scan_summaries: list[dict]) -> dict:
"""
Aggregate and transform scan summary data into findings statistics.
This function processes a list of scan summary objects and calculates overall
metrics such as the total number of passed and failed findings (including muted counts),
as well as a breakdown of results by severity (critical, high, medium, and low).
It also retrieves the unique resource count from the associated scan information.
The final output is a dictionary of aggregated statistics intended for reporting or
further analysis.
Args:
scan_summaries (list[dict]): A list of scan summary objects. Each object is expected
to have attributes including:
- _pass: Number of passed findings.
- fail: Number of failed findings.
- total: Total number of findings.
- muted: Number indicating if the finding is muted.
- severity: A string representing the severity level.
Additionally, the first scan summary should have an associated
`scan` attribute with a `unique_resource_count`.
Returns:
dict: A dictionary containing aggregated findings statistics:
- total_pass: Total number of passed findings.
- total_muted_pass: Total number of muted passed findings.
- total_fail: Total number of failed findings.
- total_muted_fail: Total number of muted failed findings.
- resources_count: The unique resource count extracted from the scan.
- findings_count: Total number of findings.
- total_critical_severity_fail: Failed findings with critical severity.
- total_critical_severity_pass: Passed findings with critical severity.
- total_high_severity_fail: Failed findings with high severity.
- total_high_severity_pass: Passed findings with high severity.
- total_medium_severity_fail: Failed findings with medium severity.
- total_medium_severity_pass: Passed findings with medium severity.
- total_low_severity_fail: Failed findings with low severity.
- total_low_severity_pass: Passed findings with low severity.
- all_fails_are_muted: A boolean indicating whether all failing findings are muted.
"""
# Initialize overall counters
total_pass = 0
total_fail = 0
muted_pass = 0
muted_fail = 0
findings_count = 0
resources_count = scan_summaries[0].scan.unique_resource_count
# Initialize severity breakdown counters
critical_severity_pass = 0
critical_severity_fail = 0
high_severity_pass = 0
high_severity_fail = 0
medium_severity_pass = 0
medium_severity_fail = 0
low_severity_pass = 0
low_severity_fail = 0
# Loop over each row from the database
for row in scan_summaries:
# Accumulate overall totals
total_pass += row._pass
total_fail += row.fail
findings_count += row.total
if row.muted > 0:
if row._pass > 0:
muted_pass += row._pass
if row.fail > 0:
muted_fail += row.fail
sev = row.severity.lower()
if sev == "critical":
critical_severity_pass += row._pass
critical_severity_fail += row.fail
elif sev == "high":
high_severity_pass += row._pass
high_severity_fail += row.fail
elif sev == "medium":
medium_severity_pass += row._pass
medium_severity_fail += row.fail
elif sev == "low":
low_severity_pass += row._pass
low_severity_fail += row.fail
all_fails_are_muted = (total_fail > 0) and (total_fail == muted_fail)
stats = {
"total_pass": total_pass,
"total_muted_pass": muted_pass,
"total_fail": total_fail,
"total_muted_fail": muted_fail,
"resources_count": resources_count,
"findings_count": findings_count,
"total_critical_severity_fail": critical_severity_fail,
"total_critical_severity_pass": critical_severity_pass,
"total_high_severity_fail": high_severity_fail,
"total_high_severity_pass": high_severity_pass,
"total_medium_severity_fail": medium_severity_fail,
"total_medium_severity_pass": medium_severity_pass,
"total_low_severity_fail": low_severity_fail,
"total_low_severity_pass": low_severity_pass,
"all_fails_are_muted": all_fails_are_muted,
}
return stats

View File

@@ -74,12 +74,13 @@ class HTML(Output):
and not self._file_descriptor.closed
and self._data
):
HTML.write_header(self._file_descriptor, provider, stats)
if self._file_descriptor.tell() == 0:
HTML.write_header(self._file_descriptor, provider, stats)
for finding in self._data:
self._file_descriptor.write(finding)
HTML.write_footer(self._file_descriptor)
# Close file descriptor
self._file_descriptor.close()
if self.close_file or self._from_cli:
HTML.write_footer(self._file_descriptor)
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"

View File

@@ -200,7 +200,8 @@ class OCSF(Output):
and not self._file_descriptor.closed
and self._data
):
self._file_descriptor.write("[")
if self._file_descriptor.tell() == 0:
self._file_descriptor.write("[")
for finding in self._data:
try:
self._file_descriptor.write(
@@ -211,14 +212,14 @@ class OCSF(Output):
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
if self._file_descriptor.tell() > 0:
if self.close_file or self._from_cli:
if self._file_descriptor.tell() != 1:
self._file_descriptor.seek(
self._file_descriptor.tell() - 1, os.SEEK_SET
)
self._file_descriptor.truncate()
self._file_descriptor.write("]")
self._file_descriptor.close()
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"

View File

@@ -23,7 +23,6 @@ class Output(ABC):
file_descriptor: Property to access the file descriptor.
transform: Abstract method to transform findings into a specific format.
batch_write_data_to_file: Abstract method to write data to a file in batches.
create_file_descriptor: Method to create a file descriptor for writing data to a file.
"""
_data: list
@@ -33,21 +32,27 @@ class Output(ABC):
def __init__(
self,
findings: List[Finding],
create_file_descriptor: bool = False,
file_path: str = None,
file_extension: str = "",
from_cli: bool = True,
) -> None:
self._data = []
self.close_file = False
self.file_path = file_path
self._file_descriptor = None
# This parameter is to avoid refactoring more code, the CLI does not write in batches, the API does
self._from_cli = from_cli
if not file_extension and file_path:
self._file_extension = "".join(Path(file_path).suffixes)
if file_extension:
self._file_extension = file_extension
self.file_path = f"{file_path}{self.file_extension}"
if findings:
self.transform(findings)
if create_file_descriptor and file_path:
self.create_file_descriptor(file_path)
if not self._file_descriptor and file_path:
self.create_file_descriptor(self.file_path)
@property
def data(self):

View File

@@ -110,7 +110,7 @@ class S3:
for output in output_list:
try:
# Object is not written to file so we need to temporarily write it
if not hasattr(output, "file_descriptor"):
if not output.file_descriptor:
output.file_descriptor = NamedTemporaryFile(mode="a")
bucket_directory = self.get_object_path(self._output_directory)

View File

@@ -577,7 +577,7 @@ class TestASFF:
assert loads(content) == expected_asff
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(ASFF([]), "_file_descriptor")
assert not ASFF([])._file_descriptor
def test_asff_generate_status(self):
assert ASFF.generate_status("PASS") == "PASSED"

View File

@@ -119,7 +119,7 @@ class TestCSV:
assert content == expected_csv
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(CSV([]), "_file_descriptor")
assert not CSV([])._file_descriptor
@pytest.fixture
def mock_output_class(self):
@@ -144,9 +144,7 @@ class TestCSV:
file_path = file.name
# Instantiate the mock class
output_instance = mock_output_class(
findings, create_file_descriptor=True, file_path=file_path
)
output_instance = mock_output_class(findings, file_path=file_path)
# Check that transform was called once
output_instance.transform.assert_called_once_with(findings)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@@ -55,6 +56,58 @@ def mock_get_check_compliance(*_):
return {"mock_compliance_key": "mock_compliance_value"}
class DummyTag:
def __init__(self, key, value):
self.key = key
self.value = value
class DummyTags:
def __init__(self, tags):
self._tags = tags
def all(self):
return self._tags
class DummyResource:
def __init__(self, uid, name, region, tags):
self.uid = uid
self.name = name
self.region = region
self.tags = DummyTags(tags)
class DummyResources:
"""Simulate a collection with a first() method."""
def __init__(self, resource):
self._resource = resource
def first(self):
return self._resource
class DummyProvider:
def __init__(self, uid):
self.uid = uid
self.type = "aws"
class DummyScan:
def __init__(self, provider):
self.provider = provider
class DummyAPIFinding:
"""
A dummy API finding model to simulate the database model.
Attributes will be added dynamically.
"""
pass
class TestFinding:
@patch(
"prowler.lib.outputs.finding.get_check_compliance",
@@ -461,3 +514,222 @@ class TestFinding:
# Generate the finding
with pytest.raises(ValidationError):
Finding.generate_output(provider, check_output, output_options)
def test_transform_api_finding(self):
"""
Test that a dummy API Finding is correctly
transformed into a Finding instance.
"""
# Set up the dummy API finding attributes
inserted_at = 1234567890
provider = DummyProvider(uid="account123")
scan = DummyScan(provider=provider)
# Create a dummy resource with one tag
tag = DummyTag("env", "prod")
resource = DummyResource(
uid="res-uid-1", name="ResourceName1", region="us-east-1", tags=[tag]
)
resources = DummyResources(resource)
# Create a dummy check_metadata dict with all required fields
check_metadata = {
"provider": "test_provider",
"checkid": "check-001",
"checktitle": "Test Check",
"checktype": ["type1"],
"servicename": "TestService",
"subservicename": "SubService",
"severity": "high",
"resourcetype": "TestResource",
"description": "A test check",
"risk": "High risk",
"relatedurl": "http://example.com",
"remediation": {
"recommendation": {"text": "Fix it", "url": "http://fix.com"},
"code": {
"nativeiac": "iac_code",
"terraform": "terraform_code",
"cli": "cli_code",
"other": "other_code",
},
},
"resourceidtemplate": "template",
"categories": ["cat-one", "cat-two"],
"dependson": ["dep1"],
"relatedto": ["rel1"],
"notes": "Some notes",
}
# Create the dummy API finding and assign required attributes
dummy_finding = DummyAPIFinding()
dummy_finding.inserted_at = inserted_at
dummy_finding.scan = scan
dummy_finding.uid = "finding-uid-1"
dummy_finding.status = "FAIL" # will be converted to Status("FAIL")
dummy_finding.status_extended = "extended"
dummy_finding.check_metadata = check_metadata
dummy_finding.resources = resources
# Call the transform_api_finding classmethod
finding_obj = Finding.transform_api_finding(dummy_finding, provider)
# Fields directly set in transform_api_finding
assert finding_obj.auth_method == "profile: "
assert finding_obj.timestamp == inserted_at
assert finding_obj.account_uid == "account123"
assert finding_obj.account_name == ""
# Check that metadata was built correctly
meta = finding_obj.metadata
assert meta.Provider == "test_provider"
assert meta.CheckID == "check-001"
assert meta.CheckTitle == "Test Check"
assert meta.CheckType == ["type1"]
assert meta.ServiceName == "TestService"
assert meta.SubServiceName == "SubService"
assert meta.Severity == "high"
assert meta.ResourceType == "TestResource"
assert meta.Description == "A test check"
assert meta.Risk == "High risk"
assert meta.RelatedUrl == "http://example.com"
assert meta.Remediation.Recommendation.Text == "Fix it"
assert meta.Remediation.Recommendation.Url == "http://fix.com"
assert meta.Remediation.Code.NativeIaC == "iac_code"
assert meta.Remediation.Code.Terraform == "terraform_code"
assert meta.Remediation.Code.CLI == "cli_code"
assert meta.Remediation.Code.Other == "other_code"
assert meta.ResourceIdTemplate == "template"
assert meta.Categories == ["cat-one", "cat-two"]
assert meta.DependsOn == ["dep1"]
assert meta.RelatedTo == ["rel1"]
assert meta.Notes == "Some notes"
# Check other Finding fields
assert finding_obj.uid == "finding-uid-1"
assert finding_obj.status == Status("FAIL")
assert finding_obj.status_extended == "extended"
# From the dummy resource
assert finding_obj.resource_uid == "res-uid-1"
assert finding_obj.resource_name == "ResourceName1"
assert finding_obj.resource_details == ""
# unroll_tags is called on a list with one tag -> expect {"env": "prod"}
assert finding_obj.resource_tags == {"env": "prod"}
assert finding_obj.region == "us-east-1"
# compliance is hardcoded to an empty dict
assert finding_obj.compliance == {}
def test_transform_findings_stats_all_fails_muted(self):
"""
Test _transform_findings_stats when every failing finding is muted.
"""
# Create a dummy scan object with a unique_resource_count
dummy_scan = SimpleNamespace(unique_resource_count=10)
# Build summaries covering each severity branch.
ss1 = SimpleNamespace(
_pass=1, fail=2, total=3, muted=2, severity="critical", scan=dummy_scan
)
ss2 = SimpleNamespace(
_pass=2, fail=0, total=2, muted=0, severity="high", scan=dummy_scan
)
ss3 = SimpleNamespace(
_pass=2, fail=3, total=5, muted=3, severity="medium", scan=dummy_scan
)
ss4 = SimpleNamespace(
_pass=3, fail=0, total=3, muted=0, severity="low", scan=dummy_scan
)
summaries = [ss1, ss2, ss3, ss4]
stats = Finding._transform_findings_stats(summaries)
# Expected calculations:
# total_pass = 1+2+2+3 = 8
# total_fail = 2+0+3+0 = 5
# findings_count = 3+2+5+3 = 13
# muted_pass = (ss1: 1) + (ss3: 2) = 3
# muted_fail = (ss1: 2) + (ss3: 3) = 5
expected = {
"total_pass": 8,
"total_muted_pass": 3,
"total_fail": 5,
"total_muted_fail": 5,
"resources_count": 10,
"findings_count": 13,
"total_critical_severity_fail": 2,
"total_critical_severity_pass": 1,
"total_high_severity_fail": 0,
"total_high_severity_pass": 2,
"total_medium_severity_fail": 3,
"total_medium_severity_pass": 2,
"total_low_severity_fail": 0,
"total_low_severity_pass": 3,
"all_fails_are_muted": True, # total_fail equals muted_fail and total_fail > 0
}
assert stats == expected
def test_transform_findings_stats_not_all_fails_muted(self):
"""
Test _transform_findings_stats when at least one failing finding is not muted.
"""
dummy_scan = SimpleNamespace(unique_resource_count=5)
# Build summaries: one summary has fail > 0 but muted == 0
ss1 = SimpleNamespace(
_pass=1, fail=2, total=3, muted=0, severity="critical", scan=dummy_scan
)
ss2 = SimpleNamespace(
_pass=2, fail=1, total=3, muted=1, severity="high", scan=dummy_scan
)
summaries = [ss1, ss2]
stats = Finding._transform_findings_stats(summaries)
# Expected calculations:
# total_pass = 1+2 = 3
# total_fail = 2+1 = 3
# findings_count = 3+3 = 6
# muted_pass = (ss2: 2) since ss1 muted is 0
# muted_fail = (ss2: 1)
# Severity breakdown: critical: pass 1, fail 2; high: pass 2, fail 1
expected = {
"total_pass": 3,
"total_muted_pass": 2,
"total_fail": 3,
"total_muted_fail": 1,
"resources_count": 5,
"findings_count": 6,
"total_critical_severity_fail": 2,
"total_critical_severity_pass": 1,
"total_high_severity_fail": 1,
"total_high_severity_pass": 2,
"total_medium_severity_fail": 0,
"total_medium_severity_pass": 0,
"total_low_severity_fail": 0,
"total_low_severity_pass": 0,
"all_fails_are_muted": False, # 3 (total_fail) != 1 (muted_fail)
}
assert stats == expected
def test_transform_api_finding_validation_error(self):
"""
Test that if required data is missing (causing a ValidationError)
the function logs the error and re-raises the exception.
For example, if the metadata dict is missing required keys.
"""
provider = DummyProvider(uid="account123")
# Create a dummy API finding that is missing some required metadata
dummy_finding = DummyAPIFinding()
dummy_finding.inserted_at = 1234567890
dummy_finding.scan = DummyScan(provider=provider)
dummy_finding.uid = "finding-uid-invalid"
dummy_finding.status = "PASS"
dummy_finding.status_extended = "extended"
# Missing required metadata keys using an empty dict
dummy_finding.check_metadata = {}
# Provide a dummy resources with a minimal resource
tag = DummyTag("env", "prod")
resource = DummyResource(
uid="res-uid-1", name="ResourceName1", region="us-east-1", tags=[tag]
)
dummy_finding.resources = DummyResources(resource)
with pytest.raises(KeyError):
Finding.transform_api_finding(dummy_finding, provider)

View File

@@ -492,7 +492,7 @@ class TestHTML:
assert content == get_aws_html_header(args) + pass_html_finding + html_footer
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(HTML([]), "_file_descriptor")
assert not HTML([])._file_descriptor
def test_write_header(self):
mock_file = StringIO()

View File

@@ -258,7 +258,7 @@ class TestOCSF:
assert json.loads(content) == expected_json_output
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(OCSF([]), "_file_descriptor")
assert not OCSF([])._file_descriptor
def test_finding_output_cloud_pass_low_muted(self):
finding_output = generate_finding_output(

View File

@@ -113,7 +113,6 @@ class TestS3:
csv_file = f"test{extension}"
csv = CSV(
findings=[FINDING],
create_file_descriptor=True,
file_path=f"{CURRENT_DIRECTORY}/{csv_file}",
)

View File

@@ -198,3 +198,40 @@ export const updateScan = async (formData: FormData) => {
};
}
};
export const getExportsZip = async (scanId: string) => {
const session = await auth();
const keyServer = process.env.API_BASE_URL;
const url = new URL(`${keyServer}/scans/${scanId}/report`);
try {
const response = await fetch(url.toString(), {
headers: {
Authorization: `Bearer ${session?.accessToken}`,
},
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(
errorData?.errors?.[0]?.detail || "Failed to fetch report",
);
}
// Get the blob data as an array buffer
const arrayBuffer = await response.arrayBuffer();
// Convert to base64
const base64 = Buffer.from(arrayBuffer).toString("base64");
return {
success: true,
data: base64,
filename: `scan-${scanId}-report.zip`,
};
} catch (error) {
return {
error: getErrorMessage(error),
};
}
};

View File

@@ -13,10 +13,12 @@ import {
EditDocumentBulkIcon,
} from "@nextui-org/shared-icons";
import { Row } from "@tanstack/react-table";
// import clsx from "clsx";
import { DownloadIcon } from "lucide-react";
import { useState } from "react";
import { getExportsZip } from "@/actions/scans";
import { VerticalDotsIcon } from "@/components/icons";
import { useToast } from "@/components/ui";
import { CustomAlertModal } from "@/components/ui/custom";
import { EditScanForm } from "../../forms";
@@ -30,9 +32,47 @@ const iconClasses =
export function DataTableRowActions<ScanProps>({
row,
}: DataTableRowActionsProps<ScanProps>) {
const { toast } = useToast();
const [isEditOpen, setIsEditOpen] = useState(false);
const scanId = (row.original as { id: string }).id;
const scanName = (row.original as any).attributes?.name;
const scanState = (row.original as any).attributes?.state;
const handleExportZip = async () => {
const result = await getExportsZip(scanId);
if (result?.success && result?.data) {
// Convert base64 to blob
const binaryString = window.atob(result.data);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
const blob = new Blob([bytes], { type: "application/zip" });
// Create download link
const url = window.URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = result.filename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
toast({
title: "Download Complete",
description: "Your scan report has been downloaded successfully.",
});
} else if (result?.error) {
toast({
variant: "destructive",
title: "Download Failed",
description: result.error,
});
}
};
return (
<>
<CustomAlertModal
@@ -63,6 +103,18 @@ export function DataTableRowActions<ScanProps>({
color="default"
variant="flat"
>
<DropdownSection title="Export artifacts">
<DropdownItem
key="export"
description="Available only for completed scans"
textValue="Export Scan Artifacts"
startContent={<DownloadIcon className={iconClasses} />}
onPress={handleExportZip}
isDisabled={scanState !== "completed"}
>
Download .zip
</DropdownItem>
</DropdownSection>
<DropdownSection title="Actions">
<DropdownItem
key="edit"

View File

@@ -95,7 +95,7 @@ const ToastTitle = React.forwardRef<
>(({ className, ...props }, ref) => (
<ToastPrimitives.Title
ref={ref}
className={cn("[&+div]:text-md text-lg font-semibold", className)}
className={cn("[&+div]:text-md font-semibold", className)}
{...props}
/>
));
@@ -107,7 +107,7 @@ const ToastDescription = React.forwardRef<
>(({ className, ...props }, ref) => (
<ToastPrimitives.Description
ref={ref}
className={cn("text-md opacity-90", className)}
className={cn("text-small opacity-90", className)}
{...props}
/>
));