feat(database): add db read replica support (#8869)

This commit is contained in:
Víctor Fernández Poyatos
2025-10-10 12:27:43 +02:00
committed by GitHub
parent 046baa8eb9
commit 335db928dc
14 changed files with 1267 additions and 170 deletions

6
.env
View File

@@ -29,6 +29,12 @@ POSTGRES_ADMIN_PASSWORD=postgres
POSTGRES_USER=prowler
POSTGRES_PASSWORD=postgres
POSTGRES_DB=prowler_db
# Read replica settings (optional)
# POSTGRES_REPLICA_HOST=postgres-db
# POSTGRES_REPLICA_PORT=5432
# POSTGRES_REPLICA_USER=prowler
# POSTGRES_REPLICA_PASSWORD=postgres
# POSTGRES_REPLICA_DB=prowler_db
# Celery-Prowler task settings
TASK_RETRY_DELAY_SECONDS=0.1

View File

@@ -11,6 +11,7 @@ All notable changes to the **Prowler API** are documented in this file.
- API Key support [(#8805)](https://github.com/prowler-cloud/prowler/pull/8805)
- SAML role mapping protection for single-admin tenants to prevent accidental lockout [(#8882)](https://github.com/prowler-cloud/prowler/pull/8882)
- Support for `passed_findings` and `total_findings` fields in compliance requirement overview for accurate Prowler ThreatScore calculation [(#8582)](https://github.com/prowler-cloud/prowler/pull/8582)
- Database read replica support [(#8869)](https://github.com/prowler-cloud/prowler/pull/8869)
### Changed
- Now the MANAGE_ACCOUNT permission is required to modify or read user permissions instead of MANAGE_USERS [(#8281)](https://github.com/prowler-cloud/prowler/pull/8281)

View File

@@ -1,13 +1,15 @@
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
from rest_framework.filters import SearchFilter
from rest_framework.permissions import SAFE_METHODS
from rest_framework_json_api import filters
from rest_framework_json_api.views import ModelViewSet
from api.authentication import CombinedJWTOrAPIKeyAuthentication
from api.db_router import MainRouter
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, Tenant
@@ -31,6 +33,20 @@ class BaseViewSet(ModelViewSet):
ordering_fields = "__all__"
ordering = ["id"]
def _get_request_db_alias(self, request):
if request is None:
return MainRouter.default_db
read_alias = (
MainRouter.replica_db
if request.method in SAFE_METHODS
and MainRouter.replica_db in settings.DATABASES
else None
)
if read_alias:
return read_alias
return MainRouter.default_db
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
@@ -48,8 +64,21 @@ class BaseViewSet(ModelViewSet):
class BaseRLSViewSet(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
self.db_alias = self._get_request_db_alias(request)
alias_token = None
try:
if self.db_alias != MainRouter.default_db:
alias_token = set_read_db_alias(self.db_alias)
if request is not None:
request.db_alias = self.db_alias
with transaction.atomic(using=self.db_alias):
return super().dispatch(request, *args, **kwargs)
finally:
if alias_token is not None:
reset_read_db_alias(alias_token)
self.db_alias = MainRouter.default_db
def initial(self, request, *args, **kwargs):
# Ideally, this logic would be in the `.setup()` method but DRF view sets don't call it
@@ -61,7 +90,9 @@ class BaseRLSViewSet(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(tenant_id):
with rls_transaction(
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
@@ -73,18 +104,33 @@ class BaseRLSViewSet(BaseViewSet):
class BaseTenantViewset(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
tenant = super().dispatch(request, *args, **kwargs)
self.db_alias = self._get_request_db_alias(request)
alias_token = None
try:
# If the request is a POST, create the admin role
if request.method == "POST":
isinstance(tenant, dict) and self._create_admin_role(tenant.data["id"])
except Exception as e:
self._handle_creation_error(e, tenant)
raise
if self.db_alias != MainRouter.default_db:
alias_token = set_read_db_alias(self.db_alias)
return tenant
if request is not None:
request.db_alias = self.db_alias
with transaction.atomic(using=self.db_alias):
tenant = super().dispatch(request, *args, **kwargs)
try:
# If the request is a POST, create the admin role
if request.method == "POST":
isinstance(tenant, dict) and self._create_admin_role(
tenant.data["id"]
)
except Exception as e:
self._handle_creation_error(e, tenant)
raise
return tenant
finally:
if alias_token is not None:
reset_read_db_alias(alias_token)
self.db_alias = MainRouter.default_db
def _create_admin_role(self, tenant_id):
Role.objects.using(MainRouter.admin_db).create(
@@ -117,14 +163,31 @@ class BaseTenantViewset(BaseViewSet):
raise NotAuthenticated("Tenant ID is not present in token")
user_id = str(request.user.id)
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
with rls_transaction(
value=user_id,
parameter=POSTGRES_USER_VAR,
using=getattr(self, "db_alias", MainRouter.default_db),
):
return super().initial(request, *args, **kwargs)
class BaseUserViewset(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
self.db_alias = self._get_request_db_alias(request)
alias_token = None
try:
if self.db_alias != MainRouter.default_db:
alias_token = set_read_db_alias(self.db_alias)
if request is not None:
request.db_alias = self.db_alias
with transaction.atomic(using=self.db_alias):
return super().dispatch(request, *args, **kwargs)
finally:
if alias_token is not None:
reset_read_db_alias(alias_token)
self.db_alias = MainRouter.default_db
def initial(self, request, *args, **kwargs):
# TODO refactor after improving RLS on users
@@ -137,6 +200,8 @@ class BaseUserViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(tenant_id):
with rls_transaction(
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

View File

@@ -150,12 +150,16 @@ def generate_scan_compliance(
requirement["checks"][check_id] = status
requirement["checks_status"][status.lower()] += 1
if requirement["status"] != "FAIL" and any(
value == "FAIL" for value in requirement["checks"].values()
):
requirement["status"] = "FAIL"
compliance_overview[compliance_id]["requirements_status"]["passed"] -= 1
compliance_overview[compliance_id]["requirements_status"]["failed"] += 1
if requirement["status"] != "FAIL" and any(
value == "FAIL" for value in requirement["checks"].values()
):
requirement["status"] = "FAIL"
compliance_overview[compliance_id]["requirements_status"][
"passed"
] -= 1
compliance_overview[compliance_id]["requirements_status"][
"failed"
] += 1
def generate_compliance_overview_template(prowler_compliance: dict):

View File

@@ -1,9 +1,31 @@
from contextvars import ContextVar
from django.conf import settings
ALLOWED_APPS = ("django", "socialaccount", "account", "authtoken", "silk")
_read_db_alias = ContextVar("read_db_alias", default=None)
def set_read_db_alias(alias: str | None):
if not alias:
return None
return _read_db_alias.set(alias)
def get_read_db_alias() -> str | None:
return _read_db_alias.get()
def reset_read_db_alias(token) -> None:
if token is not None:
_read_db_alias.reset(token)
class MainRouter:
default_db = "default"
admin_db = "admin"
replica_db = "replica"
def db_for_read(self, model, **hints): # noqa: F841
model_table_name = model._meta.db_table
@@ -11,6 +33,9 @@ class MainRouter:
model_table_name.startswith(f"{app}_") for app in ALLOWED_APPS
):
return self.admin_db
read_alias = get_read_db_alias()
if read_alias:
return read_alias
return None
def db_for_write(self, model, **hints): # noqa: F841
@@ -27,3 +52,8 @@ class MainRouter:
if {obj1._state.db, obj2._state.db} <= {self.default_db, self.admin_db}:
return True
return None
READ_REPLICA_ALIAS = (
MainRouter.replica_db if MainRouter.replica_db in settings.DATABASES else None
)

View File

@@ -6,12 +6,14 @@ from datetime import datetime, timedelta, timezone
from django.conf import settings
from django.contrib.auth.models import BaseUserManager
from django.db import connection, models, transaction
from django.db import DEFAULT_DB_ALIAS, connection, connections, models, transaction
from django_celery_beat.models import PeriodicTask
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError
from api.db_router import get_read_db_alias, reset_read_db_alias, set_read_db_alias
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
settings.DATABASES["default"]["PASSWORD"] if not settings.TESTING else "test"
@@ -49,7 +51,11 @@ def psycopg_connection(database_alias: str):
@contextmanager
def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
def rls_transaction(
value: str,
parameter: str = POSTGRES_TENANT_VAR,
using: str | None = None,
):
"""
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
if the value is a valid UUID.
@@ -57,16 +63,32 @@ def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
using (str | None): Optional database alias to run the transaction against. Defaults to the
active read alias (if any) or Django's default connection.
"""
with transaction.atomic():
with connection.cursor() as cursor:
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor
requested_alias = using or get_read_db_alias()
db_alias = requested_alias or DEFAULT_DB_ALIAS
if db_alias not in connections:
db_alias = DEFAULT_DB_ALIAS
router_token = None
try:
if db_alias != DEFAULT_DB_ALIAS:
router_token = set_read_db_alias(db_alias)
with transaction.atomic(using=db_alias):
conn = connections[db_alias]
with conn.cursor() as cursor:
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor
finally:
if router_token is not None:
reset_read_db_alias(router_token)
class CustomUserManager(BaseUserManager):

View File

@@ -11,11 +11,12 @@ class APIJSONRenderer(JSONRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
request = renderer_context.get("request")
tenant_id = getattr(request, "tenant_id", None) if request else None
db_alias = getattr(request, "db_alias", None) if request else None
include_param_present = "include" in request.query_params if request else False
# Use rls_transaction if needed for included resources, otherwise do nothing
context_manager = (
rls_transaction(tenant_id)
rls_transaction(tenant_id, using=db_alias)
if tenant_id and include_param_present
else nullcontext()
)

View File

@@ -3449,20 +3449,16 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
)
filtered_queryset = self.filter_queryset(self.get_queryset())
all_requirements = (
filtered_queryset.values(
"requirement_id",
"framework",
"version",
"description",
"passed_findings",
"total_findings",
)
.distinct()
.annotate(
total_instances=Count("id"),
manual_count=Count("id", filter=Q(requirement_status="MANUAL")),
)
all_requirements = filtered_queryset.values(
"requirement_id",
"framework",
"version",
"description",
).annotate(
total_instances=Count("id"),
manual_count=Count("id", filter=Q(requirement_status="MANUAL")),
passed_findings_sum=Sum("passed_findings"),
total_findings_sum=Sum("total_findings"),
)
passed_instances = (
@@ -3481,8 +3477,8 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
total_instances = requirement["total_instances"]
passed_count = passed_counts.get(requirement_id, 0)
is_manual = requirement["manual_count"] == total_instances
passed_findings = requirement["passed_findings"]
total_findings = requirement["total_findings"]
passed_findings = requirement["passed_findings_sum"] or 0
total_findings = requirement["total_findings_sum"] or 0
if is_manual:
requirement_status = "MANUAL"
elif passed_count == total_instances:

View File

@@ -5,24 +5,39 @@ DEBUG = env.bool("DJANGO_DEBUG", default=True)
ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["*"])
# Database
default_db_name = env("POSTGRES_DB", default="prowler_db")
default_db_user = env("POSTGRES_USER", default="prowler_user")
default_db_password = env("POSTGRES_PASSWORD", default="prowler")
default_db_host = env("POSTGRES_HOST", default="postgres-db")
default_db_port = env("POSTGRES_PORT", default="5432")
DATABASES = {
"prowler_user": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB", default="prowler_db"),
"USER": env("POSTGRES_USER", default="prowler_user"),
"PASSWORD": env("POSTGRES_PASSWORD", default="prowler"),
"HOST": env("POSTGRES_HOST", default="postgres-db"),
"PORT": env("POSTGRES_PORT", default="5432"),
"NAME": default_db_name,
"USER": default_db_user,
"PASSWORD": default_db_password,
"HOST": default_db_host,
"PORT": default_db_port,
},
"admin": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB", default="prowler_db"),
"NAME": default_db_name,
"USER": env("POSTGRES_ADMIN_USER", default="prowler"),
"PASSWORD": env("POSTGRES_ADMIN_PASSWORD", default="S3cret"),
"HOST": env("POSTGRES_HOST", default="postgres-db"),
"PORT": env("POSTGRES_PORT", default="5432"),
"HOST": default_db_host,
"PORT": default_db_port,
},
"replica": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_REPLICA_DB", default=default_db_name),
"USER": env("POSTGRES_REPLICA_USER", default=default_db_user),
"PASSWORD": env("POSTGRES_REPLICA_PASSWORD", default=default_db_password),
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
},
}
DATABASES["default"] = DATABASES["prowler_user"]
REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"] = tuple( # noqa: F405

View File

@@ -6,22 +6,37 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["localhost", "127.0.0.
# Database
# TODO Use Django database routers https://docs.djangoproject.com/en/5.0/topics/db/multi-db/#automatic-database-routing
default_db_name = env("POSTGRES_DB")
default_db_user = env("POSTGRES_USER")
default_db_password = env("POSTGRES_PASSWORD")
default_db_host = env("POSTGRES_HOST")
default_db_port = env("POSTGRES_PORT")
DATABASES = {
"prowler_user": {
"ENGINE": "django.db.backends.postgresql",
"NAME": env("POSTGRES_DB"),
"USER": env("POSTGRES_USER"),
"PASSWORD": env("POSTGRES_PASSWORD"),
"HOST": env("POSTGRES_HOST"),
"PORT": env("POSTGRES_PORT"),
"ENGINE": "psqlextra.backend",
"NAME": default_db_name,
"USER": default_db_user,
"PASSWORD": default_db_password,
"HOST": default_db_host,
"PORT": default_db_port,
},
"admin": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB"),
"NAME": default_db_name,
"USER": env("POSTGRES_ADMIN_USER"),
"PASSWORD": env("POSTGRES_ADMIN_PASSWORD"),
"HOST": env("POSTGRES_HOST"),
"PORT": env("POSTGRES_PORT"),
"HOST": default_db_host,
"PORT": default_db_port,
},
"replica": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_REPLICA_DB", default=default_db_name),
"USER": env("POSTGRES_REPLICA_USER", default=default_db_user),
"PASSWORD": env("POSTGRES_REPLICA_PASSWORD", default=default_db_password),
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
},
}
DATABASES["default"] = DATABASES["prowler_user"]

View File

@@ -5,6 +5,7 @@ from celery.utils.log import get_task_logger
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
from tasks.utils import batched
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.models import Finding, Integration, Provider
from api.utils import initialize_prowler_integration, initialize_prowler_provider
@@ -289,7 +290,7 @@ def upload_security_hub_integration(
has_findings = False
batch_number = 0
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
qs = (
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
.order_by("uid")

View File

@@ -1,8 +1,12 @@
import csv
import io
import json
import time
import uuid
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timezone
from typing import Any
from celery.utils.log import get_task_logger
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
@@ -14,8 +18,11 @@ from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
generate_scan_compliance,
)
from api.db_router import READ_REPLICA_ALIAS, MainRouter
from api.db_utils import (
create_objects_in_batches,
POSTGRES_TENANT_VAR,
SET_CONFIG_QUERY,
psycopg_connection,
rls_transaction,
update_objects_in_batches,
)
@@ -40,6 +47,28 @@ from prowler.lib.scan.scan import Scan as ProwlerScan
logger = get_task_logger(__name__)
# Column order must match `ComplianceRequirementOverview` schema in
# `api/models.py`. Keep this list minimal but sufficient to populate all
# non-nullable fields plus the counters we care about.
COMPLIANCE_REQUIREMENT_COPY_COLUMNS = (
"id",
"tenant_id",
"inserted_at",
"compliance_id",
"framework",
"version",
"description",
"region",
"requirement_id",
"requirement_status",
"passed_checks",
"failed_checks",
"total_checks",
"passed_findings",
"total_findings",
"scan_id",
)
def _create_finding_delta(
last_status: FindingStatus | None | str, new_status: FindingStatus | None
@@ -107,6 +136,124 @@ def _store_resources(
return resource_instance, (resource_instance.uid, resource_instance.region)
def _copy_compliance_requirement_rows(
tenant_id: str, rows: list[dict[str, Any]]
) -> None:
"""Stream compliance requirement rows into Postgres using COPY.
We leverage the admin connection (when available) to bypass the COPY + RLS
restriction, writing only the fields required by
``ComplianceRequirementOverview``.
Args:
tenant_id: Target tenant UUID.
rows: List of row dictionaries prepared by
:func:`create_compliance_requirements`.
"""
csv_buffer = io.StringIO()
writer = csv.writer(csv_buffer)
datetime_now = datetime.now(tz=timezone.utc)
for row in rows:
writer.writerow(
[
str(row.get("id")),
str(row.get("tenant_id")),
(row.get("inserted_at") or datetime_now).isoformat(),
row.get("compliance_id") or "",
row.get("framework") or "",
row.get("version") or "",
row.get("description") or "",
row.get("region") or "",
row.get("requirement_id") or "",
row.get("requirement_status") or "",
row.get("passed_checks", 0),
row.get("failed_checks", 0),
row.get("total_checks", 0),
row.get("passed_findings", 0),
row.get("total_findings", 0),
str(row.get("scan_id")),
]
)
csv_buffer.seek(0)
copy_sql = (
"COPY compliance_requirements_overviews ("
+ ", ".join(COMPLIANCE_REQUIREMENT_COPY_COLUMNS)
+ ") FROM STDIN WITH (FORMAT CSV, DELIMITER ',', QUOTE '\"', ESCAPE '\"', NULL '\\N')"
)
try:
with psycopg_connection(MainRouter.admin_db) as connection:
connection.autocommit = False
try:
with connection.cursor() as cursor:
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
cursor.copy_expert(copy_sql, csv_buffer)
connection.commit()
except Exception:
connection.rollback()
raise
finally:
csv_buffer.close()
def _persist_compliance_requirement_rows(
tenant_id: str, rows: list[dict[str, Any]]
) -> None:
"""Persist compliance requirement rows using COPY with ORM fallback.
Args:
tenant_id: Target tenant UUID.
rows: Precomputed row dictionaries that reflect the compliance
overview state for a scan.
"""
if not rows:
return
try:
_copy_compliance_requirement_rows(tenant_id, rows)
except Exception as error:
logger.exception(
"COPY bulk insert for compliance requirements failed; falling back to ORM bulk_create",
exc_info=error,
)
fallback_objects = [
ComplianceRequirementOverview(
id=row["id"],
tenant_id=row["tenant_id"],
inserted_at=row["inserted_at"],
compliance_id=row["compliance_id"],
framework=row["framework"],
version=row["version"],
description=row["description"],
region=row["region"],
requirement_id=row["requirement_id"],
requirement_status=row["requirement_status"],
passed_checks=row["passed_checks"],
failed_checks=row["failed_checks"],
total_checks=row["total_checks"],
passed_findings=row.get("passed_findings", 0),
total_findings=row.get("total_findings", 0),
scan_id=row["scan_id"],
)
for row in rows
]
with rls_transaction(tenant_id):
ComplianceRequirementOverview.objects.bulk_create(
fallback_objects, batch_size=500
)
def _normalized_compliance_key(framework: str | None, version: str | None) -> str:
"""Return normalized identifier used to group compliance totals."""
normalized_framework = (framework or "").lower().replace("-", "").replace("_", "")
normalized_version = (version or "").lower().replace("-", "").replace("_", "")
return f"{normalized_framework}{normalized_version}"
def perform_prowler_scan(
tenant_id: str,
scan_id: str,
@@ -143,7 +290,7 @@ def perform_prowler_scan(
scan_instance.save()
# Find the mutelist processor if it exists
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
try:
mutelist_processor = Processor.objects.get(
tenant_id=tenant_id, processor_type=Processor.ProcessorChoices.MUTELIST
@@ -272,7 +419,7 @@ def perform_prowler_scan(
unique_resources.add((resource_instance.uid, resource_instance.region))
# Process finding
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
finding_uid = finding.uid
last_first_seen_at = None
if finding_uid not in last_status_cache:
@@ -305,6 +452,12 @@ def perform_prowler_scan(
# If the finding is muted at this time the reason must be the configured Mutelist
muted_reason = "Muted by mutelist" if finding.muted else None
# Increment failed_findings_count cache if the finding status is FAIL and not muted
if status == FindingStatus.FAIL and not finding.muted:
resource_uid = finding.resource_uid
resource_failed_findings_cache[resource_uid] += 1
with rls_transaction(tenant_id):
# Create the finding
finding_instance = Finding.objects.create(
tenant_id=tenant_id,
@@ -325,11 +478,6 @@ def perform_prowler_scan(
)
finding_instance.add_resources([resource_instance])
# Increment failed_findings_count cache if the finding status is FAIL and not muted
if status == FindingStatus.FAIL and not finding.muted:
resource_uid = finding.resource_uid
resource_failed_findings_cache[resource_uid] += 1
# Update scan resource summaries
scan_resource_cache.add(
(
@@ -439,7 +587,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
- muted_new: Muted findings with a delta of 'new'.
- muted_changed: Muted findings with a delta of 'changed'.
"""
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
findings = Finding.objects.filter(tenant_id=tenant_id, scan_id=scan_id)
aggregation = findings.values(
@@ -582,11 +730,28 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
ValidationError: If tenant_id is not a valid UUID.
"""
try:
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
scan_instance = Scan.objects.get(pk=scan_id)
provider_instance = scan_instance.provider
prowler_provider = return_prowler_provider(provider_instance)
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
provider_instance.provider
]
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
threatscore_requirements_by_check: dict[str, set[str]] = {}
threatscore_framework = compliance_template.get(
modeled_threatscore_compliance_id
)
if threatscore_framework:
for requirement_id, requirement in threatscore_framework[
"requirements"
].items():
for check_id in requirement["checks"]:
threatscore_requirements_by_check.setdefault(check_id, set()).add(
requirement_id
)
# Get check status data by region from findings
findings = (
Finding.all_objects.filter(scan_id=scan_id, muted=False)
@@ -603,8 +768,7 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
findings_count_by_compliance = {}
check_status_by_region = {}
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
for finding in findings:
for resource in finding.small_resources:
region = resource.region
@@ -640,11 +804,6 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
# If not available, use regions from findings
regions = set(check_status_by_region.keys())
# Get compliance template for the provider
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
provider_instance.provider
]
# Create compliance data by region
compliance_overview_by_region = {
region: deepcopy(compliance_template) for region in regions
@@ -663,50 +822,53 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
status,
)
# Prepare compliance requirement objects
compliance_requirement_objects = []
# Prepare compliance requirement rows
compliance_requirement_rows: list[dict[str, Any]] = []
utc_datetime_now = datetime.now(tz=timezone.utc)
for region, compliance_data in compliance_overview_by_region.items():
for compliance_id, compliance in compliance_data.items():
modeled_framework = (
compliance["framework"].lower().replace("-", "").replace("_", "")
modeled_compliance_id = _normalized_compliance_key(
compliance["framework"], compliance["version"]
)
modeled_version = (
compliance["version"].lower().replace("-", "").replace("_", "")
)
modeled_compliance_id = f"{modeled_framework}{modeled_version}"
# Create an overview record for each requirement within each compliance framework
for requirement_id, requirement in compliance["requirements"].items():
compliance_requirement_objects.append(
ComplianceRequirementOverview(
tenant_id=tenant_id,
scan=scan_instance,
region=region,
compliance_id=compliance_id,
framework=compliance["framework"],
version=compliance["version"],
requirement_id=requirement_id,
description=requirement["description"],
passed_checks=requirement["checks_status"]["pass"],
failed_checks=requirement["checks_status"]["fail"],
total_checks=requirement["checks_status"]["total"],
requirement_status=requirement["status"],
passed_findings=findings_count_by_compliance.get(region, {})
checks_status = requirement["checks_status"]
compliance_requirement_rows.append(
{
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": utc_datetime_now,
"compliance_id": compliance_id,
"framework": compliance["framework"],
"version": compliance["version"] or "",
"description": requirement.get("description") or "",
"region": region,
"requirement_id": requirement_id,
"requirement_status": requirement["status"],
"passed_checks": checks_status["pass"],
"failed_checks": checks_status["fail"],
"total_checks": checks_status["total"],
"scan_id": scan_instance.id,
"passed_findings": findings_count_by_compliance.get(
region, {}
)
.get(modeled_compliance_id, {})
.get(requirement_id, {})
.get("pass", 0),
total_findings=findings_count_by_compliance.get(region, {})
"total_findings": findings_count_by_compliance.get(
region, {}
)
.get(modeled_compliance_id, {})
.get(requirement_id, {})
.get("total", 0),
)
}
)
# Bulk create requirement records
create_objects_in_batches(
tenant_id, ComplianceRequirementOverview, compliance_requirement_objects
)
# Bulk create requirement records using PostgreSQL COPY
_persist_compliance_requirement_rows(tenant_id, compliance_requirement_rows)
return {
"requirements_created": len(compliance_requirement_objects),
"requirements_created": len(compliance_requirement_rows),
"regions_processed": list(regions),
"compliance_frameworks": (
list(compliance_overview_by_region.get(list(regions)[0], {}).keys())

View File

@@ -34,6 +34,7 @@ from tasks.jobs.scan import (
from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
@@ -343,70 +344,73 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
.order_by("uid")
.iterator()
)
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
fos = [FindingOutput.transform_api_finding(f, prowler_provider) for f in batch]
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
fos = [
FindingOutput.transform_api_finding(f, prowler_provider) for f in batch
]
# Outputs
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
# Skip ASFF generation if not needed
if mode == "json-asff" and not generate_asff:
continue
# Outputs
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
# Skip ASFF generation if not needed
if mode == "json-asff" and not generate_asff:
continue
cls = cfg["class"]
suffix = cfg["suffix"]
extra = cfg.get("kwargs", {}).copy()
if mode == "html":
extra.update(provider=prowler_provider, stats=scan_summary)
cls = cfg["class"]
suffix = cfg["suffix"]
extra = cfg.get("kwargs", {}).copy()
if mode == "html":
extra.update(provider=prowler_provider, stats=scan_summary)
writer, initialization = get_writer(
output_writers,
cls,
lambda cls=cls, fos=fos, suffix=suffix: cls(
findings=fos,
file_path=out_dir,
file_extension=suffix,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos)
writer.batch_write_data_to_file(**extra)
writer._data.clear()
writer, initialization = get_writer(
output_writers,
cls,
lambda cls=cls, fos=fos, suffix=suffix: cls(
findings=fos,
file_path=out_dir,
file_extension=suffix,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos)
writer.batch_write_data_to_file(**extra)
writer._data.clear()
# Compliance CSVs
for name in frameworks_avail:
compliance_obj = frameworks_bulk[name]
# Compliance CSVs
for name in frameworks_avail:
compliance_obj = frameworks_bulk[name]
klass = GenericCompliance
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
if condition(name):
klass = cls
break
klass = GenericCompliance
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
if condition(name):
klass = cls
break
filename = f"{comp_dir}_{name}.csv"
filename = f"{comp_dir}_{name}.csv"
writer, initialization = get_writer(
compliance_writers,
name,
lambda klass=klass, fos=fos: klass(
findings=fos,
compliance=compliance_obj,
file_path=filename,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos, compliance_obj, name)
writer.batch_write_data_to_file()
writer._data.clear()
writer, initialization = get_writer(
compliance_writers,
name,
lambda klass=klass, fos=fos: klass(
findings=fos,
compliance=compliance_obj,
file_path=filename,
from_cli=False,
),
is_last,
)
if not initialization:
writer.transform(fos, compliance_obj, name)
writer.batch_write_data_to_file()
writer._data.clear()
compressed = _compress_output_files(out_dir)
upload_uri = _upload_to_s3(tenant_id, compressed, scan_id)
# S3 integrations (need output_directory)
with rls_transaction(tenant_id):
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
s3_integrations = Integration.objects.filter(
integrationproviderrelationship__provider_id=provider_id,
integration_type=Integration.IntegrationChoices.AMAZON_S3,

View File

@@ -1,17 +1,22 @@
import csv
import json
import uuid
from datetime import datetime
from datetime import datetime, timezone
from io import StringIO
from unittest.mock import MagicMock, patch
import pytest
from tasks.jobs.scan import (
_copy_compliance_requirement_rows,
_create_finding_delta,
_persist_compliance_requirement_rows,
_store_resources,
create_compliance_requirements,
perform_prowler_scan,
)
from tasks.utils import CustomEncoder
from api.db_router import MainRouter
from api.exceptions import ProviderConnectionError
from api.models import Finding, Provider, Resource, Scan, StateChoices, StatusChoices
from prowler.lib.check.models import Severity
@@ -1045,3 +1050,773 @@ class TestCreateComplianceRequirements:
assert "requirements_created" in result
assert result["requirements_created"] >= 0
class TestComplianceRequirementCopy:
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_streams_csv(
self, mock_psycopg_connection, settings
):
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
captured = {}
def copy_side_effect(sql, file_obj):
captured["sql"] = sql
captured["data"] = file_obj.read()
cursor.copy_expert.side_effect = copy_side_effect
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"compliance_id": "cisa_aws",
"framework": "CISA",
"version": None,
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
with patch.object(MainRouter, "admin_db", "admin"):
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
mock_psycopg_connection.assert_called_once_with("admin")
connection.cursor.assert_called_once()
cursor.execute.assert_called_once()
cursor.copy_expert.assert_called_once()
csv_rows = list(csv.reader(StringIO(captured["data"])))
assert csv_rows[0][0] == str(row["id"])
assert csv_rows[0][5] == ""
assert csv_rows[0][-1] == str(row["scan_id"])
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
@patch("tasks.jobs.scan.rls_transaction")
@patch(
"tasks.jobs.scan._copy_compliance_requirement_rows",
side_effect=Exception("copy failed"),
)
def test_persist_compliance_requirement_rows_fallback(
self, mock_copy, mock_rls_transaction, mock_bulk_create
):
inserted_at = datetime.now(timezone.utc)
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"inserted_at": inserted_at,
"compliance_id": "cisa_aws",
"framework": "CISA",
"version": "1.0",
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
tenant_id = row["tenant_id"]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
_persist_compliance_requirement_rows(tenant_id, [row])
mock_copy.assert_called_once_with(tenant_id, [row])
mock_rls_transaction.assert_called_once_with(tenant_id)
mock_bulk_create.assert_called_once()
args, kwargs = mock_bulk_create.call_args
objects = args[0]
assert len(objects) == 1
fallback = objects[0]
assert fallback.version == row["version"]
assert fallback.compliance_id == row["compliance_id"]
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
@patch("tasks.jobs.scan.rls_transaction")
@patch("tasks.jobs.scan._copy_compliance_requirement_rows")
def test_persist_compliance_requirement_rows_no_rows(
self, mock_copy, mock_rls_transaction, mock_bulk_create
):
_persist_compliance_requirement_rows(str(uuid.uuid4()), [])
mock_copy.assert_not_called()
mock_rls_transaction.assert_not_called()
mock_bulk_create.assert_not_called()
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_multiple_rows(
self, mock_psycopg_connection, settings
):
"""Test COPY with multiple rows to ensure batch processing works correctly."""
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
captured = {}
def copy_side_effect(sql, file_obj):
captured["sql"] = sql
captured["data"] = file_obj.read()
cursor.copy_expert.side_effect = copy_side_effect
tenant_id = str(uuid.uuid4())
scan_id = uuid.uuid4()
inserted_at = datetime.now(timezone.utc)
rows = [
{
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": inserted_at,
"compliance_id": "cisa_aws",
"framework": "CISA",
"version": "1.0",
"description": "First requirement",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 5,
"failed_checks": 0,
"total_checks": 5,
"scan_id": scan_id,
},
{
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": inserted_at,
"compliance_id": "cisa_aws",
"framework": "CISA",
"version": "1.0",
"description": "Second requirement",
"region": "us-west-2",
"requirement_id": "req-2",
"requirement_status": "FAIL",
"passed_checks": 3,
"failed_checks": 2,
"total_checks": 5,
"scan_id": scan_id,
},
{
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": inserted_at,
"compliance_id": "aws_foundational_security_aws",
"framework": "AWS-Foundational-Security-Best-Practices",
"version": "2.0",
"description": "Third requirement",
"region": "eu-west-1",
"requirement_id": "req-3",
"requirement_status": "MANUAL",
"passed_checks": 0,
"failed_checks": 0,
"total_checks": 3,
"scan_id": scan_id,
},
]
with patch.object(MainRouter, "admin_db", "admin"):
_copy_compliance_requirement_rows(tenant_id, rows)
mock_psycopg_connection.assert_called_once_with("admin")
connection.cursor.assert_called_once()
cursor.execute.assert_called_once()
cursor.copy_expert.assert_called_once()
csv_rows = list(csv.reader(StringIO(captured["data"])))
assert len(csv_rows) == 3
# Validate first row
assert csv_rows[0][0] == str(rows[0]["id"])
assert csv_rows[0][1] == tenant_id
assert csv_rows[0][3] == "cisa_aws"
assert csv_rows[0][4] == "CISA"
assert csv_rows[0][6] == "First requirement"
assert csv_rows[0][7] == "us-east-1"
assert csv_rows[0][10] == "5"
assert csv_rows[0][11] == "0"
assert csv_rows[0][12] == "5"
# Validate second row
assert csv_rows[1][0] == str(rows[1]["id"])
assert csv_rows[1][7] == "us-west-2"
assert csv_rows[1][9] == "FAIL"
assert csv_rows[1][10] == "3"
assert csv_rows[1][11] == "2"
# Validate third row
assert csv_rows[2][0] == str(rows[2]["id"])
assert csv_rows[2][3] == "aws_foundational_security_aws"
assert csv_rows[2][5] == "2.0"
assert csv_rows[2][9] == "MANUAL"
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_null_values(
self, mock_psycopg_connection, settings
):
"""Test COPY handles NULL/None values correctly in nullable fields."""
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
captured = {}
def copy_side_effect(sql, file_obj):
captured["sql"] = sql
captured["data"] = file_obj.read()
cursor.copy_expert.side_effect = copy_side_effect
# Row with all nullable fields set to None/empty
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"compliance_id": "test_framework",
"framework": "Test",
"version": None, # nullable
"description": None, # nullable
"region": "",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 0,
"failed_checks": 0,
"total_checks": 0,
"scan_id": uuid.uuid4(),
}
with patch.object(MainRouter, "admin_db", "admin"):
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
csv_rows = list(csv.reader(StringIO(captured["data"])))
assert len(csv_rows) == 1
# Validate that None values are converted to empty strings in CSV
assert csv_rows[0][5] == "" # version
assert csv_rows[0][6] == "" # description
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_special_characters(
self, mock_psycopg_connection, settings
):
"""Test COPY correctly escapes special characters in CSV."""
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
captured = {}
def copy_side_effect(sql, file_obj):
captured["sql"] = sql
captured["data"] = file_obj.read()
cursor.copy_expert.side_effect = copy_side_effect
# Row with special characters that need escaping
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"compliance_id": 'framework"with"quotes',
"framework": "Framework,with,commas",
"version": "1.0",
"description": 'Description with "quotes", commas, and\nnewlines',
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
with patch.object(MainRouter, "admin_db", "admin"):
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
# Verify CSV was generated (csv module handles escaping automatically)
csv_rows = list(csv.reader(StringIO(captured["data"])))
assert len(csv_rows) == 1
# Verify special characters are preserved after CSV parsing
assert csv_rows[0][3] == 'framework"with"quotes'
assert csv_rows[0][4] == "Framework,with,commas"
assert "quotes" in csv_rows[0][6]
assert "commas" in csv_rows[0][6]
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_missing_inserted_at(
self, mock_psycopg_connection, settings
):
"""Test COPY uses current datetime when inserted_at is missing."""
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
captured = {}
def copy_side_effect(sql, file_obj):
captured["sql"] = sql
captured["data"] = file_obj.read()
cursor.copy_expert.side_effect = copy_side_effect
# Row without inserted_at field
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"compliance_id": "test_framework",
"framework": "Test",
"version": "1.0",
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
# Note: inserted_at is intentionally missing
}
before_call = datetime.now(timezone.utc)
with patch.object(MainRouter, "admin_db", "admin"):
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
after_call = datetime.now(timezone.utc)
csv_rows = list(csv.reader(StringIO(captured["data"])))
assert len(csv_rows) == 1
# Verify inserted_at was auto-generated and is a valid ISO datetime
inserted_at_str = csv_rows[0][2]
inserted_at = datetime.fromisoformat(inserted_at_str)
assert before_call <= inserted_at <= after_call
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_transaction_rollback_on_copy_error(
self, mock_psycopg_connection, settings
):
"""Test transaction is rolled back when copy_expert fails."""
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
# Simulate copy_expert failure
cursor.copy_expert.side_effect = Exception("COPY command failed")
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"compliance_id": "test",
"framework": "Test",
"version": "1.0",
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
with patch.object(MainRouter, "admin_db", "admin"):
with pytest.raises(Exception, match="COPY command failed"):
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
# Verify rollback was called
connection.rollback.assert_called_once()
connection.commit.assert_not_called()
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_transaction_rollback_on_set_config_error(
self, mock_psycopg_connection, settings
):
"""Test transaction is rolled back when SET_CONFIG fails."""
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
# Simulate cursor.execute failure
cursor.execute.side_effect = Exception("SET prowler.tenant_id failed")
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"compliance_id": "test",
"framework": "Test",
"version": "1.0",
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
with patch.object(MainRouter, "admin_db", "admin"):
with pytest.raises(Exception, match="SET prowler.tenant_id failed"):
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
# Verify rollback was called
connection.rollback.assert_called_once()
connection.commit.assert_not_called()
@patch("tasks.jobs.scan.psycopg_connection")
def test_copy_compliance_requirement_rows_commit_on_success(
self, mock_psycopg_connection, settings
):
"""Test transaction is committed on successful COPY."""
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
connection = MagicMock()
cursor = MagicMock()
cursor_context = MagicMock()
cursor_context.__enter__.return_value = cursor
cursor_context.__exit__.return_value = False
connection.cursor.return_value = cursor_context
connection.__enter__.return_value = connection
connection.__exit__.return_value = False
context_manager = MagicMock()
context_manager.__enter__.return_value = connection
context_manager.__exit__.return_value = False
mock_psycopg_connection.return_value = context_manager
cursor.copy_expert.return_value = None # Success
row = {
"id": uuid.uuid4(),
"tenant_id": str(uuid.uuid4()),
"compliance_id": "test",
"framework": "Test",
"version": "1.0",
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
with patch.object(MainRouter, "admin_db", "admin"):
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
# Verify commit was called and rollback was not
connection.commit.assert_called_once()
connection.rollback.assert_not_called()
# Verify autocommit was disabled
assert connection.autocommit is False
@patch("tasks.jobs.scan._copy_compliance_requirement_rows")
def test_persist_compliance_requirement_rows_success(self, mock_copy):
"""Test successful COPY path without fallback to ORM."""
mock_copy.return_value = None # Success, no exception
tenant_id = str(uuid.uuid4())
rows = [
{
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": datetime.now(timezone.utc),
"compliance_id": "test",
"framework": "Test",
"version": "1.0",
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
]
_persist_compliance_requirement_rows(tenant_id, rows)
# Verify COPY was called
mock_copy.assert_called_once_with(tenant_id, rows)
@patch("tasks.jobs.scan.logger")
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
@patch("tasks.jobs.scan.rls_transaction")
@patch(
"tasks.jobs.scan._copy_compliance_requirement_rows",
side_effect=Exception("COPY failed"),
)
def test_persist_compliance_requirement_rows_fallback_logging(
self, mock_copy, mock_rls_transaction, mock_bulk_create, mock_logger
):
"""Test logger.exception is called when COPY fails and fallback occurs."""
tenant_id = str(uuid.uuid4())
row = {
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": datetime.now(timezone.utc),
"compliance_id": "test",
"framework": "Test",
"version": "1.0",
"description": "desc",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 1,
"failed_checks": 0,
"total_checks": 1,
"scan_id": uuid.uuid4(),
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
_persist_compliance_requirement_rows(tenant_id, [row])
# Verify logger.exception was called
mock_logger.exception.assert_called_once()
args, kwargs = mock_logger.exception.call_args
assert "COPY bulk insert" in args[0]
assert "falling back to ORM" in args[0]
assert kwargs.get("exc_info") is not None
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
@patch("tasks.jobs.scan.rls_transaction")
@patch(
"tasks.jobs.scan._copy_compliance_requirement_rows",
side_effect=Exception("copy failed"),
)
def test_persist_compliance_requirement_rows_fallback_multiple_rows(
self, mock_copy, mock_rls_transaction, mock_bulk_create
):
"""Test ORM fallback with multiple rows."""
tenant_id = str(uuid.uuid4())
scan_id = uuid.uuid4()
inserted_at = datetime.now(timezone.utc)
rows = [
{
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": inserted_at,
"compliance_id": "cisa_aws",
"framework": "CISA",
"version": "1.0",
"description": "First requirement",
"region": "us-east-1",
"requirement_id": "req-1",
"requirement_status": "PASS",
"passed_checks": 5,
"failed_checks": 0,
"total_checks": 5,
"scan_id": scan_id,
},
{
"id": uuid.uuid4(),
"tenant_id": tenant_id,
"inserted_at": inserted_at,
"compliance_id": "cisa_aws",
"framework": "CISA",
"version": "1.0",
"description": "Second requirement",
"region": "us-west-2",
"requirement_id": "req-2",
"requirement_status": "FAIL",
"passed_checks": 2,
"failed_checks": 3,
"total_checks": 5,
"scan_id": scan_id,
},
]
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
_persist_compliance_requirement_rows(tenant_id, rows)
mock_copy.assert_called_once_with(tenant_id, rows)
mock_rls_transaction.assert_called_once_with(tenant_id)
mock_bulk_create.assert_called_once()
args, kwargs = mock_bulk_create.call_args
objects = args[0]
assert len(objects) == 2
assert kwargs["batch_size"] == 500
# Validate first object
assert objects[0].id == rows[0]["id"]
assert objects[0].tenant_id == rows[0]["tenant_id"]
assert objects[0].compliance_id == rows[0]["compliance_id"]
assert objects[0].framework == rows[0]["framework"]
assert objects[0].region == rows[0]["region"]
assert objects[0].passed_checks == 5
assert objects[0].failed_checks == 0
# Validate second object
assert objects[1].id == rows[1]["id"]
assert objects[1].requirement_id == rows[1]["requirement_id"]
assert objects[1].requirement_status == rows[1]["requirement_status"]
assert objects[1].passed_checks == 2
assert objects[1].failed_checks == 3
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
@patch("tasks.jobs.scan.rls_transaction")
@patch(
"tasks.jobs.scan._copy_compliance_requirement_rows",
side_effect=Exception("copy failed"),
)
def test_persist_compliance_requirement_rows_fallback_all_fields(
self, mock_copy, mock_rls_transaction, mock_bulk_create
):
"""Test ORM fallback correctly maps all fields from row dict to model."""
tenant_id = str(uuid.uuid4())
row_id = uuid.uuid4()
scan_id = uuid.uuid4()
inserted_at = datetime.now(timezone.utc)
row = {
"id": row_id,
"tenant_id": tenant_id,
"inserted_at": inserted_at,
"compliance_id": "aws_foundational_security_aws",
"framework": "AWS-Foundational-Security-Best-Practices",
"version": "2.0",
"description": "Ensure MFA is enabled",
"region": "eu-west-1",
"requirement_id": "iam.1",
"requirement_status": "FAIL",
"passed_checks": 10,
"failed_checks": 5,
"total_checks": 15,
"scan_id": scan_id,
}
ctx = MagicMock()
ctx.__enter__.return_value = None
ctx.__exit__.return_value = False
mock_rls_transaction.return_value = ctx
_persist_compliance_requirement_rows(tenant_id, [row])
args, kwargs = mock_bulk_create.call_args
objects = args[0]
assert len(objects) == 1
obj = objects[0]
# Validate ALL fields are correctly mapped
assert obj.id == row_id
assert obj.tenant_id == tenant_id
assert obj.inserted_at == inserted_at
assert obj.compliance_id == "aws_foundational_security_aws"
assert obj.framework == "AWS-Foundational-Security-Best-Practices"
assert obj.version == "2.0"
assert obj.description == "Ensure MFA is enabled"
assert obj.region == "eu-west-1"
assert obj.requirement_id == "iam.1"
assert obj.requirement_status == "FAIL"
assert obj.passed_checks == 10
assert obj.failed_checks == 5
assert obj.total_checks == 15
assert obj.scan_id == scan_id