fix(RLS): enforce config security (#6066)

This commit is contained in:
Pepe Fagoaga
2024-12-13 12:55:09 +01:00
committed by GitHub
parent 32f69d24b6
commit da4f9b8e5f
5 changed files with 44 additions and 38 deletions

View File

@@ -1,14 +1,12 @@
import uuid
from django.db import connection, transaction
from django.db import transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
from rest_framework.filters import SearchFilter
from rest_framework_json_api import filters
from rest_framework_json_api.serializers import ValidationError
from rest_framework_json_api.views import ModelViewSet
from rest_framework_simplejwt.authentication import JWTAuthentication
from api.db_utils import POSTGRES_USER_VAR, tenant_transaction
from api.filters import CustomDjangoFilterBackend
@@ -47,13 +45,7 @@ class BaseRLSViewSet(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with tenant_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
@@ -75,8 +67,7 @@ class BaseTenantViewset(BaseViewSet):
):
user_id = str(request.user.id)
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
with tenant_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
return super().initial(request, *args, **kwargs)
# TODO: DRY this when we have time
@@ -87,13 +78,7 @@ class BaseTenantViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with tenant_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
@@ -114,12 +99,6 @@ class BaseUserViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with tenant_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

View File

@@ -1,4 +1,5 @@
import secrets
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
@@ -8,6 +9,7 @@ from django.core.paginator import Paginator
from django.db import connection, models, transaction
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
@@ -23,6 +25,8 @@ TASK_RUNNER_DB_TABLE = "django_celery_results_taskresult"
POSTGRES_TENANT_VAR = "api.tenant_id"
POSTGRES_USER_VAR = "api.user_id"
SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"
@contextmanager
def psycopg_connection(database_alias: str):
@@ -44,10 +48,23 @@ def psycopg_connection(database_alias: str):
@contextmanager
def tenant_transaction(tenant_id: str):
def tenant_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
"""
Creates a new database transaction setting the given configuration value. It validates the
if the value is a valid UUID to be used for Postgres RLS.
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name
"""
with transaction.atomic():
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
try:
# just in case the value is an UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor

View File

@@ -1,6 +1,10 @@
import uuid
from functools import wraps
from django.db import connection, transaction
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
def set_tenant(func):
@@ -31,7 +35,7 @@ def set_tenant(func):
pass
# When calling the task
some_task.delay(arg1, tenant_id="1234-abcd-5678")
some_task.delay(arg1, tenant_id="8db7ca86-03cc-4d42-99f6-5e480baf6ab5")
# The tenant context will be set before the task logic executes.
"""
@@ -43,9 +47,12 @@ def set_tenant(func):
tenant_id = kwargs.pop("tenant_id")
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
return func(*args, **kwargs)

View File

@@ -1,7 +1,9 @@
from unittest.mock import patch, call
import uuid
from unittest.mock import call, patch
import pytest
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.decorators import set_tenant
@@ -15,12 +17,12 @@ class TestSetTenantDecorator:
def random_func(arg):
return arg
tenant_id = "1234-abcd-5678"
tenant_id = str(uuid.uuid4())
result = random_func("test_arg", tenant_id=tenant_id)
assert (
call(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
call(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
in mock_cursor.execute.mock_calls
)
assert result == "test_arg"

View File

@@ -1,3 +1,4 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
@@ -212,7 +213,7 @@ class TestPerformScan:
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -260,7 +261,7 @@ class TestPerformScan:
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -317,7 +318,7 @@ class TestPerformScan:
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"