Compare commits

...

6 Commits

Author SHA1 Message Date
Adrián Jesús Peña Rodríguez
4a14e21553 Merge branch 'master' into PRWLR-5956-Export-Artifacts 2025-02-04 10:31:50 +01:00
Adrián Jesús Peña Rodríguez
0995c7a845 fix: scan unittests 2025-01-31 11:30:09 +01:00
Adrián Jesús Peña Rodríguez
63f8186bd6 fix: resolve ruff errors 2025-01-30 12:46:53 +01:00
Adrián Jesús Peña Rodríguez
d70c71c903 chore: findings fix restore 2025-01-30 12:18:51 +01:00
Adrián Jesús Peña Rodríguez
53c571c289 Merge branch 'master' into PRWLR-5956-Export-Artifacts 2025-01-30 11:59:54 +01:00
Adrián Jesús Peña Rodríguez
a3c7846cd9 feat(report): add export system and its endpoint 2025-01-30 11:54:44 +01:00
11 changed files with 760 additions and 455 deletions

17
.env
View File

@@ -30,6 +30,23 @@ VALKEY_HOST=valkey
VALKEY_PORT=6379
VALKEY_DB=0
# API scan settings
# The AWS access key to be used when uploading scan artifacts to an S3 bucket
# If left empty, default AWS credentials resolution behavior will be used
ARTIFACTS_AWS_ACCESS_KEY_ID=""
# The AWS secret key to be used when uploading scan artifacts to an S3 bucket
ARTIFACTS_AWS_SECRET_ACCESS_KEY=""
# An optional AWS session token
ARTIFACTS_AWS_SESSION_TOKEN=""
# The AWS region where your S3 bucket is located (e.g., "us-east-1")
ARTIFACTS_AWS_DEFAULT_REGION=""
# The name of the S3 bucket where scan artifacts should be stored
ARTIFACTS_AWS_S3_OUTPUT_BUCKET=""
# Django settings
DJANGO_ALLOWED_HOSTS=localhost,127.0.0.1,prowler-api
DJANGO_BIND_ADDRESS=0.0.0.0

View File

@@ -4093,6 +4093,38 @@ paths:
schema:
$ref: '#/components/schemas/ScanUpdateResponse'
description: ''
/api/v1/scans/{id}/report:
get:
operationId: scans_report_retrieve
description: Returns a ZIP file containing the requested report
summary: Download ZIP report
parameters:
- in: query
name: fields[scans]
schema:
type: array
items:
type: string
enum: []
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this scan.
required: true
tags:
- Scan
security:
- jwtAuth: []
responses:
'200':
description: Report obtanined successfully
'404':
description: Report not found
/api/v1/schedules/daily:
post:
operationId: schedules_daily_create

View File

