feat(export): add API export system (#6878)

This commit is contained in:
Adrián Jesús Peña Rodríguez
2025-02-26 15:49:44 +01:00
committed by GitHub
parent c4528200b0
commit 669ec74e67
34 changed files with 1613 additions and 90 deletions

24
.env
View File

@@ -30,6 +30,30 @@ VALKEY_HOST=valkey
VALKEY_PORT=6379
VALKEY_DB=0
# API scan settings
# The path to the directory where scan output should be stored
DJANGO_TMP_OUTPUT_DIRECTORY = "/tmp/prowler_api_output"
# The maximum number of findings to process in a single batch
DJANGO_FINDINGS_BATCH_SIZE = 1000
# The AWS access key to be used when uploading scan output to an S3 bucket
# If left empty, default AWS credentials resolution behavior will be used
DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID=""
# The AWS secret key to be used when uploading scan output to an S3 bucket
DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY=""
# An optional AWS session token
DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN=""
# The AWS region where your S3 bucket is located (e.g., "us-east-1")
DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION=""
# The name of the S3 bucket where scan output should be stored
DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET=""
# Django settings
DJANGO_ALLOWED_HOSTS=localhost,127.0.0.1,prowler-api
DJANGO_BIND_ADDRESS=0.0.0.0

View File

@@ -8,6 +8,7 @@ All notable changes to the **Prowler API** are documented in this file.
### Added
- Social login integration with Google and GitHub [(#6906)](https://github.com/prowler-cloud/prowler/pull/6906)
- Add API scan report system, now all scans launched from the API will generate a compressed file with the report in OCSF, CSV and HTML formats [(#6878)](https://github.com/prowler-cloud/prowler/pull/6878).
- Configurable Sentry integration [(#6874)](https://github.com/prowler-cloud/prowler/pull/6874)
### Changed

View File

@@ -28,7 +28,7 @@ start_prod_server() {
start_worker() {
echo "Starting the worker..."
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,deletion -E --max-tasks-per-child 1
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion -E --max-tasks-per-child 1
}
start_worker_beat() {

View File

@@ -7,7 +7,7 @@ from rest_framework_json_api.serializers import ValidationError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
def set_tenant(func):
def set_tenant(func=None, *, keep_tenant=False):
"""
Decorator to set the tenant context for a Celery task based on the provided tenant_id.
@@ -40,20 +40,29 @@ def set_tenant(func):
# The tenant context will be set before the task logic executes.
"""
@wraps(func)
@transaction.atomic
def wrapper(*args, **kwargs):
try:
tenant_id = kwargs.pop("tenant_id")
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
def decorator(func):
@wraps(func)
@transaction.atomic
def wrapper(*args, **kwargs):
try:
if not keep_tenant:
tenant_id = kwargs.pop("tenant_id")
else:
tenant_id = kwargs["tenant_id"]
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
return func(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
return wrapper
if func is None:
return decorator
else:
return decorator(func)

View File

@@ -0,0 +1,15 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0011_findings_performance_indexes_parent"),
]
operations = [
migrations.AddField(
model_name="scan",
name="output_location",
field=models.CharField(blank=True, max_length=200, null=True),
),
]

View File

@@ -414,6 +414,7 @@ class Scan(RowLevelSecurityProtectedModel):
scheduler_task = models.ForeignKey(
PeriodicTask, on_delete=models.CASCADE, null=True, blank=True
)
output_location = models.CharField(blank=True, null=True, max_length=200)
# TODO: mutelist foreign key
class Meta(RowLevelSecurityProtectedModel.Meta):

View File

@@ -4105,6 +4105,43 @@ paths:
schema:
$ref: '#/components/schemas/ScanUpdateResponse'
description: ''
/api/v1/scans/{id}/report:
get:
operationId: scans_report_retrieve
description: Returns a ZIP file containing the requested report
summary: Download ZIP report
parameters:
- in: query
name: fields[scan-reports]
schema:
type: array
items:
type: string
enum:
- id
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this scan.
required: true
tags:
- Scan
security:
- jwtAuth: []
responses:
'200':
description: Report obtained successfully
'202':
description: The task is in progress
'403':
description: There is a problem with credentials
'404':
description: The scan has no reports
/api/v1/schedules/daily:
post:
operationId: schedules_daily_create

View File

@@ -274,9 +274,10 @@ class TestValidateInvitation:
expired_time = datetime.now(timezone.utc) - timedelta(days=1)
invitation.expires_at = expired_time
with patch("api.utils.Invitation.objects.using") as mock_using, patch(
"api.utils.datetime"
) as mock_datetime:
with (
patch("api.utils.Invitation.objects.using") as mock_using,
patch("api.utils.datetime") as mock_datetime,
):
mock_db = mock_using.return_value
mock_db.get.return_value = invitation
mock_datetime.now.return_value = datetime.now(timezone.utc)

View File

@@ -1,9 +1,13 @@
import glob
import io
import json
import os
from datetime import datetime, timedelta, timezone
from unittest.mock import ANY, Mock, patch
import jwt
import pytest
from botocore.exceptions import NoCredentialsError
from conftest import API_JSON_CONTENT_TYPE, TEST_PASSWORD, TEST_USER
from django.conf import settings
from django.urls import reverse
@@ -20,6 +24,7 @@ from api.models import (
RoleProviderGroupRelationship,
Scan,
StateChoices,
Task,
User,
UserRoleRelationship,
)
@@ -2079,9 +2084,9 @@ class TestScanViewSet:
("started_at.gte", "2024-01-01", 3),
("started_at.lte", "2024-01-01", 0),
("trigger", Scan.TriggerChoices.MANUAL, 1),
("state", StateChoices.AVAILABLE, 2),
("state", StateChoices.AVAILABLE, 1),
("state", StateChoices.FAILED, 1),
("state.in", f"{StateChoices.FAILED},{StateChoices.AVAILABLE}", 3),
("state.in", f"{StateChoices.FAILED},{StateChoices.AVAILABLE}", 2),
("trigger", Scan.TriggerChoices.MANUAL, 1),
]
),
@@ -2156,6 +2161,159 @@ class TestScanViewSet:
response = authenticated_client.get(reverse("scan-list"), {"sort": "invalid"})
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_report_executing(self, authenticated_client, scans_fixture):
"""
When the scan is still executing (state == EXECUTING), the view should return
the task data with HTTP 202 and a Content-Location header.
"""
scan = scans_fixture[0]
scan.state = StateChoices.EXECUTING
scan.save()
task = Task.objects.create(tenant_id=scan.tenant_id)
dummy_task_data = {"id": str(task.id), "state": StateChoices.EXECUTING}
scan.task = task
scan.save()
with patch(
"api.v1.views.TaskSerializer",
return_value=type("DummySerializer", (), {"data": dummy_task_data}),
):
url = reverse("scan-report", kwargs={"pk": scan.id})
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_202_ACCEPTED
assert "Content-Location" in response
assert dummy_task_data["id"] in response["Content-Location"]
def test_report_celery_task_executing(self, authenticated_client, scans_fixture):
"""
When the scan is not executing but a related celery task exists and is running,
the view should return that task data with HTTP 202.
"""
scan = scans_fixture[0]
scan.state = StateChoices.COMPLETED
scan.output_location = "dummy"
scan.save()
dummy_task = Task.objects.create(tenant_id=scan.tenant_id)
dummy_task.id = "dummy-task-id"
dummy_task_data = {"id": dummy_task.id, "state": StateChoices.EXECUTING}
with patch("api.v1.views.Task.objects.get", return_value=dummy_task), patch(
"api.v1.views.TaskSerializer",
return_value=type("DummySerializer", (), {"data": dummy_task_data}),
):
url = reverse("scan-report", kwargs={"pk": scan.id})
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_202_ACCEPTED
assert "Content-Location" in response
assert dummy_task_data["id"] in response["Content-Location"]
def test_report_no_output_location(self, authenticated_client, scans_fixture):
"""
If the scan does not have an output_location, the view should return a 404.
"""
scan = scans_fixture[0]
scan.state = StateChoices.COMPLETED
scan.output_location = ""
scan.save()
url = reverse("scan-report", kwargs={"pk": scan.id})
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json()["errors"]["detail"] == "The scan has no reports."
def test_report_s3_no_credentials(
self, authenticated_client, scans_fixture, monkeypatch
):
"""
When output_location is an S3 URL and get_s3_client() raises a credentials exception,
the view should return HTTP 403 with the proper error message.
"""
scan = scans_fixture[0]
bucket = "test-bucket"
key = "report.zip"
scan.output_location = f"s3://{bucket}/{key}"
scan.state = StateChoices.COMPLETED
scan.save()
def fake_get_s3_client():
raise NoCredentialsError()
monkeypatch.setattr("api.v1.views.get_s3_client", fake_get_s3_client)
url = reverse("scan-report", kwargs={"pk": scan.id})
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert (
response.json()["errors"]["detail"]
== "There is a problem with credentials."
)
def test_report_s3_success(self, authenticated_client, scans_fixture, monkeypatch):
"""
When output_location is an S3 URL and the S3 client returns the file successfully,
the view should return the ZIP file with HTTP 200 and proper headers.
"""
scan = scans_fixture[0]
bucket = "test-bucket"
key = "report.zip"
scan.output_location = f"s3://{bucket}/{key}"
scan.state = StateChoices.COMPLETED
scan.save()
monkeypatch.setattr(
"api.v1.views.env", type("env", (), {"str": lambda self, key: bucket})()
)
class FakeS3Client:
def get_object(self, Bucket, Key):
assert Bucket == bucket
assert Key == key
return {"Body": io.BytesIO(b"s3 zip content")}
monkeypatch.setattr("api.v1.views.get_s3_client", lambda: FakeS3Client())
url = reverse("scan-report", kwargs={"pk": scan.id})
response = authenticated_client.get(url)
assert response.status_code == 200
expected_filename = os.path.basename("report.zip")
content_disposition = response.get("Content-Disposition")
assert content_disposition.startswith('attachment; filename="')
assert f'filename="{expected_filename}"' in content_disposition
assert response.content == b"s3 zip content"
def test_report_local_file(
self, authenticated_client, scans_fixture, tmp_path, monkeypatch
):
"""
When output_location is a local file path, the view should read the file from disk
and return it with proper headers.
"""
scan = scans_fixture[0]
file_content = b"local zip file content"
file_path = tmp_path / "report.zip"
file_path.write_bytes(file_content)
scan.output_location = str(file_path)
scan.state = StateChoices.COMPLETED
scan.save()
monkeypatch.setattr(
glob,
"glob",
lambda pattern: [str(file_path)] if pattern == str(file_path) else [],
)
url = reverse("scan-report", kwargs={"pk": scan.id})
response = authenticated_client.get(url)
assert response.status_code == 200
assert response.content == file_content
content_disposition = response.get("Content-Disposition")
assert content_disposition.startswith('attachment; filename="')
assert f'filename="{file_path.name}"' in content_disposition
@pytest.mark.django_db
class TestTaskViewSet:

View File

@@ -939,6 +939,14 @@ class ScanTaskSerializer(RLSSerializer):
]
class ScanReportSerializer(serializers.Serializer):
id = serializers.CharField(source="scan")
class Meta:
resource_name = "scan-reports"
fields = ["id"]
class ResourceTagSerializer(RLSSerializer):
"""
Serializer for the ResourceTag model

View File

@@ -1,6 +1,11 @@
import glob
import os
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
from celery.result import AsyncResult
from config.env import env
from config.settings.social_login import (
GITHUB_OAUTH_CALLBACK_URL,
GOOGLE_OAUTH_CALLBACK_URL,
@@ -12,6 +17,7 @@ from django.contrib.postgres.search import SearchQuery
from django.db import transaction
from django.db.models import Count, Exists, F, OuterRef, Prefetch, Q, Subquery, Sum
from django.db.models.functions import Coalesce
from django.http import HttpResponse
from django.urls import reverse
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_control
@@ -38,11 +44,11 @@ from rest_framework.permissions import SAFE_METHODS
from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
perform_scan_summary_task,
perform_scan_task,
)
@@ -121,6 +127,7 @@ from api.v1.serializers import (
RoleSerializer,
RoleUpdateSerializer,
ScanCreateSerializer,
ScanReportSerializer,
ScanSerializer,
ScanUpdateSerializer,
ScheduleDailyCreateSerializer,
@@ -1116,6 +1123,18 @@ class ProviderViewSet(BaseRLSViewSet):
request=ScanCreateSerializer,
responses={202: OpenApiResponse(response=TaskSerializer)},
),
report=extend_schema(
tags=["Scan"],
summary="Download ZIP report",
description="Returns a ZIP file containing the requested report",
request=ScanReportSerializer,
responses={
200: OpenApiResponse(description="Report obtained successfully"),
202: OpenApiResponse(description="The task is in progress"),
403: OpenApiResponse(description="There is a problem with credentials"),
404: OpenApiResponse(description="The scan has no reports"),
},
),
)
@method_decorator(CACHE_DECORATOR, name="list")
@method_decorator(CACHE_DECORATOR, name="retrieve")
@@ -1164,6 +1183,10 @@ class ScanViewSet(BaseRLSViewSet):
return ScanCreateSerializer
elif self.action == "partial_update":
return ScanUpdateSerializer
elif self.action == "report":
if hasattr(self, "response_serializer_class"):
return self.response_serializer_class
return ScanReportSerializer
return super().get_serializer_class()
def partial_update(self, request, *args, **kwargs):
@@ -1181,6 +1204,93 @@ class ScanViewSet(BaseRLSViewSet):
)
return Response(data=read_serializer.data, status=status.HTTP_200_OK)
@action(detail=True, methods=["get"], url_name="report")
def report(self, request, pk=None):
scan_instance = self.get_object()
if scan_instance.state == StateChoices.EXECUTING:
# If the scan is still running, return the task
prowler_task = Task.objects.get(id=scan_instance.task.id)
self.response_serializer_class = TaskSerializer
output_serializer = self.get_serializer(prowler_task)
return Response(
data=output_serializer.data,
status=status.HTTP_202_ACCEPTED,
headers={
"Content-Location": reverse(
"task-detail", kwargs={"pk": output_serializer.data["id"]}
)
},
)
try:
output_celery_task = Task.objects.get(
task_runner_task__task_name="scan-report",
task_runner_task__task_args__contains=pk,
)
self.response_serializer_class = TaskSerializer
output_serializer = self.get_serializer(output_celery_task)
if output_serializer.data["state"] == StateChoices.EXECUTING:
# If the task is still running, return the task
return Response(
data=output_serializer.data,
status=status.HTTP_202_ACCEPTED,
headers={
"Content-Location": reverse(
"task-detail", kwargs={"pk": output_serializer.data["id"]}
)
},
)
except Task.DoesNotExist:
# If the task does not exist, it means that the task is removed from the database
pass
output_location = scan_instance.output_location
if not output_location:
return Response(
{"detail": "The scan has no reports."},
status=status.HTTP_404_NOT_FOUND,
)
if scan_instance.output_location.startswith("s3://"):
try:
s3_client = get_s3_client()
except (ClientError, NoCredentialsError, ParamValidationError):
return Response(
{"detail": "There is a problem with credentials."},
status=status.HTTP_403_FORBIDDEN,
)
bucket_name = env.str("DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET")
key = output_location[len(f"s3://{bucket_name}/") :]
try:
s3_object = s3_client.get_object(Bucket=bucket_name, Key=key)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "NoSuchKey":
return Response(
{"detail": "The scan has no reports."},
status=status.HTTP_404_NOT_FOUND,
)
return Response(
{"detail": "There is a problem with credentials."},
status=status.HTTP_403_FORBIDDEN,
)
file_content = s3_object["Body"].read()
filename = os.path.basename(output_location.split("/")[-1])
else:
zip_files = glob.glob(output_location)
file_path = zip_files[0]
with open(file_path, "rb") as f:
file_content = f.read()
filename = os.path.basename(file_path)
response = HttpResponse(
file_content, content_type="application/x-zip-compressed"
)
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
def create(self, request, *args, **kwargs):
input_serializer = self.get_serializer(data=request.data)
input_serializer.is_valid(raise_exception=True)
@@ -1195,10 +1305,6 @@ class ScanViewSet(BaseRLSViewSet):
# Disabled for now
# checks_to_execute=scan.scanner_args.get("checks_to_execute"),
},
link=perform_scan_summary_task.si(
tenant_id=self.request.tenant_id,
scan_id=str(scan.id),
),
)
scan.task_id = task.id

View File

@@ -221,3 +221,18 @@ CACHE_STALE_WHILE_REVALIDATE = env.int("DJANGO_STALE_WHILE_REVALIDATE", 60)
TESTING = False
FINDINGS_MAX_DAYS_IN_RANGE = env.int("DJANGO_FINDINGS_MAX_DAYS_IN_RANGE", 7)
# API export settings
DJANGO_TMP_OUTPUT_DIRECTORY = env.str(
"DJANGO_TMP_OUTPUT_DIRECTORY", "/tmp/prowler_api_output"
)
DJANGO_FINDINGS_BATCH_SIZE = env.str("DJANGO_FINDINGS_BATCH_SIZE", 1000)
DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET = env.str("DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET", "")
DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID = env.str("DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID", "")
DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY = env.str(
"DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY", ""
)
DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN = env.str("DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN", "")
DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION = env.str("DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION", "")

View File

@@ -486,7 +486,7 @@ def scans_fixture(tenants_fixture, providers_fixture):
name="Scan 1",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.AVAILABLE,
state=StateChoices.COMPLETED,
tenant_id=tenant.id,
started_at="2024-01-02T00:00:00Z",
)

View File

@@ -0,0 +1,156 @@
import os
import zipfile
import boto3
import config.django.base as base
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
from celery.utils.log import get_task_logger
from django.conf import settings
from prowler.config.config import (
csv_file_suffix,
html_file_suffix,
json_ocsf_file_suffix,
output_file_timestamp,
)
from prowler.lib.outputs.csv.csv import CSV
from prowler.lib.outputs.html.html import HTML
from prowler.lib.outputs.ocsf.ocsf import OCSF
logger = get_task_logger(__name__)
# Predefined mapping for output formats and their configurations
OUTPUT_FORMATS_MAPPING = {
"csv": {
"class": CSV,
"suffix": csv_file_suffix,
"kwargs": {},
},
"json-ocsf": {"class": OCSF, "suffix": json_ocsf_file_suffix, "kwargs": {}},
"html": {"class": HTML, "suffix": html_file_suffix, "kwargs": {"stats": {}}},
}
def _compress_output_files(output_directory: str) -> str:
"""
Compress output files from all configured output formats into a ZIP archive.
Args:
output_directory (str): The directory where the output files are located.
The function looks up all known suffixes in OUTPUT_FORMATS_MAPPING
and compresses those files into a single ZIP.
Returns:
str: The full path to the newly created ZIP archive.
"""
zip_path = f"{output_directory}.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for suffix in [config["suffix"] for config in OUTPUT_FORMATS_MAPPING.values()]:
zipf.write(
f"{output_directory}{suffix}",
f"output/{output_directory.split('/')[-1]}{suffix}",
)
return zip_path
def get_s3_client():
"""
Create and return a boto3 S3 client using AWS credentials from environment variables.
This function attempts to initialize an S3 client by reading the AWS access key, secret key,
session token, and region from environment variables. It then validates the client by listing
available S3 buckets. If an error occurs during this process (for example, due to missing or
invalid credentials), it falls back to creating an S3 client without explicitly provided credentials,
which may rely on other configuration sources (e.g., IAM roles).
Returns:
boto3.client: A configured S3 client instance.
Raises:
ClientError, NoCredentialsError, or ParamValidationError if both attempts to create a client fail.
"""
s3_client = None
try:
s3_client = boto3.client(
"s3",
aws_access_key_id=settings.DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY,
aws_session_token=settings.DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN,
region_name=settings.DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION,
)
s3_client.list_buckets()
except (ClientError, NoCredentialsError, ParamValidationError, ValueError):
s3_client = boto3.client("s3")
s3_client.list_buckets()
return s3_client
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
"""
Upload the specified ZIP file to an S3 bucket.
If the S3 bucket environment variables are not configured,
the function returns None without performing an upload.
Args:
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
zip_path (str): The local file system path to the ZIP file to be uploaded.
scan_id (str): The scan identifier, used as part of the S3 key prefix.
Returns:
str: The S3 URI of the uploaded file (e.g., "s3://<bucket>/<key>") if successful.
None: If the required environment variables for the S3 bucket are not set.
Raises:
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
"""
if not base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET:
return
try:
s3 = get_s3_client()
s3_key = f"{tenant_id}/{scan_id}/{os.path.basename(zip_path)}"
s3.upload_file(
Filename=zip_path,
Bucket=base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET,
Key=s3_key,
)
return f"s3://{base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET}/{s3_key}"
except (ClientError, NoCredentialsError, ParamValidationError, ValueError) as e:
logger.error(f"S3 upload failed: {str(e)}")
def _generate_output_directory(
output_directory, prowler_provider: object, tenant_id: str, scan_id: str
) -> str:
"""
Generate a file system path for the output directory of a prowler scan.
This function constructs the output directory path by combining a base
temporary output directory, the tenant ID, the scan ID, and details about
the prowler provider along with a timestamp. The resulting path is used to
store the output files of a prowler scan.
Note:
This function depends on one external variable:
- `output_file_timestamp`: A timestamp (as a string) used to uniquely identify the output.
Args:
output_directory (str): The base output directory.
prowler_provider (object): An identifier or descriptor for the prowler provider.
Typically, this is a string indicating the provider (e.g., "aws").
tenant_id (str): The unique identifier for the tenant.
scan_id (str): The unique identifier for the scan.
Returns:
str: The constructed file system path for the prowler scan output directory.
Example:
>>> _generate_output_directory("/tmp", "aws", "tenant-1234", "scan-5678")
'/tmp/tenant-1234/aws/scan-5678/prowler-output-2023-02-15T12:34:56'
"""
path = (
f"{output_directory}/{tenant_id}/{scan_id}/prowler-output-"
f"{prowler_provider}-{output_file_timestamp}"
)
os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
return path

View File

@@ -1,14 +1,28 @@
from celery import shared_task
from shutil import rmtree
from celery import chain, shared_task
from celery.utils.log import get_task_logger
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from tasks.jobs.connection import check_provider_connection
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.export import (
OUTPUT_FORMATS_MAPPING,
_compress_output_files,
_generate_output_directory,
_upload_to_s3,
)
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan
from tasks.utils import get_next_execution_datetime
from tasks.utils import batched, get_next_execution_datetime
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Scan, StateChoices
from api.models import Finding, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
@shared_task(base=RLSTask, name="provider-connection-check")
@@ -68,13 +82,20 @@ def perform_scan_task(
Returns:
dict: The result of the scan execution, typically including the status and results of the performed checks.
"""
return perform_prowler_scan(
result = perform_prowler_scan(
tenant_id=tenant_id,
scan_id=scan_id,
provider_id=provider_id,
checks_to_execute=checks_to_execute,
)
chain(
perform_scan_summary_task.si(tenant_id, scan_id),
generate_outputs.si(scan_id, provider_id, tenant_id=tenant_id),
).apply_async()
return result
@shared_task(base=RLSTask, bind=True, name="scan-perform-scheduled", queue="scans")
def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
@@ -135,12 +156,11 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scheduler_task_id=periodic_task_instance.id,
)
perform_scan_summary_task.apply_async(
kwargs={
"tenant_id": tenant_id,
"scan_id": str(scan_instance.id),
}
)
chain(
perform_scan_summary_task.si(tenant_id, scan_instance.id),
generate_outputs.si(str(scan_instance.id), provider_id, tenant_id=tenant_id),
).apply_async()
return result
@@ -152,3 +172,108 @@ def perform_scan_summary_task(tenant_id: str, scan_id: str):
@shared_task(name="tenant-deletion", queue="deletion")
def delete_tenant_task(tenant_id: str):
return delete_tenant(pk=tenant_id)
@shared_task(
base=RLSTask,
name="scan-report",
queue="scan-reports",
)
@set_tenant(keep_tenant=True)
def generate_outputs(scan_id: str, provider_id: str, tenant_id: str):
"""
Process findings in batches and generate output files in multiple formats.
This function retrieves findings associated with a scan, processes them
in batches of 50, and writes each batch to the corresponding output files.
It reuses output writer instances across batches, updates them with each
batch of transformed findings, and uses a flag to indicate when the final
batch is being processed. Finally, the output files are compressed and
uploaded to S3.
Args:
tenant_id (str): The tenant identifier.
scan_id (str): The scan identifier.
provider_id (str): The provider_id id to be used in generating outputs.
"""
# Initialize the prowler provider
prowler_provider = initialize_prowler_provider(Provider.objects.get(id=provider_id))
# Get the provider UID
provider_uid = Provider.objects.get(id=provider_id).uid
# Generate and ensure the output directory exists
output_directory = _generate_output_directory(
DJANGO_TMP_OUTPUT_DIRECTORY, provider_uid, tenant_id, scan_id
)
# Define auxiliary variables
output_writers = {}
scan_summary = FindingOutput._transform_findings_stats(
ScanSummary.objects.filter(scan_id=scan_id)
)
# Retrieve findings queryset
findings_qs = Finding.all_objects.filter(scan_id=scan_id).order_by("uid")
# Process findings in batches
for batch, is_last_batch in batched(
findings_qs.iterator(), DJANGO_FINDINGS_BATCH_SIZE
):
finding_outputs = [
FindingOutput.transform_api_finding(finding, prowler_provider)
for finding in batch
]
# Generate output files
for mode, config in OUTPUT_FORMATS_MAPPING.items():
kwargs = dict(config.get("kwargs", {}))
if mode == "html":
kwargs["provider"] = prowler_provider
kwargs["stats"] = scan_summary
writer_class = config["class"]
if writer_class in output_writers:
writer = output_writers[writer_class]
writer.transform(finding_outputs)
writer.close_file = is_last_batch
else:
writer = writer_class(
findings=finding_outputs,
file_path=output_directory,
file_extension=config["suffix"],
from_cli=False,
)
writer.close_file = is_last_batch
output_writers[writer_class] = writer
# Write the current batch using the writer
writer.batch_write_data_to_file(**kwargs)
# TODO: Refactor the output classes to avoid this manual reset
writer._data = []
# Compress output files
output_directory = _compress_output_files(output_directory)
# Save to configured storage
uploaded = _upload_to_s3(tenant_id, output_directory, scan_id)
if uploaded:
output_directory = uploaded
uploaded = True
# Remove the local files after upload
rmtree(DJANGO_TMP_OUTPUT_DIRECTORY, ignore_errors=True)
else:
uploaded = False
# Update the scan instance with the output path
Scan.all_objects.filter(id=scan_id).update(output_location=output_directory)
logger.info(f"Scan output files generated, output location: {output_directory}")
return {
"upload": uploaded,
"scan_id": scan_id,
"provider_id": provider_id,
}

View File

@@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
from django_celery_beat.models import IntervalSchedule, PeriodicTask
from django_celery_results.models import TaskResult
from tasks.utils import get_next_execution_datetime
from tasks.utils import batched, get_next_execution_datetime
@pytest.mark.django_db
@@ -74,3 +74,29 @@ class TestGetNextExecutionDatetime:
get_next_execution_datetime(
task_id=task_result.task_id, provider_id="nonexistent"
)
class TestBatchedFunction:
def test_empty_iterable(self):
result = list(batched([], 3))
assert result == [([], True)]
def test_exact_batches(self):
result = list(batched([1, 2, 3, 4], 2))
expected = [([1, 2], False), ([3, 4], False), ([], True)]
assert result == expected
def test_inexact_batches(self):
result = list(batched([1, 2, 3, 4, 5], 2))
expected = [([1, 2], False), ([3, 4], False), ([5], True)]
assert result == expected
def test_batch_size_one(self):
result = list(batched([1, 2, 3], 1))
expected = [([1], False), ([2], False), ([3], False), ([], True)]
assert result == expected
def test_batch_size_greater_than_length(self):
result = list(batched([1, 2, 3], 5))
expected = [([1, 2, 3], True)]
assert result == expected

View File

@@ -24,3 +24,27 @@ def get_next_execution_datetime(task_id: int, provider_id: str) -> datetime:
)
return current_scheduled_time + timedelta(**{interval.period: interval.every})
def batched(iterable, batch_size):
"""
Yield successive batches from an iterable.
Args:
iterable: An iterable source of items.
batch_size (int): The number of items per batch.
Yields:
tuple: A pair (batch, is_last_batch) where:
- batch (list): A list of items (with length equal to batch_size,
except possibly for the last batch).
- is_last_batch (bool): True if this is the final batch, False otherwise.
"""
batch = []
for item in iterable:
batch.append(item)
if len(batch) == batch_size:
yield batch, False
batch = []
yield batch, True

View File

@@ -16,6 +16,7 @@ services:
volumes:
- "./api/src/backend:/home/prowler/backend"
- "./api/pyproject.toml:/home/prowler/pyproject.toml"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -85,6 +86,8 @@ services:
env_file:
- path: .env
required: false
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy

View File

@@ -7,6 +7,8 @@ services:
required: false
ports:
- "${DJANGO_PORT:-8080}:${DJANGO_PORT:-8080}"
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -65,6 +67,8 @@ services:
env_file:
- path: .env
required: false
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy

View File

@@ -351,7 +351,6 @@ def prowler():
if mode == "csv":
csv_output = CSV(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{csv_file_suffix}",
)
generated_outputs["regular"].append(csv_output)
@@ -361,7 +360,6 @@ def prowler():
if mode == "json-asff":
asff_output = ASFF(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{json_asff_file_suffix}",
)
generated_outputs["regular"].append(asff_output)
@@ -371,7 +369,6 @@ def prowler():
if mode == "json-ocsf":
json_output = OCSF(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{json_ocsf_file_suffix}",
)
generated_outputs["regular"].append(json_output)
@@ -379,7 +376,6 @@ def prowler():
if mode == "html":
html_output = HTML(
findings=finding_outputs,
create_file_descriptor=True,
file_path=f"{filename}{html_file_suffix}",
)
generated_outputs["regular"].append(html_output)
@@ -402,7 +398,6 @@ def prowler():
cis = AWSCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -416,7 +411,6 @@ def prowler():
mitre_attack = AWSMitreAttack(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(mitre_attack)
@@ -430,7 +424,6 @@ def prowler():
ens = AWSENS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(ens)
@@ -444,7 +437,6 @@ def prowler():
aws_well_architected = AWSWellArchitected(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(aws_well_architected)
@@ -458,7 +450,6 @@ def prowler():
iso27001 = AWSISO27001(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(iso27001)
@@ -472,7 +463,6 @@ def prowler():
kisa_ismsp = AWSKISAISMSP(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(kisa_ismsp)
@@ -485,7 +475,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -502,7 +491,6 @@ def prowler():
cis = AzureCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -516,7 +504,6 @@ def prowler():
mitre_attack = AzureMitreAttack(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(mitre_attack)
@@ -530,7 +517,6 @@ def prowler():
ens = AzureENS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(ens)
@@ -543,7 +529,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -560,7 +545,6 @@ def prowler():
cis = GCPCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -574,7 +558,6 @@ def prowler():
mitre_attack = GCPMitreAttack(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(mitre_attack)
@@ -588,7 +571,6 @@ def prowler():
ens = GCPENS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(ens)
@@ -601,7 +583,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -618,7 +599,6 @@ def prowler():
cis = KubernetesCIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -631,7 +611,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)
@@ -648,7 +627,6 @@ def prowler():
cis = Microsoft365CIS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(cis)
@@ -661,7 +639,6 @@ def prowler():
generic_compliance = GenericCompliance(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
create_file_descriptor=True,
file_path=filename,
)
generated_outputs["compliance"].append(generic_compliance)

View File

@@ -29,11 +29,11 @@ class ComplianceOutput(Output):
self,
findings: List[Finding],
compliance: Compliance,
create_file_descriptor: bool = False,
file_path: str = None,
file_extension: str = "",
) -> None:
self._data = []
self.file_descriptor = None
if not file_extension and file_path:
self._file_extension = "".join(Path(file_path).suffixes)
@@ -48,7 +48,7 @@ class ComplianceOutput(Output):
else compliance.Framework
)
self.transform(findings, compliance, compliance_name)
if create_file_descriptor:
if not self._file_descriptor and file_path:
self.create_file_descriptor(file_path)
def batch_write_data_to_file(self) -> None:

View File

@@ -98,10 +98,12 @@ class CSV(Output):
fieldnames=self._data[0].keys(),
delimiter=";",
)
csv_writer.writeheader()
if self._file_descriptor.tell() == 0:
csv_writer.writeheader()
for finding in self._data:
csv_writer.writerow(finding)
self._file_descriptor.close()
if self.close_file or self._from_cli:
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"

View File

@@ -1,13 +1,21 @@
from datetime import datetime
from types import SimpleNamespace
from typing import Optional, Union
from pydantic import BaseModel, Field, ValidationError
from prowler.config.config import prowler_version
from prowler.lib.check.models import Check_Report, CheckMetadata
from prowler.lib.check.models import (
Check_Report,
CheckMetadata,
Code,
Recommendation,
Remediation,
)
from prowler.lib.logger import logger
from prowler.lib.outputs.common import Status, fill_common_finding_data
from prowler.lib.outputs.compliance.compliance import get_check_compliance
from prowler.lib.outputs.utils import unroll_tags
from prowler.lib.utils.utils import dict_to_lowercase, get_nested_attribute
from prowler.providers.common.provider import Provider
@@ -267,3 +275,193 @@ class Finding(BaseModel):
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
raise error
@classmethod
def transform_api_finding(cls, finding, provider) -> "Finding":
"""
Transform a FindingModel instance into an API-friendly Finding object.
This class method extracts data from a FindingModel instance and maps its
properties to a new Finding object. The transformation populates various
fields including authentication details, timestamp, account information,
check metadata (such as provider, check ID, title, type, service, severity,
and remediation details), as well as resource-specific data. The resulting
Finding object is structured for use in API responses or further processing.
Args:
finding (API Finding): An API Finding instance containing data from the database.
provider (Provider): the provider object.
Returns:
Finding: A new Finding instance populated with data from the provided model.
"""
# Missing Finding's API values
finding.muted = False
finding.resource_details = ""
resource = finding.resources.first()
finding.resource_arn = resource.uid
finding.resource_name = resource.name
# TODO: Change this when the API has all the values
finding.resource = {}
finding.resource_id = resource.name if provider.type == "aws" else resource.uid
# AWS specified field
finding.region = resource.region
# Azure, GCP specified field
finding.location = resource.region
# K8s specified field
if provider.type == "kubernetes":
finding.namespace = resource.region.removeprefix("namespace: ")
if provider.type == "azure":
finding.subscription = list(provider.identity.subscriptions.keys())[0]
elif provider.type == "gcp":
finding.project_id = list(provider.projects.keys())[0]
finding.check_metadata = CheckMetadata(
Provider=finding.check_metadata["provider"],
CheckID=finding.check_metadata["checkid"],
CheckTitle=finding.check_metadata["checktitle"],
CheckType=finding.check_metadata["checktype"],
ServiceName=finding.check_metadata["servicename"],
SubServiceName=finding.check_metadata["subservicename"],
Severity=finding.check_metadata["severity"],
ResourceType=finding.check_metadata["resourcetype"],
Description=finding.check_metadata["description"],
Risk=finding.check_metadata["risk"],
RelatedUrl=finding.check_metadata["relatedurl"],
Remediation=Remediation(
Recommendation=Recommendation(
Text=finding.check_metadata["remediation"]["recommendation"][
"text"
],
Url=finding.check_metadata["remediation"]["recommendation"]["url"],
),
Code=Code(
NativeIaC=finding.check_metadata["remediation"]["code"][
"nativeiac"
],
Terraform=finding.check_metadata["remediation"]["code"][
"terraform"
],
CLI=finding.check_metadata["remediation"]["code"]["cli"],
Other=finding.check_metadata["remediation"]["code"]["other"],
),
),
ResourceIdTemplate=finding.check_metadata["resourceidtemplate"],
Categories=finding.check_metadata["categories"],
DependsOn=finding.check_metadata["dependson"],
RelatedTo=finding.check_metadata["relatedto"],
Notes=finding.check_metadata["notes"],
)
finding.resource_tags = unroll_tags(
[{"key": tag.key, "value": tag.value} for tag in resource.tags.all()]
)
return cls.generate_output(provider, finding, SimpleNamespace())
def _transform_findings_stats(scan_summaries: list[dict]) -> dict:
"""
Aggregate and transform scan summary data into findings statistics.
This function processes a list of scan summary objects and calculates overall
metrics such as the total number of passed and failed findings (including muted counts),
as well as a breakdown of results by severity (critical, high, medium, and low).
It also retrieves the unique resource count from the associated scan information.
The final output is a dictionary of aggregated statistics intended for reporting or
further analysis.
Args:
scan_summaries (list[dict]): A list of scan summary objects. Each object is expected
to have attributes including:
- _pass: Number of passed findings.
- fail: Number of failed findings.
- total: Total number of findings.
- muted: Number indicating if the finding is muted.
- severity: A string representing the severity level.
Additionally, the first scan summary should have an associated
`scan` attribute with a `unique_resource_count`.
Returns:
dict: A dictionary containing aggregated findings statistics:
- total_pass: Total number of passed findings.
- total_muted_pass: Total number of muted passed findings.
- total_fail: Total number of failed findings.
- total_muted_fail: Total number of muted failed findings.
- resources_count: The unique resource count extracted from the scan.
- findings_count: Total number of findings.
- total_critical_severity_fail: Failed findings with critical severity.
- total_critical_severity_pass: Passed findings with critical severity.
- total_high_severity_fail: Failed findings with high severity.
- total_high_severity_pass: Passed findings with high severity.
- total_medium_severity_fail: Failed findings with medium severity.
- total_medium_severity_pass: Passed findings with medium severity.
- total_low_severity_fail: Failed findings with low severity.
- total_low_severity_pass: Passed findings with low severity.
- all_fails_are_muted: A boolean indicating whether all failing findings are muted.
"""
# Initialize overall counters
total_pass = 0
total_fail = 0
muted_pass = 0
muted_fail = 0
findings_count = 0
resources_count = scan_summaries[0].scan.unique_resource_count
# Initialize severity breakdown counters
critical_severity_pass = 0
critical_severity_fail = 0
high_severity_pass = 0
high_severity_fail = 0
medium_severity_pass = 0
medium_severity_fail = 0
low_severity_pass = 0
low_severity_fail = 0
# Loop over each row from the database
for row in scan_summaries:
# Accumulate overall totals
total_pass += row._pass
total_fail += row.fail
findings_count += row.total
if row.muted > 0:
if row._pass > 0:
muted_pass += row._pass
if row.fail > 0:
muted_fail += row.fail
sev = row.severity.lower()
if sev == "critical":
critical_severity_pass += row._pass
critical_severity_fail += row.fail
elif sev == "high":
high_severity_pass += row._pass
high_severity_fail += row.fail
elif sev == "medium":
medium_severity_pass += row._pass
medium_severity_fail += row.fail
elif sev == "low":
low_severity_pass += row._pass
low_severity_fail += row.fail
all_fails_are_muted = (total_fail > 0) and (total_fail == muted_fail)
stats = {
"total_pass": total_pass,
"total_muted_pass": muted_pass,
"total_fail": total_fail,
"total_muted_fail": muted_fail,
"resources_count": resources_count,
"findings_count": findings_count,
"total_critical_severity_fail": critical_severity_fail,
"total_critical_severity_pass": critical_severity_pass,
"total_high_severity_fail": high_severity_fail,
"total_high_severity_pass": high_severity_pass,
"total_medium_severity_fail": medium_severity_fail,
"total_medium_severity_pass": medium_severity_pass,
"total_low_severity_fail": low_severity_fail,
"total_low_severity_pass": low_severity_pass,
"all_fails_are_muted": all_fails_are_muted,
}
return stats

View File

@@ -74,12 +74,15 @@ class HTML(Output):
and not self._file_descriptor.closed
and self._data
):
HTML.write_header(self._file_descriptor, provider, stats)
if self._file_descriptor.tell() == 0:
HTML.write_header(
self._file_descriptor, provider, stats, self._from_cli
)
for finding in self._data:
self._file_descriptor.write(finding)
HTML.write_footer(self._file_descriptor)
# Close file descriptor
self._file_descriptor.close()
if self.close_file or self._from_cli:
HTML.write_footer(self._file_descriptor)
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -87,7 +90,10 @@ class HTML(Output):
@staticmethod
def write_header(
file_descriptor: TextIOWrapper, provider: Provider, stats: dict
file_descriptor: TextIOWrapper,
provider: Provider,
stats: dict,
from_cli: bool = True,
) -> None:
"""
Writes the header of the HTML file.
@@ -96,6 +102,7 @@ class HTML(Output):
file_descriptor (file): the file descriptor to write the header
provider (Provider): the provider object
stats (dict): the statistics of the findings
from_cli (bool): whether the request is from the CLI or not
"""
try:
file_descriptor.write(
@@ -153,7 +160,7 @@ class HTML(Output):
</div>
</li>
<li class="list-group-item">
<b>Parameters used:</b> {" ".join(sys.argv[1:])}
<b>Parameters used:</b> {" ".join(sys.argv[1:]) if from_cli else ""}
</li>
<li class="list-group-item">
<b>Date:</b> {timestamp.isoformat()}

View File

@@ -199,7 +199,8 @@ class OCSF(Output):
and not self._file_descriptor.closed
and self._data
):
self._file_descriptor.write("[")
if self._file_descriptor.tell() == 0:
self._file_descriptor.write("[")
for finding in self._data:
try:
self._file_descriptor.write(
@@ -210,14 +211,14 @@ class OCSF(Output):
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
if self._file_descriptor.tell() > 0:
if self.close_file or self._from_cli:
if self._file_descriptor.tell() != 1:
self._file_descriptor.seek(
self._file_descriptor.tell() - 1, os.SEEK_SET
)
self._file_descriptor.truncate()
self._file_descriptor.write("]")
self._file_descriptor.close()
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"

View File

@@ -23,7 +23,6 @@ class Output(ABC):
file_descriptor: Property to access the file descriptor.
transform: Abstract method to transform findings into a specific format.
batch_write_data_to_file: Abstract method to write data to a file in batches.
create_file_descriptor: Method to create a file descriptor for writing data to a file.
"""
_data: list
@@ -33,21 +32,27 @@ class Output(ABC):
def __init__(
self,
findings: List[Finding],
create_file_descriptor: bool = False,
file_path: str = None,
file_extension: str = "",
from_cli: bool = True,
) -> None:
self._data = []
self.close_file = False
self.file_path = file_path
self._file_descriptor = None
# This parameter is to avoid refactoring more code, the CLI does not write in batches, the API does
self._from_cli = from_cli
if not file_extension and file_path:
self._file_extension = "".join(Path(file_path).suffixes)
if file_extension:
self._file_extension = file_extension
self.file_path = f"{file_path}{self.file_extension}"
if findings:
self.transform(findings)
if create_file_descriptor and file_path:
self.create_file_descriptor(file_path)
if not self._file_descriptor and file_path:
self.create_file_descriptor(self.file_path)
@property
def data(self):

View File

@@ -110,7 +110,7 @@ class S3:
for output in output_list:
try:
# Object is not written to file so we need to temporarily write it
if not hasattr(output, "file_descriptor"):
if not output.file_descriptor:
output.file_descriptor = NamedTemporaryFile(mode="a")
bucket_directory = self.get_object_path(self._output_directory)

View File

@@ -921,7 +921,7 @@ class AzureProvider(Provider):
# since that exception is not considered as critical, we keep filling another identity fields
if sp_env_auth or client_id:
# The id of the sp can be retrieved from environment variables
identity.identity_id = getenv("AZURE_CLIENT_ID")
identity.identity_id = getenv("AZURE_CLIENT_ID", default=client_id)
identity.identity_type = "Service Principal"
# Same here, if user can access AAD, some fields are retrieved if not, default value, for az cli
# should work but it doesn't, pending issue

View File

@@ -577,7 +577,7 @@ class TestASFF:
assert loads(content) == expected_asff
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(ASFF([]), "_file_descriptor")
assert not ASFF([])._file_descriptor
def test_asff_generate_status(self):
assert ASFF.generate_status("PASS") == "PASSED"

View File

@@ -119,7 +119,7 @@ class TestCSV:
assert content == expected_csv
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(CSV([]), "_file_descriptor")
assert not CSV([])._file_descriptor
@pytest.fixture
def mock_output_class(self):
@@ -144,9 +144,7 @@ class TestCSV:
file_path = file.name
# Instantiate the mock class
output_instance = mock_output_class(
findings, create_file_descriptor=True, file_path=file_path
)
output_instance = mock_output_class(findings, file_path=file_path)
# Check that transform was called once
output_instance.transform.assert_called_once_with(findings)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@@ -55,6 +56,65 @@ def mock_get_check_compliance(*_):
return {"mock_compliance_key": "mock_compliance_value"}
class DummyTag:
def __init__(self, key, value):
self.key = key
self.value = value
class DummyTags:
def __init__(self, tags):
self._tags = tags
def all(self):
return self._tags
class DummyResource:
def __init__(self, uid, name, resource_arn, region, tags):
self.uid = uid
self.name = name
self.resource_arn = resource_arn
self.region = region
self.tags = DummyTags(tags)
def __iter__(self):
yield "uid", self.uid
yield "name", self.name
yield "region", self.region
yield "tags", self.tags
class DummyResources:
"""Simulate a collection with a first() method."""
def __init__(self, resource):
self._resource = resource
def first(self):
return self._resource
class DummyProvider:
def __init__(self, uid):
self.uid = uid
self.type = "aws"
class DummyScan:
def __init__(self, provider):
self.provider = provider
class DummyAPIFinding:
"""
A dummy API finding model to simulate the database model.
Attributes will be added dynamically.
"""
pass
class TestFinding:
@patch(
"prowler.lib.outputs.finding.get_check_compliance",
@@ -461,3 +521,566 @@ class TestFinding:
# Generate the finding
with pytest.raises(ValidationError):
Finding.generate_output(provider, check_output, output_options)
@patch(
"prowler.lib.outputs.finding.get_check_compliance",
new=mock_get_check_compliance,
)
def test_transform_api_finding_aws(self):
"""
Test that a dummy API Finding is correctly
transformed into a Finding instance.
"""
# Set up the dummy API finding attributes
inserted_at = 1234567890
provider = DummyProvider(uid="account123")
provider.type = "aws"
scan = DummyScan(provider=provider)
# Create a dummy resource with one tag
tag = DummyTag("env", "prod")
resource = DummyResource(
uid="res-uid-1",
name="ResourceName1",
resource_arn="arn",
region="us-east-1",
tags=[tag],
)
resources = DummyResources(resource)
# Create a dummy check_metadata dict with all required fields
check_metadata = {
"provider": "test_provider",
"checkid": "check-001",
"checktitle": "Test Check",
"checktype": ["type1"],
"servicename": "TestService",
"subservicename": "SubService",
"severity": "high",
"resourcetype": "TestResource",
"description": "A test check",
"risk": "High risk",
"relatedurl": "http://example.com",
"remediation": {
"recommendation": {"text": "Fix it", "url": "http://fix.com"},
"code": {
"nativeiac": "iac_code",
"terraform": "terraform_code",
"cli": "cli_code",
"other": "other_code",
},
},
"resourceidtemplate": "template",
"categories": ["cat-one", "cat-two"],
"dependson": ["dep1"],
"relatedto": ["rel1"],
"notes": "Some notes",
}
# Create the dummy API finding and assign required attributes
dummy_finding = DummyAPIFinding()
dummy_finding.inserted_at = inserted_at
dummy_finding.scan = scan
dummy_finding.uid = "finding-uid-1"
dummy_finding.status = "FAIL" # will be converted to Status("FAIL")
dummy_finding.status_extended = "extended"
dummy_finding.check_metadata = check_metadata
dummy_finding.resources = resources
# Call the transform_api_finding classmethod
finding_obj = Finding.transform_api_finding(dummy_finding, provider)
# Check that metadata was built correctly
meta = finding_obj.metadata
assert meta.Provider == "test_provider"
assert meta.CheckID == "check-001"
assert meta.CheckTitle == "Test Check"
assert meta.CheckType == ["type1"]
assert meta.ServiceName == "TestService"
assert meta.SubServiceName == "SubService"
assert meta.Severity == "high"
assert meta.ResourceType == "TestResource"
assert meta.Description == "A test check"
assert meta.Risk == "High risk"
assert meta.RelatedUrl == "http://example.com"
assert meta.Remediation.Recommendation.Text == "Fix it"
assert meta.Remediation.Recommendation.Url == "http://fix.com"
assert meta.Remediation.Code.NativeIaC == "iac_code"
assert meta.Remediation.Code.Terraform == "terraform_code"
assert meta.Remediation.Code.CLI == "cli_code"
assert meta.Remediation.Code.Other == "other_code"
assert meta.ResourceIdTemplate == "template"
assert meta.Categories == ["cat-one", "cat-two"]
assert meta.DependsOn == ["dep1"]
assert meta.RelatedTo == ["rel1"]
assert meta.Notes == "Some notes"
# Check other Finding fields
assert finding_obj.uid == "prowler-aws-check-001--us-east-1-ResourceName1"
assert finding_obj.status == Status("FAIL")
assert finding_obj.status_extended == "extended"
# From the dummy resource
assert finding_obj.resource_uid == "res-uid-1"
assert finding_obj.resource_name == "ResourceName1"
assert finding_obj.resource_details == ""
# unroll_tags is called on a list with one tag -> expect {"env": "prod"}
assert finding_obj.resource_tags == {"env": "prod"}
assert finding_obj.region == "us-east-1"
assert finding_obj.compliance == {
"mock_compliance_key": "mock_compliance_value"
}
@patch(
"prowler.lib.outputs.finding.get_check_compliance",
new=mock_get_check_compliance,
)
def test_transform_api_finding_azure(self):
provider = MagicMock()
provider.type = "azure"
provider.identity.identity_type = "mock_identity_type"
provider.identity.identity_id = "mock_identity_id"
provider.identity.subscriptions = {"default": "default"}
provider.identity.tenant_ids = ["test-ing-432a-a828-d9c965196f87"]
provider.identity.tenant_domain = "mock_tenant_domain"
provider.region_config.name = "AzureCloud"
api_finding = DummyAPIFinding()
api_finding.id = "019514b3-9a66-7cde-921e-9d1ca0531ceb"
api_finding.inserted_at = "2025-02-17 16:17:49"
api_finding.updated_at = "2025-02-17 16:17:49"
api_finding.uid = (
"prowler-azure-defender_auto_provisioning_log_analytics_agent_vms_on-"
"test-ing-4646-bed4-e74f14020726-global-default"
)
api_finding.delta = "new"
api_finding.status = "FAIL"
api_finding.status_extended = "Defender Auto Provisioning Log Analytics Agents from subscription Azure subscription 1 is set to OFF."
api_finding.severity = "medium"
api_finding.impact = "medium"
api_finding.impact_extended = ""
api_finding.raw_result = {}
api_finding.check_id = "defender_auto_provisioning_log_analytics_agent_vms_on"
api_finding.check_metadata = {
"risk": "Missing critical security information about your Azure VMs, such as security alerts, security recommendations, and change tracking.",
"notes": "",
"checkid": "defender_auto_provisioning_log_analytics_agent_vms_on",
"provider": "azure",
"severity": "medium",
"checktype": [],
"dependson": [],
"relatedto": [],
"categories": [],
"checktitle": "Ensure that Auto provisioning of 'Log Analytics agent for Azure VMs' is Set to 'On'",
"compliance": None,
"relatedurl": "https://docs.microsoft.com/en-us/azure/security-center/security-center-data-security",
"description": (
"Ensure that Auto provisioning of 'Log Analytics agent for Azure VMs' is Set to 'On'. "
"The Microsoft Monitoring Agent scans for various security-related configurations and events such as system updates, "
"OS vulnerabilities, endpoint protection, and provides alerts."
),
"remediation": {
"code": {
"cli": "",
"other": "https://www.trendmicro.com/cloudoneconformity-staging/knowledge-base/azure/SecurityCenter/automatic-provisioning-of-monitoring-agent.html",
"nativeiac": "",
"terraform": "",
},
"recommendation": {
"url": "https://learn.microsoft.com/en-us/azure/defender-for-cloud/monitoring-components",
"text": (
"Ensure comprehensive visibility into possible security vulnerabilities, including missing updates, "
"misconfigured operating system security settings, and active threats, allowing for timely mitigation and improved overall security posture"
),
},
},
"servicename": "defender",
"checkaliases": [],
"resourcetype": "AzureDefenderPlan",
"subservicename": "",
"resourceidtemplate": "",
}
api_finding.tags = {}
api_resource = DummyResource(
uid="/subscriptions/test-ing-4646-bed4-e74f14020726/providers/Microsoft.Security/autoProvisioningSettings/default",
name="default",
resource_arn="arn",
region="global",
tags=[],
)
api_finding.resources = DummyResources(api_resource)
api_finding.subscription = "default"
finding_obj = Finding.transform_api_finding(api_finding, provider)
assert finding_obj.account_organization_uid == "test-ing-432a-a828-d9c965196f87"
assert finding_obj.account_organization_name == "mock_tenant_domain"
assert finding_obj.resource_uid == api_resource.uid
assert finding_obj.resource_name == api_resource.name
assert finding_obj.region == api_resource.region
assert finding_obj.resource_tags == {}
assert finding_obj.compliance == {
"mock_compliance_key": "mock_compliance_value"
}
assert finding_obj.status == Status("FAIL")
assert finding_obj.status_extended == (
"Defender Auto Provisioning Log Analytics Agents from subscription Azure subscription 1 is set to OFF."
)
meta = finding_obj.metadata
assert meta.Provider == "azure"
assert meta.CheckID == "defender_auto_provisioning_log_analytics_agent_vms_on"
assert (
meta.CheckTitle
== "Ensure that Auto provisioning of 'Log Analytics agent for Azure VMs' is Set to 'On'"
)
assert meta.Severity == "medium"
assert meta.ResourceType == "AzureDefenderPlan"
assert (
meta.Remediation.Recommendation.Url
== "https://learn.microsoft.com/en-us/azure/defender-for-cloud/monitoring-components"
)
assert meta.Remediation.Recommendation.Text.startswith(
"Ensure comprehensive visibility"
)
expected_segments = [
"prowler-azure",
"defender_auto_provisioning_log_analytics_agent_vms_on",
api_resource.region,
api_resource.name,
]
for segment in expected_segments:
assert segment in finding_obj.uid
@patch(
"prowler.lib.outputs.finding.get_check_compliance",
new=mock_get_check_compliance,
)
def test_transform_api_finding_gcp(self):
provider = MagicMock()
provider.type = "gcp"
provider.identity.profile = "gcp_profile"
dummy_project = MagicMock()
dummy_project.id = "project1"
dummy_project.name = "TestProject"
dummy_project.labels = {"env": "prod"}
dummy_org = MagicMock()
dummy_org.id = "org-123"
dummy_org.display_name = "Test Org"
dummy_project.organization = dummy_org
provider.projects = {"project1": dummy_project}
dummy_finding = DummyAPIFinding()
dummy_finding.inserted_at = "2025-02-17 16:17:49"
dummy_finding.updated_at = "2025-02-17 16:17:49"
dummy_finding.scan = DummyScan(provider=provider)
dummy_finding.uid = "finding-uid-gcp"
dummy_finding.status = "PASS"
dummy_finding.status_extended = "GCP check extended"
check_metadata = {
"provider": "gcp",
"checkid": "gcp-check-001",
"checktitle": "Test GCP Check",
"checktype": [],
"servicename": "TestGCPService",
"subservicename": "",
"severity": "medium",
"resourcetype": "GCPResourceType",
"description": "GCP check description",
"risk": "Medium risk",
"relatedurl": "http://gcp.example.com",
"remediation": {
"code": {
"nativeiac": "iac_code",
"terraform": "terraform_code",
"cli": "cli_code",
"other": "other_code",
},
"recommendation": {"text": "Fix it", "url": "http://fix-gcp.com"},
},
"resourceidtemplate": "template",
"categories": ["cat-one", "cat-two"],
"dependson": ["dep1"],
"relatedto": ["rel1"],
"notes": "Some notes",
}
dummy_finding.check_metadata = check_metadata
dummy_finding.raw_result = {}
dummy_finding.project_id = "project1"
resource = DummyResource(
uid="gcp-resource-uid",
name="gcp-resource-name",
resource_arn="arn",
region="us-central1",
tags=[],
)
dummy_finding.resources = DummyResources(resource)
finding_obj = Finding.transform_api_finding(dummy_finding, provider)
assert finding_obj.auth_method == "Principal: gcp_profile"
assert finding_obj.account_uid == dummy_project.id
assert finding_obj.account_name == dummy_project.name
assert finding_obj.account_tags == dummy_project.labels
assert finding_obj.resource_name == resource.name
assert finding_obj.resource_uid == resource.uid
assert finding_obj.region == resource.region
assert finding_obj.account_organization_uid == dummy_project.organization.id
assert (
finding_obj.account_organization_name
== dummy_project.organization.display_name
)
assert finding_obj.compliance == {
"mock_compliance_key": "mock_compliance_value"
}
assert finding_obj.status == Status("PASS")
assert finding_obj.status_extended == "GCP check extended"
expected_uid = f"prowler-gcp-{check_metadata['checkid']}-{dummy_project.id}-{resource.region}-{resource.name}"
assert finding_obj.uid == expected_uid
@patch(
"prowler.lib.outputs.finding.get_check_compliance",
new=mock_get_check_compliance,
)
def test_transform_api_finding_kubernetes(self):
provider = MagicMock()
provider.type = "kubernetes"
provider.identity.context = "In-Cluster"
provider.identity.cluster = "cluster-1"
api_finding = DummyAPIFinding()
api_finding.inserted_at = 1234567890
api_finding.scan = DummyScan(provider=provider)
api_finding.uid = "finding-uid-k8s"
api_finding.status = "PASS"
api_finding.status_extended = "K8s check extended"
check_metadata = {
"provider": "kubernetes",
"checkid": "k8s-check-001",
"checktitle": "Test K8s Check",
"checktype": [],
"servicename": "TestK8sService",
"subservicename": "",
"severity": "low",
"resourcetype": "K8sResourceType",
"description": "K8s check description",
"risk": "Low risk",
"relatedurl": "http://k8s.example.com",
"remediation": {
"code": {
"nativeiac": "iac_code",
"terraform": "terraform_code",
"cli": "cli_code",
"other": "other_code",
},
"recommendation": {"text": "Fix it", "url": "http://fix-k8s.com"},
},
"resourceidtemplate": "template",
"categories": ["cat-one"],
"dependson": [],
"relatedto": [],
"notes": "K8s notes",
}
api_finding.check_metadata = check_metadata
api_finding.raw_result = {}
api_finding.resource_name = "k8s-resource-name"
api_finding.resource_id = "k8s-resource-uid"
resource = DummyResource(
uid="k8s-resource-uid",
name="k8s-resource-name",
resource_arn="arn",
region="",
tags=[],
)
resource.region = "namespace: default"
api_finding.resources = DummyResources(resource)
finding_obj = Finding.transform_api_finding(api_finding, provider)
assert finding_obj.auth_method == "in-cluster"
assert finding_obj.resource_name == "k8s-resource-name"
assert finding_obj.resource_uid == "k8s-resource-uid"
assert finding_obj.account_name == "context: In-Cluster"
assert finding_obj.account_uid == "cluster-1"
assert finding_obj.region == "namespace: default"
@patch(
"prowler.lib.outputs.finding.get_check_compliance",
new=mock_get_check_compliance,
)
def test_transform_api_finding_microsoft365(self):
provider = MagicMock()
provider.type = "microsoft365"
provider.identity.identity_type = "ms_identity_type"
provider.identity.identity_id = "ms_identity_id"
provider.identity.tenant_id = "ms-tenant-id"
provider.identity.tenant_domain = "ms-tenant-domain"
dummy_finding = DummyAPIFinding()
dummy_finding.inserted_at = 1234567890
dummy_finding.scan = DummyScan(provider=provider)
dummy_finding.uid = "finding-uid-m365"
dummy_finding.status = "PASS"
dummy_finding.status_extended = "M365 check extended"
check_metadata = {
"provider": "microsoft365",
"checkid": "m365-check-001",
"checktitle": "Test M365 Check",
"checktype": [],
"servicename": "TestM365Service",
"subservicename": "",
"severity": "high",
"resourcetype": "M365ResourceType",
"description": "M365 check description",
"risk": "High risk",
"relatedurl": "http://m365.example.com",
"remediation": {
"code": {
"nativeiac": "iac_code",
"terraform": "terraform_code",
"cli": "cli_code",
"other": "other_code",
},
"recommendation": {"text": "Fix it", "url": "http://fix-m365.com"},
},
"resourceidtemplate": "template",
"categories": ["cat-one"],
"dependson": [],
"relatedto": [],
"notes": "M365 notes",
}
dummy_finding.check_metadata = check_metadata
dummy_finding.raw_result = {}
dummy_finding.resource_name = "ms-resource-name"
dummy_finding.resource_id = "ms-resource-uid"
dummy_finding.location = "global"
resource = DummyResource(
uid="ms-resource-uid",
name="ms-resource-name",
resource_arn="arn",
region="global",
tags=[],
)
dummy_finding.resources = DummyResources(resource)
finding_obj = Finding.transform_api_finding(dummy_finding, provider)
assert finding_obj.auth_method == "ms_identity_type: ms_identity_id"
assert finding_obj.account_uid == "ms-tenant-id"
assert finding_obj.account_name == "ms-tenant-domain"
assert finding_obj.resource_name == "ms-resource-name"
assert finding_obj.resource_uid == "ms-resource-uid"
assert finding_obj.region == "global"
def test_transform_findings_stats_all_fails_muted(self):
"""
Test _transform_findings_stats when every failing finding is muted.
"""
# Create a dummy scan object with a unique_resource_count
dummy_scan = SimpleNamespace(unique_resource_count=10)
# Build summaries covering each severity branch.
ss1 = SimpleNamespace(
_pass=1, fail=2, total=3, muted=2, severity="critical", scan=dummy_scan
)
ss2 = SimpleNamespace(
_pass=2, fail=0, total=2, muted=0, severity="high", scan=dummy_scan
)
ss3 = SimpleNamespace(
_pass=2, fail=3, total=5, muted=3, severity="medium", scan=dummy_scan
)
ss4 = SimpleNamespace(
_pass=3, fail=0, total=3, muted=0, severity="low", scan=dummy_scan
)
summaries = [ss1, ss2, ss3, ss4]
stats = Finding._transform_findings_stats(summaries)
# Expected calculations:
# total_pass = 1+2+2+3 = 8
# total_fail = 2+0+3+0 = 5
# findings_count = 3+2+5+3 = 13
# muted_pass = (ss1: 1) + (ss3: 2) = 3
# muted_fail = (ss1: 2) + (ss3: 3) = 5
expected = {
"total_pass": 8,
"total_muted_pass": 3,
"total_fail": 5,
"total_muted_fail": 5,
"resources_count": 10,
"findings_count": 13,
"total_critical_severity_fail": 2,
"total_critical_severity_pass": 1,
"total_high_severity_fail": 0,
"total_high_severity_pass": 2,
"total_medium_severity_fail": 3,
"total_medium_severity_pass": 2,
"total_low_severity_fail": 0,
"total_low_severity_pass": 3,
"all_fails_are_muted": True, # total_fail equals muted_fail and total_fail > 0
}
assert stats == expected
def test_transform_findings_stats_not_all_fails_muted(self):
"""
Test _transform_findings_stats when at least one failing finding is not muted.
"""
dummy_scan = SimpleNamespace(unique_resource_count=5)
# Build summaries: one summary has fail > 0 but muted == 0
ss1 = SimpleNamespace(
_pass=1, fail=2, total=3, muted=0, severity="critical", scan=dummy_scan
)
ss2 = SimpleNamespace(
_pass=2, fail=1, total=3, muted=1, severity="high", scan=dummy_scan
)
summaries = [ss1, ss2]
stats = Finding._transform_findings_stats(summaries)
# Expected calculations:
# total_pass = 1+2 = 3
# total_fail = 2+1 = 3
# findings_count = 3+3 = 6
# muted_pass = (ss2: 2) since ss1 muted is 0
# muted_fail = (ss2: 1)
# Severity breakdown: critical: pass 1, fail 2; high: pass 2, fail 1
expected = {
"total_pass": 3,
"total_muted_pass": 2,
"total_fail": 3,
"total_muted_fail": 1,
"resources_count": 5,
"findings_count": 6,
"total_critical_severity_fail": 2,
"total_critical_severity_pass": 1,
"total_high_severity_fail": 1,
"total_high_severity_pass": 2,
"total_medium_severity_fail": 0,
"total_medium_severity_pass": 0,
"total_low_severity_fail": 0,
"total_low_severity_pass": 0,
"all_fails_are_muted": False, # 3 (total_fail) != 1 (muted_fail)
}
assert stats == expected
def test_transform_api_finding_validation_error(self):
"""
Test that if required data is missing (causing a ValidationError)
the function logs the error and re-raises the exception.
For example, if the metadata dict is missing required keys.
"""
provider = DummyProvider(uid="account123")
# Create a dummy API finding that is missing some required metadata
dummy_finding = DummyAPIFinding()
dummy_finding.inserted_at = 1234567890
dummy_finding.scan = DummyScan(provider=provider)
dummy_finding.uid = "finding-uid-invalid"
dummy_finding.status = "PASS"
dummy_finding.status_extended = "extended"
# Missing required metadata keys using an empty dict
dummy_finding.check_metadata = {}
# Provide a dummy resources with a minimal resource
tag = DummyTag("env", "prod")
resource = DummyResource(
uid="res-uid-1",
name="ResourceName1",
resource_arn="arn",
region="us-east-1",
tags=[tag],
)
dummy_finding.resources = DummyResources(resource)
with pytest.raises(KeyError):
Finding.transform_api_finding(dummy_finding, provider)

View File

@@ -492,7 +492,7 @@ class TestHTML:
assert content == get_aws_html_header(args) + pass_html_finding + html_footer
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(HTML([]), "_file_descriptor")
assert not HTML([])._file_descriptor
def test_write_header(self):
mock_file = StringIO()

View File

@@ -256,7 +256,7 @@ class TestOCSF:
assert json.loads(content) == expected_json_output
def test_batch_write_data_to_file_without_findings(self):
assert not hasattr(OCSF([]), "_file_descriptor")
assert not OCSF([])._file_descriptor
def test_finding_output_cloud_pass_low_muted(self):
finding_output = generate_finding_output(

View File

@@ -113,7 +113,6 @@ class TestS3:
csv_file = f"test{extension}"
csv = CSV(
findings=[FINDING],
create_file_descriptor=True,
file_path=f"{CURRENT_DIRECTORY}/{csv_file}",
)