From b620f1202783cd8a8acb46d6d15c7fe35203af7e Mon Sep 17 00:00:00 2001 From: Pepe Fagoaga Date: Mon, 13 Jan 2025 16:22:40 +0545 Subject: [PATCH] chore(rls): Add tenant_id filters in views and improve querysets (#6211) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Fernández Poyatos --- api/src/backend/api/models.py | 4 +- api/src/backend/api/rbac/permissions.py | 13 ++- api/src/backend/api/tests/test_models.py | 37 +++--- api/src/backend/api/tests/test_views.py | 31 ++++-- api/src/backend/api/v1/serializers.py | 25 ++++- api/src/backend/api/v1/views.py | 136 +++++++++++++++-------- api/src/backend/conftest.py | 24 ++-- 7 files changed, 177 insertions(+), 93 deletions(-) diff --git a/api/src/backend/api/models.py b/api/src/backend/api/models.py index daecae276b..3c240a2f7c 100644 --- a/api/src/backend/api/models.py +++ b/api/src/backend/api/models.py @@ -515,8 +515,8 @@ class Resource(RowLevelSecurityProtectedModel): through="ResourceTagMapping", ) - def get_tags(self) -> dict: - return {tag.key: tag.value for tag in self.tags.all()} + def get_tags(self, tenant_id: str) -> dict: + return {tag.key: tag.value for tag in self.tags.filter(tenant_id=tenant_id)} def clear_tags(self): self.tags.clear() diff --git a/api/src/backend/api/rbac/permissions.py b/api/src/backend/api/rbac/permissions.py index abc435846b..6a95e82932 100644 --- a/api/src/backend/api/rbac/permissions.py +++ b/api/src/backend/api/rbac/permissions.py @@ -1,9 +1,11 @@ from enum import Enum -from rest_framework.permissions import BasePermission -from api.models import Provider, Role, User -from api.db_router import MainRouter from typing import Optional + from django.db.models import QuerySet +from rest_framework.permissions import BasePermission + +from api.db_router import MainRouter +from api.models import Provider, Role, User class Permissions(Enum): @@ -63,8 +65,11 @@ def get_providers(role: Role) -> QuerySet[Provider]: A QuerySet of Provider objects filtered by the role's provider groups. If the role has no provider groups, returns an empty queryset. """ + tenant = role.tenant provider_groups = role.provider_groups.all() if not provider_groups.exists(): return Provider.objects.none() - return Provider.objects.filter(provider_groups__in=provider_groups).distinct() + return Provider.objects.filter( + tenant=tenant, provider_groups__in=provider_groups + ).distinct() diff --git a/api/src/backend/api/tests/test_models.py b/api/src/backend/api/tests/test_models.py index c7fdf9deb1..b6ef1a66c9 100644 --- a/api/src/backend/api/tests/test_models.py +++ b/api/src/backend/api/tests/test_models.py @@ -7,9 +7,10 @@ from api.models import Resource, ResourceTag class TestResourceModel: def test_setting_tags(self, providers_fixture): provider, *_ = providers_fixture + tenant_id = provider.tenant_id resource = Resource.objects.create( - tenant_id=provider.tenant_id, + tenant_id=tenant_id, provider=provider, uid="arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0", name="My Instance 1", @@ -20,12 +21,12 @@ class TestResourceModel: tags = [ ResourceTag.objects.create( - tenant_id=provider.tenant_id, + tenant_id=tenant_id, key="key", value="value", ), ResourceTag.objects.create( - tenant_id=provider.tenant_id, + tenant_id=tenant_id, key="key2", value="value2", ), @@ -33,9 +34,9 @@ class TestResourceModel: resource.upsert_or_delete_tags(tags) - assert len(tags) == len(resource.tags.all()) + assert len(tags) == len(resource.tags.filter(tenant_id=tenant_id)) - tags_dict = resource.get_tags() + tags_dict = resource.get_tags(tenant_id=tenant_id) for tag in tags: assert tag.key in tags_dict @@ -43,47 +44,51 @@ class TestResourceModel: def test_adding_tags(self, resources_fixture): resource, *_ = resources_fixture + tenant_id = str(resource.tenant_id) tags = [ ResourceTag.objects.create( - tenant_id=resource.tenant_id, + tenant_id=tenant_id, key="env", value="test", ), ] - before_count = len(resource.tags.all()) + before_count = len(resource.tags.filter(tenant_id=tenant_id)) resource.upsert_or_delete_tags(tags) - assert before_count + 1 == len(resource.tags.all()) + assert before_count + 1 == len(resource.tags.filter(tenant_id=tenant_id)) - tags_dict = resource.get_tags() + tags_dict = resource.get_tags(tenant_id=tenant_id) assert "env" in tags_dict assert tags_dict["env"] == "test" def test_adding_duplicate_tags(self, resources_fixture): resource, *_ = resources_fixture + tenant_id = str(resource.tenant_id) - tags = resource.tags.all() + tags = resource.tags.filter(tenant_id=tenant_id) - before_count = len(resource.tags.all()) + before_count = len(resource.tags.filter(tenant_id=tenant_id)) resource.upsert_or_delete_tags(tags) # should be the same number of tags - assert before_count == len(resource.tags.all()) + assert before_count == len(resource.tags.filter(tenant_id=tenant_id)) def test_add_tags_none(self, resources_fixture): resource, *_ = resources_fixture + tenant_id = str(resource.tenant_id) resource.upsert_or_delete_tags(None) - assert len(resource.tags.all()) == 0 - assert resource.get_tags() == {} + assert len(resource.tags.filter(tenant_id=tenant_id)) == 0 + assert resource.get_tags(tenant_id=tenant_id) == {} def test_clear_tags(self, resources_fixture): resource, *_ = resources_fixture + tenant_id = str(resource.tenant_id) resource.clear_tags() - assert len(resource.tags.all()) == 0 - assert resource.get_tags() == {} + assert len(resource.tags.filter(tenant_id=tenant_id)) == 0 + assert resource.get_tags(tenant_id=tenant_id) == {} diff --git a/api/src/backend/api/tests/test_views.py b/api/src/backend/api/tests/test_views.py index 73677da6a6..9b057bcdd2 100644 --- a/api/src/backend/api/tests/test_views.py +++ b/api/src/backend/api/tests/test_views.py @@ -340,7 +340,7 @@ class TestTenantViewSet: def test_tenants_list(self, authenticated_client, tenants_fixture): response = authenticated_client.get(reverse("tenant-list")) assert response.status_code == status.HTTP_200_OK - assert len(response.json()["data"]) == len(tenants_fixture) + assert len(response.json()["data"]) == 2 # Test user belongs to 2 tenants def test_tenants_retrieve(self, authenticated_client, tenants_fixture): tenant1, *_ = tenants_fixture @@ -470,11 +470,11 @@ class TestTenantViewSet: ( [ ("name", "Tenant One", 1), - ("name.icontains", "Tenant", 3), - ("inserted_at", TODAY, 3), - ("inserted_at.gte", "2024-01-01", 3), + ("name.icontains", "Tenant", 2), + ("inserted_at", TODAY, 2), + ("inserted_at.gte", "2024-01-01", 2), ("inserted_at.lte", "2024-01-01", 0), - ("updated_at.gte", "2024-01-01", 3), + ("updated_at.gte", "2024-01-01", 2), ("updated_at.lte", "2024-01-01", 0), ] ), @@ -510,7 +510,9 @@ class TestTenantViewSet: assert response.status_code == status.HTTP_200_OK assert len(response.json()["data"]) == page_size assert response.json()["meta"]["pagination"]["page"] == 1 - assert response.json()["meta"]["pagination"]["pages"] == len(tenants_fixture) + assert ( + response.json()["meta"]["pagination"]["pages"] == 2 + ) # Test user belongs to 2 tenants def test_tenants_list_page_number(self, authenticated_client, tenants_fixture): page_size = 1 @@ -523,13 +525,13 @@ class TestTenantViewSet: assert response.status_code == status.HTTP_200_OK assert len(response.json()["data"]) == page_size assert response.json()["meta"]["pagination"]["page"] == page_number - assert response.json()["meta"]["pagination"]["pages"] == len(tenants_fixture) + assert response.json()["meta"]["pagination"]["pages"] == 2 def test_tenants_list_sort_name(self, authenticated_client, tenants_fixture): _, tenant2, _ = tenants_fixture response = authenticated_client.get(reverse("tenant-list"), {"sort": "-name"}) assert response.status_code == status.HTTP_200_OK - assert len(response.json()["data"]) == 3 + assert len(response.json()["data"]) == 2 assert response.json()["data"][0]["attributes"]["name"] == tenant2.name def test_tenants_list_memberships_as_owner( @@ -2339,7 +2341,10 @@ class TestResourceViewSet: response.json()["errors"][0]["detail"] == "invalid sort parameter: invalid" ) - def test_resources_retrieve(self, authenticated_client, resources_fixture): + def test_resources_retrieve( + self, authenticated_client, tenants_fixture, resources_fixture + ): + tenant = tenants_fixture[0] resource_1, *_ = resources_fixture response = authenticated_client.get( reverse("resource-detail", kwargs={"pk": resource_1.id}), @@ -2350,7 +2355,9 @@ class TestResourceViewSet: assert response.json()["data"]["attributes"]["region"] == resource_1.region assert response.json()["data"]["attributes"]["service"] == resource_1.service assert response.json()["data"]["attributes"]["type"] == resource_1.type - assert response.json()["data"]["attributes"]["tags"] == resource_1.get_tags() + assert response.json()["data"]["attributes"]["tags"] == resource_1.get_tags( + tenant_id=str(tenant.id) + ) def test_resources_invalid_retrieve(self, authenticated_client): response = authenticated_client.get( @@ -3261,8 +3268,8 @@ class TestRoleViewSet: response = authenticated_client.get(reverse("role-list")) assert response.status_code == status.HTTP_200_OK assert ( - len(response.json()["data"]) == len(roles_fixture) + 2 - ) # 2 default admin roles, one for each tenant + len(response.json()["data"]) == len(roles_fixture) + 1 + ) # 1 default admin role def test_role_retrieve(self, authenticated_client, roles_fixture): role = roles_fixture[0] diff --git a/api/src/backend/api/v1/serializers.py b/api/src/backend/api/v1/serializers.py index a2aec3d64c..2cd20c941b 100644 --- a/api/src/backend/api/v1/serializers.py +++ b/api/src/backend/api/v1/serializers.py @@ -874,7 +874,7 @@ class ResourceSerializer(RLSSerializer): } ) def get_tags(self, obj): - return obj.get_tags() + return obj.get_tags(self.context.get("tenant_id")) def get_fields(self): """`type` is a Python reserved keyword.""" @@ -1233,6 +1233,12 @@ class InvitationSerializer(RLSSerializer): roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all()) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + tenant_id = self.context.get("tenant_id") + if tenant_id is not None: + self.fields["roles"].queryset = Role.objects.filter(tenant_id=tenant_id) + class Meta: model = Invitation fields = [ @@ -1252,6 +1258,12 @@ class InvitationSerializer(RLSSerializer): class InvitationBaseWriteSerializer(BaseWriteSerializer): roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all()) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + tenant_id = self.context.get("tenant_id") + if tenant_id is not None: + self.fields["roles"].queryset = Role.objects.filter(tenant_id=tenant_id) + def validate_email(self, value): user = User.objects.filter(email=value).first() tenant_id = self.context["tenant_id"] @@ -1367,6 +1379,17 @@ class RoleSerializer(RLSSerializer, BaseWriteSerializer): queryset=ProviderGroup.objects.all(), many=True, required=False ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + tenant_id = self.context.get("tenant_id") + if tenant_id is not None: + self.fields["users"].queryset = User.objects.filter( + membership__tenant__id=tenant_id + ) + self.fields["provider_groups"].queryset = ProviderGroup.objects.filter( + tenant_id=self.context.get("tenant_id") + ) + def get_permission_state(self, obj) -> str: return obj.permission_state diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index c23a650ad9..ce24a2f830 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -309,7 +309,12 @@ class UserViewSet(BaseUserViewset): # If called during schema generation, return an empty queryset if getattr(self, "swagger_fake_view", False): return User.objects.none() - return User.objects.filter(membership__tenant__id=self.request.tenant_id) + queryset = ( + User.objects.filter(membership__tenant__id=self.request.tenant_id) + if hasattr(self.request, "tenant_id") + else User.objects.all() + ) + return queryset.prefetch_related("memberships", "roles") def get_permissions(self): if self.action == "create": @@ -452,7 +457,7 @@ class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet): required_permissions = [Permissions.MANAGE_USERS] def get_queryset(self): - return User.objects.all() + return User.objects.filter(membership__tenant__id=self.request.tenant_id) def create(self, request, *args, **kwargs): user = self.get_object() @@ -540,7 +545,8 @@ class TenantViewSet(BaseTenantViewset): required_permissions = [Permissions.MANAGE_ACCOUNT] def get_queryset(self): - return Tenant.objects.all() + queryset = Tenant.objects.filter(membership__user=self.request.user) + return queryset.prefetch_related("memberships") def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) @@ -600,7 +606,8 @@ class MembershipViewSet(BaseTenantViewset): def get_queryset(self): user = self.request.user - return Membership.objects.filter(user_id=user.id) + queryset = Membership.objects.filter(user_id=user.id) + return queryset.select_related("user", "tenant") @extend_schema_view( @@ -736,10 +743,10 @@ class ProviderGroupViewSet(BaseRLSViewSet): # Check if any of the user's roles have UNLIMITED_VISIBILITY if user_roles.unlimited_visibility: # User has unlimited visibility, return all provider groups - return ProviderGroup.objects.prefetch_related("providers") + return ProviderGroup.objects.prefetch_related("providers", "roles") # Collect provider groups associated with the user's roles - return user_roles.provider_groups.all() + return user_roles.provider_groups.all().prefetch_related("providers", "roles") def get_serializer_class(self): if self.action == "create": @@ -790,7 +797,7 @@ class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet): required_permissions = [Permissions.MANAGE_PROVIDERS] def get_queryset(self): - return ProviderGroup.objects.all() + return ProviderGroup.objects.filter(tenant_id=self.request.tenant_id) def create(self, request, *args, **kwargs): provider_group = self.get_object() @@ -904,10 +911,11 @@ class ProviderViewSet(BaseRLSViewSet): user_roles = get_role(self.request.user) if user_roles.unlimited_visibility: # User has unlimited visibility, return all providers - return Provider.objects.all() - - # User lacks permission, filter providers based on provider groups associated with the role - return get_providers(user_roles) + queryset = Provider.objects.filter(tenant_id=self.request.tenant_id) + else: + # User lacks permission, filter providers based on provider groups associated with the role + queryset = get_providers(user_roles) + return queryset.select_related("secret").prefetch_related("provider_groups") def get_serializer_class(self): if self.action == "create": @@ -945,7 +953,7 @@ class ProviderViewSet(BaseRLSViewSet): get_object_or_404(Provider, pk=pk) with transaction.atomic(): task = check_provider_connection_task.delay( - provider_id=pk, tenant_id=request.tenant_id + provider_id=pk, tenant_id=self.request.tenant_id ) prowler_task = Task.objects.get(id=task.id) serializer = TaskSerializer(prowler_task) @@ -966,7 +974,7 @@ class ProviderViewSet(BaseRLSViewSet): with transaction.atomic(): task = delete_provider_task.delay( - provider_id=pk, tenant_id=request.tenant_id + provider_id=pk, tenant_id=self.request.tenant_id ) prowler_task = Task.objects.get(id=task.id) serializer = TaskSerializer(prowler_task) @@ -1045,10 +1053,11 @@ class ScanViewSet(BaseRLSViewSet): user_roles = get_role(self.request.user) if user_roles.unlimited_visibility: # User has unlimited visibility, return all scans - return Scan.objects.all() - - # User lacks permission, filter providers based on provider groups associated with the role - return Scan.objects.filter(provider__in=get_providers(user_roles)) + queryset = Scan.objects.filter(tenant_id=self.request.tenant_id) + else: + # User lacks permission, filter providers based on provider groups associated with the role + queryset = Scan.objects.filter(provider__in=get_providers(user_roles)) + return queryset.select_related("provider", "task") def get_serializer_class(self): if self.action == "create": @@ -1082,14 +1091,14 @@ class ScanViewSet(BaseRLSViewSet): with transaction.atomic(): task = perform_scan_task.apply_async( kwargs={ - "tenant_id": request.tenant_id, + "tenant_id": self.request.tenant_id, "scan_id": str(scan.id), "provider_id": str(scan.provider_id), # Disabled for now # checks_to_execute=scan.scanner_args.get("checks_to_execute"), }, link=perform_scan_summary_task.si( - tenant_id=request.tenant_id, + tenant_id=self.request.tenant_id, scan_id=str(scan.id), ), ) @@ -1145,7 +1154,7 @@ class TaskViewSet(BaseRLSViewSet): return Task.objects.annotate( name=F("task_runner_task__task_name"), state=F("task_runner_task__status"), - ) + ).select_related("task_runner_task") def destroy(self, request, *args, pk=None, **kwargs): task = get_object_or_404(Task, pk=pk) @@ -1206,17 +1215,20 @@ class ResourceViewSet(BaseRLSViewSet): "inserted_at", "updated_at", ] - # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group) + # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of + # the provider through the provider group) required_permissions = [] def get_queryset(self): user_roles = get_role(self.request.user) if user_roles.unlimited_visibility: # User has unlimited visibility, return all scans - queryset = Resource.objects.all() + queryset = Resource.objects.filter(tenant_id=self.request.tenant_id) else: # User lacks permission, filter providers based on provider groups associated with the role - queryset = Resource.objects.filter(provider__in=get_providers(user_roles)) + queryset = Resource.objects.filter( + tenant_id=self.request.tenant_id, provider__in=get_providers(user_roles) + ) search_value = self.request.query_params.get("filter[search]", None) if search_value: @@ -1289,7 +1301,8 @@ class FindingViewSet(BaseRLSViewSet): "inserted_at", "updated_at", ] - # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group) + # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of + # the provider through the provider group) required_permissions = [] def get_serializer_class(self): @@ -1302,7 +1315,7 @@ class FindingViewSet(BaseRLSViewSet): user_roles = get_role(self.request.user) if user_roles.unlimited_visibility: # User has unlimited visibility, return all scans - queryset = Finding.objects.all() + queryset = Finding.objects.filter(tenant_id=self.request.tenant_id) else: # User lacks permission, filter providers based on provider groups associated with the role queryset = Finding.objects.filter( @@ -1409,7 +1422,7 @@ class ProviderSecretViewSet(BaseRLSViewSet): required_permissions = [Permissions.MANAGE_PROVIDERS] def get_queryset(self): - return ProviderSecret.objects.all() + return ProviderSecret.objects.filter(tenant_id=self.request.tenant_id) def get_serializer_class(self): if self.action == "create": @@ -1468,7 +1481,7 @@ class InvitationViewSet(BaseRLSViewSet): required_permissions = [Permissions.MANAGE_ACCOUNT] def get_queryset(self): - return Invitation.objects.all() + return Invitation.objects.filter(tenant_id=self.request.tenant_id) def get_serializer_class(self): if self.action == "create": @@ -1515,7 +1528,7 @@ class InvitationAcceptViewSet(BaseRLSViewSet): http_method_names = ["post"] def get_queryset(self): - return Invitation.objects.all() + return Invitation.objects.filter(tenant_id=self.request.tenant_id) def get_serializer_class(self): if hasattr(self, "response_serializer_class"): @@ -1607,7 +1620,7 @@ class RoleViewSet(BaseRLSViewSet): required_permissions = [Permissions.MANAGE_ACCOUNT] def get_queryset(self): - return Role.objects.all() + return Role.objects.filter(tenant_id=self.request.tenant_id) def get_serializer_class(self): if self.action == "create": @@ -1628,7 +1641,8 @@ class RoleViewSet(BaseRLSViewSet): create=extend_schema( tags=["Role"], summary="Create a new role-provider_groups relationship", - description="Add a new role-provider_groups relationship to the system by providing the required role-provider_groups details.", + description="Add a new role-provider_groups relationship to the system by providing the required " + "role-provider_groups details.", responses={ 204: OpenApiResponse(description="Relationship created successfully"), 400: OpenApiResponse( @@ -1667,7 +1681,7 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet): required_permissions = [Permissions.MANAGE_ACCOUNT] def get_queryset(self): - return Role.objects.all() + return Role.objects.filter(tenant_id=self.request.tenant_id) def create(self, request, *args, **kwargs): role = self.get_object() @@ -1750,7 +1764,8 @@ class ComplianceOverviewViewSet(BaseRLSViewSet): search_fields = ["compliance_id"] ordering = ["compliance_id"] ordering_fields = ["inserted_at", "compliance_id", "framework", "region"] - # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group) + # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of + # the provider through the provider group) required_permissions = [] def get_queryset(self): @@ -1761,20 +1776,28 @@ class ComplianceOverviewViewSet(BaseRLSViewSet): if self.action == "retrieve": if unlimited_visibility: - # User has unlimited visibility, return all compliance compliances - return ComplianceOverview.objects.all() + # User has unlimited visibility, return all compliance + return ComplianceOverview.objects.filter( + tenant_id=self.request.tenant_id + ) providers = get_providers(role) - return ComplianceOverview.objects.filter(scan__provider__in=providers) + return ComplianceOverview.objects.filter( + tenant_id=self.request.tenant_id, scan__provider__in=providers + ) if unlimited_visibility: - base_queryset = self.filter_queryset(ComplianceOverview.objects.all()) + base_queryset = self.filter_queryset( + ComplianceOverview.objects.filter(tenant_id=self.request.tenant_id) + ) else: providers = Provider.objects.filter( provider_groups__in=role.provider_groups.all() ).distinct() base_queryset = self.filter_queryset( - ComplianceOverview.objects.filter(scan__provider__in=providers) + ComplianceOverview.objects.filter( + tenant_id=self.request.tenant_id, scan__provider__in=providers + ) ) max_failed_ids = ( @@ -1853,7 +1876,8 @@ class OverviewViewSet(BaseRLSViewSet): queryset = ComplianceOverview.objects.all() http_method_names = ["get"] ordering = ["-id"] - # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group) + # RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of + # the provider through the provider group) required_permissions = [] def get_queryset(self): @@ -1862,8 +1886,10 @@ class OverviewViewSet(BaseRLSViewSet): def _get_filtered_queryset(model): if role.unlimited_visibility: - return model.objects.all() - return model.objects.filter(scan__provider__in=providers) + return model.objects.filter(tenant_id=self.request.tenant_id) + return model.objects.filter( + tenant_id=self.request.tenant_id, scan__provider__in=providers + ) if self.action == "providers": return _get_filtered_queryset(Finding) @@ -1902,17 +1928,22 @@ class OverviewViewSet(BaseRLSViewSet): @action(detail=False, methods=["get"], url_name="providers") def providers(self, request): + tenant_id = self.request.tenant_id # Subquery to get the most recent finding for each uid latest_finding_ids = ( Finding.objects.filter( - uid=OuterRef("uid"), scan__provider=OuterRef("scan__provider") + tenant_id=tenant_id, + uid=OuterRef("uid"), + scan__provider=OuterRef("scan__provider"), ) .order_by("-id") # Most recent by id .values("id")[:1] ) # Filter findings to only include the most recent for each uid - recent_findings = Finding.objects.filter(id__in=Subquery(latest_finding_ids)) + recent_findings = Finding.objects.filter( + tenant_id=tenant_id, id__in=Subquery(latest_finding_ids) + ) # Aggregate findings by provider findings_aggregated = ( @@ -1929,8 +1960,10 @@ class OverviewViewSet(BaseRLSViewSet): ) # Aggregate total resources by provider - resources_aggregated = Resource.objects.values("provider__provider").annotate( - total_resources=Count("id") + resources_aggregated = ( + Resource.objects.filter(tenant_id=tenant_id) + .values("provider__provider") + .annotate(total_resources=Count("id")) ) # Combine findings and resources data @@ -1962,12 +1995,15 @@ class OverviewViewSet(BaseRLSViewSet): @action(detail=False, methods=["get"], url_name="findings") def findings(self, request): + tenant_id = self.request.tenant_id queryset = self.get_queryset() filtered_queryset = self.filter_queryset(queryset) latest_scan_subquery = ( Scan.objects.filter( - state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id") + tenant_id=tenant_id, + state=StateChoices.COMPLETED, + provider_id=OuterRef("scan__provider_id"), ) .order_by("-id") .values("id")[:1] @@ -2004,12 +2040,15 @@ class OverviewViewSet(BaseRLSViewSet): @action(detail=False, methods=["get"], url_name="findings_severity") def findings_severity(self, request): + tenant_id = self.request.tenant_id queryset = self.get_queryset() filtered_queryset = self.filter_queryset(queryset) latest_scan_subquery = ( Scan.objects.filter( - state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id") + tenant_id=tenant_id, + state=StateChoices.COMPLETED, + provider_id=OuterRef("scan__provider_id"), ) .order_by("-id") .values("id")[:1] @@ -2037,12 +2076,15 @@ class OverviewViewSet(BaseRLSViewSet): @action(detail=False, methods=["get"], url_name="services") def services(self, request): + tenant_id = self.request.tenant_id queryset = self.get_queryset() filtered_queryset = self.filter_queryset(queryset) latest_scan_subquery = ( Scan.objects.filter( - state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id") + tenant_id=tenant_id, + state=StateChoices.COMPLETED, + provider_id=OuterRef("scan__provider_id"), ) .order_by("-id") .values("id")[:1] diff --git a/api/src/backend/conftest.py b/api/src/backend/conftest.py index d05de31804..77b9bc41d3 100644 --- a/api/src/backend/conftest.py +++ b/api/src/backend/conftest.py @@ -88,16 +88,14 @@ def create_test_user(django_db_setup, django_db_blocker): @pytest.fixture(scope="function") -def create_test_user_rbac(django_db_setup, django_db_blocker): +def create_test_user_rbac(django_db_setup, django_db_blocker, tenants_fixture): with django_db_blocker.unblock(): user = User.objects.create_user( name="testing", email="rbac@rbac.com", password=TEST_PASSWORD, ) - tenant = Tenant.objects.create( - name="Tenant Test", - ) + tenant = tenants_fixture[0] Membership.objects.create( user=user, tenant=tenant, @@ -123,16 +121,14 @@ def create_test_user_rbac(django_db_setup, django_db_blocker): @pytest.fixture(scope="function") -def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker): +def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker, tenants_fixture): with django_db_blocker.unblock(): user = User.objects.create_user( name="testing", email="rbac_noroles@rbac.com", password=TEST_PASSWORD, ) - tenant = Tenant.objects.create( - name="Tenant Test", - ) + tenant = tenants_fixture[0] Membership.objects.create( user=user, tenant=tenant, @@ -180,10 +176,16 @@ def create_test_user_rbac_limited(django_db_setup, django_db_blocker): @pytest.fixture def authenticated_client_rbac(create_test_user_rbac, tenants_fixture, client): client.user = create_test_user_rbac + tenant_id = tenants_fixture[0].id serializer = TokenSerializer( - data={"type": "tokens", "email": "rbac@rbac.com", "password": TEST_PASSWORD} + data={ + "type": "tokens", + "email": "rbac@rbac.com", + "password": TEST_PASSWORD, + "tenant_id": tenant_id, + } ) - serializer.is_valid() + serializer.is_valid(raise_exception=True) access_token = serializer.validated_data["access"] client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}" return client @@ -303,7 +305,7 @@ def set_user_admin_roles_fixture(create_test_user, tenants_fixture): @pytest.fixture def invitations_fixture(create_test_user, tenants_fixture): user = create_test_user - *_, tenant = tenants_fixture + tenant = tenants_fixture[0] valid_invitation = Invitation.objects.create( email="testing@prowler.com", state=Invitation.State.PENDING,