chore: unify ruff tooling and route code quality through the Makefile (#11675)

This commit is contained in:
Rubén De la Torre Vico
2026-06-23 17:15:05 +02:00
committed by GitHub
parent de7da3e960
commit 058a1dc8fe
194 changed files with 1112 additions and 1210 deletions
+19 -2
View File
@@ -14,7 +14,7 @@ dev = [
"pytest-env==1.1.3",
"pytest-randomly==3.15.0",
"pytest-xdist==3.6.1",
"ruff==0.5.0",
"ruff==0.15.11",
"tqdm==4.67.1",
"vulture==2.14",
"prek==0.3.9"
@@ -73,6 +73,23 @@ package-mode = false
requires-python = ">=3.11,<3.13"
version = "1.33.0"
# Shared ruff baseline (kept in sync with mcp_server/pyproject.toml).
# target-version tracks this project's lowest supported Python.
[tool.ruff]
src = ["src"]
target-version = "py311"
[tool.ruff.lint]
# Defaults (E4/E7/E9, F) plus import sorting, modern-syntax upgrades, and
# comprehension lints — all mechanically auto-fixable. flake8-bugbear (B) is a
# good next step but needs manual cleanup (e.g. B904 raise-from), so it is left
# out of the shared baseline for now.
extend-select = [
"I", # isort — import ordering (prek's isort hook covers only the SDK)
"UP", # pyupgrade — modern syntax for the min supported Python
"C4" # flake8-comprehensions
]
[tool.uv]
# Transitive pins matching master to avoid silent drift; bump deliberately.
constraint-dependencies = [
@@ -393,7 +410,7 @@ constraint-dependencies = [
"rpds-py==0.30.0",
"rsa==4.9.1",
"ruamel-yaml==0.19.1",
"ruff==0.5.0",
"ruff==0.15.11",
"s3transfer==0.14.0",
"scaleway==2.10.3",
"scaleway-core==2.10.3",
+1 -2
View File
@@ -1,6 +1,4 @@
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
from django.db import transaction
from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.models import (
@@ -11,6 +9,7 @@ from api.models import (
User,
UserRoleRelationship,
)
from django.db import transaction
class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
+6 -6
View File
@@ -1,14 +1,12 @@
import logging
import os
import sys
from pathlib import Path
from django.apps import AppConfig
from django.conf import settings
from config.custom_logging import BackendLogger
from config.env import env
from django.apps import AppConfig
from django.conf import settings
logger = logging.getLogger(BackendLogger.API)
@@ -30,8 +28,10 @@ class ApiConfig(AppConfig):
name = "api"
def ready(self):
from api import schema_extensions # noqa: F401
from api import signals # noqa: F401
from api import (
schema_extensions, # noqa: F401
signals, # noqa: F401
)
# Generate required cryptographic keys if not present, but only if:
# `"manage.py" not in sys.argv[0]`: If an external server (e.g., Gunicorn) is running the app
@@ -5,7 +5,6 @@ from api.attack_paths.queries import (
get_query_by_id,
)
__all__ = [
"AttackPathsQueryDefinition",
"AttackPathsQueryParameterDefinition",
@@ -22,10 +22,8 @@ Label-injection pipeline:
import re
from rest_framework.exceptions import ValidationError
from tasks.jobs.attack_paths.config import get_provider_label
# Step 1 - String / comment protection
# Single combined regex: strings first, then line comments.
# The regex engine finds the leftmost match, so a string like 'https://prowler.com'
+3 -5
View File
@@ -1,18 +1,16 @@
import atexit
import logging
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any, Iterator
from typing import Any
from uuid import UUID
import neo4j
import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from config.env import env
from django.conf import settings
from api.attack_paths.retryable_session import RetryableSession
from tasks.jobs.attack_paths.config import (
BATCH_SIZE,
PROVIDER_RESOURCE_LABEL,
@@ -1,12 +1,11 @@
from api.attack_paths.queries.types import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
from api.attack_paths.queries.registry import (
get_queries_for_provider,
get_query_by_id,
)
from api.attack_paths.queries.types import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
__all__ = [
"AttackPathsQueryDefinition",
@@ -5,7 +5,6 @@ from api.attack_paths.queries.types import (
)
from tasks.jobs.attack_paths.config import PROWLER_FINDING_LABEL
# Custom Attack Path Queries
# --------------------------
@@ -1,6 +1,5 @@
from api.attack_paths.queries.types import AttackPathsQueryDefinition
from api.attack_paths.queries.aws import AWS_QUERIES
from api.attack_paths.queries.types import AttackPathsQueryDefinition
# Query definitions organized by provider
_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
@@ -1,5 +1,4 @@
import logging
from collections.abc import Callable
from typing import Any
@@ -1,12 +1,10 @@
import logging
from typing import Any, Iterable
from collections.abc import Iterable
from typing import Any
import neo4j
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from api.attack_paths import database as graph_database, AttackPathsQueryDefinition
from api.attack_paths import AttackPathsQueryDefinition
from api.attack_paths import database as graph_database
from api.attack_paths.cypher_sanitizer import (
inject_provider_label,
validate_custom_query,
@@ -17,6 +15,7 @@ from api.attack_paths.queries.schema import (
get_cartography_schema_query,
)
from config.custom_logging import BackendLogger
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from tasks.jobs.attack_paths.config import (
INTERNAL_LABELS,
INTERNAL_PROPERTIES,
+3 -5
View File
@@ -1,6 +1,7 @@
from typing import Optional, Tuple
from uuid import UUID
from api.db_router import MainRouter
from api.models import TenantAPIKey, TenantAPIKeyManager
from cryptography.fernet import InvalidToken
from django.utils import timezone
from drf_simple_apikey.backends import APIKeyAuthentication as BaseAPIKeyAuth
@@ -10,9 +11,6 @@ from rest_framework.exceptions import AuthenticationFailed
from rest_framework.request import Request
from rest_framework_simplejwt.authentication import JWTAuthentication
from api.db_router import MainRouter
from api.models import TenantAPIKey, TenantAPIKeyManager
class TenantAPIKeyAuthentication(BaseAPIKeyAuth):
model = TenantAPIKey
@@ -81,7 +79,7 @@ class CombinedJWTOrAPIKeyAuthentication(BaseAuthentication):
jwt_auth = JWTAuthentication()
api_key_auth = TenantAPIKeyAuthentication()
def authenticate(self, request: Request) -> Optional[Tuple[object, dict]]:
def authenticate(self, request: Request) -> tuple[object, dict] | None:
auth_header = request.headers.get("Authorization", "")
# Prioritize JWT authentication if both are present
+6 -7
View File
@@ -1,3 +1,9 @@
from api.authentication import CombinedJWTOrAPIKeyAuthentication
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, UserRoleRelationship
from api.rbac.permissions import HasPermissions
from django.conf import settings
from django.db import transaction
from rest_framework import permissions
@@ -8,13 +14,6 @@ from rest_framework.response import Response
from rest_framework_json_api import filters
from rest_framework_json_api.views import ModelViewSet
from api.authentication import CombinedJWTOrAPIKeyAuthentication
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, UserRoleRelationship
from api.rbac.permissions import HasPermissions
class BaseViewSet(ModelViewSet):
authentication_classes = [CombinedJWTOrAPIKeyAuthentication]
+1 -1
View File
@@ -352,7 +352,7 @@ def generate_compliance_overview_template(
total_requirements += 1
provider_check_list = list(requirement.checks.get(provider_type, []))
total_checks = len(provider_check_list)
checks_dict = {check: None for check in provider_check_list}
checks_dict = dict.fromkeys(provider_check_list)
req_status_val = "MANUAL" if total_checks == 0 else "PASS"
+10 -11
View File
@@ -3,8 +3,14 @@ import secrets
import time
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from api.db_router import (
READ_REPLICA_ALIAS,
get_read_db_alias,
reset_read_db_alias,
set_read_db_alias,
)
from celery.utils.log import get_task_logger
from config.env import env
from django.conf import settings
@@ -22,13 +28,6 @@ from psycopg2 import sql as psycopg2_sql
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError
from api.db_router import (
READ_REPLICA_ALIAS,
get_read_db_alias,
reset_read_db_alias,
set_read_db_alias,
)
logger = get_task_logger(__name__)
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
@@ -170,7 +169,7 @@ def one_week_from_now():
"""
Return a datetime object with a date one week from now.
"""
return datetime.now(timezone.utc) + timedelta(days=7)
return datetime.now(UTC) + timedelta(days=7)
def generate_random_token(length: int = 14, symbols: str | None = None) -> str:
@@ -405,10 +404,10 @@ def _should_create_index_on_partition(
# Unknown month abbreviation, include it to be safe
return True
partition_date = datetime(year, month, 1, tzinfo=timezone.utc)
partition_date = datetime(year, month, 1, tzinfo=UTC)
# Get current month start
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
current_month_start = now.replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
+3 -4
View File
@@ -1,14 +1,13 @@
import uuid
from functools import wraps
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, connection, transaction
from rest_framework_json_api.serializers import ValidationError
from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY, rls_transaction
from api.exceptions import ProviderDeletedException
from api.models import Provider, Scan
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, connection, transaction
from rest_framework_json_api.serializers import ValidationError
def set_tenant(func=None, *, keep_tenant=False):
+26 -27
View File
@@ -1,19 +1,4 @@
from datetime import date, datetime, timedelta, timezone
from dateutil.parser import parse
from django.conf import settings
from django.db.models import F, Q
from django_filters.rest_framework import (
BaseInFilter,
BooleanFilter,
CharFilter,
ChoiceFilter,
DateFilter,
FilterSet,
UUIDFilter,
)
from rest_framework_json_api.django_filters.backends import DjangoFilterBackend
from rest_framework_json_api.serializers import ValidationError
from datetime import UTC, date, datetime, timedelta
from api.constants import SEVERITY_ORDER
from api.db_utils import (
@@ -68,6 +53,20 @@ from api.uuid_utils import (
uuid7_start,
)
from api.v1.serializers import TaskBase
from dateutil.parser import parse
from django.conf import settings
from django.db.models import F, Q
from django_filters.rest_framework import (
BaseInFilter,
BooleanFilter,
CharFilter,
ChoiceFilter,
DateFilter,
FilterSet,
UUIDFilter,
)
from rest_framework_json_api.django_filters.backends import DjangoFilterBackend
from rest_framework_json_api.serializers import ValidationError
class CustomDjangoFilterBackend(DjangoFilterBackend):
@@ -598,12 +597,12 @@ class ResourceFilter(ProviderRelationshipFilterSet):
gte_date = (
parse(self.data.get("updated_at__gte")).date()
if self.data.get("updated_at__gte")
else datetime.now(timezone.utc).date()
else datetime.now(UTC).date()
)
lte_date = (
parse(self.data.get("updated_at__lte")).date()
if self.data.get("updated_at__lte")
else datetime.now(timezone.utc).date()
else datetime.now(UTC).date()
)
if abs(lte_date - gte_date) > timedelta(
@@ -748,9 +747,9 @@ class FindingFilter(CommonFindingFilters):
lte_date = cleaned.get("inserted_at__lte") or exact_date
if gte_date is None:
gte_date = datetime.now(timezone.utc).date()
gte_date = datetime.now(UTC).date()
if lte_date is None:
lte_date = datetime.now(timezone.utc).date()
lte_date = datetime.now(UTC).date()
if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
@@ -844,7 +843,7 @@ class FindingFilter(CommonFindingFilters):
def maybe_date_to_datetime(value):
dt = value
if isinstance(value, date):
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc)
dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
return dt
@@ -933,9 +932,9 @@ class FindingGroupFilter(CommonFindingFilters):
lte_date = cleaned.get("inserted_at__lte") or exact_date
if gte_date is None:
gte_date = datetime.now(timezone.utc).date()
gte_date = datetime.now(UTC).date()
if lte_date is None:
lte_date = datetime.now(timezone.utc).date()
lte_date = datetime.now(UTC).date()
if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
@@ -977,7 +976,7 @@ class FindingGroupFilter(CommonFindingFilters):
"""Convert date to datetime if needed."""
dt = value
if isinstance(value, date):
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc)
dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
return dt
@@ -1091,9 +1090,9 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
lte_date = cleaned.get("inserted_at__lte") or exact_date
if gte_date is None:
gte_date = datetime.now(timezone.utc).date()
gte_date = datetime.now(UTC).date()
if lte_date is None:
lte_date = datetime.now(timezone.utc).date()
lte_date = datetime.now(UTC).date()
if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
@@ -1132,7 +1131,7 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
def _maybe_date_to_datetime(value):
dt = value
if isinstance(value, date):
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc)
dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
return dt
+2 -6
View File
@@ -12,7 +12,7 @@ import logging
import threading
import time
from contextlib import suppress
from datetime import datetime, timezone
from datetime import UTC, datetime
from typing import Any
import redis
@@ -62,11 +62,7 @@ class HealthJSONRenderer(JSONRenderer):
def _now_iso() -> str:
return (
datetime.now(timezone.utc)
.isoformat(timespec="milliseconds")
.replace("+00:00", "Z")
)
return datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
def _measure(name: str, check_fn) -> tuple[dict[str, Any], float]:
@@ -1,11 +1,8 @@
import random
from datetime import datetime, timezone
from datetime import UTC, datetime
from math import ceil
from uuid import uuid4
from django.core.management.base import BaseCommand
from tqdm import tqdm
from api.db_utils import rls_transaction
from api.models import (
Finding,
@@ -16,7 +13,9 @@ from api.models import (
Scan,
StatusChoices,
)
from django.core.management.base import BaseCommand
from prowler.lib.check.models import CheckMetadata
from tqdm import tqdm
class Command(BaseCommand):
@@ -116,7 +115,7 @@ class Command(BaseCommand):
trigger="manual",
state="executing",
progress=0,
started_at=datetime.now(timezone.utc),
started_at=datetime.now(UTC),
)
scan_state = "completed"
@@ -272,10 +271,8 @@ class Command(BaseCommand):
self.stdout.write(self.style.ERROR(f"Failed to populate test data: {e}"))
scan_state = "failed"
finally:
scan.completed_at = datetime.now(timezone.utc)
scan.duration = int(
(datetime.now(timezone.utc) - scan.started_at).total_seconds()
)
scan.completed_at = datetime.now(UTC)
scan.duration = int((datetime.now(UTC) - scan.started_at).total_seconds())
scan.progress = 100
scan.state = scan_state
scan.unique_resource_count = num_resources
@@ -1,5 +1,4 @@
from django.core.management.base import BaseCommand
from tasks.jobs.orphan_recovery import reconcile_orphans
+1 -2
View File
@@ -1,11 +1,10 @@
import logging
import time
from config.custom_logging import BackendLogger
from django.core.handlers.asgi import ASGIRequest
from django.db import connections
from config.custom_logging import BackendLogger
class CloseDBConnectionsMiddleware:
"""
+13 -14
View File
@@ -1,26 +1,13 @@
import uuid
from functools import partial
import api.rls
import django.contrib.auth.models
import django.contrib.postgres.indexes
import django.contrib.postgres.search
import django.core.validators
import django.db.models.deletion
import django.utils.timezone
from django.conf import settings
from django.db import migrations, models
from psqlextra.backend.migrations.operations.add_default_partition import (
PostgresAddDefaultPartition,
)
from psqlextra.backend.migrations.operations.create_partitioned_model import (
PostgresCreatePartitionedModel,
)
from psqlextra.manager.manager import PostgresManager
from psqlextra.models.partitioned import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
import api.rls
from api.db_utils import (
DB_PROWLER_PASSWORD,
DB_PROWLER_USER,
@@ -53,6 +40,18 @@ from api.models import (
StateChoices,
StatusChoices,
)
from django.conf import settings
from django.db import migrations, models
from psqlextra.backend.migrations.operations.add_default_partition import (
PostgresAddDefaultPartition,
)
from psqlextra.backend.migrations.operations.create_partitioned_model import (
PostgresCreatePartitionedModel,
)
from psqlextra.manager.manager import PostgresManager
from psqlextra.models.partitioned import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
DB_NAME = settings.DATABASES["default"]["NAME"]
@@ -1,8 +1,7 @@
from api.db_utils import DB_PROWLER_USER
from django.conf import settings
from django.db import migrations
from api.db_utils import DB_PROWLER_USER
DB_NAME = settings.DATABASES["default"]["NAME"]
+1 -2
View File
@@ -2,12 +2,11 @@
import uuid
import api.rls
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -1,6 +1,5 @@
from django.db import migrations
from api.db_router import MainRouter
from django.db import migrations
def create_admin_role(apps, schema_editor):
@@ -1,12 +1,11 @@
import json
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
import django.db.models.deletion
from django.db import migrations, models
from django_celery_beat.models import PeriodicTask
from api.db_utils import rls_transaction
from api.models import Scan, StateChoices
from django.db import migrations, models
from django_celery_beat.models import PeriodicTask
def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
@@ -17,11 +16,11 @@ def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
tenant_id = task_kwargs["tenant_id"]
provider_id = task_kwargs["provider_id"]
current_time = datetime.now(timezone.utc)
current_time = datetime.now(UTC)
scheduled_time_today = datetime.combine(
current_time.date(),
daily_scheduled_scan_task.start_time.time(),
tzinfo=timezone.utc,
tzinfo=UTC,
)
if current_time < scheduled_time_today:
@@ -2,10 +2,9 @@
from functools import partial
from django.db import migrations
from api.db_utils import IntegrationTypeEnum, PostgresEnumMigration, register_enum
from api.models import Integration
from django.db import migrations
IntegrationTypeEnumMigration = PostgresEnumMigration(
enum_name="integration_type",
@@ -2,12 +2,11 @@
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
import django.db.models.deletion
from api.rls import RowLevelSecurityConstraint
from django.db import migrations, models
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.5 on 2025-03-25 11:29
from django.db import migrations, models
import api.db_utils
from django.db import migrations, models
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.7 on 2025-04-16 08:47
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -2,12 +2,11 @@
import uuid
import api.rls
import django.db.models.deletion
import uuid6
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -1,8 +1,7 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration):
@@ -2,12 +2,11 @@
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
import django.db.models.deletion
from api.rls import RowLevelSecurityConstraint
from django.db import migrations, models
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration):
@@ -2,12 +2,11 @@
import uuid
import api.rls
import django.core.validators
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
+2 -3
View File
@@ -2,13 +2,12 @@
import uuid
import api.db_utils
import api.rls
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
import api.db_utils
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -2,10 +2,9 @@
from functools import partial
from django.db import migrations
from api.db_utils import PostgresEnumMigration, ProcessorTypeEnum, register_enum
from api.models import Processor
from django.db import migrations
ProcessorTypeEnumMigration = PostgresEnumMigration(
enum_name="processor_type",
@@ -2,12 +2,11 @@
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
import django.db.models.deletion
from api.rls import RowLevelSecurityConstraint
from django.db import migrations, models
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.7 on 2025-07-09 14:44
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -2,15 +2,14 @@
import uuid
import api.db_utils
import api.rls
import django.core.validators
import django.db.models.deletion
import drf_simple_apikey.models
from django.conf import settings
from django.db import migrations, models
import api.db_utils
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -4,15 +4,14 @@ import json
import logging
import uuid
import api.rls
import django.db.models.deletion
from api.db_router import MainRouter
from config.custom_logging import BackendLogger
from cryptography.fernet import Fernet
from django.conf import settings
from django.db import migrations, models
import api.rls
from api.db_router import MainRouter
logger = logging.getLogger(BackendLogger.API)
@@ -1,8 +1,7 @@
# Generated by Django 5.1.7 on 2025-10-14 00:00
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -2,14 +2,13 @@
import uuid
import api.rls
import django.contrib.postgres.fields
import django.core.validators
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -1,8 +1,7 @@
# Generated by Django 5.1.10 on 2025-09-09 09:25
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.13 on 2025-11-05 08:37
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -2,11 +2,10 @@
import uuid
import api.rls
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -2,11 +2,10 @@
import uuid
import api.rls
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -2,11 +2,10 @@
import uuid
import api.rls
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -2,11 +2,10 @@
import uuid
import api.rls
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -1,10 +1,9 @@
# Generated by Django 5.1.14 on 2025-12-10
from django.db import migrations
from tasks.tasks import backfill_daily_severity_summaries_task
from api.db_router import MainRouter
from api.rls import Tenant
from django.db import migrations
from tasks.tasks import backfill_daily_severity_summaries_task
def trigger_backfill_task(apps, schema_editor):
@@ -1,10 +1,9 @@
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django migration for Alibaba Cloud provider support
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,10 +1,9 @@
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
@@ -1,10 +1,9 @@
import uuid
import api.rls
import django.db.models.deletion
from django.db import migrations, models
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -1,10 +1,9 @@
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
@@ -1,12 +1,10 @@
# Generated by Django 5.1.13 on 2025-11-06 16:20
import api.rls
import django.db.models.deletion
from django.db import migrations, models
from uuid6 import uuid7
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -1,8 +1,7 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django migration for Cloudflare provider support
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django migration for OpenStack provider support
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -2,9 +2,8 @@
# on different database connections, causing a deadlock when combined with RunPython
# in the same migration.
from django.db import migrations
from api.db_router import MainRouter
from django.db import migrations
def backfill_graph_data_ready(apps, schema_editor):
@@ -2,14 +2,13 @@
import uuid
import api.rls
import django.db.models.deletion
from django.contrib.postgres.indexes import GinIndex, OpClass
from django.db import migrations, models
from django.db.models.functions import Upper
from django.utils import timezone
import api.rls
class Migration(migrations.Migration):
dependencies = [
@@ -1,10 +1,9 @@
# Generated by Django 5.1.14 on 2026-02-02
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
from api.db_router import MainRouter
from api.rls import Tenant
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
def trigger_backfill_task(apps, schema_editor):
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations
TASK_NAME = "attack-paths-cleanup-stale-scans"
INTERVAL_HOURS = 1
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
from api.db_router import MainRouter
from api.rls import Tenant
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
def trigger_backfill_task(apps, schema_editor):
@@ -1,8 +1,7 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils
from django.db import migrations
class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations
TASK_NAME = "reconcile-orphan-tasks"
INTERVAL_MINUTES = 2
+30 -31
View File
@@ -1,37 +1,11 @@
import json
import logging
import re
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4
import defusedxml
from allauth.socialaccount.models import SocialApp
from config.custom_logging import BackendLogger
from config.settings.social_login import SOCIALACCOUNT_PROVIDERS
from cryptography.fernet import Fernet, InvalidToken
from defusedxml import ElementTree as ET
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex, OpClass
from django.contrib.postgres.search import SearchVector, SearchVectorField
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q
from django.db.models.functions import Upper
from django.utils import timezone as django_timezone
from django.utils.translation import gettext_lazy as _
from django_celery_beat.models import PeriodicTask
from django_celery_results.models import TaskResult
from drf_simple_apikey.crypto import get_crypto
from drf_simple_apikey.models import AbstractAPIKey, AbstractAPIKeyManager
from psqlextra.manager import PostgresManager
from psqlextra.models import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
from api.db_router import MainRouter
from api.db_utils import (
CustomUserManager,
@@ -58,7 +32,32 @@ from api.rls import (
RowLevelSecurityProtectedModel,
Tenant,
)
from config.custom_logging import BackendLogger
from config.settings.social_login import SOCIALACCOUNT_PROVIDERS
from cryptography.fernet import Fernet, InvalidToken
from defusedxml import ElementTree as ET
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex, OpClass
from django.contrib.postgres.search import SearchVector, SearchVectorField
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q
from django.db.models.functions import Upper
from django.utils import timezone as django_timezone
from django.utils.translation import gettext_lazy as _
from django_celery_beat.models import PeriodicTask
from django_celery_results.models import TaskResult
from drf_simple_apikey.crypto import get_crypto
from drf_simple_apikey.models import AbstractAPIKey, AbstractAPIKeyManager
from prowler.lib.check.models import Severity
from psqlextra.manager import PostgresManager
from psqlextra.models import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
fernet = Fernet(settings.SECRETS_ENCRYPTION_KEY.encode())
@@ -1427,8 +1426,8 @@ class Role(RowLevelSecurityProtectedModel):
@classmethod
def filter_by_permission_state(cls, queryset, value):
q_all_true = Q(**{field: True for field in cls.PERMISSION_FIELDS})
q_all_false = Q(**{field: False for field in cls.PERMISSION_FIELDS})
q_all_true = Q(**dict.fromkeys(cls.PERMISSION_FIELDS, True))
q_all_false = Q(**dict.fromkeys(cls.PERMISSION_FIELDS, False))
if value == PermissionChoices.UNLIMITED:
return queryset.filter(q_all_true)
@@ -2011,11 +2010,11 @@ class SAMLToken(models.Model):
def save(self, *args, **kwargs):
if not self.expires_at:
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15)
self.expires_at = datetime.now(UTC) + timedelta(seconds=15)
super().save(*args, **kwargs)
def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at
return datetime.now(UTC) >= self.expires_at
class SAMLDomainIndex(models.Model):
+20 -23
View File
@@ -1,21 +1,20 @@
from datetime import datetime, timezone
from typing import Generator, Optional
from dateutil.relativedelta import relativedelta
from django.conf import settings
from psqlextra.partitioning import (
PostgresPartitioningManager,
PostgresRangePartition,
PostgresRangePartitioningStrategy,
PostgresTimePartitionSize,
PostgresPartitioningError,
)
from psqlextra.partitioning.config import PostgresPartitioningConfig
from uuid6 import UUID
from collections.abc import Generator
from datetime import UTC, datetime
from api.models import Finding, ResourceFindingMapping
from api.rls import RowLevelSecurityConstraint
from api.uuid_utils import datetime_to_uuid7
from dateutil.relativedelta import relativedelta
from django.conf import settings
from psqlextra.partitioning import (
PostgresPartitioningError,
PostgresPartitioningManager,
PostgresRangePartition,
PostgresRangePartitioningStrategy,
PostgresTimePartitionSize,
)
from psqlextra.partitioning.config import PostgresPartitioningConfig
from uuid6 import UUID
class PostgresUUIDv7RangePartition(PostgresRangePartition):
@@ -24,7 +23,7 @@ class PostgresUUIDv7RangePartition(PostgresRangePartition):
from_values: UUID,
to_values: UUID,
size: PostgresTimePartitionSize,
name_format: Optional[str] = None,
name_format: str | None = None,
**kwargs,
) -> None:
self.from_values = from_values
@@ -38,9 +37,7 @@ class PostgresUUIDv7RangePartition(PostgresRangePartition):
start_timestamp_ms = self.from_values.time
self.start_datetime = datetime.fromtimestamp(
start_timestamp_ms / 1000, timezone.utc
)
self.start_datetime = datetime.fromtimestamp(start_timestamp_ms / 1000, UTC)
def name(self) -> str:
if not self.name_format:
@@ -82,8 +79,8 @@ class PostgresUUIDv7PartitioningStrategy(PostgresRangePartitioningStrategy):
size: PostgresTimePartitionSize,
count: int,
start_date: datetime = None,
max_age: Optional[relativedelta] = None,
name_format: Optional[str] = None,
max_age: relativedelta | None = None,
name_format: str | None = None,
**kwargs,
) -> None:
self.start_date = start_date.replace(
@@ -151,7 +148,7 @@ class PostgresUUIDv7PartitioningStrategy(PostgresRangePartitioningStrategy):
Returns:
datetime: A `datetime` object representing the start of the current month in UTC.
"""
return datetime.now(timezone.utc).replace(
return datetime.now(UTC).replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
@@ -171,7 +168,7 @@ manager = PostgresPartitioningManager(
PostgresPartitioningConfig(
model=Finding,
strategy=PostgresUUIDv7PartitioningStrategy(
start_date=datetime.now(timezone.utc),
start_date=datetime.now(UTC),
size=PostgresTimePartitionSize(
months=settings.FINDINGS_TABLE_PARTITION_MONTHS
),
@@ -187,7 +184,7 @@ manager = PostgresPartitioningManager(
PostgresPartitioningConfig(
model=ResourceFindingMapping,
strategy=PostgresUUIDv7PartitioningStrategy(
start_date=datetime.now(timezone.utc),
start_date=datetime.now(UTC),
size=PostgresTimePartitionSize(
months=settings.FINDINGS_TABLE_PARTITION_MONTHS
),
+3 -4
View File
@@ -1,11 +1,10 @@
from enum import Enum
from django.db.models import QuerySet
from rest_framework.exceptions import PermissionDenied
from rest_framework.permissions import BasePermission
from api.db_router import MainRouter
from api.models import Provider, Role, User
from django.db.models import QuerySet
from rest_framework.exceptions import PermissionDenied
from rest_framework.permissions import BasePermission
class Permissions(Enum):
+1 -2
View File
@@ -1,10 +1,9 @@
from contextlib import nullcontext
from api.db_utils import rls_transaction
from rest_framework.renderers import BaseRenderer
from rest_framework_json_api.renderers import JSONRenderer
from api.db_utils import rls_transaction
class PlainTextRenderer(BaseRenderer):
media_type = "text/plain"
+1 -2
View File
@@ -1,12 +1,11 @@
from typing import Any
from uuid import uuid4
from api.db_utils import DB_USER, POSTGRES_TENANT_VAR
from django.core.exceptions import ValidationError
from django.db import DEFAULT_DB_ALIAS, models
from django.db.backends.ddl_references import Statement, Table
from api.db_utils import DB_USER, POSTGRES_TENANT_VAR
class Tenant(models.Model):
"""
+6 -7
View File
@@ -1,10 +1,3 @@
from celery import states
from celery.signals import before_task_publish
from config.celery import celery_app
from django.db.models.signals import post_delete, pre_delete
from django.dispatch import receiver
from django_celery_results.backends.database import DatabaseBackend
from api.db_utils import delete_related_daily_task
from api.models import (
LighthouseProviderConfiguration,
@@ -14,6 +7,12 @@ from api.models import (
TenantAPIKey,
User,
)
from celery import states
from celery.signals import before_task_publish
from config.celery import celery_app
from django.db.models.signals import post_delete, pre_delete
from django.dispatch import receiver
from django_celery_results.backends.database import DatabaseBackend
def create_task_result_on_publish(sender=None, headers=None, **kwargs): # noqa: F841
+1 -1
View File
@@ -7,7 +7,7 @@ enforces the tenant gate (:class:`api.sse.channelmanager.SSEChannelManager`),
and the channel-name helpers (:func:`api.sse.utils.make_channel_name`).
"""
from api.sse.utils import make_channel_name
from api.sse.base_views import BaseSSEViewSet
from api.sse.utils import make_channel_name
__all__ = ["BaseSSEViewSet", "make_channel_name"]
+2 -3
View File
@@ -5,11 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID
from api.sse.utils import tenant_id_from_channel
from django_eventstream.channelmanager import DefaultChannelManager
from rest_framework.request import Request
from api.sse.utils import tenant_id_from_channel
if TYPE_CHECKING:
from api.models import User
@@ -41,7 +40,7 @@ class SSEChannelManager(DefaultChannelManager):
if tenant_id_from_channel(channel) == request_tenant_id
}
def can_read_channel(self, user: "User | None", channel: str) -> bool:
def can_read_channel(self, user: User | None, channel: str) -> bool:
"""Re-verify tenant membership once the stream is established.
Args:
@@ -1,15 +1,14 @@
import time
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from api.models import Membership, Role, TenantAPIKey, User, UserRoleRelationship
from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header
from django.urls import reverse
from drf_simple_apikey.crypto import get_crypto
from rest_framework.test import APIClient
from api.models import Membership, Role, TenantAPIKey, User, UserRoleRelationship
@pytest.mark.django_db
def test_basic_authentication():
@@ -468,7 +467,7 @@ class TestAPIKeyErrors:
name="Expired Key",
tenant_id=tenants_fixture[0].id,
entity=create_test_user,
expiry_date=datetime.now(timezone.utc) - timedelta(days=1),
expiry_date=datetime.now(UTC) - timedelta(days=1),
)
api_key_headers = get_api_key_header(raw_key)
@@ -500,7 +499,7 @@ class TestAPIKeyErrors:
# Create a valid-looking key with non-existent UUID
crypto = get_crypto()
fake_uuid = str(uuid4())
fake_expiry = (datetime.now(timezone.utc) + timedelta(days=30)).timestamp()
fake_expiry = (datetime.now(UTC) + timedelta(days=30)).timestamp()
payload = {"_pk": fake_uuid, "_exp": fake_expiry}
encrypted_payload = crypto.generate(payload)
@@ -723,7 +722,7 @@ class TestAPIKeyLifecycle:
assert created_data["attributes"]["revoked"] is False
# Create API key with expiry
future_expiry = (datetime.now(timezone.utc) + timedelta(days=90)).isoformat()
future_expiry = (datetime.now(UTC) + timedelta(days=90)).isoformat()
create_with_expiry_response = client.post(
reverse("api-key-list"),
data={
@@ -927,9 +926,9 @@ class TestAPIKeyLifecycle:
auth_response = client.get(reverse("provider-list"), headers=api_key_headers)
# Must return 401 Unauthorized, not 500 Internal Server Error
assert (
auth_response.status_code == 401
), f"Expected 401 but got {auth_response.status_code}: {auth_response.json()}"
assert auth_response.status_code == 401, (
f"Expected 401 but got {auth_response.status_code}: {auth_response.json()}"
)
# Verify error message is present
response_json = auth_response.json()
@@ -1267,7 +1266,7 @@ class TestAPIKeyRLSBypass:
name="Expired Test Key",
tenant_id=tenant.id,
entity=create_test_user,
expiry_date=datetime.now(timezone.utc) - timedelta(days=1),
expiry_date=datetime.now(UTC) - timedelta(days=1),
)
api_key_headers = get_api_key_header(raw_key)
@@ -1,12 +1,11 @@
from unittest.mock import Mock, patch
import pytest
from api.models import Provider
from conftest import get_api_tokens, get_authorization_header
from django.urls import reverse
from rest_framework.test import APIClient
from api.models import Provider
@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.delete_provider_task.delay")
@@ -1,11 +1,10 @@
"""Tests for rls_transaction retry and fallback logic."""
import pytest
from api.db_utils import rls_transaction
from django.db import DEFAULT_DB_ALIAS
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import rls_transaction
@pytest.mark.django_db
class TestRLSTransaction:
@@ -1,10 +1,9 @@
from unittest.mock import patch
import pytest
from conftest import TEST_PASSWORD, TEST_USER, get_api_tokens, get_authorization_header
from django.urls import reverse
from conftest import TEST_USER, TEST_PASSWORD, get_api_tokens, get_authorization_header
@patch("api.v1.views.schedule_provider_scan")
@pytest.mark.django_db
+1 -2
View File
@@ -3,11 +3,10 @@ from unittest.mock import MagicMock, patch
import pytest
from allauth.socialaccount.models import SocialLogin
from django.contrib.auth import get_user_model
from api.adapters import ProwlerSocialAccountAdapter
from api.db_router import MainRouter
from api.models import SAMLConfiguration
from django.contrib.auth import get_user_model
User = get_user_model()
+2 -3
View File
@@ -4,11 +4,9 @@ import types
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from django.conf import settings
import api
import api.apps as api_apps_module
import pytest
from api.apps import (
PRIVATE_KEY_FILE,
PUBLIC_KEY_FILE,
@@ -16,6 +14,7 @@ from api.apps import (
VERIFYING_KEY_ENV,
ApiConfig,
)
from django.conf import settings
@pytest.fixture(autouse=True)
@@ -1,14 +1,12 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import neo4j
import neo4j.exceptions
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
import pytest
from api.attack_paths import database as graph_database
from api.attack_paths import views_helpers
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY,
get_provider_label,
@@ -6,15 +6,13 @@ never contacts Neo4j. These tests validate the database module behavior itself.
"""
import threading
from unittest.mock import MagicMock, patch
import api.attack_paths.database as db_module
import neo4j
import neo4j.exceptions
import pytest
import api.attack_paths.database as db_module
class TestLazyInitialization:
"""Test that Neo4j driver is initialized lazily on first use."""
@@ -1,15 +1,14 @@
import time
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from django.test import RequestFactory
from rest_framework.exceptions import AuthenticationFailed
from api.authentication import SSEAuthentication, TenantAPIKeyAuthentication
from api.db_router import MainRouter
from api.models import TenantAPIKey
from django.test import RequestFactory
from rest_framework.exceptions import AuthenticationFailed
@pytest.mark.django_db
@@ -104,7 +103,7 @@ class TestTenantAPIKeyAuthentication:
# Verify that last_used_at was updated
api_key.refresh_from_db()
assert api_key.last_used_at is not None
assert (datetime.now(timezone.utc) - api_key.last_used_at).seconds < 5
assert (datetime.now(UTC) - api_key.last_used_at).seconds < 5
def test_authenticate_valid_api_key_uses_admin_database(
self, auth_backend, api_keys_fixture, request_factory
@@ -195,7 +194,7 @@ class TestTenantAPIKeyAuthentication:
name="Expired API Key",
tenant_id=tenant.id,
entity=user,
expiry_date=datetime.now(timezone.utc) - timedelta(days=1),
expiry_date=datetime.now(UTC) - timedelta(days=1),
)
request = request_factory.get("/")
@@ -217,7 +216,7 @@ class TestTenantAPIKeyAuthentication:
# Manually create an encrypted key with a non-existent ID
payload = {
"_pk": non_existent_uuid,
"_exp": (datetime.now(timezone.utc) + timedelta(days=30)).timestamp(),
"_exp": (datetime.now(UTC) + timedelta(days=30)).timestamp(),
}
encrypted_key = auth_backend.key_crypto.generate(payload)
fake_key = f"{api_key.prefix}.{encrypted_key}"
@@ -368,7 +367,7 @@ class TestTenantAPIKeyAuthentication:
name="Short-lived API Key",
tenant_id=tenant.id,
entity=user,
expiry_date=datetime.now(timezone.utc) + timedelta(seconds=1),
expiry_date=datetime.now(UTC) + timedelta(seconds=1),
)
# Wait for the key to expire
@@ -1,7 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest
from api import compliance as compliance_module
from api.compliance import (
generate_compliance_overview_template,
@@ -3,13 +3,11 @@
from unittest.mock import patch
import pytest
from rest_framework.exceptions import ValidationError
from api.attack_paths.cypher_sanitizer import (
inject_provider_label,
validate_custom_query,
)
from rest_framework.exceptions import ValidationError
PROVIDER_ID = "019c41ee-7df3-7dec-a684-d839f95619f8"
LABEL = "_Provider_019c41ee7df37deca684d839f95619f8"
@@ -202,9 +200,7 @@ class TestClauseSplitting:
def test_multiple_match_clauses(self):
cypher = (
"MATCH (a:AWSAccount)--(b:AWSRole) "
"MATCH (b)--(c:AWSPolicy) "
"RETURN a, b, c"
"MATCH (a:AWSAccount)--(b:AWSRole) MATCH (b)--(c:AWSPolicy) RETURN a, b, c"
)
result = _inject(cypher)
assert f"(a:AWSAccount:{LABEL})" in result
@@ -265,9 +261,7 @@ class TestRealWorldQueries:
def test_custom_bare_query(self):
cypher = (
"MATCH (a)-[:HAS_POLICY]->(b)\n"
"WHERE a.name CONTAINS 'admin'\n"
"RETURN a, b"
"MATCH (a)-[:HAS_POLICY]->(b)\nWHERE a.name CONTAINS 'admin'\nRETURN a, b"
)
result = _inject(cypher)
assert f"(a:{LABEL})" in result
@@ -344,9 +338,7 @@ class TestEdgeCases:
assert f"(outer:AWSAccount:{LABEL})" in result
def test_multiple_protected_regions(self):
cypher = (
"MATCH (n:X {a: 'hello'}) " 'WHERE n.b = "world" ' "// comment\n" "RETURN n"
)
cypher = "MATCH (n:X {a: 'hello'}) WHERE n.b = \"world\" // comment\nRETURN n"
result = _inject(cypher)
assert "'hello'" in result
assert '"world"' in result
+5 -5
View File
@@ -1,12 +1,12 @@
import pytest
from django.conf import settings
from django.db.migrations.recorder import MigrationRecorder
from django.db.utils import ConnectionRouter
from unittest.mock import patch
import pytest
from api.db_router import MainRouter
from api.rls import Tenant
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
from unittest.mock import patch
from django.conf import settings
from django.db.migrations.recorder import MigrationRecorder
from django.db.utils import ConnectionRouter
@patch("api.db_router.MainRouter.admin_db", new="admin")
+16 -19
View File
@@ -1,14 +1,8 @@
from datetime import datetime, timezone
from datetime import UTC, datetime
from enum import Enum
from unittest.mock import MagicMock, patch
import pytest
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, OperationalError
from freezegun import freeze_time
from psycopg2 import sql as psycopg2_sql
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import (
POSTGRES_TENANT_VAR,
PostgresEnumMigration,
@@ -23,6 +17,11 @@ from api.db_utils import (
update_objects_in_batches,
)
from api.models import Provider
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, OperationalError
from freezegun import freeze_time
from psycopg2 import sql as psycopg2_sql
from rest_framework_json_api.serializers import ValidationError
@pytest.fixture
@@ -94,18 +93,16 @@ class TestEnumToChoices:
class TestOneWeekFromNow:
def test_one_week_from_now(self):
with patch("api.db_utils.datetime") as mock_datetime:
mock_datetime.now.return_value = datetime(2023, 1, 1, tzinfo=timezone.utc)
expected_result = datetime(2023, 1, 8, tzinfo=timezone.utc)
mock_datetime.now.return_value = datetime(2023, 1, 1, tzinfo=UTC)
expected_result = datetime(2023, 1, 8, tzinfo=UTC)
result = one_week_from_now()
assert result == expected_result
def test_one_week_from_now_with_timezone(self):
with patch("api.db_utils.datetime") as mock_datetime:
mock_datetime.now.return_value = datetime(
2023, 6, 15, 12, 0, tzinfo=timezone.utc
)
expected_result = datetime(2023, 6, 22, 12, 0, tzinfo=timezone.utc)
mock_datetime.now.return_value = datetime(2023, 6, 15, 12, 0, tzinfo=UTC)
expected_result = datetime(2023, 6, 22, 12, 0, tzinfo=UTC)
result = one_week_from_now()
assert result == expected_result
@@ -939,9 +936,9 @@ class TestPostgresEnumMigration:
mock_cursor.execute.assert_called_once()
query_arg = mock_cursor.execute.call_args[0][0]
assert isinstance(
query_arg, psycopg2_sql.Composable
), "create_enum_type must pass a psycopg2.sql.Composable, not a raw string."
assert isinstance(query_arg, psycopg2_sql.Composable), (
"create_enum_type must pass a psycopg2.sql.Composable, not a raw string."
)
# Verify the composed SQL structure: CREATE TYPE <Identifier> AS ENUM (<Literals>)
parts = query_arg.seq
assert parts[0] == psycopg2_sql.SQL("CREATE TYPE ")
@@ -962,9 +959,9 @@ class TestPostgresEnumMigration:
mock_cursor.execute.assert_called_once()
query_arg = mock_cursor.execute.call_args[0][0]
assert isinstance(
query_arg, psycopg2_sql.Composable
), "drop_enum_type must pass a psycopg2.sql.Composable, not a raw string."
assert isinstance(query_arg, psycopg2_sql.Composable), (
"drop_enum_type must pass a psycopg2.sql.Composable, not a raw string."
)
# Verify the composed SQL structure: DROP TYPE <Identifier>
parts = query_arg.seq
assert parts[0] == psycopg2_sql.SQL("DROP TYPE ")
+2 -3
View File
@@ -2,12 +2,11 @@ import uuid
from unittest.mock import call, patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.decorators import handle_provider_deletion, set_tenant
from api.exceptions import ProviderDeletedException
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError
@pytest.mark.django_db
+1 -3
View File
@@ -7,15 +7,13 @@ Cover the IETF response envelope, status code mapping (200 / 503), the
from unittest.mock import patch
import pytest
from api import health
from config import version as config_version
from django.core.cache import cache
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient
from api import health
HEALTH_MEDIA_TYPE = "application/health+json"
+1 -2
View File
@@ -1,11 +1,10 @@
from unittest.mock import MagicMock, patch
import pytest
from api.middleware import APILoggingMiddleware
from django.http import HttpResponse
from django.test import RequestFactory
from api.middleware import APILoggingMiddleware
@pytest.mark.django_db
@patch("logging.getLogger")
+3 -4
View File
@@ -2,10 +2,6 @@ import json
from uuid import uuid4
import pytest
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.exceptions import (
TaskFailedException,
TaskInProgressException,
@@ -14,6 +10,9 @@ from api.exceptions import (
from api.models import Task, User
from api.rls import Tenant
from api.v1.mixins import PaginateByPkMixin, TaskManagementMixin
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
@pytest.mark.django_db
+7 -8
View File
@@ -1,10 +1,7 @@
from datetime import datetime, timezone
from datetime import UTC, datetime
import pytest
from allauth.socialaccount.models import SocialApp
from django.core.exceptions import ValidationError
from django.db import IntegrityError
from api.db_router import MainRouter
from api.models import (
ProviderComplianceScore,
@@ -16,6 +13,8 @@ from api.models import (
StatusChoices,
TenantComplianceSummary,
)
from django.core.exceptions import ValidationError
from django.db import IntegrityError
@pytest.mark.django_db
@@ -376,7 +375,7 @@ class TestProviderComplianceScoreModel:
def test_create_provider_compliance_score(self, providers_fixture, scans_fixture):
provider = providers_fixture[0]
scan = scans_fixture[0]
scan.completed_at = datetime.now(timezone.utc)
scan.completed_at = datetime.now(UTC)
scan.save()
score = ProviderComplianceScore.objects.create(
@@ -398,7 +397,7 @@ class TestProviderComplianceScoreModel:
):
provider = providers_fixture[0]
scan = scans_fixture[0]
scan.completed_at = datetime.now(timezone.utc)
scan.completed_at = datetime.now(UTC)
scan.save()
ProviderComplianceScore.objects.create(
@@ -427,12 +426,12 @@ class TestProviderComplianceScoreModel:
):
provider1, provider2, *_ = providers_fixture
scan1 = scans_fixture[0]
scan1.completed_at = datetime.now(timezone.utc)
scan1.completed_at = datetime.now(UTC)
scan1.save()
scan2 = scans_fixture[2]
scan2.state = StateChoices.COMPLETED
scan2.completed_at = datetime.now(timezone.utc)
scan2.completed_at = datetime.now(UTC)
scan2.save()
score1 = ProviderComplianceScore.objects.create(
+3 -4
View File
@@ -2,10 +2,6 @@ import json
from unittest.mock import ANY, Mock, patch
import pytest
from conftest import TEST_PASSWORD, TODAY
from django.urls import reverse
from rest_framework import status
from api.models import (
Membership,
ProviderGroup,
@@ -16,6 +12,9 @@ from api.models import (
UserRoleRelationship,
)
from api.v1.serializers import TokenSerializer
from conftest import TEST_PASSWORD, TODAY
from django.urls import reverse
from rest_framework import status
@pytest.mark.django_db

Some files were not shown because too many files have changed in this diff Show More