mirror of
https://github.com/prowler-cloud/prowler.git
synced 2025-12-19 05:17:47 +00:00
feat(export): add API export system (#6878)
This commit is contained in:
committed by
GitHub
parent
c4528200b0
commit
669ec74e67
24
.env
24
.env
@@ -30,6 +30,30 @@ VALKEY_HOST=valkey
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_DB=0
|
||||
|
||||
# API scan settings
|
||||
|
||||
# The path to the directory where scan output should be stored
|
||||
DJANGO_TMP_OUTPUT_DIRECTORY = "/tmp/prowler_api_output"
|
||||
|
||||
# The maximum number of findings to process in a single batch
|
||||
DJANGO_FINDINGS_BATCH_SIZE = 1000
|
||||
|
||||
# The AWS access key to be used when uploading scan output to an S3 bucket
|
||||
# If left empty, default AWS credentials resolution behavior will be used
|
||||
DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID=""
|
||||
|
||||
# The AWS secret key to be used when uploading scan output to an S3 bucket
|
||||
DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY=""
|
||||
|
||||
# An optional AWS session token
|
||||
DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN=""
|
||||
|
||||
# The AWS region where your S3 bucket is located (e.g., "us-east-1")
|
||||
DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION=""
|
||||
|
||||
# The name of the S3 bucket where scan output should be stored
|
||||
DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET=""
|
||||
|
||||
# Django settings
|
||||
DJANGO_ALLOWED_HOSTS=localhost,127.0.0.1,prowler-api
|
||||
DJANGO_BIND_ADDRESS=0.0.0.0
|
||||
|
||||
@@ -8,6 +8,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
### Added
|
||||
- Social login integration with Google and GitHub [(#6906)](https://github.com/prowler-cloud/prowler/pull/6906)
|
||||
- Add API scan report system, now all scans launched from the API will generate a compressed file with the report in OCSF, CSV and HTML formats [(#6878)](https://github.com/prowler-cloud/prowler/pull/6878).
|
||||
- Configurable Sentry integration [(#6874)](https://github.com/prowler-cloud/prowler/pull/6874)
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -28,7 +28,7 @@ start_prod_server() {
|
||||
|
||||
start_worker() {
|
||||
echo "Starting the worker..."
|
||||
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,deletion -E --max-tasks-per-child 1
|
||||
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion -E --max-tasks-per-child 1
|
||||
}
|
||||
|
||||
start_worker_beat() {
|
||||
|
||||
@@ -7,7 +7,7 @@ from rest_framework_json_api.serializers import ValidationError
|
||||
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
|
||||
|
||||
|
||||
def set_tenant(func):
|
||||
def set_tenant(func=None, *, keep_tenant=False):
|
||||
"""
|
||||
Decorator to set the tenant context for a Celery task based on the provided tenant_id.
|
||||
|
||||
@@ -40,20 +40,29 @@ def set_tenant(func):
|
||||
# The tenant context will be set before the task logic executes.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
@transaction.atomic
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
tenant_id = kwargs.pop("tenant_id")
|
||||
except KeyError:
|
||||
raise KeyError("This task requires the tenant_id")
|
||||
try:
|
||||
uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise ValidationError("Tenant ID must be a valid UUID")
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@transaction.atomic
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
if not keep_tenant:
|
||||
tenant_id = kwargs.pop("tenant_id")
|
||||
else:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
except KeyError:
|
||||
raise KeyError("This task requires the tenant_id")
|
||||
try:
|
||||
uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise ValidationError("Tenant ID must be a valid UUID")
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
|
||||
|
||||
return func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(func)
|
||||
|
||||
15
api/src/backend/api/migrations/0012_scan_report_output.py
Normal file
15
api/src/backend/api/migrations/0012_scan_report_output.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0011_findings_performance_indexes_parent"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="scan",
|
||||
name="output_location",
|
||||
field=models.CharField(blank=True, max_length=200, null=True),
|
||||
),
|
||||
]
|
||||
@@ -414,6 +414,7 @@ class Scan(RowLevelSecurityProtectedModel):
|
||||
scheduler_task = models.ForeignKey(
|
||||
PeriodicTask, on_delete=models.CASCADE, null=True, blank=True
|
||||
)
|
||||
output_location = models.CharField(blank=True, null=True, max_length=200)
|
||||
# TODO: mutelist foreign key
|
||||
|
||||
class Meta(RowLevelSecurityProtectedModel.Meta):
|
||||
|
||||
@@ -4105,6 +4105,43 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ScanUpdateResponse'
|
||||
description: ''
|
||||
/api/v1/scans/{id}/report:
|
||||
get:
|
||||
operationId: scans_report_retrieve
|
||||
description: Returns a ZIP file containing the requested report
|
||||
summary: Download ZIP report
|
||||
parameters:
|
||||
- in: query
|
||||
name: fields[scan-reports]
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum:
|
||||
- id
|
||||
description: endpoint return only specific fields in the response on a per-type
|
||||
basis by including a fields[TYPE] query parameter.
|
||||
explode: false
|
||||
- in: path
|
||||
name: id
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
description: A UUID string identifying this scan.
|
||||
required: true
|
||||
tags:
|
||||
- Scan
|
||||
security:
|
||||
- jwtAuth: []
|
||||
responses:
|
||||
'200':
|
||||
description: Report obtained successfully
|
||||
'202':
|
||||
description: The task is in progress
|
||||
'403':
|
||||
description: There is a problem with credentials
|
||||
'404':
|
||||
description: The scan has no reports
|
||||
/api/v1/schedules/daily:
|
||||
post:
|
||||
operationId: schedules_daily_create
|
||||
|
||||
@@ -274,9 +274,10 @@ class TestValidateInvitation:
|
||||
expired_time = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
invitation.expires_at = expired_time
|
||||
|
||||
with patch("api.utils.Invitation.objects.using") as mock_using, patch(
|
||||
"api.utils.datetime"
|
||||
) as mock_datetime:
|
||||
with (
|
||||
patch("api.utils.Invitation.objects.using") as mock_using,
|
||||
patch("api.utils.datetime") as mock_datetime,
|
||||
):
|
||||
mock_db = mock_using.return_value
|
||||
mock_db.get.return_value = invitation
|
||||
mock_datetime.now.return_value = datetime.now(timezone.utc)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from conftest import API_JSON_CONTENT_TYPE, TEST_PASSWORD, TEST_USER
|
||||
from django.conf import settings
|
||||
from django.urls import reverse
|
||||
@@ -20,6 +24,7 @@ from api.models import (
|
||||
RoleProviderGroupRelationship,
|
||||
Scan,
|
||||
StateChoices,
|
||||
Task,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
@@ -2079,9 +2084,9 @@ class TestScanViewSet:
|
||||
("started_at.gte", "2024-01-01", 3),
|
||||
("started_at.lte", "2024-01-01", 0),
|
||||
("trigger", Scan.TriggerChoices.MANUAL, 1),
|
||||
("state", StateChoices.AVAILABLE, 2),
|
||||
("state", StateChoices.AVAILABLE, 1),
|
||||
("state", StateChoices.FAILED, 1),
|
||||
("state.in", f"{StateChoices.FAILED},{StateChoices.AVAILABLE}", 3),
|
||||
("state.in", f"{StateChoices.FAILED},{StateChoices.AVAILABLE}", 2),
|
||||
("trigger", Scan.TriggerChoices.MANUAL, 1),
|
||||
]
|
||||
),
|
||||
@@ -2156,6 +2161,159 @@ class TestScanViewSet:
|
||||
response = authenticated_client.get(reverse("scan-list"), {"sort": "invalid"})
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_report_executing(self, authenticated_client, scans_fixture):
|
||||
"""
|
||||
When the scan is still executing (state == EXECUTING), the view should return
|
||||
the task data with HTTP 202 and a Content-Location header.
|
||||
"""
|
||||
scan = scans_fixture[0]
|
||||
scan.state = StateChoices.EXECUTING
|
||||
scan.save()
|
||||
|
||||
task = Task.objects.create(tenant_id=scan.tenant_id)
|
||||
dummy_task_data = {"id": str(task.id), "state": StateChoices.EXECUTING}
|
||||
|
||||
scan.task = task
|
||||
scan.save()
|
||||
|
||||
with patch(
|
||||
"api.v1.views.TaskSerializer",
|
||||
return_value=type("DummySerializer", (), {"data": dummy_task_data}),
|
||||
):
|
||||
url = reverse("scan-report", kwargs={"pk": scan.id})
|
||||
response = authenticated_client.get(url)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
assert "Content-Location" in response
|
||||
assert dummy_task_data["id"] in response["Content-Location"]
|
||||
|
||||
def test_report_celery_task_executing(self, authenticated_client, scans_fixture):
|
||||
"""
|
||||
When the scan is not executing but a related celery task exists and is running,
|
||||
the view should return that task data with HTTP 202.
|
||||
"""
|
||||
scan = scans_fixture[0]
|
||||
scan.state = StateChoices.COMPLETED
|
||||
scan.output_location = "dummy"
|
||||
scan.save()
|
||||
|
||||
dummy_task = Task.objects.create(tenant_id=scan.tenant_id)
|
||||
dummy_task.id = "dummy-task-id"
|
||||
dummy_task_data = {"id": dummy_task.id, "state": StateChoices.EXECUTING}
|
||||
|
||||
with patch("api.v1.views.Task.objects.get", return_value=dummy_task), patch(
|
||||
"api.v1.views.TaskSerializer",
|
||||
return_value=type("DummySerializer", (), {"data": dummy_task_data}),
|
||||
):
|
||||
url = reverse("scan-report", kwargs={"pk": scan.id})
|
||||
response = authenticated_client.get(url)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
assert "Content-Location" in response
|
||||
assert dummy_task_data["id"] in response["Content-Location"]
|
||||
|
||||
def test_report_no_output_location(self, authenticated_client, scans_fixture):
|
||||
"""
|
||||
If the scan does not have an output_location, the view should return a 404.
|
||||
"""
|
||||
scan = scans_fixture[0]
|
||||
scan.state = StateChoices.COMPLETED
|
||||
scan.output_location = ""
|
||||
scan.save()
|
||||
|
||||
url = reverse("scan-report", kwargs={"pk": scan.id})
|
||||
response = authenticated_client.get(url)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert response.json()["errors"]["detail"] == "The scan has no reports."
|
||||
|
||||
def test_report_s3_no_credentials(
|
||||
self, authenticated_client, scans_fixture, monkeypatch
|
||||
):
|
||||
"""
|
||||
When output_location is an S3 URL and get_s3_client() raises a credentials exception,
|
||||
the view should return HTTP 403 with the proper error message.
|
||||
"""
|
||||
scan = scans_fixture[0]
|
||||
bucket = "test-bucket"
|
||||
key = "report.zip"
|
||||
scan.output_location = f"s3://{bucket}/{key}"
|
||||
scan.state = StateChoices.COMPLETED
|
||||
scan.save()
|
||||
|
||||
def fake_get_s3_client():
|
||||
raise NoCredentialsError()
|
||||
|
||||
monkeypatch.setattr("api.v1.views.get_s3_client", fake_get_s3_client)
|
||||
|
||||
url = reverse("scan-report", kwargs={"pk": scan.id})
|
||||
response = authenticated_client.get(url)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert (
|
||||
response.json()["errors"]["detail"]
|
||||
== "There is a problem with credentials."
|
||||
)
|
||||
|
||||
def test_report_s3_success(self, authenticated_client, scans_fixture, monkeypatch):
|
||||
"""
|
||||
When output_location is an S3 URL and the S3 client returns the file successfully,
|
||||
the view should return the ZIP file with HTTP 200 and proper headers.
|
||||
"""
|
||||
scan = scans_fixture[0]
|
||||
bucket = "test-bucket"
|
||||
key = "report.zip"
|
||||
scan.output_location = f"s3://{bucket}/{key}"
|
||||
scan.state = StateChoices.COMPLETED
|
||||
scan.save()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.v1.views.env", type("env", (), {"str": lambda self, key: bucket})()
|
||||
)
|
||||
|
||||
class FakeS3Client:
|
||||
def get_object(self, Bucket, Key):
|
||||
assert Bucket == bucket
|
||||
assert Key == key
|
||||
return {"Body": io.BytesIO(b"s3 zip content")}
|
||||
|
||||
monkeypatch.setattr("api.v1.views.get_s3_client", lambda: FakeS3Client())
|
||||
|
||||
url = reverse("scan-report", kwargs={"pk": scan.id})
|
||||
response = authenticated_client.get(url)
|
||||
assert response.status_code == 200
|
||||
expected_filename = os.path.basename("report.zip")
|
||||
content_disposition = response.get("Content-Disposition")
|
||||
assert content_disposition.startswith('attachment; filename="')
|
||||
assert f'filename="{expected_filename}"' in content_disposition
|
||||
assert response.content == b"s3 zip content"
|
||||
|
||||
def test_report_local_file(
|
||||
self, authenticated_client, scans_fixture, tmp_path, monkeypatch
|
||||
):
|
||||
"""
|
||||
When output_location is a local file path, the view should read the file from disk
|
||||
and return it with proper headers.
|
||||
"""
|
||||
scan = scans_fixture[0]
|
||||
file_content = b"local zip file content"
|
||||
file_path = tmp_path / "report.zip"
|
||||
file_path.write_bytes(file_content)
|
||||
|
||||
scan.output_location = str(file_path)
|
||||
scan.state = StateChoices.COMPLETED
|
||||
scan.save()
|
||||
|
||||
monkeypatch.setattr(
|
||||
glob,
|
||||
"glob",
|
||||
lambda pattern: [str(file_path)] if pattern == str(file_path) else [],
|
||||
)
|
||||
|
||||
url = reverse("scan-report", kwargs={"pk": scan.id})
|
||||
response = authenticated_client.get(url)
|
||||
assert response.status_code == 200
|
||||
assert response.content == file_content
|
||||
content_disposition = response.get("Content-Disposition")
|
||||
assert content_disposition.startswith('attachment; filename="')
|
||||
assert f'filename="{file_path.name}"' in content_disposition
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestTaskViewSet:
|
||||
|
||||
@@ -939,6 +939,14 @@ class ScanTaskSerializer(RLSSerializer):
|
||||
]
|
||||
|
||||
|
||||
class ScanReportSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(source="scan")
|
||||
|
||||
class Meta:
|
||||
resource_name = "scan-reports"
|
||||
fields = ["id"]
|
||||
|
||||
|
||||
class ResourceTagSerializer(RLSSerializer):
|
||||
"""
|
||||
Serializer for the ResourceTag model
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
|
||||
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
|
||||
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
|
||||
from celery.result import AsyncResult
|
||||
from config.env import env
|
||||
from config.settings.social_login import (
|
||||
GITHUB_OAUTH_CALLBACK_URL,
|
||||
GOOGLE_OAUTH_CALLBACK_URL,
|
||||
@@ -12,6 +17,7 @@ from django.contrib.postgres.search import SearchQuery
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, Exists, F, OuterRef, Prefetch, Q, Subquery, Sum
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.http import HttpResponse
|
||||
from django.urls import reverse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.cache import cache_control
|
||||
@@ -38,11 +44,11 @@ from rest_framework.permissions import SAFE_METHODS
|
||||
from rest_framework_json_api.views import RelationshipView, Response
|
||||
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
|
||||
from tasks.beat import schedule_provider_scan
|
||||
from tasks.jobs.export import get_s3_client
|
||||
from tasks.tasks import (
|
||||
check_provider_connection_task,
|
||||
delete_provider_task,
|
||||
delete_tenant_task,
|
||||
perform_scan_summary_task,
|
||||
perform_scan_task,
|
||||
)
|
||||
|
||||
@@ -121,6 +127,7 @@ from api.v1.serializers import (
|
||||
RoleSerializer,
|
||||
RoleUpdateSerializer,
|
||||
ScanCreateSerializer,
|
||||
ScanReportSerializer,
|
||||
ScanSerializer,
|
||||
ScanUpdateSerializer,
|
||||
ScheduleDailyCreateSerializer,
|
||||
@@ -1116,6 +1123,18 @@ class ProviderViewSet(BaseRLSViewSet):
|
||||
request=ScanCreateSerializer,
|
||||
responses={202: OpenApiResponse(response=TaskSerializer)},
|
||||
),
|
||||
report=extend_schema(
|
||||
tags=["Scan"],
|
||||
summary="Download ZIP report",
|
||||
description="Returns a ZIP file containing the requested report",
|
||||
request=ScanReportSerializer,
|
||||
responses={
|
||||
200: OpenApiResponse(description="Report obtained successfully"),
|
||||
202: OpenApiResponse(description="The task is in progress"),
|
||||
403: OpenApiResponse(description="There is a problem with credentials"),
|
||||
404: OpenApiResponse(description="The scan has no reports"),
|
||||
},
|
||||
),
|
||||
)
|
||||
@method_decorator(CACHE_DECORATOR, name="list")
|
||||
@method_decorator(CACHE_DECORATOR, name="retrieve")
|
||||
@@ -1164,6 +1183,10 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
return ScanCreateSerializer
|
||||
elif self.action == "partial_update":
|
||||
return ScanUpdateSerializer
|
||||
elif self.action == "report":
|
||||
if hasattr(self, "response_serializer_class"):
|
||||
return self.response_serializer_class
|
||||
return ScanReportSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
@@ -1181,6 +1204,93 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
)
|
||||
return Response(data=read_serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@action(detail=True, methods=["get"], url_name="report")
|
||||
def report(self, request, pk=None):
|
||||
scan_instance = self.get_object()
|
||||
|
||||
if scan_instance.state == StateChoices.EXECUTING:
|
||||
# If the scan is still running, return the task
|
||||
prowler_task = Task.objects.get(id=scan_instance.task.id)
|
||||
self.response_serializer_class = TaskSerializer
|
||||
output_serializer = self.get_serializer(prowler_task)
|
||||
return Response(
|
||||
data=output_serializer.data,
|
||||
status=status.HTTP_202_ACCEPTED,
|
||||
headers={
|
||||
"Content-Location": reverse(
|
||||
"task-detail", kwargs={"pk": output_serializer.data["id"]}
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
output_celery_task = Task.objects.get(
|
||||
task_runner_task__task_name="scan-report",
|
||||
task_runner_task__task_args__contains=pk,
|
||||
)
|
||||
self.response_serializer_class = TaskSerializer
|
||||
output_serializer = self.get_serializer(output_celery_task)
|
||||
if output_serializer.data["state"] == StateChoices.EXECUTING:
|
||||
# If the task is still running, return the task
|
||||
return Response(
|
||||
data=output_serializer.data,
|
||||
status=status.HTTP_202_ACCEPTED,
|
||||
headers={
|
||||
"Content-Location": reverse(
|
||||
"task-detail", kwargs={"pk": output_serializer.data["id"]}
|
||||
)
|
||||
},
|
||||
)
|
||||
except Task.DoesNotExist:
|
||||
# If the task does not exist, it means that the task is removed from the database
|
||||
pass
|
||||
|
||||
output_location = scan_instance.output_location
|
||||
if not output_location:
|
||||
return Response(
|
||||
{"detail": "The scan has no reports."},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
if scan_instance.output_location.startswith("s3://"):
|
||||
try:
|
||||
s3_client = get_s3_client()
|
||||
except (ClientError, NoCredentialsError, ParamValidationError):
|
||||
return Response(
|
||||
{"detail": "There is a problem with credentials."},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
bucket_name = env.str("DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET")
|
||||
key = output_location[len(f"s3://{bucket_name}/") :]
|
||||
try:
|
||||
s3_object = s3_client.get_object(Bucket=bucket_name, Key=key)
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "NoSuchKey":
|
||||
return Response(
|
||||
{"detail": "The scan has no reports."},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
return Response(
|
||||
{"detail": "There is a problem with credentials."},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
file_content = s3_object["Body"].read()
|
||||
filename = os.path.basename(output_location.split("/")[-1])
|
||||
else:
|
||||
zip_files = glob.glob(output_location)
|
||||
file_path = zip_files[0]
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
filename = os.path.basename(file_path)
|
||||
|
||||
response = HttpResponse(
|
||||
file_content, content_type="application/x-zip-compressed"
|
||||
)
|
||||
response["Content-Disposition"] = f'attachment; filename="{filename}"'
|
||||
return response
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
input_serializer = self.get_serializer(data=request.data)
|
||||
input_serializer.is_valid(raise_exception=True)
|
||||
@@ -1195,10 +1305,6 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
# Disabled for now
|
||||
# checks_to_execute=scan.scanner_args.get("checks_to_execute"),
|
||||
},
|
||||
link=perform_scan_summary_task.si(
|
||||
tenant_id=self.request.tenant_id,
|
||||
scan_id=str(scan.id),
|
||||
),
|
||||
)
|
||||
|
||||
scan.task_id = task.id
|
||||
|
||||
@@ -221,3 +221,18 @@ CACHE_STALE_WHILE_REVALIDATE = env.int("DJANGO_STALE_WHILE_REVALIDATE", 60)
|
||||
TESTING = False
|
||||
|
||||
FINDINGS_MAX_DAYS_IN_RANGE = env.int("DJANGO_FINDINGS_MAX_DAYS_IN_RANGE", 7)
|
||||
|
||||
|
||||
# API export settings
|
||||
DJANGO_TMP_OUTPUT_DIRECTORY = env.str(
|
||||
"DJANGO_TMP_OUTPUT_DIRECTORY", "/tmp/prowler_api_output"
|
||||
)
|
||||
DJANGO_FINDINGS_BATCH_SIZE = env.str("DJANGO_FINDINGS_BATCH_SIZE", 1000)
|
||||
|
||||
DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET = env.str("DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET", "")
|
||||
DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID = env.str("DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID", "")
|
||||
DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY = env.str(
|
||||
"DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY", ""
|
||||
)
|
||||
DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN = env.str("DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN", "")
|
||||
DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION = env.str("DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION", "")
|
||||
|
||||
@@ -486,7 +486,7 @@ def scans_fixture(tenants_fixture, providers_fixture):
|
||||
name="Scan 1",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.AVAILABLE,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
started_at="2024-01-02T00:00:00Z",
|
||||
)
|
||||
|
||||
156
api/src/backend/tasks/jobs/export.py
Normal file
156
api/src/backend/tasks/jobs/export.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import boto3
|
||||
import config.django.base as base
|
||||
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
|
||||
from celery.utils.log import get_task_logger
|
||||
from django.conf import settings
|
||||
|
||||
from prowler.config.config import (
|
||||
csv_file_suffix,
|
||||
html_file_suffix,
|
||||
json_ocsf_file_suffix,
|
||||
output_file_timestamp,
|
||||
)
|
||||
from prowler.lib.outputs.csv.csv import CSV
|
||||
from prowler.lib.outputs.html.html import HTML
|
||||
from prowler.lib.outputs.ocsf.ocsf import OCSF
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
# Predefined mapping for output formats and their configurations
|
||||
OUTPUT_FORMATS_MAPPING = {
|
||||
"csv": {
|
||||
"class": CSV,
|
||||
"suffix": csv_file_suffix,
|
||||
"kwargs": {},
|
||||
},
|
||||
"json-ocsf": {"class": OCSF, "suffix": json_ocsf_file_suffix, "kwargs": {}},
|
||||
"html": {"class": HTML, "suffix": html_file_suffix, "kwargs": {"stats": {}}},
|
||||
}
|
||||
|
||||
|
||||
def _compress_output_files(output_directory: str) -> str:
|
||||
"""
|
||||
Compress output files from all configured output formats into a ZIP archive.
|
||||
Args:
|
||||
output_directory (str): The directory where the output files are located.
|
||||
The function looks up all known suffixes in OUTPUT_FORMATS_MAPPING
|
||||
and compresses those files into a single ZIP.
|
||||
Returns:
|
||||
str: The full path to the newly created ZIP archive.
|
||||
"""
|
||||
zip_path = f"{output_directory}.zip"
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
for suffix in [config["suffix"] for config in OUTPUT_FORMATS_MAPPING.values()]:
|
||||
zipf.write(
|
||||
f"{output_directory}{suffix}",
|
||||
f"output/{output_directory.split('/')[-1]}{suffix}",
|
||||
)
|
||||
|
||||
return zip_path
|
||||
|
||||
|
||||
def get_s3_client():
|
||||
"""
|
||||
Create and return a boto3 S3 client using AWS credentials from environment variables.
|
||||
|
||||
This function attempts to initialize an S3 client by reading the AWS access key, secret key,
|
||||
session token, and region from environment variables. It then validates the client by listing
|
||||
available S3 buckets. If an error occurs during this process (for example, due to missing or
|
||||
invalid credentials), it falls back to creating an S3 client without explicitly provided credentials,
|
||||
which may rely on other configuration sources (e.g., IAM roles).
|
||||
|
||||
Returns:
|
||||
boto3.client: A configured S3 client instance.
|
||||
|
||||
Raises:
|
||||
ClientError, NoCredentialsError, or ParamValidationError if both attempts to create a client fail.
|
||||
"""
|
||||
s3_client = None
|
||||
try:
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=settings.DJANGO_OUTPUT_S3_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.DJANGO_OUTPUT_S3_AWS_SECRET_ACCESS_KEY,
|
||||
aws_session_token=settings.DJANGO_OUTPUT_S3_AWS_SESSION_TOKEN,
|
||||
region_name=settings.DJANGO_OUTPUT_S3_AWS_DEFAULT_REGION,
|
||||
)
|
||||
s3_client.list_buckets()
|
||||
except (ClientError, NoCredentialsError, ParamValidationError, ValueError):
|
||||
s3_client = boto3.client("s3")
|
||||
s3_client.list_buckets()
|
||||
|
||||
return s3_client
|
||||
|
||||
|
||||
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
|
||||
"""
|
||||
Upload the specified ZIP file to an S3 bucket.
|
||||
If the S3 bucket environment variables are not configured,
|
||||
the function returns None without performing an upload.
|
||||
Args:
|
||||
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
|
||||
zip_path (str): The local file system path to the ZIP file to be uploaded.
|
||||
scan_id (str): The scan identifier, used as part of the S3 key prefix.
|
||||
Returns:
|
||||
str: The S3 URI of the uploaded file (e.g., "s3://<bucket>/<key>") if successful.
|
||||
None: If the required environment variables for the S3 bucket are not set.
|
||||
Raises:
|
||||
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
|
||||
"""
|
||||
if not base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET:
|
||||
return
|
||||
|
||||
try:
|
||||
s3 = get_s3_client()
|
||||
s3_key = f"{tenant_id}/{scan_id}/{os.path.basename(zip_path)}"
|
||||
s3.upload_file(
|
||||
Filename=zip_path,
|
||||
Bucket=base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET,
|
||||
Key=s3_key,
|
||||
)
|
||||
return f"s3://{base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET}/{s3_key}"
|
||||
except (ClientError, NoCredentialsError, ParamValidationError, ValueError) as e:
|
||||
logger.error(f"S3 upload failed: {str(e)}")
|
||||
|
||||
|
||||
def _generate_output_directory(
|
||||
output_directory, prowler_provider: object, tenant_id: str, scan_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate a file system path for the output directory of a prowler scan.
|
||||
|
||||
This function constructs the output directory path by combining a base
|
||||
temporary output directory, the tenant ID, the scan ID, and details about
|
||||
the prowler provider along with a timestamp. The resulting path is used to
|
||||
store the output files of a prowler scan.
|
||||
|
||||
Note:
|
||||
This function depends on one external variable:
|
||||
- `output_file_timestamp`: A timestamp (as a string) used to uniquely identify the output.
|
||||
|
||||
Args:
|
||||
output_directory (str): The base output directory.
|
||||
prowler_provider (object): An identifier or descriptor for the prowler provider.
|
||||
Typically, this is a string indicating the provider (e.g., "aws").
|
||||
tenant_id (str): The unique identifier for the tenant.
|
||||
scan_id (str): The unique identifier for the scan.
|
||||
|
||||
Returns:
|
||||
str: The constructed file system path for the prowler scan output directory.
|
||||
|
||||
Example:
|
||||
>>> _generate_output_directory("/tmp", "aws", "tenant-1234", "scan-5678")
|
||||
'/tmp/tenant-1234/aws/scan-5678/prowler-output-2023-02-15T12:34:56'
|
||||
"""
|
||||
path = (
|
||||
f"{output_directory}/{tenant_id}/{scan_id}/prowler-output-"
|
||||
f"{prowler_provider}-{output_file_timestamp}"
|
||||
)
|
||||
os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
|
||||
|
||||
return path
|
||||
@@ -1,14 +1,28 @@
|
||||
from celery import shared_task
|
||||
from shutil import rmtree
|
||||
|
||||
from celery import chain, shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.celery import RLSTask
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
from tasks.jobs.connection import check_provider_connection
|
||||
from tasks.jobs.deletion import delete_provider, delete_tenant
|
||||
from tasks.jobs.export import (
|
||||
OUTPUT_FORMATS_MAPPING,
|
||||
_compress_output_files,
|
||||
_generate_output_directory,
|
||||
_upload_to_s3,
|
||||
)
|
||||
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan
|
||||
from tasks.utils import get_next_execution_datetime
|
||||
from tasks.utils import batched, get_next_execution_datetime
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
from api.decorators import set_tenant
|
||||
from api.models import Scan, StateChoices
|
||||
from api.models import Finding, Provider, Scan, ScanSummary, StateChoices
|
||||
from api.utils import initialize_prowler_provider
|
||||
from prowler.lib.outputs.finding import Finding as FindingOutput
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(base=RLSTask, name="provider-connection-check")
|
||||
@@ -68,13 +82,20 @@ def perform_scan_task(
|
||||
Returns:
|
||||
dict: The result of the scan execution, typically including the status and results of the performed checks.
|
||||
"""
|
||||
return perform_prowler_scan(
|
||||
result = perform_prowler_scan(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
provider_id=provider_id,
|
||||
checks_to_execute=checks_to_execute,
|
||||
)
|
||||
|
||||
chain(
|
||||
perform_scan_summary_task.si(tenant_id, scan_id),
|
||||
generate_outputs.si(scan_id, provider_id, tenant_id=tenant_id),
|
||||
).apply_async()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(base=RLSTask, bind=True, name="scan-perform-scheduled", queue="scans")
|
||||
def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
@@ -135,12 +156,11 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
scheduler_task_id=periodic_task_instance.id,
|
||||
)
|
||||
|
||||
perform_scan_summary_task.apply_async(
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
"scan_id": str(scan_instance.id),
|
||||
}
|
||||
)
|
||||
chain(
|
||||
perform_scan_summary_task.si(tenant_id, scan_instance.id),
|
||||
generate_outputs.si(str(scan_instance.id), provider_id, tenant_id=tenant_id),
|
||||
).apply_async()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -152,3 +172,108 @@ def perform_scan_summary_task(tenant_id: str, scan_id: str):
|
||||
@shared_task(name="tenant-deletion", queue="deletion")
|
||||
def delete_tenant_task(tenant_id: str):
|
||||
return delete_tenant(pk=tenant_id)
|
||||
|
||||
|
||||
@shared_task(
|
||||
base=RLSTask,
|
||||
name="scan-report",
|
||||
queue="scan-reports",
|
||||
)
|
||||
@set_tenant(keep_tenant=True)
|
||||
def generate_outputs(scan_id: str, provider_id: str, tenant_id: str):
|
||||
"""
|
||||
Process findings in batches and generate output files in multiple formats.
|
||||
|
||||
This function retrieves findings associated with a scan, processes them
|
||||
in batches of 50, and writes each batch to the corresponding output files.
|
||||
It reuses output writer instances across batches, updates them with each
|
||||
batch of transformed findings, and uses a flag to indicate when the final
|
||||
batch is being processed. Finally, the output files are compressed and
|
||||
uploaded to S3.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant identifier.
|
||||
scan_id (str): The scan identifier.
|
||||
provider_id (str): The provider_id id to be used in generating outputs.
|
||||
"""
|
||||
# Initialize the prowler provider
|
||||
prowler_provider = initialize_prowler_provider(Provider.objects.get(id=provider_id))
|
||||
|
||||
# Get the provider UID
|
||||
provider_uid = Provider.objects.get(id=provider_id).uid
|
||||
|
||||
# Generate and ensure the output directory exists
|
||||
output_directory = _generate_output_directory(
|
||||
DJANGO_TMP_OUTPUT_DIRECTORY, provider_uid, tenant_id, scan_id
|
||||
)
|
||||
|
||||
# Define auxiliary variables
|
||||
output_writers = {}
|
||||
scan_summary = FindingOutput._transform_findings_stats(
|
||||
ScanSummary.objects.filter(scan_id=scan_id)
|
||||
)
|
||||
|
||||
# Retrieve findings queryset
|
||||
findings_qs = Finding.all_objects.filter(scan_id=scan_id).order_by("uid")
|
||||
|
||||
# Process findings in batches
|
||||
for batch, is_last_batch in batched(
|
||||
findings_qs.iterator(), DJANGO_FINDINGS_BATCH_SIZE
|
||||
):
|
||||
finding_outputs = [
|
||||
FindingOutput.transform_api_finding(finding, prowler_provider)
|
||||
for finding in batch
|
||||
]
|
||||
|
||||
# Generate output files
|
||||
for mode, config in OUTPUT_FORMATS_MAPPING.items():
|
||||
kwargs = dict(config.get("kwargs", {}))
|
||||
if mode == "html":
|
||||
kwargs["provider"] = prowler_provider
|
||||
kwargs["stats"] = scan_summary
|
||||
|
||||
writer_class = config["class"]
|
||||
if writer_class in output_writers:
|
||||
writer = output_writers[writer_class]
|
||||
writer.transform(finding_outputs)
|
||||
writer.close_file = is_last_batch
|
||||
else:
|
||||
writer = writer_class(
|
||||
findings=finding_outputs,
|
||||
file_path=output_directory,
|
||||
file_extension=config["suffix"],
|
||||
from_cli=False,
|
||||
)
|
||||
writer.close_file = is_last_batch
|
||||
output_writers[writer_class] = writer
|
||||
|
||||
# Write the current batch using the writer
|
||||
writer.batch_write_data_to_file(**kwargs)
|
||||
|
||||
# TODO: Refactor the output classes to avoid this manual reset
|
||||
writer._data = []
|
||||
|
||||
# Compress output files
|
||||
output_directory = _compress_output_files(output_directory)
|
||||
|
||||
# Save to configured storage
|
||||
uploaded = _upload_to_s3(tenant_id, output_directory, scan_id)
|
||||
|
||||
if uploaded:
|
||||
output_directory = uploaded
|
||||
uploaded = True
|
||||
# Remove the local files after upload
|
||||
rmtree(DJANGO_TMP_OUTPUT_DIRECTORY, ignore_errors=True)
|
||||
else:
|
||||
uploaded = False
|
||||
|
||||
# Update the scan instance with the output path
|
||||
Scan.all_objects.filter(id=scan_id).update(output_location=output_directory)
|
||||
|
||||
logger.info(f"Scan output files generated, output location: {output_directory}")
|
||||
|
||||
return {
|
||||
"upload": uploaded,
|
||||
"scan_id": scan_id,
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from django_celery_beat.models import IntervalSchedule, PeriodicTask
|
||||
from django_celery_results.models import TaskResult
|
||||
from tasks.utils import get_next_execution_datetime
|
||||
from tasks.utils import batched, get_next_execution_datetime
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -74,3 +74,29 @@ class TestGetNextExecutionDatetime:
|
||||
get_next_execution_datetime(
|
||||
task_id=task_result.task_id, provider_id="nonexistent"
|
||||
)
|
||||
|
||||
|
||||
class TestBatchedFunction:
|
||||
def test_empty_iterable(self):
|
||||
result = list(batched([], 3))
|
||||
assert result == [([], True)]
|
||||
|
||||
def test_exact_batches(self):
|
||||
result = list(batched([1, 2, 3, 4], 2))
|
||||
expected = [([1, 2], False), ([3, 4], False), ([], True)]
|
||||
assert result == expected
|
||||
|
||||
def test_inexact_batches(self):
|
||||
result = list(batched([1, 2, 3, 4, 5], 2))
|
||||
expected = [([1, 2], False), ([3, 4], False), ([5], True)]
|
||||
assert result == expected
|
||||
|
||||
def test_batch_size_one(self):
|
||||
result = list(batched([1, 2, 3], 1))
|
||||
expected = [([1], False), ([2], False), ([3], False), ([], True)]
|
||||
assert result == expected
|
||||
|
||||
def test_batch_size_greater_than_length(self):
|
||||
result = list(batched([1, 2, 3], 5))
|
||||
expected = [([1, 2, 3], True)]
|
||||
assert result == expected
|
||||
|
||||
@@ -24,3 +24,27 @@ def get_next_execution_datetime(task_id: int, provider_id: str) -> datetime:
|
||||
)
|
||||
|
||||
return current_scheduled_time + timedelta(**{interval.period: interval.every})
|
||||
|
||||
|
||||
def batched(iterable, batch_size):
|
||||
"""
|
||||
Yield successive batches from an iterable.
|
||||
|
||||
Args:
|
||||
iterable: An iterable source of items.
|
||||
batch_size (int): The number of items per batch.
|
||||
|
||||
Yields:
|
||||
tuple: A pair (batch, is_last_batch) where:
|
||||
- batch (list): A list of items (with length equal to batch_size,
|
||||
except possibly for the last batch).
|
||||
- is_last_batch (bool): True if this is the final batch, False otherwise.
|
||||
"""
|
||||
batch = []
|
||||
for item in iterable:
|
||||
batch.append(item)
|
||||
if len(batch) == batch_size:
|
||||
yield batch, False
|
||||
batch = []
|
||||
|
||||
yield batch, True
|
||||
|
||||
@@ -16,6 +16,7 @@ services:
|
||||
volumes:
|
||||
- "./api/src/backend:/home/prowler/backend"
|
||||
- "./api/pyproject.toml:/home/prowler/pyproject.toml"
|
||||
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
@@ -85,6 +86,8 @@ services:
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
volumes:
|
||||
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
|
||||
depends_on:
|
||||
valkey:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -7,6 +7,8 @@ services:
|
||||
required: false
|
||||
ports:
|
||||
- "${DJANGO_PORT:-8080}:${DJANGO_PORT:-8080}"
|
||||
volumes:
|
||||
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
@@ -65,6 +67,8 @@ services:
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
volumes:
|
||||
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
|
||||
depends_on:
|
||||
valkey:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -351,7 +351,6 @@ def prowler():
|
||||
if mode == "csv":
|
||||
csv_output = CSV(
|
||||
findings=finding_outputs,
|
||||
create_file_descriptor=True,
|
||||
file_path=f"{filename}{csv_file_suffix}",
|
||||
)
|
||||
generated_outputs["regular"].append(csv_output)
|
||||
@@ -361,7 +360,6 @@ def prowler():
|
||||
if mode == "json-asff":
|
||||
asff_output = ASFF(
|
||||
findings=finding_outputs,
|
||||
create_file_descriptor=True,
|
||||
file_path=f"{filename}{json_asff_file_suffix}",
|
||||
)
|
||||
generated_outputs["regular"].append(asff_output)
|
||||
@@ -371,7 +369,6 @@ def prowler():
|
||||
if mode == "json-ocsf":
|
||||
json_output = OCSF(
|
||||
findings=finding_outputs,
|
||||
create_file_descriptor=True,
|
||||
file_path=f"{filename}{json_ocsf_file_suffix}",
|
||||
)
|
||||
generated_outputs["regular"].append(json_output)
|
||||
@@ -379,7 +376,6 @@ def prowler():
|
||||
if mode == "html":
|
||||
html_output = HTML(
|
||||
findings=finding_outputs,
|
||||
create_file_descriptor=True,
|
||||
file_path=f"{filename}{html_file_suffix}",
|
||||
)
|
||||
generated_outputs["regular"].append(html_output)
|
||||
@@ -402,7 +398,6 @@ def prowler():
|
||||
cis = AWSCIS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(cis)
|
||||
@@ -416,7 +411,6 @@ def prowler():
|
||||
mitre_attack = AWSMitreAttack(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(mitre_attack)
|
||||
@@ -430,7 +424,6 @@ def prowler():
|
||||
ens = AWSENS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(ens)
|
||||
@@ -444,7 +437,6 @@ def prowler():
|
||||
aws_well_architected = AWSWellArchitected(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(aws_well_architected)
|
||||
@@ -458,7 +450,6 @@ def prowler():
|
||||
iso27001 = AWSISO27001(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(iso27001)
|
||||
@@ -472,7 +463,6 @@ def prowler():
|
||||
kisa_ismsp = AWSKISAISMSP(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(kisa_ismsp)
|
||||
@@ -485,7 +475,6 @@ def prowler():
|
||||
generic_compliance = GenericCompliance(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(generic_compliance)
|
||||
@@ -502,7 +491,6 @@ def prowler():
|
||||
cis = AzureCIS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(cis)
|
||||
@@ -516,7 +504,6 @@ def prowler():
|
||||
mitre_attack = AzureMitreAttack(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(mitre_attack)
|
||||
@@ -530,7 +517,6 @@ def prowler():
|
||||
ens = AzureENS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(ens)
|
||||
@@ -543,7 +529,6 @@ def prowler():
|
||||
generic_compliance = GenericCompliance(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(generic_compliance)
|
||||
@@ -560,7 +545,6 @@ def prowler():
|
||||
cis = GCPCIS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(cis)
|
||||
@@ -574,7 +558,6 @@ def prowler():
|
||||
mitre_attack = GCPMitreAttack(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(mitre_attack)
|
||||
@@ -588,7 +571,6 @@ def prowler():
|
||||
ens = GCPENS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(ens)
|
||||
@@ -601,7 +583,6 @@ def prowler():
|
||||
generic_compliance = GenericCompliance(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(generic_compliance)
|
||||
@@ -618,7 +599,6 @@ def prowler():
|
||||
cis = KubernetesCIS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(cis)
|
||||
@@ -631,7 +611,6 @@ def prowler():
|
||||
generic_compliance = GenericCompliance(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(generic_compliance)
|
||||
@@ -648,7 +627,6 @@ def prowler():
|
||||
cis = Microsoft365CIS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(cis)
|
||||
@@ -661,7 +639,6 @@ def prowler():
|
||||
generic_compliance = GenericCompliance(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
create_file_descriptor=True,
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(generic_compliance)
|
||||
|
||||
@@ -29,11 +29,11 @@ class ComplianceOutput(Output):
|
||||
self,
|
||||
findings: List[Finding],
|
||||
compliance: Compliance,
|
||||
create_file_descriptor: bool = False,
|
||||
file_path: str = None,
|
||||
file_extension: str = "",
|
||||
) -> None:
|
||||
self._data = []
|
||||
self.file_descriptor = None
|
||||
|
||||
if not file_extension and file_path:
|
||||
self._file_extension = "".join(Path(file_path).suffixes)
|
||||
@@ -48,7 +48,7 @@ class ComplianceOutput(Output):
|
||||
else compliance.Framework
|
||||
)
|
||||
self.transform(findings, compliance, compliance_name)
|
||||
if create_file_descriptor:
|
||||
if not self._file_descriptor and file_path:
|
||||
self.create_file_descriptor(file_path)
|
||||
|
||||
def batch_write_data_to_file(self) -> None:
|
||||
|
||||
@@ -98,10 +98,12 @@ class CSV(Output):
|
||||
fieldnames=self._data[0].keys(),
|
||||
delimiter=";",
|
||||
)
|
||||
csv_writer.writeheader()
|
||||
if self._file_descriptor.tell() == 0:
|
||||
csv_writer.writeheader()
|
||||
for finding in self._data:
|
||||
csv_writer.writerow(finding)
|
||||
self._file_descriptor.close()
|
||||
if self.close_file or self._from_cli:
|
||||
self._file_descriptor.close()
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from prowler.config.config import prowler_version
|
||||
from prowler.lib.check.models import Check_Report, CheckMetadata
|
||||
from prowler.lib.check.models import (
|
||||
Check_Report,
|
||||
CheckMetadata,
|
||||
Code,
|
||||
Recommendation,
|
||||
Remediation,
|
||||
)
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.outputs.common import Status, fill_common_finding_data
|
||||
from prowler.lib.outputs.compliance.compliance import get_check_compliance
|
||||
from prowler.lib.outputs.utils import unroll_tags
|
||||
from prowler.lib.utils.utils import dict_to_lowercase, get_nested_attribute
|
||||
from prowler.providers.common.provider import Provider
|
||||
|
||||
@@ -267,3 +275,193 @@ class Finding(BaseModel):
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
raise error
|
||||
|
||||
@classmethod
|
||||
def transform_api_finding(cls, finding, provider) -> "Finding":
|
||||
"""
|
||||
Transform a FindingModel instance into an API-friendly Finding object.
|
||||
|
||||
This class method extracts data from a FindingModel instance and maps its
|
||||
properties to a new Finding object. The transformation populates various
|
||||
fields including authentication details, timestamp, account information,
|
||||
check metadata (such as provider, check ID, title, type, service, severity,
|
||||
and remediation details), as well as resource-specific data. The resulting
|
||||
Finding object is structured for use in API responses or further processing.
|
||||
|
||||
Args:
|
||||
finding (API Finding): An API Finding instance containing data from the database.
|
||||
provider (Provider): the provider object.
|
||||
|
||||
Returns:
|
||||
Finding: A new Finding instance populated with data from the provided model.
|
||||
"""
|
||||
# Missing Finding's API values
|
||||
finding.muted = False
|
||||
finding.resource_details = ""
|
||||
resource = finding.resources.first()
|
||||
finding.resource_arn = resource.uid
|
||||
finding.resource_name = resource.name
|
||||
|
||||
# TODO: Change this when the API has all the values
|
||||
finding.resource = {}
|
||||
|
||||
finding.resource_id = resource.name if provider.type == "aws" else resource.uid
|
||||
|
||||
# AWS specified field
|
||||
finding.region = resource.region
|
||||
# Azure, GCP specified field
|
||||
finding.location = resource.region
|
||||
# K8s specified field
|
||||
if provider.type == "kubernetes":
|
||||
finding.namespace = resource.region.removeprefix("namespace: ")
|
||||
if provider.type == "azure":
|
||||
finding.subscription = list(provider.identity.subscriptions.keys())[0]
|
||||
elif provider.type == "gcp":
|
||||
finding.project_id = list(provider.projects.keys())[0]
|
||||
|
||||
finding.check_metadata = CheckMetadata(
|
||||
Provider=finding.check_metadata["provider"],
|
||||
CheckID=finding.check_metadata["checkid"],
|
||||
CheckTitle=finding.check_metadata["checktitle"],
|
||||
CheckType=finding.check_metadata["checktype"],
|
||||
ServiceName=finding.check_metadata["servicename"],
|
||||
SubServiceName=finding.check_metadata["subservicename"],
|
||||
Severity=finding.check_metadata["severity"],
|
||||
ResourceType=finding.check_metadata["resourcetype"],
|
||||
Description=finding.check_metadata["description"],
|
||||
Risk=finding.check_metadata["risk"],
|
||||
RelatedUrl=finding.check_metadata["relatedurl"],
|
||||
Remediation=Remediation(
|
||||
Recommendation=Recommendation(
|
||||
Text=finding.check_metadata["remediation"]["recommendation"][
|
||||
"text"
|
||||
],
|
||||
Url=finding.check_metadata["remediation"]["recommendation"]["url"],
|
||||
),
|
||||
Code=Code(
|
||||
NativeIaC=finding.check_metadata["remediation"]["code"][
|
||||
"nativeiac"
|
||||
],
|
||||
Terraform=finding.check_metadata["remediation"]["code"][
|
||||
"terraform"
|
||||
],
|
||||
CLI=finding.check_metadata["remediation"]["code"]["cli"],
|
||||
Other=finding.check_metadata["remediation"]["code"]["other"],
|
||||
),
|
||||
),
|
||||
ResourceIdTemplate=finding.check_metadata["resourceidtemplate"],
|
||||
Categories=finding.check_metadata["categories"],
|
||||
DependsOn=finding.check_metadata["dependson"],
|
||||
RelatedTo=finding.check_metadata["relatedto"],
|
||||
Notes=finding.check_metadata["notes"],
|
||||
)
|
||||
finding.resource_tags = unroll_tags(
|
||||
[{"key": tag.key, "value": tag.value} for tag in resource.tags.all()]
|
||||
)
|
||||
return cls.generate_output(provider, finding, SimpleNamespace())
|
||||
|
||||
def _transform_findings_stats(scan_summaries: list[dict]) -> dict:
|
||||
"""
|
||||
Aggregate and transform scan summary data into findings statistics.
|
||||
|
||||
This function processes a list of scan summary objects and calculates overall
|
||||
metrics such as the total number of passed and failed findings (including muted counts),
|
||||
as well as a breakdown of results by severity (critical, high, medium, and low).
|
||||
It also retrieves the unique resource count from the associated scan information.
|
||||
The final output is a dictionary of aggregated statistics intended for reporting or
|
||||
further analysis.
|
||||
|
||||
Args:
|
||||
scan_summaries (list[dict]): A list of scan summary objects. Each object is expected
|
||||
to have attributes including:
|
||||
- _pass: Number of passed findings.
|
||||
- fail: Number of failed findings.
|
||||
- total: Total number of findings.
|
||||
- muted: Number indicating if the finding is muted.
|
||||
- severity: A string representing the severity level.
|
||||
Additionally, the first scan summary should have an associated
|
||||
`scan` attribute with a `unique_resource_count`.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing aggregated findings statistics:
|
||||
- total_pass: Total number of passed findings.
|
||||
- total_muted_pass: Total number of muted passed findings.
|
||||
- total_fail: Total number of failed findings.
|
||||
- total_muted_fail: Total number of muted failed findings.
|
||||
- resources_count: The unique resource count extracted from the scan.
|
||||
- findings_count: Total number of findings.
|
||||
- total_critical_severity_fail: Failed findings with critical severity.
|
||||
- total_critical_severity_pass: Passed findings with critical severity.
|
||||
- total_high_severity_fail: Failed findings with high severity.
|
||||
- total_high_severity_pass: Passed findings with high severity.
|
||||
- total_medium_severity_fail: Failed findings with medium severity.
|
||||
- total_medium_severity_pass: Passed findings with medium severity.
|
||||
- total_low_severity_fail: Failed findings with low severity.
|
||||
- total_low_severity_pass: Passed findings with low severity.
|
||||
- all_fails_are_muted: A boolean indicating whether all failing findings are muted.
|
||||
"""
|
||||
# Initialize overall counters
|
||||
total_pass = 0
|
||||
total_fail = 0
|
||||
muted_pass = 0
|
||||
muted_fail = 0
|
||||
findings_count = 0
|
||||
resources_count = scan_summaries[0].scan.unique_resource_count
|
||||
|
||||
# Initialize severity breakdown counters
|
||||
critical_severity_pass = 0
|
||||
critical_severity_fail = 0
|
||||
high_severity_pass = 0
|
||||
high_severity_fail = 0
|
||||
medium_severity_pass = 0
|
||||
medium_severity_fail = 0
|
||||
low_severity_pass = 0
|
||||
low_severity_fail = 0
|
||||
|
||||
# Loop over each row from the database
|
||||
for row in scan_summaries:
|
||||
# Accumulate overall totals
|
||||
total_pass += row._pass
|
||||
total_fail += row.fail
|
||||
findings_count += row.total
|
||||
|
||||
if row.muted > 0:
|
||||
if row._pass > 0:
|
||||
muted_pass += row._pass
|
||||
if row.fail > 0:
|
||||
muted_fail += row.fail
|
||||
|
||||
sev = row.severity.lower()
|
||||
if sev == "critical":
|
||||
critical_severity_pass += row._pass
|
||||
critical_severity_fail += row.fail
|
||||
elif sev == "high":
|
||||
high_severity_pass += row._pass
|
||||
high_severity_fail += row.fail
|
||||
elif sev == "medium":
|
||||
medium_severity_pass += row._pass
|
||||
medium_severity_fail += row.fail
|
||||
elif sev == "low":
|
||||
low_severity_pass += row._pass
|
||||
low_severity_fail += row.fail
|
||||
|
||||
all_fails_are_muted = (total_fail > 0) and (total_fail == muted_fail)
|
||||
|
||||
stats = {
|
||||
"total_pass": total_pass,
|
||||
"total_muted_pass": muted_pass,
|
||||
"total_fail": total_fail,
|
||||
"total_muted_fail": muted_fail,
|
||||
"resources_count": resources_count,
|
||||
"findings_count": findings_count,
|
||||
"total_critical_severity_fail": critical_severity_fail,
|
||||
"total_critical_severity_pass": critical_severity_pass,
|
||||
"total_high_severity_fail": high_severity_fail,
|
||||
"total_high_severity_pass": high_severity_pass,
|
||||
"total_medium_severity_fail": medium_severity_fail,
|
||||
"total_medium_severity_pass": medium_severity_pass,
|
||||
"total_low_severity_fail": low_severity_fail,
|
||||
"total_low_severity_pass": low_severity_pass,
|
||||
"all_fails_are_muted": all_fails_are_muted,
|
||||
}
|
||||
return stats
|
||||
|
||||
@@ -74,12 +74,15 @@ class HTML(Output):
|
||||
and not self._file_descriptor.closed
|
||||
and self._data
|
||||
):
|
||||
HTML.write_header(self._file_descriptor, provider, stats)
|
||||
if self._file_descriptor.tell() == 0:
|
||||
HTML.write_header(
|
||||
self._file_descriptor, provider, stats, self._from_cli
|
||||
)
|
||||
for finding in self._data:
|
||||
self._file_descriptor.write(finding)
|
||||
HTML.write_footer(self._file_descriptor)
|
||||
# Close file descriptor
|
||||
self._file_descriptor.close()
|
||||
if self.close_file or self._from_cli:
|
||||
HTML.write_footer(self._file_descriptor)
|
||||
self._file_descriptor.close()
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
@@ -87,7 +90,10 @@ class HTML(Output):
|
||||
|
||||
@staticmethod
|
||||
def write_header(
|
||||
file_descriptor: TextIOWrapper, provider: Provider, stats: dict
|
||||
file_descriptor: TextIOWrapper,
|
||||
provider: Provider,
|
||||
stats: dict,
|
||||
from_cli: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Writes the header of the HTML file.
|
||||
@@ -96,6 +102,7 @@ class HTML(Output):
|
||||
file_descriptor (file): the file descriptor to write the header
|
||||
provider (Provider): the provider object
|
||||
stats (dict): the statistics of the findings
|
||||
from_cli (bool): whether the request is from the CLI or not
|
||||
"""
|
||||
try:
|
||||
file_descriptor.write(
|
||||
@@ -153,7 +160,7 @@ class HTML(Output):
|
||||
</div>
|
||||
</li>
|
||||
<li class="list-group-item">
|
||||
<b>Parameters used:</b> {" ".join(sys.argv[1:])}
|
||||
<b>Parameters used:</b> {" ".join(sys.argv[1:]) if from_cli else ""}
|
||||
</li>
|
||||
<li class="list-group-item">
|
||||
<b>Date:</b> {timestamp.isoformat()}
|
||||
|
||||
@@ -199,7 +199,8 @@ class OCSF(Output):
|
||||
and not self._file_descriptor.closed
|
||||
and self._data
|
||||
):
|
||||
self._file_descriptor.write("[")
|
||||
if self._file_descriptor.tell() == 0:
|
||||
self._file_descriptor.write("[")
|
||||
for finding in self._data:
|
||||
try:
|
||||
self._file_descriptor.write(
|
||||
@@ -210,14 +211,14 @@ class OCSF(Output):
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
if self._file_descriptor.tell() > 0:
|
||||
if self.close_file or self._from_cli:
|
||||
if self._file_descriptor.tell() != 1:
|
||||
self._file_descriptor.seek(
|
||||
self._file_descriptor.tell() - 1, os.SEEK_SET
|
||||
)
|
||||
self._file_descriptor.truncate()
|
||||
self._file_descriptor.write("]")
|
||||
self._file_descriptor.close()
|
||||
self._file_descriptor.close()
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
|
||||
@@ -23,7 +23,6 @@ class Output(ABC):
|
||||
file_descriptor: Property to access the file descriptor.
|
||||
transform: Abstract method to transform findings into a specific format.
|
||||
batch_write_data_to_file: Abstract method to write data to a file in batches.
|
||||
create_file_descriptor: Method to create a file descriptor for writing data to a file.
|
||||
"""
|
||||
|
||||
_data: list
|
||||
@@ -33,21 +32,27 @@ class Output(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
findings: List[Finding],
|
||||
create_file_descriptor: bool = False,
|
||||
file_path: str = None,
|
||||
file_extension: str = "",
|
||||
from_cli: bool = True,
|
||||
) -> None:
|
||||
self._data = []
|
||||
self.close_file = False
|
||||
self.file_path = file_path
|
||||
self._file_descriptor = None
|
||||
# This parameter is to avoid refactoring more code, the CLI does not write in batches, the API does
|
||||
self._from_cli = from_cli
|
||||
|
||||
if not file_extension and file_path:
|
||||
self._file_extension = "".join(Path(file_path).suffixes)
|
||||
if file_extension:
|
||||
self._file_extension = file_extension
|
||||
self.file_path = f"{file_path}{self.file_extension}"
|
||||
|
||||
if findings:
|
||||
self.transform(findings)
|
||||
if create_file_descriptor and file_path:
|
||||
self.create_file_descriptor(file_path)
|
||||
if not self._file_descriptor and file_path:
|
||||
self.create_file_descriptor(self.file_path)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
||||
@@ -110,7 +110,7 @@ class S3:
|
||||
for output in output_list:
|
||||
try:
|
||||
# Object is not written to file so we need to temporarily write it
|
||||
if not hasattr(output, "file_descriptor"):
|
||||
if not output.file_descriptor:
|
||||
output.file_descriptor = NamedTemporaryFile(mode="a")
|
||||
|
||||
bucket_directory = self.get_object_path(self._output_directory)
|
||||
|
||||
@@ -921,7 +921,7 @@ class AzureProvider(Provider):
|
||||
# since that exception is not considered as critical, we keep filling another identity fields
|
||||
if sp_env_auth or client_id:
|
||||
# The id of the sp can be retrieved from environment variables
|
||||
identity.identity_id = getenv("AZURE_CLIENT_ID")
|
||||
identity.identity_id = getenv("AZURE_CLIENT_ID", default=client_id)
|
||||
identity.identity_type = "Service Principal"
|
||||
# Same here, if user can access AAD, some fields are retrieved if not, default value, for az cli
|
||||
# should work but it doesn't, pending issue
|
||||
|
||||
@@ -577,7 +577,7 @@ class TestASFF:
|
||||
assert loads(content) == expected_asff
|
||||
|
||||
def test_batch_write_data_to_file_without_findings(self):
|
||||
assert not hasattr(ASFF([]), "_file_descriptor")
|
||||
assert not ASFF([])._file_descriptor
|
||||
|
||||
def test_asff_generate_status(self):
|
||||
assert ASFF.generate_status("PASS") == "PASSED"
|
||||
|
||||
@@ -119,7 +119,7 @@ class TestCSV:
|
||||
assert content == expected_csv
|
||||
|
||||
def test_batch_write_data_to_file_without_findings(self):
|
||||
assert not hasattr(CSV([]), "_file_descriptor")
|
||||
assert not CSV([])._file_descriptor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_output_class(self):
|
||||
@@ -144,9 +144,7 @@ class TestCSV:
|
||||
file_path = file.name
|
||||
|
||||
# Instantiate the mock class
|
||||
output_instance = mock_output_class(
|
||||
findings, create_file_descriptor=True, file_path=file_path
|
||||
)
|
||||
output_instance = mock_output_class(findings, file_path=file_path)
|
||||
|
||||
# Check that transform was called once
|
||||
output_instance.transform.assert_called_once_with(findings)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -55,6 +56,65 @@ def mock_get_check_compliance(*_):
|
||||
return {"mock_compliance_key": "mock_compliance_value"}
|
||||
|
||||
|
||||
class DummyTag:
|
||||
def __init__(self, key, value):
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
|
||||
class DummyTags:
|
||||
def __init__(self, tags):
|
||||
self._tags = tags
|
||||
|
||||
def all(self):
|
||||
return self._tags
|
||||
|
||||
|
||||
class DummyResource:
|
||||
def __init__(self, uid, name, resource_arn, region, tags):
|
||||
self.uid = uid
|
||||
self.name = name
|
||||
self.resource_arn = resource_arn
|
||||
self.region = region
|
||||
self.tags = DummyTags(tags)
|
||||
|
||||
def __iter__(self):
|
||||
yield "uid", self.uid
|
||||
yield "name", self.name
|
||||
yield "region", self.region
|
||||
yield "tags", self.tags
|
||||
|
||||
|
||||
class DummyResources:
|
||||
"""Simulate a collection with a first() method."""
|
||||
|
||||
def __init__(self, resource):
|
||||
self._resource = resource
|
||||
|
||||
def first(self):
|
||||
return self._resource
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, uid):
|
||||
self.uid = uid
|
||||
self.type = "aws"
|
||||
|
||||
|
||||
class DummyScan:
|
||||
def __init__(self, provider):
|
||||
self.provider = provider
|
||||
|
||||
|
||||
class DummyAPIFinding:
|
||||
"""
|
||||
A dummy API finding model to simulate the database model.
|
||||
Attributes will be added dynamically.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFinding:
|
||||
@patch(
|
||||
"prowler.lib.outputs.finding.get_check_compliance",
|
||||
@@ -461,3 +521,566 @@ class TestFinding:
|
||||
# Generate the finding
|
||||
with pytest.raises(ValidationError):
|
||||
Finding.generate_output(provider, check_output, output_options)
|
||||
|
||||
@patch(
|
||||
"prowler.lib.outputs.finding.get_check_compliance",
|
||||
new=mock_get_check_compliance,
|
||||
)
|
||||
def test_transform_api_finding_aws(self):
|
||||
"""
|
||||
Test that a dummy API Finding is correctly
|
||||
transformed into a Finding instance.
|
||||
"""
|
||||
# Set up the dummy API finding attributes
|
||||
inserted_at = 1234567890
|
||||
provider = DummyProvider(uid="account123")
|
||||
provider.type = "aws"
|
||||
scan = DummyScan(provider=provider)
|
||||
|
||||
# Create a dummy resource with one tag
|
||||
tag = DummyTag("env", "prod")
|
||||
resource = DummyResource(
|
||||
uid="res-uid-1",
|
||||
name="ResourceName1",
|
||||
resource_arn="arn",
|
||||
region="us-east-1",
|
||||
tags=[tag],
|
||||
)
|
||||
resources = DummyResources(resource)
|
||||
|
||||
# Create a dummy check_metadata dict with all required fields
|
||||
check_metadata = {
|
||||
"provider": "test_provider",
|
||||
"checkid": "check-001",
|
||||
"checktitle": "Test Check",
|
||||
"checktype": ["type1"],
|
||||
"servicename": "TestService",
|
||||
"subservicename": "SubService",
|
||||
"severity": "high",
|
||||
"resourcetype": "TestResource",
|
||||
"description": "A test check",
|
||||
"risk": "High risk",
|
||||
"relatedurl": "http://example.com",
|
||||
"remediation": {
|
||||
"recommendation": {"text": "Fix it", "url": "http://fix.com"},
|
||||
"code": {
|
||||
"nativeiac": "iac_code",
|
||||
"terraform": "terraform_code",
|
||||
"cli": "cli_code",
|
||||
"other": "other_code",
|
||||
},
|
||||
},
|
||||
"resourceidtemplate": "template",
|
||||
"categories": ["cat-one", "cat-two"],
|
||||
"dependson": ["dep1"],
|
||||
"relatedto": ["rel1"],
|
||||
"notes": "Some notes",
|
||||
}
|
||||
|
||||
# Create the dummy API finding and assign required attributes
|
||||
dummy_finding = DummyAPIFinding()
|
||||
dummy_finding.inserted_at = inserted_at
|
||||
dummy_finding.scan = scan
|
||||
dummy_finding.uid = "finding-uid-1"
|
||||
dummy_finding.status = "FAIL" # will be converted to Status("FAIL")
|
||||
dummy_finding.status_extended = "extended"
|
||||
dummy_finding.check_metadata = check_metadata
|
||||
dummy_finding.resources = resources
|
||||
|
||||
# Call the transform_api_finding classmethod
|
||||
finding_obj = Finding.transform_api_finding(dummy_finding, provider)
|
||||
|
||||
# Check that metadata was built correctly
|
||||
meta = finding_obj.metadata
|
||||
assert meta.Provider == "test_provider"
|
||||
assert meta.CheckID == "check-001"
|
||||
assert meta.CheckTitle == "Test Check"
|
||||
assert meta.CheckType == ["type1"]
|
||||
assert meta.ServiceName == "TestService"
|
||||
assert meta.SubServiceName == "SubService"
|
||||
assert meta.Severity == "high"
|
||||
assert meta.ResourceType == "TestResource"
|
||||
assert meta.Description == "A test check"
|
||||
assert meta.Risk == "High risk"
|
||||
assert meta.RelatedUrl == "http://example.com"
|
||||
assert meta.Remediation.Recommendation.Text == "Fix it"
|
||||
assert meta.Remediation.Recommendation.Url == "http://fix.com"
|
||||
assert meta.Remediation.Code.NativeIaC == "iac_code"
|
||||
assert meta.Remediation.Code.Terraform == "terraform_code"
|
||||
assert meta.Remediation.Code.CLI == "cli_code"
|
||||
assert meta.Remediation.Code.Other == "other_code"
|
||||
assert meta.ResourceIdTemplate == "template"
|
||||
assert meta.Categories == ["cat-one", "cat-two"]
|
||||
assert meta.DependsOn == ["dep1"]
|
||||
assert meta.RelatedTo == ["rel1"]
|
||||
assert meta.Notes == "Some notes"
|
||||
|
||||
# Check other Finding fields
|
||||
assert finding_obj.uid == "prowler-aws-check-001--us-east-1-ResourceName1"
|
||||
assert finding_obj.status == Status("FAIL")
|
||||
assert finding_obj.status_extended == "extended"
|
||||
# From the dummy resource
|
||||
assert finding_obj.resource_uid == "res-uid-1"
|
||||
assert finding_obj.resource_name == "ResourceName1"
|
||||
assert finding_obj.resource_details == ""
|
||||
# unroll_tags is called on a list with one tag -> expect {"env": "prod"}
|
||||
assert finding_obj.resource_tags == {"env": "prod"}
|
||||
assert finding_obj.region == "us-east-1"
|
||||
assert finding_obj.compliance == {
|
||||
"mock_compliance_key": "mock_compliance_value"
|
||||
}
|
||||
|
||||
@patch(
|
||||
"prowler.lib.outputs.finding.get_check_compliance",
|
||||
new=mock_get_check_compliance,
|
||||
)
|
||||
def test_transform_api_finding_azure(self):
|
||||
provider = MagicMock()
|
||||
provider.type = "azure"
|
||||
provider.identity.identity_type = "mock_identity_type"
|
||||
provider.identity.identity_id = "mock_identity_id"
|
||||
provider.identity.subscriptions = {"default": "default"}
|
||||
provider.identity.tenant_ids = ["test-ing-432a-a828-d9c965196f87"]
|
||||
provider.identity.tenant_domain = "mock_tenant_domain"
|
||||
provider.region_config.name = "AzureCloud"
|
||||
|
||||
api_finding = DummyAPIFinding()
|
||||
api_finding.id = "019514b3-9a66-7cde-921e-9d1ca0531ceb"
|
||||
api_finding.inserted_at = "2025-02-17 16:17:49"
|
||||
api_finding.updated_at = "2025-02-17 16:17:49"
|
||||
api_finding.uid = (
|
||||
"prowler-azure-defender_auto_provisioning_log_analytics_agent_vms_on-"
|
||||
"test-ing-4646-bed4-e74f14020726-global-default"
|
||||
)
|
||||
api_finding.delta = "new"
|
||||
api_finding.status = "FAIL"
|
||||
api_finding.status_extended = "Defender Auto Provisioning Log Analytics Agents from subscription Azure subscription 1 is set to OFF."
|
||||
api_finding.severity = "medium"
|
||||
api_finding.impact = "medium"
|
||||
api_finding.impact_extended = ""
|
||||
api_finding.raw_result = {}
|
||||
api_finding.check_id = "defender_auto_provisioning_log_analytics_agent_vms_on"
|
||||
api_finding.check_metadata = {
|
||||
"risk": "Missing critical security information about your Azure VMs, such as security alerts, security recommendations, and change tracking.",
|
||||
"notes": "",
|
||||
"checkid": "defender_auto_provisioning_log_analytics_agent_vms_on",
|
||||
"provider": "azure",
|
||||
"severity": "medium",
|
||||
"checktype": [],
|
||||
"dependson": [],
|
||||
"relatedto": [],
|
||||
"categories": [],
|
||||
"checktitle": "Ensure that Auto provisioning of 'Log Analytics agent for Azure VMs' is Set to 'On'",
|
||||
"compliance": None,
|
||||
"relatedurl": "https://docs.microsoft.com/en-us/azure/security-center/security-center-data-security",
|
||||
"description": (
|
||||
"Ensure that Auto provisioning of 'Log Analytics agent for Azure VMs' is Set to 'On'. "
|
||||
"The Microsoft Monitoring Agent scans for various security-related configurations and events such as system updates, "
|
||||
"OS vulnerabilities, endpoint protection, and provides alerts."
|
||||
),
|
||||
"remediation": {
|
||||
"code": {
|
||||
"cli": "",
|
||||
"other": "https://www.trendmicro.com/cloudoneconformity-staging/knowledge-base/azure/SecurityCenter/automatic-provisioning-of-monitoring-agent.html",
|
||||
"nativeiac": "",
|
||||
"terraform": "",
|
||||
},
|
||||
"recommendation": {
|
||||
"url": "https://learn.microsoft.com/en-us/azure/defender-for-cloud/monitoring-components",
|
||||
"text": (
|
||||
"Ensure comprehensive visibility into possible security vulnerabilities, including missing updates, "
|
||||
"misconfigured operating system security settings, and active threats, allowing for timely mitigation and improved overall security posture"
|
||||
),
|
||||
},
|
||||
},
|
||||
"servicename": "defender",
|
||||
"checkaliases": [],
|
||||
"resourcetype": "AzureDefenderPlan",
|
||||
"subservicename": "",
|
||||
"resourceidtemplate": "",
|
||||
}
|
||||
api_finding.tags = {}
|
||||
api_resource = DummyResource(
|
||||
uid="/subscriptions/test-ing-4646-bed4-e74f14020726/providers/Microsoft.Security/autoProvisioningSettings/default",
|
||||
name="default",
|
||||
resource_arn="arn",
|
||||
region="global",
|
||||
tags=[],
|
||||
)
|
||||
api_finding.resources = DummyResources(api_resource)
|
||||
api_finding.subscription = "default"
|
||||
finding_obj = Finding.transform_api_finding(api_finding, provider)
|
||||
|
||||
assert finding_obj.account_organization_uid == "test-ing-432a-a828-d9c965196f87"
|
||||
assert finding_obj.account_organization_name == "mock_tenant_domain"
|
||||
assert finding_obj.resource_uid == api_resource.uid
|
||||
assert finding_obj.resource_name == api_resource.name
|
||||
assert finding_obj.region == api_resource.region
|
||||
assert finding_obj.resource_tags == {}
|
||||
assert finding_obj.compliance == {
|
||||
"mock_compliance_key": "mock_compliance_value"
|
||||
}
|
||||
|
||||
assert finding_obj.status == Status("FAIL")
|
||||
assert finding_obj.status_extended == (
|
||||
"Defender Auto Provisioning Log Analytics Agents from subscription Azure subscription 1 is set to OFF."
|
||||
)
|
||||
|
||||
meta = finding_obj.metadata
|
||||
assert meta.Provider == "azure"
|
||||
assert meta.CheckID == "defender_auto_provisioning_log_analytics_agent_vms_on"
|
||||
assert (
|
||||
meta.CheckTitle
|
||||
== "Ensure that Auto provisioning of 'Log Analytics agent for Azure VMs' is Set to 'On'"
|
||||
)
|
||||
assert meta.Severity == "medium"
|
||||
assert meta.ResourceType == "AzureDefenderPlan"
|
||||
assert (
|
||||
meta.Remediation.Recommendation.Url
|
||||
== "https://learn.microsoft.com/en-us/azure/defender-for-cloud/monitoring-components"
|
||||
)
|
||||
assert meta.Remediation.Recommendation.Text.startswith(
|
||||
"Ensure comprehensive visibility"
|
||||
)
|
||||
|
||||
expected_segments = [
|
||||
"prowler-azure",
|
||||
"defender_auto_provisioning_log_analytics_agent_vms_on",
|
||||
api_resource.region,
|
||||
api_resource.name,
|
||||
]
|
||||
for segment in expected_segments:
|
||||
assert segment in finding_obj.uid
|
||||
|
||||
@patch(
|
||||
"prowler.lib.outputs.finding.get_check_compliance",
|
||||
new=mock_get_check_compliance,
|
||||
)
|
||||
def test_transform_api_finding_gcp(self):
|
||||
provider = MagicMock()
|
||||
provider.type = "gcp"
|
||||
provider.identity.profile = "gcp_profile"
|
||||
dummy_project = MagicMock()
|
||||
dummy_project.id = "project1"
|
||||
dummy_project.name = "TestProject"
|
||||
dummy_project.labels = {"env": "prod"}
|
||||
dummy_org = MagicMock()
|
||||
dummy_org.id = "org-123"
|
||||
dummy_org.display_name = "Test Org"
|
||||
dummy_project.organization = dummy_org
|
||||
provider.projects = {"project1": dummy_project}
|
||||
|
||||
dummy_finding = DummyAPIFinding()
|
||||
dummy_finding.inserted_at = "2025-02-17 16:17:49"
|
||||
dummy_finding.updated_at = "2025-02-17 16:17:49"
|
||||
dummy_finding.scan = DummyScan(provider=provider)
|
||||
dummy_finding.uid = "finding-uid-gcp"
|
||||
dummy_finding.status = "PASS"
|
||||
dummy_finding.status_extended = "GCP check extended"
|
||||
check_metadata = {
|
||||
"provider": "gcp",
|
||||
"checkid": "gcp-check-001",
|
||||
"checktitle": "Test GCP Check",
|
||||
"checktype": [],
|
||||
"servicename": "TestGCPService",
|
||||
"subservicename": "",
|
||||
"severity": "medium",
|
||||
"resourcetype": "GCPResourceType",
|
||||
"description": "GCP check description",
|
||||
"risk": "Medium risk",
|
||||
"relatedurl": "http://gcp.example.com",
|
||||
"remediation": {
|
||||
"code": {
|
||||
"nativeiac": "iac_code",
|
||||
"terraform": "terraform_code",
|
||||
"cli": "cli_code",
|
||||
"other": "other_code",
|
||||
},
|
||||
"recommendation": {"text": "Fix it", "url": "http://fix-gcp.com"},
|
||||
},
|
||||
"resourceidtemplate": "template",
|
||||
"categories": ["cat-one", "cat-two"],
|
||||
"dependson": ["dep1"],
|
||||
"relatedto": ["rel1"],
|
||||
"notes": "Some notes",
|
||||
}
|
||||
dummy_finding.check_metadata = check_metadata
|
||||
dummy_finding.raw_result = {}
|
||||
dummy_finding.project_id = "project1"
|
||||
|
||||
resource = DummyResource(
|
||||
uid="gcp-resource-uid",
|
||||
name="gcp-resource-name",
|
||||
resource_arn="arn",
|
||||
region="us-central1",
|
||||
tags=[],
|
||||
)
|
||||
dummy_finding.resources = DummyResources(resource)
|
||||
finding_obj = Finding.transform_api_finding(dummy_finding, provider)
|
||||
|
||||
assert finding_obj.auth_method == "Principal: gcp_profile"
|
||||
assert finding_obj.account_uid == dummy_project.id
|
||||
assert finding_obj.account_name == dummy_project.name
|
||||
assert finding_obj.account_tags == dummy_project.labels
|
||||
assert finding_obj.resource_name == resource.name
|
||||
assert finding_obj.resource_uid == resource.uid
|
||||
assert finding_obj.region == resource.region
|
||||
assert finding_obj.account_organization_uid == dummy_project.organization.id
|
||||
assert (
|
||||
finding_obj.account_organization_name
|
||||
== dummy_project.organization.display_name
|
||||
)
|
||||
assert finding_obj.compliance == {
|
||||
"mock_compliance_key": "mock_compliance_value"
|
||||
}
|
||||
assert finding_obj.status == Status("PASS")
|
||||
assert finding_obj.status_extended == "GCP check extended"
|
||||
expected_uid = f"prowler-gcp-{check_metadata['checkid']}-{dummy_project.id}-{resource.region}-{resource.name}"
|
||||
assert finding_obj.uid == expected_uid
|
||||
|
||||
@patch(
|
||||
"prowler.lib.outputs.finding.get_check_compliance",
|
||||
new=mock_get_check_compliance,
|
||||
)
|
||||
def test_transform_api_finding_kubernetes(self):
|
||||
provider = MagicMock()
|
||||
provider.type = "kubernetes"
|
||||
provider.identity.context = "In-Cluster"
|
||||
provider.identity.cluster = "cluster-1"
|
||||
api_finding = DummyAPIFinding()
|
||||
api_finding.inserted_at = 1234567890
|
||||
api_finding.scan = DummyScan(provider=provider)
|
||||
api_finding.uid = "finding-uid-k8s"
|
||||
api_finding.status = "PASS"
|
||||
api_finding.status_extended = "K8s check extended"
|
||||
check_metadata = {
|
||||
"provider": "kubernetes",
|
||||
"checkid": "k8s-check-001",
|
||||
"checktitle": "Test K8s Check",
|
||||
"checktype": [],
|
||||
"servicename": "TestK8sService",
|
||||
"subservicename": "",
|
||||
"severity": "low",
|
||||
"resourcetype": "K8sResourceType",
|
||||
"description": "K8s check description",
|
||||
"risk": "Low risk",
|
||||
"relatedurl": "http://k8s.example.com",
|
||||
"remediation": {
|
||||
"code": {
|
||||
"nativeiac": "iac_code",
|
||||
"terraform": "terraform_code",
|
||||
"cli": "cli_code",
|
||||
"other": "other_code",
|
||||
},
|
||||
"recommendation": {"text": "Fix it", "url": "http://fix-k8s.com"},
|
||||
},
|
||||
"resourceidtemplate": "template",
|
||||
"categories": ["cat-one"],
|
||||
"dependson": [],
|
||||
"relatedto": [],
|
||||
"notes": "K8s notes",
|
||||
}
|
||||
api_finding.check_metadata = check_metadata
|
||||
api_finding.raw_result = {}
|
||||
api_finding.resource_name = "k8s-resource-name"
|
||||
api_finding.resource_id = "k8s-resource-uid"
|
||||
resource = DummyResource(
|
||||
uid="k8s-resource-uid",
|
||||
name="k8s-resource-name",
|
||||
resource_arn="arn",
|
||||
region="",
|
||||
tags=[],
|
||||
)
|
||||
resource.region = "namespace: default"
|
||||
api_finding.resources = DummyResources(resource)
|
||||
finding_obj = Finding.transform_api_finding(api_finding, provider)
|
||||
assert finding_obj.auth_method == "in-cluster"
|
||||
assert finding_obj.resource_name == "k8s-resource-name"
|
||||
assert finding_obj.resource_uid == "k8s-resource-uid"
|
||||
assert finding_obj.account_name == "context: In-Cluster"
|
||||
assert finding_obj.account_uid == "cluster-1"
|
||||
assert finding_obj.region == "namespace: default"
|
||||
|
||||
@patch(
|
||||
"prowler.lib.outputs.finding.get_check_compliance",
|
||||
new=mock_get_check_compliance,
|
||||
)
|
||||
def test_transform_api_finding_microsoft365(self):
|
||||
provider = MagicMock()
|
||||
provider.type = "microsoft365"
|
||||
provider.identity.identity_type = "ms_identity_type"
|
||||
provider.identity.identity_id = "ms_identity_id"
|
||||
provider.identity.tenant_id = "ms-tenant-id"
|
||||
provider.identity.tenant_domain = "ms-tenant-domain"
|
||||
dummy_finding = DummyAPIFinding()
|
||||
dummy_finding.inserted_at = 1234567890
|
||||
dummy_finding.scan = DummyScan(provider=provider)
|
||||
dummy_finding.uid = "finding-uid-m365"
|
||||
dummy_finding.status = "PASS"
|
||||
dummy_finding.status_extended = "M365 check extended"
|
||||
check_metadata = {
|
||||
"provider": "microsoft365",
|
||||
"checkid": "m365-check-001",
|
||||
"checktitle": "Test M365 Check",
|
||||
"checktype": [],
|
||||
"servicename": "TestM365Service",
|
||||
"subservicename": "",
|
||||
"severity": "high",
|
||||
"resourcetype": "M365ResourceType",
|
||||
"description": "M365 check description",
|
||||
"risk": "High risk",
|
||||
"relatedurl": "http://m365.example.com",
|
||||
"remediation": {
|
||||
"code": {
|
||||
"nativeiac": "iac_code",
|
||||
"terraform": "terraform_code",
|
||||
"cli": "cli_code",
|
||||
"other": "other_code",
|
||||
},
|
||||
"recommendation": {"text": "Fix it", "url": "http://fix-m365.com"},
|
||||
},
|
||||
"resourceidtemplate": "template",
|
||||
"categories": ["cat-one"],
|
||||
"dependson": [],
|
||||
"relatedto": [],
|
||||
"notes": "M365 notes",
|
||||
}
|
||||
dummy_finding.check_metadata = check_metadata
|
||||
dummy_finding.raw_result = {}
|
||||
dummy_finding.resource_name = "ms-resource-name"
|
||||
dummy_finding.resource_id = "ms-resource-uid"
|
||||
dummy_finding.location = "global"
|
||||
resource = DummyResource(
|
||||
uid="ms-resource-uid",
|
||||
name="ms-resource-name",
|
||||
resource_arn="arn",
|
||||
region="global",
|
||||
tags=[],
|
||||
)
|
||||
dummy_finding.resources = DummyResources(resource)
|
||||
finding_obj = Finding.transform_api_finding(dummy_finding, provider)
|
||||
assert finding_obj.auth_method == "ms_identity_type: ms_identity_id"
|
||||
assert finding_obj.account_uid == "ms-tenant-id"
|
||||
assert finding_obj.account_name == "ms-tenant-domain"
|
||||
assert finding_obj.resource_name == "ms-resource-name"
|
||||
assert finding_obj.resource_uid == "ms-resource-uid"
|
||||
assert finding_obj.region == "global"
|
||||
|
||||
def test_transform_findings_stats_all_fails_muted(self):
|
||||
"""
|
||||
Test _transform_findings_stats when every failing finding is muted.
|
||||
"""
|
||||
# Create a dummy scan object with a unique_resource_count
|
||||
dummy_scan = SimpleNamespace(unique_resource_count=10)
|
||||
# Build summaries covering each severity branch.
|
||||
ss1 = SimpleNamespace(
|
||||
_pass=1, fail=2, total=3, muted=2, severity="critical", scan=dummy_scan
|
||||
)
|
||||
ss2 = SimpleNamespace(
|
||||
_pass=2, fail=0, total=2, muted=0, severity="high", scan=dummy_scan
|
||||
)
|
||||
ss3 = SimpleNamespace(
|
||||
_pass=2, fail=3, total=5, muted=3, severity="medium", scan=dummy_scan
|
||||
)
|
||||
ss4 = SimpleNamespace(
|
||||
_pass=3, fail=0, total=3, muted=0, severity="low", scan=dummy_scan
|
||||
)
|
||||
|
||||
summaries = [ss1, ss2, ss3, ss4]
|
||||
stats = Finding._transform_findings_stats(summaries)
|
||||
|
||||
# Expected calculations:
|
||||
# total_pass = 1+2+2+3 = 8
|
||||
# total_fail = 2+0+3+0 = 5
|
||||
# findings_count = 3+2+5+3 = 13
|
||||
# muted_pass = (ss1: 1) + (ss3: 2) = 3
|
||||
# muted_fail = (ss1: 2) + (ss3: 3) = 5
|
||||
expected = {
|
||||
"total_pass": 8,
|
||||
"total_muted_pass": 3,
|
||||
"total_fail": 5,
|
||||
"total_muted_fail": 5,
|
||||
"resources_count": 10,
|
||||
"findings_count": 13,
|
||||
"total_critical_severity_fail": 2,
|
||||
"total_critical_severity_pass": 1,
|
||||
"total_high_severity_fail": 0,
|
||||
"total_high_severity_pass": 2,
|
||||
"total_medium_severity_fail": 3,
|
||||
"total_medium_severity_pass": 2,
|
||||
"total_low_severity_fail": 0,
|
||||
"total_low_severity_pass": 3,
|
||||
"all_fails_are_muted": True, # total_fail equals muted_fail and total_fail > 0
|
||||
}
|
||||
assert stats == expected
|
||||
|
||||
def test_transform_findings_stats_not_all_fails_muted(self):
|
||||
"""
|
||||
Test _transform_findings_stats when at least one failing finding is not muted.
|
||||
"""
|
||||
dummy_scan = SimpleNamespace(unique_resource_count=5)
|
||||
# Build summaries: one summary has fail > 0 but muted == 0
|
||||
ss1 = SimpleNamespace(
|
||||
_pass=1, fail=2, total=3, muted=0, severity="critical", scan=dummy_scan
|
||||
)
|
||||
ss2 = SimpleNamespace(
|
||||
_pass=2, fail=1, total=3, muted=1, severity="high", scan=dummy_scan
|
||||
)
|
||||
summaries = [ss1, ss2]
|
||||
stats = Finding._transform_findings_stats(summaries)
|
||||
|
||||
# Expected calculations:
|
||||
# total_pass = 1+2 = 3
|
||||
# total_fail = 2+1 = 3
|
||||
# findings_count = 3+3 = 6
|
||||
# muted_pass = (ss2: 2) since ss1 muted is 0
|
||||
# muted_fail = (ss2: 1)
|
||||
# Severity breakdown: critical: pass 1, fail 2; high: pass 2, fail 1
|
||||
expected = {
|
||||
"total_pass": 3,
|
||||
"total_muted_pass": 2,
|
||||
"total_fail": 3,
|
||||
"total_muted_fail": 1,
|
||||
"resources_count": 5,
|
||||
"findings_count": 6,
|
||||
"total_critical_severity_fail": 2,
|
||||
"total_critical_severity_pass": 1,
|
||||
"total_high_severity_fail": 1,
|
||||
"total_high_severity_pass": 2,
|
||||
"total_medium_severity_fail": 0,
|
||||
"total_medium_severity_pass": 0,
|
||||
"total_low_severity_fail": 0,
|
||||
"total_low_severity_pass": 0,
|
||||
"all_fails_are_muted": False, # 3 (total_fail) != 1 (muted_fail)
|
||||
}
|
||||
assert stats == expected
|
||||
|
||||
def test_transform_api_finding_validation_error(self):
|
||||
"""
|
||||
Test that if required data is missing (causing a ValidationError)
|
||||
the function logs the error and re-raises the exception.
|
||||
For example, if the metadata dict is missing required keys.
|
||||
"""
|
||||
provider = DummyProvider(uid="account123")
|
||||
# Create a dummy API finding that is missing some required metadata
|
||||
dummy_finding = DummyAPIFinding()
|
||||
dummy_finding.inserted_at = 1234567890
|
||||
dummy_finding.scan = DummyScan(provider=provider)
|
||||
dummy_finding.uid = "finding-uid-invalid"
|
||||
dummy_finding.status = "PASS"
|
||||
dummy_finding.status_extended = "extended"
|
||||
# Missing required metadata keys – using an empty dict
|
||||
dummy_finding.check_metadata = {}
|
||||
# Provide a dummy resources with a minimal resource
|
||||
tag = DummyTag("env", "prod")
|
||||
resource = DummyResource(
|
||||
uid="res-uid-1",
|
||||
name="ResourceName1",
|
||||
resource_arn="arn",
|
||||
region="us-east-1",
|
||||
tags=[tag],
|
||||
)
|
||||
dummy_finding.resources = DummyResources(resource)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
Finding.transform_api_finding(dummy_finding, provider)
|
||||
|
||||
@@ -492,7 +492,7 @@ class TestHTML:
|
||||
assert content == get_aws_html_header(args) + pass_html_finding + html_footer
|
||||
|
||||
def test_batch_write_data_to_file_without_findings(self):
|
||||
assert not hasattr(HTML([]), "_file_descriptor")
|
||||
assert not HTML([])._file_descriptor
|
||||
|
||||
def test_write_header(self):
|
||||
mock_file = StringIO()
|
||||
|
||||
@@ -256,7 +256,7 @@ class TestOCSF:
|
||||
assert json.loads(content) == expected_json_output
|
||||
|
||||
def test_batch_write_data_to_file_without_findings(self):
|
||||
assert not hasattr(OCSF([]), "_file_descriptor")
|
||||
assert not OCSF([])._file_descriptor
|
||||
|
||||
def test_finding_output_cloud_pass_low_muted(self):
|
||||
finding_output = generate_finding_output(
|
||||
|
||||
@@ -113,7 +113,6 @@ class TestS3:
|
||||
csv_file = f"test{extension}"
|
||||
csv = CSV(
|
||||
findings=[FINDING],
|
||||
create_file_descriptor=True,
|
||||
file_path=f"{CURRENT_DIRECTORY}/{csv_file}",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user