mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
feat(api): RBAC system (#6114)
This commit is contained in:
committed by
GitHub
parent
f9fbde6637
commit
d00d254c90
+1
-1
@@ -1,4 +1,4 @@
|
||||
FROM python:3.12-alpine AS build
|
||||
FROM python:3.12.8-alpine3.20 AS build
|
||||
|
||||
LABEL maintainer="https://github.com/prowler-cloud/api"
|
||||
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ description = "Prowler's API (Django/DRF)"
|
||||
license = "Apache-2.0"
|
||||
name = "prowler-api"
|
||||
package-mode = false
|
||||
version = "1.0.0"
|
||||
version = "1.1.0"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
celery = {extras = ["pytest"], version = "^5.4.0"}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import transaction
|
||||
from rest_framework import permissions
|
||||
from rest_framework.exceptions import NotAuthenticated
|
||||
@@ -8,6 +9,8 @@ from rest_framework_simplejwt.authentication import JWTAuthentication
|
||||
|
||||
from api.db_utils import POSTGRES_USER_VAR, tenant_transaction
|
||||
from api.filters import CustomDjangoFilterBackend
|
||||
from api.models import Role, Tenant
|
||||
from api.db_router import MainRouter
|
||||
|
||||
|
||||
class BaseViewSet(ModelViewSet):
|
||||
@@ -58,7 +61,39 @@ class BaseRLSViewSet(BaseViewSet):
|
||||
class BaseTenantViewset(BaseViewSet):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
with transaction.atomic():
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
tenant = super().dispatch(request, *args, **kwargs)
|
||||
|
||||
try:
|
||||
# If the request is a POST, create the admin role
|
||||
if request.method == "POST":
|
||||
isinstance(tenant, dict) and self._create_admin_role(tenant.data["id"])
|
||||
except Exception as e:
|
||||
self._handle_creation_error(e, tenant)
|
||||
raise
|
||||
|
||||
return tenant
|
||||
|
||||
def _create_admin_role(self, tenant_id):
|
||||
Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant_id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
|
||||
def _handle_creation_error(self, error, tenant):
|
||||
if tenant.data.get("id"):
|
||||
try:
|
||||
Tenant.objects.using(MainRouter.admin_db).filter(
|
||||
id=tenant.data["id"]
|
||||
).delete()
|
||||
except ObjectDoesNotExist:
|
||||
pass # Tenant might not exist, handle gracefully
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
if (
|
||||
|
||||
@@ -22,13 +22,11 @@ from api.db_utils import (
|
||||
StatusEnumField,
|
||||
)
|
||||
from api.models import (
|
||||
ComplianceOverview,
|
||||
Finding,
|
||||
Invitation,
|
||||
Membership,
|
||||
PermissionChoices,
|
||||
Provider,
|
||||
ProviderGroup,
|
||||
ProviderSecret,
|
||||
Resource,
|
||||
ResourceTag,
|
||||
Scan,
|
||||
@@ -36,6 +34,10 @@ from api.models import (
|
||||
SeverityChoices,
|
||||
StateChoices,
|
||||
StatusChoices,
|
||||
ProviderSecret,
|
||||
Invitation,
|
||||
Role,
|
||||
ComplianceOverview,
|
||||
Task,
|
||||
User,
|
||||
)
|
||||
@@ -481,6 +483,26 @@ class UserFilter(FilterSet):
|
||||
}
|
||||
|
||||
|
||||
class RoleFilter(FilterSet):
|
||||
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
|
||||
updated_at = DateFilter(field_name="updated_at", lookup_expr="date")
|
||||
permission_state = ChoiceFilter(
|
||||
choices=PermissionChoices.choices, method="filter_permission_state"
|
||||
)
|
||||
|
||||
def filter_permission_state(self, queryset, name, value):
|
||||
return Role.filter_by_permission_state(queryset, value)
|
||||
|
||||
class Meta:
|
||||
model = Role
|
||||
fields = {
|
||||
"id": ["exact", "in"],
|
||||
"name": ["exact", "in"],
|
||||
"inserted_at": ["gte", "lte"],
|
||||
"updated_at": ["gte", "lte"],
|
||||
}
|
||||
|
||||
|
||||
class ComplianceOverviewFilter(FilterSet):
|
||||
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
|
||||
provider_type = ChoiceFilter(choices=Provider.ProviderChoices.choices)
|
||||
|
||||
@@ -58,5 +58,96 @@
|
||||
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
|
||||
"inserted_at": "2024-11-13T11:55:41.237Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.role",
|
||||
"pk": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"name": "admin_test",
|
||||
"manage_users": true,
|
||||
"manage_account": true,
|
||||
"manage_billing": true,
|
||||
"manage_providers": true,
|
||||
"manage_integrations": true,
|
||||
"manage_scans": true,
|
||||
"unlimited_visibility": true,
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z",
|
||||
"updated_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.role",
|
||||
"pk": "845ff03a-87ef-42ba-9786-6577c70c4df0",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"name": "first_role",
|
||||
"manage_users": true,
|
||||
"manage_account": true,
|
||||
"manage_billing": true,
|
||||
"manage_providers": true,
|
||||
"manage_integrations": false,
|
||||
"manage_scans": false,
|
||||
"unlimited_visibility": true,
|
||||
"inserted_at": "2024-11-20T15:31:53.239Z",
|
||||
"updated_at": "2024-11-20T15:31:53.239Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.role",
|
||||
"pk": "902d726c-4bd5-413a-a2a4-f7b4754b6b20",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"name": "third_role",
|
||||
"manage_users": false,
|
||||
"manage_account": false,
|
||||
"manage_billing": false,
|
||||
"manage_providers": false,
|
||||
"manage_integrations": false,
|
||||
"manage_scans": true,
|
||||
"unlimited_visibility": false,
|
||||
"inserted_at": "2024-11-20T15:34:05.440Z",
|
||||
"updated_at": "2024-11-20T15:34:05.440Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.roleprovidergrouprelationship",
|
||||
"pk": "57fd024a-0a7f-49b4-a092-fa0979a07aaf",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"provider_group": "3fe28fb8-e545-424c-9b8f-69aff638f430",
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.roleprovidergrouprelationship",
|
||||
"pk": "a3cd0099-1c13-4df1-a5e5-ecdfec561b35",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"provider_group": "481769f5-db2b-447b-8b00-1dee18db90ec",
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.roleprovidergrouprelationship",
|
||||
"pk": "cfd84182-a058-40c2-af3c-0189b174940f",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.userrolerelationship",
|
||||
"pk": "92339663-e954-4fd8-98fb-8bfe15949975",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"user": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
|
||||
"inserted_at": "2024-11-20T15:36:14.302Z"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
# Generated by Django 5.1.1 on 2024-12-05 12:29
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0002_token_migrations"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Role",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=255)),
|
||||
("manage_users", models.BooleanField(default=False)),
|
||||
("manage_account", models.BooleanField(default=False)),
|
||||
("manage_billing", models.BooleanField(default=False)),
|
||||
("manage_providers", models.BooleanField(default=False)),
|
||||
("manage_integrations", models.BooleanField(default=False)),
|
||||
("manage_scans", models.BooleanField(default=False)),
|
||||
("unlimited_visibility", models.BooleanField(default=False)),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "roles",
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="RoleProviderGroupRelationship",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "role_provider_group_relationship",
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="UserRoleRelationship",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "role_user_relationship",
|
||||
},
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
name="provider_group",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.providergroup"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
name="role",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.role"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="role",
|
||||
name="provider_groups",
|
||||
field=models.ManyToManyField(
|
||||
related_name="roles",
|
||||
through="api.RoleProviderGroupRelationship",
|
||||
to="api.providergroup",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="userrolerelationship",
|
||||
name="role",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.role"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="userrolerelationship",
|
||||
name="user",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="role",
|
||||
name="users",
|
||||
field=models.ManyToManyField(
|
||||
related_name="roles",
|
||||
through="api.UserRoleRelationship",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("role_id", "provider_group_id"),
|
||||
name="unique_role_provider_group_relationship",
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_roleprovidergrouprelationship",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="userrolerelationship",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("role_id", "user_id"), name="unique_role_user_relationship"
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="userrolerelationship",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_userrolerelationship",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="role",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("tenant_id", "name"), name="unique_role_per_tenant"
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="role",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_role",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="InvitationRoleRelationship",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"invitation",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.invitation"
|
||||
),
|
||||
),
|
||||
(
|
||||
"role",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.role"
|
||||
),
|
||||
),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "role_invitation_relationship",
|
||||
},
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="invitationrolerelationship",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("role_id", "invitation_id"),
|
||||
name="unique_role_invitation_relationship",
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="invitationrolerelationship",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_invitationrolerelationship",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="role",
|
||||
name="invitations",
|
||||
field=models.ManyToManyField(
|
||||
related_name="roles",
|
||||
through="api.InvitationRoleRelationship",
|
||||
to="api.invitation",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
from django.db import migrations
|
||||
from api.db_router import MainRouter
|
||||
|
||||
|
||||
def create_admin_role(apps, schema_editor):
|
||||
Tenant = apps.get_model("api", "Tenant")
|
||||
Role = apps.get_model("api", "Role")
|
||||
User = apps.get_model("api", "User")
|
||||
UserRoleRelationship = apps.get_model("api", "UserRoleRelationship")
|
||||
|
||||
for tenant in Tenant.objects.using(MainRouter.admin_db).all():
|
||||
admin_role, _ = Role.objects.using(MainRouter.admin_db).get_or_create(
|
||||
name="admin",
|
||||
tenant=tenant,
|
||||
defaults={
|
||||
"manage_users": True,
|
||||
"manage_account": True,
|
||||
"manage_billing": True,
|
||||
"manage_providers": True,
|
||||
"manage_integrations": True,
|
||||
"manage_scans": True,
|
||||
"unlimited_visibility": True,
|
||||
},
|
||||
)
|
||||
users = User.objects.using(MainRouter.admin_db).filter(
|
||||
membership__tenant=tenant
|
||||
)
|
||||
for user in users:
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).get_or_create(
|
||||
user=user,
|
||||
role=admin_role,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0003_rbac"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(create_admin_role),
|
||||
]
|
||||
+164
-14
@@ -69,6 +69,21 @@ class StateChoices(models.TextChoices):
|
||||
CANCELLED = "cancelled", _("Cancelled")
|
||||
|
||||
|
||||
class PermissionChoices(models.TextChoices):
|
||||
"""
|
||||
Represents the different permission states that a role can have.
|
||||
|
||||
Attributes:
|
||||
UNLIMITED: Indicates that the role possesses all permissions.
|
||||
LIMITED: Indicates that the role has some permissions but not all.
|
||||
NONE: Indicates that the role does not have any permissions.
|
||||
"""
|
||||
|
||||
UNLIMITED = "unlimited", _("Unlimited permissions")
|
||||
LIMITED = "limited", _("Limited permissions")
|
||||
NONE = "none", _("No permissions")
|
||||
|
||||
|
||||
class ActiveProviderManager(models.Manager):
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().filter(self.active_provider_filter())
|
||||
@@ -294,23 +309,14 @@ class ProviderGroup(RowLevelSecurityProtectedModel):
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "provider-groups"
|
||||
resource_name = "provider-group"
|
||||
|
||||
|
||||
class ProviderGroupMembership(RowLevelSecurityProtectedModel):
|
||||
objects = ActiveProviderManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
provider = models.ForeignKey(
|
||||
Provider,
|
||||
on_delete=models.CASCADE,
|
||||
)
|
||||
provider_group = models.ForeignKey(
|
||||
ProviderGroup,
|
||||
on_delete=models.CASCADE,
|
||||
)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(Provider, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "provider_group_memberships"
|
||||
@@ -327,7 +333,7 @@ class ProviderGroupMembership(RowLevelSecurityProtectedModel):
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "provider-group-memberships"
|
||||
resource_name = "provider_groups-provider"
|
||||
|
||||
|
||||
class Task(RowLevelSecurityProtectedModel):
|
||||
@@ -851,6 +857,150 @@ class Invitation(RowLevelSecurityProtectedModel):
|
||||
resource_name = "invitations"
|
||||
|
||||
|
||||
class Role(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
name = models.CharField(max_length=255)
|
||||
manage_users = models.BooleanField(default=False)
|
||||
manage_account = models.BooleanField(default=False)
|
||||
manage_billing = models.BooleanField(default=False)
|
||||
manage_providers = models.BooleanField(default=False)
|
||||
manage_integrations = models.BooleanField(default=False)
|
||||
manage_scans = models.BooleanField(default=False)
|
||||
unlimited_visibility = models.BooleanField(default=False)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
updated_at = models.DateTimeField(auto_now=True, editable=False)
|
||||
provider_groups = models.ManyToManyField(
|
||||
ProviderGroup, through="RoleProviderGroupRelationship", related_name="roles"
|
||||
)
|
||||
users = models.ManyToManyField(
|
||||
User, through="UserRoleRelationship", related_name="roles"
|
||||
)
|
||||
invitations = models.ManyToManyField(
|
||||
Invitation, through="InvitationRoleRelationship", related_name="roles"
|
||||
)
|
||||
|
||||
# Filter permission_state
|
||||
PERMISSION_FIELDS = [
|
||||
"manage_users",
|
||||
"manage_account",
|
||||
"manage_billing",
|
||||
"manage_providers",
|
||||
"manage_integrations",
|
||||
"manage_scans",
|
||||
]
|
||||
|
||||
@property
|
||||
def permission_state(self):
|
||||
values = [getattr(self, field) for field in self.PERMISSION_FIELDS]
|
||||
if all(values):
|
||||
return PermissionChoices.UNLIMITED
|
||||
elif not any(values):
|
||||
return PermissionChoices.NONE
|
||||
else:
|
||||
return PermissionChoices.LIMITED
|
||||
|
||||
@classmethod
|
||||
def filter_by_permission_state(cls, queryset, value):
|
||||
q_all_true = Q(**{field: True for field in cls.PERMISSION_FIELDS})
|
||||
q_all_false = Q(**{field: False for field in cls.PERMISSION_FIELDS})
|
||||
|
||||
if value == PermissionChoices.UNLIMITED:
|
||||
return queryset.filter(q_all_true)
|
||||
elif value == PermissionChoices.NONE:
|
||||
return queryset.filter(q_all_false)
|
||||
else:
|
||||
return queryset.exclude(q_all_true | q_all_false)
|
||||
|
||||
class Meta:
|
||||
db_table = "roles"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["tenant_id", "name"],
|
||||
name="unique_role_per_tenant",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "role"
|
||||
|
||||
|
||||
class RoleProviderGroupRelationship(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
role = models.ForeignKey(Role, on_delete=models.CASCADE)
|
||||
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "role_provider_group_relationship"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["role_id", "provider_group_id"],
|
||||
name="unique_role_provider_group_relationship",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "role-provider_groups"
|
||||
|
||||
|
||||
class UserRoleRelationship(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
role = models.ForeignKey(Role, on_delete=models.CASCADE)
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "role_user_relationship"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["role_id", "user_id"],
|
||||
name="unique_role_user_relationship",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "user-roles"
|
||||
|
||||
|
||||
class InvitationRoleRelationship(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
role = models.ForeignKey(Role, on_delete=models.CASCADE)
|
||||
invitation = models.ForeignKey(Invitation, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "role_invitation_relationship"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["role_id", "invitation_id"],
|
||||
name="unique_role_invitation_relationship",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "invitation-roles"
|
||||
|
||||
|
||||
class ComplianceOverview(RowLevelSecurityProtectedModel):
|
||||
objects = ActiveProviderManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
from enum import Enum
|
||||
from rest_framework.permissions import BasePermission
|
||||
from api.models import User
|
||||
from api.db_router import MainRouter
|
||||
|
||||
|
||||
class Permissions(Enum):
|
||||
MANAGE_USERS = "manage_users"
|
||||
MANAGE_ACCOUNT = "manage_account"
|
||||
MANAGE_BILLING = "manage_billing"
|
||||
MANAGE_PROVIDERS = "manage_providers"
|
||||
MANAGE_INTEGRATIONS = "manage_integrations"
|
||||
MANAGE_SCANS = "manage_scans"
|
||||
UNLIMITED_VISIBILITY = "unlimited_visibility"
|
||||
|
||||
|
||||
class HasPermissions(BasePermission):
|
||||
"""
|
||||
Custom permission to check if the user's role has the required permissions.
|
||||
The required permissions should be specified in the view as a list in `required_permissions`.
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
required_permissions = getattr(view, "required_permissions", [])
|
||||
if not required_permissions:
|
||||
return True
|
||||
|
||||
user_roles = (
|
||||
User.objects.using(MainRouter.admin_db).get(id=request.user.id).roles.all()
|
||||
)
|
||||
if not user_roles:
|
||||
return False
|
||||
|
||||
for perm in required_permissions:
|
||||
if not getattr(user_roles[0], perm.value, False):
|
||||
return False
|
||||
|
||||
return True
|
||||
+1515
-43
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,10 @@
|
||||
import pytest
|
||||
from django.urls import reverse
|
||||
from unittest.mock import patch
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header
|
||||
|
||||
|
||||
@patch("api.v1.views.MainRouter.admin_db", new="default")
|
||||
@pytest.mark.django_db
|
||||
def test_basic_authentication():
|
||||
client = APIClient()
|
||||
|
||||
@@ -13,6 +13,7 @@ def test_check_resources_between_different_tenants(
|
||||
enforce_test_user_db_connection,
|
||||
authenticated_api_client,
|
||||
tenants_fixture,
|
||||
set_user_admin_roles_fixture,
|
||||
):
|
||||
client = authenticated_api_client
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ from django.db.utils import ConnectionRouter
|
||||
from api.db_router import MainRouter
|
||||
from api.rls import Tenant
|
||||
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="admin")
|
||||
class TestMainDatabaseRouter:
|
||||
@pytest.fixture(scope="module")
|
||||
def router(self):
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
import pytest
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
from unittest.mock import patch, ANY, Mock
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestUserViewSet:
|
||||
def test_list_users_with_all_permissions(self, authenticated_client_rbac):
|
||||
response = authenticated_client_rbac.get(reverse("user-list"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert isinstance(response.json()["data"], list)
|
||||
|
||||
def test_list_users_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(reverse("user-list"))
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_retrieve_user_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
response = authenticated_client_rbac.get(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert (
|
||||
response.json()["data"]["attributes"]["email"]
|
||||
== create_test_user_rbac.email
|
||||
)
|
||||
|
||||
def test_retrieve_user_with_no_roles(
|
||||
self, authenticated_client_rbac_noroles, create_test_user_rbac_no_roles
|
||||
):
|
||||
response = authenticated_client_rbac_noroles.get(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac_no_roles.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_retrieve_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_create_user_with_all_permissions(self, authenticated_client_rbac):
|
||||
valid_user_payload = {
|
||||
"name": "test",
|
||||
"password": "newpassword123",
|
||||
"email": "new_user@test.com",
|
||||
}
|
||||
response = authenticated_client_rbac.post(
|
||||
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
|
||||
|
||||
def test_create_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
valid_user_payload = {
|
||||
"name": "test",
|
||||
"password": "newpassword123",
|
||||
"email": "new_user@test.com",
|
||||
}
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
|
||||
|
||||
def test_partial_update_user_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
updated_data = {
|
||||
"data": {
|
||||
"type": "users",
|
||||
"id": str(create_test_user_rbac.id),
|
||||
"attributes": {"name": "Updated Name"},
|
||||
},
|
||||
}
|
||||
response = authenticated_client_rbac.patch(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id}),
|
||||
data=updated_data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["name"] == "Updated Name"
|
||||
|
||||
def test_partial_update_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
updated_data = {
|
||||
"data": {
|
||||
"type": "users",
|
||||
"attributes": {"name": "Updated Name"},
|
||||
}
|
||||
}
|
||||
response = authenticated_client_no_permissions_rbac.patch(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user.id}),
|
||||
data=updated_data,
|
||||
format="vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_delete_user_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
response = authenticated_client_rbac.delete(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
def test_delete_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.delete(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_me_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
response = authenticated_client_rbac.get(reverse("user-me"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert (
|
||||
response.json()["data"]["attributes"]["email"]
|
||||
== create_test_user_rbac.email
|
||||
)
|
||||
|
||||
def test_me_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(reverse("user-me"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["email"] == "rbac_limited@rbac.com"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProviderViewSet:
|
||||
def test_list_providers_with_all_permissions(
|
||||
self, authenticated_client_rbac, providers_fixture
|
||||
):
|
||||
response = authenticated_client_rbac.get(reverse("provider-list"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == len(providers_fixture)
|
||||
|
||||
def test_list_providers_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(
|
||||
reverse("provider-list")
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 0
|
||||
|
||||
def test_retrieve_provider_with_all_permissions(
|
||||
self, authenticated_client_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_rbac.get(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["alias"] == provider.alias
|
||||
|
||||
def test_retrieve_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_no_permissions_rbac.get(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_create_provider_with_all_permissions(self, authenticated_client_rbac):
|
||||
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
|
||||
response = authenticated_client_rbac.post(
|
||||
reverse("provider-list"), data=payload, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["data"]["attributes"]["alias"] == "new_alias"
|
||||
|
||||
def test_create_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse("provider-list"), data=payload, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_partial_update_provider_with_all_permissions(
|
||||
self, authenticated_client_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
payload = {
|
||||
"data": {
|
||||
"type": "providers",
|
||||
"id": provider.id,
|
||||
"attributes": {"alias": "updated_alias"},
|
||||
},
|
||||
}
|
||||
response = authenticated_client_rbac.patch(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id}),
|
||||
data=payload,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["alias"] == "updated_alias"
|
||||
|
||||
def test_partial_update_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
update_payload = {
|
||||
"data": {
|
||||
"type": "providers",
|
||||
"attributes": {"alias": "updated_alias"},
|
||||
}
|
||||
}
|
||||
response = authenticated_client_no_permissions_rbac.patch(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id}),
|
||||
data=update_payload,
|
||||
format="vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@patch("api.v1.views.Task.objects.get")
|
||||
@patch("api.v1.views.delete_provider_task.delay")
|
||||
def test_delete_provider_with_all_permissions(
|
||||
self,
|
||||
mock_delete_task,
|
||||
mock_task_get,
|
||||
authenticated_client_rbac,
|
||||
providers_fixture,
|
||||
tasks_fixture,
|
||||
):
|
||||
prowler_task = tasks_fixture[0]
|
||||
task_mock = Mock()
|
||||
task_mock.id = prowler_task.id
|
||||
mock_delete_task.return_value = task_mock
|
||||
mock_task_get.return_value = prowler_task
|
||||
|
||||
provider1, *_ = providers_fixture
|
||||
response = authenticated_client_rbac.delete(
|
||||
reverse("provider-detail", kwargs={"pk": provider1.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
mock_delete_task.assert_called_once_with(
|
||||
provider_id=str(provider1.id), tenant_id=ANY
|
||||
)
|
||||
assert "Content-Location" in response.headers
|
||||
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
|
||||
|
||||
def test_delete_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_no_permissions_rbac.delete(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@patch("api.v1.views.Task.objects.get")
|
||||
@patch("api.v1.views.check_provider_connection_task.delay")
|
||||
def test_connection_with_all_permissions(
|
||||
self,
|
||||
mock_provider_connection,
|
||||
mock_task_get,
|
||||
authenticated_client_rbac,
|
||||
providers_fixture,
|
||||
tasks_fixture,
|
||||
):
|
||||
prowler_task = tasks_fixture[0]
|
||||
task_mock = Mock()
|
||||
task_mock.id = prowler_task.id
|
||||
task_mock.status = "PENDING"
|
||||
mock_provider_connection.return_value = task_mock
|
||||
mock_task_get.return_value = prowler_task
|
||||
|
||||
provider1, *_ = providers_fixture
|
||||
assert provider1.connected is None
|
||||
assert provider1.connection_last_checked_at is None
|
||||
|
||||
response = authenticated_client_rbac.post(
|
||||
reverse("provider-connection", kwargs={"pk": provider1.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
mock_provider_connection.assert_called_once_with(
|
||||
provider_id=str(provider1.id), tenant_id=ANY
|
||||
)
|
||||
assert "Content-Location" in response.headers
|
||||
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
|
||||
|
||||
def test_connection_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse("provider-connection", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
@@ -9,11 +9,14 @@ from django.urls import reverse
|
||||
from rest_framework import status
|
||||
|
||||
from api.models import (
|
||||
Invitation,
|
||||
Membership,
|
||||
Provider,
|
||||
ProviderGroup,
|
||||
ProviderGroupMembership,
|
||||
Role,
|
||||
RoleProviderGroupRelationship,
|
||||
Invitation,
|
||||
UserRoleRelationship,
|
||||
ProviderSecret,
|
||||
Scan,
|
||||
StateChoices,
|
||||
@@ -50,7 +53,6 @@ class TestUserViewSet:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["email"] == create_test_user.email
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_users_create(self, client):
|
||||
valid_user_payload = {
|
||||
"name": "test",
|
||||
@@ -67,7 +69,6 @@ class TestUserViewSet:
|
||||
== valid_user_payload["email"].lower()
|
||||
)
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_users_create_duplicated_email(self, client):
|
||||
# Create a user
|
||||
self.test_users_create(client)
|
||||
@@ -122,7 +123,6 @@ class TestUserViewSet:
|
||||
"NonExistentEmail@prowler.com",
|
||||
],
|
||||
)
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_users_create_used_email(self, authenticated_client, email):
|
||||
# First user created; no errors should occur
|
||||
user_payload = {
|
||||
@@ -418,7 +418,6 @@ class TestTenantViewSet:
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
@patch("api.v1.views.delete_tenant_task.apply_async")
|
||||
def test_tenants_delete(
|
||||
self, delete_tenant_mock, authenticated_client, tenants_fixture
|
||||
@@ -815,7 +814,7 @@ class TestProviderViewSet:
|
||||
@pytest.mark.parametrize(
|
||||
"include_values, expected_resources",
|
||||
[
|
||||
("provider_groups", ["provider-groups"]),
|
||||
("provider_groups", ["provider-group"]),
|
||||
],
|
||||
)
|
||||
def test_providers_list_include(
|
||||
@@ -1200,7 +1199,7 @@ class TestProviderGroupViewSet:
|
||||
def test_provider_group_create(self, authenticated_client):
|
||||
data = {
|
||||
"data": {
|
||||
"type": "provider-groups",
|
||||
"type": "provider-group",
|
||||
"attributes": {
|
||||
"name": "Test Provider Group",
|
||||
},
|
||||
@@ -1219,7 +1218,7 @@ class TestProviderGroupViewSet:
|
||||
def test_provider_group_create_invalid(self, authenticated_client):
|
||||
data = {
|
||||
"data": {
|
||||
"type": "provider-groups",
|
||||
"type": "provider-group",
|
||||
"attributes": {
|
||||
# Name is missing
|
||||
},
|
||||
@@ -1241,7 +1240,7 @@ class TestProviderGroupViewSet:
|
||||
data = {
|
||||
"data": {
|
||||
"id": str(provider_group.id),
|
||||
"type": "provider-groups",
|
||||
"type": "provider-group",
|
||||
"attributes": {
|
||||
"name": "Updated Provider Group Name",
|
||||
},
|
||||
@@ -1263,7 +1262,7 @@ class TestProviderGroupViewSet:
|
||||
data = {
|
||||
"data": {
|
||||
"id": str(provider_group.id),
|
||||
"type": "provider-groups",
|
||||
"type": "provider-group",
|
||||
"attributes": {
|
||||
"name": "", # Invalid name
|
||||
},
|
||||
@@ -1294,100 +1293,6 @@ class TestProviderGroupViewSet:
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_provider_group_providers_update(
|
||||
self, authenticated_client, provider_groups_fixture, providers_fixture
|
||||
):
|
||||
provider_group = provider_groups_fixture[0]
|
||||
provider_ids = [str(provider.id) for provider in providers_fixture]
|
||||
|
||||
data = {
|
||||
"data": {
|
||||
"type": "provider-group-memberships",
|
||||
"id": str(provider_group.id),
|
||||
"attributes": {"provider_ids": provider_ids},
|
||||
}
|
||||
}
|
||||
|
||||
response = authenticated_client.put(
|
||||
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
memberships = ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group
|
||||
)
|
||||
assert memberships.count() == len(provider_ids)
|
||||
for membership in memberships:
|
||||
assert str(membership.provider_id) in provider_ids
|
||||
|
||||
def test_provider_group_providers_update_non_existent_provider(
|
||||
self, authenticated_client, provider_groups_fixture, providers_fixture
|
||||
):
|
||||
provider_group = provider_groups_fixture[0]
|
||||
provider_ids = [str(provider.id) for provider in providers_fixture]
|
||||
provider_ids[-1] = "1b59e032-3eb6-4694-93a5-df84cd9b3ce2"
|
||||
|
||||
data = {
|
||||
"data": {
|
||||
"type": "provider-group-memberships",
|
||||
"id": str(provider_group.id),
|
||||
"attributes": {"provider_ids": provider_ids},
|
||||
}
|
||||
}
|
||||
|
||||
response = authenticated_client.put(
|
||||
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert (
|
||||
errors[0]["detail"]
|
||||
== f"The following provider IDs do not exist: {provider_ids[-1]}"
|
||||
)
|
||||
|
||||
def test_provider_group_providers_update_invalid_provider(
|
||||
self, authenticated_client, provider_groups_fixture
|
||||
):
|
||||
provider_group = provider_groups_fixture[1]
|
||||
invalid_provider_id = "non-existent-id"
|
||||
data = {
|
||||
"data": {
|
||||
"type": "provider-group-memberships",
|
||||
"id": str(provider_group.id),
|
||||
"attributes": {"provider_ids": [invalid_provider_id]},
|
||||
}
|
||||
}
|
||||
|
||||
response = authenticated_client.put(
|
||||
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert errors[0]["detail"] == "Must be a valid UUID."
|
||||
|
||||
def test_provider_group_providers_update_invalid_payload(
|
||||
self, authenticated_client, provider_groups_fixture
|
||||
):
|
||||
provider_group = provider_groups_fixture[2]
|
||||
data = {
|
||||
# Missing "provider_ids"
|
||||
}
|
||||
|
||||
response = authenticated_client.put(
|
||||
reverse("providergroup-providers", kwargs={"pk": provider_group.id}),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert errors[0]["detail"] == "Received document does not contain primary data"
|
||||
|
||||
def test_provider_group_retrieve_not_found(self, authenticated_client):
|
||||
response = authenticated_client.get(
|
||||
reverse("providergroup-detail", kwargs={"pk": "non-existent-id"})
|
||||
@@ -2652,7 +2557,9 @@ class TestInvitationViewSet:
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_invitations_create_valid(self, authenticated_client, create_test_user):
|
||||
def test_invitations_create_valid(
|
||||
self, authenticated_client, create_test_user, roles_fixture
|
||||
):
|
||||
user = create_test_user
|
||||
data = {
|
||||
"data": {
|
||||
@@ -2661,6 +2568,11 @@ class TestInvitationViewSet:
|
||||
"email": "any_email@prowler.com",
|
||||
"expires_at": self.TOMORROW_ISO,
|
||||
},
|
||||
"relationships": {
|
||||
"roles": {
|
||||
"data": [{"type": "role", "id": str(roles_fixture[0].id)}]
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
@@ -2719,6 +2631,11 @@ class TestInvitationViewSet:
|
||||
response.json()["errors"][0]["source"]["pointer"]
|
||||
== "/data/attributes/email"
|
||||
)
|
||||
assert response.json()["errors"][1]["code"] == "required"
|
||||
assert (
|
||||
response.json()["errors"][1]["source"]["pointer"]
|
||||
== "/data/relationships/roles"
|
||||
)
|
||||
|
||||
def test_invitations_create_invalid_expires_at(
|
||||
self, authenticated_client, invitations_fixture
|
||||
@@ -2745,6 +2662,11 @@ class TestInvitationViewSet:
|
||||
response.json()["errors"][0]["source"]["pointer"]
|
||||
== "/data/attributes/expires_at"
|
||||
)
|
||||
assert response.json()["errors"][1]["code"] == "required"
|
||||
assert (
|
||||
response.json()["errors"][1]["source"]["pointer"]
|
||||
== "/data/relationships/roles"
|
||||
)
|
||||
|
||||
def test_invitations_partial_update_valid(
|
||||
self, authenticated_client, invitations_fixture
|
||||
@@ -2932,7 +2854,6 @@ class TestInvitationViewSet:
|
||||
== "This invitation cannot be revoked."
|
||||
)
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_invitations_accept_invitation_new_user(self, client, invitations_fixture):
|
||||
invitation, *_ = invitations_fixture
|
||||
|
||||
@@ -2958,7 +2879,6 @@ class TestInvitationViewSet:
|
||||
user__email__iexact=invitation.email, tenant=invitation.tenant
|
||||
).exists()
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_invitations_accept_invitation_existing_user(
|
||||
self, authenticated_client, create_test_user, tenants_fixture
|
||||
):
|
||||
@@ -2983,7 +2903,6 @@ class TestInvitationViewSet:
|
||||
response = authenticated_client.post(
|
||||
reverse("invitation-accept"), data=data, format="json"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
invitation.refresh_from_db()
|
||||
assert Membership.objects.filter(
|
||||
@@ -2991,7 +2910,6 @@ class TestInvitationViewSet:
|
||||
).exists()
|
||||
assert invitation.state == Invitation.State.ACCEPTED.value
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_invitations_accept_invitation_invalid_token(self, authenticated_client):
|
||||
data = {
|
||||
"invitation_token": "invalid_token",
|
||||
@@ -3004,7 +2922,6 @@ class TestInvitationViewSet:
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert response.json()["errors"][0]["code"] == "not_found"
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_invitations_accept_invitation_invalid_token_expired(
|
||||
self, authenticated_client, invitations_fixture
|
||||
):
|
||||
@@ -3023,7 +2940,6 @@ class TestInvitationViewSet:
|
||||
|
||||
assert response.status_code == status.HTTP_410_GONE
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_invitations_accept_invitation_invalid_token_expired_new_user(
|
||||
self, client, invitations_fixture
|
||||
):
|
||||
@@ -3047,7 +2963,6 @@ class TestInvitationViewSet:
|
||||
|
||||
assert response.status_code == status.HTTP_410_GONE
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_invitations_accept_invitation_invalid_token_accepted(
|
||||
self, authenticated_client, invitations_fixture
|
||||
):
|
||||
@@ -3071,7 +2986,6 @@ class TestInvitationViewSet:
|
||||
== "This invitation is no longer valid."
|
||||
)
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
def test_invitations_accept_invitation_invalid_token_revoked(
|
||||
self, authenticated_client, invitations_fixture
|
||||
):
|
||||
@@ -3166,6 +3080,622 @@ class TestInvitationViewSet:
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestRoleViewSet:
|
||||
def test_role_list(self, authenticated_client, roles_fixture):
|
||||
response = authenticated_client.get(reverse("role-list"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert (
|
||||
len(response.json()["data"]) == len(roles_fixture) + 2
|
||||
) # 2 default admin roles, one for each tenant
|
||||
|
||||
def test_role_retrieve(self, authenticated_client, roles_fixture):
|
||||
role = roles_fixture[0]
|
||||
response = authenticated_client.get(
|
||||
reverse("role-detail", kwargs={"pk": role.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert data["id"] == str(role.id)
|
||||
assert data["attributes"]["name"] == role.name
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("permission_state", "index"),
|
||||
[("limited", 0), ("unlimited", 2), ("none", 3)],
|
||||
)
|
||||
def test_role_retrieve_permission_state(
|
||||
self, authenticated_client, roles_fixture, permission_state, index
|
||||
):
|
||||
role = roles_fixture[index]
|
||||
response = authenticated_client.get(
|
||||
reverse("role-detail", kwargs={"pk": role.id}),
|
||||
{"filter[permission_state]": permission_state},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert data["id"] == str(role.id)
|
||||
assert data["attributes"]["name"] == role.name
|
||||
assert data["attributes"]["permission_state"] == permission_state
|
||||
|
||||
def test_role_create(self, authenticated_client):
|
||||
data = {
|
||||
"data": {
|
||||
"type": "role",
|
||||
"attributes": {
|
||||
"name": "Test Role",
|
||||
"manage_users": "false",
|
||||
"manage_account": "false",
|
||||
"manage_billing": "false",
|
||||
"manage_providers": "true",
|
||||
"manage_integrations": "true",
|
||||
"manage_scans": "true",
|
||||
"unlimited_visibility": "true",
|
||||
},
|
||||
"relationships": {"provider_groups": {"data": []}},
|
||||
}
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse("role-list"),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
response_data = response.json()["data"]
|
||||
assert response_data["attributes"]["name"] == "Test Role"
|
||||
assert Role.objects.filter(name="Test Role").exists()
|
||||
|
||||
def test_role_provider_groups_create(
|
||||
self, authenticated_client, provider_groups_fixture
|
||||
):
|
||||
data = {
|
||||
"data": {
|
||||
"type": "role",
|
||||
"attributes": {
|
||||
"name": "Test Role",
|
||||
"manage_users": "false",
|
||||
"manage_account": "false",
|
||||
"manage_billing": "false",
|
||||
"manage_providers": "true",
|
||||
"manage_integrations": "true",
|
||||
"manage_scans": "true",
|
||||
"unlimited_visibility": "true",
|
||||
},
|
||||
"relationships": {
|
||||
"provider_groups": {
|
||||
"data": [
|
||||
{"type": "provider-group", "id": str(provider_group.id)}
|
||||
for provider_group in provider_groups_fixture[:2]
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse("role-list"),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
response_data = response.json()["data"]
|
||||
assert response_data["attributes"]["name"] == "Test Role"
|
||||
assert Role.objects.filter(name="Test Role").exists()
|
||||
relationships = (
|
||||
Role.objects.filter(name="Test Role").first().provider_groups.all()
|
||||
)
|
||||
assert relationships.count() == 2
|
||||
for relationship in relationships:
|
||||
assert relationship.id in [pg.id for pg in provider_groups_fixture[:2]]
|
||||
|
||||
def test_role_create_invalid(self, authenticated_client):
|
||||
data = {
|
||||
"data": {
|
||||
"type": "role",
|
||||
"attributes": {
|
||||
# Name is missing
|
||||
},
|
||||
}
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse("role-list"),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert errors[0]["source"]["pointer"] == "/data/attributes/name"
|
||||
|
||||
def test_role_partial_update(self, authenticated_client, roles_fixture):
|
||||
role = roles_fixture[1]
|
||||
data = {
|
||||
"data": {
|
||||
"id": str(role.id),
|
||||
"type": "role",
|
||||
"attributes": {
|
||||
"name": "Updated Provider Group Name",
|
||||
},
|
||||
}
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse("role-detail", kwargs={"pk": role.id}),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
role.refresh_from_db()
|
||||
assert role.name == "Updated Provider Group Name"
|
||||
|
||||
def test_role_partial_update_invalid(self, authenticated_client, roles_fixture):
|
||||
role = roles_fixture[2]
|
||||
data = {
|
||||
"data": {
|
||||
"id": str(role.id),
|
||||
"type": "role",
|
||||
"attributes": {
|
||||
"name": "", # Invalid name
|
||||
},
|
||||
}
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse("role-detail", kwargs={"pk": role.id}),
|
||||
data=json.dumps(data),
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]
|
||||
assert errors[0]["source"]["pointer"] == "/data/attributes/name"
|
||||
|
||||
def test_role_destroy(self, authenticated_client, roles_fixture):
|
||||
role = roles_fixture[2]
|
||||
response = authenticated_client.delete(
|
||||
reverse("role-detail", kwargs={"pk": role.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
assert not Role.objects.filter(id=role.id).exists()
|
||||
|
||||
def test_role_destroy_invalid(self, authenticated_client):
|
||||
response = authenticated_client.delete(
|
||||
reverse("role-detail", kwargs={"pk": "non-existent-id"})
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_role_retrieve_not_found(self, authenticated_client):
|
||||
response = authenticated_client.get(
|
||||
reverse("role-detail", kwargs={"pk": "non-existent-id"})
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_role_list_filters(self, authenticated_client, roles_fixture):
|
||||
role = roles_fixture[0]
|
||||
response = authenticated_client.get(
|
||||
reverse("role-list"), {"filter[name]": role.name}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert len(data) == 1
|
||||
assert data[0]["attributes"]["name"] == role.name
|
||||
|
||||
def test_role_list_sorting(
|
||||
self, authenticated_client, set_user_admin_roles_fixture, roles_fixture
|
||||
):
|
||||
response = authenticated_client.get(reverse("role-list"), {"sort": "name"})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
names = [
|
||||
item["attributes"]["name"]
|
||||
for item in data
|
||||
if item["attributes"]["name"] != "admin"
|
||||
]
|
||||
assert names == sorted(names, key=lambda v: v.lower())
|
||||
|
||||
def test_role_invalid_method(self, authenticated_client):
|
||||
response = authenticated_client.put(reverse("role-list"))
|
||||
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestUserRoleRelationshipViewSet:
|
||||
def test_create_relationship(
|
||||
self, authenticated_client, roles_fixture, create_test_user
|
||||
):
|
||||
data = {
|
||||
"data": [{"type": "role", "id": str(role.id)} for role in roles_fixture[:2]]
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = UserRoleRelationship.objects.filter(user=create_test_user.id)
|
||||
assert relationships.count() == 4
|
||||
for relationship in relationships[2:]: # Skip admin role
|
||||
assert relationship.role.id in [r.id for r in roles_fixture[:2]]
|
||||
|
||||
def test_create_relationship_already_exists(
|
||||
self, authenticated_client, roles_fixture, create_test_user
|
||||
):
|
||||
data = {
|
||||
"data": [{"type": "role", "id": str(role.id)} for role in roles_fixture[:2]]
|
||||
}
|
||||
authenticated_client.post(
|
||||
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "role", "id": str(roles_fixture[0].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]["detail"]
|
||||
assert "already associated" in errors
|
||||
|
||||
def test_partial_update_relationship(
|
||||
self, authenticated_client, roles_fixture, create_test_user
|
||||
):
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "role", "id": str(roles_fixture[1].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = UserRoleRelationship.objects.filter(user=create_test_user.id)
|
||||
assert relationships.count() == 1
|
||||
assert {rel.role.id for rel in relationships} == {roles_fixture[1].id}
|
||||
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "role", "id": str(roles_fixture[1].id)},
|
||||
{"type": "role", "id": str(roles_fixture[2].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = UserRoleRelationship.objects.filter(user=create_test_user.id)
|
||||
assert relationships.count() == 2
|
||||
assert {rel.role.id for rel in relationships} == {
|
||||
roles_fixture[1].id,
|
||||
roles_fixture[2].id,
|
||||
}
|
||||
|
||||
def test_destroy_relationship(
|
||||
self, authenticated_client, roles_fixture, create_test_user
|
||||
):
|
||||
response = authenticated_client.delete(
|
||||
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = UserRoleRelationship.objects.filter(role=roles_fixture[0].id)
|
||||
assert relationships.count() == 0
|
||||
|
||||
def test_invalid_provider_group_id(self, authenticated_client, create_test_user):
|
||||
invalid_id = "non-existent-id"
|
||||
data = {"data": [{"type": "provider-group", "id": invalid_id}]}
|
||||
response = authenticated_client.post(
|
||||
reverse("user-roles-relationship", kwargs={"pk": create_test_user.id}),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"][0]["detail"]
|
||||
assert "valid UUID" in errors
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestRoleProviderGroupRelationshipViewSet:
|
||||
def test_create_relationship(
|
||||
self, authenticated_client, roles_fixture, provider_groups_fixture
|
||||
):
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider-group", "id": str(provider_group.id)}
|
||||
for provider_group in provider_groups_fixture[:2]
|
||||
]
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = RoleProviderGroupRelationship.objects.filter(
|
||||
role=roles_fixture[0].id
|
||||
)
|
||||
assert relationships.count() == 2
|
||||
for relationship in relationships:
|
||||
assert relationship.provider_group.id in [
|
||||
pg.id for pg in provider_groups_fixture[:2]
|
||||
]
|
||||
|
||||
def test_create_relationship_already_exists(
|
||||
self, authenticated_client, roles_fixture, provider_groups_fixture
|
||||
):
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider-group", "id": str(provider_group.id)}
|
||||
for provider_group in provider_groups_fixture[:2]
|
||||
]
|
||||
}
|
||||
authenticated_client.post(
|
||||
reverse(
|
||||
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider-group", "id": str(provider_groups_fixture[0].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]["detail"]
|
||||
assert "already associated" in errors
|
||||
|
||||
def test_partial_update_relationship(
|
||||
self, authenticated_client, roles_fixture, provider_groups_fixture
|
||||
):
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider-group", "id": str(provider_groups_fixture[1].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse(
|
||||
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[2].id}
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = RoleProviderGroupRelationship.objects.filter(
|
||||
role=roles_fixture[2].id
|
||||
)
|
||||
assert relationships.count() == 1
|
||||
assert {rel.provider_group.id for rel in relationships} == {
|
||||
provider_groups_fixture[1].id
|
||||
}
|
||||
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider-group", "id": str(provider_groups_fixture[1].id)},
|
||||
{"type": "provider-group", "id": str(provider_groups_fixture[2].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse(
|
||||
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[2].id}
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = RoleProviderGroupRelationship.objects.filter(
|
||||
role=roles_fixture[2].id
|
||||
)
|
||||
assert relationships.count() == 2
|
||||
assert {rel.provider_group.id for rel in relationships} == {
|
||||
provider_groups_fixture[1].id,
|
||||
provider_groups_fixture[2].id,
|
||||
}
|
||||
|
||||
def test_destroy_relationship(
|
||||
self, authenticated_client, roles_fixture, provider_groups_fixture
|
||||
):
|
||||
response = authenticated_client.delete(
|
||||
reverse(
|
||||
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[0].id}
|
||||
),
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = RoleProviderGroupRelationship.objects.filter(
|
||||
role=roles_fixture[0].id
|
||||
)
|
||||
assert relationships.count() == 0
|
||||
|
||||
def test_invalid_provider_group_id(self, authenticated_client, roles_fixture):
|
||||
invalid_id = "non-existent-id"
|
||||
data = {"data": [{"type": "provider-group", "id": invalid_id}]}
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"role-provider-groups-relationship", kwargs={"pk": roles_fixture[1].id}
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"][0]["detail"]
|
||||
assert "valid UUID" in errors
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProviderGroupMembershipViewSet:
|
||||
def test_create_relationship(
|
||||
self, authenticated_client, providers_fixture, provider_groups_fixture
|
||||
):
|
||||
provider_group, *_ = provider_groups_fixture
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider", "id": str(provider.id)}
|
||||
for provider in providers_fixture[:2]
|
||||
]
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group.id
|
||||
)
|
||||
assert relationships.count() == 2
|
||||
for relationship in relationships:
|
||||
assert relationship.provider.id in [p.id for p in providers_fixture[:2]]
|
||||
|
||||
def test_create_relationship_already_exists(
|
||||
self, authenticated_client, providers_fixture, provider_groups_fixture
|
||||
):
|
||||
provider_group, *_ = provider_groups_fixture
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider", "id": str(provider.id)}
|
||||
for provider in providers_fixture[:2]
|
||||
]
|
||||
}
|
||||
authenticated_client.post(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider", "id": str(providers_fixture[0].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"]["detail"]
|
||||
assert "already associated" in errors
|
||||
|
||||
def test_partial_update_relationship(
|
||||
self, authenticated_client, providers_fixture, provider_groups_fixture
|
||||
):
|
||||
provider_group, *_ = provider_groups_fixture
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider", "id": str(providers_fixture[1].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group.id
|
||||
)
|
||||
assert relationships.count() == 1
|
||||
assert {rel.provider.id for rel in relationships} == {providers_fixture[1].id}
|
||||
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider", "id": str(providers_fixture[1].id)},
|
||||
{"type": "provider", "id": str(providers_fixture[2].id)},
|
||||
]
|
||||
}
|
||||
response = authenticated_client.patch(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group.id
|
||||
)
|
||||
assert relationships.count() == 2
|
||||
assert {rel.provider.id for rel in relationships} == {
|
||||
providers_fixture[1].id,
|
||||
providers_fixture[2].id,
|
||||
}
|
||||
|
||||
def test_destroy_relationship(
|
||||
self, authenticated_client, providers_fixture, provider_groups_fixture
|
||||
):
|
||||
provider_group, *_ = provider_groups_fixture
|
||||
data = {
|
||||
"data": [
|
||||
{"type": "provider", "id": str(provider.id)}
|
||||
for provider in providers_fixture[:2]
|
||||
]
|
||||
}
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
response = authenticated_client.delete(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
relationships = ProviderGroupMembership.objects.filter(
|
||||
provider_group=providers_fixture[0].id
|
||||
)
|
||||
assert relationships.count() == 0
|
||||
|
||||
def test_invalid_provider_group_id(
|
||||
self, authenticated_client, provider_groups_fixture
|
||||
):
|
||||
provider_group, *_ = provider_groups_fixture
|
||||
invalid_id = "non-existent-id"
|
||||
data = {"data": [{"type": "provider-group", "id": invalid_id}]}
|
||||
response = authenticated_client.post(
|
||||
reverse(
|
||||
"provider_group-providers-relationship",
|
||||
kwargs={"pk": provider_group.id},
|
||||
),
|
||||
data=data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
errors = response.json()["errors"][0]["detail"]
|
||||
assert "valid UUID" in errors
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestComplianceOverviewViewSet:
|
||||
def test_compliance_overview_list_none(self, authenticated_client):
|
||||
|
||||
@@ -14,16 +14,20 @@ from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
|
||||
from rest_framework_simplejwt.tokens import RefreshToken
|
||||
|
||||
from api.models import (
|
||||
ComplianceOverview,
|
||||
Finding,
|
||||
Invitation,
|
||||
Membership,
|
||||
Provider,
|
||||
ProviderGroup,
|
||||
ProviderGroupMembership,
|
||||
ProviderSecret,
|
||||
Resource,
|
||||
ResourceTag,
|
||||
Finding,
|
||||
ProviderSecret,
|
||||
Invitation,
|
||||
InvitationRoleRelationship,
|
||||
Role,
|
||||
RoleProviderGroupRelationship,
|
||||
UserRoleRelationship,
|
||||
ComplianceOverview,
|
||||
Scan,
|
||||
StateChoices,
|
||||
Task,
|
||||
@@ -176,10 +180,26 @@ class UserSerializer(BaseSerializerV1):
|
||||
"""
|
||||
|
||||
memberships = serializers.ResourceRelatedField(many=True, read_only=True)
|
||||
roles = serializers.ResourceRelatedField(many=True, read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ["id", "name", "email", "company_name", "date_joined", "memberships"]
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"email",
|
||||
"company_name",
|
||||
"date_joined",
|
||||
"memberships",
|
||||
"roles",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"roles": {"read_only": True},
|
||||
}
|
||||
|
||||
included_serializers = {
|
||||
"roles": "api.v1.serializers.RoleSerializer",
|
||||
}
|
||||
|
||||
|
||||
class UserCreateSerializer(BaseWriteSerializer):
|
||||
@@ -235,6 +255,73 @@ class UserUpdateSerializer(BaseWriteSerializer):
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
|
||||
class RoleResourceIdentifierSerializer(serializers.Serializer):
|
||||
resource_type = serializers.CharField(source="type")
|
||||
id = serializers.UUIDField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "role-identifier"
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""
|
||||
Ensure 'type' is used in the output instead of 'resource_type'.
|
||||
"""
|
||||
representation = super().to_representation(instance)
|
||||
representation["type"] = representation.pop("resource_type", None)
|
||||
return representation
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Map 'type' back to 'resource_type' during input.
|
||||
"""
|
||||
data["resource_type"] = data.pop("type", None)
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
class UserRoleRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for modifying user memberships
|
||||
"""
|
||||
|
||||
roles = serializers.ListField(
|
||||
child=RoleResourceIdentifierSerializer(),
|
||||
help_text="List of resource identifier objects representing roles.",
|
||||
)
|
||||
|
||||
def create(self, validated_data):
|
||||
role_ids = [item["id"] for item in validated_data["roles"]]
|
||||
roles = Role.objects.filter(id__in=role_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
new_relationships = [
|
||||
UserRoleRelationship(
|
||||
user=self.context.get("user"), role=r, tenant_id=tenant_id
|
||||
)
|
||||
for r in roles
|
||||
]
|
||||
UserRoleRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return self.context.get("user")
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
role_ids = [item["id"] for item in validated_data["roles"]]
|
||||
roles = Role.objects.filter(id__in=role_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
instance.roles.clear()
|
||||
new_relationships = [
|
||||
UserRoleRelationship(user=instance, role=r, tenant_id=tenant_id)
|
||||
for r in roles
|
||||
]
|
||||
UserRoleRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
model = UserRoleRelationship
|
||||
fields = ["id", "roles"]
|
||||
|
||||
|
||||
# Tasks
|
||||
class TaskBase(serializers.ModelSerializer):
|
||||
state_mapping = {
|
||||
@@ -361,31 +448,30 @@ class ProviderGroupSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
providers = serializers.ResourceRelatedField(many=True, read_only=True)
|
||||
|
||||
def validate(self, attrs):
|
||||
tenant = self.context["tenant_id"]
|
||||
name = attrs.get("name", self.instance.name if self.instance else None)
|
||||
|
||||
# Exclude the current instance when checking for uniqueness during updates
|
||||
queryset = ProviderGroup.objects.filter(tenant=tenant, name=name)
|
||||
if self.instance:
|
||||
queryset = queryset.exclude(pk=self.instance.pk)
|
||||
|
||||
if queryset.exists():
|
||||
if ProviderGroup.objects.filter(name=attrs.get("name")).exists():
|
||||
raise serializers.ValidationError(
|
||||
{
|
||||
"name": "A provider group with this name already exists for this tenant."
|
||||
}
|
||||
{"name": "A provider group with this name already exists."}
|
||||
)
|
||||
|
||||
return super().validate(attrs)
|
||||
|
||||
class Meta:
|
||||
model = ProviderGroup
|
||||
fields = ["id", "name", "inserted_at", "updated_at", "providers", "url"]
|
||||
read_only_fields = ["id", "inserted_at", "updated_at"]
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
"providers",
|
||||
"roles",
|
||||
"url",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"id": {"read_only": True},
|
||||
"inserted_at": {"read_only": True},
|
||||
"updated_at": {"read_only": True},
|
||||
"roles": {"read_only": True},
|
||||
"url": {"read_only": True},
|
||||
}
|
||||
|
||||
|
||||
@@ -406,41 +492,75 @@ class ProviderGroupUpdateSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
fields = ["id", "name"]
|
||||
|
||||
|
||||
class ProviderGroupMembershipUpdateSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
class ProviderResourceIdentifierSerializer(serializers.Serializer):
|
||||
resource_type = serializers.CharField(source="type")
|
||||
id = serializers.UUIDField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "provider-identifier"
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""
|
||||
Ensure 'type' is used in the output instead of 'resource_type'.
|
||||
"""
|
||||
representation = super().to_representation(instance)
|
||||
representation["type"] = representation.pop("resource_type", None)
|
||||
return representation
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Map 'type' back to 'resource_type' during input.
|
||||
"""
|
||||
data["resource_type"] = data.pop("type", None)
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
class ProviderGroupMembershipSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for modifying provider group memberships
|
||||
Serializer for modifying provider_group memberships
|
||||
"""
|
||||
|
||||
provider_ids = serializers.ListField(
|
||||
child=serializers.UUIDField(),
|
||||
help_text="List of provider UUIDs to add to the group",
|
||||
providers = serializers.ListField(
|
||||
child=ProviderResourceIdentifierSerializer(),
|
||||
help_text="List of resource identifier objects representing providers.",
|
||||
)
|
||||
|
||||
def validate(self, attrs):
|
||||
tenant_id = self.context["tenant_id"]
|
||||
provider_ids = attrs.get("provider_ids", [])
|
||||
def create(self, validated_data):
|
||||
provider_ids = [item["id"] for item in validated_data["providers"]]
|
||||
providers = Provider.objects.filter(id__in=provider_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
existing_provider_ids = set(
|
||||
Provider.objects.filter(
|
||||
id__in=provider_ids, tenant_id=tenant_id
|
||||
).values_list("id", flat=True)
|
||||
)
|
||||
provided_provider_ids = set(provider_ids)
|
||||
|
||||
missing_provider_ids = provided_provider_ids - existing_provider_ids
|
||||
|
||||
if missing_provider_ids:
|
||||
raise serializers.ValidationError(
|
||||
{
|
||||
"provider_ids": f"The following provider IDs do not exist: {', '.join(str(id) for id in missing_provider_ids)}"
|
||||
}
|
||||
new_relationships = [
|
||||
ProviderGroupMembership(
|
||||
provider_group=self.context.get("provider_group"),
|
||||
provider=p,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
ProviderGroupMembership.objects.bulk_create(new_relationships)
|
||||
|
||||
return super().validate(attrs)
|
||||
return self.context.get("provider_group")
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
provider_ids = [item["id"] for item in validated_data["providers"]]
|
||||
providers = Provider.objects.filter(id__in=provider_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
instance.providers.clear()
|
||||
new_relationships = [
|
||||
ProviderGroupMembership(
|
||||
provider_group=instance, provider=p, tenant_id=tenant_id
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
ProviderGroupMembership.objects.bulk_create(new_relationships)
|
||||
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
model = ProviderGroupMembership
|
||||
fields = ["id", "provider_ids"]
|
||||
fields = ["id", "providers"]
|
||||
|
||||
|
||||
# Providers
|
||||
@@ -1034,6 +1154,8 @@ class InvitationSerializer(RLSSerializer):
|
||||
Serializer for the Invitation model.
|
||||
"""
|
||||
|
||||
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
|
||||
|
||||
class Meta:
|
||||
model = Invitation
|
||||
fields = [
|
||||
@@ -1043,6 +1165,7 @@ class InvitationSerializer(RLSSerializer):
|
||||
"email",
|
||||
"state",
|
||||
"token",
|
||||
"roles",
|
||||
"expires_at",
|
||||
"inviter",
|
||||
"url",
|
||||
@@ -1050,6 +1173,8 @@ class InvitationSerializer(RLSSerializer):
|
||||
|
||||
|
||||
class InvitationBaseWriteSerializer(BaseWriteSerializer):
|
||||
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
|
||||
|
||||
def validate_email(self, value):
|
||||
user = User.objects.filter(email=value).first()
|
||||
tenant_id = self.context["tenant_id"]
|
||||
@@ -1086,31 +1211,54 @@ class InvitationCreateSerializer(InvitationBaseWriteSerializer, RLSSerializer):
|
||||
|
||||
class Meta:
|
||||
model = Invitation
|
||||
fields = ["email", "expires_at", "state", "token", "inviter"]
|
||||
fields = ["email", "expires_at", "state", "token", "inviter", "roles"]
|
||||
extra_kwargs = {
|
||||
"token": {"read_only": True},
|
||||
"state": {"read_only": True},
|
||||
"inviter": {"read_only": True},
|
||||
"expires_at": {"required": False},
|
||||
"roles": {"required": False},
|
||||
}
|
||||
|
||||
def create(self, validated_data):
|
||||
inviter = self.context.get("request").user
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
validated_data["inviter"] = inviter
|
||||
return super().create(validated_data)
|
||||
roles = validated_data.pop("roles", [])
|
||||
invitation = super().create(validated_data)
|
||||
for role in roles:
|
||||
InvitationRoleRelationship.objects.create(
|
||||
role=role, invitation=invitation, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
|
||||
class InvitationUpdateSerializer(InvitationBaseWriteSerializer):
|
||||
class Meta:
|
||||
model = Invitation
|
||||
fields = ["id", "email", "expires_at", "state", "token"]
|
||||
fields = ["id", "email", "expires_at", "state", "token", "roles"]
|
||||
extra_kwargs = {
|
||||
"token": {"read_only": True},
|
||||
"state": {"read_only": True},
|
||||
"expires_at": {"required": False},
|
||||
"email": {"required": False},
|
||||
"roles": {"required": False},
|
||||
}
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
roles = validated_data.pop("roles", [])
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
invitation = super().update(instance, validated_data)
|
||||
if roles:
|
||||
instance.roles.clear()
|
||||
for role in roles:
|
||||
InvitationRoleRelationship.objects.create(
|
||||
role=role, invitation=invitation, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
|
||||
class InvitationAcceptSerializer(RLSSerializer):
|
||||
"""Serializer for accepting an invitation."""
|
||||
@@ -1122,6 +1270,179 @@ class InvitationAcceptSerializer(RLSSerializer):
|
||||
fields = ["invitation_token"]
|
||||
|
||||
|
||||
# Roles
|
||||
|
||||
|
||||
class RoleSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
provider_groups = serializers.ResourceRelatedField(
|
||||
many=True, queryset=ProviderGroup.objects.all()
|
||||
)
|
||||
permission_state = serializers.SerializerMethodField()
|
||||
|
||||
def get_permission_state(self, obj):
|
||||
return obj.permission_state
|
||||
|
||||
def validate(self, attrs):
|
||||
if Role.objects.filter(name=attrs.get("name")).exists():
|
||||
raise serializers.ValidationError(
|
||||
{"name": "A role with this name already exists."}
|
||||
)
|
||||
|
||||
if attrs.get("manage_providers"):
|
||||
attrs["unlimited_visibility"] = True
|
||||
|
||||
# Prevent updates to the admin role
|
||||
if getattr(self.instance, "name", None) == "admin":
|
||||
raise serializers.ValidationError(
|
||||
{"name": "The admin role cannot be updated."}
|
||||
)
|
||||
|
||||
return super().validate(attrs)
|
||||
|
||||
class Meta:
|
||||
model = Role
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"manage_users",
|
||||
"manage_account",
|
||||
"manage_billing",
|
||||
"manage_providers",
|
||||
"manage_integrations",
|
||||
"manage_scans",
|
||||
"permission_state",
|
||||
"unlimited_visibility",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
"provider_groups",
|
||||
"users",
|
||||
"invitations",
|
||||
"url",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"id": {"read_only": True},
|
||||
"inserted_at": {"read_only": True},
|
||||
"updated_at": {"read_only": True},
|
||||
"users": {"read_only": True},
|
||||
"url": {"read_only": True},
|
||||
}
|
||||
|
||||
|
||||
class RoleCreateSerializer(RoleSerializer):
|
||||
def create(self, validated_data):
|
||||
provider_groups = validated_data.pop("provider_groups", [])
|
||||
users = validated_data.pop("users", [])
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
role = Role.objects.create(tenant_id=tenant_id, **validated_data)
|
||||
|
||||
through_model_instances = [
|
||||
RoleProviderGroupRelationship(
|
||||
role=role,
|
||||
provider_group=provider_group,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider_group in provider_groups
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
through_model_instances = [
|
||||
UserRoleRelationship(
|
||||
role=user,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
UserRoleRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
return role
|
||||
|
||||
|
||||
class RoleUpdateSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
class Meta:
|
||||
model = Role
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"manage_users",
|
||||
"manage_account",
|
||||
"manage_billing",
|
||||
"manage_providers",
|
||||
"manage_integrations",
|
||||
"manage_scans",
|
||||
"unlimited_visibility",
|
||||
]
|
||||
|
||||
|
||||
class ProviderGroupResourceIdentifierSerializer(serializers.Serializer):
|
||||
resource_type = serializers.CharField(source="type")
|
||||
id = serializers.UUIDField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "provider-group-identifier"
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""
|
||||
Ensure 'type' is used in the output instead of 'resource_type'.
|
||||
"""
|
||||
representation = super().to_representation(instance)
|
||||
representation["type"] = representation.pop("resource_type", None)
|
||||
return representation
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Map 'type' back to 'resource_type' during input.
|
||||
"""
|
||||
data["resource_type"] = data.pop("type", None)
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
class RoleProviderGroupRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for modifying role memberships
|
||||
"""
|
||||
|
||||
provider_groups = serializers.ListField(
|
||||
child=ProviderGroupResourceIdentifierSerializer(),
|
||||
help_text="List of resource identifier objects representing provider groups.",
|
||||
)
|
||||
|
||||
def create(self, validated_data):
|
||||
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
|
||||
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
new_relationships = [
|
||||
RoleProviderGroupRelationship(
|
||||
role=self.context.get("role"), provider_group=pg, tenant_id=tenant_id
|
||||
)
|
||||
for pg in provider_groups
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return self.context.get("role")
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
|
||||
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
instance.provider_groups.clear()
|
||||
new_relationships = [
|
||||
RoleProviderGroupRelationship(
|
||||
role=instance, provider_group=pg, tenant_id=tenant_id
|
||||
)
|
||||
for pg in provider_groups
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
model = RoleProviderGroupRelationship
|
||||
fields = ["id", "provider_groups"]
|
||||
|
||||
|
||||
# Compliance overview
|
||||
|
||||
|
||||
|
||||
@@ -3,16 +3,20 @@ from drf_spectacular.views import SpectacularRedocView
|
||||
from rest_framework_nested import routers
|
||||
|
||||
from api.v1.views import (
|
||||
ComplianceOverviewViewSet,
|
||||
CustomTokenObtainView,
|
||||
CustomTokenRefreshView,
|
||||
FindingViewSet,
|
||||
InvitationAcceptViewSet,
|
||||
InvitationViewSet,
|
||||
MembershipViewSet,
|
||||
OverviewViewSet,
|
||||
ProviderGroupViewSet,
|
||||
ProviderGroupProvidersRelationshipView,
|
||||
ProviderSecretViewSet,
|
||||
InvitationViewSet,
|
||||
InvitationAcceptViewSet,
|
||||
RoleViewSet,
|
||||
RoleProviderGroupRelationshipView,
|
||||
UserRoleRelationshipView,
|
||||
OverviewViewSet,
|
||||
ComplianceOverviewViewSet,
|
||||
ProviderViewSet,
|
||||
ResourceViewSet,
|
||||
ScanViewSet,
|
||||
@@ -29,11 +33,12 @@ router = routers.DefaultRouter(trailing_slash=False)
|
||||
router.register(r"users", UserViewSet, basename="user")
|
||||
router.register(r"tenants", TenantViewSet, basename="tenant")
|
||||
router.register(r"providers", ProviderViewSet, basename="provider")
|
||||
router.register(r"provider_groups", ProviderGroupViewSet, basename="providergroup")
|
||||
router.register(r"provider-groups", ProviderGroupViewSet, basename="providergroup")
|
||||
router.register(r"scans", ScanViewSet, basename="scan")
|
||||
router.register(r"tasks", TaskViewSet, basename="task")
|
||||
router.register(r"resources", ResourceViewSet, basename="resource")
|
||||
router.register(r"findings", FindingViewSet, basename="finding")
|
||||
router.register(r"roles", RoleViewSet, basename="role")
|
||||
router.register(
|
||||
r"compliance-overviews", ComplianceOverviewViewSet, basename="complianceoverview"
|
||||
)
|
||||
@@ -80,6 +85,27 @@ urlpatterns = [
|
||||
InvitationAcceptViewSet.as_view({"post": "accept"}),
|
||||
name="invitation-accept",
|
||||
),
|
||||
path(
|
||||
"roles/<uuid:pk>/relationships/provider_groups",
|
||||
RoleProviderGroupRelationshipView.as_view(
|
||||
{"post": "create", "patch": "partial_update", "delete": "destroy"}
|
||||
),
|
||||
name="role-provider-groups-relationship",
|
||||
),
|
||||
path(
|
||||
"users/<uuid:pk>/relationships/roles",
|
||||
UserRoleRelationshipView.as_view(
|
||||
{"post": "create", "patch": "partial_update", "delete": "destroy"}
|
||||
),
|
||||
name="user-roles-relationship",
|
||||
),
|
||||
path(
|
||||
"provider-groups/<uuid:pk>/relationships/providers",
|
||||
ProviderGroupProvidersRelationshipView.as_view(
|
||||
{"post": "create", "patch": "partial_update", "delete": "destroy"}
|
||||
),
|
||||
name="provider_group-providers-relationship",
|
||||
),
|
||||
path("", include(router.urls)),
|
||||
path("", include(tenants_router.urls)),
|
||||
path("", include(users_router.urls)),
|
||||
|
||||
+575
-74
@@ -8,6 +8,7 @@ from django.urls import reverse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.cache import cache_control
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
from drf_spectacular_jsonapi.schemas.openapi import JsonApiAutoSchema
|
||||
from drf_spectacular.utils import (
|
||||
OpenApiParameter,
|
||||
OpenApiResponse,
|
||||
@@ -25,8 +26,10 @@ from rest_framework.exceptions import (
|
||||
ValidationError,
|
||||
)
|
||||
from rest_framework.generics import GenericAPIView, get_object_or_404
|
||||
from rest_framework_json_api.views import Response
|
||||
from rest_framework_json_api.views import RelationshipView, Response
|
||||
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
|
||||
from rest_framework.permissions import SAFE_METHODS
|
||||
|
||||
from tasks.beat import schedule_provider_scan
|
||||
from tasks.tasks import (
|
||||
check_provider_connection_task,
|
||||
@@ -52,8 +55,12 @@ from api.filters import (
|
||||
TaskFilter,
|
||||
TenantFilter,
|
||||
UserFilter,
|
||||
RoleFilter,
|
||||
)
|
||||
from api.models import (
|
||||
StatusChoices,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
ComplianceOverview,
|
||||
Finding,
|
||||
Invitation,
|
||||
@@ -62,20 +69,27 @@ from api.models import (
|
||||
ProviderGroup,
|
||||
ProviderGroupMembership,
|
||||
ProviderSecret,
|
||||
Role,
|
||||
RoleProviderGroupRelationship,
|
||||
Resource,
|
||||
Scan,
|
||||
ScanSummary,
|
||||
SeverityChoices,
|
||||
StateChoices,
|
||||
StatusChoices,
|
||||
Task,
|
||||
User,
|
||||
)
|
||||
from api.pagination import ComplianceOverviewPagination
|
||||
from api.rbac.permissions import HasPermissions, Permissions
|
||||
from api.rls import Tenant
|
||||
from api.utils import validate_invitation
|
||||
from api.uuid_utils import datetime_to_uuid7
|
||||
from api.v1.serializers import (
|
||||
TokenSerializer,
|
||||
TokenRefreshSerializer,
|
||||
UserSerializer,
|
||||
UserCreateSerializer,
|
||||
UserUpdateSerializer,
|
||||
UserRoleRelationshipSerializer,
|
||||
ComplianceOverviewFullSerializer,
|
||||
ComplianceOverviewSerializer,
|
||||
FindingDynamicFilterSerializer,
|
||||
@@ -89,34 +103,39 @@ from api.v1.serializers import (
|
||||
OverviewProviderSerializer,
|
||||
OverviewSeveritySerializer,
|
||||
ProviderCreateSerializer,
|
||||
ProviderGroupMembershipUpdateSerializer,
|
||||
ProviderGroupMembershipSerializer,
|
||||
ProviderGroupSerializer,
|
||||
ProviderGroupUpdateSerializer,
|
||||
ProviderSecretCreateSerializer,
|
||||
ProviderSecretSerializer,
|
||||
ProviderSecretUpdateSerializer,
|
||||
RoleProviderGroupRelationshipSerializer,
|
||||
ProviderSerializer,
|
||||
ProviderUpdateSerializer,
|
||||
ResourceSerializer,
|
||||
ScanCreateSerializer,
|
||||
ScanSerializer,
|
||||
ScanUpdateSerializer,
|
||||
ScheduleDailyCreateSerializer,
|
||||
TaskSerializer,
|
||||
TenantSerializer,
|
||||
TokenRefreshSerializer,
|
||||
TokenSerializer,
|
||||
UserCreateSerializer,
|
||||
UserSerializer,
|
||||
UserUpdateSerializer,
|
||||
TaskSerializer,
|
||||
ScanSerializer,
|
||||
ScanCreateSerializer,
|
||||
ScanUpdateSerializer,
|
||||
ResourceSerializer,
|
||||
ProviderSecretSerializer,
|
||||
ProviderSecretUpdateSerializer,
|
||||
ProviderSecretCreateSerializer,
|
||||
RoleSerializer,
|
||||
RoleCreateSerializer,
|
||||
RoleUpdateSerializer,
|
||||
ScheduleDailyCreateSerializer,
|
||||
)
|
||||
|
||||
|
||||
CACHE_DECORATOR = cache_control(
|
||||
max_age=django_settings.CACHE_MAX_AGE,
|
||||
stale_while_revalidate=django_settings.CACHE_STALE_WHILE_REVALIDATE,
|
||||
)
|
||||
|
||||
|
||||
class RelationshipViewSchema(JsonApiAutoSchema):
|
||||
def _resolve_path_parameters(self, _path_variables):
|
||||
return []
|
||||
|
||||
|
||||
@extend_schema(
|
||||
tags=["Token"],
|
||||
summary="Obtain a token",
|
||||
@@ -271,6 +290,26 @@ class UserViewSet(BaseUserViewset):
|
||||
filterset_class = UserFilter
|
||||
ordering = ["-date_joined"]
|
||||
ordering_fields = ["name", "email", "company_name", "date_joined", "is_active"]
|
||||
required_permissions = [Permissions.MANAGE_USERS]
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
"""
|
||||
self.required_permissions = self.get_required_permissions()
|
||||
super().initial(request, *args, **kwargs)
|
||||
|
||||
def get_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.action == "me":
|
||||
# No permissions required for me request
|
||||
return []
|
||||
else:
|
||||
# Require permission for the rest of the requests
|
||||
return [Permissions.MANAGE_USERS]
|
||||
|
||||
def get_queryset(self):
|
||||
# If called during schema generation, return an empty queryset
|
||||
@@ -347,11 +386,123 @@ class UserViewSet(BaseUserViewset):
|
||||
user=user, tenant=tenant, role=role
|
||||
)
|
||||
if invitation:
|
||||
user_role = []
|
||||
for role in invitation.roles.all():
|
||||
user_role.append(
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user, role=role, tenant=invitation.tenant
|
||||
)
|
||||
)
|
||||
invitation.state = Invitation.State.ACCEPTED
|
||||
invitation.save(using=MainRouter.admin_db)
|
||||
else:
|
||||
role = Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
return Response(data=UserSerializer(user).data, status=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
tags=["User"],
|
||||
summary="Create a new user-roles relationship",
|
||||
description="Add a new user-roles relationship to the system by providing the required user-roles details.",
|
||||
responses={
|
||||
204: OpenApiResponse(description="Relationship created successfully"),
|
||||
400: OpenApiResponse(
|
||||
description="Bad request (e.g., relationship already exists)"
|
||||
),
|
||||
},
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
tags=["User"],
|
||||
summary="Partially update a user-roles relationship",
|
||||
description="Update the user-roles relationship information without affecting other fields.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship updated successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
tags=["User"],
|
||||
summary="Delete a user-roles relationship",
|
||||
description="Remove the user-roles relationship from the system by their ID.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship deleted successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
queryset = User.objects.all()
|
||||
serializer_class = UserRoleRelationshipSerializer
|
||||
resource_name = "roles"
|
||||
http_method_names = ["post", "patch", "delete"]
|
||||
schema = RelationshipViewSchema()
|
||||
|
||||
def get_queryset(self):
|
||||
return User.objects.all()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
user = self.get_object()
|
||||
|
||||
role_ids = [item["id"] for item in request.data]
|
||||
existing_relationships = UserRoleRelationship.objects.filter(
|
||||
user=user, role_id__in=role_ids
|
||||
)
|
||||
|
||||
if existing_relationships.exists():
|
||||
return Response(
|
||||
{"detail": "One or more roles are already associated with the user."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(
|
||||
data={"roles": request.data},
|
||||
context={
|
||||
"user": user,
|
||||
"tenant_id": self.request.tenant_id,
|
||||
"request": request,
|
||||
},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
user = self.get_object()
|
||||
serializer = self.get_serializer(
|
||||
instance=user,
|
||||
data={"roles": request.data},
|
||||
context={"tenant_id": self.request.tenant_id, "request": request},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
user = self.get_object()
|
||||
user.roles.clear()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
tags=["Tenant"],
|
||||
@@ -389,6 +540,8 @@ class TenantViewSet(BaseTenantViewset):
|
||||
search_fields = ["name"]
|
||||
ordering = ["-inserted_at"]
|
||||
ordering_fields = ["name", "inserted_at", "updated_at"]
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def get_queryset(self):
|
||||
return Tenant.objects.all()
|
||||
@@ -562,66 +715,139 @@ class ProviderGroupViewSet(BaseRLSViewSet):
|
||||
queryset = ProviderGroup.objects.all()
|
||||
serializer_class = ProviderGroupSerializer
|
||||
filterset_class = ProviderGroupFilter
|
||||
http_method_names = ["get", "post", "patch", "put", "delete"]
|
||||
http_method_names = ["get", "post", "patch", "delete"]
|
||||
ordering = ["inserted_at"]
|
||||
required_permissions = []
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
"""
|
||||
self.required_permissions = self.get_required_permissions()
|
||||
super().initial(request, *args, **kwargs)
|
||||
|
||||
def get_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.request.method in SAFE_METHODS:
|
||||
# No permissions required for GET requests
|
||||
return []
|
||||
else:
|
||||
# Require permission for non-GET requests
|
||||
return [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return ProviderGroup.objects.prefetch_related("providers")
|
||||
user = self.request.user
|
||||
user_roles = user.roles.all()
|
||||
|
||||
# Check if any of the user's roles have UNLIMITED_VISIBILITY
|
||||
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
# User has unlimited visibility, return all provider groups
|
||||
return ProviderGroup.objects.prefetch_related("providers")
|
||||
|
||||
# Collect provider groups associated with the user's roles
|
||||
provider_groups = (
|
||||
ProviderGroup.objects.filter(roles__in=user_roles)
|
||||
.distinct()
|
||||
.prefetch_related("providers")
|
||||
)
|
||||
|
||||
return provider_groups
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "partial_update":
|
||||
return ProviderGroupUpdateSerializer
|
||||
elif self.action == "providers":
|
||||
if hasattr(self, "response_serializer_class"):
|
||||
return self.response_serializer_class
|
||||
return ProviderGroupMembershipUpdateSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
@extend_schema(
|
||||
tags=["Provider Group"],
|
||||
summary="Add providers to a provider group",
|
||||
description="Add one or more providers to an existing provider group.",
|
||||
request=ProviderGroupMembershipUpdateSerializer,
|
||||
responses={200: OpenApiResponse(response=ProviderGroupSerializer)},
|
||||
)
|
||||
@action(detail=True, methods=["put"], url_name="providers")
|
||||
def providers(self, request, pk=None):
|
||||
|
||||
@extend_schema(tags=["Provider Group"])
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
summary="Create a new provider_group-providers relationship",
|
||||
description="Add a new provider_group-providers relationship to the system by providing the required provider_group-providers details.",
|
||||
responses={
|
||||
204: OpenApiResponse(description="Relationship created successfully"),
|
||||
400: OpenApiResponse(
|
||||
description="Bad request (e.g., relationship already exists)"
|
||||
),
|
||||
},
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
summary="Partially update a provider_group-providers relationship",
|
||||
description="Update the provider_group-providers relationship information without affecting other fields.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship updated successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
summary="Delete a provider_group-providers relationship",
|
||||
description="Remove the provider_group-providers relationship from the system by their ID.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship deleted successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
queryset = ProviderGroup.objects.all()
|
||||
serializer_class = ProviderGroupMembershipSerializer
|
||||
resource_name = "providers"
|
||||
http_method_names = ["post", "patch", "delete"]
|
||||
schema = RelationshipViewSchema()
|
||||
|
||||
def get_queryset(self):
|
||||
return ProviderGroup.objects.all()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
provider_group = self.get_object()
|
||||
|
||||
# Validate input data
|
||||
serializer = self.get_serializer_class()(
|
||||
data=request.data,
|
||||
context=self.get_serializer_context(),
|
||||
provider_ids = [item["id"] for item in request.data]
|
||||
existing_relationships = ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group, provider_id__in=provider_ids
|
||||
)
|
||||
|
||||
if existing_relationships.exists():
|
||||
return Response(
|
||||
{
|
||||
"detail": "One or more providers are already associated with the provider_group."
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(
|
||||
data={"providers": request.data},
|
||||
context={
|
||||
"provider_group": provider_group,
|
||||
"tenant_id": self.request.tenant_id,
|
||||
"request": request,
|
||||
},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
provider_ids = serializer.validated_data["provider_ids"]
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# Update memberships
|
||||
ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group, tenant_id=request.tenant_id
|
||||
).delete()
|
||||
|
||||
provider_group_memberships = [
|
||||
ProviderGroupMembership(
|
||||
tenant_id=self.request.tenant_id,
|
||||
provider_group=provider_group,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for provider_id in provider_ids
|
||||
]
|
||||
|
||||
ProviderGroupMembership.objects.bulk_create(
|
||||
provider_group_memberships, ignore_conflicts=True
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
provider_group = self.get_object()
|
||||
serializer = self.get_serializer(
|
||||
instance=provider_group,
|
||||
data={"providers": request.data},
|
||||
context={"tenant_id": self.request.tenant_id, "request": request},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# Return the updated provider group with providers
|
||||
provider_group.refresh_from_db()
|
||||
self.response_serializer_class = ProviderGroupSerializer
|
||||
response_serializer = ProviderGroupSerializer(
|
||||
provider_group, context=self.get_serializer_context()
|
||||
)
|
||||
return Response(data=response_serializer.data, status=status.HTTP_200_OK)
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
provider_group = self.get_object()
|
||||
provider_group.providers.clear()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
@@ -671,9 +897,41 @@ class ProviderViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
required_permissions = []
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
"""
|
||||
self.required_permissions = self.get_required_permissions()
|
||||
super().initial(request, *args, **kwargs)
|
||||
|
||||
def get_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.request.method in SAFE_METHODS:
|
||||
# No permissions required for GET requests
|
||||
return []
|
||||
else:
|
||||
# Require permission for non-GET requests
|
||||
return [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return Provider.objects.all()
|
||||
user = self.request.user
|
||||
user_roles = user.roles.all()
|
||||
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
# User has unlimited visibility, return all providers
|
||||
return Provider.objects.all()
|
||||
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
provider_groups = user_roles[0].provider_groups.all()
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=provider_groups
|
||||
).distinct()
|
||||
|
||||
return providers
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -793,9 +1051,40 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
required_permissions = [Permissions.MANAGE_SCANS]
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
"""
|
||||
self.required_permissions = self.get_required_permissions()
|
||||
super().initial(request, *args, **kwargs)
|
||||
|
||||
def get_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.request.method in SAFE_METHODS:
|
||||
# No permissions required for GET requests
|
||||
return []
|
||||
else:
|
||||
# Require permission for non-GET requests
|
||||
return [Permissions.MANAGE_SCANS]
|
||||
|
||||
def get_queryset(self):
|
||||
return Scan.objects.all()
|
||||
user = self.request.user
|
||||
user_roles = user.roles.all()
|
||||
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
# User has unlimited visibility, return all scans
|
||||
return Scan.objects.all()
|
||||
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
provider_groups = user_roles[0].provider_groups.all()
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=provider_groups
|
||||
).distinct()
|
||||
return Scan.objects.filter(provider__in=providers).distinct()
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -885,11 +1174,26 @@ class TaskViewSet(BaseRLSViewSet):
|
||||
search_fields = ["name"]
|
||||
ordering = ["-inserted_at"]
|
||||
ordering_fields = ["inserted_at", "completed_at", "name", "state"]
|
||||
required_permissions = []
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def get_queryset(self):
|
||||
return Task.objects.annotate(
|
||||
name=F("task_runner_task__task_name"), state=F("task_runner_task__status")
|
||||
)
|
||||
user = self.request.user
|
||||
user_roles = user.roles.all()
|
||||
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
# User has unlimited visibility, return all tasks
|
||||
return Task.objects.annotate(
|
||||
name=F("task_runner_task__task_name"),
|
||||
state=F("task_runner_task__status"),
|
||||
)
|
||||
|
||||
# User lacks permission, filter tasks based on provider groups associated with the role
|
||||
provider_groups = user_roles[0].provider_groups.all()
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=provider_groups
|
||||
).distinct()
|
||||
scans = Scan.objects.filter(provider__in=providers).distinct()
|
||||
return Task.objects.filter(scan__in=scans).distinct()
|
||||
|
||||
def destroy(self, request, *args, pk=None, **kwargs):
|
||||
task = get_object_or_404(Task, pk=pk)
|
||||
@@ -950,11 +1254,31 @@ class ResourceViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
required_permissions = []
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
"""
|
||||
self.required_permissions = ResourceViewSet.required_permissions
|
||||
super().initial(request, *args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = Resource.objects.all()
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
user = self.request.user
|
||||
user_roles = user.roles.all()
|
||||
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
# User has unlimited visibility, return all scans
|
||||
queryset = Resource.objects.all()
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
provider_groups = user_roles[0].provider_groups.all()
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=provider_groups
|
||||
).distinct()
|
||||
queryset = Resource.objects.filter(provider__in=providers).distinct()
|
||||
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
if search_value:
|
||||
# Django's ORM will build a LEFT JOIN and OUTER JOIN on the "through" table, resulting in duplicates
|
||||
# The duplicates then require a `distinct` query
|
||||
@@ -1025,11 +1349,15 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
required_permissions = []
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def inserted_at_to_uuidv7(self, inserted_at):
|
||||
if inserted_at is None:
|
||||
return None
|
||||
return datetime_to_uuid7(inserted_at)
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
"""
|
||||
self.required_permissions = ResourceViewSet.required_permissions
|
||||
super().initial(request, *args, **kwargs)
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "findings_services_regions":
|
||||
@@ -1038,9 +1366,21 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
return super().get_serializer_class()
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = Finding.objects.all()
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
user = self.request.user
|
||||
user_roles = user.roles.all()
|
||||
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
# User has unlimited visibility, return all scans
|
||||
queryset = Finding.objects.all()
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
provider_groups = user_roles[0].provider_groups.all()
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=provider_groups
|
||||
).distinct()
|
||||
scans = Scan.objects.filter(provider__in=providers).distinct()
|
||||
queryset = Finding.objects.filter(scan__in=scans).distinct()
|
||||
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
if search_value:
|
||||
# Django's ORM will build a LEFT JOIN and OUTER JOIN on any "through" tables, resulting in duplicates
|
||||
# The duplicates then require a `distinct` query
|
||||
@@ -1068,6 +1408,11 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
|
||||
return queryset
|
||||
|
||||
def inserted_at_to_uuidv7(self, inserted_at):
|
||||
if inserted_at is None:
|
||||
return None
|
||||
return datetime_to_uuid7(inserted_at)
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="findings_services_regions")
|
||||
def findings_services_regions(self, request):
|
||||
queryset = self.get_queryset()
|
||||
@@ -1188,6 +1533,8 @@ class InvitationViewSet(BaseRLSViewSet):
|
||||
"state",
|
||||
"inviter",
|
||||
]
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def get_queryset(self):
|
||||
return Invitation.objects.all()
|
||||
@@ -1275,6 +1622,13 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
|
||||
user=user,
|
||||
tenant=invitation.tenant,
|
||||
)
|
||||
user_role = []
|
||||
for role in invitation.roles.all():
|
||||
user_role.append(
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user, role=role, tenant=invitation.tenant
|
||||
)
|
||||
)
|
||||
invitation.state = Invitation.State.ACCEPTED
|
||||
invitation.save(using=MainRouter.admin_db)
|
||||
|
||||
@@ -1283,6 +1637,153 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
|
||||
return Response(data=membership_serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
@extend_schema(tags=["Role"])
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="List all roles",
|
||||
description="Retrieve a list of all roles with options for filtering by various criteria.",
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Retrieve data from a role",
|
||||
description="Fetch detailed information about a specific role by their ID.",
|
||||
),
|
||||
create=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Create a new role",
|
||||
description="Add a new role to the system by providing the required role details.",
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Partially update a role",
|
||||
description="Update certain fields of an existing role's information without affecting other fields.",
|
||||
responses={200: RoleSerializer},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Delete a role",
|
||||
description="Remove a role from the system by their ID.",
|
||||
),
|
||||
)
|
||||
class RoleViewSet(BaseRLSViewSet):
|
||||
queryset = Role.objects.all()
|
||||
serializer_class = RoleSerializer
|
||||
filterset_class = RoleFilter
|
||||
http_method_names = ["get", "post", "patch", "delete"]
|
||||
ordering = ["inserted_at"]
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
|
||||
|
||||
def get_queryset(self):
|
||||
return Role.objects.all()
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
return RoleCreateSerializer
|
||||
elif self.action == "partial_update":
|
||||
return RoleUpdateSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
user = request.user
|
||||
user_role = user.roles.all().first()
|
||||
# If the user is the owner of the role, the manage_account field is not editable
|
||||
if user_role and kwargs["pk"] == str(user_role.id):
|
||||
request.data["manage_account"] = str(user_role.manage_account).lower()
|
||||
return super().partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Create a new role-provider_groups relationship",
|
||||
description="Add a new role-provider_groups relationship to the system by providing the required role-provider_groups details.",
|
||||
responses={
|
||||
204: OpenApiResponse(description="Relationship created successfully"),
|
||||
400: OpenApiResponse(
|
||||
description="Bad request (e.g., relationship already exists)"
|
||||
),
|
||||
},
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Partially update a role-provider_groups relationship",
|
||||
description="Update the role-provider_groups relationship information without affecting other fields.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship updated successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Delete a role-provider_groups relationship",
|
||||
description="Remove the role-provider_groups relationship from the system by their ID.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship deleted successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
queryset = Role.objects.all()
|
||||
serializer_class = RoleProviderGroupRelationshipSerializer
|
||||
resource_name = "provider_groups"
|
||||
http_method_names = ["post", "patch", "delete"]
|
||||
schema = RelationshipViewSchema()
|
||||
|
||||
def get_queryset(self):
|
||||
return Role.objects.all()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
|
||||
provider_group_ids = [item["id"] for item in request.data]
|
||||
existing_relationships = RoleProviderGroupRelationship.objects.filter(
|
||||
role=role, provider_group_id__in=provider_group_ids
|
||||
)
|
||||
|
||||
if existing_relationships.exists():
|
||||
return Response(
|
||||
{
|
||||
"detail": "One or more provider groups are already associated with the role."
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(
|
||||
data={"provider_groups": request.data},
|
||||
context={
|
||||
"role": role,
|
||||
"tenant_id": self.request.tenant_id,
|
||||
"request": request,
|
||||
},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
serializer = self.get_serializer(
|
||||
instance=role,
|
||||
data={"provider_groups": request.data},
|
||||
context={"tenant_id": self.request.tenant_id, "request": request},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
role.provider_groups.clear()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
tags=["Compliance Overview"],
|
||||
|
||||
@@ -10,8 +10,8 @@ DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": "prowler_db_test",
|
||||
"USER": env("POSTGRES_USER", default="prowler"),
|
||||
"PASSWORD": env("POSTGRES_PASSWORD", default="S3cret"),
|
||||
"USER": env("POSTGRES_USER", default="prowler_admin"),
|
||||
"PASSWORD": env("POSTGRES_PASSWORD", default="postgres"),
|
||||
"HOST": env("POSTGRES_HOST", default="localhost"),
|
||||
"PORT": env("POSTGRES_PORT", default="5432"),
|
||||
},
|
||||
|
||||
+236
-1
@@ -10,6 +10,8 @@ from prowler.lib.check.models import Severity
|
||||
from prowler.lib.outputs.finding import Status
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
from unittest.mock import patch
|
||||
from api.db_utils import tenant_transaction
|
||||
|
||||
from api.models import (
|
||||
Finding,
|
||||
@@ -20,6 +22,7 @@ from api.models import (
|
||||
ProviderGroup,
|
||||
Resource,
|
||||
ResourceTag,
|
||||
Role,
|
||||
Scan,
|
||||
StateChoices,
|
||||
Task,
|
||||
@@ -27,6 +30,7 @@ from api.models import (
|
||||
ProviderSecret,
|
||||
Invitation,
|
||||
ComplianceOverview,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from api.rls import Tenant
|
||||
from api.v1.serializers import TokenSerializer
|
||||
@@ -83,8 +87,148 @@ def create_test_user(django_db_setup, django_db_blocker):
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac(django_db_setup, django_db_blocker):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing",
|
||||
email="rbac@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
role=Membership.RoleChoices.OWNER,
|
||||
)
|
||||
Role.objects.create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=Role.objects.get(name="admin"),
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing",
|
||||
email="rbac_noroles@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
role=Membership.RoleChoices.OWNER,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing_limited",
|
||||
email="rbac_limited@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
role=Membership.RoleChoices.OWNER,
|
||||
)
|
||||
Role.objects.create(
|
||||
name="limited",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=False,
|
||||
manage_account=False,
|
||||
manage_billing=False,
|
||||
manage_providers=False,
|
||||
manage_integrations=False,
|
||||
manage_scans=False,
|
||||
unlimited_visibility=False,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=Role.objects.get(name="limited"),
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client(create_test_user, tenants_fixture, client):
|
||||
def authenticated_client_rbac(create_test_user_rbac, tenants_fixture, client):
|
||||
client.user = create_test_user_rbac
|
||||
serializer = TokenSerializer(
|
||||
data={"type": "tokens", "email": "rbac@rbac.com", "password": TEST_PASSWORD}
|
||||
)
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client_rbac_noroles(
|
||||
create_test_user_rbac_no_roles, tenants_fixture, client
|
||||
):
|
||||
client.user = create_test_user_rbac_no_roles
|
||||
serializer = TokenSerializer(
|
||||
data={
|
||||
"type": "tokens",
|
||||
"email": "rbac_noroles@rbac.com",
|
||||
"password": TEST_PASSWORD,
|
||||
}
|
||||
)
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client_no_permissions_rbac(
|
||||
create_test_user_rbac_limited, tenants_fixture, client
|
||||
):
|
||||
client.user = create_test_user_rbac_limited
|
||||
serializer = TokenSerializer(
|
||||
data={
|
||||
"type": "tokens",
|
||||
"email": "rbac_limited@rbac.com",
|
||||
"password": TEST_PASSWORD,
|
||||
}
|
||||
)
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client(
|
||||
create_test_user, tenants_fixture, set_user_admin_roles_fixture, client
|
||||
):
|
||||
client.user = create_test_user
|
||||
serializer = TokenSerializer(
|
||||
data={"type": "tokens", "email": TEST_USER, "password": TEST_PASSWORD}
|
||||
@@ -104,6 +248,7 @@ def authenticated_api_client(create_test_user, tenants_fixture):
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@@ -128,9 +273,33 @@ def tenants_fixture(create_test_user):
|
||||
tenant3 = Tenant.objects.create(
|
||||
name="Tenant Three",
|
||||
)
|
||||
|
||||
return tenant1, tenant2, tenant3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
|
||||
user = create_test_user
|
||||
for tenant in tenants_fixture[:2]:
|
||||
with tenant_transaction(str(tenant.id)):
|
||||
role = Role.objects.create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invitations_fixture(create_test_user, tenants_fixture):
|
||||
user = create_test_user
|
||||
@@ -210,6 +379,57 @@ def provider_groups_fixture(tenants_fixture):
|
||||
return pgroup1, pgroup2, pgroup3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roles_fixture(tenants_fixture):
|
||||
tenant, *_ = tenants_fixture
|
||||
role1 = Role.objects.create(
|
||||
name="Role One",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=False,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=False,
|
||||
)
|
||||
role2 = Role.objects.create(
|
||||
name="Role Two",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=False,
|
||||
manage_account=False,
|
||||
manage_billing=False,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
role3 = Role.objects.create(
|
||||
name="Role Three",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
role4 = Role.objects.create(
|
||||
name="Role Four",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=False,
|
||||
manage_account=False,
|
||||
manage_billing=False,
|
||||
manage_providers=False,
|
||||
manage_integrations=False,
|
||||
manage_scans=False,
|
||||
unlimited_visibility=False,
|
||||
)
|
||||
|
||||
return role1, role2, role3, role4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_secret_fixture(providers_fixture):
|
||||
return tuple(
|
||||
@@ -544,3 +764,18 @@ def get_api_tokens(
|
||||
|
||||
def get_authorization_header(access_token: str) -> dict:
|
||||
return {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
"""Ensure test_rbac.py is executed first."""
|
||||
items.sort(key=lambda item: 0 if "test_rbac.py" in item.nodeid else 1)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# Apply the mock before the test session starts. This is necessary to avoid admin error when running the 0004_rbac_missing_admin_roles migration
|
||||
patch("api.db_router.MainRouter.admin_db", new="default").start()
|
||||
|
||||
|
||||
def pytest_unconfigure(config):
|
||||
# Stop all patches after the test session ends. This is necessary to avoid admin error when running the 0004_rbac_missing_admin_roles migration
|
||||
patch.stopall()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from tasks.jobs.deletion import delete_provider, delete_tenant
|
||||
@@ -24,7 +22,6 @@ class TestDeleteProvider:
|
||||
delete_provider(non_existent_pk)
|
||||
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
@pytest.mark.django_db
|
||||
class TestDeleteTenant:
|
||||
def test_delete_tenant_success(self, tenants_fixture, providers_fixture):
|
||||
|
||||
Reference in New Issue
Block a user