Compare commits

...

12 Commits

Author SHA1 Message Date
Víctor Fernández Poyatos
bcbf68736c fix(routing): apply correct routing and RLS settings to every connection 2025-06-17 12:28:18 +02:00
Víctor Fernández Poyatos
04056813eb Merge branch 'master' into PRWLR-4758-django-must-support-read-only-replica-database-config 2025-06-12 13:26:13 +02:00
Víctor Fernández Poyatos
6ab6099afc fix(db-router): allow relation between prowler_user and the rest 2025-05-21 12:36:32 +02:00
Víctor Fernández Poyatos
75db2ace1b Merge branch 'master' into PRWLR-4758-django-must-support-read-only-replica-database-config 2025-05-21 12:26:29 +02:00
Víctor Fernández Poyatos
9b45bec015 chore: update .gitignore 2025-05-21 12:21:45 +02:00
Víctor Fernández Poyatos
acfd1d99d6 fix(db-connectors): default values 2025-03-14 11:14:32 +01:00
Víctor Fernández Poyatos
f6780ab0dc Merge branch 'master' into PRWLR-4758-django-must-support-read-only-replica-database-config 2025-03-13 12:50:16 +01:00
Víctor Fernández Poyatos
6ca3a8e076 ref: use same var names as before 2025-03-12 16:20:43 +01:00
Víctor Fernández Poyatos
ac14f6f8d1 chore: update API changelog 2025-03-12 16:17:47 +01:00
Víctor Fernández Poyatos
d9bc3fbf3c fix(tests): adapt unit tests to db connectors changes 2025-03-12 16:14:21 +01:00
Víctor Fernández Poyatos
178d398aa1 ref: refactor forced db connections to use replica when needed 2025-03-12 16:13:51 +01:00
Víctor Fernández Poyatos
36f9051167 feat(db): add read replica configuration through env vars 2025-03-12 16:11:58 +01:00
19 changed files with 146 additions and 79 deletions

1
.gitignore vendored
View File

@@ -44,6 +44,7 @@ junit-reports/
# Cursor files
.cursorignore
.cursor/
# Terraform
.terraform*

View File

@@ -29,6 +29,9 @@ DJANGO_SENTRY_DSN=
# If running django and celery on host, use 'localhost', else use 'postgres-db'
POSTGRES_HOST=[localhost|postgres-db]
POSTGRES_PORT=5432
# If you are running a replica only for read queries. Defaults to the same value as POSTGRES_HOST and POSTGRES_PORT
POSTGRES_HOST_READ_ONLY=[localhost|postgres-db]
POSTGRES_PORT_READ_ONLY=5432
POSTGRES_ADMIN_USER=prowler
POSTGRES_ADMIN_PASSWORD=S3cret
POSTGRES_USER=prowler_user

View File

