feat(api): RBAC system (#6114)

This commit is contained in:
Adrián Jesús Peña Rodríguez
2024-12-13 14:14:40 +01:00
committed by GitHub
parent f9fbde6637
commit d00d254c90
21 changed files with 4323 additions and 309 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
FROM python:3.12-alpine AS build
FROM python:3.12.8-alpine3.20 AS build
LABEL maintainer="https://github.com/prowler-cloud/api"
+1 -1
View File
@@ -8,7 +8,7 @@ description = "Prowler's API (Django/DRF)"
license = "Apache-2.0"
name = "prowler-api"
package-mode = false
version = "1.0.0"
version = "1.1.0"
[tool.poetry.dependencies]
celery = {extras = ["pytest"], version = "^5.4.0"}
+36 -1
View File
@@ -1,3 +1,4 @@
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
@@ -8,6 +9,8 @@ from rest_framework_simplejwt.authentication import JWTAuthentication
from api.db_utils import POSTGRES_USER_VAR, tenant_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, Tenant
from api.db_router import MainRouter
class BaseViewSet(ModelViewSet):
@@ -58,7 +61,39 @@ class BaseRLSViewSet(BaseViewSet):
class BaseTenantViewset(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
tenant = super().dispatch(request, *args, **kwargs)
try:
# If the request is a POST, create the admin role
if request.method == "POST":
isinstance(tenant, dict) and self._create_admin_role(tenant.data["id"])
except Exception as e:
self._handle_creation_error(e, tenant)
raise
return tenant
def _create_admin_role(self, tenant_id):
Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant_id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
def _handle_creation_error(self, error, tenant):
if tenant.data.get("id"):
try:
Tenant.objects.using(MainRouter.admin_db).filter(
id=tenant.data["id"]
).delete()
except ObjectDoesNotExist:
pass # Tenant might not exist, handle gracefully
def initial(self, request, *args, **kwargs):
if (
+25 -3
View File
@@ -22,13 +22,11 @@ from api.db_utils import (
StatusEnumField,
)
from api.models import (
ComplianceOverview,
Finding,
Invitation,
Membership,
PermissionChoices,
Provider,
ProviderGroup,
ProviderSecret,
Resource,
ResourceTag,
Scan,
@@ -36,6 +34,10 @@ from api.models import (
SeverityChoices,
StateChoices,
StatusChoices,
ProviderSecret,
Invitation,
Role,
ComplianceOverview,
Task,
User,
)
@@ -481,6 +483,26 @@ class UserFilter(FilterSet):
}
class RoleFilter(FilterSet):
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
updated_at = DateFilter(field_name="updated_at", lookup_expr="date")
permission_state = ChoiceFilter(
choices=PermissionChoices.choices, method="filter_permission_state"
)
def filter_permission_state(self, queryset, name, value):
return Role.filter_by_permission_state(queryset, value)
class Meta:
model = Role
fields = {
"id": ["exact", "in"],
"name": ["exact", "in"],
"inserted_at": ["gte", "lte"],
"updated_at": ["gte", "lte"],
}
class ComplianceOverviewFilter(FilterSet):
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
provider_type = ChoiceFilter(choices=Provider.ProviderChoices.choices)
@@ -58,5 +58,96 @@
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
"inserted_at": "2024-11-13T11:55:41.237Z"
}
},
{
"model": "api.role",
"pk": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"name": "admin_test",
"manage_users": true,
"manage_account": true,
"manage_billing": true,
"manage_providers": true,
"manage_integrations": true,
"manage_scans": true,
"unlimited_visibility": true,
"inserted_at": "2024-11-20T15:32:42.402Z",
"updated_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.role",
"pk": "845ff03a-87ef-42ba-9786-6577c70c4df0",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"name": "first_role",
"manage_users": true,
"manage_account": true,
"manage_billing": true,
"manage_providers": true,
"manage_integrations": false,
"manage_scans": false,
"unlimited_visibility": true,
"inserted_at": "2024-11-20T15:31:53.239Z",
"updated_at": "2024-11-20T15:31:53.239Z"
}
},
{
"model": "api.role",
"pk": "902d726c-4bd5-413a-a2a4-f7b4754b6b20",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"name": "third_role",
"manage_users": false,
"manage_account": false,
"manage_billing": false,
"manage_providers": false,
"manage_integrations": false,
"manage_scans": true,
"unlimited_visibility": false,
"inserted_at": "2024-11-20T15:34:05.440Z",
"updated_at": "2024-11-20T15:34:05.440Z"
}
},
{
"model": "api.roleprovidergrouprelationship",
"pk": "57fd024a-0a7f-49b4-a092-fa0979a07aaf",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"provider_group": "3fe28fb8-e545-424c-9b8f-69aff638f430",
"inserted_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.roleprovidergrouprelationship",
"pk": "a3cd0099-1c13-4df1-a5e5-ecdfec561b35",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"provider_group": "481769f5-db2b-447b-8b00-1dee18db90ec",
"inserted_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.roleprovidergrouprelationship",
"pk": "cfd84182-a058-40c2-af3c-0189b174940f",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
"inserted_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.userrolerelationship",
"pk": "92339663-e954-4fd8-98fb-8bfe15949975",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"user": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
"inserted_at": "2024-11-20T15:36:14.302Z"
}
}
]
+246
View File
@@ -0,0 +1,246 @@
# Generated by Django 5.1.1 on 2024-12-05 12:29
import api.rls
import django.db.models.deletion
import uuid
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0002_token_migrations"),
]
operations = [
migrations.CreateModel(
name="Role",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("name", models.CharField(max_length=255)),
("manage_users", models.BooleanField(default=False)),
("manage_account", models.BooleanField(default=False)),
("manage_billing", models.BooleanField(default=False)),
("manage_providers", models.BooleanField(default=False)),
("manage_integrations", models.BooleanField(default=False)),
("manage_scans", models.BooleanField(default=False)),
("unlimited_visibility", models.BooleanField(default=False)),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "roles",
},
),
migrations.CreateModel(
name="RoleProviderGroupRelationship",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "role_provider_group_relationship",
},
),
migrations.CreateModel(
name="UserRoleRelationship",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "role_user_relationship",
},
),
migrations.AddField(
model_name="roleprovidergrouprelationship",
name="provider_group",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.providergroup"
),
),
migrations.AddField(
model_name="roleprovidergrouprelationship",
name="role",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.role"
),
),
migrations.AddField(
model_name="role",
name="provider_groups",
field=models.ManyToManyField(
related_name="roles",
through="api.RoleProviderGroupRelationship",
to="api.providergroup",
),
),
migrations.AddField(
model_name="userrolerelationship",
name="role",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.role"
),
),
migrations.AddField(
model_name="userrolerelationship",
name="user",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
),
migrations.AddField(
model_name="role",
name="users",
field=models.ManyToManyField(
related_name="roles",
through="api.UserRoleRelationship",
to=settings.AUTH_USER_MODEL,
),
),
migrations.AddConstraint(
model_name="roleprovidergrouprelationship",
constraint=models.UniqueConstraint(
fields=("role_id", "provider_group_id"),
name="unique_role_provider_group_relationship",
),
),
migrations.AddConstraint(
model_name="roleprovidergrouprelationship",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_roleprovidergrouprelationship",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddConstraint(
model_name="userrolerelationship",
constraint=models.UniqueConstraint(
fields=("role_id", "user_id"), name="unique_role_user_relationship"
),
),
migrations.AddConstraint(
model_name="userrolerelationship",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_userrolerelationship",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddConstraint(
model_name="role",
constraint=models.UniqueConstraint(
fields=("tenant_id", "name"), name="unique_role_per_tenant"
),
),
migrations.AddConstraint(
model_name="role",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_role",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.CreateModel(
name="InvitationRoleRelationship",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"invitation",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.invitation"
),
),
(
"role",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.role"
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "role_invitation_relationship",
},
),
migrations.AddConstraint(
model_name="invitationrolerelationship",
constraint=models.UniqueConstraint(
fields=("role_id", "invitation_id"),
name="unique_role_invitation_relationship",
),
),
migrations.AddConstraint(
model_name="invitationrolerelationship",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_invitationrolerelationship",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddField(
model_name="role",
name="invitations",
field=models.ManyToManyField(
related_name="roles",
through="api.InvitationRoleRelationship",
to="api.invitation",
),
),
]
@@ -0,0 +1,43 @@
from django.db import migrations
from api.db_router import MainRouter
def create_admin_role(apps, schema_editor):
Tenant = apps.get_model("api", "Tenant")
Role = apps.get_model("api", "Role")
User = apps.get_model("api", "User")
UserRoleRelationship = apps.get_model("api", "UserRoleRelationship")
for tenant in Tenant.objects.using(MainRouter.admin_db).all():
admin_role, _ = Role.objects.using(MainRouter.admin_db).get_or_create(
name="admin",
tenant=tenant,
defaults={
"manage_users": True,
"manage_account": True,
"manage_billing": True,
"manage_providers": True,
"manage_integrations": True,
"manage_scans": True,
"unlimited_visibility": True,
},
)
users = User.objects.using(MainRouter.admin_db).filter(
membership__tenant=tenant
)
for user in users:
UserRoleRelationship.objects.using(MainRouter.admin_db).get_or_create(
user=user,
role=admin_role,
tenant=tenant,
)
class Migration(migrations.Migration):
dependencies = [
("api", "0003_rbac"),
]
operations = [
migrations.RunPython(create_admin_role),
]
+164 -14
View File
@@ -69,6 +69,21 @@ class StateChoices(models.TextChoices):
CANCELLED = "cancelled", _("Cancelled")
class PermissionChoices(models.TextChoices):
"""
Represents the different permission states that a role can have.
Attributes:
UNLIMITED: Indicates that the role possesses all permissions.
LIMITED: Indicates that the role has some permissions but not all.
NONE: Indicates that the role does not have any permissions.
"""
UNLIMITED = "unlimited", _("Unlimited permissions")
LIMITED = "limited", _("Limited permissions")
NONE = "none", _("No permissions")
class ActiveProviderManager(models.Manager):
def get_queryset(self):
return super().get_queryset().filter(self.active_provider_filter())
@@ -294,23 +309,14 @@ class ProviderGroup(RowLevelSecurityProtectedModel):
]
class JSONAPIMeta:
resource_name = "provider-groups"
resource_name = "provider-group"
class ProviderGroupMembership(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
)
provider_group = models.ForeignKey(
ProviderGroup,
on_delete=models.CASCADE,
)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
provider = models.ForeignKey(Provider, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "provider_group_memberships"
@@ -327,7 +333,7 @@ class ProviderGroupMembership(RowLevelSecurityProtectedModel):
]
class JSONAPIMeta:
resource_name = "provider-group-memberships"
resource_name = "provider_groups-provider"
class Task(RowLevelSecurityProtectedModel):
@@ -851,6 +857,150 @@ class Invitation(RowLevelSecurityProtectedModel):
resource_name = "invitations"
class Role(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
name = models.CharField(max_length=255)
manage_users = models.BooleanField(default=False)
manage_account = models.BooleanField(default=False)
manage_billing = models.BooleanField(default=False)
manage_providers = models.BooleanField(default=False)
manage_integrations = models.BooleanField(default=False)
manage_scans = models.BooleanField(default=False)
unlimited_visibility = models.BooleanField(default=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
provider_groups = models.ManyToManyField(
ProviderGroup, through="RoleProviderGroupRelationship", related_name="roles"
)
users = models.ManyToManyField(
User, through="UserRoleRelationship", related_name="roles"
)
invitations = models.ManyToManyField(
Invitation, through="InvitationRoleRelationship", related_name="roles"
)
# Filter permission_state
PERMISSION_FIELDS = [
"manage_users",
"manage_account",
"manage_billing",
"manage_providers",
"manage_integrations",
"manage_scans",
]
@property
def permission_state(self):
values = [getattr(self, field) for field in self.PERMISSION_FIELDS]
if all(values):
return PermissionChoices.UNLIMITED
elif not any(values):
return PermissionChoices.NONE
else:
return PermissionChoices.LIMITED
@classmethod
def filter_by_permission_state(cls, queryset, value):
q_all_true = Q(**{field: True for field in cls.PERMISSION_FIELDS})
q_all_false = Q(**{field: False for field in cls.PERMISSION_FIELDS})
if value == PermissionChoices.UNLIMITED:
return queryset.filter(q_all_true)
elif value == PermissionChoices.NONE:
return queryset.filter(q_all_false)
else:
return queryset.exclude(q_all_true | q_all_false)
class Meta:
db_table = "roles"
constraints = [
models.UniqueConstraint(
fields=["tenant_id", "name"],
name="unique_role_per_tenant",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "role"
class RoleProviderGroupRelationship(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
role = models.ForeignKey(Role, on_delete=models.CASCADE)
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "role_provider_group_relationship"
constraints = [
models.UniqueConstraint(
fields=["role_id", "provider_group_id"],
name="unique_role_provider_group_relationship",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "role-provider_groups"
class UserRoleRelationship(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
role = models.ForeignKey(Role, on_delete=models.CASCADE)
user = models.ForeignKey(User, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "role_user_relationship"
constraints = [
models.UniqueConstraint(
fields=["role_id", "user_id"],
name="unique_role_user_relationship",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "user-roles"
class InvitationRoleRelationship(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
role = models.ForeignKey(Role, on_delete=models.CASCADE)
invitation = models.ForeignKey(Invitation, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "role_invitation_relationship"
constraints = [
models.UniqueConstraint(
fields=["role_id", "invitation_id"],
name="unique_role_invitation_relationship",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "invitation-roles"
class ComplianceOverview(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()
+38
View File
@@ -0,0 +1,38 @@
from enum import Enum
from rest_framework.permissions import BasePermission
from api.models import User
from api.db_router import MainRouter
class Permissions(Enum):
MANAGE_USERS = "manage_users"
MANAGE_ACCOUNT = "manage_account"
MANAGE_BILLING = "manage_billing"
MANAGE_PROVIDERS = "manage_providers"
MANAGE_INTEGRATIONS = "manage_integrations"
MANAGE_SCANS = "manage_scans"
UNLIMITED_VISIBILITY = "unlimited_visibility"
class HasPermissions(BasePermission):
"""
Custom permission to check if the user's role has the required permissions.
The required permissions should be specified in the view as a list in `required_permissions`.
"""
def has_permission(self, request, view):
required_permissions = getattr(view, "required_permissions", [])
if not required_permissions:
return True
user_roles = (
User.objects.using(MainRouter.admin_db).get(id=request.user.id).roles.all()
)
if not user_roles:
return False
for perm in required_permissions:
if not getattr(user_roles[0], perm.value, False):
return False
return True
File diff suppressed because it is too large Load Diff
@@ -1,12 +1,10 @@
import pytest
from django.urls import reverse
from unittest.mock import patch
from rest_framework.test import APIClient
from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header
@patch("api.v1.views.MainRouter.admin_db", new="default")
@pytest.mark.django_db
def test_basic_authentication():
client = APIClient()
@@ -13,6 +13,7 @@ def test_check_resources_between_different_tenants(
enforce_test_user_db_connection,
authenticated_api_client,
tenants_fixture,
set_user_admin_roles_fixture,
):
client = authenticated_api_client
@@ -6,8 +6,10 @@ from django.db.utils import ConnectionRouter
from api.db_router import MainRouter
from api.rls import Tenant
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
from unittest.mock import patch
@patch("api.db_router.MainRouter.admin_db", new="admin")
class TestMainDatabaseRouter:
@pytest.fixture(scope="module")
def router(self):
+306
View File
@@ -0,0 +1,306 @@
import pytest
from django.urls import reverse
from rest_framework import status
from unittest.mock import patch, ANY, Mock
@pytest.mark.django_db
class TestUserViewSet:
def test_list_users_with_all_permissions(self, authenticated_client_rbac):
response = authenticated_client_rbac.get(reverse("user-list"))
assert response.status_code == status.HTTP_200_OK
assert isinstance(response.json()["data"], list)
def test_list_users_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
response = authenticated_client_no_permissions_rbac.get(reverse("user-list"))
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_retrieve_user_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
response = authenticated_client_rbac.get(
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
)
assert response.status_code == status.HTTP_200_OK
assert (
response.json()["data"]["attributes"]["email"]
== create_test_user_rbac.email
)
def test_retrieve_user_with_no_roles(
self, authenticated_client_rbac_noroles, create_test_user_rbac_no_roles
):
response = authenticated_client_rbac_noroles.get(
reverse("user-detail", kwargs={"pk": create_test_user_rbac_no_roles.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_retrieve_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
response = authenticated_client_no_permissions_rbac.get(
reverse("user-detail", kwargs={"pk": create_test_user.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_create_user_with_all_permissions(self, authenticated_client_rbac):
valid_user_payload = {
"name": "test",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_rbac.post(
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
def test_create_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
valid_user_payload = {
"name": "test",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_no_permissions_rbac.post(
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
def test_partial_update_user_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
updated_data = {
"data": {
"type": "users",
"id": str(create_test_user_rbac.id),
"attributes": {"name": "Updated Name"},
},
}
response = authenticated_client_rbac.patch(
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id}),
data=updated_data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["name"] == "Updated Name"
def test_partial_update_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
updated_data = {
"data": {
"type": "users",
"attributes": {"name": "Updated Name"},
}
}
response = authenticated_client_no_permissions_rbac.patch(
reverse("user-detail", kwargs={"pk": create_test_user.id}),
data=updated_data,
format="vnd.api+json",
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_user_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
response = authenticated_client_rbac.delete(
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
)
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_delete_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
response = authenticated_client_no_permissions_rbac.delete(
reverse("user-detail", kwargs={"pk": create_test_user.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_me_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
response = authenticated_client_rbac.get(reverse("user-me"))
assert response.status_code == status.HTTP_200_OK
assert (
response.json()["data"]["attributes"]["email"]
== create_test_user_rbac.email
)
def test_me_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
response = authenticated_client_no_permissions_rbac.get(reverse("user-me"))
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["email"] == "rbac_limited@rbac.com"
@pytest.mark.django_db
class TestProviderViewSet:
def test_list_providers_with_all_permissions(
self, authenticated_client_rbac, providers_fixture
):
response = authenticated_client_rbac.get(reverse("provider-list"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == len(providers_fixture)
def test_list_providers_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
response = authenticated_client_no_permissions_rbac.get(
reverse("provider-list")
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
def test_retrieve_provider_with_all_permissions(
self, authenticated_client_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_rbac.get(
reverse("provider-detail", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["alias"] == provider.alias
def test_retrieve_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_no_permissions_rbac.get(
reverse("provider-detail", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_create_provider_with_all_permissions(self, authenticated_client_rbac):
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
response = authenticated_client_rbac.post(
reverse("provider-list"), data=payload, format="json"
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["data"]["attributes"]["alias"] == "new_alias"
def test_create_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
response = authenticated_client_no_permissions_rbac.post(
reverse("provider-list"), data=payload, format="json"
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_partial_update_provider_with_all_permissions(
self, authenticated_client_rbac, providers_fixture
):
provider = providers_fixture[0]
payload = {
"data": {
"type": "providers",
"id": provider.id,
"attributes": {"alias": "updated_alias"},
},
}
response = authenticated_client_rbac.patch(
reverse("provider-detail", kwargs={"pk": provider.id}),
data=payload,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["alias"] == "updated_alias"
def test_partial_update_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
update_payload = {
"data": {
"type": "providers",
"attributes": {"alias": "updated_alias"},
}
}
response = authenticated_client_no_permissions_rbac.patch(
reverse("provider-detail", kwargs={"pk": provider.id}),
data=update_payload,
format="vnd.api+json",
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.delete_provider_task.delay")
def test_delete_provider_with_all_permissions(
self,
mock_delete_task,
mock_task_get,
authenticated_client_rbac,
providers_fixture,
tasks_fixture,
):
prowler_task = tasks_fixture[0]
task_mock = Mock()
task_mock.id = prowler_task.id
mock_delete_task.return_value = task_mock
mock_task_get.return_value = prowler_task
provider1, *_ = providers_fixture
response = authenticated_client_rbac.delete(
reverse("provider-detail", kwargs={"pk": provider1.id})
)
assert response.status_code == status.HTTP_202_ACCEPTED
mock_delete_task.assert_called_once_with(
provider_id=str(provider1.id), tenant_id=ANY
)
assert "Content-Location" in response.headers
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
def test_delete_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_no_permissions_rbac.delete(
reverse("provider-detail", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.check_provider_connection_task.delay")
def test_connection_with_all_permissions(
self,
mock_provider_connection,
mock_task_get,
authenticated_client_rbac,
providers_fixture,
tasks_fixture,
):
prowler_task = tasks_fixture[0]
task_mock = Mock()
task_mock.id = prowler_task.id
task_mock.status = "PENDING"
mock_provider_connection.return_value = task_mock
mock_task_get.return_value = prowler_task
provider1, *_ = providers_fixture
assert provider1.connected is None
assert provider1.connection_last_checked_at is None
response = authenticated_client_rbac.post(
reverse("provider-connection", kwargs={"pk": provider1.id})
)
assert response.status_code == status.HTTP_202_ACCEPTED
mock_provider_connection.assert_called_once_with(
provider_id=str(provider1.id), tenant_id=ANY
)
assert "Content-Location" in response.headers
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
def test_connection_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_no_permissions_rbac.post(
reverse("provider-connection", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
+643 -113
View File
@@ -9,11 +9,14 @@ from django.urls import reverse
from rest_framework import status
from api.models import (
Invitation,
Membership,
Provider,
ProviderGroup,
ProviderGroupMembership,
Role,
RoleProviderGroupRelationship,
Invitation,
UserRoleRelationship,
ProviderSecret,
Scan,
StateChoices,
@@ -50,7 +53,6 @@ class TestUserViewSet:
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["email"] == create_test_user.email
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_users_create(self, client):
valid_user_payload = {
"name": "test",
@@ -67,7 +69,6 @@ class TestUserViewSet:
== valid_user_payload["email"].lower()
)
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_users_create_duplicated_email(self, client):
# Create a user
self.test_users_create(client)
@@ -122,7 +123,6 @@ class TestUserViewSet:
"NonExistentEmail@prowler.com",
],
)
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_users_create_used_email(self, authenticated_client, email):
# First user created; no errors should occur
user_payload = {
@@ -418,7 +418,6 @@ class TestTenantViewSet:
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@patch("api.db_router.MainRouter.admin_db", new="default")
@patch("api.v1.views.delete_tenant_task.apply_async")
def test_tenants_delete(
self, delete_tenant_mock, authenticated_client, tenants_fixture
@@ -815,7 +814,7 @@ class TestProviderViewSet:
@pytest.mark.parametrize(
"include_values, expected_resources",
[
("provider_groups", ["provider-groups"]),
("provider_groups", ["provider-group"]),
],
)
def test_providers_list_include(
@@ -1200,7 +1199,7 @@ class TestProviderGroupViewSet:
def test_provider_group_create(self, authenticated_client):
data = {
"data": {
"type": "provider-groups",
"type": "provider-group",
"attributes": {
"name": "Test Provider Group",
},
@@ -1219,7 +1218,7 @@ class TestProviderGroupViewSet:
def test_provider_group_create_invalid(self, authenticated_client):
data = {
"data": {
"type": "provider-groups",
"type": "provider-group",
"attributes": {
# Name is missing
},
@@ -1241,7 +1240,7 @@ class TestProviderGroupViewSet:
data = {
"data": {
"id": str(provider_group.id),
"type": "provider-groups",
"type": "provider-group",
"attributes": {
"name": "Updated Provider Group Name",
},
@@ -1263,7 +1262,7 @@ class TestProviderGroupViewSet:
data = {
"data": {
"id": str(provider_group.id),
"type": "provider-groups",
"type": "provider-group",
"attributes": {
"name": "", # Invalid name
},
@@ -1294,100 +1293,6 @@ class TestProviderGroupViewSet:
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_provider_group_providers_update(
self, authenticated_client, provider_groups_fixture, providers_fixture
):
provider_group = provider_groups_fixture[0]
provider_ids = [str(provider.id) for provider in providers_fixture]
data = {
"data": {
"type": "provider-group-memberships",
"id": str(provider_group.id),
"attributes": {"provider_ids": provider_ids},
}
}
response = authenticated_client.put(
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_200_OK
memberships = ProviderGroupMembership.objects.filter(
provider_group=provider_group
)
assert memberships.count() == len(provider_ids)
for membership in memberships:
assert str(membership.provider_id) in provider_ids
def test_provider_group_providers_update_non_existent_provider(
self, authenticated_client, provider_groups_fixture, providers_fixture
):
provider_group = provider_groups_fixture[0]
provider_ids = [str(provider.id) for provider in providers_fixture]
provider_ids[-1] = "1b59e032-3eb6-4694-93a5-df84cd9b3ce2"
data = {
"data": {
"type": "provider-group-memberships",
"id": str(provider_group.id),
"attributes": {"provider_ids": provider_ids},
}
}
response = authenticated_client.put(
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]
assert (
errors[0]["detail"]
== f"The following provider IDs do not exist: {provider_ids[-1]}"
)
def test_provider_group_providers_update_invalid_provider(
self, authenticated_client, provider_groups_fixture
):
provider_group = provider_groups_fixture[1]
invalid_provider_id = "non-existent-id"
data = {
"data": {
"type": "provider-group-memberships",
"id": str(provider_group.id),
"attributes": {"provider_ids": [invalid_provider_id]},
}
}
response = authenticated_client.put(
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]
assert errors[0]["detail"] == "Must be a valid UUID."
def test_provider_group_providers_update_invalid_payload(
self, authenticated_client, provider_groups_fixture
):
provider_group = provider_groups_fixture[2]
data = {
# Missing "provider_ids"
}
response = authenticated_client.put(
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]
assert errors[0]["detail"] == "Received document does not contain primary data"
def test_provider_group_retrieve_not_found(self, authenticated_client):
response = authenticated_client.get(
reverse("providergroup-detail", kwargs={"pk": "non-existent-id"})
@@ -2652,7 +2557,9 @@ class TestInvitationViewSet:
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_invitations_create_valid(self, authenticated_client, create_test_user):
def test_invitations_create_valid(
self, authenticated_client, create_test_user, roles_fixture
):
user = create_test_user
data = {
"data": {
@@ -2661,6 +2568,11 @@ class TestInvitationViewSet:
"email": "any_email@prowler.com",
"expires_at": self.TOMORROW_ISO,
},
"relationships": {
"roles": {
"data": [{"type": "role", "id": str(roles_fixture[0].id)}]
}
},
}
}
response = authenticated_client.post(
@@ -2719,6 +2631,11 @@ class TestInvitationViewSet:
response.json()["errors"][0]["source"]["pointer"]
== "/data/attributes/email"
)
assert response.json()["errors"][1]["code"] == "required"
assert (
response.json()["errors"][1]["source"]["pointer"]
== "/data/relationships/roles"
)
def test_invitations_create_invalid_expires_at(
self, authenticated_client, invitations_fixture
@@ -2745,6 +2662,11 @@ class TestInvitationViewSet:
response.json()["errors"][0]["source"]["pointer"]
== "/data/attributes/expires_at"
)
assert response.json()["errors"][1]["code"] == "required"
assert (
response.json()["errors"][1]["source"]["pointer"]
== "/data/relationships/roles"
)
def test_invitations_partial_update_valid(
self, authenticated_client, invitations_fixture
@@ -2932,7 +2854,6 @@ class TestInvitationViewSet:
== "This invitation cannot be revoked."
)
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_invitations_accept_invitation_new_user(self, client, invitations_fixture):
invitation, *_ = invitations_fixture
@@ -2958,7 +2879,6 @@ class TestInvitationViewSet:
user__email__iexact=invitation.email, tenant=invitation.tenant
).exists()
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_invitations_accept_invitation_existing_user(
self, authenticated_client, create_test_user, tenants_fixture
):
@@ -2983,7 +2903,6 @@ class TestInvitationViewSet:
response = authenticated_client.post(
reverse("invitation-accept"), data=data, format="json"
)
assert response.status_code == status.HTTP_201_CREATED
invitation.refresh_from_db()
assert Membership.objects.filter(
@@ -2991,7 +2910,6 @@ class TestInvitationViewSet:
).exists()
assert invitation.state == Invitation.State.ACCEPTED.value
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_invitations_accept_invitation_invalid_token(self, authenticated_client):
data = {
"invitation_token": "invalid_token",
@@ -3004,7 +2922,6 @@ class TestInvitationViewSet:
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json()["errors"][0]["code"] == "not_found"
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_invitations_accept_invitation_invalid_token_expired(
self, authenticated_client, invitations_fixture
):
@@ -3023,7 +2940,6 @@ class TestInvitationViewSet:
assert response.status_code == status.HTTP_410_GONE
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_invitations_accept_invitation_invalid_token_expired_new_user(
self, client, invitations_fixture
):
@@ -3047,7 +2963,6 @@ class TestInvitationViewSet:
assert response.status_code == status.HTTP_410_GONE
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_invitations_accept_invitation_invalid_token_accepted(
self, authenticated_client, invitations_fixture
):
@@ -3071,7 +2986,6 @@ class TestInvitationViewSet:
== "This invitation is no longer valid."
)
@patch("api.db_router.MainRouter.admin_db", new="default")
def test_invitations_accept_invitation_invalid_token_revoked(
self, authenticated_client, invitations_fixture
):
@@ -3166,6 +3080,622 @@ class TestInvitationViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
@pytest.mark.django_db
class TestRoleViewSet:
def test_role_list(self, authenticated_client, roles_fixture):
response = authenticated_client.get(reverse("role-list"))
assert response.status_code == status.HTTP_200_OK
assert (
len(response.json()["data"]) == len(roles_fixture) + 2
) # 2 default admin roles, one for each tenant
def test_role_retrieve(self, authenticated_client, roles_fixture):
role = roles_fixture[0]
response = authenticated_client.get(
reverse("role-detail", kwargs={"pk": role.id})
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert data["id"] == str(role.id)
assert data["attributes"]["name"] == role.name
@pytest.mark.parametrize(
("permission_state", "index"),
[("limited", 0), ("unlimited", 2), ("none", 3)],
)
def test_role_retrieve_permission_state(
self, authenticated_client, roles_fixture, permission_state, index
):
role = roles_fixture[index]
response = authenticated_client.get(
reverse("role-detail", kwargs={"pk": role.id}),
{"filter[permission_state]": permission_state},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert data["id"] == str(role.id)
assert data["attributes"]["name"] == role.name
assert data["attributes"]["permission_state"] == permission_state
def test_role_create(self, authenticated_client):
data = {
"data": {
"type": "role",
"attributes": {
"name": "Test Role",
"manage_users": "false",
"manage_account": "false",
"manage_billing": "false",
"manage_providers": "true",
"manage_integrations": "true",
"manage_scans": "true",
"unlimited_visibility": "true",
},
"relationships": {"provider_groups": {"data": []}},
}
}
response = authenticated_client.post(
reverse("role-list"),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_201_CREATED
response_data = response.json()["data"]
assert response_data["attributes"]["name"] == "Test Role"
assert Role.objects.filter(name="Test Role").exists()
def test_role_provider_groups_create(
self, authenticated_client, provider_groups_fixture
):
data = {
"data": {
"type": "role",
"attributes": {
"name": "Test Role",
"manage_users": "false",
"manage_account": "false",
"manage_billing": "false",
"manage_providers": "true",
"manage_integrations": "true",
"manage_scans": "true",
"unlimited_visibility": "true",
},
"relationships": {
"provider_groups": {
"data": [
{"type": "provider-group", "id": str(provider_group.id)}
for provider_group in provider_groups_fixture[:2]
]
}
},
}
}
response = authenticated_client.post(
reverse("role-list"),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_201_CREATED
response_data = response.json()["data"]
assert response_data["attributes"]["name"] == "Test Role"
assert Role.objects.filter(name="Test Role").exists()
relationships = (
Role.objects.filter(name="Test Role").first().provider_groups.all()
)
assert relationships.count() == 2
for relationship in relationships:
assert relationship.id in [pg.id for pg in provider_groups_fixture[:2]]
def test_role_create_invalid(self, authenticated_client):
data = {
"data": {
"type": "role",
"attributes": {
# Name is missing
},
}
}
response = authenticated_client.post(
reverse("role-list"),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]
assert errors[0]["source"]["pointer"] == "/data/attributes/name"
def test_role_partial_update(self, authenticated_client, roles_fixture):
role = roles_fixture[1]
data = {
"data": {
"id": str(role.id),
"type": "role",
"attributes": {
"name": "Updated Provider Group Name",
},
}
}
response = authenticated_client.patch(
reverse("role-detail", kwargs={"pk": role.id}),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_200_OK
role.refresh_from_db()
assert role.name == "Updated Provider Group Name"
def test_role_partial_update_invalid(self, authenticated_client, roles_fixture):
role = roles_fixture[2]
data = {
"data": {
"id": str(role.id),
"type": "role",
"attributes": {
"name": "", # Invalid name
},
}
}
response = authenticated_client.patch(
reverse("role-detail", kwargs={"pk": role.id}),
data=json.dumps(data),
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]
assert errors[0]["source"]["pointer"] == "/data/attributes/name"
def test_role_destroy(self, authenticated_client, roles_fixture):
role = roles_fixture[2]
response = authenticated_client.delete(
reverse("role-detail", kwargs={"pk": role.id})
)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert not Role.objects.filter(id=role.id).exists()
def test_role_destroy_invalid(self, authenticated_client):
response = authenticated_client.delete(
reverse("role-detail", kwargs={"pk": "non-existent-id"})
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_role_retrieve_not_found(self, authenticated_client):
response = authenticated_client.get(
reverse("role-detail", kwargs={"pk": "non-existent-id"})
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_role_list_filters(self, authenticated_client, roles_fixture):
role = roles_fixture[0]
response = authenticated_client.get(
reverse("role-list"), {"filter[name]": role.name}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
assert data[0]["attributes"]["name"] == role.name
def test_role_list_sorting(
self, authenticated_client, set_user_admin_roles_fixture, roles_fixture
):
response = authenticated_client.get(reverse("role-list"), {"sort": "name"})
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
names = [
item["attributes"]["name"]
for item in data
if item["attributes"]["name"] != "admin"
]
assert names == sorted(names, key=lambda v: v.lower())
def test_role_invalid_method(self, authenticated_client):
response = authenticated_client.put(reverse("role-list"))
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@pytest.mark.django_db
class TestUserRoleRelationshipViewSet:
def test_create_relationship(
self, authenticated_client, roles_fixture, create_test_user
):
data = {
"data": [{"type": "role", "id": str(role.id)} for role in roles_fixture[:2]]
}
response = authenticated_client.post(
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = UserRoleRelationship.objects.filter(user=create_test_user.id)
assert relationships.count() == 4
for relationship in relationships[2:]: # Skip admin role
assert relationship.role.id in [r.id for r in roles_fixture[:2]]
def test_create_relationship_already_exists(
self, authenticated_client, roles_fixture, create_test_user
):
data = {
"data": [{"type": "role", "id": str(role.id)} for role in roles_fixture[:2]]
}
authenticated_client.post(
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
data=data,
content_type="application/vnd.api+json",
)
data = {
"data": [
{"type": "role", "id": str(roles_fixture[0].id)},
]
}
response = authenticated_client.post(
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]["detail"]
assert "already associated" in errors
def test_partial_update_relationship(
self, authenticated_client, roles_fixture, create_test_user
):
data = {
"data": [
{"type": "role", "id": str(roles_fixture[1].id)},
]
}
response = authenticated_client.patch(
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = UserRoleRelationship.objects.filter(user=create_test_user.id)
assert relationships.count() == 1
assert {rel.role.id for rel in relationships} == {roles_fixture[1].id}
data = {
"data": [
{"type": "role", "id": str(roles_fixture[1].id)},
{"type": "role", "id": str(roles_fixture[2].id)},
]
}
response = authenticated_client.patch(
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = UserRoleRelationship.objects.filter(user=create_test_user.id)
assert relationships.count() == 2
assert {rel.role.id for rel in relationships} == {
roles_fixture[1].id,
roles_fixture[2].id,
}
def test_destroy_relationship(
self, authenticated_client, roles_fixture, create_test_user
):
response = authenticated_client.delete(
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = UserRoleRelationship.objects.filter(role=roles_fixture[0].id)
assert relationships.count() == 0
def test_invalid_provider_group_id(self, authenticated_client, create_test_user):
invalid_id = "non-existent-id"
data = {"data": [{"type": "provider-group", "id": invalid_id}]}
response = authenticated_client.post(
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"][0]["detail"]
assert "valid UUID" in errors
@pytest.mark.django_db
class TestRoleProviderGroupRelationshipViewSet:
def test_create_relationship(
self, authenticated_client, roles_fixture, provider_groups_fixture
):
data = {
"data": [
{"type": "provider-group", "id": str(provider_group.id)}
for provider_group in provider_groups_fixture[:2]
]
}
response = authenticated_client.post(
reverse(
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = RoleProviderGroupRelationship.objects.filter(
role=roles_fixture[0].id
)
assert relationships.count() == 2
for relationship in relationships:
assert relationship.provider_group.id in [
pg.id for pg in provider_groups_fixture[:2]
]
def test_create_relationship_already_exists(
self, authenticated_client, roles_fixture, provider_groups_fixture
):
data = {
"data": [
{"type": "provider-group", "id": str(provider_group.id)}
for provider_group in provider_groups_fixture[:2]
]
}
authenticated_client.post(
reverse(
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
),
data=data,
content_type="application/vnd.api+json",
)
data = {
"data": [
{"type": "provider-group", "id": str(provider_groups_fixture[0].id)},
]
}
response = authenticated_client.post(
reverse(
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]["detail"]
assert "already associated" in errors
def test_partial_update_relationship(
self, authenticated_client, roles_fixture, provider_groups_fixture
):
data = {
"data": [
{"type": "provider-group", "id": str(provider_groups_fixture[1].id)},
]
}
response = authenticated_client.patch(
reverse(
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[2].id}
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = RoleProviderGroupRelationship.objects.filter(
role=roles_fixture[2].id
)
assert relationships.count() == 1
assert {rel.provider_group.id for rel in relationships} == {
provider_groups_fixture[1].id
}
data = {
"data": [
{"type": "provider-group", "id": str(provider_groups_fixture[1].id)},
{"type": "provider-group", "id": str(provider_groups_fixture[2].id)},
]
}
response = authenticated_client.patch(
reverse(
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[2].id}
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = RoleProviderGroupRelationship.objects.filter(
role=roles_fixture[2].id
)
assert relationships.count() == 2
assert {rel.provider_group.id for rel in relationships} == {
provider_groups_fixture[1].id,
provider_groups_fixture[2].id,
}
def test_destroy_relationship(
self, authenticated_client, roles_fixture, provider_groups_fixture
):
response = authenticated_client.delete(
reverse(
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = RoleProviderGroupRelationship.objects.filter(
role=roles_fixture[0].id
)
assert relationships.count() == 0
def test_invalid_provider_group_id(self, authenticated_client, roles_fixture):
invalid_id = "non-existent-id"
data = {"data": [{"type": "provider-group", "id": invalid_id}]}
response = authenticated_client.post(
reverse(
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[1].id}
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"][0]["detail"]
assert "valid UUID" in errors
@pytest.mark.django_db
class TestProviderGroupMembershipViewSet:
def test_create_relationship(
self, authenticated_client, providers_fixture, provider_groups_fixture
):
provider_group, *_ = provider_groups_fixture
data = {
"data": [
{"type": "provider", "id": str(provider.id)}
for provider in providers_fixture[:2]
]
}
response = authenticated_client.post(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = ProviderGroupMembership.objects.filter(
provider_group=provider_group.id
)
assert relationships.count() == 2
for relationship in relationships:
assert relationship.provider.id in [p.id for p in providers_fixture[:2]]
def test_create_relationship_already_exists(
self, authenticated_client, providers_fixture, provider_groups_fixture
):
provider_group, *_ = provider_groups_fixture
data = {
"data": [
{"type": "provider", "id": str(provider.id)}
for provider in providers_fixture[:2]
]
}
authenticated_client.post(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
data=data,
content_type="application/vnd.api+json",
)
data = {
"data": [
{"type": "provider", "id": str(providers_fixture[0].id)},
]
}
response = authenticated_client.post(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"]["detail"]
assert "already associated" in errors
def test_partial_update_relationship(
self, authenticated_client, providers_fixture, provider_groups_fixture
):
provider_group, *_ = provider_groups_fixture
data = {
"data": [
{"type": "provider", "id": str(providers_fixture[1].id)},
]
}
response = authenticated_client.patch(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = ProviderGroupMembership.objects.filter(
provider_group=provider_group.id
)
assert relationships.count() == 1
assert {rel.provider.id for rel in relationships} == {providers_fixture[1].id}
data = {
"data": [
{"type": "provider", "id": str(providers_fixture[1].id)},
{"type": "provider", "id": str(providers_fixture[2].id)},
]
}
response = authenticated_client.patch(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = ProviderGroupMembership.objects.filter(
provider_group=provider_group.id
)
assert relationships.count() == 2
assert {rel.provider.id for rel in relationships} == {
providers_fixture[1].id,
providers_fixture[2].id,
}
def test_destroy_relationship(
self, authenticated_client, providers_fixture, provider_groups_fixture
):
provider_group, *_ = provider_groups_fixture
data = {
"data": [
{"type": "provider", "id": str(provider.id)}
for provider in providers_fixture[:2]
]
}
response = authenticated_client.post(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_204_NO_CONTENT
response = authenticated_client.delete(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = ProviderGroupMembership.objects.filter(
provider_group=providers_fixture[0].id
)
assert relationships.count() == 0
def test_invalid_provider_group_id(
self, authenticated_client, provider_groups_fixture
):
provider_group, *_ = provider_groups_fixture
invalid_id = "non-existent-id"
data = {"data": [{"type": "provider-group", "id": invalid_id}]}
response = authenticated_client.post(
reverse(
"provider_group-providers-relationship",
kwargs={"pk": provider_group.id},
),
data=data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()["errors"][0]["detail"]
assert "valid UUID" in errors
@pytest.mark.django_db
class TestComplianceOverviewViewSet:
def test_compliance_overview_list_none(self, authenticated_client):
+367 -46
View File
@@ -14,16 +14,20 @@ from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.tokens import RefreshToken
from api.models import (
ComplianceOverview,
Finding,
Invitation,
Membership,
Provider,
ProviderGroup,
ProviderGroupMembership,
ProviderSecret,
Resource,
ResourceTag,
Finding,
ProviderSecret,
Invitation,
InvitationRoleRelationship,
Role,
RoleProviderGroupRelationship,
UserRoleRelationship,
ComplianceOverview,
Scan,
StateChoices,
Task,
@@ -176,10 +180,26 @@ class UserSerializer(BaseSerializerV1):
"""
memberships = serializers.ResourceRelatedField(many=True, read_only=True)
roles = serializers.ResourceRelatedField(many=True, read_only=True)
class Meta:
model = User
fields = ["id", "name", "email", "company_name", "date_joined", "memberships"]
fields = [
"id",
"name",
"email",
"company_name",
"date_joined",
"memberships",
"roles",
]
extra_kwargs = {
"roles": {"read_only": True},
}
included_serializers = {
"roles": "api.v1.serializers.RoleSerializer",
}
class UserCreateSerializer(BaseWriteSerializer):
@@ -235,6 +255,73 @@ class UserUpdateSerializer(BaseWriteSerializer):
return super().update(instance, validated_data)
class RoleResourceIdentifierSerializer(serializers.Serializer):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
class JSONAPIMeta:
resource_name = "role-identifier"
def to_representation(self, instance):
"""
Ensure 'type' is used in the output instead of 'resource_type'.
"""
representation = super().to_representation(instance)
representation["type"] = representation.pop("resource_type", None)
return representation
def to_internal_value(self, data):
"""
Map 'type' back to 'resource_type' during input.
"""
data["resource_type"] = data.pop("type", None)
return super().to_internal_value(data)
class UserRoleRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for modifying user memberships
"""
roles = serializers.ListField(
child=RoleResourceIdentifierSerializer(),
help_text="List of resource identifier objects representing roles.",
)
def create(self, validated_data):
role_ids = [item["id"] for item in validated_data["roles"]]
roles = Role.objects.filter(id__in=role_ids)
tenant_id = self.context.get("tenant_id")
new_relationships = [
UserRoleRelationship(
user=self.context.get("user"), role=r, tenant_id=tenant_id
)
for r in roles
]
UserRoleRelationship.objects.bulk_create(new_relationships)
return self.context.get("user")
def update(self, instance, validated_data):
role_ids = [item["id"] for item in validated_data["roles"]]
roles = Role.objects.filter(id__in=role_ids)
tenant_id = self.context.get("tenant_id")
instance.roles.clear()
new_relationships = [
UserRoleRelationship(user=instance, role=r, tenant_id=tenant_id)
for r in roles
]
UserRoleRelationship.objects.bulk_create(new_relationships)
return instance
class Meta:
model = UserRoleRelationship
fields = ["id", "roles"]
# Tasks
class TaskBase(serializers.ModelSerializer):
state_mapping = {
@@ -361,31 +448,30 @@ class ProviderGroupSerializer(RLSSerializer, BaseWriteSerializer):
providers = serializers.ResourceRelatedField(many=True, read_only=True)
def validate(self, attrs):
tenant = self.context["tenant_id"]
name = attrs.get("name", self.instance.name if self.instance else None)
# Exclude the current instance when checking for uniqueness during updates
queryset = ProviderGroup.objects.filter(tenant=tenant, name=name)
if self.instance:
queryset = queryset.exclude(pk=self.instance.pk)
if queryset.exists():
if ProviderGroup.objects.filter(name=attrs.get("name")).exists():
raise serializers.ValidationError(
{
"name": "A provider group with this name already exists for this tenant."
}
{"name": "A provider group with this name already exists."}
)
return super().validate(attrs)
class Meta:
model = ProviderGroup
fields = ["id", "name", "inserted_at", "updated_at", "providers", "url"]
read_only_fields = ["id", "inserted_at", "updated_at"]
fields = [
"id",
"name",
"inserted_at",
"updated_at",
"providers",
"roles",
"url",
]
extra_kwargs = {
"id": {"read_only": True},
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"roles": {"read_only": True},
"url": {"read_only": True},
}
@@ -406,41 +492,75 @@ class ProviderGroupUpdateSerializer(RLSSerializer, BaseWriteSerializer):
fields = ["id", "name"]
class ProviderGroupMembershipUpdateSerializer(RLSSerializer, BaseWriteSerializer):
class ProviderResourceIdentifierSerializer(serializers.Serializer):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
class JSONAPIMeta:
resource_name = "provider-identifier"
def to_representation(self, instance):
"""
Ensure 'type' is used in the output instead of 'resource_type'.
"""
representation = super().to_representation(instance)
representation["type"] = representation.pop("resource_type", None)
return representation
def to_internal_value(self, data):
"""
Map 'type' back to 'resource_type' during input.
"""
data["resource_type"] = data.pop("type", None)
return super().to_internal_value(data)
class ProviderGroupMembershipSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for modifying provider group memberships
Serializer for modifying provider_group memberships
"""
provider_ids = serializers.ListField(
child=serializers.UUIDField(),
help_text="List of provider UUIDs to add to the group",
providers = serializers.ListField(
child=ProviderResourceIdentifierSerializer(),
help_text="List of resource identifier objects representing providers.",
)
def validate(self, attrs):
tenant_id = self.context["tenant_id"]
provider_ids = attrs.get("provider_ids", [])
def create(self, validated_data):
provider_ids = [item["id"] for item in validated_data["providers"]]
providers = Provider.objects.filter(id__in=provider_ids)
tenant_id = self.context.get("tenant_id")
existing_provider_ids = set(
Provider.objects.filter(
id__in=provider_ids, tenant_id=tenant_id
).values_list("id", flat=True)
)
provided_provider_ids = set(provider_ids)
missing_provider_ids = provided_provider_ids - existing_provider_ids
if missing_provider_ids:
raise serializers.ValidationError(
{
"provider_ids": f"The following provider IDs do not exist: {', '.join(str(id) for id in missing_provider_ids)}"
}
new_relationships = [
ProviderGroupMembership(
provider_group=self.context.get("provider_group"),
provider=p,
tenant_id=tenant_id,
)
for p in providers
]
ProviderGroupMembership.objects.bulk_create(new_relationships)
return super().validate(attrs)
return self.context.get("provider_group")
def update(self, instance, validated_data):
provider_ids = [item["id"] for item in validated_data["providers"]]
providers = Provider.objects.filter(id__in=provider_ids)
tenant_id = self.context.get("tenant_id")
instance.providers.clear()
new_relationships = [
ProviderGroupMembership(
provider_group=instance, provider=p, tenant_id=tenant_id
)
for p in providers
]
ProviderGroupMembership.objects.bulk_create(new_relationships)
return instance
class Meta:
model = ProviderGroupMembership
fields = ["id", "provider_ids"]
fields = ["id", "providers"]
# Providers
@@ -1034,6 +1154,8 @@ class InvitationSerializer(RLSSerializer):
Serializer for the Invitation model.
"""
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
class Meta:
model = Invitation
fields = [
@@ -1043,6 +1165,7 @@ class InvitationSerializer(RLSSerializer):
"email",
"state",
"token",
"roles",
"expires_at",
"inviter",
"url",
@@ -1050,6 +1173,8 @@ class InvitationSerializer(RLSSerializer):
class InvitationBaseWriteSerializer(BaseWriteSerializer):
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
def validate_email(self, value):
user = User.objects.filter(email=value).first()
tenant_id = self.context["tenant_id"]
@@ -1086,31 +1211,54 @@ class InvitationCreateSerializer(InvitationBaseWriteSerializer, RLSSerializer):
class Meta:
model = Invitation
fields = ["email", "expires_at", "state", "token", "inviter"]
fields = ["email", "expires_at", "state", "token", "inviter", "roles"]
extra_kwargs = {
"token": {"read_only": True},
"state": {"read_only": True},
"inviter": {"read_only": True},
"expires_at": {"required": False},
"roles": {"required": False},
}
def create(self, validated_data):
inviter = self.context.get("request").user
tenant_id = self.context.get("tenant_id")
validated_data["inviter"] = inviter
return super().create(validated_data)
roles = validated_data.pop("roles", [])
invitation = super().create(validated_data)
for role in roles:
InvitationRoleRelationship.objects.create(
role=role, invitation=invitation, tenant_id=tenant_id
)
return invitation
class InvitationUpdateSerializer(InvitationBaseWriteSerializer):
class Meta:
model = Invitation
fields = ["id", "email", "expires_at", "state", "token"]
fields = ["id", "email", "expires_at", "state", "token", "roles"]
extra_kwargs = {
"token": {"read_only": True},
"state": {"read_only": True},
"expires_at": {"required": False},
"email": {"required": False},
"roles": {"required": False},
}
def update(self, instance, validated_data):
roles = validated_data.pop("roles", [])
tenant_id = self.context.get("tenant_id")
invitation = super().update(instance, validated_data)
if roles:
instance.roles.clear()
for role in roles:
InvitationRoleRelationship.objects.create(
role=role, invitation=invitation, tenant_id=tenant_id
)
return invitation
class InvitationAcceptSerializer(RLSSerializer):
"""Serializer for accepting an invitation."""
@@ -1122,6 +1270,179 @@ class InvitationAcceptSerializer(RLSSerializer):
fields = ["invitation_token"]
# Roles
class RoleSerializer(RLSSerializer, BaseWriteSerializer):
provider_groups = serializers.ResourceRelatedField(
many=True, queryset=ProviderGroup.objects.all()
)
permission_state = serializers.SerializerMethodField()
def get_permission_state(self, obj):
return obj.permission_state
def validate(self, attrs):
if Role.objects.filter(name=attrs.get("name")).exists():
raise serializers.ValidationError(
{"name": "A role with this name already exists."}
)
if attrs.get("manage_providers"):
attrs["unlimited_visibility"] = True
# Prevent updates to the admin role
if getattr(self.instance, "name", None) == "admin":
raise serializers.ValidationError(
{"name": "The admin role cannot be updated."}
)
return super().validate(attrs)
class Meta:
model = Role
fields = [
"id",
"name",
"manage_users",
"manage_account",
"manage_billing",
"manage_providers",
"manage_integrations",
"manage_scans",
"permission_state",
"unlimited_visibility",
"inserted_at",
"updated_at",
"provider_groups",
"users",
"invitations",
"url",
]
extra_kwargs = {
"id": {"read_only": True},
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"users": {"read_only": True},
"url": {"read_only": True},
}
class RoleCreateSerializer(RoleSerializer):
def create(self, validated_data):
provider_groups = validated_data.pop("provider_groups", [])
users = validated_data.pop("users", [])
tenant_id = self.context.get("tenant_id")
role = Role.objects.create(tenant_id=tenant_id, **validated_data)
through_model_instances = [
RoleProviderGroupRelationship(
role=role,
provider_group=provider_group,
tenant_id=tenant_id,
)
for provider_group in provider_groups
]
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
through_model_instances = [
UserRoleRelationship(
role=user,
user=user,
tenant_id=tenant_id,
)
for user in users
]
UserRoleRelationship.objects.bulk_create(through_model_instances)
return role
class RoleUpdateSerializer(RLSSerializer, BaseWriteSerializer):
class Meta:
model = Role
fields = [
"id",
"name",
"manage_users",
"manage_account",
"manage_billing",
"manage_providers",
"manage_integrations",
"manage_scans",
"unlimited_visibility",
]
class ProviderGroupResourceIdentifierSerializer(serializers.Serializer):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
class JSONAPIMeta:
resource_name = "provider-group-identifier"
def to_representation(self, instance):
"""
Ensure 'type' is used in the output instead of 'resource_type'.
"""
representation = super().to_representation(instance)
representation["type"] = representation.pop("resource_type", None)
return representation
def to_internal_value(self, data):
"""
Map 'type' back to 'resource_type' during input.
"""
data["resource_type"] = data.pop("type", None)
return super().to_internal_value(data)
class RoleProviderGroupRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for modifying role memberships
"""
provider_groups = serializers.ListField(
child=ProviderGroupResourceIdentifierSerializer(),
help_text="List of resource identifier objects representing provider groups.",
)
def create(self, validated_data):
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
tenant_id = self.context.get("tenant_id")
new_relationships = [
RoleProviderGroupRelationship(
role=self.context.get("role"), provider_group=pg, tenant_id=tenant_id
)
for pg in provider_groups
]
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
return self.context.get("role")
def update(self, instance, validated_data):
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
tenant_id = self.context.get("tenant_id")
instance.provider_groups.clear()
new_relationships = [
RoleProviderGroupRelationship(
role=instance, provider_group=pg, tenant_id=tenant_id
)
for pg in provider_groups
]
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
return instance
class Meta:
model = RoleProviderGroupRelationship
fields = ["id", "provider_groups"]
# Compliance overview
+31 -5
View File
@@ -3,16 +3,20 @@ from drf_spectacular.views import SpectacularRedocView
from rest_framework_nested import routers
from api.v1.views import (
ComplianceOverviewViewSet,
CustomTokenObtainView,
CustomTokenRefreshView,
FindingViewSet,
InvitationAcceptViewSet,
InvitationViewSet,
MembershipViewSet,
OverviewViewSet,
ProviderGroupViewSet,
ProviderGroupProvidersRelationshipView,
ProviderSecretViewSet,
InvitationViewSet,
InvitationAcceptViewSet,
RoleViewSet,
RoleProviderGroupRelationshipView,
UserRoleRelationshipView,
OverviewViewSet,
ComplianceOverviewViewSet,
ProviderViewSet,
ResourceViewSet,
ScanViewSet,
@@ -29,11 +33,12 @@ router = routers.DefaultRouter(trailing_slash=False)
router.register(r"users", UserViewSet, basename="user")
router.register(r"tenants", TenantViewSet, basename="tenant")
router.register(r"providers", ProviderViewSet, basename="provider")
router.register(r"provider_groups", ProviderGroupViewSet, basename="providergroup")
router.register(r"provider-groups", ProviderGroupViewSet, basename="providergroup")
router.register(r"scans", ScanViewSet, basename="scan")
router.register(r"tasks", TaskViewSet, basename="task")
router.register(r"resources", ResourceViewSet, basename="resource")
router.register(r"findings", FindingViewSet, basename="finding")
router.register(r"roles", RoleViewSet, basename="role")
router.register(
r"compliance-overviews", ComplianceOverviewViewSet, basename="complianceoverview"
)
@@ -80,6 +85,27 @@ urlpatterns = [
InvitationAcceptViewSet.as_view({"post": "accept"}),
name="invitation-accept",
),
path(
"roles/<uuid:pk>/relationships/provider_groups",
RoleProviderGroupRelationshipView.as_view(
{"post": "create", "patch": "partial_update", "delete": "destroy"}
),
name="role-provider-groups-relationship",
),
path(
"users/<uuid:pk>/relationships/roles",
UserRoleRelationshipView.as_view(
{"post": "create", "patch": "partial_update", "delete": "destroy"}
),
name="user-roles-relationship",
),
path(
"provider-groups/<uuid:pk>/relationships/providers",
ProviderGroupProvidersRelationshipView.as_view(
{"post": "create", "patch": "partial_update", "delete": "destroy"}
),
name="provider_group-providers-relationship",
),
path("", include(router.urls)),
path("", include(tenants_router.urls)),
path("", include(users_router.urls)),
+575 -74
View File
@@ -8,6 +8,7 @@ from django.urls import reverse
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_control
from drf_spectacular.settings import spectacular_settings
from drf_spectacular_jsonapi.schemas.openapi import JsonApiAutoSchema
from drf_spectacular.utils import (
OpenApiParameter,
OpenApiResponse,
@@ -25,8 +26,10 @@ from rest_framework.exceptions import (
ValidationError,
)
from rest_framework.generics import GenericAPIView, get_object_or_404
from rest_framework_json_api.views import Response
from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from rest_framework.permissions import SAFE_METHODS
from tasks.beat import schedule_provider_scan
from tasks.tasks import (
check_provider_connection_task,
@@ -52,8 +55,12 @@ from api.filters import (
TaskFilter,
TenantFilter,
UserFilter,
RoleFilter,
)
from api.models import (
StatusChoices,
User,
UserRoleRelationship,
ComplianceOverview,
Finding,
Invitation,
@@ -62,20 +69,27 @@ from api.models import (
ProviderGroup,
ProviderGroupMembership,
ProviderSecret,
Role,
RoleProviderGroupRelationship,
Resource,
Scan,
ScanSummary,
SeverityChoices,
StateChoices,
StatusChoices,
Task,
User,
)
from api.pagination import ComplianceOverviewPagination
from api.rbac.permissions import HasPermissions, Permissions
from api.rls import Tenant
from api.utils import validate_invitation
from api.uuid_utils import datetime_to_uuid7
from api.v1.serializers import (
TokenSerializer,
TokenRefreshSerializer,
UserSerializer,
UserCreateSerializer,
UserUpdateSerializer,
UserRoleRelationshipSerializer,
ComplianceOverviewFullSerializer,
ComplianceOverviewSerializer,
FindingDynamicFilterSerializer,
@@ -89,34 +103,39 @@ from api.v1.serializers import (
OverviewProviderSerializer,
OverviewSeveritySerializer,
ProviderCreateSerializer,
ProviderGroupMembershipUpdateSerializer,
ProviderGroupMembershipSerializer,
ProviderGroupSerializer,
ProviderGroupUpdateSerializer,
ProviderSecretCreateSerializer,
ProviderSecretSerializer,
ProviderSecretUpdateSerializer,
RoleProviderGroupRelationshipSerializer,
ProviderSerializer,
ProviderUpdateSerializer,
ResourceSerializer,
ScanCreateSerializer,
ScanSerializer,
ScanUpdateSerializer,
ScheduleDailyCreateSerializer,
TaskSerializer,
TenantSerializer,
TokenRefreshSerializer,
TokenSerializer,
UserCreateSerializer,
UserSerializer,
UserUpdateSerializer,
TaskSerializer,
ScanSerializer,
ScanCreateSerializer,
ScanUpdateSerializer,
ResourceSerializer,
ProviderSecretSerializer,
ProviderSecretUpdateSerializer,
ProviderSecretCreateSerializer,
RoleSerializer,
RoleCreateSerializer,
RoleUpdateSerializer,
ScheduleDailyCreateSerializer,
)
CACHE_DECORATOR = cache_control(
max_age=django_settings.CACHE_MAX_AGE,
stale_while_revalidate=django_settings.CACHE_STALE_WHILE_REVALIDATE,
)
class RelationshipViewSchema(JsonApiAutoSchema):
def _resolve_path_parameters(self, _path_variables):
return []
@extend_schema(
tags=["Token"],
summary="Obtain a token",
@@ -271,6 +290,26 @@ class UserViewSet(BaseUserViewset):
filterset_class = UserFilter
ordering = ["-date_joined"]
ordering_fields = ["name", "email", "company_name", "date_joined", "is_active"]
required_permissions = [Permissions.MANAGE_USERS]
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.required_permissions = self.get_required_permissions()
super().initial(request, *args, **kwargs)
def get_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.action == "me":
# No permissions required for me request
return []
else:
# Require permission for the rest of the requests
return [Permissions.MANAGE_USERS]
def get_queryset(self):
# If called during schema generation, return an empty queryset
@@ -347,11 +386,123 @@ class UserViewSet(BaseUserViewset):
user=user, tenant=tenant, role=role
)
if invitation:
user_role = []
for role in invitation.roles.all():
user_role.append(
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=role, tenant=invitation.tenant
)
)
invitation.state = Invitation.State.ACCEPTED
invitation.save(using=MainRouter.admin_db)
else:
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
return Response(data=UserSerializer(user).data, status=status.HTTP_201_CREATED)
@extend_schema_view(
create=extend_schema(
tags=["User"],
summary="Create a new user-roles relationship",
description="Add a new user-roles relationship to the system by providing the required user-roles details.",
responses={
204: OpenApiResponse(description="Relationship created successfully"),
400: OpenApiResponse(
description="Bad request (e.g., relationship already exists)"
),
},
),
partial_update=extend_schema(
tags=["User"],
summary="Partially update a user-roles relationship",
description="Update the user-roles relationship information without affecting other fields.",
responses={
204: OpenApiResponse(
response=None, description="Relationship updated successfully"
)
},
),
destroy=extend_schema(
tags=["User"],
summary="Delete a user-roles relationship",
description="Remove the user-roles relationship from the system by their ID.",
responses={
204: OpenApiResponse(
response=None, description="Relationship deleted successfully"
)
},
),
)
class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
queryset = User.objects.all()
serializer_class = UserRoleRelationshipSerializer
resource_name = "roles"
http_method_names = ["post", "patch", "delete"]
schema = RelationshipViewSchema()
def get_queryset(self):
return User.objects.all()
def create(self, request, *args, **kwargs):
user = self.get_object()
role_ids = [item["id"] for item in request.data]
existing_relationships = UserRoleRelationship.objects.filter(
user=user, role_id__in=role_ids
)
if existing_relationships.exists():
return Response(
{"detail": "One or more roles are already associated with the user."},
status=status.HTTP_400_BAD_REQUEST,
)
serializer = self.get_serializer(
data={"roles": request.data},
context={
"user": user,
"tenant_id": self.request.tenant_id,
"request": request,
},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def partial_update(self, request, *args, **kwargs):
user = self.get_object()
serializer = self.get_serializer(
instance=user,
data={"roles": request.data},
context={"tenant_id": self.request.tenant_id, "request": request},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def destroy(self, request, *args, **kwargs):
user = self.get_object()
user.roles.clear()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
list=extend_schema(
tags=["Tenant"],
@@ -389,6 +540,8 @@ class TenantViewSet(BaseTenantViewset):
search_fields = ["name"]
ordering = ["-inserted_at"]
ordering_fields = ["name", "inserted_at", "updated_at"]
required_permissions = [Permissions.MANAGE_ACCOUNT]
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def get_queryset(self):
return Tenant.objects.all()
@@ -562,66 +715,139 @@ class ProviderGroupViewSet(BaseRLSViewSet):
queryset = ProviderGroup.objects.all()
serializer_class = ProviderGroupSerializer
filterset_class = ProviderGroupFilter
http_method_names = ["get", "post", "patch", "put", "delete"]
http_method_names = ["get", "post", "patch", "delete"]
ordering = ["inserted_at"]
required_permissions = []
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.required_permissions = self.get_required_permissions()
super().initial(request, *args, **kwargs)
def get_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.request.method in SAFE_METHODS:
# No permissions required for GET requests
return []
else:
# Require permission for non-GET requests
return [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return ProviderGroup.objects.prefetch_related("providers")
user = self.request.user
user_roles = user.roles.all()
# Check if any of the user's roles have UNLIMITED_VISIBILITY
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all provider groups
return ProviderGroup.objects.prefetch_related("providers")
# Collect provider groups associated with the user's roles
provider_groups = (
ProviderGroup.objects.filter(roles__in=user_roles)
.distinct()
.prefetch_related("providers")
)
return provider_groups
def get_serializer_class(self):
if self.action == "partial_update":
return ProviderGroupUpdateSerializer
elif self.action == "providers":
if hasattr(self, "response_serializer_class"):
return self.response_serializer_class
return ProviderGroupMembershipUpdateSerializer
return super().get_serializer_class()
@extend_schema(
tags=["Provider Group"],
summary="Add providers to a provider group",
description="Add one or more providers to an existing provider group.",
request=ProviderGroupMembershipUpdateSerializer,
responses={200: OpenApiResponse(response=ProviderGroupSerializer)},
)
@action(detail=True, methods=["put"], url_name="providers")
def providers(self, request, pk=None):
@extend_schema(tags=["Provider Group"])
@extend_schema_view(
create=extend_schema(
summary="Create a new provider_group-providers relationship",
description="Add a new provider_group-providers relationship to the system by providing the required provider_group-providers details.",
responses={
204: OpenApiResponse(description="Relationship created successfully"),
400: OpenApiResponse(
description="Bad request (e.g., relationship already exists)"
),
},
),
partial_update=extend_schema(
summary="Partially update a provider_group-providers relationship",
description="Update the provider_group-providers relationship information without affecting other fields.",
responses={
204: OpenApiResponse(
response=None, description="Relationship updated successfully"
)
},
),
destroy=extend_schema(
summary="Delete a provider_group-providers relationship",
description="Remove the provider_group-providers relationship from the system by their ID.",
responses={
204: OpenApiResponse(
response=None, description="Relationship deleted successfully"
)
},
),
)
class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
queryset = ProviderGroup.objects.all()
serializer_class = ProviderGroupMembershipSerializer
resource_name = "providers"
http_method_names = ["post", "patch", "delete"]
schema = RelationshipViewSchema()
def get_queryset(self):
return ProviderGroup.objects.all()
def create(self, request, *args, **kwargs):
provider_group = self.get_object()
# Validate input data
serializer = self.get_serializer_class()(
data=request.data,
context=self.get_serializer_context(),
provider_ids = [item["id"] for item in request.data]
existing_relationships = ProviderGroupMembership.objects.filter(
provider_group=provider_group, provider_id__in=provider_ids
)
if existing_relationships.exists():
return Response(
{
"detail": "One or more providers are already associated with the provider_group."
},
status=status.HTTP_400_BAD_REQUEST,
)
serializer = self.get_serializer(
data={"providers": request.data},
context={
"provider_group": provider_group,
"tenant_id": self.request.tenant_id,
"request": request,
},
)
serializer.is_valid(raise_exception=True)
serializer.save()
provider_ids = serializer.validated_data["provider_ids"]
return Response(status=status.HTTP_204_NO_CONTENT)
# Update memberships
ProviderGroupMembership.objects.filter(
provider_group=provider_group, tenant_id=request.tenant_id
).delete()
provider_group_memberships = [
ProviderGroupMembership(
tenant_id=self.request.tenant_id,
provider_group=provider_group,
provider_id=provider_id,
)
for provider_id in provider_ids
]
ProviderGroupMembership.objects.bulk_create(
provider_group_memberships, ignore_conflicts=True
def partial_update(self, request, *args, **kwargs):
provider_group = self.get_object()
serializer = self.get_serializer(
instance=provider_group,
data={"providers": request.data},
context={"tenant_id": self.request.tenant_id, "request": request},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
# Return the updated provider group with providers
provider_group.refresh_from_db()
self.response_serializer_class = ProviderGroupSerializer
response_serializer = ProviderGroupSerializer(
provider_group, context=self.get_serializer_context()
)
return Response(data=response_serializer.data, status=status.HTTP_200_OK)
def destroy(self, request, *args, **kwargs):
provider_group = self.get_object()
provider_group.providers.clear()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
@@ -671,9 +897,41 @@ class ProviderViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
required_permissions = []
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.required_permissions = self.get_required_permissions()
super().initial(request, *args, **kwargs)
def get_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.request.method in SAFE_METHODS:
# No permissions required for GET requests
return []
else:
# Require permission for non-GET requests
return [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return Provider.objects.all()
user = self.request.user
user_roles = user.roles.all()
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all providers
return Provider.objects.all()
# User lacks permission, filter providers based on provider groups associated with the role
provider_groups = user_roles[0].provider_groups.all()
providers = Provider.objects.filter(
provider_groups__in=provider_groups
).distinct()
return providers
def get_serializer_class(self):
if self.action == "create":
@@ -793,9 +1051,40 @@ class ScanViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
required_permissions = [Permissions.MANAGE_SCANS]
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.required_permissions = self.get_required_permissions()
super().initial(request, *args, **kwargs)
def get_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.request.method in SAFE_METHODS:
# No permissions required for GET requests
return []
else:
# Require permission for non-GET requests
return [Permissions.MANAGE_SCANS]
def get_queryset(self):
return Scan.objects.all()
user = self.request.user
user_roles = user.roles.all()
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all scans
return Scan.objects.all()
# User lacks permission, filter providers based on provider groups associated with the role
provider_groups = user_roles[0].provider_groups.all()
providers = Provider.objects.filter(
provider_groups__in=provider_groups
).distinct()
return Scan.objects.filter(provider__in=providers).distinct()
def get_serializer_class(self):
if self.action == "create":
@@ -885,11 +1174,26 @@ class TaskViewSet(BaseRLSViewSet):
search_fields = ["name"]
ordering = ["-inserted_at"]
ordering_fields = ["inserted_at", "completed_at", "name", "state"]
required_permissions = []
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def get_queryset(self):
return Task.objects.annotate(
name=F("task_runner_task__task_name"), state=F("task_runner_task__status")
)
user = self.request.user
user_roles = user.roles.all()
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all tasks
return Task.objects.annotate(
name=F("task_runner_task__task_name"),
state=F("task_runner_task__status"),
)
# User lacks permission, filter tasks based on provider groups associated with the role
provider_groups = user_roles[0].provider_groups.all()
providers = Provider.objects.filter(
provider_groups__in=provider_groups
).distinct()
scans = Scan.objects.filter(provider__in=providers).distinct()
return Task.objects.filter(scan__in=scans).distinct()
def destroy(self, request, *args, pk=None, **kwargs):
task = get_object_or_404(Task, pk=pk)
@@ -950,11 +1254,31 @@ class ResourceViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
required_permissions = []
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.required_permissions = ResourceViewSet.required_permissions
super().initial(request, *args, **kwargs)
def get_queryset(self):
queryset = Resource.objects.all()
search_value = self.request.query_params.get("filter[search]", None)
user = self.request.user
user_roles = user.roles.all()
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all scans
queryset = Resource.objects.all()
else:
# User lacks permission, filter providers based on provider groups associated with the role
provider_groups = user_roles[0].provider_groups.all()
providers = Provider.objects.filter(
provider_groups__in=provider_groups
).distinct()
queryset = Resource.objects.filter(provider__in=providers).distinct()
search_value = self.request.query_params.get("filter[search]", None)
if search_value:
# Django's ORM will build a LEFT JOIN and OUTER JOIN on the "through" table, resulting in duplicates
# The duplicates then require a `distinct` query
@@ -1025,11 +1349,15 @@ class FindingViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
required_permissions = []
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def inserted_at_to_uuidv7(self, inserted_at):
if inserted_at is None:
return None
return datetime_to_uuid7(inserted_at)
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.required_permissions = ResourceViewSet.required_permissions
super().initial(request, *args, **kwargs)
def get_serializer_class(self):
if self.action == "findings_services_regions":
@@ -1038,9 +1366,21 @@ class FindingViewSet(BaseRLSViewSet):
return super().get_serializer_class()
def get_queryset(self):
queryset = Finding.objects.all()
search_value = self.request.query_params.get("filter[search]", None)
user = self.request.user
user_roles = user.roles.all()
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all scans
queryset = Finding.objects.all()
else:
# User lacks permission, filter providers based on provider groups associated with the role
provider_groups = user_roles[0].provider_groups.all()
providers = Provider.objects.filter(
provider_groups__in=provider_groups
).distinct()
scans = Scan.objects.filter(provider__in=providers).distinct()
queryset = Finding.objects.filter(scan__in=scans).distinct()
search_value = self.request.query_params.get("filter[search]", None)
if search_value:
# Django's ORM will build a LEFT JOIN and OUTER JOIN on any "through" tables, resulting in duplicates
# The duplicates then require a `distinct` query
@@ -1068,6 +1408,11 @@ class FindingViewSet(BaseRLSViewSet):
return queryset
def inserted_at_to_uuidv7(self, inserted_at):
if inserted_at is None:
return None
return datetime_to_uuid7(inserted_at)
@action(detail=False, methods=["get"], url_name="findings_services_regions")
def findings_services_regions(self, request):
queryset = self.get_queryset()
@@ -1188,6 +1533,8 @@ class InvitationViewSet(BaseRLSViewSet):
"state",
"inviter",
]
required_permissions = [Permissions.MANAGE_ACCOUNT]
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def get_queryset(self):
return Invitation.objects.all()
@@ -1275,6 +1622,13 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
user=user,
tenant=invitation.tenant,
)
user_role = []
for role in invitation.roles.all():
user_role.append(
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=role, tenant=invitation.tenant
)
)
invitation.state = Invitation.State.ACCEPTED
invitation.save(using=MainRouter.admin_db)
@@ -1283,6 +1637,153 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
return Response(data=membership_serializer.data, status=status.HTTP_201_CREATED)
@extend_schema(tags=["Role"])
@extend_schema_view(
list=extend_schema(
tags=["Role"],
summary="List all roles",
description="Retrieve a list of all roles with options for filtering by various criteria.",
),
retrieve=extend_schema(
tags=["Role"],
summary="Retrieve data from a role",
description="Fetch detailed information about a specific role by their ID.",
),
create=extend_schema(
tags=["Role"],
summary="Create a new role",
description="Add a new role to the system by providing the required role details.",
),
partial_update=extend_schema(
tags=["Role"],
summary="Partially update a role",
description="Update certain fields of an existing role's information without affecting other fields.",
responses={200: RoleSerializer},
),
destroy=extend_schema(
tags=["Role"],
summary="Delete a role",
description="Remove a role from the system by their ID.",
),
)
class RoleViewSet(BaseRLSViewSet):
queryset = Role.objects.all()
serializer_class = RoleSerializer
filterset_class = RoleFilter
http_method_names = ["get", "post", "patch", "delete"]
ordering = ["inserted_at"]
required_permissions = [Permissions.MANAGE_ACCOUNT]
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
def get_queryset(self):
return Role.objects.all()
def get_serializer_class(self):
if self.action == "create":
return RoleCreateSerializer
elif self.action == "partial_update":
return RoleUpdateSerializer
return super().get_serializer_class()
def partial_update(self, request, *args, **kwargs):
user = request.user
user_role = user.roles.all().first()
# If the user is the owner of the role, the manage_account field is not editable
if user_role and kwargs["pk"] == str(user_role.id):
request.data["manage_account"] = str(user_role.manage_account).lower()
return super().partial_update(request, *args, **kwargs)
@extend_schema_view(
create=extend_schema(
tags=["Role"],
summary="Create a new role-provider_groups relationship",
description="Add a new role-provider_groups relationship to the system by providing the required role-provider_groups details.",
responses={
204: OpenApiResponse(description="Relationship created successfully"),
400: OpenApiResponse(
description="Bad request (e.g., relationship already exists)"
),
},
),
partial_update=extend_schema(
tags=["Role"],
summary="Partially update a role-provider_groups relationship",
description="Update the role-provider_groups relationship information without affecting other fields.",
responses={
204: OpenApiResponse(
response=None, description="Relationship updated successfully"
)
},
),
destroy=extend_schema(
tags=["Role"],
summary="Delete a role-provider_groups relationship",
description="Remove the role-provider_groups relationship from the system by their ID.",
responses={
204: OpenApiResponse(
response=None, description="Relationship deleted successfully"
)
},
),
)
class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
queryset = Role.objects.all()
serializer_class = RoleProviderGroupRelationshipSerializer
resource_name = "provider_groups"
http_method_names = ["post", "patch", "delete"]
schema = RelationshipViewSchema()
def get_queryset(self):
return Role.objects.all()
def create(self, request, *args, **kwargs):
role = self.get_object()
provider_group_ids = [item["id"] for item in request.data]
existing_relationships = RoleProviderGroupRelationship.objects.filter(
role=role, provider_group_id__in=provider_group_ids
)
if existing_relationships.exists():
return Response(
{
"detail": "One or more provider groups are already associated with the role."
},
status=status.HTTP_400_BAD_REQUEST,
)
serializer = self.get_serializer(
data={"provider_groups": request.data},
context={
"role": role,
"tenant_id": self.request.tenant_id,
"request": request,
},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def partial_update(self, request, *args, **kwargs):
role = self.get_object()
serializer = self.get_serializer(
instance=role,
data={"provider_groups": request.data},
context={"tenant_id": self.request.tenant_id, "request": request},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def destroy(self, request, *args, **kwargs):
role = self.get_object()
role.provider_groups.clear()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
list=extend_schema(
tags=["Compliance Overview"],
+2 -2
View File
@@ -10,8 +10,8 @@ DATABASES = {
"default": {
"ENGINE": "psqlextra.backend",
"NAME": "prowler_db_test",
"USER": env("POSTGRES_USER", default="prowler"),
"PASSWORD": env("POSTGRES_PASSWORD", default="S3cret"),
"USER": env("POSTGRES_USER", default="prowler_admin"),
"PASSWORD": env("POSTGRES_PASSWORD", default="postgres"),
"HOST": env("POSTGRES_HOST", default="localhost"),
"PORT": env("POSTGRES_PORT", default="5432"),
},
+236 -1
View File
@@ -10,6 +10,8 @@ from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
from rest_framework import status
from rest_framework.test import APIClient
from unittest.mock import patch
from api.db_utils import tenant_transaction
from api.models import (
Finding,
@@ -20,6 +22,7 @@ from api.models import (
ProviderGroup,
Resource,
ResourceTag,
Role,
Scan,
StateChoices,
Task,
@@ -27,6 +30,7 @@ from api.models import (
ProviderSecret,
Invitation,
ComplianceOverview,
UserRoleRelationship,
)
from api.rls import Tenant
from api.v1.serializers import TokenSerializer
@@ -83,8 +87,148 @@ def create_test_user(django_db_setup, django_db_blocker):
return user
@pytest.fixture(scope="function")
def create_test_user_rbac(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing",
email="rbac@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
Role.objects.create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.create(
user=user,
role=Role.objects.get(name="admin"),
tenant_id=tenant.id,
)
return user
@pytest.fixture(scope="function")
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing",
email="rbac_noroles@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
return user
@pytest.fixture(scope="function")
def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing_limited",
email="rbac_limited@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
Role.objects.create(
name="limited",
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=False,
manage_integrations=False,
manage_scans=False,
unlimited_visibility=False,
)
UserRoleRelationship.objects.create(
user=user,
role=Role.objects.get(name="limited"),
tenant_id=tenant.id,
)
return user
@pytest.fixture
def authenticated_client(create_test_user, tenants_fixture, client):
def authenticated_client_rbac(create_test_user_rbac, tenants_fixture, client):
client.user = create_test_user_rbac
serializer = TokenSerializer(
data={"type": "tokens", "email": "rbac@rbac.com", "password": TEST_PASSWORD}
)
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@pytest.fixture
def authenticated_client_rbac_noroles(
create_test_user_rbac_no_roles, tenants_fixture, client
):
client.user = create_test_user_rbac_no_roles
serializer = TokenSerializer(
data={
"type": "tokens",
"email": "rbac_noroles@rbac.com",
"password": TEST_PASSWORD,
}
)
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@pytest.fixture
def authenticated_client_no_permissions_rbac(
create_test_user_rbac_limited, tenants_fixture, client
):
client.user = create_test_user_rbac_limited
serializer = TokenSerializer(
data={
"type": "tokens",
"email": "rbac_limited@rbac.com",
"password": TEST_PASSWORD,
}
)
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@pytest.fixture
def authenticated_client(
create_test_user, tenants_fixture, set_user_admin_roles_fixture, client
):
client.user = create_test_user
serializer = TokenSerializer(
data={"type": "tokens", "email": TEST_USER, "password": TEST_PASSWORD}
@@ -104,6 +248,7 @@ def authenticated_api_client(create_test_user, tenants_fixture):
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@@ -128,9 +273,33 @@ def tenants_fixture(create_test_user):
tenant3 = Tenant.objects.create(
name="Tenant Three",
)
return tenant1, tenant2, tenant3
@pytest.fixture
def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
user = create_test_user
for tenant in tenants_fixture[:2]:
with tenant_transaction(str(tenant.id)):
role = Role.objects.create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.create(
user=user,
role=role,
tenant_id=tenant.id,
)
@pytest.fixture
def invitations_fixture(create_test_user, tenants_fixture):
user = create_test_user
@@ -210,6 +379,57 @@ def provider_groups_fixture(tenants_fixture):
return pgroup1, pgroup2, pgroup3
@pytest.fixture
def roles_fixture(tenants_fixture):
tenant, *_ = tenants_fixture
role1 = Role.objects.create(
name="Role One",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=False,
manage_scans=True,
unlimited_visibility=False,
)
role2 = Role.objects.create(
name="Role Two",
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
role3 = Role.objects.create(
name="Role Three",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
role4 = Role.objects.create(
name="Role Four",
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=False,
manage_integrations=False,
manage_scans=False,
unlimited_visibility=False,
)
return role1, role2, role3, role4
@pytest.fixture
def provider_secret_fixture(providers_fixture):
return tuple(
@@ -544,3 +764,18 @@ def get_api_tokens(
def get_authorization_header(access_token: str) -> dict:
return {"Authorization": f"Bearer {access_token}"}
def pytest_collection_modifyitems(items):
"""Ensure test_rbac.py is executed first."""
items.sort(key=lambda item: 0 if "test_rbac.py" in item.nodeid else 1)
def pytest_configure(config):
# Apply the mock before the test session starts. This is necessary to avoid admin error when running the 0004_rbac_missing_admin_roles migration
patch("api.db_router.MainRouter.admin_db", new="default").start()
def pytest_unconfigure(config):
# Stop all patches after the test session ends. This is necessary to avoid admin error when running the 0004_rbac_missing_admin_roles migration
patch.stopall()
@@ -1,5 +1,3 @@
from unittest.mock import patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider, delete_tenant
@@ -24,7 +22,6 @@ class TestDeleteProvider:
delete_provider(non_existent_pk)
@patch("api.db_router.MainRouter.admin_db", new="default")
@pytest.mark.django_db
class TestDeleteTenant:
def test_delete_tenant_success(self, tenants_fixture, providers_fixture):