mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-02-09 02:30:43 +00:00
Compare commits
6 Commits
feat/PROWL
...
PRWLR-5956
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a14e21553 | ||
|
|
0995c7a845 | ||
|
|
63f8186bd6 | ||
|
|
d70c71c903 | ||
|
|
53c571c289 | ||
|
|
a3c7846cd9 |
17
.env
17
.env
@@ -30,6 +30,23 @@ VALKEY_HOST=valkey
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_DB=0
|
||||
|
||||
# API scan settings
|
||||
# The AWS access key to be used when uploading scan artifacts to an S3 bucket
|
||||
# If left empty, default AWS credentials resolution behavior will be used
|
||||
ARTIFACTS_AWS_ACCESS_KEY_ID=""
|
||||
|
||||
# The AWS secret key to be used when uploading scan artifacts to an S3 bucket
|
||||
ARTIFACTS_AWS_SECRET_ACCESS_KEY=""
|
||||
|
||||
# An optional AWS session token
|
||||
ARTIFACTS_AWS_SESSION_TOKEN=""
|
||||
|
||||
# The AWS region where your S3 bucket is located (e.g., "us-east-1")
|
||||
ARTIFACTS_AWS_DEFAULT_REGION=""
|
||||
|
||||
# The name of the S3 bucket where scan artifacts should be stored
|
||||
ARTIFACTS_AWS_S3_OUTPUT_BUCKET=""
|
||||
|
||||
# Django settings
|
||||
DJANGO_ALLOWED_HOSTS=localhost,127.0.0.1,prowler-api
|
||||
DJANGO_BIND_ADDRESS=0.0.0.0
|
||||
|
||||
@@ -4093,6 +4093,38 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ScanUpdateResponse'
|
||||
description: ''
|
||||
/api/v1/scans/{id}/report:
|
||||
get:
|
||||
operationId: scans_report_retrieve
|
||||
description: Returns a ZIP file containing the requested report
|
||||
summary: Download ZIP report
|
||||
parameters:
|
||||
- in: query
|
||||
name: fields[scans]
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: []
|
||||
description: endpoint return only specific fields in the response on a per-type
|
||||
basis by including a fields[TYPE] query parameter.
|
||||
explode: false
|
||||
- in: path
|
||||
name: id
|
||||
schema:
|
||||
type: string
|
||||
format: uuid
|
||||
description: A UUID string identifying this scan.
|
||||
required: true
|
||||
tags:
|
||||
- Scan
|
||||
security:
|
||||
- jwtAuth: []
|
||||
responses:
|
||||
'200':
|
||||
description: Report obtanined successfully
|
||||
'404':
|
||||
description: Report not found
|
||||
/api/v1/schedules/daily:
|
||||
post:
|
||||
operationId: schedules_daily_create
|
||||
|
||||
@@ -819,6 +819,14 @@ class ScanTaskSerializer(RLSSerializer):
|
||||
]
|
||||
|
||||
|
||||
class ScanReportSerializer(RLSSerializer):
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
"id",
|
||||
]
|
||||
|
||||
|
||||
class ResourceTagSerializer(RLSSerializer):
|
||||
"""
|
||||
Serializer for the ResourceTag model
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
|
||||
from celery.result import AsyncResult
|
||||
from config.env import env
|
||||
from django.conf import settings as django_settings
|
||||
from django.contrib.postgres.aggregates import ArrayAgg
|
||||
from django.contrib.postgres.search import SearchQuery
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, F, OuterRef, Prefetch, Q, Subquery, Sum
|
||||
from django.http import HttpResponse
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.urls import reverse
|
||||
from django.utils.decorators import method_decorator
|
||||
@@ -114,6 +121,7 @@ from api.v1.serializers import (
|
||||
RoleSerializer,
|
||||
RoleUpdateSerializer,
|
||||
ScanCreateSerializer,
|
||||
ScanReportSerializer,
|
||||
ScanSerializer,
|
||||
ScanUpdateSerializer,
|
||||
ScheduleDailyCreateSerializer,
|
||||
@@ -126,6 +134,7 @@ from api.v1.serializers import (
|
||||
UserSerializer,
|
||||
UserUpdateSerializer,
|
||||
)
|
||||
from prowler.config.config import tmp_output_directory
|
||||
|
||||
CACHE_DECORATOR = cache_control(
|
||||
max_age=django_settings.CACHE_MAX_AGE,
|
||||
@@ -1073,6 +1082,8 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
return ScanCreateSerializer
|
||||
elif self.action == "partial_update":
|
||||
return ScanUpdateSerializer
|
||||
elif self.action == "report":
|
||||
return ScanReportSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
@@ -1127,6 +1138,101 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
},
|
||||
)
|
||||
|
||||
@extend_schema(
|
||||
tags=["Scan"],
|
||||
summary="Download ZIP report",
|
||||
description="Returns a ZIP file containing the requested report",
|
||||
request=ScanReportSerializer,
|
||||
responses={
|
||||
200: OpenApiResponse(description="Report obtanined successfully"),
|
||||
404: OpenApiResponse(description="Report not found"),
|
||||
},
|
||||
)
|
||||
@action(detail=True, methods=["get"], url_name="report")
|
||||
def report(self, request, pk=None):
|
||||
s3_client = None
|
||||
try:
|
||||
s3_client = boto3.client("s3")
|
||||
s3_client.list_buckets()
|
||||
except (ClientError, NoCredentialsError, ParamValidationError):
|
||||
try:
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=env.str("ARTIFACTS_AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=env.str("ARTIFACTS_AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_token=env.str("ARTIFACTS_AWS_SESSION_TOKEN"),
|
||||
region_name=env.str("ARTIFACTS_AWS_DEFAULT_REGION"),
|
||||
)
|
||||
s3_client.list_buckets()
|
||||
except (ClientError, NoCredentialsError, ParamValidationError):
|
||||
s3_client = None
|
||||
|
||||
if s3_client:
|
||||
bucket_name = env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET")
|
||||
s3_prefix = f"{request.tenant_id}/{pk}/"
|
||||
|
||||
try:
|
||||
response = s3_client.list_objects_v2(
|
||||
Bucket=bucket_name, Prefix=s3_prefix
|
||||
)
|
||||
if response["KeyCount"] == 0:
|
||||
return Response(
|
||||
{"detail": "No files found in S3 storage"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
zip_files = [
|
||||
obj["Key"]
|
||||
for obj in response.get("Contents", [])
|
||||
if obj["Key"].endswith(".zip")
|
||||
]
|
||||
if not zip_files:
|
||||
return Response(
|
||||
{"detail": "No ZIP files found in S3 storage"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
s3_key = zip_files[0]
|
||||
s3_object = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
|
||||
file_content = s3_object["Body"].read()
|
||||
filename = os.path.basename(s3_key)
|
||||
|
||||
except ClientError:
|
||||
return Response(
|
||||
{"detail": "Error accessing cloud storage"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
else:
|
||||
local_path = os.path.join(
|
||||
tmp_output_directory,
|
||||
str(request.tenant_id),
|
||||
str(pk),
|
||||
"*.zip",
|
||||
)
|
||||
zip_files = glob.glob(local_path)
|
||||
if not zip_files:
|
||||
return Response(
|
||||
{"detail": "No local files found"}, status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
file_path = zip_files[0]
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
filename = os.path.basename(file_path)
|
||||
except IOError:
|
||||
return Response(
|
||||
{"detail": "Error reading local file"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
response = HttpResponse(
|
||||
file_content, content_type="application/x-zip-compressed"
|
||||
)
|
||||
response["Content-Disposition"] = f'attachment; filename="{filename}"'
|
||||
return response
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.env import env
|
||||
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
|
||||
from django.db import IntegrityError, OperationalError
|
||||
from django.db.models import Case, Count, IntegerField, Sum, When
|
||||
@@ -25,15 +30,117 @@ from api.models import (
|
||||
from api.models import StatusChoices as FindingStatus
|
||||
from api.utils import initialize_prowler_provider
|
||||
from api.v1.serializers import ScanTaskSerializer
|
||||
from prowler.config.config import (
|
||||
csv_file_suffix,
|
||||
html_file_suffix,
|
||||
json_asff_file_suffix,
|
||||
json_ocsf_file_suffix,
|
||||
output_file_timestamp,
|
||||
tmp_output_directory,
|
||||
)
|
||||
from prowler.lib.outputs.asff.asff import ASFF
|
||||
from prowler.lib.outputs.csv.csv import CSV
|
||||
from prowler.lib.outputs.finding import Finding as ProwlerFinding
|
||||
from prowler.lib.outputs.html.html import HTML
|
||||
from prowler.lib.outputs.ocsf.ocsf import OCSF
|
||||
from prowler.lib.scan.scan import Scan as ProwlerScan
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
# Predefined mapping for output formats and their configurations
|
||||
OUTPUT_FORMATS_MAPPING = {
|
||||
"csv": {
|
||||
"class": CSV,
|
||||
"suffix": csv_file_suffix,
|
||||
"kwargs": {},
|
||||
},
|
||||
"json-asff": {"class": ASFF, "suffix": json_asff_file_suffix, "kwargs": {}},
|
||||
"json-ocsf": {"class": OCSF, "suffix": json_ocsf_file_suffix, "kwargs": {}},
|
||||
"html": {"class": HTML, "suffix": html_file_suffix, "kwargs": {"stats": {}}},
|
||||
}
|
||||
|
||||
# Mapping provider types to their identity components for output paths
|
||||
PROVIDER_IDENTITY_MAP = {
|
||||
"aws": lambda p: p.identity.account,
|
||||
"azure": lambda p: p.identity.tenant_domain,
|
||||
"gcp": lambda p: p.identity.profile,
|
||||
"kubernetes": lambda p: p.identity.context.replace(":", "_").replace("/", "_"),
|
||||
}
|
||||
|
||||
|
||||
def _compress_output_files(output_directory: str) -> str:
|
||||
"""
|
||||
Compress output files from all configured output formats into a ZIP archive.
|
||||
|
||||
Args:
|
||||
output_directory (str): The directory where the output files are located.
|
||||
The function looks up all known suffixes in OUTPUT_FORMATS_MAPPING
|
||||
and compresses those files into a single ZIP.
|
||||
|
||||
Returns:
|
||||
str: The full path to the newly created ZIP archive.
|
||||
"""
|
||||
zip_path = f"{output_directory}.zip"
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
for suffix in [config["suffix"] for config in OUTPUT_FORMATS_MAPPING.values()]:
|
||||
zipf.write(
|
||||
f"{output_directory}{suffix}",
|
||||
f"artifacts/{output_directory.split('/')[-1]}{suffix}",
|
||||
)
|
||||
|
||||
return zip_path
|
||||
|
||||
|
||||
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
|
||||
"""
|
||||
Upload the specified ZIP file to an S3 bucket.
|
||||
|
||||
If the S3 bucket environment variables are not configured,
|
||||
the function returns None without performing an upload.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
|
||||
zip_path (str): The local file system path to the ZIP file to be uploaded.
|
||||
scan_id (str): The scan identifier, used as part of the S3 key prefix.
|
||||
|
||||
Returns:
|
||||
str: The S3 URI of the uploaded file (e.g., "s3://<bucket>/<key>") if successful.
|
||||
None: If the required environment variables for the S3 bucket are not set.
|
||||
|
||||
Raises:
|
||||
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
|
||||
"""
|
||||
if not env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET", ""):
|
||||
return
|
||||
|
||||
if env.str("ARTIFACTS_AWS_ACCESS_KEY_ID", ""):
|
||||
s3 = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=env.str("ARTIFACTS_AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=env.str("ARTIFACTS_AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_token=env.str("ARTIFACTS_AWS_SESSION_TOKEN"),
|
||||
region_name=env.str("ARTIFACTS_AWS_DEFAULT_REGION"),
|
||||
)
|
||||
else:
|
||||
s3 = boto3.client("s3")
|
||||
|
||||
s3_key = f"{tenant_id}/{scan_id}/{os.path.basename(zip_path)}"
|
||||
try:
|
||||
s3.upload_file(
|
||||
Filename=zip_path,
|
||||
Bucket=env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET"),
|
||||
Key=s3_key,
|
||||
)
|
||||
return f"s3://{env.str("ARTIFACTS_AWS_S3_OUTPUT_BUCKET")}/{s3_key}"
|
||||
except ClientError as e:
|
||||
logger.error(f"S3 upload failed: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
def _create_finding_delta(
|
||||
last_status: FindingStatus | None | str, new_status: FindingStatus | None
|
||||
) -> Finding.DeltaChoices:
|
||||
) -> Finding.DeltaChoices | None:
|
||||
"""
|
||||
Determine the delta status of a finding based on its previous and current status.
|
||||
|
||||
@@ -53,7 +160,11 @@ def _create_finding_delta(
|
||||
|
||||
|
||||
def _store_resources(
|
||||
finding: ProwlerFinding, tenant_id: str, provider_instance: Provider
|
||||
finding: ProwlerFinding,
|
||||
tenant_id: str,
|
||||
provider_instance: Provider,
|
||||
resource_cache: dict,
|
||||
tag_cache: dict,
|
||||
) -> tuple[Resource, tuple[str, str]]:
|
||||
"""
|
||||
Store resource information from a finding, including tags, in the database.
|
||||
@@ -65,40 +176,91 @@ def _store_resources(
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- Resource: The resource instance created or retrieved from the database.
|
||||
- Resource: The resource instance created or updated from the database.
|
||||
- tuple[str, str]: A tuple containing the resource UID and region.
|
||||
|
||||
"""
|
||||
with rls_transaction(tenant_id):
|
||||
resource_instance, created = Resource.objects.get_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_instance,
|
||||
uid=finding.resource_uid,
|
||||
defaults={
|
||||
"region": finding.region,
|
||||
"service": finding.service_name,
|
||||
"type": finding.resource_type,
|
||||
},
|
||||
)
|
||||
resource_uid = finding.resource_uid
|
||||
|
||||
# Check cache or create/update resource
|
||||
if resource_uid in resource_cache:
|
||||
resource_instance = resource_cache[resource_uid]
|
||||
update_fields = []
|
||||
for field, value in [
|
||||
("region", finding.region),
|
||||
("service", finding.service_name),
|
||||
("type", finding.resource_type),
|
||||
("name", finding.resource_name),
|
||||
]:
|
||||
if getattr(resource_instance, field) != value:
|
||||
setattr(resource_instance, field, value)
|
||||
update_fields.append(field)
|
||||
if update_fields:
|
||||
with rls_transaction(tenant_id):
|
||||
resource_instance.save(update_fields=update_fields)
|
||||
else:
|
||||
with rls_transaction(tenant_id):
|
||||
resource_instance, _ = Resource.objects.update_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_instance,
|
||||
uid=resource_uid,
|
||||
defaults={
|
||||
"region": finding.region,
|
||||
"service": finding.service_name,
|
||||
"type": finding.resource_type,
|
||||
"name": finding.resource_name,
|
||||
},
|
||||
)
|
||||
resource_cache[resource_uid] = resource_instance
|
||||
|
||||
# Process tags with caching
|
||||
tags = []
|
||||
for key, value in finding.resource_tags.items():
|
||||
tag_key = (key, value)
|
||||
if tag_key not in tag_cache:
|
||||
with rls_transaction(tenant_id):
|
||||
tag_instance, _ = ResourceTag.objects.get_or_create(
|
||||
tenant_id=tenant_id, key=key, value=value
|
||||
)
|
||||
tag_cache[tag_key] = tag_instance
|
||||
tags.append(tag_cache[tag_key])
|
||||
|
||||
if not created:
|
||||
resource_instance.region = finding.region
|
||||
resource_instance.service = finding.service_name
|
||||
resource_instance.type = finding.resource_type
|
||||
resource_instance.save()
|
||||
with rls_transaction(tenant_id):
|
||||
tags = [
|
||||
ResourceTag.objects.get_or_create(
|
||||
tenant_id=tenant_id, key=key, value=value
|
||||
)[0]
|
||||
for key, value in finding.resource_tags.items()
|
||||
]
|
||||
resource_instance.upsert_or_delete_tags(tags=tags)
|
||||
|
||||
return resource_instance, (resource_instance.uid, resource_instance.region)
|
||||
|
||||
|
||||
def _generate_output_directory(
|
||||
prowler_provider: object, tenant_id: str, scan_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate a dynamic output directory path based on the given provider type.
|
||||
|
||||
Args:
|
||||
prowler_provider (object): An object that has a `type` attribute indicating
|
||||
the provider type (e.g., "aws", "azure", etc.).
|
||||
tenant_id (str): A unique identifier for the tenant. Used to build the output path.
|
||||
scan_id (str): A unique identifier for the scan. Included in the output path.
|
||||
|
||||
Returns:
|
||||
str: The complete path to the output directory, including the tenant ID, scan ID,
|
||||
provider identity, and a timestamp.
|
||||
|
||||
"""
|
||||
provider_type = prowler_provider.type
|
||||
get_identity = PROVIDER_IDENTITY_MAP.get(provider_type, lambda _: "unknown")
|
||||
return (
|
||||
f"{tmp_output_directory}/{tenant_id}/{scan_id}/prowler-output-"
|
||||
f"{get_identity(prowler_provider)}-{output_file_timestamp}"
|
||||
)
|
||||
|
||||
|
||||
def perform_prowler_scan(
|
||||
tenant_id: str, scan_id: str, provider_id: str, checks_to_execute: list[str] = None
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
provider_id: str,
|
||||
checks_to_execute: list[str] = None,
|
||||
):
|
||||
"""
|
||||
Perform a scan using Prowler and store the findings and resources in the database.
|
||||
@@ -120,6 +282,9 @@ def perform_prowler_scan(
|
||||
exception = None
|
||||
unique_resources = set()
|
||||
start_time = time.time()
|
||||
resource_cache = {}
|
||||
tag_cache = {}
|
||||
last_status_cache = {}
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
provider_instance = Provider.objects.get(pk=provider_id)
|
||||
@@ -129,6 +294,7 @@ def perform_prowler_scan(
|
||||
scan_instance.save()
|
||||
|
||||
try:
|
||||
# Provider initialization
|
||||
with rls_transaction(tenant_id):
|
||||
try:
|
||||
prowler_provider = initialize_prowler_provider(provider_instance)
|
||||
@@ -144,116 +310,66 @@ def perform_prowler_scan(
|
||||
)
|
||||
provider_instance.save()
|
||||
|
||||
prowler_scan = ProwlerScan(provider=prowler_provider, checks=checks_to_execute)
|
||||
# Scan configuration
|
||||
prowler_scan = ProwlerScan(
|
||||
provider=prowler_provider, checks=checks_to_execute or []
|
||||
)
|
||||
output_directory = _generate_output_directory(
|
||||
prowler_provider, tenant_id, scan_id
|
||||
)
|
||||
# Create the output directory
|
||||
os.makedirs("/".join(output_directory.split("/")[:-1]), exist_ok=True)
|
||||
|
||||
resource_cache = {}
|
||||
tag_cache = {}
|
||||
last_status_cache = {}
|
||||
|
||||
for progress, findings in prowler_scan.scan():
|
||||
all_findings = []
|
||||
for progress, findings, stats in prowler_scan.scan():
|
||||
for finding in findings:
|
||||
if finding is None:
|
||||
logger.error(f"None finding detected on scan {scan_id}.")
|
||||
continue
|
||||
for attempt in range(CELERY_DEADLOCK_ATTEMPTS):
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
# Process resource
|
||||
resource_uid = finding.resource_uid
|
||||
if resource_uid not in resource_cache:
|
||||
# Get or create the resource
|
||||
resource_instance, _ = Resource.objects.get_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_instance,
|
||||
uid=resource_uid,
|
||||
defaults={
|
||||
"region": finding.region,
|
||||
"service": finding.service_name,
|
||||
"type": finding.resource_type,
|
||||
"name": finding.resource_name,
|
||||
},
|
||||
)
|
||||
resource_cache[resource_uid] = resource_instance
|
||||
else:
|
||||
resource_instance = resource_cache[resource_uid]
|
||||
|
||||
# Update resource fields if necessary
|
||||
updated_fields = []
|
||||
if (
|
||||
finding.region
|
||||
and resource_instance.region != finding.region
|
||||
):
|
||||
resource_instance.region = finding.region
|
||||
updated_fields.append("region")
|
||||
if resource_instance.service != finding.service_name:
|
||||
resource_instance.service = finding.service_name
|
||||
updated_fields.append("service")
|
||||
if resource_instance.type != finding.resource_type:
|
||||
resource_instance.type = finding.resource_type
|
||||
updated_fields.append("type")
|
||||
if updated_fields:
|
||||
with rls_transaction(tenant_id):
|
||||
resource_instance.save(update_fields=updated_fields)
|
||||
resource_instance, resource_uid_region = _store_resources(
|
||||
finding,
|
||||
tenant_id,
|
||||
provider_instance,
|
||||
resource_cache,
|
||||
tag_cache,
|
||||
)
|
||||
unique_resources.add(resource_uid_region)
|
||||
break
|
||||
except (OperationalError, IntegrityError) as db_err:
|
||||
if attempt < CELERY_DEADLOCK_ATTEMPTS - 1:
|
||||
logger.warning(
|
||||
f"{'Deadlock error' if isinstance(db_err, OperationalError) else 'Integrity error'} "
|
||||
f"detected when processing resource {resource_uid} on scan {scan_id}. Retrying..."
|
||||
f"Database error ({type(db_err).__name__}) "
|
||||
f"processing resource {finding.resource_uid}, retrying..."
|
||||
)
|
||||
time.sleep(0.1 * (2**attempt))
|
||||
continue
|
||||
else:
|
||||
raise db_err
|
||||
|
||||
# Update tags
|
||||
tags = []
|
||||
with rls_transaction(tenant_id):
|
||||
for key, value in finding.resource_tags.items():
|
||||
tag_key = (key, value)
|
||||
if tag_key not in tag_cache:
|
||||
tag_instance, _ = ResourceTag.objects.get_or_create(
|
||||
tenant_id=tenant_id, key=key, value=value
|
||||
)
|
||||
tag_cache[tag_key] = tag_instance
|
||||
else:
|
||||
tag_instance = tag_cache[tag_key]
|
||||
tags.append(tag_instance)
|
||||
resource_instance.upsert_or_delete_tags(tags=tags)
|
||||
|
||||
unique_resources.add((resource_instance.uid, resource_instance.region))
|
||||
|
||||
# Process finding
|
||||
# Finding processing
|
||||
with rls_transaction(tenant_id):
|
||||
finding_uid = finding.uid
|
||||
last_first_seen_at = None
|
||||
if finding_uid not in last_status_cache:
|
||||
most_recent_finding = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id, uid=finding_uid
|
||||
)
|
||||
most_recent = (
|
||||
Finding.objects.filter(uid=finding_uid)
|
||||
.order_by("-inserted_at")
|
||||
.values("status", "first_seen_at")
|
||||
.first()
|
||||
)
|
||||
last_status = None
|
||||
if most_recent_finding:
|
||||
last_status = most_recent_finding["status"]
|
||||
last_first_seen_at = most_recent_finding["first_seen_at"]
|
||||
last_status_cache[finding_uid] = last_status, last_first_seen_at
|
||||
last_status, first_seen = (
|
||||
(most_recent["status"], most_recent["first_seen_at"])
|
||||
if most_recent
|
||||
else (None, None)
|
||||
)
|
||||
last_status_cache[finding_uid] = (last_status, first_seen)
|
||||
else:
|
||||
last_status, last_first_seen_at = last_status_cache[finding_uid]
|
||||
last_status, first_seen = last_status_cache[finding_uid]
|
||||
|
||||
status = FindingStatus[finding.status]
|
||||
delta = _create_finding_delta(last_status, status)
|
||||
# For the findings prior to the change, when a first finding is found with delta!="new" it will be
|
||||
# assigned a current date as first_seen_at and the successive findings with the same UID will
|
||||
# always get the date of the previous finding.
|
||||
# For new findings, when a finding (delta="new") is found for the first time, the first_seen_at
|
||||
# attribute will be assigned the current date, the following findings will get that date.
|
||||
if not last_first_seen_at:
|
||||
last_first_seen_at = datetime.now(tz=timezone.utc)
|
||||
first_seen = first_seen or datetime.now(tz=timezone.utc)
|
||||
|
||||
# Create the finding
|
||||
finding_instance = Finding.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
uid=finding_uid,
|
||||
@@ -266,92 +382,96 @@ def perform_prowler_scan(
|
||||
raw_result=finding.raw,
|
||||
check_id=finding.check_id,
|
||||
scan=scan_instance,
|
||||
first_seen_at=last_first_seen_at,
|
||||
first_seen_at=first_seen,
|
||||
)
|
||||
finding_instance.add_resources([resource_instance])
|
||||
|
||||
# Update compliance data if applicable
|
||||
if finding.status.value == "MUTED":
|
||||
continue
|
||||
# Update compliance status
|
||||
if finding.status.value != "MUTED":
|
||||
region_data = check_status_by_region.setdefault(finding.region, {})
|
||||
if region_data.get(finding.check_id) != "FAIL":
|
||||
region_data[finding.check_id] = finding.status.value
|
||||
|
||||
region_dict = check_status_by_region.setdefault(finding.region, {})
|
||||
current_status = region_dict.get(finding.check_id)
|
||||
if current_status == "FAIL":
|
||||
continue
|
||||
region_dict[finding.check_id] = finding.status.value
|
||||
|
||||
# Update scan progress
|
||||
# Progress updates and output generation
|
||||
with rls_transaction(tenant_id):
|
||||
scan_instance.progress = progress
|
||||
scan_instance.save()
|
||||
|
||||
all_findings.extend(findings)
|
||||
|
||||
# Generate output files
|
||||
for mode, config in OUTPUT_FORMATS_MAPPING.items():
|
||||
kwargs = dict(config["kwargs"])
|
||||
if mode == "html":
|
||||
kwargs["provider"] = prowler_provider
|
||||
kwargs["stats"] = stats
|
||||
config["class"](
|
||||
findings=all_findings,
|
||||
create_file_descriptor=True,
|
||||
file_path=output_directory,
|
||||
file_extension=config["suffix"],
|
||||
).batch_write_data_to_file(**kwargs)
|
||||
|
||||
scan_instance.state = StateChoices.COMPLETED
|
||||
|
||||
# Compress output files
|
||||
zip_path = _compress_output_files(output_directory)
|
||||
|
||||
# Save to configured storage
|
||||
_upload_to_s3(tenant_id, zip_path, scan_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing scan {scan_id}: {e}")
|
||||
logger.error(f"Scan {scan_id} failed: {str(e)}")
|
||||
exception = e
|
||||
scan_instance.state = StateChoices.FAILED
|
||||
|
||||
finally:
|
||||
# Final scan updates
|
||||
with rls_transaction(tenant_id):
|
||||
scan_instance.duration = time.time() - start_time
|
||||
scan_instance.completed_at = datetime.now(tz=timezone.utc)
|
||||
scan_instance.unique_resource_count = len(unique_resources)
|
||||
scan_instance.save()
|
||||
|
||||
if exception is None:
|
||||
try:
|
||||
regions = prowler_provider.get_regions()
|
||||
except AttributeError:
|
||||
regions = set()
|
||||
|
||||
# Compliance processing
|
||||
if not exception:
|
||||
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
|
||||
provider_instance.provider
|
||||
]
|
||||
compliance_overview_by_region = {
|
||||
region: deepcopy(compliance_template) for region in regions
|
||||
compliance_overview = {
|
||||
region: deepcopy(compliance_template)
|
||||
for region in getattr(prowler_provider, "get_regions", lambda: set())()
|
||||
}
|
||||
|
||||
for region, check_status in check_status_by_region.items():
|
||||
compliance_data = compliance_overview_by_region.setdefault(
|
||||
region, deepcopy(compliance_template)
|
||||
)
|
||||
for check_name, status in check_status.items():
|
||||
for region, checks in check_status_by_region.items():
|
||||
for check_id, status in checks.items():
|
||||
generate_scan_compliance(
|
||||
compliance_data,
|
||||
compliance_overview.setdefault(
|
||||
region, deepcopy(compliance_template)
|
||||
),
|
||||
provider_instance.provider,
|
||||
check_name,
|
||||
check_id,
|
||||
status,
|
||||
)
|
||||
|
||||
# Prepare compliance overview objects
|
||||
compliance_overview_objects = []
|
||||
for region, compliance_data in compliance_overview_by_region.items():
|
||||
for compliance_id, compliance in compliance_data.items():
|
||||
compliance_overview_objects.append(
|
||||
ComplianceOverview(
|
||||
tenant_id=tenant_id,
|
||||
scan=scan_instance,
|
||||
region=region,
|
||||
compliance_id=compliance_id,
|
||||
framework=compliance["framework"],
|
||||
version=compliance["version"],
|
||||
description=compliance["description"],
|
||||
requirements=compliance["requirements"],
|
||||
requirements_passed=compliance["requirements_status"]["passed"],
|
||||
requirements_failed=compliance["requirements_status"]["failed"],
|
||||
requirements_manual=compliance["requirements_status"]["manual"],
|
||||
total_requirements=compliance["total_requirements"],
|
||||
)
|
||||
ComplianceOverview.objects.bulk_create(
|
||||
[
|
||||
ComplianceOverview(
|
||||
tenant_id=tenant_id,
|
||||
scan=scan_instance,
|
||||
region=region,
|
||||
compliance_id=compliance_id,
|
||||
**compliance_data,
|
||||
)
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceOverview.objects.bulk_create(compliance_overview_objects)
|
||||
for region, data in compliance_overview.items()
|
||||
for compliance_id, compliance_data in data.items()
|
||||
]
|
||||
)
|
||||
|
||||
if exception is not None:
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
serializer = ScanTaskSerializer(instance=scan_instance)
|
||||
return serializer.data
|
||||
return ScanTaskSerializer(instance=scan_instance).data
|
||||
|
||||
|
||||
def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
@@ -478,29 +598,28 @@ def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
),
|
||||
)
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
scan_aggregations = {
|
||||
ScanSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
check_id=agg["check_id"],
|
||||
service=agg["resources__service"],
|
||||
severity=agg["severity"],
|
||||
region=agg["resources__region"],
|
||||
fail=agg["fail"],
|
||||
_pass=agg["_pass"],
|
||||
muted=agg["muted"],
|
||||
total=agg["total"],
|
||||
new=agg["new"],
|
||||
changed=agg["changed"],
|
||||
unchanged=agg["unchanged"],
|
||||
fail_new=agg["fail_new"],
|
||||
fail_changed=agg["fail_changed"],
|
||||
pass_new=agg["pass_new"],
|
||||
pass_changed=agg["pass_changed"],
|
||||
muted_new=agg["muted_new"],
|
||||
muted_changed=agg["muted_changed"],
|
||||
)
|
||||
for agg in aggregation
|
||||
}
|
||||
ScanSummary.objects.bulk_create(scan_aggregations, batch_size=3000)
|
||||
ScanSummary.objects.bulk_create(
|
||||
[
|
||||
ScanSummary(
|
||||
tenant_id=tenant_id,
|
||||
scan_id=scan_id,
|
||||
check_id=agg["check_id"],
|
||||
service=agg["resources__service"],
|
||||
severity=agg["severity"],
|
||||
region=agg["resources__region"],
|
||||
**{
|
||||
k: v or 0
|
||||
for k, v in agg.items()
|
||||
if k
|
||||
not in {
|
||||
"check_id",
|
||||
"resources__service",
|
||||
"severity",
|
||||
"resources__region",
|
||||
}
|
||||
},
|
||||
)
|
||||
for agg in aggregation
|
||||
],
|
||||
batch_size=3000,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ services:
|
||||
volumes:
|
||||
- "./api/src/backend:/home/prowler/backend"
|
||||
- "./api/pyproject.toml:/home/prowler/pyproject.toml"
|
||||
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
@@ -85,6 +86,8 @@ services:
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
volumes:
|
||||
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
|
||||
depends_on:
|
||||
valkey:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -59,6 +59,7 @@ aws_services_json_file = "aws_regions_by_service.json"
|
||||
# gcp_zones_json_file = "gcp_zones.json"
|
||||
|
||||
default_output_directory = getcwd() + "/output"
|
||||
tmp_output_directory = "/tmp/prowler_api_output"
|
||||
output_file_timestamp = timestamp.strftime("%Y%m%d%H%M%S")
|
||||
timestamp_iso = timestamp.isoformat(sep=" ", timespec="seconds")
|
||||
csv_file_suffix = ".csv"
|
||||
|
||||
@@ -139,13 +139,9 @@ def remove_custom_checks_module(input_folder: str, provider: str):
|
||||
def list_services(provider: str) -> set:
|
||||
available_services = set()
|
||||
checks_tuple = recover_checks_from_provider(provider)
|
||||
split_character = "\\" if os.name == "nt" else "/"
|
||||
for _, check_path in checks_tuple:
|
||||
# Format: /absolute_path/prowler/providers/{provider}/services/{service_name}/{check_name}
|
||||
if os.name == "nt":
|
||||
service_name = check_path.split("\\")[-2]
|
||||
else:
|
||||
service_name = check_path.split("/")[-2]
|
||||
available_services.add(service_name)
|
||||
available_services.add(check_path.split(split_character)[-2])
|
||||
return sorted(available_services)
|
||||
|
||||
|
||||
|
||||
@@ -38,16 +38,18 @@ class Output(ABC):
|
||||
file_extension: str = "",
|
||||
) -> None:
|
||||
self._data = []
|
||||
self.file_path = file_path
|
||||
|
||||
if not file_extension and file_path:
|
||||
self._file_extension = "".join(Path(file_path).suffixes)
|
||||
if file_extension:
|
||||
self._file_extension = file_extension
|
||||
self.file_path = f"{file_path}{self.file_extension}"
|
||||
|
||||
if findings:
|
||||
self.transform(findings)
|
||||
if create_file_descriptor and file_path:
|
||||
self.create_file_descriptor(file_path)
|
||||
self.create_file_descriptor(self.file_path)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import datetime
|
||||
from typing import Generator
|
||||
from typing import Dict, Generator, List, Optional, Set
|
||||
|
||||
from prowler.lib.check.check import (
|
||||
execute,
|
||||
@@ -14,6 +14,7 @@ from prowler.lib.check.models import CheckMetadata, Severity
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.outputs.common import Status
|
||||
from prowler.lib.outputs.finding import Finding
|
||||
from prowler.lib.outputs.outputs import extract_findings_statistics
|
||||
from prowler.lib.scan.exceptions.exceptions import (
|
||||
ScanInvalidCategoryError,
|
||||
ScanInvalidCheckError,
|
||||
@@ -28,28 +29,25 @@ from prowler.providers.common.provider import Provider
|
||||
|
||||
class Scan:
|
||||
_provider: Provider
|
||||
# Refactor(Core): This should replace the Audit_Metadata
|
||||
_number_of_checks_to_execute: int = 0
|
||||
_number_of_checks_completed: int = 0
|
||||
# TODO the str should be a set of Check objects
|
||||
_checks_to_execute: list[str]
|
||||
_service_checks_to_execute: dict[str, set[str]]
|
||||
_service_checks_completed: dict[str, set[str]]
|
||||
_checks_to_execute: List[str]
|
||||
_service_checks_map: Dict[str, Set[str]]
|
||||
_completed_checks: Set[str]
|
||||
_progress: float = 0.0
|
||||
_findings: List[Finding] = []
|
||||
_duration: int = 0
|
||||
_status: list[str] = None
|
||||
_statuses: Optional[List[Status]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: Provider,
|
||||
checks: list[str] = None,
|
||||
services: list[str] = None,
|
||||
compliances: list[str] = None,
|
||||
categories: list[str] = None,
|
||||
severities: list[str] = None,
|
||||
excluded_checks: list[str] = None,
|
||||
excluded_services: list[str] = None,
|
||||
status: list[str] = None,
|
||||
checks: Optional[List[str]] = None,
|
||||
services: Optional[List[str]] = None,
|
||||
compliances: Optional[List[str]] = None,
|
||||
categories: Optional[List[str]] = None,
|
||||
severities: Optional[List[str]] = None,
|
||||
excluded_checks: Optional[List[str]] = None,
|
||||
excluded_services: Optional[List[str]] = None,
|
||||
status: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Scan is the class that executes the checks and yields the progress and the findings.
|
||||
@@ -74,151 +72,190 @@ class Scan:
|
||||
ScanInvalidStatusError: If the status does not exist in the provider.
|
||||
"""
|
||||
self._provider = provider
|
||||
self._statuses = self._validate_statuses(status)
|
||||
|
||||
# Validate the status
|
||||
if status:
|
||||
try:
|
||||
for s in status:
|
||||
Status(s)
|
||||
if not self._status:
|
||||
self._status = []
|
||||
if s not in self._status:
|
||||
self._status.append(s)
|
||||
except ValueError:
|
||||
raise ScanInvalidStatusError(f"Invalid status provided: {s}.")
|
||||
|
||||
# Load bulk compliance frameworks
|
||||
bulk_compliance_frameworks = Compliance.get_bulk(provider.type)
|
||||
bulk_checks_metadata = self._load_checks_metadata(bulk_compliance_frameworks)
|
||||
|
||||
# Get bulk checks metadata for the provider
|
||||
bulk_checks_metadata = CheckMetadata.get_bulk(provider.type)
|
||||
# Complete checks metadata with the compliance framework specification
|
||||
bulk_checks_metadata = update_checks_metadata_with_compliance(
|
||||
self._validate_inputs(
|
||||
checks,
|
||||
services,
|
||||
compliances,
|
||||
categories,
|
||||
severities,
|
||||
bulk_checks_metadata,
|
||||
bulk_compliance_frameworks,
|
||||
)
|
||||
|
||||
self._checks_to_execute = self._init_checks_to_execute(
|
||||
bulk_checks_metadata,
|
||||
bulk_compliance_frameworks,
|
||||
checks,
|
||||
services,
|
||||
compliances,
|
||||
categories,
|
||||
severities,
|
||||
excluded_checks,
|
||||
excluded_services,
|
||||
)
|
||||
|
||||
self._service_checks_map = get_service_checks_mapping(self._checks_to_execute)
|
||||
self._completed_checks = set()
|
||||
|
||||
def _validate_statuses(
|
||||
self, statuses: Optional[List[str]]
|
||||
) -> Optional[List[Status]]:
|
||||
"""Validate and convert status strings to Status enums."""
|
||||
if not statuses:
|
||||
return None
|
||||
|
||||
validated = []
|
||||
for status in statuses:
|
||||
try:
|
||||
validated.append(Status(status))
|
||||
except ValueError:
|
||||
raise ScanInvalidStatusError(f"Invalid status: {status}")
|
||||
return validated
|
||||
|
||||
def _load_checks_metadata(self, bulk_compliance_frameworks: dict) -> dict:
|
||||
"""Load and enhance checks metadata with compliance information."""
|
||||
bulk_checks_metadata = CheckMetadata.get_bulk(self._provider.type)
|
||||
return update_checks_metadata_with_compliance(
|
||||
bulk_compliance_frameworks, bulk_checks_metadata
|
||||
)
|
||||
|
||||
# Create a list of valid categories
|
||||
valid_categories = set()
|
||||
for check, metadata in bulk_checks_metadata.items():
|
||||
for category in metadata.Categories:
|
||||
if category not in valid_categories:
|
||||
valid_categories.add(category)
|
||||
def _validate_inputs(
|
||||
self,
|
||||
checks: Optional[List[str]],
|
||||
services: Optional[List[str]],
|
||||
compliances: Optional[List[str]],
|
||||
categories: Optional[List[str]],
|
||||
severities: Optional[List[str]],
|
||||
bulk_checks_metadata: dict,
|
||||
bulk_compliance_frameworks: dict,
|
||||
):
|
||||
"""Validate all input parameters against provider capabilities."""
|
||||
valid_services = list_services(self._provider.type)
|
||||
valid_categories = self._extract_valid_categories(bulk_checks_metadata)
|
||||
|
||||
# Validate checks
|
||||
self._validate_checks(checks, bulk_checks_metadata)
|
||||
self._validate_services(services, valid_services)
|
||||
self._validate_compliances(compliances, bulk_compliance_frameworks)
|
||||
self._validate_categories(categories, valid_categories)
|
||||
self._validate_severities(severities)
|
||||
|
||||
def _extract_valid_categories(self, checks_metadata: dict) -> Set[str]:
|
||||
"""Extract unique categories from checks metadata."""
|
||||
return {
|
||||
category
|
||||
for metadata in checks_metadata.values()
|
||||
for category in metadata.Categories
|
||||
}
|
||||
|
||||
def _validate_checks(self, checks: Optional[List[str]], metadata: dict):
|
||||
"""Validate requested checks exist in provider."""
|
||||
if checks:
|
||||
for check in checks:
|
||||
if check not in bulk_checks_metadata.keys():
|
||||
raise ScanInvalidCheckError(f"Invalid check provided: {check}.")
|
||||
invalid = [check for check in checks if check not in metadata]
|
||||
if invalid:
|
||||
raise ScanInvalidCheckError(f"Invalid checks: {', '.join(invalid)}")
|
||||
|
||||
# Validate services
|
||||
def _validate_services(self, services: Optional[List[str]], valid_services: list):
|
||||
"""Validate requested services exist in provider."""
|
||||
if services:
|
||||
for service in services:
|
||||
if service not in list_services(provider.type):
|
||||
raise ScanInvalidServiceError(
|
||||
f"Invalid service provided: {service}."
|
||||
)
|
||||
invalid = [srv for srv in services if srv not in valid_services]
|
||||
if invalid:
|
||||
raise ScanInvalidServiceError(f"Invalid services: {', '.join(invalid)}")
|
||||
|
||||
# Validate compliances
|
||||
def _validate_compliances(self, compliances: Optional[List[str]], frameworks: dict):
|
||||
"""Validate compliance frameworks exist."""
|
||||
if compliances:
|
||||
for compliance in compliances:
|
||||
if compliance not in bulk_compliance_frameworks.keys():
|
||||
raise ScanInvalidComplianceFrameworkError(
|
||||
f"Invalid compliance provided: {compliance}."
|
||||
)
|
||||
invalid = [comp for comp in compliances if comp not in frameworks]
|
||||
if invalid:
|
||||
raise ScanInvalidComplianceFrameworkError(
|
||||
f"Invalid compliances: {', '.join(invalid)}"
|
||||
)
|
||||
|
||||
# Validate categories
|
||||
def _validate_categories(
|
||||
self, categories: Optional[List[str]], valid_categories: Set[str]
|
||||
):
|
||||
"""Validate categories exist in provider checks."""
|
||||
if categories:
|
||||
for category in categories:
|
||||
if category not in valid_categories:
|
||||
raise ScanInvalidCategoryError(
|
||||
f"Invalid category provided: {category}."
|
||||
)
|
||||
invalid = [cat for cat in categories if cat not in valid_categories]
|
||||
if invalid:
|
||||
raise ScanInvalidCategoryError(
|
||||
f"Invalid categories: {', '.join(invalid)}"
|
||||
)
|
||||
|
||||
# Validate severity
|
||||
def _validate_severities(self, severities: Optional[List[str]]):
|
||||
"""Validate severity values are valid."""
|
||||
if severities:
|
||||
for severity in severities:
|
||||
try:
|
||||
Severity(severity)
|
||||
except ValueError:
|
||||
raise ScanInvalidSeverityError(
|
||||
f"Invalid severity provided: {severity}."
|
||||
)
|
||||
try:
|
||||
[Severity(sev) for sev in severities]
|
||||
except ValueError as e:
|
||||
raise ScanInvalidSeverityError(f"Invalid severity: {e}")
|
||||
|
||||
# Load checks to execute
|
||||
self._checks_to_execute = sorted(
|
||||
load_checks_to_execute(
|
||||
bulk_checks_metadata=bulk_checks_metadata,
|
||||
bulk_compliance_frameworks=bulk_compliance_frameworks,
|
||||
check_list=checks,
|
||||
service_list=services,
|
||||
compliance_frameworks=compliances,
|
||||
categories=categories,
|
||||
severities=severities,
|
||||
provider=provider.type,
|
||||
checks_file=None,
|
||||
)
|
||||
def _init_checks_to_execute(
|
||||
self,
|
||||
bulk_checks_metadata: dict,
|
||||
bulk_compliance_frameworks: dict,
|
||||
checks: Optional[List[str]],
|
||||
services: Optional[List[str]],
|
||||
compliances: Optional[List[str]],
|
||||
categories: Optional[List[str]],
|
||||
severities: Optional[List[str]],
|
||||
excluded_checks: Optional[List[str]],
|
||||
excluded_services: Optional[List[str]],
|
||||
) -> List[str]:
|
||||
"""Load and filter checks based on configuration."""
|
||||
checks = load_checks_to_execute(
|
||||
bulk_checks_metadata=bulk_checks_metadata,
|
||||
bulk_compliance_frameworks=bulk_compliance_frameworks,
|
||||
check_list=checks,
|
||||
service_list=services,
|
||||
compliance_frameworks=compliances,
|
||||
categories=categories,
|
||||
severities=severities,
|
||||
provider=self._provider.type,
|
||||
)
|
||||
|
||||
# Exclude checks
|
||||
if excluded_checks:
|
||||
for check in excluded_checks:
|
||||
if check in self._checks_to_execute:
|
||||
self._checks_to_execute.remove(check)
|
||||
else:
|
||||
raise ScanInvalidCheckError(
|
||||
f"Invalid check provided: {check}. Check does not exist in the provider."
|
||||
)
|
||||
checks = [c for c in checks if c not in excluded_checks]
|
||||
|
||||
# Exclude services
|
||||
if excluded_services:
|
||||
for check in self._checks_to_execute:
|
||||
if get_service_name_from_check_name(check) in excluded_services:
|
||||
self._checks_to_execute.remove(check)
|
||||
else:
|
||||
raise ScanInvalidServiceError(
|
||||
f"Invalid service provided: {check}. Service does not exist in the provider."
|
||||
)
|
||||
excluded_services_set = set(excluded_services)
|
||||
checks = [
|
||||
c
|
||||
for c in checks
|
||||
if get_service_from_check(c) not in excluded_services_set
|
||||
]
|
||||
|
||||
self._number_of_checks_to_execute = len(self._checks_to_execute)
|
||||
|
||||
service_checks_to_execute = get_service_checks_to_execute(
|
||||
self._checks_to_execute
|
||||
)
|
||||
service_checks_completed = dict()
|
||||
|
||||
self._service_checks_to_execute = service_checks_to_execute
|
||||
self._service_checks_completed = service_checks_completed
|
||||
return sorted(checks)
|
||||
|
||||
@property
|
||||
def checks_to_execute(self) -> list[str]:
|
||||
return self._checks_to_execute
|
||||
def total_checks(self) -> int:
|
||||
return len(self._checks_to_execute)
|
||||
|
||||
@property
|
||||
def service_checks_to_execute(self) -> dict[str, set[str]]:
|
||||
return self._service_checks_to_execute
|
||||
|
||||
@property
|
||||
def service_checks_completed(self) -> dict[str, set[str]]:
|
||||
return self._service_checks_completed
|
||||
|
||||
@property
|
||||
def provider(self) -> Provider:
|
||||
return self._provider
|
||||
def completed_checks(self) -> int:
|
||||
return len(self._completed_checks)
|
||||
|
||||
@property
|
||||
def progress(self) -> float:
|
||||
return (
|
||||
self._number_of_checks_completed / self._number_of_checks_to_execute * 100
|
||||
(self.completed_checks / self.total_checks * 100)
|
||||
if self.total_checks
|
||||
else 0
|
||||
)
|
||||
|
||||
@property
|
||||
def duration(self) -> int:
|
||||
return self._duration
|
||||
def remaining_services(self) -> Dict[str, Set[str]]:
|
||||
return {
|
||||
service: checks
|
||||
for service, checks in self._service_checks_map.items()
|
||||
if checks
|
||||
}
|
||||
|
||||
def scan(
|
||||
self,
|
||||
custom_checks_metadata: dict = {},
|
||||
) -> Generator[tuple[float, list[Finding]], None, None]:
|
||||
def scan(self, custom_checks_metadata: dict = {}) -> Generator[float, List[Finding], dict]:
|
||||
"""
|
||||
Executes the scan by iterating over the checks to execute and executing each check.
|
||||
Yields the progress and findings for each check.
|
||||
@@ -233,143 +270,103 @@ class Scan:
|
||||
ModuleNotFoundError: If the check does not exist in the provider or is from another provider.
|
||||
Exception: If any other error occurs during the execution of a check.
|
||||
"""
|
||||
try:
|
||||
checks_to_execute = self.checks_to_execute
|
||||
# Initialize the Audit Metadata
|
||||
# TODO: this should be done in the provider class
|
||||
# Refactor(Core): Audit manager?
|
||||
self._provider.audit_metadata = Audit_Metadata(
|
||||
services_scanned=0,
|
||||
expected_checks=checks_to_execute,
|
||||
completed_checks=0,
|
||||
audit_progress=0,
|
||||
)
|
||||
self._provider.audit_metadata = Audit_Metadata(
|
||||
services_scanned=0,
|
||||
expected_checks=self._checks_to_execute,
|
||||
completed_checks=0,
|
||||
audit_progress=0,
|
||||
)
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
for check_name in checks_to_execute:
|
||||
for check_name in self._checks_to_execute:
|
||||
service = get_service_from_check(check_name)
|
||||
try:
|
||||
check_module = self._import_check_module(check_name, service)
|
||||
findings = self._execute_check(check_module, check_name, custom_checks_metadata)
|
||||
filtered_findings = self._filter_findings_by_status(findings)
|
||||
except Exception as error:
|
||||
logger.error(f"{check_name} failed: {error}")
|
||||
continue
|
||||
|
||||
self._update_scan_state(check_name, service, filtered_findings)
|
||||
stats = extract_findings_statistics(filtered_findings)
|
||||
|
||||
findings = []
|
||||
for finding in filtered_findings:
|
||||
try:
|
||||
# Recover service from check name
|
||||
service = get_service_name_from_check_name(check_name)
|
||||
try:
|
||||
# Import check module
|
||||
check_module_path = f"prowler.providers.{self._provider.type}.services.{service}.{check_name}.{check_name}"
|
||||
lib = import_check(check_module_path)
|
||||
# Recover functions from check
|
||||
check_to_execute = getattr(lib, check_name)
|
||||
check = check_to_execute()
|
||||
except ModuleNotFoundError:
|
||||
logger.error(
|
||||
f"Check '{check_name}' was not found for the {self._provider.type.upper()} provider"
|
||||
findings.append(
|
||||
Finding.generate_output(
|
||||
self._provider, finding, output_options=None
|
||||
)
|
||||
continue
|
||||
# Execute the check
|
||||
check_findings = execute(
|
||||
check,
|
||||
self._provider,
|
||||
custom_checks_metadata,
|
||||
output_options=None,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Filter the findings by the status
|
||||
if self._status:
|
||||
for finding in check_findings:
|
||||
if finding.status not in self._status:
|
||||
check_findings.remove(finding)
|
||||
yield self.progress, findings, stats
|
||||
|
||||
# Remove the executed check
|
||||
self._service_checks_to_execute[service].remove(check_name)
|
||||
if len(self._service_checks_to_execute[service]) == 0:
|
||||
self._service_checks_to_execute.pop(service, None)
|
||||
# Add the completed check
|
||||
if service not in self._service_checks_completed:
|
||||
self._service_checks_completed[service] = set()
|
||||
self._service_checks_completed[service].add(check_name)
|
||||
self._number_of_checks_completed += 1
|
||||
self._duration = int((datetime.datetime.now() - start_time).total_seconds())
|
||||
|
||||
# This should be done just once all the service's checks are completed
|
||||
# This metadata needs to get to the services not within the provider
|
||||
# since it is present in the Scan class
|
||||
self._provider.audit_metadata = update_audit_metadata(
|
||||
self._provider.audit_metadata,
|
||||
self.get_completed_services(),
|
||||
self.get_completed_checks(),
|
||||
)
|
||||
|
||||
findings = []
|
||||
for finding in check_findings:
|
||||
try:
|
||||
findings.append(
|
||||
Finding.generate_output(
|
||||
self._provider, finding, output_options=None
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
yield self.progress, findings
|
||||
# If check does not exists in the provider or is from another provider
|
||||
except ModuleNotFoundError:
|
||||
logger.error(
|
||||
f"Check '{check_name}' was not found for the {self._provider.type.upper()} provider"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{check_name} - {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
# Update the scan duration when all checks are completed
|
||||
self._duration = int((datetime.datetime.now() - start_time).total_seconds())
|
||||
except Exception as error:
|
||||
def _import_check_module(self, check_name: str, service: str):
|
||||
"""Dynamically import check module."""
|
||||
module_path = (
|
||||
f"prowler.providers.{self._provider.type}.services."
|
||||
f"{service}.{check_name}.{check_name}"
|
||||
)
|
||||
try:
|
||||
return import_check(module_path)
|
||||
except ModuleNotFoundError:
|
||||
logger.error(
|
||||
f"{check_name} - {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"Check '{check_name}' not found for {self._provider.type.upper()}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_completed_services(self) -> set[str]:
|
||||
"""
|
||||
get_completed_services returns the services that have been completed.
|
||||
def _execute_check(self, check_module, check_name: str, custom_checks_metadata: dict = {}) -> List[Finding]:
|
||||
"""Execute a single check and return its findings."""
|
||||
check_func = getattr(check_module, check_name)
|
||||
return execute(check_func(), self._provider, custom_checks_metadata)
|
||||
|
||||
Example:
|
||||
get_completed_services() -> {"ec2", "s3"}
|
||||
"""
|
||||
return self._service_checks_completed.keys()
|
||||
def _filter_findings_by_status(self, findings: List[Finding]) -> List[Finding]:
|
||||
"""Filter findings based on configured status filters."""
|
||||
if not self._statuses:
|
||||
return findings
|
||||
return [f for f in findings if f.status in self._statuses]
|
||||
|
||||
def get_completed_checks(self) -> set[str]:
|
||||
"""
|
||||
get_completed_checks returns the checks that have been completed.
|
||||
def _update_scan_state(
|
||||
self, check_name: str, service: str, findings: List[Finding]
|
||||
):
|
||||
"""Update scan state after check completion."""
|
||||
self._service_checks_map[service].discard(check_name)
|
||||
self._completed_checks.add(check_name)
|
||||
self._findings.extend(findings)
|
||||
|
||||
Example:
|
||||
get_completed_checks() -> {"ec2_instance_public_ip", "s3_bucket_public"}
|
||||
"""
|
||||
completed_checks = set()
|
||||
for checks in self._service_checks_completed.values():
|
||||
completed_checks.update(checks)
|
||||
return completed_checks
|
||||
self._provider.audit_metadata = update_audit_metadata(
|
||||
self._provider.audit_metadata,
|
||||
self.remaining_services.keys(),
|
||||
self._completed_checks,
|
||||
)
|
||||
|
||||
|
||||
def get_service_name_from_check_name(check_name: str) -> str:
|
||||
def get_service_from_check(check_name: str) -> str:
|
||||
"""
|
||||
get_service_name_from_check_name returns the service name for a given check name.
|
||||
Return the service name for a given check name.
|
||||
|
||||
Example:
|
||||
get_service_name_from_check_name("ec2_instance_public") -> "ec2"
|
||||
get_service_from_check("ec2_instance_public") -> "ec2"
|
||||
"""
|
||||
return check_name.split("_")[0]
|
||||
|
||||
|
||||
def get_service_checks_to_execute(checks_to_execute: set[str]) -> dict[str, set[str]]:
|
||||
def get_service_checks_mapping(checks: List[str]) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
get_service_checks_to_execute returns a dictionary with the services and the checks to execute.
|
||||
Return a dictionary with the services and the checks to execute.
|
||||
|
||||
Example:
|
||||
get_service_checks_to_execute({"accessanalyzer_enabled", "ec2_instance_public_ip"})
|
||||
get_service_checks_mapping({"accessanalyzer_enabled", "ec2_instance_public_ip"})
|
||||
-> {"accessanalyzer": {"accessanalyzer_enabled"}, "ec2": {"ec2_instance_public_ip"}}
|
||||
"""
|
||||
service_checks_to_execute = dict()
|
||||
for check in checks_to_execute:
|
||||
# check -> accessanalyzer_enabled
|
||||
# service -> accessanalyzer
|
||||
service = get_service_name_from_check_name(check)
|
||||
if service not in service_checks_to_execute:
|
||||
service_checks_to_execute[service] = set()
|
||||
service_checks_to_execute[service].add(check)
|
||||
return service_checks_to_execute
|
||||
service_map = {}
|
||||
for check in checks:
|
||||
service = get_service_from_check(check)
|
||||
service_map.setdefault(service, set()).add(check)
|
||||
return service_map
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from importlib.machinery import FileFinder
|
||||
from pkgutil import ModuleInfo
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from mock import MagicMock, patch
|
||||
|
||||
@@ -13,7 +12,7 @@ from prowler.lib.scan.exceptions.exceptions import (
|
||||
ScanInvalidSeverityError,
|
||||
ScanInvalidStatusError,
|
||||
)
|
||||
from prowler.lib.scan.scan import Scan, get_service_checks_to_execute
|
||||
from prowler.lib.scan.scan import Scan, get_service_checks_mapping
|
||||
from tests.lib.outputs.fixtures.fixtures import generate_finding_output
|
||||
from tests.providers.aws.utils import set_mocked_aws_provider
|
||||
|
||||
@@ -196,9 +195,9 @@ class TestScan:
|
||||
mock_provider.type = "aws"
|
||||
scan = Scan(mock_provider, checks=checks_to_execute)
|
||||
|
||||
assert scan.provider == mock_provider
|
||||
assert scan._provider == mock_provider
|
||||
# Check that the checks to execute are sorted and without duplicates
|
||||
assert scan.checks_to_execute == [
|
||||
assert scan._checks_to_execute == [
|
||||
"accessanalyzer_enabled",
|
||||
"accessanalyzer_enabled_without_findings",
|
||||
"account_maintain_current_contact_details",
|
||||
@@ -258,14 +257,16 @@ class TestScan:
|
||||
"config_recorder_all_regions_enabled",
|
||||
"workspaces_vpc_2private_1public_subnets_nat",
|
||||
]
|
||||
assert scan.service_checks_to_execute == get_service_checks_to_execute(
|
||||
assert scan._service_checks_map == get_service_checks_mapping(
|
||||
checks_to_execute
|
||||
)
|
||||
assert scan.service_checks_completed == {}
|
||||
assert scan.progress == 0
|
||||
assert scan.duration == 0
|
||||
assert scan.get_completed_services() == set()
|
||||
assert scan.get_completed_checks() == set()
|
||||
assert scan._completed_checks == set()
|
||||
assert scan._progress == 0
|
||||
assert scan._duration == 0
|
||||
all_values = set().union(*scan.remaining_services.values())
|
||||
for check in scan._checks_to_execute:
|
||||
assert check in all_values
|
||||
assert scan.completed_checks == 0
|
||||
|
||||
def test_init_with_no_checks(
|
||||
mock_provider,
|
||||
@@ -281,18 +282,24 @@ class TestScan:
|
||||
mock_load_checks_to_execute.assert_called_once()
|
||||
mock_recover_checks_from_provider.assert_called_once_with("aws")
|
||||
|
||||
assert scan.provider == mock_provider
|
||||
assert scan.checks_to_execute == ["accessanalyzer_enabled"]
|
||||
assert scan.service_checks_to_execute == get_service_checks_to_execute(
|
||||
assert scan._provider == mock_provider
|
||||
assert scan._checks_to_execute == ["accessanalyzer_enabled"]
|
||||
assert scan._service_checks_map == get_service_checks_mapping(
|
||||
["accessanalyzer_enabled"]
|
||||
)
|
||||
assert scan.service_checks_completed == {}
|
||||
assert scan.progress == 0
|
||||
assert scan.get_completed_services() == set()
|
||||
assert scan.get_completed_checks() == set()
|
||||
assert scan._completed_checks == set()
|
||||
assert scan._progress == 0
|
||||
all_values = set().union(*scan.remaining_services.values())
|
||||
for check in scan._checks_to_execute:
|
||||
assert check in all_values
|
||||
assert scan.completed_checks == 0
|
||||
|
||||
@patch("importlib.import_module")
|
||||
@patch("prowler.lib.scan.scan.list_services")
|
||||
@patch("prowler.lib.scan.scan.extract_findings_statistics")
|
||||
def test_scan(
|
||||
mock_extract_findings_statistics,
|
||||
mock_list_services,
|
||||
mock_import_module,
|
||||
mock_global_provider,
|
||||
mock_execute,
|
||||
@@ -328,11 +335,9 @@ class TestScan:
|
||||
assert results[0][0] == 100.0
|
||||
assert scan.progress == 100.0
|
||||
# Since the scan is mocked, the duration will always be 0 for now
|
||||
assert scan.duration == 0
|
||||
assert scan._number_of_checks_completed == 1
|
||||
assert scan.service_checks_completed == {
|
||||
"accessanalyzer": {"accessanalyzer_enabled"},
|
||||
}
|
||||
assert scan._duration == 0
|
||||
assert scan.completed_checks == 1
|
||||
assert scan._completed_checks == {"accessanalyzer_enabled"}
|
||||
mock_logger.error.assert_not_called()
|
||||
|
||||
def test_init_invalid_severity(
|
||||
@@ -396,7 +401,9 @@ class TestScan:
|
||||
Scan(mock_provider, checks=checks_to_execute, status=["invalid_status"])
|
||||
|
||||
@patch("importlib.import_module")
|
||||
@patch("prowler.lib.scan.scan.list_services")
|
||||
def test_scan_filter_status(
|
||||
mock_list_services,
|
||||
mock_import_module,
|
||||
mock_global_provider,
|
||||
mock_recover_checks_from_provider,
|
||||
@@ -422,4 +429,21 @@ class TestScan:
|
||||
mock_recover_checks_from_provider.assert_called_once_with("aws")
|
||||
results = list(scan.scan(custom_checks_metadata))
|
||||
|
||||
assert results[0] == (100.0, [])
|
||||
assert results[0] == (100.0, [], {
|
||||
"all_fails_are_muted": True,
|
||||
"findings_count": 0,
|
||||
"resources_count": 0,
|
||||
"total_critical_severity_fail": 0,
|
||||
"total_critical_severity_pass": 0,
|
||||
"total_fail": 0,
|
||||
"total_high_severity_fail": 0,
|
||||
"total_high_severity_pass": 0,
|
||||
"total_low_severity_fail": 0,
|
||||
"total_low_severity_pass": 0,
|
||||
"total_medium_severity_fail": 0,
|
||||
"total_medium_severity_pass": 0,
|
||||
"total_muted_fail": 0,
|
||||
"total_muted_pass": 0,
|
||||
"total_pass": 0,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user