feat(integrations): add s3 integration (#8056)

This commit is contained in:
Adrián Jesús Peña Rodríguez
2025-07-30 12:05:46 +02:00
committed by GitHub
parent 7ec514d9dd
commit 163fbaff19
28 changed files with 1165 additions and 43 deletions

View File

@@ -164,8 +164,9 @@ jobs:
working-directory: ./api
if: steps.are-non-ignored-files-changed.outputs.any_changed == 'true'
# 76352, 76353, 77323 come from SDK, but they cannot upgrade it yet. It does not affect API
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
run: |
poetry run safety check --ignore 70612,66963,74429,76352,76353,77323
poetry run safety check --ignore 70612,66963,74429,76352,76353,77323,77744,77745
- name: Vulture
working-directory: ./api

View File

@@ -115,7 +115,8 @@ repos:
- id: safety
name: safety
description: "Safety is a tool that checks your installed dependencies for known security vulnerabilities"
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353'
# TODO: Botocore needs urllib3 1.X so we need to ignore these vulnerabilities 77744,77745. Remove this once we upgrade to urllib3 2.X
entry: bash -c 'safety check --ignore 70612,66963,74429,76352,76353,77744,77745'
language: system
- id: vulture

View File

@@ -6,6 +6,7 @@ All notable changes to the **Prowler API** are documented in this file.
### Added
- Github provider support [(#8271)](https://github.com/prowler-cloud/prowler/pull/8271)
- Integration with Amazon S3, enabling storage and retrieval of scan data via S3 buckets [(#8056)](https://github.com/prowler-cloud/prowler/pull/8056)
---

View File

@@ -44,6 +44,9 @@ USER prowler
WORKDIR /home/prowler
# Ensure output directory exists
RUN mkdir -p /tmp/prowler_api_output
COPY pyproject.toml ./
RUN pip install --no-cache-dir --upgrade pip && \

View File

@@ -32,7 +32,7 @@ start_prod_server() {
start_worker() {
echo "Starting the worker..."
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview -E --max-tasks-per-child 1
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview,integrations -E --max-tasks-per-child 1
}
start_worker_beat() {

View File

@@ -0,0 +1,19 @@
# Generated by Django 5.1.10 on 2025-07-17 11:52
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0043_github_provider"),
]
operations = [
migrations.AddConstraint(
model_name="integration",
constraint=models.UniqueConstraint(
fields=("configuration", "tenant"),
name="unique_configuration_per_tenant",
),
),
]

View File

@@ -0,0 +1,17 @@
# Generated by Django 5.1.10 on 2025-07-21 16:08
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0044_integration_unique_configuration_per_tenant"),
]
operations = [
migrations.AlterField(
model_name="scan",
name="output_location",
field=models.CharField(blank=True, max_length=4096, null=True),
),
]

View File

@@ -438,7 +438,7 @@ class Scan(RowLevelSecurityProtectedModel):
scheduler_task = models.ForeignKey(
PeriodicTask, on_delete=models.SET_NULL, null=True, blank=True
)
output_location = models.CharField(blank=True, null=True, max_length=200)
output_location = models.CharField(blank=True, null=True, max_length=4096)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
@@ -1346,7 +1346,7 @@ class ScanSummary(RowLevelSecurityProtectedModel):
class Integration(RowLevelSecurityProtectedModel):
class IntegrationChoices(models.TextChoices):
S3 = "amazon_s3", _("Amazon S3")
AMAZON_S3 = "amazon_s3", _("Amazon S3")
AWS_SECURITY_HUB = "aws_security_hub", _("AWS Security Hub")
JIRA = "jira", _("JIRA")
SLACK = "slack", _("Slack")
@@ -1372,6 +1372,10 @@ class Integration(RowLevelSecurityProtectedModel):
db_table = "integrations"
constraints = [
models.UniqueConstraint(
fields=("configuration", "tenant"),
name="unique_configuration_per_tenant",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",

View File

@@ -2871,6 +2871,30 @@ paths:
responses:
'204':
description: No response body
/api/v1/integrations/{id}/connection:
post:
operationId: integrations_connection_create
description: Try to verify integration connection
summary: Check integration connection
parameters:
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this integration.
required: true
tags:
- Integration
security:
- jwtAuth: []
responses:
'202':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/OpenApiResponseResponse'
description: ''
/api/v1/invitations/accept:
post:
operationId: invitations_accept_create

View File

@@ -0,0 +1,103 @@
import pytest
from rest_framework.exceptions import ValidationError
from api.v1.serializer_utils.integrations import S3ConfigSerializer
class TestS3ConfigSerializer:
"""Test cases for S3ConfigSerializer validation."""
def test_validate_output_directory_valid_paths(self):
"""Test that valid output directory paths are accepted."""
serializer = S3ConfigSerializer()
# Test normal paths
assert serializer.validate_output_directory("test") == "test"
assert serializer.validate_output_directory("test/folder") == "test/folder"
assert serializer.validate_output_directory("my-folder_123") == "my-folder_123"
# Test paths with leading slashes (should be normalized)
assert serializer.validate_output_directory("/test") == "test"
assert serializer.validate_output_directory("/test/folder") == "test/folder"
# Test paths with excessive slashes (should be normalized)
assert serializer.validate_output_directory("///test") == "test"
assert serializer.validate_output_directory("///////test") == "test"
assert serializer.validate_output_directory("test//folder") == "test/folder"
assert serializer.validate_output_directory("test///folder") == "test/folder"
def test_validate_output_directory_empty_values(self):
"""Test that empty values raise validation errors."""
serializer = S3ConfigSerializer()
with pytest.raises(ValidationError, match="Output directory cannot be empty"):
serializer.validate_output_directory("")
with pytest.raises(
ValidationError, match="Output directory cannot be empty or just"
):
serializer.validate_output_directory(".")
with pytest.raises(
ValidationError, match="Output directory cannot be empty or just"
):
serializer.validate_output_directory("/")
def test_validate_output_directory_invalid_characters(self):
"""Test that invalid characters are rejected."""
serializer = S3ConfigSerializer()
invalid_chars = ["<", ">", ":", '"', "|", "?", "*"]
for char in invalid_chars:
with pytest.raises(
ValidationError, match="Output directory contains invalid characters"
):
serializer.validate_output_directory(f"test{char}folder")
def test_validate_output_directory_too_long(self):
"""Test that paths that are too long are rejected."""
serializer = S3ConfigSerializer()
# Create a path longer than 900 characters
long_path = "a" * 901
with pytest.raises(ValidationError, match="Output directory path is too long"):
serializer.validate_output_directory(long_path)
def test_validate_output_directory_edge_cases(self):
"""Test edge cases for output directory validation."""
serializer = S3ConfigSerializer()
# Test path at the limit (900 characters)
path_at_limit = "a" * 900
assert serializer.validate_output_directory(path_at_limit) == path_at_limit
# Test complex normalization
assert serializer.validate_output_directory("//test/../folder//") == "folder"
assert serializer.validate_output_directory("/test/./folder/") == "test/folder"
def test_s3_config_serializer_full_validation(self):
"""Test the full S3ConfigSerializer with valid data."""
data = {
"bucket_name": "my-test-bucket",
"output_directory": "///////test", # This should be normalized
}
serializer = S3ConfigSerializer(data=data)
assert serializer.is_valid()
validated_data = serializer.validated_data
assert validated_data["bucket_name"] == "my-test-bucket"
assert validated_data["output_directory"] == "test" # Normalized
def test_s3_config_serializer_invalid_data(self):
"""Test the full S3ConfigSerializer with invalid data."""
data = {
"bucket_name": "my-test-bucket",
"output_directory": "test<invalid", # Contains invalid character
}
serializer = S3ConfigSerializer(data=data)
assert not serializer.is_valid()
assert "output_directory" in serializer.errors

View File

@@ -5641,7 +5641,7 @@ class TestIntegrationViewSet:
[
# Amazon S3 - AWS credentials
(
Integration.IntegrationChoices.S3,
Integration.IntegrationChoices.AMAZON_S3,
{
"bucket_name": "bucket-name",
"output_directory": "output-directory",
@@ -5653,7 +5653,7 @@ class TestIntegrationViewSet:
),
# Amazon S3 - No credentials (AWS self-hosted)
(
Integration.IntegrationChoices.S3,
Integration.IntegrationChoices.AMAZON_S3,
{
"bucket_name": "bucket-name",
"output_directory": "output-directory",
@@ -5717,7 +5717,7 @@ class TestIntegrationViewSet:
"data": {
"type": "integrations",
"attributes": {
"integration_type": Integration.IntegrationChoices.S3,
"integration_type": Integration.IntegrationChoices.AMAZON_S3,
"configuration": {
"bucket_name": "bucket-name",
"output_directory": "output-directory",
@@ -5952,11 +5952,11 @@ class TestIntegrationViewSet:
("inserted_at", TODAY, 2),
("inserted_at.gte", "2024-01-01", 2),
("inserted_at.lte", "2024-01-01", 0),
("integration_type", Integration.IntegrationChoices.S3, 2),
("integration_type", Integration.IntegrationChoices.AMAZON_S3, 2),
("integration_type", Integration.IntegrationChoices.SLACK, 0),
(
"integration_type__in",
f"{Integration.IntegrationChoices.S3},{Integration.IntegrationChoices.SLACK}",
f"{Integration.IntegrationChoices.AMAZON_S3},{Integration.IntegrationChoices.SLACK}",
2,
),
]

View File

@@ -7,9 +7,10 @@ from rest_framework.exceptions import NotFound, ValidationError
from api.db_router import MainRouter
from api.exceptions import InvitationTokenExpiredException
from api.models import Invitation, Processor, Provider, Resource
from api.models import Integration, Invitation, Processor, Provider, Resource
from api.v1.serializers import FindingMetadataSerializer
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.lib.s3.s3 import S3
from prowler.providers.azure.azure_provider import AzureProvider
from prowler.providers.common.models import Connection
from prowler.providers.gcp.gcp_provider import GcpProvider
@@ -175,6 +176,37 @@ def prowler_provider_connection_test(provider: Provider) -> Connection:
)
def prowler_integration_connection_test(integration: Integration) -> Connection:
"""Test the connection to a Prowler integration based on the given integration type.
Args:
integration (Integration): The integration object containing the integration type and associated credentials.
Returns:
Connection: A connection object representing the result of the connection test for the specified integration.
"""
if integration.integration_type == Integration.IntegrationChoices.AMAZON_S3:
return S3.test_connection(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
raise_on_exception=False,
)
# TODO: It is possible that we can unify the connection test for all integrations, but need refactoring
# to avoid code duplication. Actually the AWS integrations are similar, so SecurityHub and S3 can be unified making some changes in the SDK.
elif (
integration.integration_type == Integration.IntegrationChoices.AWS_SECURITY_HUB
):
pass
elif integration.integration_type == Integration.IntegrationChoices.JIRA:
pass
elif integration.integration_type == Integration.IntegrationChoices.SLACK:
pass
else:
raise ValueError(
f"Integration type {integration.integration_type} not supported"
)
def validate_invitation(
invitation_token: str, email: str, raise_not_found=False
) -> Invitation:

View File

@@ -1,3 +1,6 @@
import os
import re
from drf_spectacular.utils import extend_schema_field
from rest_framework_json_api import serializers
@@ -8,6 +11,41 @@ class S3ConfigSerializer(BaseValidateSerializer):
bucket_name = serializers.CharField()
output_directory = serializers.CharField()
def validate_output_directory(self, value):
"""
Validate the output_directory field to ensure it's a properly formatted path.
Prevents paths with excessive slashes like "///////test".
"""
if not value:
raise serializers.ValidationError("Output directory cannot be empty.")
# Normalize the path to remove excessive slashes
normalized_path = os.path.normpath(value)
# Remove leading slashes for S3 paths
if normalized_path.startswith("/"):
normalized_path = normalized_path.lstrip("/")
# Check for invalid characters or patterns
if re.search(r'[<>:"|?*]', normalized_path):
raise serializers.ValidationError(
'Output directory contains invalid characters. Avoid: < > : " | ? *'
)
# Check for empty path after normalization
if not normalized_path or normalized_path == ".":
raise serializers.ValidationError(
"Output directory cannot be empty or just '.'."
)
# Check for paths that are too long (S3 key limit is 1024 characters, leave some room for filename)
if len(normalized_path) > 900:
raise serializers.ValidationError(
"Output directory path is too long (max 900 characters)."
)
return normalized_path
class Meta:
resource_name = "integrations"
@@ -98,7 +136,9 @@ class IntegrationCredentialField(serializers.JSONField):
},
"output_directory": {
"type": "string",
"description": "The directory path within the bucket where files will be saved.",
"description": 'The directory path within the bucket where files will be saved. Path will be normalized to remove excessive slashes and invalid characters are not allowed (< > : " | ? *). Maximum length is 900 characters.',
"maxLength": 900,
"pattern": '^[^<>:"|?*]+$',
},
},
"required": ["bucket_name", "output_directory"],

View File

@@ -1950,6 +1950,16 @@ class ScheduleDailyCreateSerializer(serializers.Serializer):
class BaseWriteIntegrationSerializer(BaseWriteSerializer):
def validate(self, attrs):
if Integration.objects.filter(
configuration=attrs.get("configuration")
).exists():
raise serializers.ValidationError(
{"name": "This integration already exists."}
)
return super().validate(attrs)
@staticmethod
def validate_integration_data(
integration_type: str,
@@ -1957,7 +1967,7 @@ class BaseWriteIntegrationSerializer(BaseWriteSerializer):
configuration: dict,
credentials: dict,
):
if integration_type == Integration.IntegrationChoices.S3:
if integration_type == Integration.IntegrationChoices.AMAZON_S3:
config_serializer = S3ConfigSerializer
credentials_serializers = [AWSCredentialSerializer]
# TODO: This will be required for AWS Security Hub
@@ -1975,7 +1985,11 @@ class BaseWriteIntegrationSerializer(BaseWriteSerializer):
}
)
config_serializer(data=configuration).is_valid(raise_exception=True)
serializer_instance = config_serializer(data=configuration)
serializer_instance.is_valid(raise_exception=True)
# Apply the validated (and potentially transformed) data back to configuration
configuration.update(serializer_instance.validated_data)
for cred_serializer in credentials_serializers:
try:
@@ -2059,6 +2073,7 @@ class IntegrationCreateSerializer(BaseWriteIntegrationSerializer):
}
def validate(self, attrs):
super().validate(attrs)
integration_type = attrs.get("integration_type")
providers = attrs.get("providers")
configuration = attrs.get("configuration")
@@ -2118,6 +2133,7 @@ class IntegrationUpdateSerializer(BaseWriteIntegrationSerializer):
}
def validate(self, attrs):
super().validate(attrs)
integration_type = self.instance.integration_type
providers = attrs.get("providers")
configuration = attrs.get("configuration") or self.instance.configuration

View File

@@ -57,6 +57,7 @@ from tasks.beat import schedule_provider_scan
from tasks.jobs.export import get_s3_client
from tasks.tasks import (
backfill_scan_resource_summaries_task,
check_integration_connection_task,
check_lighthouse_connection_task,
check_provider_connection_task,
delete_provider_task,
@@ -3838,6 +3839,32 @@ class IntegrationViewSet(BaseRLSViewSet):
context["allowed_providers"] = self.allowed_providers
return context
@extend_schema(
tags=["Integration"],
summary="Check integration connection",
description="Try to verify integration connection",
request=None,
responses={202: OpenApiResponse(response=TaskSerializer)},
)
@action(detail=True, methods=["post"], url_name="connection")
def connection(self, request, pk=None):
get_object_or_404(Integration, pk=pk)
with transaction.atomic():
task = check_integration_connection_task.delay(
integration_id=pk, tenant_id=self.request.tenant_id
)
prowler_task = Task.objects.get(id=task.id)
serializer = TaskSerializer(prowler_task)
return Response(
data=serializer.data,
status=status.HTTP_202_ACCEPTED,
headers={
"Content-Location": reverse(
"task-detail", kwargs={"pk": prowler_task.id}
)
},
)
@extend_schema_view(
list=extend_schema(

View File

@@ -1065,7 +1065,7 @@ def integrations_fixture(providers_fixture):
enabled=True,
connected=True,
integration_type="amazon_s3",
configuration={"key": "value"},
configuration={"key": "value1"},
credentials={"psswd": "1234"},
)
IntegrationProviderRelationship.objects.create(

View File

@@ -3,8 +3,11 @@ from datetime import datetime, timezone
import openai
from celery.utils.log import get_task_logger
from api.models import LighthouseConfiguration, Provider
from api.utils import prowler_provider_connection_test
from api.models import Integration, LighthouseConfiguration, Provider
from api.utils import (
prowler_integration_connection_test,
prowler_provider_connection_test,
)
logger = get_task_logger(__name__)
@@ -83,3 +86,30 @@ def check_lighthouse_connection(lighthouse_config_id: str):
lighthouse_config.is_active = False
lighthouse_config.save()
return {"connected": False, "error": str(e), "available_models": []}
def check_integration_connection(integration_id: str):
"""
Business logic to check the connection status of an integration.
Args:
integration_id (str): The primary key of the Integration instance to check.
"""
integration = Integration.objects.get(pk=integration_id)
try:
result = prowler_integration_connection_test(integration)
except Exception as e:
logger.warning(
f"Unexpected exception checking {integration.integration_type} integration connection: {str(e)}"
)
raise e
# Update integration connection status
integration.connected = result.is_connected
integration.connection_last_checked_at = datetime.now(tz=timezone.utc)
integration.save()
return {
"connected": result.is_connected,
"error": str(result.error) if result.error else None,
}

View File

@@ -171,7 +171,7 @@ def get_s3_client():
return s3_client
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str | None:
"""
Upload the specified ZIP file to an S3 bucket.
If the S3 bucket environment variables are not configured,
@@ -188,7 +188,7 @@ def _upload_to_s3(tenant_id: str, zip_path: str, scan_id: str) -> str:
"""
bucket = base.DJANGO_OUTPUT_S3_AWS_OUTPUT_BUCKET
if not bucket:
return None
return
try:
s3 = get_s3_client()

View File

@@ -0,0 +1,160 @@
import os
from glob import glob
from celery.utils.log import get_task_logger
from api.db_utils import rls_transaction
from api.models import Integration
from prowler.lib.outputs.asff.asff import ASFF
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.csv.csv import CSV
from prowler.lib.outputs.html.html import HTML
from prowler.lib.outputs.ocsf.ocsf import OCSF
from prowler.providers.aws.lib.s3.s3 import S3
from prowler.providers.common.models import Connection
logger = get_task_logger(__name__)
def get_s3_client_from_integration(
integration: Integration,
) -> tuple[bool, S3 | Connection]:
"""
Create and return a boto3 S3 client using AWS credentials from an integration.
Args:
integration (Integration): The integration to get the S3 client from.
Returns:
tuple[bool, S3 | Connection]: A tuple containing a boolean indicating if the connection was successful and the S3 client or connection object.
"""
s3 = S3(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
output_directory=integration.configuration["output_directory"],
)
connection = s3.test_connection(
**integration.credentials,
bucket_name=integration.configuration["bucket_name"],
)
if connection.is_connected:
return True, s3
return False, connection
def upload_s3_integration(
tenant_id: str, provider_id: str, output_directory: str
) -> bool:
"""
Upload the specified output files to an S3 bucket from an integration.
Reconstructs output objects from files in the output directory instead of using serialized data.
Args:
tenant_id (str): The tenant identifier, used as part of the S3 key prefix.
provider_id (str): The provider identifier, used as part of the S3 key prefix.
output_directory (str): Path to the directory containing output files.
Returns:
bool: True if all integrations were executed, False otherwise.
Raises:
botocore.exceptions.ClientError: If the upload attempt to S3 fails for any reason.
"""
logger.info(f"Processing S3 integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
integrations = list(
Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
)
)
if not integrations:
logger.error(f"No S3 integrations found for provider {provider_id}")
return False
integration_executions = 0
for integration in integrations:
try:
connected, s3 = get_s3_client_from_integration(integration)
# Since many scans will be send to the same S3 bucket, we need to
# add the output directory to the S3 output directory to avoid
# overwriting the files and known the scan origin.
folder = os.getenv("OUTPUT_DIRECTORY", "/tmp/prowler_api_output")
s3._output_directory = (
f"{s3._output_directory}{output_directory.split(folder)[-1]}"
)
except Exception as e:
logger.error(
f"S3 connection failed for integration {integration.id}: {e}"
)
continue
if connected:
try:
# Reconstruct generated_outputs from files in output directory
# This approach scans the output directory for files and creates the appropriate
# output objects based on file extensions and naming patterns.
generated_outputs = {"regular": [], "compliance": []}
# Find and recreate regular outputs (CSV, HTML, OCSF)
output_file_patterns = {
".csv": CSV,
".html": HTML,
".ocsf.json": OCSF,
".asff.json": ASFF,
}
base_dir = os.path.dirname(output_directory)
for extension, output_class in output_file_patterns.items():
pattern = f"{output_directory}*{extension}"
for file_path in glob(pattern):
if os.path.exists(file_path):
output = output_class(findings=[], file_path=file_path)
output.create_file_descriptor(file_path)
generated_outputs["regular"].append(output)
# Find and recreate compliance outputs
compliance_pattern = os.path.join(base_dir, "compliance", "*.csv")
for file_path in glob(compliance_pattern):
if os.path.exists(file_path):
output = GenericCompliance(
findings=[],
compliance=None,
file_path=file_path,
file_extension=".csv",
)
output.create_file_descriptor(file_path)
generated_outputs["compliance"].append(output)
# Use send_to_bucket with recreated generated_outputs objects
s3.send_to_bucket(generated_outputs)
except Exception as e:
logger.error(
f"S3 upload failed for integration {integration.id}: {e}"
)
continue
integration_executions += 1
else:
integration.connected = False
integration.save()
logger.error(
f"S3 upload failed for integration {integration.id}: {s3.error}"
)
result = integration_executions == len(integrations)
if result:
logger.info(
f"All the S3 integrations completed successfully for provider {provider_id}"
)
else:
logger.error(f"Some S3 integrations failed for provider {provider_id}")
return result
except Exception as e:
logger.error(f"S3 integrations failed for provider {provider_id}: {str(e)}")
return False

View File

@@ -2,13 +2,17 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path
from shutil import rmtree
from celery import chain, shared_task
from celery import chain, group, shared_task
from celery.utils.log import get_task_logger
from config.celery import RLSTask
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE, DJANGO_TMP_OUTPUT_DIRECTORY
from django_celery_beat.models import PeriodicTask
from tasks.jobs.backfill import backfill_resource_scan_summaries
from tasks.jobs.connection import check_lighthouse_connection, check_provider_connection
from tasks.jobs.connection import (
check_integration_connection,
check_lighthouse_connection,
check_provider_connection,
)
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.export import (
COMPLIANCE_CLASS_MAP,
@@ -17,6 +21,7 @@ from tasks.jobs.export import (
_generate_output_directory,
_upload_to_s3,
)
from tasks.jobs.integrations import upload_s3_integration
from tasks.jobs.scan import (
aggregate_findings,
create_compliance_requirements,
@@ -27,7 +32,7 @@ from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Finding, Provider, Scan, ScanSummary, StateChoices
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
from api.utils import initialize_prowler_provider
from api.v1.serializers import ScanTaskSerializer
from prowler.lib.check.compliance_models import Compliance
@@ -54,6 +59,10 @@ def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str)
generate_outputs_task.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
check_integrations_task.si(
tenant_id=tenant_id,
provider_id=provider_id,
),
).apply_async()
@@ -74,6 +83,18 @@ def check_provider_connection_task(provider_id: str):
return check_provider_connection(provider_id=provider_id)
@shared_task(base=RLSTask, name="integration-connection-check")
@set_tenant
def check_integration_connection_task(integration_id: str):
"""
Task to check the connection status of an integration.
Args:
integration_id (str): The primary key of the Integration instance to check.
"""
return check_integration_connection(integration_id=integration_id)
@shared_task(
base=RLSTask, name="provider-deletion", queue="deletion", autoretry_for=(Exception,)
)
@@ -361,7 +382,33 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
compressed = _compress_output_files(out_dir)
upload_uri = _upload_to_s3(tenant_id, compressed, scan_id)
# S3 integrations (need output_directory)
with rls_transaction(tenant_id):
s3_integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,
)
if s3_integrations:
# Pass the output directory path to S3 integration task to reconstruct objects from files
s3_integration_task.apply_async(
kwargs={
"tenant_id": tenant_id,
"provider_id": provider_id,
"output_directory": out_dir,
}
).get(
disable_sync_subtasks=False
) # TODO: This synchronous execution is NOT recommended
# We're forced to do this because we need the files to exist before deletion occurs.
# Once we have the periodic file cleanup task implemented, we should:
# 1. Remove this .get() call and make it fully async
# 2. For Cloud deployments, develop a secondary approach where outputs are stored
# directly in S3 and read from there, eliminating local file dependencies
if upload_uri:
# TODO: We need to create a new periodic task to delete the output files
# This task shouldn't be responsible for deleting the output files
try:
rmtree(Path(compressed).parent, ignore_errors=True)
except Exception as e:
@@ -372,7 +419,10 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
Scan.all_objects.filter(id=scan_id).update(output_location=final_location)
logger.info(f"Scan outputs at {final_location}")
return {"upload": did_upload}
return {
"upload": did_upload,
}
@shared_task(name="backfill-scan-resource-summaries", queue="backfill")
@@ -420,3 +470,72 @@ def check_lighthouse_connection_task(lighthouse_config_id: str, tenant_id: str =
- 'available_models' (list): List of available models if connection is successful.
"""
return check_lighthouse_connection(lighthouse_config_id=lighthouse_config_id)
@shared_task(name="integration-check")
def check_integrations_task(tenant_id: str, provider_id: str):
"""
Check and execute all configured integrations for a provider.
Args:
tenant_id (str): The tenant identifier
provider_id (str): The provider identifier
"""
logger.info(f"Checking integrations for provider {provider_id}")
try:
with rls_transaction(tenant_id):
integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id
)
if not integrations.exists():
logger.info(f"No integrations configured for provider {provider_id}")
return {"integrations_processed": 0}
integration_tasks = []
# TODO: Add other integration types here
# slack_integrations = integrations.filter(
# integration_type=Integration.IntegrationChoices.SLACK
# )
# if slack_integrations.exists():
# integration_tasks.append(
# slack_integration_task.s(
# tenant_id=tenant_id,
# provider_id=provider_id,
# )
# )
except Exception as e:
logger.error(f"Integration check failed for provider {provider_id}: {str(e)}")
return {"integrations_processed": 0, "error": str(e)}
# Execute all integration tasks in parallel if any were found
if integration_tasks:
job = group(integration_tasks)
job.apply_async()
logger.info(f"Launched {len(integration_tasks)} integration task(s)")
return {"integrations_processed": len(integration_tasks)}
@shared_task(
base=RLSTask,
name="integration-s3",
queue="integrations",
)
def s3_integration_task(
tenant_id: str,
provider_id: str,
output_directory: str,
):
"""
Process S3 integrations for a provider.
Args:
tenant_id (str): The tenant identifier
provider_id (str): The provider identifier
output_directory (str): Path to the directory containing output files
"""
return upload_s3_integration(tenant_id, provider_id, output_directory)

View File

@@ -0,0 +1,422 @@
from unittest.mock import MagicMock, patch
import pytest
from tasks.jobs.integrations import (
get_s3_client_from_integration,
upload_s3_integration,
)
from api.models import Integration
from api.utils import prowler_integration_connection_test
from prowler.providers.common.models import Connection
@pytest.mark.django_db
class TestS3IntegrationUploads:
@patch("tasks.jobs.integrations.S3")
def test_get_s3_client_from_integration_success(self, mock_s3_class):
mock_integration = MagicMock()
mock_integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
mock_integration.configuration = {
"bucket_name": "test-bucket",
"output_directory": "test-prefix",
}
mock_s3 = MagicMock()
mock_connection = MagicMock()
mock_connection.is_connected = True
mock_s3.test_connection.return_value = mock_connection
mock_s3_class.return_value = mock_s3
connected, s3 = get_s3_client_from_integration(mock_integration)
assert connected is True
assert s3 == mock_s3
mock_s3_class.assert_called_once_with(
**mock_integration.credentials,
bucket_name="test-bucket",
output_directory="test-prefix",
)
mock_s3.test_connection.assert_called_once_with(
**mock_integration.credentials,
bucket_name="test-bucket",
)
@patch("tasks.jobs.integrations.S3")
def test_get_s3_client_from_integration_failure(self, mock_s3_class):
mock_integration = MagicMock()
mock_integration.credentials = {}
mock_integration.configuration = {
"bucket_name": "test-bucket",
"output_directory": "test-prefix",
}
from prowler.providers.common.models import Connection
mock_connection = Connection()
mock_connection.is_connected = False
mock_connection.error = Exception("test error")
mock_s3 = MagicMock()
mock_s3.test_connection.return_value = mock_connection
mock_s3_class.return_value = mock_s3
connected, connection = get_s3_client_from_integration(mock_integration)
assert connected is False
assert isinstance(connection, Connection)
assert str(connection.error) == "test error"
@patch("tasks.jobs.integrations.GenericCompliance")
@patch("tasks.jobs.integrations.ASFF")
@patch("tasks.jobs.integrations.OCSF")
@patch("tasks.jobs.integrations.HTML")
@patch("tasks.jobs.integrations.CSV")
@patch("tasks.jobs.integrations.glob")
@patch("tasks.jobs.integrations.get_s3_client_from_integration")
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
def test_upload_s3_integration_uploads_serialized_outputs(
self,
mock_integration_model,
mock_rls,
mock_get_s3,
mock_glob,
mock_csv,
mock_html,
mock_ocsf,
mock_asff,
mock_compliance,
):
tenant_id = "tenant-id"
provider_id = "provider-id"
integration = MagicMock()
integration.id = "i-1"
integration.configuration = {
"bucket_name": "bucket",
"output_directory": "prefix",
}
mock_integration_model.objects.filter.return_value = [integration]
mock_s3 = MagicMock()
mock_get_s3.return_value = (True, mock_s3)
# Mock the output classes to return mock instances
mock_csv_instance = MagicMock()
mock_html_instance = MagicMock()
mock_ocsf_instance = MagicMock()
mock_asff_instance = MagicMock()
mock_compliance_instance = MagicMock()
mock_csv.return_value = mock_csv_instance
mock_html.return_value = mock_html_instance
mock_ocsf.return_value = mock_ocsf_instance
mock_asff.return_value = mock_asff_instance
mock_compliance.return_value = mock_compliance_instance
# Mock glob to return test files
output_directory = "/tmp/prowler_output/scan123"
mock_glob.side_effect = [
["/tmp/prowler_output/scan123.csv"],
["/tmp/prowler_output/scan123.html"],
["/tmp/prowler_output/scan123.ocsf.json"],
["/tmp/prowler_output/scan123.asff.json"],
["/tmp/prowler_output/compliance/compliance.csv"],
]
with patch("os.path.exists", return_value=True):
with patch("os.getenv", return_value="/tmp/prowler_api_output"):
result = upload_s3_integration(tenant_id, provider_id, output_directory)
assert result is True
mock_s3.send_to_bucket.assert_called_once()
@patch("tasks.jobs.integrations.get_s3_client_from_integration")
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.logger")
def test_upload_s3_integration_fails_connection_logs_error(
self, mock_logger, mock_integration_model, mock_rls, mock_get_s3
):
tenant_id = "tenant-id"
provider_id = "provider-id"
integration = MagicMock()
integration.id = "i-1"
integration.connected = True
mock_s3_client = MagicMock()
mock_s3_client.error = "Connection failed"
mock_integration_model.objects.filter.return_value = [integration]
mock_get_s3.return_value = (False, mock_s3_client)
output_directory = "/tmp/prowler_output/scan123"
result = upload_s3_integration(tenant_id, provider_id, output_directory)
assert result is False
integration.save.assert_called_once()
assert integration.connected is False
mock_logger.error.assert_any_call(
"S3 upload failed for integration i-1: Connection failed"
)
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.logger")
def test_upload_s3_integration_logs_if_no_integrations(
self, mock_logger, mock_integration_model, mock_rls
):
mock_integration_model.objects.filter.return_value = []
output_directory = "/tmp/prowler_output/scan123"
result = upload_s3_integration("tenant", "provider", output_directory)
assert result is False
mock_logger.error.assert_called_once_with(
"No S3 integrations found for provider provider"
)
@patch(
"tasks.jobs.integrations.get_s3_client_from_integration",
side_effect=Exception("failed"),
)
@patch("tasks.jobs.integrations.rls_transaction")
@patch("tasks.jobs.integrations.Integration")
@patch("tasks.jobs.integrations.logger")
def test_upload_s3_integration_logs_connection_exception_and_continues(
self, mock_logger, mock_integration_model, mock_rls, mock_get_s3
):
tenant_id = "tenant-id"
provider_id = "provider-id"
integration = MagicMock()
integration.id = "i-1"
integration.configuration = {
"bucket_name": "bucket",
"output_directory": "prefix",
}
mock_integration_model.objects.filter.return_value = [integration]
output_directory = "/tmp/prowler_output/scan123"
result = upload_s3_integration(tenant_id, provider_id, output_directory)
assert result is False
mock_logger.error.assert_any_call(
"S3 connection failed for integration i-1: failed"
)
def test_s3_integration_validates_and_normalizes_output_directory(self):
"""Test that S3 integration validation normalizes output_directory paths."""
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "///////test", # This should be normalized
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
# Should not raise an exception and should normalize the path
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Verify that the path was normalized
assert configuration["output_directory"] == "test"
def test_s3_integration_rejects_invalid_output_directory_characters(self):
"""Test that S3 integration validation rejects invalid characters."""
from rest_framework.exceptions import ValidationError
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "test<invalid", # Contains invalid character
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
with pytest.raises(ValidationError) as exc_info:
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Should contain validation error about invalid characters
assert "Output directory contains invalid characters" in str(exc_info.value)
def test_s3_integration_rejects_empty_output_directory(self):
"""Test that S3 integration validation rejects empty directories."""
from rest_framework.exceptions import ValidationError
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "/////", # This becomes empty after normalization
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
with pytest.raises(ValidationError) as exc_info:
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Should contain validation error about empty directory
assert "Output directory cannot be empty" in str(exc_info.value)
def test_s3_integration_normalizes_complex_paths(self):
"""Test that S3 integration validation handles complex path normalization."""
from api.models import Integration
from api.v1.serializers import BaseWriteIntegrationSerializer
integration_type = Integration.IntegrationChoices.AMAZON_S3
providers = []
configuration = {
"bucket_name": "test-bucket",
"output_directory": "//test//folder///subfolder//",
}
credentials = {
"aws_access_key_id": "AKIATEST",
"aws_secret_access_key": "secret123",
}
BaseWriteIntegrationSerializer.validate_integration_data(
integration_type, providers, configuration, credentials
)
# Verify complex path normalization
assert configuration["output_directory"] == "test/folder/subfolder"
@pytest.mark.django_db
class TestProwlerIntegrationConnectionTest:
@patch("api.utils.S3")
def test_s3_integration_connection_success(self, mock_s3_class):
"""Test successful S3 integration connection."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
integration.configuration = {"bucket_name": "test-bucket"}
mock_connection = Connection(is_connected=True)
mock_s3_class.test_connection.return_value = mock_connection
result = prowler_integration_connection_test(integration)
assert result.is_connected is True
mock_s3_class.test_connection.assert_called_once_with(
**integration.credentials,
bucket_name="test-bucket",
raise_on_exception=False,
)
@patch("api.utils.S3")
def test_aws_provider_exception_handling(self, mock_s3_class):
"""Test S3 connection exception is properly caught and returned."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
integration.credentials = {
"aws_access_key_id": "invalid",
"aws_secret_access_key": "credentials",
}
integration.configuration = {"bucket_name": "test-bucket"}
test_exception = Exception("Invalid credentials")
mock_connection = Connection(is_connected=False, error=test_exception)
mock_s3_class.test_connection.return_value = mock_connection
result = prowler_integration_connection_test(integration)
assert result.is_connected is False
assert result.error == test_exception
mock_s3_class.test_connection.assert_called_once_with(
aws_access_key_id="invalid",
aws_secret_access_key="credentials",
bucket_name="test-bucket",
raise_on_exception=False,
)
@patch("api.utils.AwsProvider")
@patch("api.utils.S3")
def test_s3_integration_connection_failure(self, mock_s3_class, mock_aws_provider):
"""Test S3 integration connection failure."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AMAZON_S3
integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
integration.configuration = {"bucket_name": "test-bucket"}
mock_session = MagicMock()
mock_aws_provider.return_value.session.current_session = mock_session
mock_connection = Connection(
is_connected=False, error=Exception("Bucket not found")
)
mock_s3_class.test_connection.return_value = mock_connection
result = prowler_integration_connection_test(integration)
assert result.is_connected is False
assert str(result.error) == "Bucket not found"
@patch("api.utils.AwsProvider")
@patch("api.utils.S3")
def test_aws_security_hub_integration_connection(
self, mock_s3_class, mock_aws_provider
):
"""Test AWS Security Hub integration only validates AWS session."""
integration = MagicMock()
integration.integration_type = Integration.IntegrationChoices.AWS_SECURITY_HUB
integration.credentials = {
"aws_access_key_id": "AKIA...",
"aws_secret_access_key": "SECRET",
}
integration.configuration = {"region": "us-east-1"}
mock_session = MagicMock()
mock_aws_provider.return_value.session.current_session = mock_session
# For AWS Security Hub, the function should return early after AWS session validation
result = prowler_integration_connection_test(integration)
# The function should not reach S3 test_connection for AWS_SECURITY_HUB
mock_s3_class.test_connection.assert_not_called()
# Since no exception was raised during AWS session creation, return None (success)
assert result is None
def test_unsupported_integration_type(self):
"""Test unsupported integration type raises ValueError."""
integration = MagicMock()
integration.integration_type = "UNSUPPORTED_TYPE"
integration.credentials = {}
integration.configuration = {}
with pytest.raises(
ValueError, match="Integration type UNSUPPORTED_TYPE not supported"
):
prowler_integration_connection_test(integration)

View File

@@ -1,9 +1,13 @@
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from tasks.tasks import _perform_scan_complete_tasks, generate_outputs_task
from tasks.tasks import (
_perform_scan_complete_tasks,
check_integrations_task,
generate_outputs_task,
s3_integration_task,
)
# TODO Move this to outputs/reports jobs
@@ -27,7 +31,6 @@ class TestGenerateOutputs:
assert result == {"upload": False}
mock_filter.assert_called_once_with(scan_id=self.scan_id)
@patch("tasks.tasks.rmtree")
@patch("tasks.tasks._upload_to_s3")
@patch("tasks.tasks._compress_output_files")
@patch("tasks.tasks.get_compliance_frameworks")
@@ -46,7 +49,6 @@ class TestGenerateOutputs:
mock_get_available_frameworks,
mock_compress,
mock_upload,
mock_rmtree,
):
mock_scan_summary_filter.return_value.exists.return_value = True
@@ -96,6 +98,7 @@ class TestGenerateOutputs:
return_value=("out-dir", "comp-dir"),
),
patch("tasks.tasks.Scan.all_objects.filter") as mock_scan_update,
patch("tasks.tasks.rmtree"),
):
mock_compress.return_value = "/tmp/zipped.zip"
mock_upload.return_value = "s3://bucket/zipped.zip"
@@ -110,9 +113,6 @@ class TestGenerateOutputs:
mock_scan_update.return_value.update.assert_called_once_with(
output_location="s3://bucket/zipped.zip"
)
mock_rmtree.assert_called_once_with(
Path("/tmp/zipped.zip").parent, ignore_errors=True
)
def test_generate_outputs_fails_upload(self):
with (
@@ -144,6 +144,7 @@ class TestGenerateOutputs:
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value=None),
patch("tasks.tasks.Scan.all_objects.filter") as mock_scan_update,
patch("tasks.tasks.rmtree"),
):
mock_filter.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
@@ -153,7 +154,7 @@ class TestGenerateOutputs:
result = generate_outputs_task(
scan_id="scan",
provider_id="provider",
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
@@ -185,6 +186,7 @@ class TestGenerateOutputs:
patch("tasks.tasks._compress_output_files", return_value="/tmp/compressed"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/f.zip"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
):
mock_filter.return_value.exists.return_value = True
mock_findings.return_value.order_by.return_value.iterator.return_value = [
@@ -255,8 +257,8 @@ class TestGenerateOutputs:
),
patch("tasks.tasks._compress_output_files", return_value="outdir.zip"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/outdir.zip"),
patch("tasks.tasks.rmtree"),
patch("tasks.tasks.Scan.all_objects.filter"),
patch("tasks.tasks.rmtree"),
patch(
"tasks.tasks.batched",
return_value=[
@@ -333,13 +335,13 @@ class TestGenerateOutputs:
),
patch("tasks.tasks._compress_output_files", return_value="outdir.zip"),
patch("tasks.tasks._upload_to_s3", return_value="s3://bucket/outdir.zip"),
patch("tasks.tasks.rmtree"),
patch(
"tasks.tasks.Scan.all_objects.filter",
return_value=MagicMock(update=lambda **kw: None),
),
patch("tasks.tasks.batched", return_value=two_batches),
patch("tasks.tasks.OUTPUT_FORMATS_MAPPING", {}),
patch("tasks.tasks.rmtree"),
patch(
"tasks.tasks.COMPLIANCE_CLASS_MAP",
{"aws": [(lambda name: True, TrackingComplianceWriter)]},
@@ -358,6 +360,7 @@ class TestGenerateOutputs:
assert writer.transform_calls == [([raw2], compliance_obj, "cis")]
assert result == {"upload": True}
# TODO: We need to add a periodic task to delete old output files
def test_generate_outputs_logs_rmtree_exception(self, caplog):
mock_finding_output = MagicMock()
mock_finding_output.compliance = {"cis": ["requirement-1", "requirement-2"]}
@@ -436,3 +439,88 @@ class TestScanCompleteTasks:
provider_id="provider-id",
tenant_id="tenant-id",
)
@pytest.mark.django_db
class TestCheckIntegrationsTask:
def setup_method(self):
self.scan_id = str(uuid.uuid4())
self.provider_id = str(uuid.uuid4())
self.tenant_id = str(uuid.uuid4())
self.output_directory = "/tmp/some-output-dir"
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_no_integrations(
self, mock_integration_filter, mock_rls
):
mock_integration_filter.return_value.exists.return_value = False
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
)
assert result == {"integrations_processed": 0}
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id
)
@patch("tasks.tasks.group")
@patch("tasks.tasks.rls_transaction")
@patch("tasks.tasks.Integration.objects.filter")
def test_check_integrations_s3_success(
self, mock_integration_filter, mock_rls, mock_group
):
# Mock that we have some integrations
mock_integration_filter.return_value.exists.return_value = True
# Ensure rls_transaction is mocked
mock_rls.return_value.__enter__.return_value = None
# Since the current implementation doesn't actually create tasks yet (TODO comment),
# we test that no tasks are created but the function returns the correct count
result = check_integrations_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
)
assert result == {"integrations_processed": 0}
mock_integration_filter.assert_called_once_with(
integrationproviderrelationship__provider_id=self.provider_id
)
# group should not be called since no integration tasks are created yet
mock_group.assert_not_called()
@patch("tasks.tasks.upload_s3_integration")
def test_s3_integration_task_success(self, mock_upload):
mock_upload.return_value = True
output_directory = "/tmp/prowler_api_output/test"
result = s3_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
output_directory=output_directory,
)
assert result is True
mock_upload.assert_called_once_with(
self.tenant_id, self.provider_id, output_directory
)
@patch("tasks.tasks.upload_s3_integration")
def test_s3_integration_task_failure(self, mock_upload):
mock_upload.return_value = False
output_directory = "/tmp/prowler_api_output/test"
result = s3_integration_task(
tenant_id=self.tenant_id,
provider_id=self.provider_id,
output_directory=output_directory,
)
assert result is False
mock_upload.assert_called_once_with(
self.tenant_id, self.provider_id, output_directory
)

View File

@@ -16,7 +16,7 @@ services:
volumes:
- "./api/src/backend:/home/prowler/backend"
- "./api/pyproject.toml:/home/prowler/pyproject.toml"
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
- "outputs:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -87,7 +87,7 @@ services:
- path: .env
required: false
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
- "outputs:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy
@@ -115,3 +115,7 @@ services:
entrypoint:
- "../docker-entrypoint.sh"
- "beat"
volumes:
outputs:
driver: local

View File

@@ -8,7 +8,7 @@ services:
ports:
- "${DJANGO_PORT:-8080}:${DJANGO_PORT:-8080}"
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
- "output:/tmp/prowler_api_output"
depends_on:
postgres:
condition: service_healthy
@@ -68,7 +68,7 @@ services:
- path: .env
required: false
volumes:
- "/tmp/prowler_api_output:/tmp/prowler_api_output"
- "output:/tmp/prowler_api_output"
depends_on:
valkey:
condition: service_healthy
@@ -91,3 +91,7 @@ services:
entrypoint:
- "../docker-entrypoint.sh"
- "beat"
volumes:
output:
driver: local

View File

@@ -22,6 +22,8 @@ All notable changes to the **Prowler SDK** are documented in this file.
- False positives in SQS encryption check for ephemeral queues [(#8330)](https://github.com/prowler-cloud/prowler/pull/8330)
- Add protocol validation check in security group checks to ensure proper protocol matching [(#8374)](https://github.com/prowler-cloud/prowler/pull/8374)
- Add missing audit evidence for controls 1.1.4 and 2.5.5 for ISMS-P compliance. [(#8386)](https://github.com/prowler-cloud/prowler/pull/8386)
- Use the correct @staticmethod decorator for `set_identity` and `set_session_config` methods in AwsProvider [(#8056)](https://github.com/prowler-cloud/prowler/pull/8056)
- Use the correct default value for `role_session_name` and `session_duration` in AwsSetUpSession [(#8056)](https://github.com/prowler-cloud/prowler/pull/8056)
---

View File

@@ -469,8 +469,8 @@ class AwsProvider(Provider):
return profile_region
@staticmethod
def set_identity(
self,
caller_identity: AWSCallerIdentity,
profile: str,
regions: set,
@@ -991,7 +991,8 @@ class AwsProvider(Provider):
mfa_TOTP = input("Enter MFA code: ")
return AWSMFAInfo(arn=mfa_ARN, totp=mfa_TOTP)
def set_session_config(self, retries_max_attempts: int) -> Config:
@staticmethod
def set_session_config(retries_max_attempts: int) -> Config:
"""
set_session_config returns a botocore Config object with the Prowler user agent and the default retrier configuration if nothing is passed as argument

View File

@@ -125,7 +125,10 @@ class S3:
retries_max_attempts=retries_max_attempts,
regions=regions,
)
self._session = aws_setup_session._session
self._session = aws_setup_session._session.current_session.client(
__class__.__name__.lower(),
config=aws_setup_session._session.session_config,
)
self._bucket_name = bucket_name
self._output_directory = output_directory
@@ -487,4 +490,4 @@ class S3:
except Exception as error:
if raise_on_exception:
raise S3TestConnectionError(original_exception=error)
return Connection(error=error)
return Connection(is_connected=False, error=error)

View File

@@ -6,6 +6,7 @@ from prowler.providers.aws.aws_provider import (
get_aws_region_for_sts,
parse_iam_credentials_arn,
)
from prowler.providers.aws.config import ROLE_SESSION_NAME
from prowler.providers.aws.models import (
AWSAssumeRoleConfiguration,
AWSAssumeRoleInfo,
@@ -31,9 +32,9 @@ class AwsSetUpSession:
def __init__(
self,
role_arn: str = None,
session_duration: int = None,
session_duration: int = 3600,
external_id: str = None,
role_session_name: str = None,
role_session_name: str = ROLE_SESSION_NAME,
mfa: bool = None,
profile: str = None,
aws_access_key_id: str = None,