mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-22 03:08:23 +00:00
chore(rls): Add tenant_id filters in views and improve querysets (#6211)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
This commit is contained in:
@@ -515,8 +515,8 @@ class Resource(RowLevelSecurityProtectedModel):
|
||||
through="ResourceTagMapping",
|
||||
)
|
||||
|
||||
def get_tags(self) -> dict:
|
||||
return {tag.key: tag.value for tag in self.tags.all()}
|
||||
def get_tags(self, tenant_id: str) -> dict:
|
||||
return {tag.key: tag.value for tag in self.tags.filter(tenant_id=tenant_id)}
|
||||
|
||||
def clear_tags(self):
|
||||
self.tags.clear()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from enum import Enum
|
||||
from rest_framework.permissions import BasePermission
|
||||
from api.models import Provider, Role, User
|
||||
from api.db_router import MainRouter
|
||||
from typing import Optional
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework.permissions import BasePermission
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.models import Provider, Role, User
|
||||
|
||||
|
||||
class Permissions(Enum):
|
||||
@@ -63,8 +65,11 @@ def get_providers(role: Role) -> QuerySet[Provider]:
|
||||
A QuerySet of Provider objects filtered by the role's provider groups.
|
||||
If the role has no provider groups, returns an empty queryset.
|
||||
"""
|
||||
tenant = role.tenant
|
||||
provider_groups = role.provider_groups.all()
|
||||
if not provider_groups.exists():
|
||||
return Provider.objects.none()
|
||||
|
||||
return Provider.objects.filter(provider_groups__in=provider_groups).distinct()
|
||||
return Provider.objects.filter(
|
||||
tenant=tenant, provider_groups__in=provider_groups
|
||||
).distinct()
|
||||
|
||||
@@ -7,9 +7,10 @@ from api.models import Resource, ResourceTag
|
||||
class TestResourceModel:
|
||||
def test_setting_tags(self, providers_fixture):
|
||||
provider, *_ = providers_fixture
|
||||
tenant_id = provider.tenant_id
|
||||
|
||||
resource = Resource.objects.create(
|
||||
tenant_id=provider.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
uid="arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0",
|
||||
name="My Instance 1",
|
||||
@@ -20,12 +21,12 @@ class TestResourceModel:
|
||||
|
||||
tags = [
|
||||
ResourceTag.objects.create(
|
||||
tenant_id=provider.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
key="key",
|
||||
value="value",
|
||||
),
|
||||
ResourceTag.objects.create(
|
||||
tenant_id=provider.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
key="key2",
|
||||
value="value2",
|
||||
),
|
||||
@@ -33,9 +34,9 @@ class TestResourceModel:
|
||||
|
||||
resource.upsert_or_delete_tags(tags)
|
||||
|
||||
assert len(tags) == len(resource.tags.all())
|
||||
assert len(tags) == len(resource.tags.filter(tenant_id=tenant_id))
|
||||
|
||||
tags_dict = resource.get_tags()
|
||||
tags_dict = resource.get_tags(tenant_id=tenant_id)
|
||||
|
||||
for tag in tags:
|
||||
assert tag.key in tags_dict
|
||||
@@ -43,47 +44,51 @@ class TestResourceModel:
|
||||
|
||||
def test_adding_tags(self, resources_fixture):
|
||||
resource, *_ = resources_fixture
|
||||
tenant_id = str(resource.tenant_id)
|
||||
|
||||
tags = [
|
||||
ResourceTag.objects.create(
|
||||
tenant_id=resource.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
key="env",
|
||||
value="test",
|
||||
),
|
||||
]
|
||||
before_count = len(resource.tags.all())
|
||||
before_count = len(resource.tags.filter(tenant_id=tenant_id))
|
||||
|
||||
resource.upsert_or_delete_tags(tags)
|
||||
|
||||
assert before_count + 1 == len(resource.tags.all())
|
||||
assert before_count + 1 == len(resource.tags.filter(tenant_id=tenant_id))
|
||||
|
||||
tags_dict = resource.get_tags()
|
||||
tags_dict = resource.get_tags(tenant_id=tenant_id)
|
||||
|
||||
assert "env" in tags_dict
|
||||
assert tags_dict["env"] == "test"
|
||||
|
||||
def test_adding_duplicate_tags(self, resources_fixture):
|
||||
resource, *_ = resources_fixture
|
||||
tenant_id = str(resource.tenant_id)
|
||||
|
||||
tags = resource.tags.all()
|
||||
tags = resource.tags.filter(tenant_id=tenant_id)
|
||||
|
||||
before_count = len(resource.tags.all())
|
||||
before_count = len(resource.tags.filter(tenant_id=tenant_id))
|
||||
|
||||
resource.upsert_or_delete_tags(tags)
|
||||
|
||||
# should be the same number of tags
|
||||
assert before_count == len(resource.tags.all())
|
||||
assert before_count == len(resource.tags.filter(tenant_id=tenant_id))
|
||||
|
||||
def test_add_tags_none(self, resources_fixture):
|
||||
resource, *_ = resources_fixture
|
||||
tenant_id = str(resource.tenant_id)
|
||||
resource.upsert_or_delete_tags(None)
|
||||
|
||||
assert len(resource.tags.all()) == 0
|
||||
assert resource.get_tags() == {}
|
||||
assert len(resource.tags.filter(tenant_id=tenant_id)) == 0
|
||||
assert resource.get_tags(tenant_id=tenant_id) == {}
|
||||
|
||||
def test_clear_tags(self, resources_fixture):
|
||||
resource, *_ = resources_fixture
|
||||
tenant_id = str(resource.tenant_id)
|
||||
resource.clear_tags()
|
||||
|
||||
assert len(resource.tags.all()) == 0
|
||||
assert resource.get_tags() == {}
|
||||
assert len(resource.tags.filter(tenant_id=tenant_id)) == 0
|
||||
assert resource.get_tags(tenant_id=tenant_id) == {}
|
||||
|
||||
@@ -340,7 +340,7 @@ class TestTenantViewSet:
|
||||
def test_tenants_list(self, authenticated_client, tenants_fixture):
|
||||
response = authenticated_client.get(reverse("tenant-list"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == len(tenants_fixture)
|
||||
assert len(response.json()["data"]) == 2 # Test user belongs to 2 tenants
|
||||
|
||||
def test_tenants_retrieve(self, authenticated_client, tenants_fixture):
|
||||
tenant1, *_ = tenants_fixture
|
||||
@@ -470,11 +470,11 @@ class TestTenantViewSet:
|
||||
(
|
||||
[
|
||||
("name", "Tenant One", 1),
|
||||
("name.icontains", "Tenant", 3),
|
||||
("inserted_at", TODAY, 3),
|
||||
("inserted_at.gte", "2024-01-01", 3),
|
||||
("name.icontains", "Tenant", 2),
|
||||
("inserted_at", TODAY, 2),
|
||||
("inserted_at.gte", "2024-01-01", 2),
|
||||
("inserted_at.lte", "2024-01-01", 0),
|
||||
("updated_at.gte", "2024-01-01", 3),
|
||||
("updated_at.gte", "2024-01-01", 2),
|
||||
("updated_at.lte", "2024-01-01", 0),
|
||||
]
|
||||
),
|
||||
@@ -510,7 +510,9 @@ class TestTenantViewSet:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == page_size
|
||||
assert response.json()["meta"]["pagination"]["page"] == 1
|
||||
assert response.json()["meta"]["pagination"]["pages"] == len(tenants_fixture)
|
||||
assert (
|
||||
response.json()["meta"]["pagination"]["pages"] == 2
|
||||
) # Test user belongs to 2 tenants
|
||||
|
||||
def test_tenants_list_page_number(self, authenticated_client, tenants_fixture):
|
||||
page_size = 1
|
||||
@@ -523,13 +525,13 @@ class TestTenantViewSet:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == page_size
|
||||
assert response.json()["meta"]["pagination"]["page"] == page_number
|
||||
assert response.json()["meta"]["pagination"]["pages"] == len(tenants_fixture)
|
||||
assert response.json()["meta"]["pagination"]["pages"] == 2
|
||||
|
||||
def test_tenants_list_sort_name(self, authenticated_client, tenants_fixture):
|
||||
_, tenant2, _ = tenants_fixture
|
||||
response = authenticated_client.get(reverse("tenant-list"), {"sort": "-name"})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 3
|
||||
assert len(response.json()["data"]) == 2
|
||||
assert response.json()["data"][0]["attributes"]["name"] == tenant2.name
|
||||
|
||||
def test_tenants_list_memberships_as_owner(
|
||||
@@ -2339,7 +2341,10 @@ class TestResourceViewSet:
|
||||
response.json()["errors"][0]["detail"] == "invalid sort parameter: invalid"
|
||||
)
|
||||
|
||||
def test_resources_retrieve(self, authenticated_client, resources_fixture):
|
||||
def test_resources_retrieve(
|
||||
self, authenticated_client, tenants_fixture, resources_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
resource_1, *_ = resources_fixture
|
||||
response = authenticated_client.get(
|
||||
reverse("resource-detail", kwargs={"pk": resource_1.id}),
|
||||
@@ -2350,7 +2355,9 @@ class TestResourceViewSet:
|
||||
assert response.json()["data"]["attributes"]["region"] == resource_1.region
|
||||
assert response.json()["data"]["attributes"]["service"] == resource_1.service
|
||||
assert response.json()["data"]["attributes"]["type"] == resource_1.type
|
||||
assert response.json()["data"]["attributes"]["tags"] == resource_1.get_tags()
|
||||
assert response.json()["data"]["attributes"]["tags"] == resource_1.get_tags(
|
||||
tenant_id=str(tenant.id)
|
||||
)
|
||||
|
||||
def test_resources_invalid_retrieve(self, authenticated_client):
|
||||
response = authenticated_client.get(
|
||||
@@ -3261,8 +3268,8 @@ class TestRoleViewSet:
|
||||
response = authenticated_client.get(reverse("role-list"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert (
|
||||
len(response.json()["data"]) == len(roles_fixture) + 2
|
||||
) # 2 default admin roles, one for each tenant
|
||||
len(response.json()["data"]) == len(roles_fixture) + 1
|
||||
) # 1 default admin role
|
||||
|
||||
def test_role_retrieve(self, authenticated_client, roles_fixture):
|
||||
role = roles_fixture[0]
|
||||
|
||||
@@ -874,7 +874,7 @@ class ResourceSerializer(RLSSerializer):
|
||||
}
|
||||
)
|
||||
def get_tags(self, obj):
|
||||
return obj.get_tags()
|
||||
return obj.get_tags(self.context.get("tenant_id"))
|
||||
|
||||
def get_fields(self):
|
||||
"""`type` is a Python reserved keyword."""
|
||||
@@ -1233,6 +1233,12 @@ class InvitationSerializer(RLSSerializer):
|
||||
|
||||
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
if tenant_id is not None:
|
||||
self.fields["roles"].queryset = Role.objects.filter(tenant_id=tenant_id)
|
||||
|
||||
class Meta:
|
||||
model = Invitation
|
||||
fields = [
|
||||
@@ -1252,6 +1258,12 @@ class InvitationSerializer(RLSSerializer):
|
||||
class InvitationBaseWriteSerializer(BaseWriteSerializer):
|
||||
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
if tenant_id is not None:
|
||||
self.fields["roles"].queryset = Role.objects.filter(tenant_id=tenant_id)
|
||||
|
||||
def validate_email(self, value):
|
||||
user = User.objects.filter(email=value).first()
|
||||
tenant_id = self.context["tenant_id"]
|
||||
@@ -1367,6 +1379,17 @@ class RoleSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
queryset=ProviderGroup.objects.all(), many=True, required=False
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
if tenant_id is not None:
|
||||
self.fields["users"].queryset = User.objects.filter(
|
||||
membership__tenant__id=tenant_id
|
||||
)
|
||||
self.fields["provider_groups"].queryset = ProviderGroup.objects.filter(
|
||||
tenant_id=self.context.get("tenant_id")
|
||||
)
|
||||
|
||||
def get_permission_state(self, obj) -> str:
|
||||
return obj.permission_state
|
||||
|
||||
|
||||
@@ -309,7 +309,12 @@ class UserViewSet(BaseUserViewset):
|
||||
# If called during schema generation, return an empty queryset
|
||||
if getattr(self, "swagger_fake_view", False):
|
||||
return User.objects.none()
|
||||
return User.objects.filter(membership__tenant__id=self.request.tenant_id)
|
||||
queryset = (
|
||||
User.objects.filter(membership__tenant__id=self.request.tenant_id)
|
||||
if hasattr(self.request, "tenant_id")
|
||||
else User.objects.all()
|
||||
)
|
||||
return queryset.prefetch_related("memberships", "roles")
|
||||
|
||||
def get_permissions(self):
|
||||
if self.action == "create":
|
||||
@@ -452,7 +457,7 @@ class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
required_permissions = [Permissions.MANAGE_USERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return User.objects.all()
|
||||
return User.objects.filter(membership__tenant__id=self.request.tenant_id)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
user = self.get_object()
|
||||
@@ -540,7 +545,8 @@ class TenantViewSet(BaseTenantViewset):
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Tenant.objects.all()
|
||||
queryset = Tenant.objects.filter(membership__user=self.request.user)
|
||||
return queryset.prefetch_related("memberships")
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
@@ -600,7 +606,8 @@ class MembershipViewSet(BaseTenantViewset):
|
||||
|
||||
def get_queryset(self):
|
||||
user = self.request.user
|
||||
return Membership.objects.filter(user_id=user.id)
|
||||
queryset = Membership.objects.filter(user_id=user.id)
|
||||
return queryset.select_related("user", "tenant")
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
@@ -736,10 +743,10 @@ class ProviderGroupViewSet(BaseRLSViewSet):
|
||||
# Check if any of the user's roles have UNLIMITED_VISIBILITY
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all provider groups
|
||||
return ProviderGroup.objects.prefetch_related("providers")
|
||||
return ProviderGroup.objects.prefetch_related("providers", "roles")
|
||||
|
||||
# Collect provider groups associated with the user's roles
|
||||
return user_roles.provider_groups.all()
|
||||
return user_roles.provider_groups.all().prefetch_related("providers", "roles")
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -790,7 +797,7 @@ class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return ProviderGroup.objects.all()
|
||||
return ProviderGroup.objects.filter(tenant_id=self.request.tenant_id)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
provider_group = self.get_object()
|
||||
@@ -904,10 +911,11 @@ class ProviderViewSet(BaseRLSViewSet):
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all providers
|
||||
return Provider.objects.all()
|
||||
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
return get_providers(user_roles)
|
||||
queryset = Provider.objects.filter(tenant_id=self.request.tenant_id)
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
queryset = get_providers(user_roles)
|
||||
return queryset.select_related("secret").prefetch_related("provider_groups")
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -945,7 +953,7 @@ class ProviderViewSet(BaseRLSViewSet):
|
||||
get_object_or_404(Provider, pk=pk)
|
||||
with transaction.atomic():
|
||||
task = check_provider_connection_task.delay(
|
||||
provider_id=pk, tenant_id=request.tenant_id
|
||||
provider_id=pk, tenant_id=self.request.tenant_id
|
||||
)
|
||||
prowler_task = Task.objects.get(id=task.id)
|
||||
serializer = TaskSerializer(prowler_task)
|
||||
@@ -966,7 +974,7 @@ class ProviderViewSet(BaseRLSViewSet):
|
||||
|
||||
with transaction.atomic():
|
||||
task = delete_provider_task.delay(
|
||||
provider_id=pk, tenant_id=request.tenant_id
|
||||
provider_id=pk, tenant_id=self.request.tenant_id
|
||||
)
|
||||
prowler_task = Task.objects.get(id=task.id)
|
||||
serializer = TaskSerializer(prowler_task)
|
||||
@@ -1045,10 +1053,11 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all scans
|
||||
return Scan.objects.all()
|
||||
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
return Scan.objects.filter(provider__in=get_providers(user_roles))
|
||||
queryset = Scan.objects.filter(tenant_id=self.request.tenant_id)
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
queryset = Scan.objects.filter(provider__in=get_providers(user_roles))
|
||||
return queryset.select_related("provider", "task")
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -1082,14 +1091,14 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
with transaction.atomic():
|
||||
task = perform_scan_task.apply_async(
|
||||
kwargs={
|
||||
"tenant_id": request.tenant_id,
|
||||
"tenant_id": self.request.tenant_id,
|
||||
"scan_id": str(scan.id),
|
||||
"provider_id": str(scan.provider_id),
|
||||
# Disabled for now
|
||||
# checks_to_execute=scan.scanner_args.get("checks_to_execute"),
|
||||
},
|
||||
link=perform_scan_summary_task.si(
|
||||
tenant_id=request.tenant_id,
|
||||
tenant_id=self.request.tenant_id,
|
||||
scan_id=str(scan.id),
|
||||
),
|
||||
)
|
||||
@@ -1145,7 +1154,7 @@ class TaskViewSet(BaseRLSViewSet):
|
||||
return Task.objects.annotate(
|
||||
name=F("task_runner_task__task_name"),
|
||||
state=F("task_runner_task__status"),
|
||||
)
|
||||
).select_related("task_runner_task")
|
||||
|
||||
def destroy(self, request, *args, pk=None, **kwargs):
|
||||
task = get_object_or_404(Task, pk=pk)
|
||||
@@ -1206,17 +1215,20 @@ class ResourceViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
|
||||
# the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all scans
|
||||
queryset = Resource.objects.all()
|
||||
queryset = Resource.objects.filter(tenant_id=self.request.tenant_id)
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
queryset = Resource.objects.filter(provider__in=get_providers(user_roles))
|
||||
queryset = Resource.objects.filter(
|
||||
tenant_id=self.request.tenant_id, provider__in=get_providers(user_roles)
|
||||
)
|
||||
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
if search_value:
|
||||
@@ -1289,7 +1301,8 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
|
||||
# the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_serializer_class(self):
|
||||
@@ -1302,7 +1315,7 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all scans
|
||||
queryset = Finding.objects.all()
|
||||
queryset = Finding.objects.filter(tenant_id=self.request.tenant_id)
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
queryset = Finding.objects.filter(
|
||||
@@ -1409,7 +1422,7 @@ class ProviderSecretViewSet(BaseRLSViewSet):
|
||||
required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return ProviderSecret.objects.all()
|
||||
return ProviderSecret.objects.filter(tenant_id=self.request.tenant_id)
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -1468,7 +1481,7 @@ class InvitationViewSet(BaseRLSViewSet):
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Invitation.objects.all()
|
||||
return Invitation.objects.filter(tenant_id=self.request.tenant_id)
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -1515,7 +1528,7 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
|
||||
http_method_names = ["post"]
|
||||
|
||||
def get_queryset(self):
|
||||
return Invitation.objects.all()
|
||||
return Invitation.objects.filter(tenant_id=self.request.tenant_id)
|
||||
|
||||
def get_serializer_class(self):
|
||||
if hasattr(self, "response_serializer_class"):
|
||||
@@ -1607,7 +1620,7 @@ class RoleViewSet(BaseRLSViewSet):
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Role.objects.all()
|
||||
return Role.objects.filter(tenant_id=self.request.tenant_id)
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -1628,7 +1641,8 @@ class RoleViewSet(BaseRLSViewSet):
|
||||
create=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Create a new role-provider_groups relationship",
|
||||
description="Add a new role-provider_groups relationship to the system by providing the required role-provider_groups details.",
|
||||
description="Add a new role-provider_groups relationship to the system by providing the required "
|
||||
"role-provider_groups details.",
|
||||
responses={
|
||||
204: OpenApiResponse(description="Relationship created successfully"),
|
||||
400: OpenApiResponse(
|
||||
@@ -1667,7 +1681,7 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Role.objects.all()
|
||||
return Role.objects.filter(tenant_id=self.request.tenant_id)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
@@ -1750,7 +1764,8 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
|
||||
search_fields = ["compliance_id"]
|
||||
ordering = ["compliance_id"]
|
||||
ordering_fields = ["inserted_at", "compliance_id", "framework", "region"]
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
|
||||
# the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
@@ -1761,20 +1776,28 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
|
||||
|
||||
if self.action == "retrieve":
|
||||
if unlimited_visibility:
|
||||
# User has unlimited visibility, return all compliance compliances
|
||||
return ComplianceOverview.objects.all()
|
||||
# User has unlimited visibility, return all compliance
|
||||
return ComplianceOverview.objects.filter(
|
||||
tenant_id=self.request.tenant_id
|
||||
)
|
||||
|
||||
providers = get_providers(role)
|
||||
return ComplianceOverview.objects.filter(scan__provider__in=providers)
|
||||
return ComplianceOverview.objects.filter(
|
||||
tenant_id=self.request.tenant_id, scan__provider__in=providers
|
||||
)
|
||||
|
||||
if unlimited_visibility:
|
||||
base_queryset = self.filter_queryset(ComplianceOverview.objects.all())
|
||||
base_queryset = self.filter_queryset(
|
||||
ComplianceOverview.objects.filter(tenant_id=self.request.tenant_id)
|
||||
)
|
||||
else:
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=role.provider_groups.all()
|
||||
).distinct()
|
||||
base_queryset = self.filter_queryset(
|
||||
ComplianceOverview.objects.filter(scan__provider__in=providers)
|
||||
ComplianceOverview.objects.filter(
|
||||
tenant_id=self.request.tenant_id, scan__provider__in=providers
|
||||
)
|
||||
)
|
||||
|
||||
max_failed_ids = (
|
||||
@@ -1853,7 +1876,8 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
queryset = ComplianceOverview.objects.all()
|
||||
http_method_names = ["get"]
|
||||
ordering = ["-id"]
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
|
||||
# the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
@@ -1862,8 +1886,10 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
|
||||
def _get_filtered_queryset(model):
|
||||
if role.unlimited_visibility:
|
||||
return model.objects.all()
|
||||
return model.objects.filter(scan__provider__in=providers)
|
||||
return model.objects.filter(tenant_id=self.request.tenant_id)
|
||||
return model.objects.filter(
|
||||
tenant_id=self.request.tenant_id, scan__provider__in=providers
|
||||
)
|
||||
|
||||
if self.action == "providers":
|
||||
return _get_filtered_queryset(Finding)
|
||||
@@ -1902,17 +1928,22 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="providers")
|
||||
def providers(self, request):
|
||||
tenant_id = self.request.tenant_id
|
||||
# Subquery to get the most recent finding for each uid
|
||||
latest_finding_ids = (
|
||||
Finding.objects.filter(
|
||||
uid=OuterRef("uid"), scan__provider=OuterRef("scan__provider")
|
||||
tenant_id=tenant_id,
|
||||
uid=OuterRef("uid"),
|
||||
scan__provider=OuterRef("scan__provider"),
|
||||
)
|
||||
.order_by("-id") # Most recent by id
|
||||
.values("id")[:1]
|
||||
)
|
||||
|
||||
# Filter findings to only include the most recent for each uid
|
||||
recent_findings = Finding.objects.filter(id__in=Subquery(latest_finding_ids))
|
||||
recent_findings = Finding.objects.filter(
|
||||
tenant_id=tenant_id, id__in=Subquery(latest_finding_ids)
|
||||
)
|
||||
|
||||
# Aggregate findings by provider
|
||||
findings_aggregated = (
|
||||
@@ -1929,8 +1960,10 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
)
|
||||
|
||||
# Aggregate total resources by provider
|
||||
resources_aggregated = Resource.objects.values("provider__provider").annotate(
|
||||
total_resources=Count("id")
|
||||
resources_aggregated = (
|
||||
Resource.objects.filter(tenant_id=tenant_id)
|
||||
.values("provider__provider")
|
||||
.annotate(total_resources=Count("id"))
|
||||
)
|
||||
|
||||
# Combine findings and resources data
|
||||
@@ -1962,12 +1995,15 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="findings")
|
||||
def findings(self, request):
|
||||
tenant_id = self.request.tenant_id
|
||||
queryset = self.get_queryset()
|
||||
filtered_queryset = self.filter_queryset(queryset)
|
||||
|
||||
latest_scan_subquery = (
|
||||
Scan.objects.filter(
|
||||
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
|
||||
tenant_id=tenant_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
provider_id=OuterRef("scan__provider_id"),
|
||||
)
|
||||
.order_by("-id")
|
||||
.values("id")[:1]
|
||||
@@ -2004,12 +2040,15 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="findings_severity")
|
||||
def findings_severity(self, request):
|
||||
tenant_id = self.request.tenant_id
|
||||
queryset = self.get_queryset()
|
||||
filtered_queryset = self.filter_queryset(queryset)
|
||||
|
||||
latest_scan_subquery = (
|
||||
Scan.objects.filter(
|
||||
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
|
||||
tenant_id=tenant_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
provider_id=OuterRef("scan__provider_id"),
|
||||
)
|
||||
.order_by("-id")
|
||||
.values("id")[:1]
|
||||
@@ -2037,12 +2076,15 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="services")
|
||||
def services(self, request):
|
||||
tenant_id = self.request.tenant_id
|
||||
queryset = self.get_queryset()
|
||||
filtered_queryset = self.filter_queryset(queryset)
|
||||
|
||||
latest_scan_subquery = (
|
||||
Scan.objects.filter(
|
||||
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
|
||||
tenant_id=tenant_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
provider_id=OuterRef("scan__provider_id"),
|
||||
)
|
||||
.order_by("-id")
|
||||
.values("id")[:1]
|
||||
|
||||
@@ -88,16 +88,14 @@ def create_test_user(django_db_setup, django_db_blocker):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac(django_db_setup, django_db_blocker):
|
||||
def create_test_user_rbac(django_db_setup, django_db_blocker, tenants_fixture):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing",
|
||||
email="rbac@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
tenant = tenants_fixture[0]
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
@@ -123,16 +121,14 @@ def create_test_user_rbac(django_db_setup, django_db_blocker):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker):
|
||||
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker, tenants_fixture):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing",
|
||||
email="rbac_noroles@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
tenant = tenants_fixture[0]
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
@@ -180,10 +176,16 @@ def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
|
||||
@pytest.fixture
|
||||
def authenticated_client_rbac(create_test_user_rbac, tenants_fixture, client):
|
||||
client.user = create_test_user_rbac
|
||||
tenant_id = tenants_fixture[0].id
|
||||
serializer = TokenSerializer(
|
||||
data={"type": "tokens", "email": "rbac@rbac.com", "password": TEST_PASSWORD}
|
||||
data={
|
||||
"type": "tokens",
|
||||
"email": "rbac@rbac.com",
|
||||
"password": TEST_PASSWORD,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
)
|
||||
serializer.is_valid()
|
||||
serializer.is_valid(raise_exception=True)
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
return client
|
||||
@@ -303,7 +305,7 @@ def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
|
||||
@pytest.fixture
|
||||
def invitations_fixture(create_test_user, tenants_fixture):
|
||||
user = create_test_user
|
||||
*_, tenant = tenants_fixture
|
||||
tenant = tenants_fixture[0]
|
||||
valid_invitation = Invitation.objects.create(
|
||||
email="testing@prowler.com",
|
||||
state=Invitation.State.PENDING,
|
||||
|
||||
Reference in New Issue
Block a user