mirror of
https://github.com/prowler-cloud/prowler.git
synced 2025-12-19 05:17:47 +00:00
feat(database): add db read replica support (#8869)
This commit is contained in:
committed by
GitHub
parent
046baa8eb9
commit
335db928dc
6
.env
6
.env
@@ -29,6 +29,12 @@ POSTGRES_ADMIN_PASSWORD=postgres
|
||||
POSTGRES_USER=prowler
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=prowler_db
|
||||
# Read replica settings (optional)
|
||||
# POSTGRES_REPLICA_HOST=postgres-db
|
||||
# POSTGRES_REPLICA_PORT=5432
|
||||
# POSTGRES_REPLICA_USER=prowler
|
||||
# POSTGRES_REPLICA_PASSWORD=postgres
|
||||
# POSTGRES_REPLICA_DB=prowler_db
|
||||
|
||||
# Celery-Prowler task settings
|
||||
TASK_RETRY_DELAY_SECONDS=0.1
|
||||
|
||||
@@ -11,6 +11,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
- API Key support [(#8805)](https://github.com/prowler-cloud/prowler/pull/8805)
|
||||
- SAML role mapping protection for single-admin tenants to prevent accidental lockout [(#8882)](https://github.com/prowler-cloud/prowler/pull/8882)
|
||||
- Support for `passed_findings` and `total_findings` fields in compliance requirement overview for accurate Prowler ThreatScore calculation [(#8582)](https://github.com/prowler-cloud/prowler/pull/8582)
|
||||
- Database read replica support [(#8869)](https://github.com/prowler-cloud/prowler/pull/8869)
|
||||
|
||||
### Changed
|
||||
- Now the MANAGE_ACCOUNT permission is required to modify or read user permissions instead of MANAGE_USERS [(#8281)](https://github.com/prowler-cloud/prowler/pull/8281)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import transaction
|
||||
from rest_framework import permissions
|
||||
from rest_framework.exceptions import NotAuthenticated
|
||||
from rest_framework.filters import SearchFilter
|
||||
from rest_framework.permissions import SAFE_METHODS
|
||||
from rest_framework_json_api import filters
|
||||
from rest_framework_json_api.views import ModelViewSet
|
||||
|
||||
from api.authentication import CombinedJWTOrAPIKeyAuthentication
|
||||
from api.db_router import MainRouter
|
||||
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
|
||||
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
|
||||
from api.filters import CustomDjangoFilterBackend
|
||||
from api.models import Role, Tenant
|
||||
@@ -31,6 +33,20 @@ class BaseViewSet(ModelViewSet):
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def _get_request_db_alias(self, request):
|
||||
if request is None:
|
||||
return MainRouter.default_db
|
||||
|
||||
read_alias = (
|
||||
MainRouter.replica_db
|
||||
if request.method in SAFE_METHODS
|
||||
and MainRouter.replica_db in settings.DATABASES
|
||||
else None
|
||||
)
|
||||
if read_alias:
|
||||
return read_alias
|
||||
return MainRouter.default_db
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
@@ -48,8 +64,21 @@ class BaseViewSet(ModelViewSet):
|
||||
|
||||
class BaseRLSViewSet(BaseViewSet):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
with transaction.atomic():
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
self.db_alias = self._get_request_db_alias(request)
|
||||
alias_token = None
|
||||
try:
|
||||
if self.db_alias != MainRouter.default_db:
|
||||
alias_token = set_read_db_alias(self.db_alias)
|
||||
|
||||
if request is not None:
|
||||
request.db_alias = self.db_alias
|
||||
|
||||
with transaction.atomic(using=self.db_alias):
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
finally:
|
||||
if alias_token is not None:
|
||||
reset_read_db_alias(alias_token)
|
||||
self.db_alias = MainRouter.default_db
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
# Ideally, this logic would be in the `.setup()` method but DRF view sets don't call it
|
||||
@@ -61,7 +90,9 @@ class BaseRLSViewSet(BaseViewSet):
|
||||
if tenant_id is None:
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(
|
||||
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
|
||||
):
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
@@ -73,18 +104,33 @@ class BaseRLSViewSet(BaseViewSet):
|
||||
|
||||
class BaseTenantViewset(BaseViewSet):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
with transaction.atomic():
|
||||
tenant = super().dispatch(request, *args, **kwargs)
|
||||
|
||||
self.db_alias = self._get_request_db_alias(request)
|
||||
alias_token = None
|
||||
try:
|
||||
# If the request is a POST, create the admin role
|
||||
if request.method == "POST":
|
||||
isinstance(tenant, dict) and self._create_admin_role(tenant.data["id"])
|
||||
except Exception as e:
|
||||
self._handle_creation_error(e, tenant)
|
||||
raise
|
||||
if self.db_alias != MainRouter.default_db:
|
||||
alias_token = set_read_db_alias(self.db_alias)
|
||||
|
||||
return tenant
|
||||
if request is not None:
|
||||
request.db_alias = self.db_alias
|
||||
|
||||
with transaction.atomic(using=self.db_alias):
|
||||
tenant = super().dispatch(request, *args, **kwargs)
|
||||
|
||||
try:
|
||||
# If the request is a POST, create the admin role
|
||||
if request.method == "POST":
|
||||
isinstance(tenant, dict) and self._create_admin_role(
|
||||
tenant.data["id"]
|
||||
)
|
||||
except Exception as e:
|
||||
self._handle_creation_error(e, tenant)
|
||||
raise
|
||||
|
||||
return tenant
|
||||
finally:
|
||||
if alias_token is not None:
|
||||
reset_read_db_alias(alias_token)
|
||||
self.db_alias = MainRouter.default_db
|
||||
|
||||
def _create_admin_role(self, tenant_id):
|
||||
Role.objects.using(MainRouter.admin_db).create(
|
||||
@@ -117,14 +163,31 @@ class BaseTenantViewset(BaseViewSet):
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
user_id = str(request.user.id)
|
||||
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
|
||||
with rls_transaction(
|
||||
value=user_id,
|
||||
parameter=POSTGRES_USER_VAR,
|
||||
using=getattr(self, "db_alias", MainRouter.default_db),
|
||||
):
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
|
||||
class BaseUserViewset(BaseViewSet):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
with transaction.atomic():
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
self.db_alias = self._get_request_db_alias(request)
|
||||
alias_token = None
|
||||
try:
|
||||
if self.db_alias != MainRouter.default_db:
|
||||
alias_token = set_read_db_alias(self.db_alias)
|
||||
|
||||
if request is not None:
|
||||
request.db_alias = self.db_alias
|
||||
|
||||
with transaction.atomic(using=self.db_alias):
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
finally:
|
||||
if alias_token is not None:
|
||||
reset_read_db_alias(alias_token)
|
||||
self.db_alias = MainRouter.default_db
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
# TODO refactor after improving RLS on users
|
||||
@@ -137,6 +200,8 @@ class BaseUserViewset(BaseViewSet):
|
||||
if tenant_id is None:
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(
|
||||
tenant_id, using=getattr(self, "db_alias", MainRouter.default_db)
|
||||
):
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
@@ -150,12 +150,16 @@ def generate_scan_compliance(
|
||||
requirement["checks"][check_id] = status
|
||||
requirement["checks_status"][status.lower()] += 1
|
||||
|
||||
if requirement["status"] != "FAIL" and any(
|
||||
value == "FAIL" for value in requirement["checks"].values()
|
||||
):
|
||||
requirement["status"] = "FAIL"
|
||||
compliance_overview[compliance_id]["requirements_status"]["passed"] -= 1
|
||||
compliance_overview[compliance_id]["requirements_status"]["failed"] += 1
|
||||
if requirement["status"] != "FAIL" and any(
|
||||
value == "FAIL" for value in requirement["checks"].values()
|
||||
):
|
||||
requirement["status"] = "FAIL"
|
||||
compliance_overview[compliance_id]["requirements_status"][
|
||||
"passed"
|
||||
] -= 1
|
||||
compliance_overview[compliance_id]["requirements_status"][
|
||||
"failed"
|
||||
] += 1
|
||||
|
||||
|
||||
def generate_compliance_overview_template(prowler_compliance: dict):
|
||||
|
||||
@@ -1,9 +1,31 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
ALLOWED_APPS = ("django", "socialaccount", "account", "authtoken", "silk")
|
||||
|
||||
_read_db_alias = ContextVar("read_db_alias", default=None)
|
||||
|
||||
|
||||
def set_read_db_alias(alias: str | None):
|
||||
if not alias:
|
||||
return None
|
||||
return _read_db_alias.set(alias)
|
||||
|
||||
|
||||
def get_read_db_alias() -> str | None:
|
||||
return _read_db_alias.get()
|
||||
|
||||
|
||||
def reset_read_db_alias(token) -> None:
|
||||
if token is not None:
|
||||
_read_db_alias.reset(token)
|
||||
|
||||
|
||||
class MainRouter:
|
||||
default_db = "default"
|
||||
admin_db = "admin"
|
||||
replica_db = "replica"
|
||||
|
||||
def db_for_read(self, model, **hints): # noqa: F841
|
||||
model_table_name = model._meta.db_table
|
||||
@@ -11,6 +33,9 @@ class MainRouter:
|
||||
model_table_name.startswith(f"{app}_") for app in ALLOWED_APPS
|
||||
):
|
||||
return self.admin_db
|
||||
read_alias = get_read_db_alias()
|
||||
if read_alias:
|
||||
return read_alias
|
||||
return None
|
||||
|
||||
def db_for_write(self, model, **hints): # noqa: F841
|
||||
@@ -27,3 +52,8 @@ class MainRouter:
|
||||
if {obj1._state.db, obj2._state.db} <= {self.default_db, self.admin_db}:
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
READ_REPLICA_ALIAS = (
|
||||
MainRouter.replica_db if MainRouter.replica_db in settings.DATABASES else None
|
||||
)
|
||||
|
||||
@@ -6,12 +6,14 @@ from datetime import datetime, timedelta, timezone
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import BaseUserManager
|
||||
from django.db import connection, models, transaction
|
||||
from django.db import DEFAULT_DB_ALIAS, connection, connections, models, transaction
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
from psycopg2 import connect as psycopg2_connect
|
||||
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
from api.db_router import get_read_db_alias, reset_read_db_alias, set_read_db_alias
|
||||
|
||||
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
|
||||
DB_PASSWORD = (
|
||||
settings.DATABASES["default"]["PASSWORD"] if not settings.TESTING else "test"
|
||||
@@ -49,7 +51,11 @@ def psycopg_connection(database_alias: str):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
|
||||
def rls_transaction(
|
||||
value: str,
|
||||
parameter: str = POSTGRES_TENANT_VAR,
|
||||
using: str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
|
||||
if the value is a valid UUID.
|
||||
@@ -57,16 +63,32 @@ def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
|
||||
Args:
|
||||
value (str): Database configuration parameter value.
|
||||
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
|
||||
using (str | None): Optional database alias to run the transaction against. Defaults to the
|
||||
active read alias (if any) or Django's default connection.
|
||||
"""
|
||||
with transaction.atomic():
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# just in case the value is a UUID object
|
||||
uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
yield cursor
|
||||
requested_alias = using or get_read_db_alias()
|
||||
db_alias = requested_alias or DEFAULT_DB_ALIAS
|
||||
if db_alias not in connections:
|
||||
db_alias = DEFAULT_DB_ALIAS
|
||||
|
||||
router_token = None
|
||||
try:
|
||||
if db_alias != DEFAULT_DB_ALIAS:
|
||||
router_token = set_read_db_alias(db_alias)
|
||||
|
||||
with transaction.atomic(using=db_alias):
|
||||
conn = connections[db_alias]
|
||||
with conn.cursor() as cursor:
|
||||
try:
|
||||
# just in case the value is a UUID object
|
||||
uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
yield cursor
|
||||
finally:
|
||||
if router_token is not None:
|
||||
reset_read_db_alias(router_token)
|
||||
|
||||
|
||||
class CustomUserManager(BaseUserManager):
|
||||
|
||||
@@ -11,11 +11,12 @@ class APIJSONRenderer(JSONRenderer):
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||
request = renderer_context.get("request")
|
||||
tenant_id = getattr(request, "tenant_id", None) if request else None
|
||||
db_alias = getattr(request, "db_alias", None) if request else None
|
||||
include_param_present = "include" in request.query_params if request else False
|
||||
|
||||
# Use rls_transaction if needed for included resources, otherwise do nothing
|
||||
context_manager = (
|
||||
rls_transaction(tenant_id)
|
||||
rls_transaction(tenant_id, using=db_alias)
|
||||
if tenant_id and include_param_present
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
@@ -3449,20 +3449,16 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
)
|
||||
filtered_queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
all_requirements = (
|
||||
filtered_queryset.values(
|
||||
"requirement_id",
|
||||
"framework",
|
||||
"version",
|
||||
"description",
|
||||
"passed_findings",
|
||||
"total_findings",
|
||||
)
|
||||
.distinct()
|
||||
.annotate(
|
||||
total_instances=Count("id"),
|
||||
manual_count=Count("id", filter=Q(requirement_status="MANUAL")),
|
||||
)
|
||||
all_requirements = filtered_queryset.values(
|
||||
"requirement_id",
|
||||
"framework",
|
||||
"version",
|
||||
"description",
|
||||
).annotate(
|
||||
total_instances=Count("id"),
|
||||
manual_count=Count("id", filter=Q(requirement_status="MANUAL")),
|
||||
passed_findings_sum=Sum("passed_findings"),
|
||||
total_findings_sum=Sum("total_findings"),
|
||||
)
|
||||
|
||||
passed_instances = (
|
||||
@@ -3481,8 +3477,8 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
total_instances = requirement["total_instances"]
|
||||
passed_count = passed_counts.get(requirement_id, 0)
|
||||
is_manual = requirement["manual_count"] == total_instances
|
||||
passed_findings = requirement["passed_findings"]
|
||||
total_findings = requirement["total_findings"]
|
||||
passed_findings = requirement["passed_findings_sum"] or 0
|
||||
total_findings = requirement["total_findings_sum"] or 0
|
||||
if is_manual:
|
||||
requirement_status = "MANUAL"
|
||||
elif passed_count == total_instances:
|
||||
|
||||
@@ -5,24 +5,39 @@ DEBUG = env.bool("DJANGO_DEBUG", default=True)
|
||||
ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["*"])
|
||||
|
||||
# Database
|
||||
default_db_name = env("POSTGRES_DB", default="prowler_db")
|
||||
default_db_user = env("POSTGRES_USER", default="prowler_user")
|
||||
default_db_password = env("POSTGRES_PASSWORD", default="prowler")
|
||||
default_db_host = env("POSTGRES_HOST", default="postgres-db")
|
||||
default_db_port = env("POSTGRES_PORT", default="5432")
|
||||
|
||||
DATABASES = {
|
||||
"prowler_user": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": env("POSTGRES_DB", default="prowler_db"),
|
||||
"USER": env("POSTGRES_USER", default="prowler_user"),
|
||||
"PASSWORD": env("POSTGRES_PASSWORD", default="prowler"),
|
||||
"HOST": env("POSTGRES_HOST", default="postgres-db"),
|
||||
"PORT": env("POSTGRES_PORT", default="5432"),
|
||||
"NAME": default_db_name,
|
||||
"USER": default_db_user,
|
||||
"PASSWORD": default_db_password,
|
||||
"HOST": default_db_host,
|
||||
"PORT": default_db_port,
|
||||
},
|
||||
"admin": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": env("POSTGRES_DB", default="prowler_db"),
|
||||
"NAME": default_db_name,
|
||||
"USER": env("POSTGRES_ADMIN_USER", default="prowler"),
|
||||
"PASSWORD": env("POSTGRES_ADMIN_PASSWORD", default="S3cret"),
|
||||
"HOST": env("POSTGRES_HOST", default="postgres-db"),
|
||||
"PORT": env("POSTGRES_PORT", default="5432"),
|
||||
"HOST": default_db_host,
|
||||
"PORT": default_db_port,
|
||||
},
|
||||
"replica": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": env("POSTGRES_REPLICA_DB", default=default_db_name),
|
||||
"USER": env("POSTGRES_REPLICA_USER", default=default_db_user),
|
||||
"PASSWORD": env("POSTGRES_REPLICA_PASSWORD", default=default_db_password),
|
||||
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
|
||||
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
|
||||
},
|
||||
}
|
||||
|
||||
DATABASES["default"] = DATABASES["prowler_user"]
|
||||
|
||||
REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"] = tuple( # noqa: F405
|
||||
|
||||
@@ -6,22 +6,37 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["localhost", "127.0.0.
|
||||
|
||||
# Database
|
||||
# TODO Use Django database routers https://docs.djangoproject.com/en/5.0/topics/db/multi-db/#automatic-database-routing
|
||||
default_db_name = env("POSTGRES_DB")
|
||||
default_db_user = env("POSTGRES_USER")
|
||||
default_db_password = env("POSTGRES_PASSWORD")
|
||||
default_db_host = env("POSTGRES_HOST")
|
||||
default_db_port = env("POSTGRES_PORT")
|
||||
|
||||
DATABASES = {
|
||||
"prowler_user": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"NAME": env("POSTGRES_DB"),
|
||||
"USER": env("POSTGRES_USER"),
|
||||
"PASSWORD": env("POSTGRES_PASSWORD"),
|
||||
"HOST": env("POSTGRES_HOST"),
|
||||
"PORT": env("POSTGRES_PORT"),
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": default_db_name,
|
||||
"USER": default_db_user,
|
||||
"PASSWORD": default_db_password,
|
||||
"HOST": default_db_host,
|
||||
"PORT": default_db_port,
|
||||
},
|
||||
"admin": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": env("POSTGRES_DB"),
|
||||
"NAME": default_db_name,
|
||||
"USER": env("POSTGRES_ADMIN_USER"),
|
||||
"PASSWORD": env("POSTGRES_ADMIN_PASSWORD"),
|
||||
"HOST": env("POSTGRES_HOST"),
|
||||
"PORT": env("POSTGRES_PORT"),
|
||||
"HOST": default_db_host,
|
||||
"PORT": default_db_port,
|
||||
},
|
||||
"replica": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": env("POSTGRES_REPLICA_DB", default=default_db_name),
|
||||
"USER": env("POSTGRES_REPLICA_USER", default=default_db_user),
|
||||
"PASSWORD": env("POSTGRES_REPLICA_PASSWORD", default=default_db_password),
|
||||
"HOST": env("POSTGRES_REPLICA_HOST", default=default_db_host),
|
||||
"PORT": env("POSTGRES_REPLICA_PORT", default=default_db_port),
|
||||
},
|
||||
}
|
||||
|
||||
DATABASES["default"] = DATABASES["prowler_user"]
|
||||
|
||||
@@ -5,6 +5,7 @@ from celery.utils.log import get_task_logger
|
||||
from config.django.base import DJANGO_FINDINGS_BATCH_SIZE
|
||||
from tasks.utils import batched
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Finding, Integration, Provider
|
||||
from api.utils import initialize_prowler_integration, initialize_prowler_provider
|
||||
@@ -289,7 +290,7 @@ def upload_security_hub_integration(
|
||||
has_findings = False
|
||||
batch_number = 0
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
qs = (
|
||||
Finding.all_objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
.order_by("uid")
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.settings.celery import CELERY_DEADLOCK_ATTEMPTS
|
||||
@@ -14,8 +18,11 @@ from api.compliance import (
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
|
||||
generate_scan_compliance,
|
||||
)
|
||||
from api.db_router import READ_REPLICA_ALIAS, MainRouter
|
||||
from api.db_utils import (
|
||||
create_objects_in_batches,
|
||||
POSTGRES_TENANT_VAR,
|
||||
SET_CONFIG_QUERY,
|
||||
psycopg_connection,
|
||||
rls_transaction,
|
||||
update_objects_in_batches,
|
||||
)
|
||||
@@ -40,6 +47,28 @@ from prowler.lib.scan.scan import Scan as ProwlerScan
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
# Column order must match `ComplianceRequirementOverview` schema in
|
||||
# `api/models.py`. Keep this list minimal but sufficient to populate all
|
||||
# non-nullable fields plus the counters we care about.
|
||||
COMPLIANCE_REQUIREMENT_COPY_COLUMNS = (
|
||||
"id",
|
||||
"tenant_id",
|
||||
"inserted_at",
|
||||
"compliance_id",
|
||||
"framework",
|
||||
"version",
|
||||
"description",
|
||||
"region",
|
||||
"requirement_id",
|
||||
"requirement_status",
|
||||
"passed_checks",
|
||||
"failed_checks",
|
||||
"total_checks",
|
||||
"passed_findings",
|
||||
"total_findings",
|
||||
"scan_id",
|
||||
)
|
||||
|
||||
|
||||
def _create_finding_delta(
|
||||
last_status: FindingStatus | None | str, new_status: FindingStatus | None
|
||||
@@ -107,6 +136,124 @@ def _store_resources(
|
||||
return resource_instance, (resource_instance.uid, resource_instance.region)
|
||||
|
||||
|
||||
def _copy_compliance_requirement_rows(
|
||||
tenant_id: str, rows: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Stream compliance requirement rows into Postgres using COPY.
|
||||
|
||||
We leverage the admin connection (when available) to bypass the COPY + RLS
|
||||
restriction, writing only the fields required by
|
||||
``ComplianceRequirementOverview``.
|
||||
|
||||
Args:
|
||||
tenant_id: Target tenant UUID.
|
||||
rows: List of row dictionaries prepared by
|
||||
:func:`create_compliance_requirements`.
|
||||
"""
|
||||
|
||||
csv_buffer = io.StringIO()
|
||||
writer = csv.writer(csv_buffer)
|
||||
|
||||
datetime_now = datetime.now(tz=timezone.utc)
|
||||
for row in rows:
|
||||
writer.writerow(
|
||||
[
|
||||
str(row.get("id")),
|
||||
str(row.get("tenant_id")),
|
||||
(row.get("inserted_at") or datetime_now).isoformat(),
|
||||
row.get("compliance_id") or "",
|
||||
row.get("framework") or "",
|
||||
row.get("version") or "",
|
||||
row.get("description") or "",
|
||||
row.get("region") or "",
|
||||
row.get("requirement_id") or "",
|
||||
row.get("requirement_status") or "",
|
||||
row.get("passed_checks", 0),
|
||||
row.get("failed_checks", 0),
|
||||
row.get("total_checks", 0),
|
||||
row.get("passed_findings", 0),
|
||||
row.get("total_findings", 0),
|
||||
str(row.get("scan_id")),
|
||||
]
|
||||
)
|
||||
|
||||
csv_buffer.seek(0)
|
||||
copy_sql = (
|
||||
"COPY compliance_requirements_overviews ("
|
||||
+ ", ".join(COMPLIANCE_REQUIREMENT_COPY_COLUMNS)
|
||||
+ ") FROM STDIN WITH (FORMAT CSV, DELIMITER ',', QUOTE '\"', ESCAPE '\"', NULL '\\N')"
|
||||
)
|
||||
|
||||
try:
|
||||
with psycopg_connection(MainRouter.admin_db) as connection:
|
||||
connection.autocommit = False
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
|
||||
cursor.copy_expert(copy_sql, csv_buffer)
|
||||
connection.commit()
|
||||
except Exception:
|
||||
connection.rollback()
|
||||
raise
|
||||
finally:
|
||||
csv_buffer.close()
|
||||
|
||||
|
||||
def _persist_compliance_requirement_rows(
|
||||
tenant_id: str, rows: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Persist compliance requirement rows using COPY with ORM fallback.
|
||||
|
||||
Args:
|
||||
tenant_id: Target tenant UUID.
|
||||
rows: Precomputed row dictionaries that reflect the compliance
|
||||
overview state for a scan.
|
||||
"""
|
||||
if not rows:
|
||||
return
|
||||
|
||||
try:
|
||||
_copy_compliance_requirement_rows(tenant_id, rows)
|
||||
except Exception as error:
|
||||
logger.exception(
|
||||
"COPY bulk insert for compliance requirements failed; falling back to ORM bulk_create",
|
||||
exc_info=error,
|
||||
)
|
||||
fallback_objects = [
|
||||
ComplianceRequirementOverview(
|
||||
id=row["id"],
|
||||
tenant_id=row["tenant_id"],
|
||||
inserted_at=row["inserted_at"],
|
||||
compliance_id=row["compliance_id"],
|
||||
framework=row["framework"],
|
||||
version=row["version"],
|
||||
description=row["description"],
|
||||
region=row["region"],
|
||||
requirement_id=row["requirement_id"],
|
||||
requirement_status=row["requirement_status"],
|
||||
passed_checks=row["passed_checks"],
|
||||
failed_checks=row["failed_checks"],
|
||||
total_checks=row["total_checks"],
|
||||
passed_findings=row.get("passed_findings", 0),
|
||||
total_findings=row.get("total_findings", 0),
|
||||
scan_id=row["scan_id"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceRequirementOverview.objects.bulk_create(
|
||||
fallback_objects, batch_size=500
|
||||
)
|
||||
|
||||
|
||||
def _normalized_compliance_key(framework: str | None, version: str | None) -> str:
|
||||
"""Return normalized identifier used to group compliance totals."""
|
||||
|
||||
normalized_framework = (framework or "").lower().replace("-", "").replace("_", "")
|
||||
normalized_version = (version or "").lower().replace("-", "").replace("_", "")
|
||||
return f"{normalized_framework}{normalized_version}"
|
||||
|
||||
|
||||
def perform_prowler_scan(
|
||||
tenant_id: str,
|
||||
scan_id: str,
|
||||
@@ -143,7 +290,7 @@ def perform_prowler_scan(
|
||||
scan_instance.save()
|
||||
|
||||
# Find the mutelist processor if it exists
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
try:
|
||||
mutelist_processor = Processor.objects.get(
|
||||
tenant_id=tenant_id, processor_type=Processor.ProcessorChoices.MUTELIST
|
||||
@@ -272,7 +419,7 @@ def perform_prowler_scan(
|
||||
unique_resources.add((resource_instance.uid, resource_instance.region))
|
||||
|
||||
# Process finding
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
finding_uid = finding.uid
|
||||
last_first_seen_at = None
|
||||
if finding_uid not in last_status_cache:
|
||||
@@ -305,6 +452,12 @@ def perform_prowler_scan(
|
||||
# If the finding is muted at this time the reason must be the configured Mutelist
|
||||
muted_reason = "Muted by mutelist" if finding.muted else None
|
||||
|
||||
# Increment failed_findings_count cache if the finding status is FAIL and not muted
|
||||
if status == FindingStatus.FAIL and not finding.muted:
|
||||
resource_uid = finding.resource_uid
|
||||
resource_failed_findings_cache[resource_uid] += 1
|
||||
|
||||
with rls_transaction(tenant_id):
|
||||
# Create the finding
|
||||
finding_instance = Finding.objects.create(
|
||||
tenant_id=tenant_id,
|
||||
@@ -325,11 +478,6 @@ def perform_prowler_scan(
|
||||
)
|
||||
finding_instance.add_resources([resource_instance])
|
||||
|
||||
# Increment failed_findings_count cache if the finding status is FAIL and not muted
|
||||
if status == FindingStatus.FAIL and not finding.muted:
|
||||
resource_uid = finding.resource_uid
|
||||
resource_failed_findings_cache[resource_uid] += 1
|
||||
|
||||
# Update scan resource summaries
|
||||
scan_resource_cache.add(
|
||||
(
|
||||
@@ -439,7 +587,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
- muted_new: Muted findings with a delta of 'new'.
|
||||
- muted_changed: Muted findings with a delta of 'changed'.
|
||||
"""
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
findings = Finding.objects.filter(tenant_id=tenant_id, scan_id=scan_id)
|
||||
|
||||
aggregation = findings.values(
|
||||
@@ -582,11 +730,28 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
ValidationError: If tenant_id is not a valid UUID.
|
||||
"""
|
||||
try:
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
scan_instance = Scan.objects.get(pk=scan_id)
|
||||
provider_instance = scan_instance.provider
|
||||
prowler_provider = return_prowler_provider(provider_instance)
|
||||
|
||||
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
|
||||
provider_instance.provider
|
||||
]
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
threatscore_requirements_by_check: dict[str, set[str]] = {}
|
||||
threatscore_framework = compliance_template.get(
|
||||
modeled_threatscore_compliance_id
|
||||
)
|
||||
if threatscore_framework:
|
||||
for requirement_id, requirement in threatscore_framework[
|
||||
"requirements"
|
||||
].items():
|
||||
for check_id in requirement["checks"]:
|
||||
threatscore_requirements_by_check.setdefault(check_id, set()).add(
|
||||
requirement_id
|
||||
)
|
||||
|
||||
# Get check status data by region from findings
|
||||
findings = (
|
||||
Finding.all_objects.filter(scan_id=scan_id, muted=False)
|
||||
@@ -603,8 +768,7 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
|
||||
findings_count_by_compliance = {}
|
||||
check_status_by_region = {}
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
for finding in findings:
|
||||
for resource in finding.small_resources:
|
||||
region = resource.region
|
||||
@@ -640,11 +804,6 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
# If not available, use regions from findings
|
||||
regions = set(check_status_by_region.keys())
|
||||
|
||||
# Get compliance template for the provider
|
||||
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
|
||||
provider_instance.provider
|
||||
]
|
||||
|
||||
# Create compliance data by region
|
||||
compliance_overview_by_region = {
|
||||
region: deepcopy(compliance_template) for region in regions
|
||||
@@ -663,50 +822,53 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
status,
|
||||
)
|
||||
|
||||
# Prepare compliance requirement objects
|
||||
compliance_requirement_objects = []
|
||||
# Prepare compliance requirement rows
|
||||
compliance_requirement_rows: list[dict[str, Any]] = []
|
||||
utc_datetime_now = datetime.now(tz=timezone.utc)
|
||||
for region, compliance_data in compliance_overview_by_region.items():
|
||||
for compliance_id, compliance in compliance_data.items():
|
||||
modeled_framework = (
|
||||
compliance["framework"].lower().replace("-", "").replace("_", "")
|
||||
modeled_compliance_id = _normalized_compliance_key(
|
||||
compliance["framework"], compliance["version"]
|
||||
)
|
||||
modeled_version = (
|
||||
compliance["version"].lower().replace("-", "").replace("_", "")
|
||||
)
|
||||
modeled_compliance_id = f"{modeled_framework}{modeled_version}"
|
||||
# Create an overview record for each requirement within each compliance framework
|
||||
for requirement_id, requirement in compliance["requirements"].items():
|
||||
compliance_requirement_objects.append(
|
||||
ComplianceRequirementOverview(
|
||||
tenant_id=tenant_id,
|
||||
scan=scan_instance,
|
||||
region=region,
|
||||
compliance_id=compliance_id,
|
||||
framework=compliance["framework"],
|
||||
version=compliance["version"],
|
||||
requirement_id=requirement_id,
|
||||
description=requirement["description"],
|
||||
passed_checks=requirement["checks_status"]["pass"],
|
||||
failed_checks=requirement["checks_status"]["fail"],
|
||||
total_checks=requirement["checks_status"]["total"],
|
||||
requirement_status=requirement["status"],
|
||||
passed_findings=findings_count_by_compliance.get(region, {})
|
||||
checks_status = requirement["checks_status"]
|
||||
compliance_requirement_rows.append(
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": utc_datetime_now,
|
||||
"compliance_id": compliance_id,
|
||||
"framework": compliance["framework"],
|
||||
"version": compliance["version"] or "",
|
||||
"description": requirement.get("description") or "",
|
||||
"region": region,
|
||||
"requirement_id": requirement_id,
|
||||
"requirement_status": requirement["status"],
|
||||
"passed_checks": checks_status["pass"],
|
||||
"failed_checks": checks_status["fail"],
|
||||
"total_checks": checks_status["total"],
|
||||
"scan_id": scan_instance.id,
|
||||
"passed_findings": findings_count_by_compliance.get(
|
||||
region, {}
|
||||
)
|
||||
.get(modeled_compliance_id, {})
|
||||
.get(requirement_id, {})
|
||||
.get("pass", 0),
|
||||
total_findings=findings_count_by_compliance.get(region, {})
|
||||
"total_findings": findings_count_by_compliance.get(
|
||||
region, {}
|
||||
)
|
||||
.get(modeled_compliance_id, {})
|
||||
.get(requirement_id, {})
|
||||
.get("total", 0),
|
||||
)
|
||||
}
|
||||
)
|
||||
# Bulk create requirement records
|
||||
create_objects_in_batches(
|
||||
tenant_id, ComplianceRequirementOverview, compliance_requirement_objects
|
||||
)
|
||||
|
||||
# Bulk create requirement records using PostgreSQL COPY
|
||||
_persist_compliance_requirement_rows(tenant_id, compliance_requirement_rows)
|
||||
|
||||
return {
|
||||
"requirements_created": len(compliance_requirement_objects),
|
||||
"requirements_created": len(compliance_requirement_rows),
|
||||
"regions_processed": list(regions),
|
||||
"compliance_frameworks": (
|
||||
list(compliance_overview_by_region.get(list(regions)[0], {}).keys())
|
||||
|
||||
@@ -34,6 +34,7 @@ from tasks.jobs.scan import (
|
||||
from tasks.utils import batched, get_next_execution_datetime
|
||||
|
||||
from api.compliance import get_compliance_frameworks
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import rls_transaction
|
||||
from api.decorators import set_tenant
|
||||
from api.models import Finding, Integration, Provider, Scan, ScanSummary, StateChoices
|
||||
@@ -343,70 +344,73 @@ def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
|
||||
.order_by("uid")
|
||||
.iterator()
|
||||
)
|
||||
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
fos = [FindingOutput.transform_api_finding(f, prowler_provider) for f in batch]
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
for batch, is_last in batched(qs, DJANGO_FINDINGS_BATCH_SIZE):
|
||||
fos = [
|
||||
FindingOutput.transform_api_finding(f, prowler_provider) for f in batch
|
||||
]
|
||||
|
||||
# Outputs
|
||||
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
|
||||
# Skip ASFF generation if not needed
|
||||
if mode == "json-asff" and not generate_asff:
|
||||
continue
|
||||
# Outputs
|
||||
for mode, cfg in OUTPUT_FORMATS_MAPPING.items():
|
||||
# Skip ASFF generation if not needed
|
||||
if mode == "json-asff" and not generate_asff:
|
||||
continue
|
||||
|
||||
cls = cfg["class"]
|
||||
suffix = cfg["suffix"]
|
||||
extra = cfg.get("kwargs", {}).copy()
|
||||
if mode == "html":
|
||||
extra.update(provider=prowler_provider, stats=scan_summary)
|
||||
cls = cfg["class"]
|
||||
suffix = cfg["suffix"]
|
||||
extra = cfg.get("kwargs", {}).copy()
|
||||
if mode == "html":
|
||||
extra.update(provider=prowler_provider, stats=scan_summary)
|
||||
|
||||
writer, initialization = get_writer(
|
||||
output_writers,
|
||||
cls,
|
||||
lambda cls=cls, fos=fos, suffix=suffix: cls(
|
||||
findings=fos,
|
||||
file_path=out_dir,
|
||||
file_extension=suffix,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos)
|
||||
writer.batch_write_data_to_file(**extra)
|
||||
writer._data.clear()
|
||||
writer, initialization = get_writer(
|
||||
output_writers,
|
||||
cls,
|
||||
lambda cls=cls, fos=fos, suffix=suffix: cls(
|
||||
findings=fos,
|
||||
file_path=out_dir,
|
||||
file_extension=suffix,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos)
|
||||
writer.batch_write_data_to_file(**extra)
|
||||
writer._data.clear()
|
||||
|
||||
# Compliance CSVs
|
||||
for name in frameworks_avail:
|
||||
compliance_obj = frameworks_bulk[name]
|
||||
# Compliance CSVs
|
||||
for name in frameworks_avail:
|
||||
compliance_obj = frameworks_bulk[name]
|
||||
|
||||
klass = GenericCompliance
|
||||
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
|
||||
if condition(name):
|
||||
klass = cls
|
||||
break
|
||||
klass = GenericCompliance
|
||||
for condition, cls in COMPLIANCE_CLASS_MAP.get(provider_type, []):
|
||||
if condition(name):
|
||||
klass = cls
|
||||
break
|
||||
|
||||
filename = f"{comp_dir}_{name}.csv"
|
||||
filename = f"{comp_dir}_{name}.csv"
|
||||
|
||||
writer, initialization = get_writer(
|
||||
compliance_writers,
|
||||
name,
|
||||
lambda klass=klass, fos=fos: klass(
|
||||
findings=fos,
|
||||
compliance=compliance_obj,
|
||||
file_path=filename,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos, compliance_obj, name)
|
||||
writer.batch_write_data_to_file()
|
||||
writer._data.clear()
|
||||
writer, initialization = get_writer(
|
||||
compliance_writers,
|
||||
name,
|
||||
lambda klass=klass, fos=fos: klass(
|
||||
findings=fos,
|
||||
compliance=compliance_obj,
|
||||
file_path=filename,
|
||||
from_cli=False,
|
||||
),
|
||||
is_last,
|
||||
)
|
||||
if not initialization:
|
||||
writer.transform(fos, compliance_obj, name)
|
||||
writer.batch_write_data_to_file()
|
||||
writer._data.clear()
|
||||
|
||||
compressed = _compress_output_files(out_dir)
|
||||
upload_uri = _upload_to_s3(tenant_id, compressed, scan_id)
|
||||
|
||||
# S3 integrations (need output_directory)
|
||||
with rls_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
s3_integrations = Integration.objects.filter(
|
||||
integrationproviderrelationship__provider_id=provider_id,
|
||||
integration_type=Integration.IntegrationChoices.AMAZON_S3,
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import csv
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from tasks.jobs.scan import (
|
||||
_copy_compliance_requirement_rows,
|
||||
_create_finding_delta,
|
||||
_persist_compliance_requirement_rows,
|
||||
_store_resources,
|
||||
create_compliance_requirements,
|
||||
perform_prowler_scan,
|
||||
)
|
||||
from tasks.utils import CustomEncoder
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.exceptions import ProviderConnectionError
|
||||
from api.models import Finding, Provider, Resource, Scan, StateChoices, StatusChoices
|
||||
from prowler.lib.check.models import Severity
|
||||
@@ -1045,3 +1050,773 @@ class TestCreateComplianceRequirements:
|
||||
|
||||
assert "requirements_created" in result
|
||||
assert result["requirements_created"] >= 0
|
||||
|
||||
|
||||
class TestComplianceRequirementCopy:
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_streams_csv(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
captured = {}
|
||||
|
||||
def copy_side_effect(sql, file_obj):
|
||||
captured["sql"] = sql
|
||||
captured["data"] = file_obj.read()
|
||||
|
||||
cursor.copy_expert.side_effect = copy_side_effect
|
||||
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"compliance_id": "cisa_aws",
|
||||
"framework": "CISA",
|
||||
"version": None,
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
|
||||
|
||||
mock_psycopg_connection.assert_called_once_with("admin")
|
||||
connection.cursor.assert_called_once()
|
||||
cursor.execute.assert_called_once()
|
||||
cursor.copy_expert.assert_called_once()
|
||||
|
||||
csv_rows = list(csv.reader(StringIO(captured["data"])))
|
||||
assert csv_rows[0][0] == str(row["id"])
|
||||
assert csv_rows[0][5] == ""
|
||||
assert csv_rows[0][-1] == str(row["scan_id"])
|
||||
|
||||
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
@patch(
|
||||
"tasks.jobs.scan._copy_compliance_requirement_rows",
|
||||
side_effect=Exception("copy failed"),
|
||||
)
|
||||
def test_persist_compliance_requirement_rows_fallback(
|
||||
self, mock_copy, mock_rls_transaction, mock_bulk_create
|
||||
):
|
||||
inserted_at = datetime.now(timezone.utc)
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"inserted_at": inserted_at,
|
||||
"compliance_id": "cisa_aws",
|
||||
"framework": "CISA",
|
||||
"version": "1.0",
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
tenant_id = row["tenant_id"]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, [row])
|
||||
|
||||
mock_copy.assert_called_once_with(tenant_id, [row])
|
||||
mock_rls_transaction.assert_called_once_with(tenant_id)
|
||||
mock_bulk_create.assert_called_once()
|
||||
|
||||
args, kwargs = mock_bulk_create.call_args
|
||||
objects = args[0]
|
||||
assert len(objects) == 1
|
||||
fallback = objects[0]
|
||||
assert fallback.version == row["version"]
|
||||
assert fallback.compliance_id == row["compliance_id"]
|
||||
|
||||
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
@patch("tasks.jobs.scan._copy_compliance_requirement_rows")
|
||||
def test_persist_compliance_requirement_rows_no_rows(
|
||||
self, mock_copy, mock_rls_transaction, mock_bulk_create
|
||||
):
|
||||
_persist_compliance_requirement_rows(str(uuid.uuid4()), [])
|
||||
|
||||
mock_copy.assert_not_called()
|
||||
mock_rls_transaction.assert_not_called()
|
||||
mock_bulk_create.assert_not_called()
|
||||
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_multiple_rows(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
"""Test COPY with multiple rows to ensure batch processing works correctly."""
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
captured = {}
|
||||
|
||||
def copy_side_effect(sql, file_obj):
|
||||
captured["sql"] = sql
|
||||
captured["data"] = file_obj.read()
|
||||
|
||||
cursor.copy_expert.side_effect = copy_side_effect
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = uuid.uuid4()
|
||||
inserted_at = datetime.now(timezone.utc)
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": inserted_at,
|
||||
"compliance_id": "cisa_aws",
|
||||
"framework": "CISA",
|
||||
"version": "1.0",
|
||||
"description": "First requirement",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 5,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 5,
|
||||
"scan_id": scan_id,
|
||||
},
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": inserted_at,
|
||||
"compliance_id": "cisa_aws",
|
||||
"framework": "CISA",
|
||||
"version": "1.0",
|
||||
"description": "Second requirement",
|
||||
"region": "us-west-2",
|
||||
"requirement_id": "req-2",
|
||||
"requirement_status": "FAIL",
|
||||
"passed_checks": 3,
|
||||
"failed_checks": 2,
|
||||
"total_checks": 5,
|
||||
"scan_id": scan_id,
|
||||
},
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": inserted_at,
|
||||
"compliance_id": "aws_foundational_security_aws",
|
||||
"framework": "AWS-Foundational-Security-Best-Practices",
|
||||
"version": "2.0",
|
||||
"description": "Third requirement",
|
||||
"region": "eu-west-1",
|
||||
"requirement_id": "req-3",
|
||||
"requirement_status": "MANUAL",
|
||||
"passed_checks": 0,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 3,
|
||||
"scan_id": scan_id,
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
_copy_compliance_requirement_rows(tenant_id, rows)
|
||||
|
||||
mock_psycopg_connection.assert_called_once_with("admin")
|
||||
connection.cursor.assert_called_once()
|
||||
cursor.execute.assert_called_once()
|
||||
cursor.copy_expert.assert_called_once()
|
||||
|
||||
csv_rows = list(csv.reader(StringIO(captured["data"])))
|
||||
assert len(csv_rows) == 3
|
||||
|
||||
# Validate first row
|
||||
assert csv_rows[0][0] == str(rows[0]["id"])
|
||||
assert csv_rows[0][1] == tenant_id
|
||||
assert csv_rows[0][3] == "cisa_aws"
|
||||
assert csv_rows[0][4] == "CISA"
|
||||
assert csv_rows[0][6] == "First requirement"
|
||||
assert csv_rows[0][7] == "us-east-1"
|
||||
assert csv_rows[0][10] == "5"
|
||||
assert csv_rows[0][11] == "0"
|
||||
assert csv_rows[0][12] == "5"
|
||||
|
||||
# Validate second row
|
||||
assert csv_rows[1][0] == str(rows[1]["id"])
|
||||
assert csv_rows[1][7] == "us-west-2"
|
||||
assert csv_rows[1][9] == "FAIL"
|
||||
assert csv_rows[1][10] == "3"
|
||||
assert csv_rows[1][11] == "2"
|
||||
|
||||
# Validate third row
|
||||
assert csv_rows[2][0] == str(rows[2]["id"])
|
||||
assert csv_rows[2][3] == "aws_foundational_security_aws"
|
||||
assert csv_rows[2][5] == "2.0"
|
||||
assert csv_rows[2][9] == "MANUAL"
|
||||
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_null_values(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
"""Test COPY handles NULL/None values correctly in nullable fields."""
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
captured = {}
|
||||
|
||||
def copy_side_effect(sql, file_obj):
|
||||
captured["sql"] = sql
|
||||
captured["data"] = file_obj.read()
|
||||
|
||||
cursor.copy_expert.side_effect = copy_side_effect
|
||||
|
||||
# Row with all nullable fields set to None/empty
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"compliance_id": "test_framework",
|
||||
"framework": "Test",
|
||||
"version": None, # nullable
|
||||
"description": None, # nullable
|
||||
"region": "",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 0,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 0,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
|
||||
|
||||
csv_rows = list(csv.reader(StringIO(captured["data"])))
|
||||
assert len(csv_rows) == 1
|
||||
|
||||
# Validate that None values are converted to empty strings in CSV
|
||||
assert csv_rows[0][5] == "" # version
|
||||
assert csv_rows[0][6] == "" # description
|
||||
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_special_characters(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
"""Test COPY correctly escapes special characters in CSV."""
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
captured = {}
|
||||
|
||||
def copy_side_effect(sql, file_obj):
|
||||
captured["sql"] = sql
|
||||
captured["data"] = file_obj.read()
|
||||
|
||||
cursor.copy_expert.side_effect = copy_side_effect
|
||||
|
||||
# Row with special characters that need escaping
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"compliance_id": 'framework"with"quotes',
|
||||
"framework": "Framework,with,commas",
|
||||
"version": "1.0",
|
||||
"description": 'Description with "quotes", commas, and\nnewlines',
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
|
||||
|
||||
# Verify CSV was generated (csv module handles escaping automatically)
|
||||
csv_rows = list(csv.reader(StringIO(captured["data"])))
|
||||
assert len(csv_rows) == 1
|
||||
|
||||
# Verify special characters are preserved after CSV parsing
|
||||
assert csv_rows[0][3] == 'framework"with"quotes'
|
||||
assert csv_rows[0][4] == "Framework,with,commas"
|
||||
assert "quotes" in csv_rows[0][6]
|
||||
assert "commas" in csv_rows[0][6]
|
||||
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_missing_inserted_at(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
"""Test COPY uses current datetime when inserted_at is missing."""
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
captured = {}
|
||||
|
||||
def copy_side_effect(sql, file_obj):
|
||||
captured["sql"] = sql
|
||||
captured["data"] = file_obj.read()
|
||||
|
||||
cursor.copy_expert.side_effect = copy_side_effect
|
||||
|
||||
# Row without inserted_at field
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"compliance_id": "test_framework",
|
||||
"framework": "Test",
|
||||
"version": "1.0",
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
# Note: inserted_at is intentionally missing
|
||||
}
|
||||
|
||||
before_call = datetime.now(timezone.utc)
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
|
||||
after_call = datetime.now(timezone.utc)
|
||||
|
||||
csv_rows = list(csv.reader(StringIO(captured["data"])))
|
||||
assert len(csv_rows) == 1
|
||||
|
||||
# Verify inserted_at was auto-generated and is a valid ISO datetime
|
||||
inserted_at_str = csv_rows[0][2]
|
||||
inserted_at = datetime.fromisoformat(inserted_at_str)
|
||||
assert before_call <= inserted_at <= after_call
|
||||
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_transaction_rollback_on_copy_error(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
"""Test transaction is rolled back when copy_expert fails."""
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
# Simulate copy_expert failure
|
||||
cursor.copy_expert.side_effect = Exception("COPY command failed")
|
||||
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"compliance_id": "test",
|
||||
"framework": "Test",
|
||||
"version": "1.0",
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
with pytest.raises(Exception, match="COPY command failed"):
|
||||
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
|
||||
|
||||
# Verify rollback was called
|
||||
connection.rollback.assert_called_once()
|
||||
connection.commit.assert_not_called()
|
||||
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_transaction_rollback_on_set_config_error(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
"""Test transaction is rolled back when SET_CONFIG fails."""
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
# Simulate cursor.execute failure
|
||||
cursor.execute.side_effect = Exception("SET prowler.tenant_id failed")
|
||||
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"compliance_id": "test",
|
||||
"framework": "Test",
|
||||
"version": "1.0",
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
with pytest.raises(Exception, match="SET prowler.tenant_id failed"):
|
||||
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
|
||||
|
||||
# Verify rollback was called
|
||||
connection.rollback.assert_called_once()
|
||||
connection.commit.assert_not_called()
|
||||
|
||||
@patch("tasks.jobs.scan.psycopg_connection")
|
||||
def test_copy_compliance_requirement_rows_commit_on_success(
|
||||
self, mock_psycopg_connection, settings
|
||||
):
|
||||
"""Test transaction is committed on successful COPY."""
|
||||
settings.DATABASES.setdefault("admin", settings.DATABASES["default"])
|
||||
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor_context = MagicMock()
|
||||
cursor_context.__enter__.return_value = cursor
|
||||
cursor_context.__exit__.return_value = False
|
||||
connection.cursor.return_value = cursor_context
|
||||
connection.__enter__.return_value = connection
|
||||
connection.__exit__.return_value = False
|
||||
|
||||
context_manager = MagicMock()
|
||||
context_manager.__enter__.return_value = connection
|
||||
context_manager.__exit__.return_value = False
|
||||
mock_psycopg_connection.return_value = context_manager
|
||||
|
||||
cursor.copy_expert.return_value = None # Success
|
||||
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"compliance_id": "test",
|
||||
"framework": "Test",
|
||||
"version": "1.0",
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
with patch.object(MainRouter, "admin_db", "admin"):
|
||||
_copy_compliance_requirement_rows(str(row["tenant_id"]), [row])
|
||||
|
||||
# Verify commit was called and rollback was not
|
||||
connection.commit.assert_called_once()
|
||||
connection.rollback.assert_not_called()
|
||||
# Verify autocommit was disabled
|
||||
assert connection.autocommit is False
|
||||
|
||||
@patch("tasks.jobs.scan._copy_compliance_requirement_rows")
|
||||
def test_persist_compliance_requirement_rows_success(self, mock_copy):
|
||||
"""Test successful COPY path without fallback to ORM."""
|
||||
mock_copy.return_value = None # Success, no exception
|
||||
|
||||
tenant_id = str(uuid.uuid4())
|
||||
rows = [
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": datetime.now(timezone.utc),
|
||||
"compliance_id": "test",
|
||||
"framework": "Test",
|
||||
"version": "1.0",
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
]
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, rows)
|
||||
|
||||
# Verify COPY was called
|
||||
mock_copy.assert_called_once_with(tenant_id, rows)
|
||||
|
||||
@patch("tasks.jobs.scan.logger")
|
||||
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
@patch(
|
||||
"tasks.jobs.scan._copy_compliance_requirement_rows",
|
||||
side_effect=Exception("COPY failed"),
|
||||
)
|
||||
def test_persist_compliance_requirement_rows_fallback_logging(
|
||||
self, mock_copy, mock_rls_transaction, mock_bulk_create, mock_logger
|
||||
):
|
||||
"""Test logger.exception is called when COPY fails and fallback occurs."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
row = {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": datetime.now(timezone.utc),
|
||||
"compliance_id": "test",
|
||||
"framework": "Test",
|
||||
"version": "1.0",
|
||||
"description": "desc",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 1,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 1,
|
||||
"scan_id": uuid.uuid4(),
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, [row])
|
||||
|
||||
# Verify logger.exception was called
|
||||
mock_logger.exception.assert_called_once()
|
||||
args, kwargs = mock_logger.exception.call_args
|
||||
assert "COPY bulk insert" in args[0]
|
||||
assert "falling back to ORM" in args[0]
|
||||
assert kwargs.get("exc_info") is not None
|
||||
|
||||
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
@patch(
|
||||
"tasks.jobs.scan._copy_compliance_requirement_rows",
|
||||
side_effect=Exception("copy failed"),
|
||||
)
|
||||
def test_persist_compliance_requirement_rows_fallback_multiple_rows(
|
||||
self, mock_copy, mock_rls_transaction, mock_bulk_create
|
||||
):
|
||||
"""Test ORM fallback with multiple rows."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = uuid.uuid4()
|
||||
inserted_at = datetime.now(timezone.utc)
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": inserted_at,
|
||||
"compliance_id": "cisa_aws",
|
||||
"framework": "CISA",
|
||||
"version": "1.0",
|
||||
"description": "First requirement",
|
||||
"region": "us-east-1",
|
||||
"requirement_id": "req-1",
|
||||
"requirement_status": "PASS",
|
||||
"passed_checks": 5,
|
||||
"failed_checks": 0,
|
||||
"total_checks": 5,
|
||||
"scan_id": scan_id,
|
||||
},
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": inserted_at,
|
||||
"compliance_id": "cisa_aws",
|
||||
"framework": "CISA",
|
||||
"version": "1.0",
|
||||
"description": "Second requirement",
|
||||
"region": "us-west-2",
|
||||
"requirement_id": "req-2",
|
||||
"requirement_status": "FAIL",
|
||||
"passed_checks": 2,
|
||||
"failed_checks": 3,
|
||||
"total_checks": 5,
|
||||
"scan_id": scan_id,
|
||||
},
|
||||
]
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, rows)
|
||||
|
||||
mock_copy.assert_called_once_with(tenant_id, rows)
|
||||
mock_rls_transaction.assert_called_once_with(tenant_id)
|
||||
mock_bulk_create.assert_called_once()
|
||||
|
||||
args, kwargs = mock_bulk_create.call_args
|
||||
objects = args[0]
|
||||
assert len(objects) == 2
|
||||
assert kwargs["batch_size"] == 500
|
||||
|
||||
# Validate first object
|
||||
assert objects[0].id == rows[0]["id"]
|
||||
assert objects[0].tenant_id == rows[0]["tenant_id"]
|
||||
assert objects[0].compliance_id == rows[0]["compliance_id"]
|
||||
assert objects[0].framework == rows[0]["framework"]
|
||||
assert objects[0].region == rows[0]["region"]
|
||||
assert objects[0].passed_checks == 5
|
||||
assert objects[0].failed_checks == 0
|
||||
|
||||
# Validate second object
|
||||
assert objects[1].id == rows[1]["id"]
|
||||
assert objects[1].requirement_id == rows[1]["requirement_id"]
|
||||
assert objects[1].requirement_status == rows[1]["requirement_status"]
|
||||
assert objects[1].passed_checks == 2
|
||||
assert objects[1].failed_checks == 3
|
||||
|
||||
@patch("tasks.jobs.scan.ComplianceRequirementOverview.objects.bulk_create")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
@patch(
|
||||
"tasks.jobs.scan._copy_compliance_requirement_rows",
|
||||
side_effect=Exception("copy failed"),
|
||||
)
|
||||
def test_persist_compliance_requirement_rows_fallback_all_fields(
|
||||
self, mock_copy, mock_rls_transaction, mock_bulk_create
|
||||
):
|
||||
"""Test ORM fallback correctly maps all fields from row dict to model."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
row_id = uuid.uuid4()
|
||||
scan_id = uuid.uuid4()
|
||||
inserted_at = datetime.now(timezone.utc)
|
||||
|
||||
row = {
|
||||
"id": row_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inserted_at": inserted_at,
|
||||
"compliance_id": "aws_foundational_security_aws",
|
||||
"framework": "AWS-Foundational-Security-Best-Practices",
|
||||
"version": "2.0",
|
||||
"description": "Ensure MFA is enabled",
|
||||
"region": "eu-west-1",
|
||||
"requirement_id": "iam.1",
|
||||
"requirement_status": "FAIL",
|
||||
"passed_checks": 10,
|
||||
"failed_checks": 5,
|
||||
"total_checks": 15,
|
||||
"scan_id": scan_id,
|
||||
}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
|
||||
_persist_compliance_requirement_rows(tenant_id, [row])
|
||||
|
||||
args, kwargs = mock_bulk_create.call_args
|
||||
objects = args[0]
|
||||
assert len(objects) == 1
|
||||
|
||||
obj = objects[0]
|
||||
# Validate ALL fields are correctly mapped
|
||||
assert obj.id == row_id
|
||||
assert obj.tenant_id == tenant_id
|
||||
assert obj.inserted_at == inserted_at
|
||||
assert obj.compliance_id == "aws_foundational_security_aws"
|
||||
assert obj.framework == "AWS-Foundational-Security-Best-Practices"
|
||||
assert obj.version == "2.0"
|
||||
assert obj.description == "Ensure MFA is enabled"
|
||||
assert obj.region == "eu-west-1"
|
||||
assert obj.requirement_id == "iam.1"
|
||||
assert obj.requirement_status == "FAIL"
|
||||
assert obj.passed_checks == 10
|
||||
assert obj.failed_checks == 5
|
||||
assert obj.total_checks == 15
|
||||
assert obj.scan_id == scan_id
|
||||
|
||||
Reference in New Issue
Block a user