mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
chore: unify ruff tooling and route code quality through the Makefile (#11675)
This commit is contained in:
committed by
GitHub
parent
de7da3e960
commit
058a1dc8fe
+19
-2
@@ -14,7 +14,7 @@ dev = [
|
||||
"pytest-env==1.1.3",
|
||||
"pytest-randomly==3.15.0",
|
||||
"pytest-xdist==3.6.1",
|
||||
"ruff==0.5.0",
|
||||
"ruff==0.15.11",
|
||||
"tqdm==4.67.1",
|
||||
"vulture==2.14",
|
||||
"prek==0.3.9"
|
||||
@@ -73,6 +73,23 @@ package-mode = false
|
||||
requires-python = ">=3.11,<3.13"
|
||||
version = "1.33.0"
|
||||
|
||||
# Shared ruff baseline (kept in sync with mcp_server/pyproject.toml).
|
||||
# target-version tracks this project's lowest supported Python.
|
||||
[tool.ruff]
|
||||
src = ["src"]
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Defaults (E4/E7/E9, F) plus import sorting, modern-syntax upgrades, and
|
||||
# comprehension lints — all mechanically auto-fixable. flake8-bugbear (B) is a
|
||||
# good next step but needs manual cleanup (e.g. B904 raise-from), so it is left
|
||||
# out of the shared baseline for now.
|
||||
extend-select = [
|
||||
"I", # isort — import ordering (prek's isort hook covers only the SDK)
|
||||
"UP", # pyupgrade — modern syntax for the min supported Python
|
||||
"C4" # flake8-comprehensions
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
# Transitive pins matching master to avoid silent drift; bump deliberately.
|
||||
constraint-dependencies = [
|
||||
@@ -393,7 +410,7 @@ constraint-dependencies = [
|
||||
"rpds-py==0.30.0",
|
||||
"rsa==4.9.1",
|
||||
"ruamel-yaml==0.19.1",
|
||||
"ruff==0.5.0",
|
||||
"ruff==0.15.11",
|
||||
"s3transfer==0.14.0",
|
||||
"scaleway==2.10.3",
|
||||
"scaleway-core==2.10.3",
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
|
||||
from django.db import transaction
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
@@ -11,6 +9,7 @@ from api.models import (
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from django.db import transaction
|
||||
|
||||
|
||||
class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.conf import settings
|
||||
|
||||
from config.custom_logging import BackendLogger
|
||||
from config.env import env
|
||||
from django.apps import AppConfig
|
||||
from django.conf import settings
|
||||
|
||||
logger = logging.getLogger(BackendLogger.API)
|
||||
|
||||
@@ -30,8 +28,10 @@ class ApiConfig(AppConfig):
|
||||
name = "api"
|
||||
|
||||
def ready(self):
|
||||
from api import schema_extensions # noqa: F401
|
||||
from api import signals # noqa: F401
|
||||
from api import (
|
||||
schema_extensions, # noqa: F401
|
||||
signals, # noqa: F401
|
||||
)
|
||||
|
||||
# Generate required cryptographic keys if not present, but only if:
|
||||
# `"manage.py" not in sys.argv[0]`: If an external server (e.g., Gunicorn) is running the app
|
||||
|
||||
@@ -5,7 +5,6 @@ from api.attack_paths.queries import (
|
||||
get_query_by_id,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttackPathsQueryDefinition",
|
||||
"AttackPathsQueryParameterDefinition",
|
||||
|
||||
@@ -22,10 +22,8 @@ Label-injection pipeline:
|
||||
import re
|
||||
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from tasks.jobs.attack_paths.config import get_provider_label
|
||||
|
||||
|
||||
# Step 1 - String / comment protection
|
||||
# Single combined regex: strings first, then line comments.
|
||||
# The regex engine finds the leftmost match, so a string like 'https://prowler.com'
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import atexit
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Iterator
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import neo4j
|
||||
import neo4j.exceptions
|
||||
|
||||
from api.attack_paths.retryable_session import RetryableSession
|
||||
from config.env import env
|
||||
from django.conf import settings
|
||||
|
||||
from api.attack_paths.retryable_session import RetryableSession
|
||||
from tasks.jobs.attack_paths.config import (
|
||||
BATCH_SIZE,
|
||||
PROVIDER_RESOURCE_LABEL,
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from api.attack_paths.queries.types import (
|
||||
AttackPathsQueryDefinition,
|
||||
AttackPathsQueryParameterDefinition,
|
||||
)
|
||||
from api.attack_paths.queries.registry import (
|
||||
get_queries_for_provider,
|
||||
get_query_by_id,
|
||||
)
|
||||
|
||||
from api.attack_paths.queries.types import (
|
||||
AttackPathsQueryDefinition,
|
||||
AttackPathsQueryParameterDefinition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AttackPathsQueryDefinition",
|
||||
|
||||
@@ -5,7 +5,6 @@ from api.attack_paths.queries.types import (
|
||||
)
|
||||
from tasks.jobs.attack_paths.config import PROWLER_FINDING_LABEL
|
||||
|
||||
|
||||
# Custom Attack Path Queries
|
||||
# --------------------------
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from api.attack_paths.queries.types import AttackPathsQueryDefinition
|
||||
from api.attack_paths.queries.aws import AWS_QUERIES
|
||||
|
||||
from api.attack_paths.queries.types import AttackPathsQueryDefinition
|
||||
|
||||
# Query definitions organized by provider
|
||||
_QUERY_DEFINITIONS: dict[str, list[AttackPathsQueryDefinition]] = {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import logging
|
||||
|
||||
from typing import Any, Iterable
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import neo4j
|
||||
|
||||
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
|
||||
|
||||
from api.attack_paths import database as graph_database, AttackPathsQueryDefinition
|
||||
from api.attack_paths import AttackPathsQueryDefinition
|
||||
from api.attack_paths import database as graph_database
|
||||
from api.attack_paths.cypher_sanitizer import (
|
||||
inject_provider_label,
|
||||
validate_custom_query,
|
||||
@@ -17,6 +15,7 @@ from api.attack_paths.queries.schema import (
|
||||
get_cartography_schema_query,
|
||||
)
|
||||
from config.custom_logging import BackendLogger
|
||||
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
|
||||
from tasks.jobs.attack_paths.config import (
|
||||
INTERNAL_LABELS,
|
||||
INTERNAL_PROPERTIES,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.models import TenantAPIKey, TenantAPIKeyManager
|
||||
from cryptography.fernet import InvalidToken
|
||||
from django.utils import timezone
|
||||
from drf_simple_apikey.backends import APIKeyAuthentication as BaseAPIKeyAuth
|
||||
@@ -10,9 +11,6 @@ from rest_framework.exceptions import AuthenticationFailed
|
||||
from rest_framework.request import Request
|
||||
from rest_framework_simplejwt.authentication import JWTAuthentication
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.models import TenantAPIKey, TenantAPIKeyManager
|
||||
|
||||
|
||||
class TenantAPIKeyAuthentication(BaseAPIKeyAuth):
|
||||
model = TenantAPIKey
|
||||
@@ -81,7 +79,7 @@ class CombinedJWTOrAPIKeyAuthentication(BaseAuthentication):
|
||||
jwt_auth = JWTAuthentication()
|
||||
api_key_auth = TenantAPIKeyAuthentication()
|
||||
|
||||
def authenticate(self, request: Request) -> Optional[Tuple[object, dict]]:
|
||||
def authenticate(self, request: Request) -> tuple[object, dict] | None:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
# Prioritize JWT authentication if both are present
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
from api.authentication import CombinedJWTOrAPIKeyAuthentication
|
||||
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
|
||||
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
|
||||
from api.filters import CustomDjangoFilterBackend
|
||||
from api.models import Role, UserRoleRelationship
|
||||
from api.rbac.permissions import HasPermissions
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from rest_framework import permissions
|
||||
@@ -8,13 +14,6 @@ from rest_framework.response import Response
|
||||
from rest_framework_json_api import filters
|
||||
from rest_framework_json_api.views import ModelViewSet
|
||||
|
||||
from api.authentication import CombinedJWTOrAPIKeyAuthentication
|
||||
from api.db_router import MainRouter, reset_read_db_alias, set_read_db_alias
|
||||
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
|
||||
from api.filters import CustomDjangoFilterBackend
|
||||
from api.models import Role, UserRoleRelationship
|
||||
from api.rbac.permissions import HasPermissions
|
||||
|
||||
|
||||
class BaseViewSet(ModelViewSet):
|
||||
authentication_classes = [CombinedJWTOrAPIKeyAuthentication]
|
||||
|
||||
@@ -352,7 +352,7 @@ def generate_compliance_overview_template(
|
||||
total_requirements += 1
|
||||
provider_check_list = list(requirement.checks.get(provider_type, []))
|
||||
total_checks = len(provider_check_list)
|
||||
checks_dict = {check: None for check in provider_check_list}
|
||||
checks_dict = dict.fromkeys(provider_check_list)
|
||||
|
||||
req_status_val = "MANUAL" if total_checks == 0 else "PASS"
|
||||
|
||||
|
||||
@@ -3,8 +3,14 @@ import secrets
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from api.db_router import (
|
||||
READ_REPLICA_ALIAS,
|
||||
get_read_db_alias,
|
||||
reset_read_db_alias,
|
||||
set_read_db_alias,
|
||||
)
|
||||
from celery.utils.log import get_task_logger
|
||||
from config.env import env
|
||||
from django.conf import settings
|
||||
@@ -22,13 +28,6 @@ from psycopg2 import sql as psycopg2_sql
|
||||
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
from api.db_router import (
|
||||
READ_REPLICA_ALIAS,
|
||||
get_read_db_alias,
|
||||
reset_read_db_alias,
|
||||
set_read_db_alias,
|
||||
)
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
|
||||
@@ -170,7 +169,7 @@ def one_week_from_now():
|
||||
"""
|
||||
Return a datetime object with a date one week from now.
|
||||
"""
|
||||
return datetime.now(timezone.utc) + timedelta(days=7)
|
||||
return datetime.now(UTC) + timedelta(days=7)
|
||||
|
||||
|
||||
def generate_random_token(length: int = 14, symbols: str | None = None) -> str:
|
||||
@@ -405,10 +404,10 @@ def _should_create_index_on_partition(
|
||||
# Unknown month abbreviation, include it to be safe
|
||||
return True
|
||||
|
||||
partition_date = datetime(year, month, 1, tzinfo=timezone.utc)
|
||||
partition_date = datetime(year, month, 1, tzinfo=UTC)
|
||||
|
||||
# Get current month start
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
current_month_start = now.replace(
|
||||
day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import DatabaseError, connection, transaction
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
from api.db_router import READ_REPLICA_ALIAS
|
||||
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY, rls_transaction
|
||||
from api.exceptions import ProviderDeletedException
|
||||
from api.models import Provider, Scan
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import DatabaseError, connection, transaction
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
|
||||
def set_tenant(func=None, *, keep_tenant=False):
|
||||
|
||||
@@ -1,19 +1,4 @@
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
from dateutil.parser import parse
|
||||
from django.conf import settings
|
||||
from django.db.models import F, Q
|
||||
from django_filters.rest_framework import (
|
||||
BaseInFilter,
|
||||
BooleanFilter,
|
||||
CharFilter,
|
||||
ChoiceFilter,
|
||||
DateFilter,
|
||||
FilterSet,
|
||||
UUIDFilter,
|
||||
)
|
||||
from rest_framework_json_api.django_filters.backends import DjangoFilterBackend
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
|
||||
from api.constants import SEVERITY_ORDER
|
||||
from api.db_utils import (
|
||||
@@ -68,6 +53,20 @@ from api.uuid_utils import (
|
||||
uuid7_start,
|
||||
)
|
||||
from api.v1.serializers import TaskBase
|
||||
from dateutil.parser import parse
|
||||
from django.conf import settings
|
||||
from django.db.models import F, Q
|
||||
from django_filters.rest_framework import (
|
||||
BaseInFilter,
|
||||
BooleanFilter,
|
||||
CharFilter,
|
||||
ChoiceFilter,
|
||||
DateFilter,
|
||||
FilterSet,
|
||||
UUIDFilter,
|
||||
)
|
||||
from rest_framework_json_api.django_filters.backends import DjangoFilterBackend
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
|
||||
class CustomDjangoFilterBackend(DjangoFilterBackend):
|
||||
@@ -598,12 +597,12 @@ class ResourceFilter(ProviderRelationshipFilterSet):
|
||||
gte_date = (
|
||||
parse(self.data.get("updated_at__gte")).date()
|
||||
if self.data.get("updated_at__gte")
|
||||
else datetime.now(timezone.utc).date()
|
||||
else datetime.now(UTC).date()
|
||||
)
|
||||
lte_date = (
|
||||
parse(self.data.get("updated_at__lte")).date()
|
||||
if self.data.get("updated_at__lte")
|
||||
else datetime.now(timezone.utc).date()
|
||||
else datetime.now(UTC).date()
|
||||
)
|
||||
|
||||
if abs(lte_date - gte_date) > timedelta(
|
||||
@@ -748,9 +747,9 @@ class FindingFilter(CommonFindingFilters):
|
||||
lte_date = cleaned.get("inserted_at__lte") or exact_date
|
||||
|
||||
if gte_date is None:
|
||||
gte_date = datetime.now(timezone.utc).date()
|
||||
gte_date = datetime.now(UTC).date()
|
||||
if lte_date is None:
|
||||
lte_date = datetime.now(timezone.utc).date()
|
||||
lte_date = datetime.now(UTC).date()
|
||||
|
||||
if abs(lte_date - gte_date) > timedelta(
|
||||
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
|
||||
@@ -844,7 +843,7 @@ class FindingFilter(CommonFindingFilters):
|
||||
def maybe_date_to_datetime(value):
|
||||
dt = value
|
||||
if isinstance(value, date):
|
||||
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc)
|
||||
dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
|
||||
@@ -933,9 +932,9 @@ class FindingGroupFilter(CommonFindingFilters):
|
||||
lte_date = cleaned.get("inserted_at__lte") or exact_date
|
||||
|
||||
if gte_date is None:
|
||||
gte_date = datetime.now(timezone.utc).date()
|
||||
gte_date = datetime.now(UTC).date()
|
||||
if lte_date is None:
|
||||
lte_date = datetime.now(timezone.utc).date()
|
||||
lte_date = datetime.now(UTC).date()
|
||||
|
||||
if abs(lte_date - gte_date) > timedelta(
|
||||
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
|
||||
@@ -977,7 +976,7 @@ class FindingGroupFilter(CommonFindingFilters):
|
||||
"""Convert date to datetime if needed."""
|
||||
dt = value
|
||||
if isinstance(value, date):
|
||||
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc)
|
||||
dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
|
||||
@@ -1091,9 +1090,9 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
|
||||
lte_date = cleaned.get("inserted_at__lte") or exact_date
|
||||
|
||||
if gte_date is None:
|
||||
gte_date = datetime.now(timezone.utc).date()
|
||||
gte_date = datetime.now(UTC).date()
|
||||
if lte_date is None:
|
||||
lte_date = datetime.now(timezone.utc).date()
|
||||
lte_date = datetime.now(UTC).date()
|
||||
|
||||
if abs(lte_date - gte_date) > timedelta(
|
||||
days=settings.FINDINGS_MAX_DAYS_IN_RANGE
|
||||
@@ -1132,7 +1131,7 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
|
||||
def _maybe_date_to_datetime(value):
|
||||
dt = value
|
||||
if isinstance(value, date):
|
||||
dt = datetime.combine(value, datetime.min.time(), tzinfo=timezone.utc)
|
||||
dt = datetime.combine(value, datetime.min.time(), tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
@@ -62,11 +62,7 @@ class HealthJSONRenderer(JSONRenderer):
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return (
|
||||
datetime.now(timezone.utc)
|
||||
.isoformat(timespec="milliseconds")
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
return datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _measure(name: str, check_fn) -> tuple[dict[str, Any], float]:
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import random
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from math import ceil
|
||||
from uuid import uuid4
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from tqdm import tqdm
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
Finding,
|
||||
@@ -16,7 +13,9 @@ from api.models import (
|
||||
Scan,
|
||||
StatusChoices,
|
||||
)
|
||||
from django.core.management.base import BaseCommand
|
||||
from prowler.lib.check.models import CheckMetadata
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
@@ -116,7 +115,7 @@ class Command(BaseCommand):
|
||||
trigger="manual",
|
||||
state="executing",
|
||||
progress=0,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
started_at=datetime.now(UTC),
|
||||
)
|
||||
scan_state = "completed"
|
||||
|
||||
@@ -272,10 +271,8 @@ class Command(BaseCommand):
|
||||
self.stdout.write(self.style.ERROR(f"Failed to populate test data: {e}"))
|
||||
scan_state = "failed"
|
||||
finally:
|
||||
scan.completed_at = datetime.now(timezone.utc)
|
||||
scan.duration = int(
|
||||
(datetime.now(timezone.utc) - scan.started_at).total_seconds()
|
||||
)
|
||||
scan.completed_at = datetime.now(UTC)
|
||||
scan.duration = int((datetime.now(UTC) - scan.started_at).total_seconds())
|
||||
scan.progress = 100
|
||||
scan.state = scan_state
|
||||
scan.unique_resource_count = num_resources
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from tasks.jobs.orphan_recovery import reconcile_orphans
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from config.custom_logging import BackendLogger
|
||||
from django.core.handlers.asgi import ASGIRequest
|
||||
from django.db import connections
|
||||
|
||||
from config.custom_logging import BackendLogger
|
||||
|
||||
|
||||
class CloseDBConnectionsMiddleware:
|
||||
"""
|
||||
|
||||
@@ -1,26 +1,13 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
|
||||
import api.rls
|
||||
import django.contrib.auth.models
|
||||
import django.contrib.postgres.indexes
|
||||
import django.contrib.postgres.search
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
import django.utils.timezone
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
from psqlextra.backend.migrations.operations.add_default_partition import (
|
||||
PostgresAddDefaultPartition,
|
||||
)
|
||||
from psqlextra.backend.migrations.operations.create_partitioned_model import (
|
||||
PostgresCreatePartitionedModel,
|
||||
)
|
||||
from psqlextra.manager.manager import PostgresManager
|
||||
from psqlextra.models.partitioned import PostgresPartitionedModel
|
||||
from psqlextra.types import PostgresPartitioningMethod
|
||||
from uuid6 import uuid7
|
||||
|
||||
import api.rls
|
||||
from api.db_utils import (
|
||||
DB_PROWLER_PASSWORD,
|
||||
DB_PROWLER_USER,
|
||||
@@ -53,6 +40,18 @@ from api.models import (
|
||||
StateChoices,
|
||||
StatusChoices,
|
||||
)
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
from psqlextra.backend.migrations.operations.add_default_partition import (
|
||||
PostgresAddDefaultPartition,
|
||||
)
|
||||
from psqlextra.backend.migrations.operations.create_partitioned_model import (
|
||||
PostgresCreatePartitionedModel,
|
||||
)
|
||||
from psqlextra.manager.manager import PostgresManager
|
||||
from psqlextra.models.partitioned import PostgresPartitionedModel
|
||||
from psqlextra.types import PostgresPartitioningMethod
|
||||
from uuid6 import uuid7
|
||||
|
||||
DB_NAME = settings.DATABASES["default"]["NAME"]
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from api.db_utils import DB_PROWLER_USER
|
||||
from django.conf import settings
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import DB_PROWLER_USER
|
||||
|
||||
DB_NAME = settings.DATABASES["default"]["NAME"]
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def create_admin_role(apps, schema_editor):
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Scan, StateChoices
|
||||
from django.db import migrations, models
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
|
||||
|
||||
def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
|
||||
@@ -17,11 +16,11 @@ def migrate_daily_scheduled_scan_tasks(apps, schema_editor):
|
||||
tenant_id = task_kwargs["tenant_id"]
|
||||
provider_id = task_kwargs["provider_id"]
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
current_time = datetime.now(UTC)
|
||||
scheduled_time_today = datetime.combine(
|
||||
current_time.date(),
|
||||
daily_scheduled_scan_task.start_time.time(),
|
||||
tzinfo=timezone.utc,
|
||||
tzinfo=UTC,
|
||||
)
|
||||
|
||||
if current_time < scheduled_time_today:
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import IntegrationTypeEnum, PostgresEnumMigration, register_enum
|
||||
from api.models import Integration
|
||||
from django.db import migrations
|
||||
|
||||
IntegrationTypeEnumMigration = PostgresEnumMigration(
|
||||
enum_name="integration_type",
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from api.rls import RowLevelSecurityConstraint
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django 5.1.5 on 2025-03-25 11:29
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django 5.1.7 on 2025-04-16 08:47
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
import uuid6
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from api.rls import RowLevelSecurityConstraint
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import PostgresEnumMigration, ProcessorTypeEnum, register_enum
|
||||
from api.models import Processor
|
||||
from django.db import migrations
|
||||
|
||||
ProcessorTypeEnumMigration = PostgresEnumMigration(
|
||||
enum_name="processor_type",
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from api.rls import RowLevelSecurityConstraint
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django 5.1.7 on 2025-07-09 14:44
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -2,15 +2,14 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
import drf_simple_apikey.models
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -4,15 +4,14 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from api.db_router import MainRouter
|
||||
from config.custom_logging import BackendLogger
|
||||
from cryptography.fernet import Fernet
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
from api.db_router import MainRouter
|
||||
|
||||
logger = logging.getLogger(BackendLogger.API)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django 5.1.7 on 2025-10-14 00:00
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.contrib.postgres.fields
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django 5.1.10 on 2025-09-09 09:25
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django 5.1.13 on 2025-11-05 08:37
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
# Generated by Django 5.1.14 on 2025-12-10
|
||||
|
||||
from django.db import migrations
|
||||
from tasks.tasks import backfill_daily_severity_summaries_task
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.rls import Tenant
|
||||
from django.db import migrations
|
||||
from tasks.tasks import backfill_daily_severity_summaries_task
|
||||
|
||||
|
||||
def trigger_backfill_task(apps, schema_editor):
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django migration for Alibaba Cloud provider support
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import api.db_utils
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Generated by Django 5.1.13 on 2025-11-06 16:20
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
|
||||
from django.db import migrations, models
|
||||
from uuid6 import uuid7
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django migration for Cloudflare provider support
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Generated by Django migration for OpenStack provider support
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
# on different database connections, causing a deadlock when combined with RunPython
|
||||
# in the same migration.
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def backfill_graph_data_ready(apps, schema_editor):
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
|
||||
import uuid
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
from django.contrib.postgres.indexes import GinIndex, OpClass
|
||||
from django.db import migrations, models
|
||||
from django.db.models.functions import Upper
|
||||
from django.utils import timezone
|
||||
|
||||
import api.rls
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
# Generated by Django 5.1.14 on 2026-02-02
|
||||
|
||||
from django.db import migrations
|
||||
from tasks.tasks import backfill_finding_group_summaries_task
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.rls import Tenant
|
||||
from django.db import migrations
|
||||
from tasks.tasks import backfill_finding_group_summaries_task
|
||||
|
||||
|
||||
def trigger_backfill_task(apps, schema_editor):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
TASK_NAME = "attack-paths-cleanup-stale-scans"
|
||||
INTERVAL_HOURS = 1
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from django.db import migrations
|
||||
from tasks.tasks import backfill_finding_group_summaries_task
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.rls import Tenant
|
||||
from django.db import migrations
|
||||
from tasks.tasks import backfill_finding_group_summaries_task
|
||||
|
||||
|
||||
def trigger_backfill_task(apps, schema_editor):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.db import migrations
|
||||
|
||||
import api.db_utils
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
TASK_NAME = "reconcile-orphan-tasks"
|
||||
INTERVAL_MINUTES = 2
|
||||
|
||||
|
||||
@@ -1,37 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import defusedxml
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from config.custom_logging import BackendLogger
|
||||
from config.settings.social_login import SOCIALACCOUNT_PROVIDERS
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from defusedxml import ElementTree as ET
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import AbstractBaseUser
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.indexes import GinIndex, OpClass
|
||||
from django.contrib.postgres.search import SearchVector, SearchVectorField
|
||||
from django.contrib.sites.models import Site
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import MinLengthValidator
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.db.models.functions import Upper
|
||||
from django.utils import timezone as django_timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
from django_celery_results.models import TaskResult
|
||||
from drf_simple_apikey.crypto import get_crypto
|
||||
from drf_simple_apikey.models import AbstractAPIKey, AbstractAPIKeyManager
|
||||
from psqlextra.manager import PostgresManager
|
||||
from psqlextra.models import PostgresPartitionedModel
|
||||
from psqlextra.types import PostgresPartitioningMethod
|
||||
from uuid6 import uuid7
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.db_utils import (
|
||||
CustomUserManager,
|
||||
@@ -58,7 +32,32 @@ from api.rls import (
|
||||
RowLevelSecurityProtectedModel,
|
||||
Tenant,
|
||||
)
|
||||
from config.custom_logging import BackendLogger
|
||||
from config.settings.social_login import SOCIALACCOUNT_PROVIDERS
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from defusedxml import ElementTree as ET
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import AbstractBaseUser
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.indexes import GinIndex, OpClass
|
||||
from django.contrib.postgres.search import SearchVector, SearchVectorField
|
||||
from django.contrib.sites.models import Site
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import MinLengthValidator
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.db.models.functions import Upper
|
||||
from django.utils import timezone as django_timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
from django_celery_results.models import TaskResult
|
||||
from drf_simple_apikey.crypto import get_crypto
|
||||
from drf_simple_apikey.models import AbstractAPIKey, AbstractAPIKeyManager
|
||||
from prowler.lib.check.models import Severity
|
||||
from psqlextra.manager import PostgresManager
|
||||
from psqlextra.models import PostgresPartitionedModel
|
||||
from psqlextra.types import PostgresPartitioningMethod
|
||||
from uuid6 import uuid7
|
||||
|
||||
fernet = Fernet(settings.SECRETS_ENCRYPTION_KEY.encode())
|
||||
|
||||
@@ -1427,8 +1426,8 @@ class Role(RowLevelSecurityProtectedModel):
|
||||
|
||||
@classmethod
|
||||
def filter_by_permission_state(cls, queryset, value):
|
||||
q_all_true = Q(**{field: True for field in cls.PERMISSION_FIELDS})
|
||||
q_all_false = Q(**{field: False for field in cls.PERMISSION_FIELDS})
|
||||
q_all_true = Q(**dict.fromkeys(cls.PERMISSION_FIELDS, True))
|
||||
q_all_false = Q(**dict.fromkeys(cls.PERMISSION_FIELDS, False))
|
||||
|
||||
if value == PermissionChoices.UNLIMITED:
|
||||
return queryset.filter(q_all_true)
|
||||
@@ -2011,11 +2010,11 @@ class SAMLToken(models.Model):
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.expires_at:
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15)
|
||||
self.expires_at = datetime.now(UTC) + timedelta(seconds=15)
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
return datetime.now(UTC) >= self.expires_at
|
||||
|
||||
|
||||
class SAMLDomainIndex(models.Model):
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Generator, Optional
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.conf import settings
|
||||
from psqlextra.partitioning import (
|
||||
PostgresPartitioningManager,
|
||||
PostgresRangePartition,
|
||||
PostgresRangePartitioningStrategy,
|
||||
PostgresTimePartitionSize,
|
||||
PostgresPartitioningError,
|
||||
)
|
||||
from psqlextra.partitioning.config import PostgresPartitioningConfig
|
||||
from uuid6 import UUID
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from api.models import Finding, ResourceFindingMapping
|
||||
from api.rls import RowLevelSecurityConstraint
|
||||
from api.uuid_utils import datetime_to_uuid7
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.conf import settings
|
||||
from psqlextra.partitioning import (
|
||||
PostgresPartitioningError,
|
||||
PostgresPartitioningManager,
|
||||
PostgresRangePartition,
|
||||
PostgresRangePartitioningStrategy,
|
||||
PostgresTimePartitionSize,
|
||||
)
|
||||
from psqlextra.partitioning.config import PostgresPartitioningConfig
|
||||
from uuid6 import UUID
|
||||
|
||||
|
||||
class PostgresUUIDv7RangePartition(PostgresRangePartition):
|
||||
@@ -24,7 +23,7 @@ class PostgresUUIDv7RangePartition(PostgresRangePartition):
|
||||
from_values: UUID,
|
||||
to_values: UUID,
|
||||
size: PostgresTimePartitionSize,
|
||||
name_format: Optional[str] = None,
|
||||
name_format: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.from_values = from_values
|
||||
@@ -38,9 +37,7 @@ class PostgresUUIDv7RangePartition(PostgresRangePartition):
|
||||
|
||||
start_timestamp_ms = self.from_values.time
|
||||
|
||||
self.start_datetime = datetime.fromtimestamp(
|
||||
start_timestamp_ms / 1000, timezone.utc
|
||||
)
|
||||
self.start_datetime = datetime.fromtimestamp(start_timestamp_ms / 1000, UTC)
|
||||
|
||||
def name(self) -> str:
|
||||
if not self.name_format:
|
||||
@@ -82,8 +79,8 @@ class PostgresUUIDv7PartitioningStrategy(PostgresRangePartitioningStrategy):
|
||||
size: PostgresTimePartitionSize,
|
||||
count: int,
|
||||
start_date: datetime = None,
|
||||
max_age: Optional[relativedelta] = None,
|
||||
name_format: Optional[str] = None,
|
||||
max_age: relativedelta | None = None,
|
||||
name_format: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.start_date = start_date.replace(
|
||||
@@ -151,7 +148,7 @@ class PostgresUUIDv7PartitioningStrategy(PostgresRangePartitioningStrategy):
|
||||
Returns:
|
||||
datetime: A `datetime` object representing the start of the current month in UTC.
|
||||
"""
|
||||
return datetime.now(timezone.utc).replace(
|
||||
return datetime.now(UTC).replace(
|
||||
day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
@@ -171,7 +168,7 @@ manager = PostgresPartitioningManager(
|
||||
PostgresPartitioningConfig(
|
||||
model=Finding,
|
||||
strategy=PostgresUUIDv7PartitioningStrategy(
|
||||
start_date=datetime.now(timezone.utc),
|
||||
start_date=datetime.now(UTC),
|
||||
size=PostgresTimePartitionSize(
|
||||
months=settings.FINDINGS_TABLE_PARTITION_MONTHS
|
||||
),
|
||||
@@ -187,7 +184,7 @@ manager = PostgresPartitioningManager(
|
||||
PostgresPartitioningConfig(
|
||||
model=ResourceFindingMapping,
|
||||
strategy=PostgresUUIDv7PartitioningStrategy(
|
||||
start_date=datetime.now(timezone.utc),
|
||||
start_date=datetime.now(UTC),
|
||||
size=PostgresTimePartitionSize(
|
||||
months=settings.FINDINGS_TABLE_PARTITION_MONTHS
|
||||
),
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.permissions import BasePermission
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.models import Provider, Role, User
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.permissions import BasePermission
|
||||
|
||||
|
||||
class Permissions(Enum):
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
from rest_framework_json_api.renderers import JSONRenderer
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
|
||||
|
||||
class PlainTextRenderer(BaseRenderer):
|
||||
media_type = "text/plain"
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from api.db_utils import DB_USER, POSTGRES_TENANT_VAR
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import DEFAULT_DB_ALIAS, models
|
||||
from django.db.backends.ddl_references import Statement, Table
|
||||
|
||||
from api.db_utils import DB_USER, POSTGRES_TENANT_VAR
|
||||
|
||||
|
||||
class Tenant(models.Model):
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
from celery import states
|
||||
from celery.signals import before_task_publish
|
||||
from config.celery import celery_app
|
||||
from django.db.models.signals import post_delete, pre_delete
|
||||
from django.dispatch import receiver
|
||||
from django_celery_results.backends.database import DatabaseBackend
|
||||
|
||||
from api.db_utils import delete_related_daily_task
|
||||
from api.models import (
|
||||
LighthouseProviderConfiguration,
|
||||
@@ -14,6 +7,12 @@ from api.models import (
|
||||
TenantAPIKey,
|
||||
User,
|
||||
)
|
||||
from celery import states
|
||||
from celery.signals import before_task_publish
|
||||
from config.celery import celery_app
|
||||
from django.db.models.signals import post_delete, pre_delete
|
||||
from django.dispatch import receiver
|
||||
from django_celery_results.backends.database import DatabaseBackend
|
||||
|
||||
|
||||
def create_task_result_on_publish(sender=None, headers=None, **kwargs): # noqa: F841
|
||||
|
||||
@@ -7,7 +7,7 @@ enforces the tenant gate (:class:`api.sse.channelmanager.SSEChannelManager`),
|
||||
and the channel-name helpers (:func:`api.sse.utils.make_channel_name`).
|
||||
"""
|
||||
|
||||
from api.sse.utils import make_channel_name
|
||||
from api.sse.base_views import BaseSSEViewSet
|
||||
from api.sse.utils import make_channel_name
|
||||
|
||||
__all__ = ["BaseSSEViewSet", "make_channel_name"]
|
||||
|
||||
@@ -5,11 +5,10 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from api.sse.utils import tenant_id_from_channel
|
||||
from django_eventstream.channelmanager import DefaultChannelManager
|
||||
from rest_framework.request import Request
|
||||
|
||||
from api.sse.utils import tenant_id_from_channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.models import User
|
||||
|
||||
@@ -41,7 +40,7 @@ class SSEChannelManager(DefaultChannelManager):
|
||||
if tenant_id_from_channel(channel) == request_tenant_id
|
||||
}
|
||||
|
||||
def can_read_channel(self, user: "User | None", channel: str) -> bool:
|
||||
def can_read_channel(self, user: User | None, channel: str) -> bool:
|
||||
"""Re-verify tenant membership once the stream is established.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from api.models import Membership, Role, TenantAPIKey, User, UserRoleRelationship
|
||||
from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header
|
||||
from django.urls import reverse
|
||||
from drf_simple_apikey.crypto import get_crypto
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from api.models import Membership, Role, TenantAPIKey, User, UserRoleRelationship
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_basic_authentication():
|
||||
@@ -468,7 +467,7 @@ class TestAPIKeyErrors:
|
||||
name="Expired Key",
|
||||
tenant_id=tenants_fixture[0].id,
|
||||
entity=create_test_user,
|
||||
expiry_date=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
expiry_date=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
|
||||
api_key_headers = get_api_key_header(raw_key)
|
||||
@@ -500,7 +499,7 @@ class TestAPIKeyErrors:
|
||||
# Create a valid-looking key with non-existent UUID
|
||||
crypto = get_crypto()
|
||||
fake_uuid = str(uuid4())
|
||||
fake_expiry = (datetime.now(timezone.utc) + timedelta(days=30)).timestamp()
|
||||
fake_expiry = (datetime.now(UTC) + timedelta(days=30)).timestamp()
|
||||
payload = {"_pk": fake_uuid, "_exp": fake_expiry}
|
||||
encrypted_payload = crypto.generate(payload)
|
||||
|
||||
@@ -723,7 +722,7 @@ class TestAPIKeyLifecycle:
|
||||
assert created_data["attributes"]["revoked"] is False
|
||||
|
||||
# Create API key with expiry
|
||||
future_expiry = (datetime.now(timezone.utc) + timedelta(days=90)).isoformat()
|
||||
future_expiry = (datetime.now(UTC) + timedelta(days=90)).isoformat()
|
||||
create_with_expiry_response = client.post(
|
||||
reverse("api-key-list"),
|
||||
data={
|
||||
@@ -927,9 +926,9 @@ class TestAPIKeyLifecycle:
|
||||
auth_response = client.get(reverse("provider-list"), headers=api_key_headers)
|
||||
|
||||
# Must return 401 Unauthorized, not 500 Internal Server Error
|
||||
assert (
|
||||
auth_response.status_code == 401
|
||||
), f"Expected 401 but got {auth_response.status_code}: {auth_response.json()}"
|
||||
assert auth_response.status_code == 401, (
|
||||
f"Expected 401 but got {auth_response.status_code}: {auth_response.json()}"
|
||||
)
|
||||
|
||||
# Verify error message is present
|
||||
response_json = auth_response.json()
|
||||
@@ -1267,7 +1266,7 @@ class TestAPIKeyRLSBypass:
|
||||
name="Expired Test Key",
|
||||
tenant_id=tenant.id,
|
||||
entity=create_test_user,
|
||||
expiry_date=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
expiry_date=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
|
||||
api_key_headers = get_api_key_header(raw_key)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from api.models import Provider
|
||||
from conftest import get_api_tokens, get_authorization_header
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from api.models import Provider
|
||||
|
||||
|
||||
@patch("api.v1.views.Task.objects.get")
|
||||
@patch("api.v1.views.delete_provider_task.delay")
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Tests for rls_transaction retry and fallback logic."""
|
||||
|
||||
import pytest
|
||||
from api.db_utils import rls_transaction
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestRLSTransaction:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from conftest import TEST_PASSWORD, TEST_USER, get_api_tokens, get_authorization_header
|
||||
from django.urls import reverse
|
||||
|
||||
from conftest import TEST_USER, TEST_PASSWORD, get_api_tokens, get_authorization_header
|
||||
|
||||
|
||||
@patch("api.v1.views.schedule_provider_scan")
|
||||
@pytest.mark.django_db
|
||||
|
||||
@@ -3,11 +3,10 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from allauth.socialaccount.models import SocialLogin
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
from api.adapters import ProwlerSocialAccountAdapter
|
||||
from api.db_router import MainRouter
|
||||
from api.models import SAMLConfiguration
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
@@ -4,11 +4,9 @@ import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
|
||||
import api
|
||||
import api.apps as api_apps_module
|
||||
import pytest
|
||||
from api.apps import (
|
||||
PRIVATE_KEY_FILE,
|
||||
PUBLIC_KEY_FILE,
|
||||
@@ -16,6 +14,7 @@ from api.apps import (
|
||||
VERIFYING_KEY_ENV,
|
||||
ApiConfig,
|
||||
)
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
import neo4j
|
||||
import neo4j.exceptions
|
||||
|
||||
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
|
||||
|
||||
import pytest
|
||||
from api.attack_paths import database as graph_database
|
||||
from api.attack_paths import views_helpers
|
||||
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
|
||||
from tasks.jobs.attack_paths.config import (
|
||||
PROVIDER_ELEMENT_ID_PROPERTY,
|
||||
get_provider_label,
|
||||
|
||||
@@ -6,15 +6,13 @@ never contacts Neo4j. These tests validate the database module behavior itself.
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import api.attack_paths.database as db_module
|
||||
import neo4j
|
||||
import neo4j.exceptions
|
||||
import pytest
|
||||
|
||||
import api.attack_paths.database as db_module
|
||||
|
||||
|
||||
class TestLazyInitialization:
|
||||
"""Test that Neo4j driver is initialized lazily on first use."""
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from django.test import RequestFactory
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
from api.authentication import SSEAuthentication, TenantAPIKeyAuthentication
|
||||
from api.db_router import MainRouter
|
||||
from api.models import TenantAPIKey
|
||||
from django.test import RequestFactory
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -104,7 +103,7 @@ class TestTenantAPIKeyAuthentication:
|
||||
# Verify that last_used_at was updated
|
||||
api_key.refresh_from_db()
|
||||
assert api_key.last_used_at is not None
|
||||
assert (datetime.now(timezone.utc) - api_key.last_used_at).seconds < 5
|
||||
assert (datetime.now(UTC) - api_key.last_used_at).seconds < 5
|
||||
|
||||
def test_authenticate_valid_api_key_uses_admin_database(
|
||||
self, auth_backend, api_keys_fixture, request_factory
|
||||
@@ -195,7 +194,7 @@ class TestTenantAPIKeyAuthentication:
|
||||
name="Expired API Key",
|
||||
tenant_id=tenant.id,
|
||||
entity=user,
|
||||
expiry_date=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
expiry_date=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
|
||||
request = request_factory.get("/")
|
||||
@@ -217,7 +216,7 @@ class TestTenantAPIKeyAuthentication:
|
||||
# Manually create an encrypted key with a non-existent ID
|
||||
payload = {
|
||||
"_pk": non_existent_uuid,
|
||||
"_exp": (datetime.now(timezone.utc) + timedelta(days=30)).timestamp(),
|
||||
"_exp": (datetime.now(UTC) + timedelta(days=30)).timestamp(),
|
||||
}
|
||||
encrypted_key = auth_backend.key_crypto.generate(payload)
|
||||
fake_key = f"{api_key.prefix}.{encrypted_key}"
|
||||
@@ -368,7 +367,7 @@ class TestTenantAPIKeyAuthentication:
|
||||
name="Short-lived API Key",
|
||||
tenant_id=tenant.id,
|
||||
entity=user,
|
||||
expiry_date=datetime.now(timezone.utc) + timedelta(seconds=1),
|
||||
expiry_date=datetime.now(UTC) + timedelta(seconds=1),
|
||||
)
|
||||
|
||||
# Wait for the key to expire
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api import compliance as compliance_module
|
||||
from api.compliance import (
|
||||
generate_compliance_overview_template,
|
||||
|
||||
@@ -3,13 +3,11 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from api.attack_paths.cypher_sanitizer import (
|
||||
inject_provider_label,
|
||||
validate_custom_query,
|
||||
)
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
PROVIDER_ID = "019c41ee-7df3-7dec-a684-d839f95619f8"
|
||||
LABEL = "_Provider_019c41ee7df37deca684d839f95619f8"
|
||||
@@ -202,9 +200,7 @@ class TestClauseSplitting:
|
||||
|
||||
def test_multiple_match_clauses(self):
|
||||
cypher = (
|
||||
"MATCH (a:AWSAccount)--(b:AWSRole) "
|
||||
"MATCH (b)--(c:AWSPolicy) "
|
||||
"RETURN a, b, c"
|
||||
"MATCH (a:AWSAccount)--(b:AWSRole) MATCH (b)--(c:AWSPolicy) RETURN a, b, c"
|
||||
)
|
||||
result = _inject(cypher)
|
||||
assert f"(a:AWSAccount:{LABEL})" in result
|
||||
@@ -265,9 +261,7 @@ class TestRealWorldQueries:
|
||||
|
||||
def test_custom_bare_query(self):
|
||||
cypher = (
|
||||
"MATCH (a)-[:HAS_POLICY]->(b)\n"
|
||||
"WHERE a.name CONTAINS 'admin'\n"
|
||||
"RETURN a, b"
|
||||
"MATCH (a)-[:HAS_POLICY]->(b)\nWHERE a.name CONTAINS 'admin'\nRETURN a, b"
|
||||
)
|
||||
result = _inject(cypher)
|
||||
assert f"(a:{LABEL})" in result
|
||||
@@ -344,9 +338,7 @@ class TestEdgeCases:
|
||||
assert f"(outer:AWSAccount:{LABEL})" in result
|
||||
|
||||
def test_multiple_protected_regions(self):
|
||||
cypher = (
|
||||
"MATCH (n:X {a: 'hello'}) " 'WHERE n.b = "world" ' "// comment\n" "RETURN n"
|
||||
)
|
||||
cypher = "MATCH (n:X {a: 'hello'}) WHERE n.b = \"world\" // comment\nRETURN n"
|
||||
result = _inject(cypher)
|
||||
assert "'hello'" in result
|
||||
assert '"world"' in result
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
from django.db.utils import ConnectionRouter
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from api.db_router import MainRouter
|
||||
from api.rls import Tenant
|
||||
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
|
||||
from unittest.mock import patch
|
||||
from django.conf import settings
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
from django.db.utils import ConnectionRouter
|
||||
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="admin")
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.db import DEFAULT_DB_ALIAS, OperationalError
|
||||
from freezegun import freeze_time
|
||||
from psycopg2 import sql as psycopg2_sql
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
from api.db_utils import (
|
||||
POSTGRES_TENANT_VAR,
|
||||
PostgresEnumMigration,
|
||||
@@ -23,6 +17,11 @@ from api.db_utils import (
|
||||
update_objects_in_batches,
|
||||
)
|
||||
from api.models import Provider
|
||||
from django.conf import settings
|
||||
from django.db import DEFAULT_DB_ALIAS, OperationalError
|
||||
from freezegun import freeze_time
|
||||
from psycopg2 import sql as psycopg2_sql
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -94,18 +93,16 @@ class TestEnumToChoices:
|
||||
class TestOneWeekFromNow:
|
||||
def test_one_week_from_now(self):
|
||||
with patch("api.db_utils.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
expected_result = datetime(2023, 1, 8, tzinfo=timezone.utc)
|
||||
mock_datetime.now.return_value = datetime(2023, 1, 1, tzinfo=UTC)
|
||||
expected_result = datetime(2023, 1, 8, tzinfo=UTC)
|
||||
|
||||
result = one_week_from_now()
|
||||
assert result == expected_result
|
||||
|
||||
def test_one_week_from_now_with_timezone(self):
|
||||
with patch("api.db_utils.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(
|
||||
2023, 6, 15, 12, 0, tzinfo=timezone.utc
|
||||
)
|
||||
expected_result = datetime(2023, 6, 22, 12, 0, tzinfo=timezone.utc)
|
||||
mock_datetime.now.return_value = datetime(2023, 6, 15, 12, 0, tzinfo=UTC)
|
||||
expected_result = datetime(2023, 6, 22, 12, 0, tzinfo=UTC)
|
||||
|
||||
result = one_week_from_now()
|
||||
assert result == expected_result
|
||||
@@ -939,9 +936,9 @@ class TestPostgresEnumMigration:
|
||||
|
||||
mock_cursor.execute.assert_called_once()
|
||||
query_arg = mock_cursor.execute.call_args[0][0]
|
||||
assert isinstance(
|
||||
query_arg, psycopg2_sql.Composable
|
||||
), "create_enum_type must pass a psycopg2.sql.Composable, not a raw string."
|
||||
assert isinstance(query_arg, psycopg2_sql.Composable), (
|
||||
"create_enum_type must pass a psycopg2.sql.Composable, not a raw string."
|
||||
)
|
||||
# Verify the composed SQL structure: CREATE TYPE <Identifier> AS ENUM (<Literals>)
|
||||
parts = query_arg.seq
|
||||
assert parts[0] == psycopg2_sql.SQL("CREATE TYPE ")
|
||||
@@ -962,9 +959,9 @@ class TestPostgresEnumMigration:
|
||||
|
||||
mock_cursor.execute.assert_called_once()
|
||||
query_arg = mock_cursor.execute.call_args[0][0]
|
||||
assert isinstance(
|
||||
query_arg, psycopg2_sql.Composable
|
||||
), "drop_enum_type must pass a psycopg2.sql.Composable, not a raw string."
|
||||
assert isinstance(query_arg, psycopg2_sql.Composable), (
|
||||
"drop_enum_type must pass a psycopg2.sql.Composable, not a raw string."
|
||||
)
|
||||
# Verify the composed SQL structure: DROP TYPE <Identifier>
|
||||
parts = query_arg.seq
|
||||
assert parts[0] == psycopg2_sql.SQL("DROP TYPE ")
|
||||
|
||||
@@ -2,12 +2,11 @@ import uuid
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import pytest
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import DatabaseError, IntegrityError
|
||||
|
||||
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
|
||||
from api.decorators import handle_provider_deletion, set_tenant
|
||||
from api.exceptions import ProviderDeletedException
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import DatabaseError, IntegrityError
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
||||
@@ -7,15 +7,13 @@ Cover the IETF response envelope, status code mapping (200 / 503), the
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from api import health
|
||||
from config import version as config_version
|
||||
from django.core.cache import cache
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from api import health
|
||||
|
||||
|
||||
HEALTH_MEDIA_TYPE = "application/health+json"
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from api.middleware import APILoggingMiddleware
|
||||
from django.http import HttpResponse
|
||||
from django.test import RequestFactory
|
||||
|
||||
from api.middleware import APILoggingMiddleware
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("logging.getLogger")
|
||||
|
||||
@@ -2,10 +2,6 @@ import json
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from django_celery_results.models import TaskResult
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
|
||||
from api.exceptions import (
|
||||
TaskFailedException,
|
||||
TaskInProgressException,
|
||||
@@ -14,6 +10,9 @@ from api.exceptions import (
|
||||
from api.models import Task, User
|
||||
from api.rls import Tenant
|
||||
from api.v1.mixins import PaginateByPkMixin, TaskManagementMixin
|
||||
from django_celery_results.models import TaskResult
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.models import (
|
||||
ProviderComplianceScore,
|
||||
@@ -16,6 +13,8 @@ from api.models import (
|
||||
StatusChoices,
|
||||
TenantComplianceSummary,
|
||||
)
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -376,7 +375,7 @@ class TestProviderComplianceScoreModel:
|
||||
def test_create_provider_compliance_score(self, providers_fixture, scans_fixture):
|
||||
provider = providers_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
scan.completed_at = datetime.now(timezone.utc)
|
||||
scan.completed_at = datetime.now(UTC)
|
||||
scan.save()
|
||||
|
||||
score = ProviderComplianceScore.objects.create(
|
||||
@@ -398,7 +397,7 @@ class TestProviderComplianceScoreModel:
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
scan = scans_fixture[0]
|
||||
scan.completed_at = datetime.now(timezone.utc)
|
||||
scan.completed_at = datetime.now(UTC)
|
||||
scan.save()
|
||||
|
||||
ProviderComplianceScore.objects.create(
|
||||
@@ -427,12 +426,12 @@ class TestProviderComplianceScoreModel:
|
||||
):
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
scan1 = scans_fixture[0]
|
||||
scan1.completed_at = datetime.now(timezone.utc)
|
||||
scan1.completed_at = datetime.now(UTC)
|
||||
scan1.save()
|
||||
|
||||
scan2 = scans_fixture[2]
|
||||
scan2.state = StateChoices.COMPLETED
|
||||
scan2.completed_at = datetime.now(timezone.utc)
|
||||
scan2.completed_at = datetime.now(UTC)
|
||||
scan2.save()
|
||||
|
||||
score1 = ProviderComplianceScore.objects.create(
|
||||
|
||||
@@ -2,10 +2,6 @@ import json
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from conftest import TEST_PASSWORD, TODAY
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
|
||||
from api.models import (
|
||||
Membership,
|
||||
ProviderGroup,
|
||||
@@ -16,6 +12,9 @@ from api.models import (
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from api.v1.serializers import TokenSerializer
|
||||
from conftest import TEST_PASSWORD, TODAY
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user