chore: unify ruff tooling and route code quality through the Makefile (#11675)

This commit is contained in:
Rubén De la Torre Vico
2026-06-23 17:15:05 +02:00
committed by GitHub
parent de7da3e960
commit 058a1dc8fe
194 changed files with 1112 additions and 1210 deletions
+28 -9
View File
@@ -110,17 +110,36 @@ repos:
priority: 30 priority: 30
## PYTHON — API + MCP Server (ruff) ## PYTHON — API + MCP Server (ruff)
- repo: https://github.com/astral-sh/ruff-pre-commit # Run ruff through `uv run` against each project so prek uses the exact ruff
rev: v0.15.11 # version pinned in that project's uv.lock — the same version GitHub Actions
# runs via `uv run ruff`. This removes the drift between the local hooks and
# CI. api/ and mcp_server/ are separate uv projects, so they need separate
# hooks (each `uv run --project` resolves its own pinned ruff + config).
- repo: local
hooks: hooks:
- id: ruff - id: ruff-check-api
name: "API + MCP - ruff check" name: "API - ruff check"
files: { glob: ["{api,mcp_server}/**/*.py"] } entry: uv run --project ./api ruff check --fix
args: ["--fix"] language: system
files: { glob: ["api/**/*.py"] }
priority: 30 priority: 30
- id: ruff-format - id: ruff-format-api
name: "API + MCP - ruff format" name: "API - ruff format"
files: { glob: ["{api,mcp_server}/**/*.py"] } entry: uv run --project ./api ruff format
language: system
files: { glob: ["api/**/*.py"] }
priority: 20
- id: ruff-check-mcp
name: "MCP - ruff check"
entry: uv run --project ./mcp_server ruff check --fix
language: system
files: { glob: ["mcp_server/**/*.py"] }
priority: 30
- id: ruff-format-mcp
name: "MCP - ruff format"
entry: uv run --project ./mcp_server ruff format
language: system
files: { glob: ["mcp_server/**/*.py"] }
priority: 20 priority: 20
## PYTHON — uv (API + SDK) ## PYTHON — uv (API + SDK)
+34 -11
View File
@@ -45,18 +45,41 @@ coverage-html: ## Show Test Coverage
coverage html && \ coverage html && \
open htmlcov/index.html open htmlcov/index.html
##@ Linting ##@ Code Quality
format: ## Format Code # `make` is the single entrypoint and mirrors CI exactly (uv run + same flags):
@echo "Running black..." # SDK (prowler/, util/) -> flake8 + black + pylint
black . # API & MCP server -> ruff (rules live in each project's pyproject.toml)
# `format` applies fixes (incl. ruff's import/upgrade autofixes); `lint` only
# verifies and is what CI gates on.
.PHONY: format format-sdk format-api format-mcp lint lint-sdk lint-api lint-mcp
lint: ## Lint Code format: format-sdk format-api format-mcp ## Format & autofix all components (SDK, API, MCP)
@echo "Running flake8..."
flake8 . --ignore=E266,W503,E203,E501,W605,E128 --exclude .venv,contrib lint: lint-sdk lint-api lint-mcp ## Lint all components (SDK, API, MCP) — mirrors CI
@echo "Running black... "
black --check . format-sdk: ## Format SDK code (black)
@echo "Running pylint..." uv run black --exclude "\.venv|api|ui|skills|mcp_server" .
pylint --disable=W,C,R,E -j 0 prowler util
lint-sdk: ## Lint SDK code (flake8, black --check, pylint)
uv run flake8 . --ignore=E266,W503,E203,E501,W605,E128 --exclude .venv,contrib,ui,api,skills,mcp_server
uv run black --exclude "\.venv|api|ui|skills|mcp_server" --check .
uv run pylint --disable=W,C,R,E -j 0 -rn -sn prowler/
format-api: ## Format & autofix API code (ruff)
cd api && uv run ruff check . --exclude contrib --fix
cd api && uv run ruff format . --exclude contrib
lint-api: ## Lint API code (ruff check + format --check)
cd api && uv run ruff check . --exclude contrib
cd api && uv run ruff format --check . --exclude contrib
format-mcp: ## Format & autofix MCP server code (ruff)
cd mcp_server && uv run ruff check . --fix
cd mcp_server && uv run ruff format .
lint-mcp: ## Lint MCP server code (ruff check + format --check)
cd mcp_server && uv run ruff check .
cd mcp_server && uv run ruff format --check .
##@ PyPI ##@ PyPI
pypi-clean: ## Delete the distribution files pypi-clean: ## Delete the distribution files
+19 -2
View File
@@ -14,7 +14,7 @@ dev = [
"pytest-env==1.1.3", "pytest-env==1.1.3",
"pytest-randomly==3.15.0", "pytest-randomly==3.15.0",
"pytest-xdist==3.6.1", "pytest-xdist==3.6.1",
"ruff==0.5.0", "ruff==0.15.11",
"tqdm==4.67.1", "tqdm==4.67.1",
"vulture==2.14", "vulture==2.14",
"prek==0.3.9" "prek==0.3.9"
@@ -73,6 +73,23 @@ package-mode = false
requires-python = ">=3.11,<3.13" requires-python = ">=3.11,<3.13"
version = "1.33.0" version = "1.33.0"
# Shared ruff baseline (kept in sync with mcp_server/pyproject.toml).
# target-version tracks this project's lowest supported Python.
[tool.ruff]
src = ["src"]
target-version = "py311"
[tool.ruff.lint]
# Defaults (E4/E7/E9, F) plus import sorting, modern-syntax upgrades, and
# comprehension lints — all mechanically auto-fixable. flake8-bugbear (B) is a
# good next step but needs manual cleanup (e.g. B904 raise-from), so it is left
# out of the shared baseline for now.
extend-select = [
"I", # isort — import ordering (prek's isort hook covers only the SDK)
"UP", # pyupgrade — modern syntax for the min supported Python
"C4" # flake8-comprehensions
]
[tool.uv] [tool.uv]
# Transitive pins matching master to avoid silent drift; bump deliberately. # Transitive pins matching master to avoid silent drift; bump deliberately.
constraint-dependencies = [ constraint-dependencies = [
@@ -393,7 +410,7 @@ constraint-dependencies = [
"rpds-py==0.30.0", "rpds-py==0.30.0",
"rsa==4.9.1", "rsa==4.9.1",
"ruamel-yaml==0.19.1", "ruamel-yaml==0.19.1",
"ruff==0.5.0", "ruff==0.15.11",
"s3transfer==0.14.0", "s3transfer==0.14.0",
"scaleway==2.10.3", "scaleway==2.10.3",
"scaleway-core==2.10.3", "scaleway-core==2.10.3",
+1 -2
View File
@@ -1,6 +1,4 @@
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
from django.db import transaction
from api.db_router import MainRouter from api.db_router import MainRouter
from api.db_utils import rls_transaction from api.db_utils import rls_transaction
from api.models import ( from api.models import (
@@ -11,6 +9,7 @@ from api.models import (
User, User,
UserRoleRelationship, UserRoleRelationship,
) )
from django.db import transaction
class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter): class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
+6 -6
View File
@@ -1,14 +1,12 @@
import logging import logging
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from django.apps import AppConfig
from django.conf import settings
from config.custom_logging import BackendLogger from config.custom_logging import BackendLogger
from config.env import env from config.env import env
from django.apps import AppConfig
from django.conf import settings
logger = logging.getLogger(BackendLogger.API) logger = logging.getLogger(BackendLogger.API)
@@ -30,8 +28,10 @@ class ApiConfig(AppConfig):
name = "api" name = "api"
def ready(self): def ready(self):
from api import schema_extensions # noqa: F401 from api import (
from api import signals # noqa: F401 schema_extensions, # noqa: F401
signals, # noqa: F401
)
# Generate required cryptographic keys if not present, but only if: # Generate required cryptographic keys if not present, but only if:
# `"manage.py" not in sys.argv[0]`: If an external server (e.g., Gunicorn) is running the app # `"manage.py" not in sys.argv[0]`: If an external server (e.g., Gunicorn) is running the app
@@ -5,7 +5,6 @@ from api.attack_paths.queries import (
get_query_by_id, get_query_by_id,
) )
__all__ = [ __all__ = [
"AttackPathsQueryDefinition", "AttackPathsQueryDefinition",
"AttackPathsQueryParameterDefinition", "AttackPathsQueryParameterDefinition",
@@ -22,10 +22,8 @@ Label-injection pipeline:
import re import re
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from tasks.jobs.attack_paths.config import get_provider_label from tasks.jobs.attack_paths.config import get_provider_label
# Step 1 - String / comment protection # Step 1 - String / comment protection
# Single combined regex: strings first, then line comments. # Single combined regex: strings first, then line comments.
# The regex engine finds the leftmost match, so a string like 'https://prowler.com' # The regex engine finds the leftmost match, so a string like 'https://prowler.com'
+3 -5
View File
@@ -1,18 +1,16 @@
import atexit import atexit
import logging import logging
import threading import threading
from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Iterator from typing import Any
from uuid import UUID from uuid import UUID
import neo4j import neo4j
import neo4j.exceptions import neo4j.exceptions
from api.attack_paths.retryable_session import RetryableSession
from config.env import env from config.env import env
from django.conf import settings from django.conf import settings
from api.attack_paths.retryable_session import RetryableSession
from tasks.jobs.attack_paths.config import ( from tasks.jobs.attack_paths.config import (
BATCH_SIZE, BATCH_SIZE,
PROVIDER_RESOURCE_LABEL, PROVIDER_RESOURCE_LABEL,
@@ -1,12 +1,11 @@
from api.attack_paths.queries.types import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
from api.attack_paths.queries.registry import ( from api.attack_paths.queries.registry import (
get_queries_for_provider, get_queries_for_provider,
get_query_by_id, get_query_by_id,
) )
from api.attack_paths.queries.types import (
AttackPathsQueryDefinition,
AttackPathsQueryParameterDefinition,
)
__all__ = [ __all__ = [
"AttackPathsQueryDefinition", "AttackPathsQueryDefinition",
@@ -5,7 +5,6 @@ from api.attack_paths.queries.types import (
) )
from tasks.jobs.attack_paths.config import PROWLER_FINDING_LABEL from tasks.jobs.attack_paths.config import PROWLER_FINDING_LABEL
# Custom Attack Path Queries # Custom Attack Path Queries
# -------------------------- # --------------------------
@@ -1,6 +1,5 @@
from api.attack_paths.queries.types import AttackPathsQueryDefinition
from api.attack_paths.queries.aws import AWS_QUERIES from api.attack_paths.queries.aws import AWS_QUERIES
from api.attack_paths.queries.types import AttackPathsQueryDefinition
# Query definitions organized by provider # Query definitions organized by provider
_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = { _QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
@@ -1,5 +1,4 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@@ -1,12 +1,10 @@
import logging import logging
from collections.abc import Iterable
from typing import Any, Iterable from typing import Any
import neo4j import neo4j
from api.attack_paths import AttackPathsQueryDefinition
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError from api.attack_paths import database as graph_database
from api.attack_paths import database as graph_database, AttackPathsQueryDefinition
from api.attack_paths.cypher_sanitizer import ( from api.attack_paths.cypher_sanitizer import (
inject_provider_label, inject_provider_label,
validate_custom_query, validate_custom_query,
@@ -17,6 +15,7 @@ from api.attack_paths.queries.schema import (
get_cartography_schema_query, get_cartography_schema_query,
) )
from config.custom_logging import BackendLogger from config.custom_logging import BackendLogger
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from tasks.jobs.attack_paths.config import ( from tasks.jobs.attack_paths.config import (
INTERNAL_LABELS, INTERNAL_LABELS,
INTERNAL_PROPERTIES, INTERNAL_PROPERTIES,
+3 -5
View File
@@ -1,6 +1,7 @@
from typing import Optional, Tuple
from uuid import UUID from uuid import UUID
from api.db_router import MainRouter
from api.models import TenantAPIKey, TenantAPIKeyManager
from cryptography.fernet import InvalidToken from cryptography.fernet import InvalidToken
from django.utils import timezone from django.utils import timezone
from drf_simple_apikey.backends import APIKeyAuthentication as BaseAPIKeyAuth from drf_simple_apikey.backends import APIKeyAuthentication as BaseAPIKeyAuth
@@ -10,9 +11,6 @@ from rest_framework.exceptions import AuthenticationFailed
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.authentication import JWTAuthentication
from api.db_router import MainRouter
from api.models import TenantAPIKey, TenantAPIKeyManager
class TenantAPIKeyAuthentication(BaseAPIKeyAuth): class TenantAPIKeyAuthentication(BaseAPIKeyAuth):
model = TenantAPIKey model = TenantAPIKey
@@ -81,7 +79,7 @@ class CombinedJWTOrAPIKeyAuthentication(BaseAuthentication):
jwt_auth = JWTAuthentication() jwt_auth = JWTAuthentication()
api_key_auth = TenantAPIKeyAuthentication() api_key_auth = TenantAPIKeyAuthentication()
def authenticate(self, request: Request) -> Optional[Tuple[object, dict]]: def authenticate(self, request: Request) -> tuple[object, dict] | None:
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
# Prioritize JWT authentication if both are present # Prioritize JWT authentication if both are present
+6 -7
View File
@@ -1,3 +1,9 @@
from api.authentication import CombinedJWTOrAPIKeyAuthentication
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, UserRoleRelationship
from api.rbac.permissions import HasPermissions
from django.conf import settings from django.conf import settings
from django.db import transaction from django.db import transaction
from rest_framework import permissions from rest_framework import permissions
@@ -8,13 +14,6 @@ from rest_framework.response import Response
from rest_framework_json_api import filters from rest_framework_json_api import filters
from rest_framework_json_api.views import ModelViewSet from rest_framework_json_api.views import ModelViewSet
from api.authentication import CombinedJWTOrAPIKeyAuthentication
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, UserRoleRelationship
from api.rbac.permissions import HasPermissions
class BaseViewSet(ModelViewSet): class BaseViewSet(ModelViewSet):
authentication_classes = [CombinedJWTOrAPIKeyAuthentication] authentication_classes = [CombinedJWTOrAPIKeyAuthentication]
+1 -1
View File
@@ -352,7 +352,7 @@ def generate_compliance_overview_template(
total_requirements += 1 total_requirements += 1
provider_check_list = list(requirement.checks.get(provider_type, [])) provider_check_list = list(requirement.checks.get(provider_type, []))
total_checks = len(provider_check_list) total_checks = len(provider_check_list)
checks_dict = {check: None for check in provider_check_list} checks_dict = dict.fromkeys(provider_check_list)
req_status_val = "MANUAL" if total_checks == 0 else "PASS" req_status_val = "MANUAL" if total_checks == 0 else "PASS"
+10 -11
View File
@@ -3,8 +3,14 @@ import secrets
import time import time
import uuid import uuid
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from api.db_router import (
READ_REPLICA_ALIAS,
get_read_db_alias,
reset_read_db_alias,
set_read_db_alias,
)
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from config.env import env from config.env import env
from django.conf import settings from django.conf import settings
@@ -22,13 +28,6 @@ from psycopg2 import sql as psycopg2_sql
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError from rest_framework_json_api.serializers import ValidationError
from api.db_router import (
READ_REPLICA_ALIAS,
get_read_db_alias,
reset_read_db_alias,
set_read_db_alias,
)
logger = get_task_logger(__name__) logger = get_task_logger(__name__)
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test" DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
@@ -170,7 +169,7 @@ def one_week_from_now():
""" """
Return a datetime object with a date one week from now. Return a datetime object with a date one week from now.
""" """
return datetime.now(timezone.utc) + timedelta(days=7) return datetime.now(UTC) + timedelta(days=7)
def generate_random_token(length: int = 14, symbols: str | None = None) -> str: def generate_random_token(length: int = 14, symbols: str | None = None) -> str:
@@ -405,10 +404,10 @@ def _should_create_index_on_partition(
# Unknown month abbreviation, include it to be safe # Unknown month abbreviation, include it to be safe
return True return True
partition_date = datetime(year, month, 1, tzinfo=timezone.utc) partition_date = datetime(year, month, 1, tzinfo=UTC)
# Get current month start # Get current month start
now = datetime.now(timezone.utc) now = datetime.now(UTC)
current_month_start = now.replace( current_month_start = now.replace(
day=1, hour=0, minute=0, second=0, microsecond=0 day=1, hour=0, minute=0, second=0, microsecond=0
) )
+3 -4
View File
@@ -1,14 +1,13 @@
import uuid import uuid
from functools import wraps from functools import wraps
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, connection, transaction
from rest_framework_json_api.serializers import ValidationError
from api.db_router import READ_REPLICA_ALIAS from api.db_router import READ_REPLICA_ALIAS
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY, rls_transaction from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY, rls_transaction
from api.exceptions import ProviderDeletedException from api.exceptions import ProviderDeletedException
from api.models import Provider, Scan from api.models import Provider, Scan
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, connection, transaction
from rest_framework_json_api.serializers import ValidationError
def set_tenant(func=None, *, keep_tenant=False): def set_tenant(func=None, *, keep_tenant=False):
+26 -27
View File
@@ -1,19 +1,4 @@
from datetime import date, datetime, timedelta, timezone from datetime import UTC, date, datetime, timedelta
from dateutil.parser import parse
from django.conf import settings
from django.db.models import F, Q
from django_filters.rest_framework import (
BaseInFilter,
BooleanFilter,
CharFilter,
ChoiceFilter,
DateFilter,
FilterSet,
UUIDFilter,
)
from rest_framework_json_api.django_filters.backends import DjangoFilterBackend
from rest_framework_json_api.serializers import ValidationError
from api.constants import SEVERITY_ORDER from api.constants import SEVERITY_ORDER
from api.db_utils import ( from api.db_utils import (
@@ -68,6 +53,20 @@ from api.uuid_utils import (
uuid7_start, uuid7_start,
) )
from api.v1.serializers import TaskBase from api.v1.serializers import TaskBase
from dateutil.parser import parse
from django.conf import settings
from django.db.models import F, Q
from django_filters.rest_framework import (
BaseInFilter,
BooleanFilter,
CharFilter,
ChoiceFilter,
DateFilter,
FilterSet,
UUIDFilter,
)
from rest_framework_json_api.django_filters.backends import DjangoFilterBackend
from rest_framework_json_api.serializers import ValidationError
class CustomDjangoFilterBackend(DjangoFilterBackend): class CustomDjangoFilterBackend(DjangoFilterBackend):
@@ -598,12 +597,12 @@ class ResourceFilter(ProviderRelationshipFilterSet):
gte_date = ( gte_date = (
parse(self.data.get("updated_at__gte")).date() parse(self.data.get("updated_at__gte")).date()
if self.data.get("updated_at__gte") if self.data.get("updated_at__gte")
else datetime.now(timezone.utc).date() else datetime.now(UTC).date()
) )
lte_date = ( lte_date = (
parse(self.data.get("updated_at__lte")).date() parse(self.data.get("updated_at__lte")).date()
if self.data.get("updated_at__lte") if self.data.get("updated_at__lte")
else datetime.now(timezone.utc).date() else datetime.now(UTC).date()
) )
if abs(lte_date - gte_date) > timedelta( if abs(lte_date - gte_date) > timedelta(
@@ -748,9 +747,9 @@ class FindingFilter(CommonFindingFilters):
lte_date = cleaned.get("inserted_at__lte") or exact_date lte_date = cleaned.get("inserted_at__lte") or exact_date
if gte_date is None: if gte_date is None:
gte_date = datetime.now(timezone.utc).date() gte_date = datetime.now(UTC).date()
if lte_date is None: if lte_date is None:
lte_date = datetime.now(timezone.utc).date() lte_date = datetime.now(UTC).date()
if abs(lte_date - gte_date) > timedelta( if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE days=settings.FINDINGS_MAX_DAYS_IN_RANGE
@@ -844,7 +843,7 @@ class FindingFilter(CommonFindingFilters):
def maybe_date_to_datetime(value): def maybe_date_to_datetime(value):
dt = value dt = value
if isinstance(value, date): if isinstance(value, date):
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc) dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
return dt return dt
@@ -933,9 +932,9 @@ class FindingGroupFilter(CommonFindingFilters):
lte_date = cleaned.get("inserted_at__lte") or exact_date lte_date = cleaned.get("inserted_at__lte") or exact_date
if gte_date is None: if gte_date is None:
gte_date = datetime.now(timezone.utc).date() gte_date = datetime.now(UTC).date()
if lte_date is None: if lte_date is None:
lte_date = datetime.now(timezone.utc).date() lte_date = datetime.now(UTC).date()
if abs(lte_date - gte_date) > timedelta( if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE days=settings.FINDINGS_MAX_DAYS_IN_RANGE
@@ -977,7 +976,7 @@ class FindingGroupFilter(CommonFindingFilters):
"""Convert date to datetime if needed.""" """Convert date to datetime if needed."""
dt = value dt = value
if isinstance(value, date): if isinstance(value, date):
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc) dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
return dt return dt
@@ -1091,9 +1090,9 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
lte_date = cleaned.get("inserted_at__lte") or exact_date lte_date = cleaned.get("inserted_at__lte") or exact_date
if gte_date is None: if gte_date is None:
gte_date = datetime.now(timezone.utc).date() gte_date = datetime.now(UTC).date()
if lte_date is None: if lte_date is None:
lte_date = datetime.now(timezone.utc).date() lte_date = datetime.now(UTC).date()
if abs(lte_date - gte_date) > timedelta( if abs(lte_date - gte_date) > timedelta(
days=settings.FINDINGS_MAX_DAYS_IN_RANGE days=settings.FINDINGS_MAX_DAYS_IN_RANGE
@@ -1132,7 +1131,7 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
def _maybe_date_to_datetime(value): def _maybe_date_to_datetime(value):
dt = value dt = value
if isinstance(value, date): if isinstance(value, date):
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc) dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
return dt return dt
+2 -6
View File
@@ -12,7 +12,7 @@ import logging
import threading import threading
import time import time
from contextlib import suppress from contextlib import suppress
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Any from typing import Any
import redis import redis
@@ -62,11 +62,7 @@ class HealthJSONRenderer(JSONRenderer):
def _now_iso() -> str: def _now_iso() -> str:
return ( return datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
datetime.now(timezone.utc)
.isoformat(timespec="milliseconds")
.replace("+00:00", "Z")
)
def _measure(name: str, check_fn) -> tuple[dict[str, Any], float]: def _measure(name: str, check_fn) -> tuple[dict[str, Any], float]:
@@ -1,11 +1,8 @@
import random import random
from datetime import datetime, timezone from datetime import UTC, datetime
from math import ceil from math import ceil
from uuid import uuid4 from uuid import uuid4
from django.core.management.base import BaseCommand
from tqdm import tqdm
from api.db_utils import rls_transaction from api.db_utils import rls_transaction
from api.models import ( from api.models import (
Finding, Finding,
@@ -16,7 +13,9 @@ from api.models import (
Scan, Scan,
StatusChoices, StatusChoices,
) )
from django.core.management.base import BaseCommand
from prowler.lib.check.models import CheckMetadata from prowler.lib.check.models import CheckMetadata
from tqdm import tqdm
class Command(BaseCommand): class Command(BaseCommand):
@@ -116,7 +115,7 @@ class Command(BaseCommand):
trigger="manual", trigger="manual",
state="executing", state="executing",
progress=0, progress=0,
started_at=datetime.now(timezone.utc), started_at=datetime.now(UTC),
) )
scan_state = "completed" scan_state = "completed"
@@ -272,10 +271,8 @@ class Command(BaseCommand):
self.stdout.write(self.style.ERROR(f"Failed to populate test data: {e}")) self.stdout.write(self.style.ERROR(f"Failed to populate test data: {e}"))
scan_state = "failed" scan_state = "failed"
finally: finally:
scan.completed_at = datetime.now(timezone.utc) scan.completed_at = datetime.now(UTC)
scan.duration = int( scan.duration = int((datetime.now(UTC) - scan.started_at).total_seconds())
(datetime.now(timezone.utc) - scan.started_at).total_seconds()
)
scan.progress = 100 scan.progress = 100
scan.state = scan_state scan.state = scan_state
scan.unique_resource_count = num_resources scan.unique_resource_count = num_resources
@@ -1,5 +1,4 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from tasks.jobs.orphan_recovery import reconcile_orphans from tasks.jobs.orphan_recovery import reconcile_orphans
+1 -2
View File
@@ -1,11 +1,10 @@
import logging import logging
import time import time
from config.custom_logging import BackendLogger
from django.core.handlers.asgi import ASGIRequest from django.core.handlers.asgi import ASGIRequest
from django.db import connections from django.db import connections
from config.custom_logging import BackendLogger
class CloseDBConnectionsMiddleware: class CloseDBConnectionsMiddleware:
""" """
+13 -14
View File
@@ -1,26 +1,13 @@
import uuid import uuid
from functools import partial from functools import partial
import api.rls
import django.contrib.auth.models import django.contrib.auth.models
import django.contrib.postgres.indexes import django.contrib.postgres.indexes
import django.contrib.postgres.search import django.contrib.postgres.search
import django.core.validators import django.core.validators
import django.db.models.deletion import django.db.models.deletion
import django.utils.timezone import django.utils.timezone
from django.conf import settings
from django.db import migrations, models
from psqlextra.backend.migrations.operations.add_default_partition import (
PostgresAddDefaultPartition,
)
from psqlextra.backend.migrations.operations.create_partitioned_model import (
PostgresCreatePartitionedModel,
)
from psqlextra.manager.manager import PostgresManager
from psqlextra.models.partitioned import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
import api.rls
from api.db_utils import ( from api.db_utils import (
DB_PROWLER_PASSWORD, DB_PROWLER_PASSWORD,
DB_PROWLER_USER, DB_PROWLER_USER,
@@ -53,6 +40,18 @@ from api.models import (
StateChoices, StateChoices,
StatusChoices, StatusChoices,
) )
from django.conf import settings
from django.db import migrations, models
from psqlextra.backend.migrations.operations.add_default_partition import (
PostgresAddDefaultPartition,
)
from psqlextra.backend.migrations.operations.create_partitioned_model import (
PostgresCreatePartitionedModel,
)
from psqlextra.manager.manager import PostgresManager
from psqlextra.models.partitioned import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
DB_NAME = settings.DATABASES["default"]["NAME"] DB_NAME = settings.DATABASES["default"]["NAME"]
@@ -1,8 +1,7 @@
from api.db_utils import DB_PROWLER_USER
from django.conf import settings from django.conf import settings
from django.db import migrations from django.db import migrations
from api.db_utils import DB_PROWLER_USER
DB_NAME = settings.DATABASES["default"]["NAME"] DB_NAME = settings.DATABASES["default"]["NAME"]
+1 -2
View File
@@ -2,12 +2,11 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -1,6 +1,5 @@
from django.db import migrations
from api.db_router import MainRouter from api.db_router import MainRouter
from django.db import migrations
def create_admin_role(apps, schema_editor): def create_admin_role(apps, schema_editor):
@@ -1,12 +1,11 @@
import json import json
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models
from django_celery_beat.models import PeriodicTask
from api.db_utils import rls_transaction from api.db_utils import rls_transaction
from api.models import Scan, StateChoices from api.models import Scan, StateChoices
from django.db import migrations, models
from django_celery_beat.models import PeriodicTask
def migrate_daily_scheduled_scan_tasks(apps, schema_editor): def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
@@ -17,11 +16,11 @@ def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
tenant_id = task_kwargs["tenant_id"] tenant_id = task_kwargs["tenant_id"]
provider_id = task_kwargs["provider_id"] provider_id = task_kwargs["provider_id"]
current_time = datetime.now(timezone.utc) current_time = datetime.now(UTC)
scheduled_time_today = datetime.combine( scheduled_time_today = datetime.combine(
current_time.date(), current_time.date(),
daily_scheduled_scan_task.start_time.time(), daily_scheduled_scan_task.start_time.time(),
tzinfo=timezone.utc, tzinfo=UTC,
) )
if current_time < scheduled_time_today: if current_time < scheduled_time_today:
@@ -2,10 +2,9 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import IntegrationTypeEnum, PostgresEnumMigration, register_enum from api.db_utils import IntegrationTypeEnum, PostgresEnumMigration, register_enum
from api.models import Integration from api.models import Integration
from django.db import migrations
IntegrationTypeEnumMigration = PostgresEnumMigration( IntegrationTypeEnumMigration = PostgresEnumMigration(
enum_name="integration_type", enum_name="integration_type",
@@ -2,12 +2,11 @@
import uuid import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils import api.db_utils
import api.rls import api.rls
import django.db.models.deletion
from api.rls import RowLevelSecurityConstraint from api.rls import RowLevelSecurityConstraint
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.5 on 2025-03-25 11:29 # Generated by Django 5.1.5 on 2025-03-25 11:29
from django.db import migrations, models
import api.db_utils import api.db_utils
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.7 on 2025-04-16 08:47 # Generated by Django 5.1.7 on 2025-04-16 08:47
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -2,12 +2,11 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
import uuid6 import uuid6
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -1,8 +1,7 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -2,12 +2,11 @@
import uuid import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils import api.db_utils
import api.rls import api.rls
import django.db.models.deletion
from api.rls import RowLevelSecurityConstraint from api.rls import RowLevelSecurityConstraint
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -2,12 +2,11 @@
import uuid import uuid
import api.rls
import django.core.validators import django.core.validators
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
+2 -3
View File
@@ -2,13 +2,12 @@
import uuid import uuid
import api.db_utils
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import api.db_utils
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -2,10 +2,9 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import PostgresEnumMigration, ProcessorTypeEnum, register_enum from api.db_utils import PostgresEnumMigration, ProcessorTypeEnum, register_enum
from api.models import Processor from api.models import Processor
from django.db import migrations
ProcessorTypeEnumMigration = PostgresEnumMigration( ProcessorTypeEnumMigration = PostgresEnumMigration(
enum_name="processor_type", enum_name="processor_type",
@@ -2,12 +2,11 @@
import uuid import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils import api.db_utils
import api.rls import api.rls
import django.db.models.deletion
from api.rls import RowLevelSecurityConstraint from api.rls import RowLevelSecurityConstraint
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.7 on 2025-07-09 14:44 # Generated by Django 5.1.7 on 2025-07-09 14:44
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -2,15 +2,14 @@
import uuid import uuid
import api.db_utils
import api.rls
import django.core.validators import django.core.validators
import django.db.models.deletion import django.db.models.deletion
import drf_simple_apikey.models import drf_simple_apikey.models
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import api.db_utils
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -4,15 +4,14 @@ import json
import logging import logging
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from api.db_router import MainRouter
from config.custom_logging import BackendLogger from config.custom_logging import BackendLogger
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import api.rls
from api.db_router import MainRouter
logger = logging.getLogger(BackendLogger.API) logger = logging.getLogger(BackendLogger.API)
@@ -1,8 +1,7 @@
# Generated by Django 5.1.7 on 2025-10-14 00:00 # Generated by Django 5.1.7 on 2025-10-14 00:00
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -2,14 +2,13 @@
import uuid import uuid
import api.rls
import django.contrib.postgres.fields import django.contrib.postgres.fields
import django.core.validators import django.core.validators
import django.db.models.deletion import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -1,8 +1,7 @@
# Generated by Django 5.1.10 on 2025-09-09 09:25 # Generated by Django 5.1.10 on 2025-09-09 09:25
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django 5.1.13 on 2025-11-05 08:37 # Generated by Django 5.1.13 on 2025-11-05 08:37
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -2,11 +2,10 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -2,11 +2,10 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -2,11 +2,10 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -2,11 +2,10 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -1,10 +1,9 @@
# Generated by Django 5.1.14 on 2025-12-10 # Generated by Django 5.1.14 on 2025-12-10
from django.db import migrations
from tasks.tasks import backfill_daily_severity_summaries_task
from api.db_router import MainRouter from api.db_router import MainRouter
from api.rls import Tenant from api.rls import Tenant
from django.db import migrations
from tasks.tasks import backfill_daily_severity_summaries_task
def trigger_backfill_task(apps, schema_editor): def trigger_backfill_task(apps, schema_editor):
@@ -1,10 +1,9 @@
import uuid import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils import api.db_utils
import api.rls import api.rls
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django migration for Alibaba Cloud provider support # Generated by Django migration for Alibaba Cloud provider support
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,10 +1,9 @@
import uuid import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils import api.db_utils
import api.rls import api.rls
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,10 +1,9 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -1,10 +1,9 @@
import uuid import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils import api.db_utils
import api.rls import api.rls
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,12 +1,10 @@
# Generated by Django 5.1.13 on 2025-11-06 16:20 # Generated by Django 5.1.13 on 2025-11-06 16:20
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
from uuid6 import uuid7 from uuid6 import uuid7
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -1,8 +1,7 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django migration for Cloudflare provider support # Generated by Django migration for Cloudflare provider support
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
# Generated by Django migration for OpenStack provider support # Generated by Django migration for OpenStack provider support
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -2,9 +2,8 @@
# on different database connections, causing a deadlock when combined with RunPython # on different database connections, causing a deadlock when combined with RunPython
# in the same migration. # in the same migration.
from django.db import migrations
from api.db_router import MainRouter from api.db_router import MainRouter
from django.db import migrations
def backfill_graph_data_ready(apps, schema_editor): def backfill_graph_data_ready(apps, schema_editor):
@@ -2,14 +2,13 @@
import uuid import uuid
import api.rls
import django.db.models.deletion import django.db.models.deletion
from django.contrib.postgres.indexes import GinIndex, OpClass from django.contrib.postgres.indexes import GinIndex, OpClass
from django.db import migrations, models from django.db import migrations, models
from django.db.models.functions import Upper from django.db.models.functions import Upper
from django.utils import timezone from django.utils import timezone
import api.rls
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
@@ -1,10 +1,9 @@
# Generated by Django 5.1.14 on 2026-02-02 # Generated by Django 5.1.14 on 2026-02-02
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
from api.db_router import MainRouter from api.db_router import MainRouter
from api.rls import Tenant from api.rls import Tenant
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
def trigger_backfill_task(apps, schema_editor): def trigger_backfill_task(apps, schema_editor):
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations from django.db import migrations
TASK_NAME = "attack-paths-cleanup-stale-scans" TASK_NAME = "attack-paths-cleanup-stale-scans"
INTERVAL_HOURS = 1 INTERVAL_HOURS = 1
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,8 +1,7 @@
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
from api.db_router import MainRouter from api.db_router import MainRouter
from api.rls import Tenant from api.rls import Tenant
from django.db import migrations
from tasks.tasks import backfill_finding_group_summaries_task
def trigger_backfill_task(apps, schema_editor): def trigger_backfill_task(apps, schema_editor):
@@ -1,8 +1,7 @@
from functools import partial from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions from api.db_utils import create_index_on_partitions, drop_index_on_partitions
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations
import api.db_utils import api.db_utils
from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
@@ -1,6 +1,5 @@
from django.db import migrations from django.db import migrations
TASK_NAME = "reconcile-orphan-tasks" TASK_NAME = "reconcile-orphan-tasks"
INTERVAL_MINUTES = 2 INTERVAL_MINUTES = 2
+30 -31
View File
@@ -1,37 +1,11 @@
import json import json
import logging import logging
import re import re
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import defusedxml import defusedxml
from allauth.socialaccount.models import SocialApp from allauth.socialaccount.models import SocialApp
from config.custom_logging import BackendLogger
from config.settings.social_login import SOCIALACCOUNT_PROVIDERS
from cryptography.fernet import Fernet, InvalidToken
from defusedxml import ElementTree as ET
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex, OpClass
from django.contrib.postgres.search import SearchVector, SearchVectorField
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q
from django.db.models.functions import Upper
from django.utils import timezone as django_timezone
from django.utils.translation import gettext_lazy as _
from django_celery_beat.models import PeriodicTask
from django_celery_results.models import TaskResult
from drf_simple_apikey.crypto import get_crypto
from drf_simple_apikey.models import AbstractAPIKey, AbstractAPIKeyManager
from psqlextra.manager import PostgresManager
from psqlextra.models import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
from api.db_router import MainRouter from api.db_router import MainRouter
from api.db_utils import ( from api.db_utils import (
CustomUserManager, CustomUserManager,
@@ -58,7 +32,32 @@ from api.rls import (
RowLevelSecurityProtectedModel, RowLevelSecurityProtectedModel,
Tenant, Tenant,
) )
from config.custom_logging import BackendLogger
from config.settings.social_login import SOCIALACCOUNT_PROVIDERS
from cryptography.fernet import Fernet, InvalidToken
from defusedxml import ElementTree as ET
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex, OpClass
from django.contrib.postgres.search import SearchVector, SearchVectorField
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q
from django.db.models.functions import Upper
from django.utils import timezone as django_timezone
from django.utils.translation import gettext_lazy as _
from django_celery_beat.models import PeriodicTask
from django_celery_results.models import TaskResult
from drf_simple_apikey.crypto import get_crypto
from drf_simple_apikey.models import AbstractAPIKey, AbstractAPIKeyManager
from prowler.lib.check.models import Severity from prowler.lib.check.models import Severity
from psqlextra.manager import PostgresManager
from psqlextra.models import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
fernet = Fernet(settings.SECRETS_ENCRYPTION_KEY.encode()) fernet = Fernet(settings.SECRETS_ENCRYPTION_KEY.encode())
@@ -1427,8 +1426,8 @@ class Role(RowLevelSecurityProtectedModel):
@classmethod @classmethod
def filter_by_permission_state(cls, queryset, value): def filter_by_permission_state(cls, queryset, value):
q_all_true = Q(**{field: True for field in cls.PERMISSION_FIELDS}) q_all_true = Q(**dict.fromkeys(cls.PERMISSION_FIELDS, True))
q_all_false = Q(**{field: False for field in cls.PERMISSION_FIELDS}) q_all_false = Q(**dict.fromkeys(cls.PERMISSION_FIELDS, False))
if value == PermissionChoices.UNLIMITED: if value == PermissionChoices.UNLIMITED:
return queryset.filter(q_all_true) return queryset.filter(q_all_true)
@@ -2011,11 +2010,11 @@ class SAMLToken(models.Model):
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if not self.expires_at: if not self.expires_at:
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15) self.expires_at = datetime.now(UTC) + timedelta(seconds=15)
super().save(*args, **kwargs) super().save(*args, **kwargs)
def is_expired(self) -> bool: def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at return datetime.now(UTC) >= self.expires_at
class SAMLDomainIndex(models.Model): class SAMLDomainIndex(models.Model):
+20 -23
View File
@@ -1,21 +1,20 @@
from datetime import datetime, timezone from collections.abc import Generator
from typing import Generator, Optional from datetime import UTC, datetime
from dateutil.relativedelta import relativedelta
from django.conf import settings
from psqlextra.partitioning import (
PostgresPartitioningManager,
PostgresRangePartition,
PostgresRangePartitioningStrategy,
PostgresTimePartitionSize,
PostgresPartitioningError,
)
from psqlextra.partitioning.config import PostgresPartitioningConfig
from uuid6 import UUID
from api.models import Finding, ResourceFindingMapping from api.models import Finding, ResourceFindingMapping
from api.rls import RowLevelSecurityConstraint from api.rls import RowLevelSecurityConstraint
from api.uuid_utils import datetime_to_uuid7 from api.uuid_utils import datetime_to_uuid7
from dateutil.relativedelta import relativedelta
from django.conf import settings
from psqlextra.partitioning import (
PostgresPartitioningError,
PostgresPartitioningManager,
PostgresRangePartition,
PostgresRangePartitioningStrategy,
PostgresTimePartitionSize,
)
from psqlextra.partitioning.config import PostgresPartitioningConfig
from uuid6 import UUID
class PostgresUUIDv7RangePartition(PostgresRangePartition): class PostgresUUIDv7RangePartition(PostgresRangePartition):
@@ -24,7 +23,7 @@ class PostgresUUIDv7RangePartition(PostgresRangePartition):
from_values: UUID, from_values: UUID,
to_values: UUID, to_values: UUID,
size: PostgresTimePartitionSize, size: PostgresTimePartitionSize,
name_format: Optional[str] = None, name_format: str | None = None,
**kwargs, **kwargs,
) -> None: ) -> None:
self.from_values = from_values self.from_values = from_values
@@ -38,9 +37,7 @@ class PostgresUUIDv7RangePartition(PostgresRangePartition):
start_timestamp_ms = self.from_values.time start_timestamp_ms = self.from_values.time
self.start_datetime = datetime.fromtimestamp( self.start_datetime = datetime.fromtimestamp(start_timestamp_ms / 1000, UTC)
start_timestamp_ms / 1000, timezone.utc
)
def name(self) -> str: def name(self) -> str:
if not self.name_format: if not self.name_format:
@@ -82,8 +79,8 @@ class PostgresUUIDv7PartitioningStrategy(PostgresRangePartitioningStrategy):
size: PostgresTimePartitionSize, size: PostgresTimePartitionSize,
count: int, count: int,
start_date: datetime = None, start_date: datetime = None,
max_age: Optional[relativedelta] = None, max_age: relativedelta | None = None,
name_format: Optional[str] = None, name_format: str | None = None,
**kwargs, **kwargs,
) -> None: ) -> None:
self.start_date = start_date.replace( self.start_date = start_date.replace(
@@ -151,7 +148,7 @@ class PostgresUUIDv7PartitioningStrategy(PostgresRangePartitioningStrategy):
Returns: Returns:
datetime: A `datetime` object representing the start of the current month in UTC. datetime: A `datetime` object representing the start of the current month in UTC.
""" """
return datetime.now(timezone.utc).replace( return datetime.now(UTC).replace(
day=1, hour=0, minute=0, second=0, microsecond=0 day=1, hour=0, minute=0, second=0, microsecond=0
) )
@@ -171,7 +168,7 @@ manager = PostgresPartitioningManager(
PostgresPartitioningConfig( PostgresPartitioningConfig(
model=Finding, model=Finding,
strategy=PostgresUUIDv7PartitioningStrategy( strategy=PostgresUUIDv7PartitioningStrategy(
start_date=datetime.now(timezone.utc), start_date=datetime.now(UTC),
size=PostgresTimePartitionSize( size=PostgresTimePartitionSize(
months=settings.FINDINGS_TABLE_PARTITION_MONTHS months=settings.FINDINGS_TABLE_PARTITION_MONTHS
), ),
@@ -187,7 +184,7 @@ manager = PostgresPartitioningManager(
PostgresPartitioningConfig( PostgresPartitioningConfig(
model=ResourceFindingMapping, model=ResourceFindingMapping,
strategy=PostgresUUIDv7PartitioningStrategy( strategy=PostgresUUIDv7PartitioningStrategy(
start_date=datetime.now(timezone.utc), start_date=datetime.now(UTC),
size=PostgresTimePartitionSize( size=PostgresTimePartitionSize(
months=settings.FINDINGS_TABLE_PARTITION_MONTHS months=settings.FINDINGS_TABLE_PARTITION_MONTHS
), ),
+3 -4
View File
@@ -1,11 +1,10 @@
from enum import Enum from enum import Enum
from django.db.models import QuerySet
from rest_framework.exceptions import PermissionDenied
from rest_framework.permissions import BasePermission
from api.db_router import MainRouter from api.db_router import MainRouter
from api.models import Provider, Role, User from api.models import Provider, Role, User
from django.db.models import QuerySet
from rest_framework.exceptions import PermissionDenied
from rest_framework.permissions import BasePermission
class Permissions(Enum): class Permissions(Enum):
+1 -2
View File
@@ -1,10 +1,9 @@
from contextlib import nullcontext from contextlib import nullcontext
from api.db_utils import rls_transaction
from rest_framework.renderers import BaseRenderer from rest_framework.renderers import BaseRenderer
from rest_framework_json_api.renderers import JSONRenderer from rest_framework_json_api.renderers import JSONRenderer
from api.db_utils import rls_transaction
class PlainTextRenderer(BaseRenderer): class PlainTextRenderer(BaseRenderer):
media_type = "text/plain" media_type = "text/plain"
+1 -2
View File
@@ -1,12 +1,11 @@
from typing import Any from typing import Any
from uuid import uuid4 from uuid import uuid4
from api.db_utils import DB_USER, POSTGRES_TENANT_VAR
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import DEFAULT_DB_ALIAS, models from django.db import DEFAULT_DB_ALIAS, models
from django.db.backends.ddl_references import Statement, Table from django.db.backends.ddl_references import Statement, Table
from api.db_utils import DB_USER, POSTGRES_TENANT_VAR
class Tenant(models.Model): class Tenant(models.Model):
""" """
+6 -7
View File
@@ -1,10 +1,3 @@
from celery import states
from celery.signals import before_task_publish
from config.celery import celery_app
from django.db.models.signals import post_delete, pre_delete
from django.dispatch import receiver
from django_celery_results.backends.database import DatabaseBackend
from api.db_utils import delete_related_daily_task from api.db_utils import delete_related_daily_task
from api.models import ( from api.models import (
LighthouseProviderConfiguration, LighthouseProviderConfiguration,
@@ -14,6 +7,12 @@ from api.models import (
TenantAPIKey, TenantAPIKey,
User, User,
) )
from celery import states
from celery.signals import before_task_publish
from config.celery import celery_app
from django.db.models.signals import post_delete, pre_delete
from django.dispatch import receiver
from django_celery_results.backends.database import DatabaseBackend
def create_task_result_on_publish(sender=None, headers=None, **kwargs): # noqa: F841 def create_task_result_on_publish(sender=None, headers=None, **kwargs): # noqa: F841
+1 -1
View File
@@ -7,7 +7,7 @@ enforces the tenant gate (:class:`api.sse.channelmanager.SSEChannelManager`),
and the channel-name helpers (:func:`api.sse.utils.make_channel_name`). and the channel-name helpers (:func:`api.sse.utils.make_channel_name`).
""" """
from api.sse.utils import make_channel_name
from api.sse.base_views import BaseSSEViewSet from api.sse.base_views import BaseSSEViewSet
from api.sse.utils import make_channel_name
__all__ = ["BaseSSEViewSet", "make_channel_name"] __all__ = ["BaseSSEViewSet", "make_channel_name"]
+2 -3
View File
@@ -5,11 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from api.sse.utils import tenant_id_from_channel
from django_eventstream.channelmanager import DefaultChannelManager from django_eventstream.channelmanager import DefaultChannelManager
from rest_framework.request import Request from rest_framework.request import Request
from api.sse.utils import tenant_id_from_channel
if TYPE_CHECKING: if TYPE_CHECKING:
from api.models import User from api.models import User
@@ -41,7 +40,7 @@ class SSEChannelManager(DefaultChannelManager):
if tenant_id_from_channel(channel) == request_tenant_id if tenant_id_from_channel(channel) == request_tenant_id
} }
def can_read_channel(self, user: "User | None", channel: str) -> bool: def can_read_channel(self, user: User | None, channel: str) -> bool:
"""Re-verify tenant membership once the stream is established. """Re-verify tenant membership once the stream is established.
Args: Args:
@@ -1,15 +1,14 @@
import time import time
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from api.models import Membership, Role, TenantAPIKey, User, UserRoleRelationship
from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header
from django.urls import reverse from django.urls import reverse
from drf_simple_apikey.crypto import get_crypto from drf_simple_apikey.crypto import get_crypto
from rest_framework.test import APIClient from rest_framework.test import APIClient
from api.models import Membership, Role, TenantAPIKey, User, UserRoleRelationship
@pytest.mark.django_db @pytest.mark.django_db
def test_basic_authentication(): def test_basic_authentication():
@@ -468,7 +467,7 @@ class TestAPIKeyErrors:
name="Expired Key", name="Expired Key",
tenant_id=tenants_fixture[0].id, tenant_id=tenants_fixture[0].id,
entity=create_test_user, entity=create_test_user,
expiry_date=datetime.now(timezone.utc) - timedelta(days=1), expiry_date=datetime.now(UTC) - timedelta(days=1),
) )
api_key_headers = get_api_key_header(raw_key) api_key_headers = get_api_key_header(raw_key)
@@ -500,7 +499,7 @@ class TestAPIKeyErrors:
# Create a valid-looking key with non-existent UUID # Create a valid-looking key with non-existent UUID
crypto = get_crypto() crypto = get_crypto()
fake_uuid = str(uuid4()) fake_uuid = str(uuid4())
fake_expiry = (datetime.now(timezone.utc) + timedelta(days=30)).timestamp() fake_expiry = (datetime.now(UTC) + timedelta(days=30)).timestamp()
payload = {"_pk": fake_uuid, "_exp": fake_expiry} payload = {"_pk": fake_uuid, "_exp": fake_expiry}
encrypted_payload = crypto.generate(payload) encrypted_payload = crypto.generate(payload)
@@ -723,7 +722,7 @@ class TestAPIKeyLifecycle:
assert created_data["attributes"]["revoked"] is False assert created_data["attributes"]["revoked"] is False
# Create API key with expiry # Create API key with expiry
future_expiry = (datetime.now(timezone.utc) + timedelta(days=90)).isoformat() future_expiry = (datetime.now(UTC) + timedelta(days=90)).isoformat()
create_with_expiry_response = client.post( create_with_expiry_response = client.post(
reverse("api-key-list"), reverse("api-key-list"),
data={ data={
@@ -927,9 +926,9 @@ class TestAPIKeyLifecycle:
auth_response = client.get(reverse("provider-list"), headers=api_key_headers) auth_response = client.get(reverse("provider-list"), headers=api_key_headers)
# Must return 401 Unauthorized, not 500 Internal Server Error # Must return 401 Unauthorized, not 500 Internal Server Error
assert ( assert auth_response.status_code == 401, (
auth_response.status_code == 401 f"Expected 401 but got {auth_response.status_code}: {auth_response.json()}"
), f"Expected 401 but got {auth_response.status_code}: {auth_response.json()}" )
# Verify error message is present # Verify error message is present
response_json = auth_response.json() response_json = auth_response.json()
@@ -1267,7 +1266,7 @@ class TestAPIKeyRLSBypass:
name="Expired Test Key", name="Expired Test Key",
tenant_id=tenant.id, tenant_id=tenant.id,
entity=create_test_user, entity=create_test_user,
expiry_date=datetime.now(timezone.utc) - timedelta(days=1), expiry_date=datetime.now(UTC) - timedelta(days=1),
) )
api_key_headers = get_api_key_header(raw_key) api_key_headers = get_api_key_header(raw_key)
@@ -1,12 +1,11 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from api.models import Provider
from conftest import get_api_tokens, get_authorization_header from conftest import get_api_tokens, get_authorization_header
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APIClient from rest_framework.test import APIClient
from api.models import Provider
@patch("api.v1.views.Task.objects.get") @patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.delete_provider_task.delay") @patch("api.v1.views.delete_provider_task.delay")
@@ -1,11 +1,10 @@
"""Tests for rls_transaction retry and fallback logic.""" """Tests for rls_transaction retry and fallback logic."""
import pytest import pytest
from api.db_utils import rls_transaction
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
from rest_framework_json_api.serializers import ValidationError from rest_framework_json_api.serializers import ValidationError
from api.db_utils import rls_transaction
@pytest.mark.django_db @pytest.mark.django_db
class TestRLSTransaction: class TestRLSTransaction:
@@ -1,10 +1,9 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from conftest import TEST_PASSWORD, TEST_USER, get_api_tokens, get_authorization_header
from django.urls import reverse from django.urls import reverse
from conftest import TEST_USER, TEST_PASSWORD, get_api_tokens, get_authorization_header
@patch("api.v1.views.schedule_provider_scan") @patch("api.v1.views.schedule_provider_scan")
@pytest.mark.django_db @pytest.mark.django_db
+1 -2
View File
@@ -3,11 +3,10 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from allauth.socialaccount.models import SocialLogin from allauth.socialaccount.models import SocialLogin
from django.contrib.auth import get_user_model
from api.adapters import ProwlerSocialAccountAdapter from api.adapters import ProwlerSocialAccountAdapter
from api.db_router import MainRouter from api.db_router import MainRouter
from api.models import SAMLConfiguration from api.models import SAMLConfiguration
from django.contrib.auth import get_user_model
User = get_user_model() User = get_user_model()
+2 -3
View File
@@ -4,11 +4,9 @@ import types
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from django.conf import settings
import api import api
import api.apps as api_apps_module import api.apps as api_apps_module
import pytest
from api.apps import ( from api.apps import (
PRIVATE_KEY_FILE, PRIVATE_KEY_FILE,
PUBLIC_KEY_FILE, PUBLIC_KEY_FILE,
@@ -16,6 +14,7 @@ from api.apps import (
VERIFYING_KEY_ENV, VERIFYING_KEY_ENV,
ApiConfig, ApiConfig,
) )
from django.conf import settings
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -1,14 +1,12 @@
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
import neo4j import neo4j
import neo4j.exceptions import neo4j.exceptions
import pytest
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from api.attack_paths import database as graph_database from api.attack_paths import database as graph_database
from api.attack_paths import views_helpers from api.attack_paths import views_helpers
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from tasks.jobs.attack_paths.config import ( from tasks.jobs.attack_paths.config import (
PROVIDER_ELEMENT_ID_PROPERTY, PROVIDER_ELEMENT_ID_PROPERTY,
get_provider_label, get_provider_label,
@@ -6,15 +6,13 @@ never contacts Neo4j. These tests validate the database module behavior itself.
""" """
import threading import threading
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import api.attack_paths.database as db_module
import neo4j import neo4j
import neo4j.exceptions import neo4j.exceptions
import pytest import pytest
import api.attack_paths.database as db_module
class TestLazyInitialization: class TestLazyInitialization:
"""Test that Neo4j driver is initialized lazily on first use.""" """Test that Neo4j driver is initialized lazily on first use."""
@@ -1,15 +1,14 @@
import time import time
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from django.test import RequestFactory
from rest_framework.exceptions import AuthenticationFailed
from api.authentication import SSEAuthentication, TenantAPIKeyAuthentication from api.authentication import SSEAuthentication, TenantAPIKeyAuthentication
from api.db_router import MainRouter from api.db_router import MainRouter
from api.models import TenantAPIKey from api.models import TenantAPIKey
from django.test import RequestFactory
from rest_framework.exceptions import AuthenticationFailed
@pytest.mark.django_db @pytest.mark.django_db
@@ -104,7 +103,7 @@ class TestTenantAPIKeyAuthentication:
# Verify that last_used_at was updated # Verify that last_used_at was updated
api_key.refresh_from_db() api_key.refresh_from_db()
assert api_key.last_used_at is not None assert api_key.last_used_at is not None
assert (datetime.now(timezone.utc) - api_key.last_used_at).seconds < 5 assert (datetime.now(UTC) - api_key.last_used_at).seconds < 5
def test_authenticate_valid_api_key_uses_admin_database( def test_authenticate_valid_api_key_uses_admin_database(
self, auth_backend, api_keys_fixture, request_factory self, auth_backend, api_keys_fixture, request_factory
@@ -195,7 +194,7 @@ class TestTenantAPIKeyAuthentication:
name="Expired API Key", name="Expired API Key",
tenant_id=tenant.id, tenant_id=tenant.id,
entity=user, entity=user,
expiry_date=datetime.now(timezone.utc) - timedelta(days=1), expiry_date=datetime.now(UTC) - timedelta(days=1),
) )
request = request_factory.get("/") request = request_factory.get("/")
@@ -217,7 +216,7 @@ class TestTenantAPIKeyAuthentication:
# Manually create an encrypted key with a non-existent ID # Manually create an encrypted key with a non-existent ID
payload = { payload = {
"_pk": non_existent_uuid, "_pk": non_existent_uuid,
"_exp": (datetime.now(timezone.utc) + timedelta(days=30)).timestamp(), "_exp": (datetime.now(UTC) + timedelta(days=30)).timestamp(),
} }
encrypted_key = auth_backend.key_crypto.generate(payload) encrypted_key = auth_backend.key_crypto.generate(payload)
fake_key = f"{api_key.prefix}.{encrypted_key}" fake_key = f"{api_key.prefix}.{encrypted_key}"
@@ -368,7 +367,7 @@ class TestTenantAPIKeyAuthentication:
name="Short-lived API Key", name="Short-lived API Key",
tenant_id=tenant.id, tenant_id=tenant.id,
entity=user, entity=user,
expiry_date=datetime.now(timezone.utc) + timedelta(seconds=1), expiry_date=datetime.now(UTC) + timedelta(seconds=1),
) )
# Wait for the key to expire # Wait for the key to expire
@@ -1,7 +1,6 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from api import compliance as compliance_module from api import compliance as compliance_module
from api.compliance import ( from api.compliance import (
generate_compliance_overview_template, generate_compliance_overview_template,
@@ -3,13 +3,11 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from rest_framework.exceptions import ValidationError
from api.attack_paths.cypher_sanitizer import ( from api.attack_paths.cypher_sanitizer import (
inject_provider_label, inject_provider_label,
validate_custom_query, validate_custom_query,
) )
from rest_framework.exceptions import ValidationError
PROVIDER_ID = "019c41ee-7df3-7dec-a684-d839f95619f8" PROVIDER_ID = "019c41ee-7df3-7dec-a684-d839f95619f8"
LABEL = "_Provider_019c41ee7df37deca684d839f95619f8" LABEL = "_Provider_019c41ee7df37deca684d839f95619f8"
@@ -202,9 +200,7 @@ class TestClauseSplitting:
def test_multiple_match_clauses(self): def test_multiple_match_clauses(self):
cypher = ( cypher = (
"MATCH (a:AWSAccount)--(b:AWSRole) " "MATCH (a:AWSAccount)--(b:AWSRole) MATCH (b)--(c:AWSPolicy) RETURN a, b, c"
"MATCH (b)--(c:AWSPolicy) "
"RETURN a, b, c"
) )
result = _inject(cypher) result = _inject(cypher)
assert f"(a:AWSAccount:{LABEL})" in result assert f"(a:AWSAccount:{LABEL})" in result
@@ -265,9 +261,7 @@ class TestRealWorldQueries:
def test_custom_bare_query(self): def test_custom_bare_query(self):
cypher = ( cypher = (
"MATCH (a)-[:HAS_POLICY]->(b)\n" "MATCH (a)-[:HAS_POLICY]->(b)\nWHERE a.name CONTAINS 'admin'\nRETURN a, b"
"WHERE a.name CONTAINS 'admin'\n"
"RETURN a, b"
) )
result = _inject(cypher) result = _inject(cypher)
assert f"(a:{LABEL})" in result assert f"(a:{LABEL})" in result
@@ -344,9 +338,7 @@ class TestEdgeCases:
assert f"(outer:AWSAccount:{LABEL})" in result assert f"(outer:AWSAccount:{LABEL})" in result
def test_multiple_protected_regions(self): def test_multiple_protected_regions(self):
cypher = ( cypher = "MATCH (n:X {a: 'hello'}) WHERE n.b = \"world\" // comment\nRETURN n"
"MATCH (n:X {a: 'hello'}) " 'WHERE n.b = "world" ' "// comment\n" "RETURN n"
)
result = _inject(cypher) result = _inject(cypher)
assert "'hello'" in result assert "'hello'" in result
assert '"world"' in result assert '"world"' in result
+5 -5
View File
@@ -1,12 +1,12 @@
import pytest from unittest.mock import patch
from django.conf import settings
from django.db.migrations.recorder import MigrationRecorder
from django.db.utils import ConnectionRouter
import pytest
from api.db_router import MainRouter from api.db_router import MainRouter
from api.rls import Tenant from api.rls import Tenant
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
from unittest.mock import patch from django.conf import settings
from django.db.migrations.recorder import MigrationRecorder
from django.db.utils import ConnectionRouter
@patch("api.db_router.MainRouter.admin_db", new="admin") @patch("api.db_router.MainRouter.admin_db", new="admin")
+16 -19
View File
@@ -1,14 +1,8 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from enum import Enum from enum import Enum
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, OperationalError
from freezegun import freeze_time
from psycopg2 import sql as psycopg2_sql
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import ( from api.db_utils import (
POSTGRES_TENANT_VAR, POSTGRES_TENANT_VAR,
PostgresEnumMigration, PostgresEnumMigration,
@@ -23,6 +17,11 @@ from api.db_utils import (
update_objects_in_batches, update_objects_in_batches,
) )
from api.models import Provider from api.models import Provider
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, OperationalError
from freezegun import freeze_time
from psycopg2 import sql as psycopg2_sql
from rest_framework_json_api.serializers import ValidationError
@pytest.fixture @pytest.fixture
@@ -94,18 +93,16 @@ class TestEnumToChoices:
class TestOneWeekFromNow: class TestOneWeekFromNow:
def test_one_week_from_now(self): def test_one_week_from_now(self):
with patch("api.db_utils.datetime") as mock_datetime: with patch("api.db_utils.datetime") as mock_datetime:
mock_datetime.now.return_value = datetime(2023, 1, 1, tzinfo=timezone.utc) mock_datetime.now.return_value = datetime(2023, 1, 1, tzinfo=UTC)
expected_result = datetime(2023, 1, 8, tzinfo=timezone.utc) expected_result = datetime(2023, 1, 8, tzinfo=UTC)
result = one_week_from_now() result = one_week_from_now()
assert result == expected_result assert result == expected_result
def test_one_week_from_now_with_timezone(self): def test_one_week_from_now_with_timezone(self):
with patch("api.db_utils.datetime") as mock_datetime: with patch("api.db_utils.datetime") as mock_datetime:
mock_datetime.now.return_value = datetime( mock_datetime.now.return_value = datetime(2023, 6, 15, 12, 0, tzinfo=UTC)
2023, 6, 15, 12, 0, tzinfo=timezone.utc expected_result = datetime(2023, 6, 22, 12, 0, tzinfo=UTC)
)
expected_result = datetime(2023, 6, 22, 12, 0, tzinfo=timezone.utc)
result = one_week_from_now() result = one_week_from_now()
assert result == expected_result assert result == expected_result
@@ -939,9 +936,9 @@ class TestPostgresEnumMigration:
mock_cursor.execute.assert_called_once() mock_cursor.execute.assert_called_once()
query_arg = mock_cursor.execute.call_args[0][0] query_arg = mock_cursor.execute.call_args[0][0]
assert isinstance( assert isinstance(query_arg, psycopg2_sql.Composable), (
query_arg, psycopg2_sql.Composable "create_enum_type must pass a psycopg2.sql.Composable, not a raw string."
), "create_enum_type must pass a psycopg2.sql.Composable, not a raw string." )
# Verify the composed SQL structure: CREATE TYPE <Identifier> AS ENUM (<Literals>) # Verify the composed SQL structure: CREATE TYPE <Identifier> AS ENUM (<Literals>)
parts = query_arg.seq parts = query_arg.seq
assert parts[0] == psycopg2_sql.SQL("CREATE TYPE ") assert parts[0] == psycopg2_sql.SQL("CREATE TYPE ")
@@ -962,9 +959,9 @@ class TestPostgresEnumMigration:
mock_cursor.execute.assert_called_once() mock_cursor.execute.assert_called_once()
query_arg = mock_cursor.execute.call_args[0][0] query_arg = mock_cursor.execute.call_args[0][0]
assert isinstance( assert isinstance(query_arg, psycopg2_sql.Composable), (
query_arg, psycopg2_sql.Composable "drop_enum_type must pass a psycopg2.sql.Composable, not a raw string."
), "drop_enum_type must pass a psycopg2.sql.Composable, not a raw string." )
# Verify the composed SQL structure: DROP TYPE <Identifier> # Verify the composed SQL structure: DROP TYPE <Identifier>
parts = query_arg.seq parts = query_arg.seq
assert parts[0] == psycopg2_sql.SQL("DROP TYPE ") assert parts[0] == psycopg2_sql.SQL("DROP TYPE ")
+2 -3
View File
@@ -2,12 +2,11 @@ import uuid
from unittest.mock import call, patch from unittest.mock import call, patch
import pytest import pytest
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.decorators import handle_provider_deletion, set_tenant from api.decorators import handle_provider_deletion, set_tenant
from api.exceptions import ProviderDeletedException from api.exceptions import ProviderDeletedException
from django.core.exceptions import ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError
@pytest.mark.django_db @pytest.mark.django_db
+1 -3
View File
@@ -7,15 +7,13 @@ Cover the IETF response envelope, status code mapping (200 / 503), the
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from api import health
from config import version as config_version from config import version as config_version
from django.core.cache import cache from django.core.cache import cache
from django.urls import reverse from django.urls import reverse
from rest_framework import status from rest_framework import status
from rest_framework.test import APIClient from rest_framework.test import APIClient
from api import health
HEALTH_MEDIA_TYPE = "application/health+json" HEALTH_MEDIA_TYPE = "application/health+json"
+1 -2
View File
@@ -1,11 +1,10 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from api.middleware import APILoggingMiddleware
from django.http import HttpResponse from django.http import HttpResponse
from django.test import RequestFactory from django.test import RequestFactory
from api.middleware import APILoggingMiddleware
@pytest.mark.django_db @pytest.mark.django_db
@patch("logging.getLogger") @patch("logging.getLogger")
+3 -4
View File
@@ -2,10 +2,6 @@ import json
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.exceptions import ( from api.exceptions import (
TaskFailedException, TaskFailedException,
TaskInProgressException, TaskInProgressException,
@@ -14,6 +10,9 @@ from api.exceptions import (
from api.models import Task, User from api.models import Task, User
from api.rls import Tenant from api.rls import Tenant
from api.v1.mixins import PaginateByPkMixin, TaskManagementMixin from api.v1.mixins import PaginateByPkMixin, TaskManagementMixin
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
@pytest.mark.django_db @pytest.mark.django_db

Some files were not shown because too many files have changed in this diff Show More