@@ -819,6 +819,14 @@ class ScanTaskSerializer(RLSSerializer):
]
class ScanReportSerializer(RLSSerializer):
class Meta:
model = Scan
fields = [
"id",
]
class ResourceTagSerializer(RLSSerializer):
"""
Serializer for the ResourceTag model

View File

@@ -1,9 +1,16 @@
import glob
import os
import boto3
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
from celery.result import AsyncResult
from config.env import env
from django.conf import settings as django_settings
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.search import SearchQuery
from django.db import transaction
from django.db.models import Count, F, OuterRef, Prefetch, Q, Subquery, Sum
from django.http import HttpResponse
from django.db.models.functions import Coalesce
from django.urls import reverse
from django.utils.decorators import method_decorator
@@ -114,6 +121,7 @@ from api.v1.serializers import (
RoleSerializer,
RoleUpdateSerializer,
ScanCreateSerializer,
ScanReportSerializer,
ScanSerializer,
ScanUpdateSerializer,
ScheduleDailyCreateSerializer,
@@ -126,6 +134,7 @@ from api.v1.serializers import (
UserSerializer,
UserUpdateSerializer,
)
from prowler.config.config import tmp_output_directory
CACHE_DECORATOR = cache_control(
max_age=django_settings.CACHE_MAX_AGE,
@@ -1073,6 +1082,8 @@ class ScanViewSet(BaseRLSViewSet):
return ScanCreateSerializer
elif self.action == "partial_update":
return ScanUpdateSerializer
elif self.action == "report":
return ScanReportSerializer
return super().get_serializer_class()
def partial_update(self, request, *args, **kwargs):
@@ -1127,6 +1138,101 @@ class ScanViewSet(BaseRLSViewSet):
},
)
@extend_schema(
tags=["Scan"],
summary="Download ZIP report",
description="Returns a ZIP file containing the requested report",
request=ScanReportSerializer,
responses={
200: OpenApiResponse(description="Report obtanined successfully"),
404: OpenApiResponse(description="Report not found"),
},
)
@action(detail=True, methods=["get"], url_name="report")
def report(self, request, pk=None):
s3_client = None
try:
s3_client = boto3.client("s3")
s3_client.list_buckets()
except (ClientError, NoCredentialsError, ParamValidationError):
try:
s3_client = boto3.client(
"s3",
aws_access_key_id=env.str("ARTIFACTS_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=env.str("ARTIFACTS_AWS_SECRET_ACCESS_KEY"),
aws_session_token=env.str("ARTIFACTS_AWS_SESSION_TOKEN"),
region_name=env.str("ARTIFACTS_AWS_DEFAULT_REGION"),
)
s3_client.list_buckets()
except (ClientError, NoCredentialsError, ParamValidationError):
s3_client = None
if s3_client:
bucket_name = env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET")
s3_prefix = f"{request.tenant_id}/{pk}/"
try:
response = s3_client.list_objects_v2(
Bucket=bucket_name, Prefix=s3_prefix
)
if response["KeyCount"] == 0:
return Response(
{"detail": "No files found in S3 storage"},
status=status.HTTP_404_NOT_FOUND,
)
zip_files = [
obj["Key"]
for obj in response.get("Contents", [])
if obj["Key"].endswith(".zip")
]
if not zip_files:
return Response(
{"detail": "No ZIP files found in S3 storage"},
status=status.HTTP_404_NOT_FOUND,
)
s3_key = zip_files[0]
s3_object = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
file_content = s3_object["Body"].read()
filename = os.path.basename(s3_key)
except ClientError:
return Response(
{"detail": "Error accessing cloud storage"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
else:
local_path = os.path.join(
tmp_output_directory,
str(request.tenant_id),
str(pk),
"*.zip",
)
zip_files = glob.glob(local_path)
if not zip_files:
return Response(
{"detail": "No local files found"}, status=status.HTTP_404_NOT_FOUND
)
try:
file_path = zip_files[0]
with open(file_path, "rb") as f:
file_content = f.read()
filename = os.path.basename(file_path)
except IOError:
return Response(
{"detail": "Error reading local file"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
response = HttpResponse(
file_content, content_type="application/x-zip-compressed"
)
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
@extend_schema_view(
list=extend_schema(

View File

@@ -1,8 +1,13 @@
import os
import time
import zipfile
from copy import deepcopy
from datetime import datetime, timezone
import boto3
from botocore.exceptions import ClientError
from celery.utils.log import get_task_logger
from config.env import env
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
from django.db import IntegrityError, OperationalError
from django.db.models import Case, Count, IntegerField, Sum, When
@@ -25,15 +30,117 @@ from api.models import (
from api.models import StatusChoices as FindingStatus
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.config.config import (
csv_file_suffix,
html_file_suffix,
json_asff_file_suffix,
json_ocsf_file_suffix,
output_file_timestamp,
tmp_output_directory,
)
from prowler.lib.outputs.asff.asff import ASFF
from prowler.lib.outputs.csv.csv import CSV
from prowler.lib.outputs.finding import Finding as ProwlerFinding
from prowler.lib.outputs.html.html import HTML
from prowler.lib.outputs.ocsf.ocsf import OCSF
from prowler.lib.scan.scan import Scan as ProwlerScan
logger = get_task_logger(__name__)
# Predefined mapping for output formats and their configurations
OUTPUT_FORMATS_MAPPING = {
"csv": {
"class": CSV,
"suffix": csv_file_suffix,
"kwargs": {},
},
"json-asff": {"class": ASFF, "suffix": json_asff_file_suffix, "kwargs": {}},
"json-ocsf": {"class": OCSF, "suffix": json_ocsf_file_suffix, "kwargs": {}},
"html": {"class": HTML, "suffix": html_file_suffix, "kwargs": {"stats": {}}},
}
# Mapping provider types to their identity components for output paths
PROVIDER_IDENTITY_MAP = {
"aws": lambda p: p.identity.account,
"azure": lambda p: p.identity.tenant_domain,
"gcp": lambda p: p.identity.profile,
"kubernetes": lambda p: p.identity.context.replace(":", "_").replace("/", "_"),
}
def _compress_output_files(output_directory: str) -> str:
"""
Compress output files from all configured output formats into a ZIP archive.
Args:
output_directory (str): The directory where the output files are located.
The function looks up all known suffixes in OUTPUT_FORMATS_MAPPING
and compresses those files into a single ZIP.
Returns:
str: The full path to the newly created ZIP archive.
"""
zip_path = f"{output_directory}.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for suffix in [config["suffix"] for config in OUTPUT_FORMATS_MAPPING.values()]:
zipf.write(
f"{output_directory}{suffix}",
f"artifacts/{output_directory.split('/')[-1]}{suffix}",
)
return zip_path
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
"""
Upload the specified ZIP file to an S3 bucket.
If the S3 bucket environment variables are not configured,
the function returns None without performing an upload.
Args:
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
zip_path (str): The local file system path to the ZIP file to be uploaded.
scan_id (str): The scan identifier, used as part of the S3 key prefix.
Returns:
str: The S3 URI of the uploaded file (e.g., "s3://<bucket>/<key>") if successful.
None: If the required environment variables for the S3 bucket are not set.
Raises:
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
"""
if not env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET", ""):
return
if env.str("ARTIFACTS_AWS_ACCESS_KEY_ID", ""):
s3 = boto3.client(
"s3",
aws_access_key_id=env.str("ARTIFACTS_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=env.str("ARTIFACTS_AWS_SECRET_ACCESS_KEY"),
aws_session_token=env.str("ARTIFACTS_AWS_SESSION_TOKEN"),
region_name=env.str("ARTIFACTS_AWS_DEFAULT_REGION"),
)
else:
s3 = boto3.client("s3")
s3_key = f"{tenant_id}/{scan_id}/{os.path.basename(zip_path)}"
try:
s3.upload_file(
Filename=zip_path,
Bucket=env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET"),
Key=s3_key,
)
return f"s3://{env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET")}/{s3_key}"
except ClientError as e:
logger.error(f"S3 upload failed: {str(e)}")
raise e
def _create_finding_delta(
last_status: FindingStatus | None | str, new_status: FindingStatus | None
) -> Finding.DeltaChoices:
) -> Finding.DeltaChoices | None:
"""
Determine the delta status of a finding based on its previous and current status.
@@ -53,7 +160,11 @@ def _create_finding_delta(
def _store_resources(
finding: ProwlerFinding, tenant_id: str, provider_instance: Provider
finding: ProwlerFinding,
tenant_id: str,
provider_instance: Provider,
resource_cache: dict,
tag_cache: dict,
) -> tuple[Resource, tuple[str, str]]:
"""
Store resource information from a finding, including tags, in the database.
@@ -65,40 +176,91 @@ def _store_resources(
Returns:
tuple:
- Resource: The resource instance created or retrieved from the database.
- Resource: The resource instance created or updated from the database.
- tuple[str, str]: A tuple containing the resource UID and region.
"""
with rls_transaction(tenant_id):
resource_instance, created = Resource.objects.get_or_create(
tenant_id=tenant_id,
provider=provider_instance,
uid=finding.resource_uid,
defaults={
"region": finding.region,
"service": finding.service_name,
"type": finding.resource_type,
},
)
resource_uid = finding.resource_uid
# Check cache or create/update resource
if resource_uid in resource_cache:
resource_instance = resource_cache[resource_uid]
update_fields = []
for field, value in [
("region", finding.region),
("service", finding.service_name),
("type", finding.resource_type),
("name", finding.resource_name),
]:
if getattr(resource_instance, field) != value:
setattr(resource_instance, field, value)
update_fields.append(field)
if update_fields:
with rls_transaction(tenant_id):
resource_instance.save(update_fields=update_fields)
else:
with rls_transaction(tenant_id):
resource_instance, _ = Resource.objects.update_or_create(
tenant_id=tenant_id,
provider=provider_instance,
uid=resource_uid,
defaults={
"region": finding.region,
"service": finding.service_name,
"type": finding.resource_type,
"name": finding.resource_name,
},
)
resource_cache[resource_uid] = resource_instance
# Process tags with caching
tags = []
for key, value in finding.resource_tags.items():
tag_key = (key, value)
if tag_key not in tag_cache:
with rls_transaction(tenant_id):
tag_instance, _ = ResourceTag.objects.get_or_create(
tenant_id=tenant_id, key=key, value=value
)
tag_cache[tag_key] = tag_instance
tags.append(tag_cache[tag_key])
if not created:
resource_instance.region = finding.region
resource_instance.service = finding.service_name
resource_instance.type = finding.resource_type
resource_instance.save()
with rls_transaction(tenant_id):
tags = [
ResourceTag.objects.get_or_create(
tenant_id=tenant_id, key=key, value=value
)[0]
for key, value in finding.resource_tags.items()
]
resource_instance.upsert_or_delete_tags(tags=tags)
return resource_instance, (resource_instance.uid, resource_instance.region)
def _generate_output_directory(
prowler_provider: object, tenant_id: str, scan_id: str
) -> str:
"""
Generate a dynamic output directory path based on the given provider type.
Args:
prowler_provider (object): An object that has a `type` attribute indicating
the provider type (e.g., "aws", "azure", etc.).
tenant_id (str): A unique identifier for the tenant. Used to build the output path.
scan_id (str): A unique identifier for the scan. Included in the output path.
Returns:
str: The complete path to the output directory, including the tenant ID, scan ID,
provider identity, and a timestamp.
"""
provider_type = prowler_provider.type
get_identity = PROVIDER_IDENTITY_MAP.get(provider_type, lambda _: "unknown")
return (
f"{tmp_output_directory}/{tenant_id}/{scan_id}/prowler-output-"
f"{get_identity(prowler_provider)}-{output_file_timestamp}"
)
def perform_prowler_scan(
tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None
tenant_id: str,
scan_id: str,
provider_id: str,
checks_to_execute: list[str] = None,
):
"""
Perform a scan using Prowler and store the findings and resources in the database.
@@ -120,6 +282,9 @@ def perform_prowler_scan(
exception = None
unique_resources = set()
start_time = time.time()
resource_cache = {}
tag_cache = {}
last_status_cache = {}
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
@@ -129,6 +294,7 @@ def perform_prowler_scan(
scan_instance.save()
try:
# Provider initialization
with rls_transaction(tenant_id):
try:
prowler_provider = initialize_prowler_provider(provider_instance)
@@ -144,116 +310,66 @@ def perform_prowler_scan(
)
provider_instance.save()
prowler_scan = ProwlerScan(provider=prowler_provider, checks=checks_to_execute)
# Scan configuration
prowler_scan = ProwlerScan(
provider=prowler_provider, checks=checks_to_execute or []
)
output_directory = _generate_output_directory(
prowler_provider, tenant_id, scan_id
)
# Create the output directory
os.makedirs("/".join(output_directory.split("/")[:-1]), exist_ok=True)
resource_cache = {}
tag_cache = {}
last_status_cache = {}
for progress, findings in prowler_scan.scan():
all_findings = []
for progress, findings, stats in prowler_scan.scan():
for finding in findings:
if finding is None:
logger.error(f"None finding detected on scan {scan_id}.")
continue
for attempt in range(CELERY_DEADLOCK_ATTEMPTS):
try:
with rls_transaction(tenant_id):
# Process resource
resource_uid = finding.resource_uid
if resource_uid not in resource_cache:
# Get or create the resource
resource_instance, _ = Resource.objects.get_or_create(
tenant_id=tenant_id,
provider=provider_instance,
uid=resource_uid,
defaults={
"region": finding.region,
"service": finding.service_name,
"type": finding.resource_type,
"name": finding.resource_name,
},
)
resource_cache[resource_uid] = resource_instance
else:
resource_instance = resource_cache[resource_uid]
# Update resource fields if necessary
updated_fields = []
if (
finding.region
and resource_instance.region != finding.region
):
resource_instance.region = finding.region
updated_fields.append("region")
if resource_instance.service != finding.service_name:
resource_instance.service = finding.service_name
updated_fields.append("service")
if resource_instance.type != finding.resource_type:
resource_instance.type = finding.resource_type
updated_fields.append("type")
if updated_fields:
with rls_transaction(tenant_id):
resource_instance.save(update_fields=updated_fields)
resource_instance, resource_uid_region = _store_resources(
finding,
tenant_id,
provider_instance,
resource_cache,
tag_cache,
)
unique_resources.add(resource_uid_region)
break
except (OperationalError, IntegrityError) as db_err:
if attempt < CELERY_DEADLOCK_ATTEMPTS - 1:
logger.warning(
f"{'Deadlock error' if isinstance(db_err, OperationalError) else 'Integrity error'} "
f"detected when processing resource {resource_uid} on scan {scan_id}. Retrying..."
f"Database error ({type(db_err).__name__}) "
f"processing resource {finding.resource_uid}, retrying..."
)
time.sleep(0.1 * (2**attempt))
continue
else:
raise db_err
# Update tags
tags = []
with rls_transaction(tenant_id):
for key, value in finding.resource_tags.items():
tag_key = (key, value)
if tag_key not in tag_cache:
tag_instance, _ = ResourceTag.objects.get_or_create(
tenant_id=tenant_id, key=key, value=value
)
tag_cache[tag_key] = tag_instance
else:
tag_instance = tag_cache[tag_key]
tags.append(tag_instance)
resource_instance.upsert_or_delete_tags(tags=tags)
unique_resources.add((resource_instance.uid, resource_instance.region))
# Process finding
# Finding processing
with rls_transaction(tenant_id):
finding_uid = finding.uid
last_first_seen_at = None
if finding_uid not in last_status_cache:
most_recent_finding = (
Finding.all_objects.filter(
tenant_id=tenant_id, uid=finding_uid
)
most_recent = (
Finding.objects.filter(uid=finding_uid)
.order_by("-inserted_at")
.values("status", "first_seen_at")
.first()
)
last_status = None
if most_recent_finding:
last_status = most_recent_finding["status"]
last_first_seen_at = most_recent_finding["first_seen_at"]
last_status_cache[finding_uid] = last_status, last_first_seen_at
last_status, first_seen = (
(most_recent["status"], most_recent["first_seen_at"])
if most_recent
else (None, None)
)
last_status_cache[finding_uid] = (last_status, first_seen)
else:
last_status, last_first_seen_at = last_status_cache[finding_uid]
last_status, first_seen = last_status_cache[finding_uid]
status = FindingStatus[finding.status]
delta = _create_finding_delta(last_status, status)
# For the findings prior to the change, when a first finding is found with delta!="new" it will be
# assigned a current date as first_seen_at and the successive findings with the same UID will
# always get the date of the previous finding.
# For new findings, when a finding (delta="new") is found for the first time, the first_seen_at
# attribute will be assigned the current date, the following findings will get that date.
if not last_first_seen_at:
last_first_seen_at = datetime.now(tz=timezone.utc)
first_seen = first_seen or datetime.now(tz=timezone.utc)
# Create the finding
finding_instance = Finding.objects.create(
tenant_id=tenant_id,
uid=finding_uid,
@@ -266,92 +382,96 @@ def perform_prowler_scan(
raw_result=finding.raw,
check_id=finding.check_id,
scan=scan_instance,
first_seen_at=last_first_seen_at,
first_seen_at=first_seen,
)
finding_instance.add_resources([resource_instance])
# Update compliance data if applicable
if finding.status.value == "MUTED":
continue
# Update compliance status
if finding.status.value != "MUTED":
region_data = check_status_by_region.setdefault(finding.region, {})
if region_data.get(finding.check_id) != "FAIL":
region_data[finding.check_id] = finding.status.value
region_dict = check_status_by_region.setdefault(finding.region, {})
current_status = region_dict.get(finding.check_id)
if current_status == "FAIL":
continue
region_dict[finding.check_id] = finding.status.value
# Update scan progress
# Progress updates and output generation
with rls_transaction(tenant_id):
scan_instance.progress = progress
scan_instance.save()
all_findings.extend(findings)
# Generate output files
for mode, config in OUTPUT_FORMATS_MAPPING.items():
kwargs = dict(config["kwargs"])
if mode == "html":
kwargs["provider"] = prowler_provider
kwargs["stats"] = stats
config["class"](
findings=all_findings,
create_file_descriptor=True,
file_path=output_directory,
file_extension=config["suffix"],
).batch_write_data_to_file(**kwargs)
scan_instance.state = StateChoices.COMPLETED
# Compress output files
zip_path = _compress_output_files(output_directory)
# Save to configured storage
_upload_to_s3(tenant_id, zip_path, scan_id)
except Exception as e:
logger.error(f"Error performing scan {scan_id}: {e}")
logger.error(f"Scan {scan_id} failed: {str(e)}")
exception = e
scan_instance.state = StateChoices.FAILED
finally:
# Final scan updates
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=timezone.utc)
scan_instance.unique_resource_count = len(unique_resources)
scan_instance.save()
if exception is None:
try:
regions = prowler_provider.get_regions()
except AttributeError:
regions = set()
# Compliance processing
if not exception:
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
provider_instance.provider
]
compliance_overview_by_region = {
region: deepcopy(compliance_template) for region in regions
compliance_overview = {
region: deepcopy(compliance_template)
for region in getattr(prowler_provider, "get_regions", lambda: set())()
}
for region, check_status in check_status_by_region.items():
compliance_data = compliance_overview_by_region.setdefault(
region, deepcopy(compliance_template)
)
for check_name, status in check_status.items():
for region, checks in check_status_by_region.items():
for check_id, status in checks.items():
generate_scan_compliance(
compliance_data,
compliance_overview.setdefault(
region, deepcopy(compliance_template)
),
provider_instance.provider,
check_name,
check_id,
status,
)
# Prepare compliance overview objects
compliance_overview_objects = []
for region, compliance_data in compliance_overview_by_region.items():
for compliance_id, compliance in compliance_data.items():
compliance_overview_objects.append(
ComplianceOverview(
tenant_id=tenant_id,
scan=scan_instance,
region=region,
compliance_id=compliance_id,
framework=compliance["framework"],
version=compliance["version"],
description=compliance["description"],
requirements=compliance["requirements"],
requirements_passed=compliance["requirements_status"]["passed"],
requirements_failed=compliance["requirements_status"]["failed"],
requirements_manual=compliance["requirements_status"]["manual"],
total_requirements=compliance["total_requirements"],
)
ComplianceOverview.objects.bulk_create(
[
ComplianceOverview(
tenant_id=tenant_id,
scan=scan_instance,
region=region,
compliance_id=compliance_id,
**compliance_data,
)
with rls_transaction(tenant_id):
ComplianceOverview.objects.bulk_create(compliance_overview_objects)
for region, data in compliance_overview.items()
for compliance_id, compliance_data in data.items()
]
)
if exception is not None:
if exception:
raise exception
serializer = ScanTaskSerializer(instance=scan_instance)
return serializer.data
return ScanTaskSerializer(instance=scan_instance).data
def aggregate_findings(tenant_id: str, scan_id: str):
@@ -478,29 +598,28 @@ def aggregate_findings(tenant_id: str, scan_id: str):
),
)
with rls_transaction(tenant_id):
scan_aggregations = {
ScanSummary(
tenant_id=tenant_id,
scan_id=scan_id,
check_id=agg["check_id"],
service=agg["resources__service"],
severity=agg["severity"],
region=agg["resources__region"],
fail=agg["fail"],
_pass=agg["_pass"],
muted=agg["muted"],
total=agg["total"],
new=agg["new"],
changed=agg["changed"],
unchanged=agg["unchanged"],
fail_new=agg["fail_new"],
fail_changed=agg["fail_changed"],
pass_new=agg["pass_new"],
pass_changed=agg["pass_changed"],
muted_new=agg["muted_new"],
muted_changed=agg["muted_changed"],
)
for agg in aggregation
}
ScanSummary.objects.bulk_create(scan_aggregations, batch_size=3000)
ScanSummary.objects.bulk_create(
[
ScanSummary(
tenant_id=tenant_id,
scan_id=scan_id,
check_id=agg["check_id"],
service=agg["resources__service"],
severity=agg["severity"],
region=agg["resources__region"],
**{
k: v or 0
for k, v in agg.items()
if k
not in {
"check_id",
"resources__service",
"severity",
"resources__region",
}
},
)
for agg in aggregation
],
batch_size=3000,
)

View File

@@ -16,6 +16,7 @@ services:
volumes:
- "./api/src/backend:/home/prowler/backend"
- "./api/pyproject.toml:/home/prowler/pyproject.toml"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -85,6 +86,8 @@ services:
env_file:
- path: .env
required: false
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy

View File

@@ -59,6 +59,7 @@ aws_services_json_file = "aws_regions_by_service.json"
# gcp_zones_json_file = "gcp_zones.json"
default_output_directory = getcwd() + "/output"
tmp_output_directory = "/tmp/prowler_api_output"
output_file_timestamp = timestamp.strftime("%Y%m%d%H%M%S")
timestamp_iso = timestamp.isoformat(sep=" ", timespec="seconds")
csv_file_suffix = ".csv"

View File

@@ -139,13 +139,9 @@ def remove_custom_checks_module(input_folder: str, provider: str):
def list_services(provider: str) -> set:
available_services = set()
checks_tuple = recover_checks_from_provider(provider)
split_character = "\\" if os.name == "nt" else "/"
for _, check_path in checks_tuple:
# Format: /absolute_path/prowler/providers/{provider}/services/{service_name}/{check_name}
if os.name == "nt":
service_name = check_path.split("\\")[-2]
else:
service_name = check_path.split("/")[-2]
available_services.add(service_name)
available_services.add(check_path.split(split_character)[-2])
return sorted(available_services)

View File

@@ -38,16 +38,18 @@ class Output(ABC):
file_extension: str = "",
) -> None:
self._data = []
self.file_path = file_path
if not file_extension and file_path:
self._file_extension = "".join(Path(file_path).suffixes)
if file_extension:
self._file_extension = file_extension
self.file_path = f"{file_path}{self.file_extension}"
if findings:
self.transform(findings)
if create_file_descriptor and file_path:
self.create_file_descriptor(file_path)
self.create_file_descriptor(self.file_path)
@property
def data(self):

View File

@@ -1,5 +1,5 @@
import datetime
from typing import Generator
from typing import Dict, Generator, List, Optional, Set
from prowler.lib.check.check import (
execute,
@@ -14,6 +14,7 @@ from prowler.lib.check.models import CheckMetadata, Severity
from prowler.lib.logger import logger
from prowler.lib.outputs.common import Status
from prowler.lib.outputs.finding import Finding
from prowler.lib.outputs.outputs import extract_findings_statistics
from prowler.lib.scan.exceptions.exceptions import (
ScanInvalidCategoryError,
ScanInvalidCheckError,
@@ -28,28 +29,25 @@ from prowler.providers.common.provider import Provider
class Scan:
_provider: Provider
# Refactor(Core): This should replace the Audit_Metadata
_number_of_checks_to_execute: int = 0
_number_of_checks_completed: int = 0
# TODO the str should be a set of Check objects
_checks_to_execute: list[str]
_service_checks_to_execute: dict[str, set[str]]
_service_checks_completed: dict[str, set[str]]
_checks_to_execute: List[str]
_service_checks_map: Dict[str, Set[str]]
_completed_checks: Set[str]
_progress: float = 0.0
_findings: List[Finding] = []
_duration: int = 0
_status: list[str] = None
_statuses: Optional[List[Status]] = None
def __init__(
self,
provider: Provider,
checks: list[str] = None,
services: list[str] = None,
compliances: list[str] = None,
categories: list[str] = None,
severities: list[str] = None,
excluded_checks: list[str] = None,
excluded_services: list[str] = None,
status: list[str] = None,
checks: Optional[List[str]] = None,
services: Optional[List[str]] = None,
compliances: Optional[List[str]] = None,
categories: Optional[List[str]] = None,
severities: Optional[List[str]] = None,
excluded_checks: Optional[List[str]] = None,
excluded_services: Optional[List[str]] = None,
status: Optional[List[str]] = None,
):
"""
Scan is the class that executes the checks and yields the progress and the findings.
@@ -74,151 +72,190 @@ class Scan:
ScanInvalidStatusError: If the status does not exist in the provider.
"""
self._provider = provider
self._statuses = self._validate_statuses(status)
# Validate the status
if status:
try:
for s in status:
Status(s)
if not self._status:
self._status = []
if s not in self._status:
self._status.append(s)
except ValueError:
raise ScanInvalidStatusError(f"Invalid status provided: {s}.")
# Load bulk compliance frameworks
bulk_compliance_frameworks = Compliance.get_bulk(provider.type)
bulk_checks_metadata = self._load_checks_metadata(bulk_compliance_frameworks)
# Get bulk checks metadata for the provider
bulk_checks_metadata = CheckMetadata.get_bulk(provider.type)
# Complete checks metadata with the compliance framework specification
bulk_checks_metadata = update_checks_metadata_with_compliance(
self._validate_inputs(
checks,
services,
compliances,
categories,
severities,
bulk_checks_metadata,
bulk_compliance_frameworks,
)
self._checks_to_execute = self._init_checks_to_execute(
bulk_checks_metadata,
bulk_compliance_frameworks,
checks,
services,
compliances,
categories,
severities,
excluded_checks,
excluded_services,
)
self._service_checks_map = get_service_checks_mapping(self._checks_to_execute)
self._completed_checks = set()
def _validate_statuses(
self, statuses: Optional[List[str]]
) -> Optional[List[Status]]:
"""Validate and convert status strings to Status enums."""
if not statuses:
return None
validated = []
for status in statuses:
try:
validated.append(Status(status))
except ValueError:
raise ScanInvalidStatusError(f"Invalid status: {status}")
return validated
def _load_checks_metadata(self, bulk_compliance_frameworks: dict) -> dict:
"""Load and enhance checks metadata with compliance information."""
bulk_checks_metadata = CheckMetadata.get_bulk(self._provider.type)
return update_checks_metadata_with_compliance(
bulk_compliance_frameworks, bulk_checks_metadata
)
# Create a list of valid categories
valid_categories = set()
for check, metadata in bulk_checks_metadata.items():
for category in metadata.Categories:
if category not in valid_categories:
valid_categories.add(category)
def _validate_inputs(
self,
checks: Optional[List[str]],
services: Optional[List[str]],
compliances: Optional[List[str]],
categories: Optional[List[str]],
severities: Optional[List[str]],
bulk_checks_metadata: dict,
bulk_compliance_frameworks: dict,
):
"""Validate all input parameters against provider capabilities."""
valid_services = list_services(self._provider.type)
valid_categories = self._extract_valid_categories(bulk_checks_metadata)
# Validate checks
self._validate_checks(checks, bulk_checks_metadata)
self._validate_services(services, valid_services)
self._validate_compliances(compliances, bulk_compliance_frameworks)
self._validate_categories(categories, valid_categories)
self._validate_severities(severities)
def _extract_valid_categories(self, checks_metadata: dict) -> Set[str]:
"""Extract unique categories from checks metadata."""
return {
category
for metadata in checks_metadata.values()
for category in metadata.Categories
}
def _validate_checks(self, checks: Optional[List[str]], metadata: dict):
"""Validate requested checks exist in provider."""
if checks:
for check in checks:
if check not in bulk_checks_metadata.keys():
raise ScanInvalidCheckError(f"Invalid check provided: {check}.")
invalid = [check for check in checks if check not in metadata]
if invalid:
raise ScanInvalidCheckError(f"Invalid checks: {', '.join(invalid)}")
# Validate services
def _validate_services(self, services: Optional[List[str]], valid_services: list):
"""Validate requested services exist in provider."""
if services:
for service in services:
if service not in list_services(provider.type):
raise ScanInvalidServiceError(
f"Invalid service provided: {service}."
)
invalid = [srv for srv in services if srv not in valid_services]
if invalid:
raise ScanInvalidServiceError(f"Invalid services: {', '.join(invalid)}")
# Validate compliances
def _validate_compliances(self, compliances: Optional[List[str]], frameworks: dict):
"""Validate compliance frameworks exist."""
if compliances:
for compliance in compliances:
if compliance not in bulk_compliance_frameworks.keys():
raise ScanInvalidComplianceFrameworkError(
f"Invalid compliance provided: {compliance}."
)
invalid = [comp for comp in compliances if comp not in frameworks]
if invalid:
raise ScanInvalidComplianceFrameworkError(
f"Invalid compliances: {', '.join(invalid)}"
)
# Validate categories
def _validate_categories(
self, categories: Optional[List[str]], valid_categories: Set[str]
):
"""Validate categories exist in provider checks."""
if categories:
for category in categories:
if category not in valid_categories:
raise ScanInvalidCategoryError(
f"Invalid category provided: {category}."
)
invalid = [cat for cat in categories if cat not in valid_categories]
if invalid:
raise ScanInvalidCategoryError(
f"Invalid categories: {', '.join(invalid)}"
)
# Validate severity
def _validate_severities(self, severities: Optional[List[str]]):
"""Validate severity values are valid."""
if severities:
for severity in severities:
try:
Severity(severity)
except ValueError:
raise ScanInvalidSeverityError(
f"Invalid severity provided: {severity}."
)
try:
[Severity(sev) for sev in severities]
except ValueError as e:
raise ScanInvalidSeverityError(f"Invalid severity: {e}")
# Load checks to execute
self._checks_to_execute = sorted(
load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks=bulk_compliance_frameworks,
check_list=checks,
service_list=services,
compliance_frameworks=compliances,
categories=categories,
severities=severities,
provider=provider.type,
checks_file=None,
)
def _init_checks_to_execute(
self,
bulk_checks_metadata: dict,
bulk_compliance_frameworks: dict,
checks: Optional[List[str]],
services: Optional[List[str]],
compliances: Optional[List[str]],
categories: Optional[List[str]],
severities: Optional[List[str]],
excluded_checks: Optional[List[str]],
excluded_services: Optional[List[str]],
) -> List[str]:
"""Load and filter checks based on configuration."""
checks = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks=bulk_compliance_frameworks,
check_list=checks,
service_list=services,
compliance_frameworks=compliances,
categories=categories,
severities=severities,
provider=self._provider.type,
)
# Exclude checks
if excluded_checks:
for check in excluded_checks:
if check in self._checks_to_execute:
self._checks_to_execute.remove(check)
else:
raise ScanInvalidCheckError(
f"Invalid check provided: {check}. Check does not exist in the provider."
)
checks = [c for c in checks if c not in excluded_checks]
# Exclude services
if excluded_services:
for check in self._checks_to_execute:
if get_service_name_from_check_name(check) in excluded_services:
self._checks_to_execute.remove(check)
else:
raise ScanInvalidServiceError(
f"Invalid service provided: {check}. Service does not exist in the provider."
)
excluded_services_set = set(excluded_services)
checks = [
c
for c in checks
if get_service_from_check(c) not in excluded_services_set
]
self._number_of_checks_to_execute = len(self._checks_to_execute)
service_checks_to_execute = get_service_checks_to_execute(
self._checks_to_execute
)
service_checks_completed = dict()
self._service_checks_to_execute = service_checks_to_execute
self._service_checks_completed = service_checks_completed
return sorted(checks)
@property
def checks_to_execute(self) -> list[str]:
return self._checks_to_execute
def total_checks(self) -> int:
return len(self._checks_to_execute)
@property
def service_checks_to_execute(self) -> dict[str, set[str]]:
return self._service_checks_to_execute
@property
def service_checks_completed(self) -> dict[str, set[str]]:
return self._service_checks_completed
@property
def provider(self) -> Provider:
return self._provider
def completed_checks(self) -> int:
return len(self._completed_checks)
@property
def progress(self) -> float:
return (
self._number_of_checks_completed / self._number_of_checks_to_execute * 100
(self.completed_checks / self.total_checks * 100)
if self.total_checks
else 0
)
@property
def duration(self) -> int:
return self._duration
def remaining_services(self) -> Dict[str, Set[str]]:
return {
service: checks
for service, checks in self._service_checks_map.items()
if checks
}
def scan(
self,
custom_checks_metadata: dict = {},
) -> Generator[tuple[float, list[Finding]], None, None]:
def scan(self, custom_checks_metadata: dict = {}) -> Generator[float, List[Finding], dict]:
"""
Executes the scan by iterating over the checks to execute and executing each check.
Yields the progress and findings for each check.
@@ -233,143 +270,103 @@ class Scan:
ModuleNotFoundError: If the check does not exist in the provider or is from another provider.
Exception: If any other error occurs during the execution of a check.
"""
try:
checks_to_execute = self.checks_to_execute
# Initialize the Audit Metadata
# TODO: this should be done in the provider class
# Refactor(Core): Audit manager?
self._provider.audit_metadata = Audit_Metadata(
services_scanned=0,
expected_checks=checks_to_execute,
completed_checks=0,
audit_progress=0,
)
self._provider.audit_metadata = Audit_Metadata(
services_scanned=0,
expected_checks=self._checks_to_execute,
completed_checks=0,
audit_progress=0,
)
start_time = datetime.datetime.now()
start_time = datetime.datetime.now()
for check_name in checks_to_execute:
for check_name in self._checks_to_execute:
service = get_service_from_check(check_name)
try:
check_module = self._import_check_module(check_name, service)
findings = self._execute_check(check_module, check_name, custom_checks_metadata)
filtered_findings = self._filter_findings_by_status(findings)
except Exception as error:
logger.error(f"{check_name} failed: {error}")
continue
self._update_scan_state(check_name, service, filtered_findings)
stats = extract_findings_statistics(filtered_findings)
findings = []
for finding in filtered_findings:
try:
# Recover service from check name
service = get_service_name_from_check_name(check_name)
try:
# Import check module
check_module_path = f"prowler.providers.{self._provider.type}.services.{service}.{check_name}.{check_name}"
lib = import_check(check_module_path)
# Recover functions from check
check_to_execute = getattr(lib, check_name)
check = check_to_execute()
except ModuleNotFoundError:
logger.error(
f"Check '{check_name}' was not found for the {self._provider.type.upper()} provider"
findings.append(
Finding.generate_output(
self._provider, finding, output_options=None
)
continue
# Execute the check
check_findings = execute(
check,
self._provider,
custom_checks_metadata,
output_options=None,
)
except Exception:
continue
# Filter the findings by the status
if self._status:
for finding in check_findings:
if finding.status not in self._status:
check_findings.remove(finding)
yield self.progress, findings, stats
# Remove the executed check
self._service_checks_to_execute[service].remove(check_name)
if len(self._service_checks_to_execute[service]) == 0:
self._service_checks_to_execute.pop(service, None)
# Add the completed check
if service not in self._service_checks_completed:
self._service_checks_completed[service] = set()
self._service_checks_completed[service].add(check_name)
self._number_of_checks_completed += 1
self._duration = int((datetime.datetime.now() - start_time).total_seconds())
# This should be done just once all the service's checks are completed
# This metadata needs to get to the services not within the provider
# since it is present in the Scan class
self._provider.audit_metadata = update_audit_metadata(
self._provider.audit_metadata,
self.get_completed_services(),
self.get_completed_checks(),
)
findings = []
for finding in check_findings:
try:
findings.append(
Finding.generate_output(
self._provider, finding, output_options=None
)
)
except Exception:
continue
yield self.progress, findings
# If check does not exists in the provider or is from another provider
except ModuleNotFoundError:
logger.error(
f"Check '{check_name}' was not found for the {self._provider.type.upper()} provider"
)
except Exception as error:
logger.error(
f"{check_name} - {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
# Update the scan duration when all checks are completed
self._duration = int((datetime.datetime.now() - start_time).total_seconds())
except Exception as error:
def _import_check_module(self, check_name: str, service: str):
"""Dynamically import check module."""
module_path = (
f"prowler.providers.{self._provider.type}.services."
f"{service}.{check_name}.{check_name}"
)
try:
return import_check(module_path)
except ModuleNotFoundError:
logger.error(
f"{check_name} - {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"Check '{check_name}' not found for {self._provider.type.upper()}"
)
raise
def get_completed_services(self) -> set[str]:
"""
get_completed_services returns the services that have been completed.
def _execute_check(self, check_module, check_name: str, custom_checks_metadata: dict = {}) -> List[Finding]:
"""Execute a single check and return its findings."""
check_func = getattr(check_module, check_name)
return execute(check_func(), self._provider, custom_checks_metadata)
Example:
get_completed_services() -> {"ec2", "s3"}
"""
return self._service_checks_completed.keys()
def _filter_findings_by_status(self, findings: List[Finding]) -> List[Finding]:
"""Filter findings based on configured status filters."""
if not self._statuses:
return findings
return [f for f in findings if f.status in self._statuses]
def get_completed_checks(self) -> set[str]:
"""
get_completed_checks returns the checks that have been completed.
def _update_scan_state(
self, check_name: str, service: str, findings: List[Finding]
):
"""Update scan state after check completion."""
self._service_checks_map[service].discard(check_name)
self._completed_checks.add(check_name)
self._findings.extend(findings)
Example:
get_completed_checks() -> {"ec2_instance_public_ip", "s3_bucket_public"}
"""
completed_checks = set()
for checks in self._service_checks_completed.values():
completed_checks.update(checks)
return completed_checks
self._provider.audit_metadata = update_audit_metadata(
self._provider.audit_metadata,
self.remaining_services.keys(),
self._completed_checks,
)
def get_service_name_from_check_name(check_name: str) -> str:
def get_service_from_check(check_name: str) -> str:
"""
get_service_name_from_check_name returns the service name for a given check name.
Return the service name for a given check name.
Example:
get_service_name_from_check_name("ec2_instance_public") -> "ec2"
get_service_from_check("ec2_instance_public") -> "ec2"
"""
return check_name.split("_")[0]
def get_service_checks_to_execute(checks_to_execute: set[str]) -> dict[str, set[str]]:
def get_service_checks_mapping(checks: List[str]) -> Dict[str, Set[str]]:
"""
get_service_checks_to_execute returns a dictionary with the services and the checks to execute.
Return a dictionary with the services and the checks to execute.
Example:
get_service_checks_to_execute({"accessanalyzer_enabled", "ec2_instance_public_ip"})
get_service_checks_mapping({"accessanalyzer_enabled", "ec2_instance_public_ip"})
-> {"accessanalyzer": {"accessanalyzer_enabled"}, "ec2": {"ec2_instance_public_ip"}}
"""
service_checks_to_execute = dict()
for check in checks_to_execute:
# check -> accessanalyzer_enabled
# service -> accessanalyzer
service = get_service_name_from_check_name(check)
if service not in service_checks_to_execute:
service_checks_to_execute[service] = set()
service_checks_to_execute[service].add(check)
return service_checks_to_execute
service_map = {}
for check in checks:
service = get_service_from_check(check)
service_map.setdefault(service, set()).add(check)
return service_map

View File

@@ -1,7 +1,6 @@
from importlib.machinery import FileFinder
from pkgutil import ModuleInfo
from unittest import mock
import pytest
from mock import MagicMock, patch
@@ -13,7 +12,7 @@ from prowler.lib.scan.exceptions.exceptions import (
ScanInvalidSeverityError,
ScanInvalidStatusError,
)
from prowler.lib.scan.scan import Scan, get_service_checks_to_execute
from prowler.lib.scan.scan import Scan, get_service_checks_mapping
from tests.lib.outputs.fixtures.fixtures import generate_finding_output
from tests.providers.aws.utils import set_mocked_aws_provider
@@ -196,9 +195,9 @@ class TestScan:
mock_provider.type = "aws"
scan = Scan(mock_provider, checks=checks_to_execute)
assert scan.provider == mock_provider
assert scan._provider == mock_provider
# Check that the checks to execute are sorted and without duplicates
assert scan.checks_to_execute == [
assert scan._checks_to_execute == [
"accessanalyzer_enabled",
"accessanalyzer_enabled_without_findings",
"account_maintain_current_contact_details",
@@ -258,14 +257,16 @@ class TestScan:
"config_recorder_all_regions_enabled",
"workspaces_vpc_2private_1public_subnets_nat",
]
assert scan.service_checks_to_execute == get_service_checks_to_execute(
assert scan._service_checks_map == get_service_checks_mapping(
checks_to_execute
)
assert scan.service_checks_completed == {}
assert scan.progress == 0
assert scan.duration == 0
assert scan.get_completed_services() == set()
assert scan.get_completed_checks() == set()
assert scan._completed_checks == set()
assert scan._progress == 0
assert scan._duration == 0
all_values = set().union(*scan.remaining_services.values())
for check in scan._checks_to_execute:
assert check in all_values
assert scan.completed_checks == 0
def test_init_with_no_checks(
mock_provider,
@@ -281,18 +282,24 @@ class TestScan:
mock_load_checks_to_execute.assert_called_once()
mock_recover_checks_from_provider.assert_called_once_with("aws")
assert scan.provider == mock_provider
assert scan.checks_to_execute == ["accessanalyzer_enabled"]
assert scan.service_checks_to_execute == get_service_checks_to_execute(
assert scan._provider == mock_provider
assert scan._checks_to_execute == ["accessanalyzer_enabled"]
assert scan._service_checks_map == get_service_checks_mapping(
["accessanalyzer_enabled"]
)
assert scan.service_checks_completed == {}
assert scan.progress == 0
assert scan.get_completed_services() == set()
assert scan.get_completed_checks() == set()
assert scan._completed_checks == set()
assert scan._progress == 0
all_values = set().union(*scan.remaining_services.values())
for check in scan._checks_to_execute:
assert check in all_values
assert scan.completed_checks == 0
@patch("importlib.import_module")
@patch("prowler.lib.scan.scan.list_services")
@patch("prowler.lib.scan.scan.extract_findings_statistics")
def test_scan(
mock_extract_findings_statistics,
mock_list_services,
mock_import_module,
mock_global_provider,
mock_execute,
@@ -328,11 +335,9 @@ class TestScan:
assert results[0][0] == 100.0
assert scan.progress == 100.0
# Since the scan is mocked, the duration will always be 0 for now
assert scan.duration == 0
assert scan._number_of_checks_completed == 1
assert scan.service_checks_completed == {
"accessanalyzer": {"accessanalyzer_enabled"},
}
assert scan._duration == 0
assert scan.completed_checks == 1
assert scan._completed_checks == {"accessanalyzer_enabled"}
mock_logger.error.assert_not_called()
def test_init_invalid_severity(
@@ -396,7 +401,9 @@ class TestScan:
Scan(mock_provider, checks=checks_to_execute, status=["invalid_status"])
@patch("importlib.import_module")
@patch("prowler.lib.scan.scan.list_services")
def test_scan_filter_status(
mock_list_services,
mock_import_module,
mock_global_provider,
mock_recover_checks_from_provider,
@@ -422,4 +429,21 @@ class TestScan:
mock_recover_checks_from_provider.assert_called_once_with("aws")
results = list(scan.scan(custom_checks_metadata))
assert results[0] == (100.0, [])
assert results[0] == (100.0, [], {
"all_fails_are_muted": True,
"findings_count": 0,
"resources_count": 0,
"total_critical_severity_fail": 0,
"total_critical_severity_pass": 0,
"total_fail": 0,
"total_high_severity_fail": 0,
"total_high_severity_pass": 0,
"total_low_severity_fail": 0,
"total_low_severity_pass": 0,
"total_medium_severity_fail": 0,
"total_medium_severity_pass": 0,
"total_muted_fail": 0,
"total_muted_pass": 0,
"total_pass": 0,
},
)