@@ -8,6 +8,7 @@ All notable changes to the **Prowler API** are documented in this file.
- Added SSO with SAML support [(#7822)](https://github.com/prowler-cloud/prowler/pull/7822).
- Support GCP Service Account key. [(#7824)](https://github.com/prowler-cloud/prowler/pull/7824)
- Added new `GET /compliance-overviews` endpoints to retrieve compliance metadata and specific requirements statuses [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877).
- Support for read only replicas in the database [(#7210)](https://github.com/prowler-cloud/prowler/pull/7210).
### Changed
- Reworked `GET /compliance-overviews` to return proper requirement metrics [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877).
@@ -77,7 +78,6 @@ All notable changes to the **Prowler API** are documented in this file.
## [v1.6.0] (Prowler v5.5.0)
### Added
- Support for developing new integrations [(#7167)](https://github.com/prowler-cloud/prowler/pull/7167).
- HTTP Security Headers [(#7289)](https://github.com/prowler-cloud/prowler/pull/7289).
- New endpoint to get the compliance overviews metadata [(#7333)](https://github.com/prowler-cloud/prowler/pull/7333).

View File

@@ -51,7 +51,7 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
email_domain = user.email.split("@")[-1]
tenant = (
SAMLConfiguration.objects.using(MainRouter.admin_db)
SAMLConfiguration.objects.using(MainRouter.admin_read)
.get(email_domain=email_domain)
.tenant
)
@@ -60,7 +60,7 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
role_name = extra.get("userType", ["saml_default_role"])[0].strip()
try:
role = Role.objects.using(MainRouter.admin_db).get(
role = Role.objects.using(MainRouter.admin_read).get(
name=role_name, tenant_id=tenant.id
)
except Role.DoesNotExist:

View File

@@ -15,6 +15,8 @@ from api.rbac.permissions import HasPermissions
class BaseViewSet(ModelViewSet):
_rls_ctx = None
authentication_classes = [JWTAuthentication]
required_permissions = []
permission_classes = [permissions.IsAuthenticated, HasPermissions]
@@ -45,25 +47,29 @@ class BaseViewSet(ModelViewSet):
def get_queryset(self):
raise NotImplementedError
def finalize_response(self, request, response, *args, **kwargs):
try:
return super().finalize_response(request, response, *args, **kwargs)
finally:
if self._rls_ctx:
self._rls_ctx.__exit__(None, None, None)
self._rls_ctx = None
class BaseRLSViewSet(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
def initial(self, request, *args, **kwargs):
# Ideally, this logic would be in the `.setup()` method but DRF view sets don't call it
# https://docs.djangoproject.com/en/5.1/ref/class-based-views/base/#django.views.generic.base.View.setup
if request.auth is None:
raise NotAuthenticated
tenant_id = request.auth.get("tenant_id")
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
if not tenant_id:
raise NotAuthenticated("Tenant ID missing in JWT")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
self._rls_ctx = rls_transaction(tenant_id)
self._rls_ctx.__enter__()
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
def get_serializer_context(self):
context = super().get_serializer_context()
@@ -117,17 +123,15 @@ class BaseTenantViewset(BaseViewSet):
raise NotAuthenticated("Tenant ID is not present in token")
user_id = str(request.user.id)
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
return super().initial(request, *args, **kwargs)
self._rls_ctx = rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR)
self._rls_ctx.__enter__()
return super().initial(request, *args, **kwargs)
class BaseUserViewset(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
def initial(self, request, *args, **kwargs):
# TODO refactor after improving RLS on users
if request.stream is not None and request.stream.method == "POST":
return super().initial(request, *args, **kwargs)
if request.auth is None:
@@ -137,6 +141,8 @@ class BaseUserViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
self._rls_ctx = rls_transaction(tenant_id)
self._rls_ctx.__enter__()
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

View File

@@ -3,27 +3,30 @@ ALLOWED_APPS = ("django", "socialaccount", "account", "authtoken", "silk")
class MainRouter:
default_db = "default"
default_read = "default_read"
prowler_user = "prowler_user"
admin_db = "admin"
admin_read = "admin_read"
def db_for_read(self, model, **hints): # noqa: F841
model_table_name = model._meta.db_table
if model_table_name.startswith("django_") or any(
model_table_name.startswith(f"{app}_") for app in ALLOWED_APPS
):
return self.admin_db
return None
if any(model_table_name.startswith(f"{app}_") for app in ALLOWED_APPS):
return self.admin_read
return self.default_read
def db_for_write(self, model, **hints): # noqa: F841
model_table_name = model._meta.db_table
if any(model_table_name.startswith(f"{app}_") for app in ALLOWED_APPS):
return self.admin_db
return None
return self.default_db
def allow_migrate(self, db, app_label, model_name=None, **hints): # noqa: F841
return db == self.admin_db
def allow_relation(self, obj1, obj2, **hints): # noqa: F841
# Allow relations if both objects are in either "default" or "admin" db connectors
if {obj1._state.db, obj2._state.db} <= {self.default_db, self.admin_db}:
# Allow relations if both objects are using one of our defined connectors
allowed = {self.default_db, self.default_read, self.admin_db, self.admin_read, self.prowler_user}
if {obj1._state.db, obj2._state.db} <= allowed:
return True
return None

View File

@@ -1,17 +1,19 @@
import re
import secrets
import uuid
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from datetime import datetime, timedelta, timezone
from django.conf import settings
from django.contrib.auth.models import BaseUserManager
from django.db import connection, models, transaction
from django.db import connection, connections, models, transaction
from django_celery_beat.models import PeriodicTask
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError
from api.db_router import MainRouter
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
settings.DATABASES["default"]["PASSWORD"] if not settings.TESTING else "test"
@@ -58,15 +60,29 @@ def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
"""
with transaction.atomic():
with connection.cursor() as cursor:
try:
# just in case the value is an UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
try:
# just in case the value is a UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
aliases = (
MainRouter.default_db,
MainRouter.admin_db,
MainRouter.prowler_user,
MainRouter.default_read,
)
with ExitStack() as stack:
cursors = []
for alias in aliases:
stack.enter_context(transaction.atomic(using=alias))
cursor = connections[alias].cursor()
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor
cursors.append(cursor)
stack.callback(cursor.close)
yield
class CustomUserManager(BaseUserManager):

View File

@@ -1,10 +1,6 @@
import uuid
from functools import wraps
from django.db import connection, transaction
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.db_utils import rls_transaction
def set_tenant(func=None, *, keep_tenant=False):
@@ -42,7 +38,6 @@ def set_tenant(func=None, *, keep_tenant=False):
def decorator(func):
@wraps(func)
@transaction.atomic
def wrapper(*args, **kwargs):
try:
if not keep_tenant:
@@ -51,14 +46,8 @@ def set_tenant(func=None, *, keep_tenant=False):
tenant_id = kwargs["tenant_id"]
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
return func(*args, **kwargs)
with rls_transaction(tenant_id):
return func(*args, **kwargs)
return wrapper

View File

@@ -1459,7 +1459,7 @@ class SAMLConfiguration(RowLevelSecurityProtectedModel):
)
# The email domain must be unique in the entire system
qs = SAMLConfiguration.objects.using(MainRouter.admin_db).filter(
qs = SAMLConfiguration.objects.using(MainRouter.admin_read).filter(
email_domain__iexact=self.email_domain
)
if qs.exists() and old_email_domain != self.email_domain:

View File

@@ -30,7 +30,9 @@ class HasPermissions(BasePermission):
return True
user_roles = (
User.objects.using(MainRouter.admin_db).get(id=request.user.id).roles.all()
User.objects.using(MainRouter.admin_read)
.get(id=request.user.id)
.roles.all()
)
if not user_roles:
return False

View File

@@ -67,7 +67,7 @@ class TestProwlerSocialAccountAdapter:
tenant = Tenant.objects.using(MainRouter.admin_db).get(
id=saml_setup["tenant_id"]
)
saml_config = SAMLConfiguration.objects.using(MainRouter.admin_db).get(
saml_config = SAMLConfiguration.objects.using(MainRouter.admin_read).get(
tenant=tenant
)
assert saml_config.email_domain == saml_setup["domain"]
@@ -76,7 +76,7 @@ class TestProwlerSocialAccountAdapter:
assert user.email == saml_setup["email"]
assert (
Membership.objects.using(MainRouter.admin_db)
Membership.objects.using(MainRouter.admin_read)
.filter(user=user, tenant=tenant)
.exists()
)

View File

@@ -1,15 +1,17 @@
from unittest.mock import patch
import pytest
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
from django.conf import settings
from django.db.migrations.recorder import MigrationRecorder
from django.db.utils import ConnectionRouter
from api.db_router import MainRouter
from api.rls import Tenant
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
from unittest.mock import patch
@patch("api.db_router.MainRouter.admin_db", new="admin")
@patch("api.db_router.MainRouter.admin_read", new="admin_read")
class TestMainDatabaseRouter:
@pytest.fixture(scope="module")
def router(self):
@@ -20,12 +22,12 @@ class TestMainDatabaseRouter:
@pytest.mark.parametrize("api_model", [Tenant])
def test_router_api_models(self, api_model, router):
assert router.db_for_read(api_model) == "default"
assert router.db_for_read(api_model) == "prowler_user_read"
assert router.db_for_write(api_model) == "default"
assert router.allow_migrate_model(MainRouter.admin_db, api_model)
assert not router.allow_migrate_model("default", api_model)
def test_router_django_models(self, router):
assert router.db_for_read(MigrationRecorder.Migration) == MainRouter.admin_db
assert not router.db_for_read(MigrationRecorder.Migration) == "default"
assert router.db_for_read(MigrationRecorder.Migration) == MainRouter.admin_read
assert router.db_for_read(MigrationRecorder.Migration) != "default"

View File

@@ -186,7 +186,7 @@ def validate_invitation(
try:
# Admin DB connector is used to bypass RLS protection since the invitation belongs to a tenant the user
# is not a member of yet
invitation = Invitation.objects.using(MainRouter.admin_db).get(
invitation = Invitation.objects.using(MainRouter.admin_read).get(
token=invitation_token, email=email
)
except Invitation.DoesNotExist:

View File

@@ -507,13 +507,13 @@ class TenantFinishACSView(FinishACSView):
email_domain = user.email.split("@")[-1]
tenant = (
SAMLConfiguration.objects.using(MainRouter.admin_db)
SAMLConfiguration.objects.using(MainRouter.admin_read)
.get(email_domain=email_domain)
.tenant
)
role_name = extra.get("userType", ["saml_default_role"])[0].strip()
try:
role = Role.objects.using(MainRouter.admin_db).get(
role = Role.objects.using(MainRouter.admin_read).get(
name=role_name, tenant=tenant
)
except Role.DoesNotExist:
@@ -2380,7 +2380,7 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
)
# Proceed with accepting the invitation
user = User.objects.using(MainRouter.admin_db).get(email=user_email)
user = User.objects.using(MainRouter.admin_read).get(email=user_email)
membership = Membership.objects.using(MainRouter.admin_db).create(
user=user,
tenant=invitation.tenant,

View File

@@ -40,6 +40,7 @@ class RLSTask(Task):
):
from django_celery_results.models import TaskResult
from api.db_router import MainRouter
from api.models import Task as APITask
result = super().apply_async(
@@ -52,14 +53,14 @@ class RLSTask(Task):
shadow=shadow,
**options,
)
task_result_instance = TaskResult.objects.get(task_id=result.task_id)
from api.db_utils import rls_transaction
task_result_instance = TaskResult.objects.using(MainRouter.admin_db).get(
task_id=result.task_id
)
tenant_id = kwargs.get("tenant_id")
with rls_transaction(tenant_id):
APITask.objects.update_or_create(
id=task_result_instance.task_id,
tenant_id=tenant_id,
defaults={"task_runner_task": task_result_instance},
)
APITask.objects.using(MainRouter.admin_db).update_or_create(
id=task_result_instance.task_id,
tenant_id=tenant_id,
defaults={"task_runner_task": task_result_instance},
)
return result

View File

@@ -14,6 +14,19 @@ DATABASES = {
"HOST": env("POSTGRES_HOST", default="postgres-db"),
"PORT": env("POSTGRES_PORT", default="5432"),
},
"default_read": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB", default="prowler_db"),
"USER": env("POSTGRES_USER", default="prowler_user"),
"PASSWORD": env("POSTGRES_PASSWORD", default="prowler"),
"HOST": env(
"POSTGRES_HOST_READ_ONLY",
default=env("POSTGRES_HOST", default="postgres-db"),
),
"PORT": env(
"POSTGRES_PORT_READ_ONLY", default=env("POSTGRES_PORT", default="5432")
),
},
"admin": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB", default="prowler_db"),
@@ -22,6 +35,19 @@ DATABASES = {
"HOST": env("POSTGRES_HOST", default="postgres-db"),
"PORT": env("POSTGRES_PORT", default="5432"),
},
"admin_read": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB", default="prowler_db"),
"USER": env("POSTGRES_ADMIN_USER", default="prowler"),
"PASSWORD": env("POSTGRES_ADMIN_PASSWORD", default="S3cret"),
"HOST": env(
"POSTGRES_HOST_READ_ONLY",
default=env("POSTGRES_HOST", default="postgres-db"),
),
"PORT": env(
"POSTGRES_PORT_READ_ONLY", default=env("POSTGRES_PORT", default="5432")
),
},
}
DATABASES["default"] = DATABASES["prowler_user"]

View File

@@ -5,7 +5,6 @@ DEBUG = env.bool("DJANGO_DEBUG", default=False)
ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["localhost", "127.0.0.1"])
# Database
# TODO Use Django database routers https://docs.djangoproject.com/en/5.0/topics/db/multi-db/#automatic-database-routing
DATABASES = {
"prowler_user": {
"ENGINE": "django.db.backends.postgresql",
@@ -15,6 +14,14 @@ DATABASES = {
"HOST": env("POSTGRES_HOST"),
"PORT": env("POSTGRES_PORT"),
},
"default_read": {
"ENGINE": "django.db.backends.postgresql",
"NAME": env("POSTGRES_DB"),
"USER": env("POSTGRES_USER"),
"PASSWORD": env("POSTGRES_PASSWORD"),
"HOST": env("POSTGRES_HOST_READ_ONLY", default=env("POSTGRES_HOST")),
"PORT": env("POSTGRES_PORT_READ_ONLY", default=env("POSTGRES_PORT")),
},
"admin": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB"),
@@ -23,5 +30,13 @@ DATABASES = {
"HOST": env("POSTGRES_HOST"),
"PORT": env("POSTGRES_PORT"),
},
"admin_read": {
"ENGINE": "psqlextra.backend",
"NAME": env("POSTGRES_DB"),
"USER": env("POSTGRES_ADMIN_USER"),
"PASSWORD": env("POSTGRES_ADMIN_PASSWORD"),
"HOST": env("POSTGRES_HOST_READ_ONLY", default=env("POSTGRES_HOST")),
"PORT": env("POSTGRES_PORT_READ_ONLY", default=env("POSTGRES_PORT")),
},
}
DATABASES["default"] = DATABASES["prowler_user"]

View File

@@ -1154,6 +1154,9 @@ def pytest_configure(config):
# Apply the mock before the test session starts. This is necessary to avoid admin error when running the
# 0004_rbac_missing_admin_roles migration
patch("api.db_router.MainRouter.admin_db", new="default").start()
patch("api.db_router.MainRouter.admin_read", new="default").start()
patch("api.db_router.MainRouter.prowler_user", new="default").start()
patch("api.db_router.MainRouter.default_read", new="default").start()
def pytest_unconfigure(config):

View File

@@ -64,7 +64,7 @@ def delete_tenant(pk: str):
"""
deletion_summary = {}
for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
for provider in Provider.objects.using(MainRouter.admin_read).filter(tenant_id=pk):
summary = delete_provider(pk, provider.id)
deletion_summary.update(summary)