chore(rls): Add tenant_id filters in views and improve querysets (#6211)

Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
This commit is contained in:
Pepe Fagoaga
2025-01-13 16:22:40 +05:45
committed by GitHub
parent 00722181ad
commit b620f12027
7 changed files with 177 additions and 93 deletions

View File

@@ -515,8 +515,8 @@ class Resource(RowLevelSecurityProtectedModel):
through="ResourceTagMapping",
)
def get_tags(self) -> dict:
return {tag.key: tag.value for tag in self.tags.all()}
def get_tags(self, tenant_id: str) -> dict:
return {tag.key: tag.value for tag in self.tags.filter(tenant_id=tenant_id)}
def clear_tags(self):
self.tags.clear()

View File

@@ -1,9 +1,11 @@
from enum import Enum
from rest_framework.permissions import BasePermission
from api.models import Provider, Role, User
from api.db_router import MainRouter
from typing import Optional
from django.db.models import QuerySet
from rest_framework.permissions import BasePermission
from api.db_router import MainRouter
from api.models import Provider, Role, User
class Permissions(Enum):
@@ -63,8 +65,11 @@ def get_providers(role: Role) -> QuerySet[Provider]:
A QuerySet of Provider objects filtered by the role's provider groups.
If the role has no provider groups, returns an empty queryset.
"""
tenant = role.tenant
provider_groups = role.provider_groups.all()
if not provider_groups.exists():
return Provider.objects.none()
return Provider.objects.filter(provider_groups__in=provider_groups).distinct()
return Provider.objects.filter(
tenant=tenant, provider_groups__in=provider_groups
).distinct()

View File

@@ -7,9 +7,10 @@ from api.models import Resource, ResourceTag
class TestResourceModel:
def test_setting_tags(self, providers_fixture):
provider, *_ = providers_fixture
tenant_id = provider.tenant_id
resource = Resource.objects.create(
tenant_id=provider.tenant_id,
tenant_id=tenant_id,
provider=provider,
uid="arn:aws:ec2:us-east-1:123456789012:instance/i-1234567890abcdef0",
name="My Instance 1",
@@ -20,12 +21,12 @@ class TestResourceModel:
tags = [
ResourceTag.objects.create(
tenant_id=provider.tenant_id,
tenant_id=tenant_id,
key="key",
value="value",
),
ResourceTag.objects.create(
tenant_id=provider.tenant_id,
tenant_id=tenant_id,
key="key2",
value="value2",
),
@@ -33,9 +34,9 @@ class TestResourceModel:
resource.upsert_or_delete_tags(tags)
assert len(tags) == len(resource.tags.all())
assert len(tags) == len(resource.tags.filter(tenant_id=tenant_id))
tags_dict = resource.get_tags()
tags_dict = resource.get_tags(tenant_id=tenant_id)
for tag in tags:
assert tag.key in tags_dict
@@ -43,47 +44,51 @@ class TestResourceModel:
def test_adding_tags(self, resources_fixture):
resource, *_ = resources_fixture
tenant_id = str(resource.tenant_id)
tags = [
ResourceTag.objects.create(
tenant_id=resource.tenant_id,
tenant_id=tenant_id,
key="env",
value="test",
),
]
before_count = len(resource.tags.all())
before_count = len(resource.tags.filter(tenant_id=tenant_id))
resource.upsert_or_delete_tags(tags)
assert before_count + 1 == len(resource.tags.all())
assert before_count + 1 == len(resource.tags.filter(tenant_id=tenant_id))
tags_dict = resource.get_tags()
tags_dict = resource.get_tags(tenant_id=tenant_id)
assert "env" in tags_dict
assert tags_dict["env"] == "test"
def test_adding_duplicate_tags(self, resources_fixture):
resource, *_ = resources_fixture
tenant_id = str(resource.tenant_id)
tags = resource.tags.all()
tags = resource.tags.filter(tenant_id=tenant_id)
before_count = len(resource.tags.all())
before_count = len(resource.tags.filter(tenant_id=tenant_id))
resource.upsert_or_delete_tags(tags)
# should be the same number of tags
assert before_count == len(resource.tags.all())
assert before_count == len(resource.tags.filter(tenant_id=tenant_id))
def test_add_tags_none(self, resources_fixture):
resource, *_ = resources_fixture
tenant_id = str(resource.tenant_id)
resource.upsert_or_delete_tags(None)
assert len(resource.tags.all()) == 0
assert resource.get_tags() == {}
assert len(resource.tags.filter(tenant_id=tenant_id)) == 0
assert resource.get_tags(tenant_id=tenant_id) == {}
def test_clear_tags(self, resources_fixture):
resource, *_ = resources_fixture
tenant_id = str(resource.tenant_id)
resource.clear_tags()
assert len(resource.tags.all()) == 0
assert resource.get_tags() == {}
assert len(resource.tags.filter(tenant_id=tenant_id)) == 0
assert resource.get_tags(tenant_id=tenant_id) == {}

View File

@@ -340,7 +340,7 @@ class TestTenantViewSet:
def test_tenants_list(self, authenticated_client, tenants_fixture):
response = authenticated_client.get(reverse("tenant-list"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == len(tenants_fixture)
assert len(response.json()["data"]) == 2 # Test user belongs to 2 tenants
def test_tenants_retrieve(self, authenticated_client, tenants_fixture):
tenant1, *_ = tenants_fixture
@@ -470,11 +470,11 @@ class TestTenantViewSet:
(
[
("name", "Tenant One", 1),
("name.icontains", "Tenant", 3),
("inserted_at", TODAY, 3),
("inserted_at.gte", "2024-01-01", 3),
("name.icontains", "Tenant", 2),
("inserted_at", TODAY, 2),
("inserted_at.gte", "2024-01-01", 2),
("inserted_at.lte", "2024-01-01", 0),
("updated_at.gte", "2024-01-01", 3),
("updated_at.gte", "2024-01-01", 2),
("updated_at.lte", "2024-01-01", 0),
]
),
@@ -510,7 +510,9 @@ class TestTenantViewSet:
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == page_size
assert response.json()["meta"]["pagination"]["page"] == 1
assert response.json()["meta"]["pagination"]["pages"] == len(tenants_fixture)
assert (
response.json()["meta"]["pagination"]["pages"] == 2
) # Test user belongs to 2 tenants
def test_tenants_list_page_number(self, authenticated_client, tenants_fixture):
page_size = 1
@@ -523,13 +525,13 @@ class TestTenantViewSet:
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == page_size
assert response.json()["meta"]["pagination"]["page"] == page_number
assert response.json()["meta"]["pagination"]["pages"] == len(tenants_fixture)
assert response.json()["meta"]["pagination"]["pages"] == 2
def test_tenants_list_sort_name(self, authenticated_client, tenants_fixture):
_, tenant2, _ = tenants_fixture
response = authenticated_client.get(reverse("tenant-list"), {"sort": "-name"})
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 3
assert len(response.json()["data"]) == 2
assert response.json()["data"][0]["attributes"]["name"] == tenant2.name
def test_tenants_list_memberships_as_owner(
@@ -2339,7 +2341,10 @@ class TestResourceViewSet:
response.json()["errors"][0]["detail"] == "invalid sort parameter: invalid"
)
def test_resources_retrieve(self, authenticated_client, resources_fixture):
def test_resources_retrieve(
self, authenticated_client, tenants_fixture, resources_fixture
):
tenant = tenants_fixture[0]
resource_1, *_ = resources_fixture
response = authenticated_client.get(
reverse("resource-detail", kwargs={"pk": resource_1.id}),
@@ -2350,7 +2355,9 @@ class TestResourceViewSet:
assert response.json()["data"]["attributes"]["region"] == resource_1.region
assert response.json()["data"]["attributes"]["service"] == resource_1.service
assert response.json()["data"]["attributes"]["type"] == resource_1.type
assert response.json()["data"]["attributes"]["tags"] == resource_1.get_tags()
assert response.json()["data"]["attributes"]["tags"] == resource_1.get_tags(
tenant_id=str(tenant.id)
)
def test_resources_invalid_retrieve(self, authenticated_client):
response = authenticated_client.get(
@@ -3261,8 +3268,8 @@ class TestRoleViewSet:
response = authenticated_client.get(reverse("role-list"))
assert response.status_code == status.HTTP_200_OK
assert (
len(response.json()["data"]) == len(roles_fixture) + 2
) # 2 default admin roles, one for each tenant
len(response.json()["data"]) == len(roles_fixture) + 1
) # 1 default admin role
def test_role_retrieve(self, authenticated_client, roles_fixture):
role = roles_fixture[0]

View File

@@ -874,7 +874,7 @@ class ResourceSerializer(RLSSerializer):
}
)
def get_tags(self, obj):
return obj.get_tags()
return obj.get_tags(self.context.get("tenant_id"))
def get_fields(self):
"""`type` is a Python reserved keyword."""
@@ -1233,6 +1233,12 @@ class InvitationSerializer(RLSSerializer):
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tenant_id = self.context.get("tenant_id")
if tenant_id is not None:
self.fields["roles"].queryset = Role.objects.filter(tenant_id=tenant_id)
class Meta:
model = Invitation
fields = [
@@ -1252,6 +1258,12 @@ class InvitationSerializer(RLSSerializer):
class InvitationBaseWriteSerializer(BaseWriteSerializer):
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tenant_id = self.context.get("tenant_id")
if tenant_id is not None:
self.fields["roles"].queryset = Role.objects.filter(tenant_id=tenant_id)
def validate_email(self, value):
user = User.objects.filter(email=value).first()
tenant_id = self.context["tenant_id"]
@@ -1367,6 +1379,17 @@ class RoleSerializer(RLSSerializer, BaseWriteSerializer):
queryset=ProviderGroup.objects.all(), many=True, required=False
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tenant_id = self.context.get("tenant_id")
if tenant_id is not None:
self.fields["users"].queryset = User.objects.filter(
membership__tenant__id=tenant_id
)
self.fields["provider_groups"].queryset = ProviderGroup.objects.filter(
tenant_id=self.context.get("tenant_id")
)
def get_permission_state(self, obj) -> str:
return obj.permission_state

View File

@@ -309,7 +309,12 @@ class UserViewSet(BaseUserViewset):
# If called during schema generation, return an empty queryset
if getattr(self, "swagger_fake_view", False):
return User.objects.none()
return User.objects.filter(membership__tenant__id=self.request.tenant_id)
queryset = (
User.objects.filter(membership__tenant__id=self.request.tenant_id)
if hasattr(self.request, "tenant_id")
else User.objects.all()
)
return queryset.prefetch_related("memberships", "roles")
def get_permissions(self):
if self.action == "create":
@@ -452,7 +457,7 @@ class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
required_permissions = [Permissions.MANAGE_USERS]
def get_queryset(self):
return User.objects.all()
return User.objects.filter(membership__tenant__id=self.request.tenant_id)
def create(self, request, *args, **kwargs):
user = self.get_object()
@@ -540,7 +545,8 @@ class TenantViewSet(BaseTenantViewset):
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Tenant.objects.all()
queryset = Tenant.objects.filter(membership__user=self.request.user)
return queryset.prefetch_related("memberships")
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
@@ -600,7 +606,8 @@ class MembershipViewSet(BaseTenantViewset):
def get_queryset(self):
user = self.request.user
return Membership.objects.filter(user_id=user.id)
queryset = Membership.objects.filter(user_id=user.id)
return queryset.select_related("user", "tenant")
@extend_schema_view(
@@ -736,10 +743,10 @@ class ProviderGroupViewSet(BaseRLSViewSet):
# Check if any of the user's roles have UNLIMITED_VISIBILITY
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all provider groups
return ProviderGroup.objects.prefetch_related("providers")
return ProviderGroup.objects.prefetch_related("providers", "roles")
# Collect provider groups associated with the user's roles
return user_roles.provider_groups.all()
return user_roles.provider_groups.all().prefetch_related("providers", "roles")
def get_serializer_class(self):
if self.action == "create":
@@ -790,7 +797,7 @@ class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
required_permissions = [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return ProviderGroup.objects.all()
return ProviderGroup.objects.filter(tenant_id=self.request.tenant_id)
def create(self, request, *args, **kwargs):
provider_group = self.get_object()
@@ -904,10 +911,11 @@ class ProviderViewSet(BaseRLSViewSet):
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all providers
return Provider.objects.all()
# User lacks permission, filter providers based on provider groups associated with the role
return get_providers(user_roles)
queryset = Provider.objects.filter(tenant_id=self.request.tenant_id)
else:
# User lacks permission, filter providers based on provider groups associated with the role
queryset = get_providers(user_roles)
return queryset.select_related("secret").prefetch_related("provider_groups")
def get_serializer_class(self):
if self.action == "create":
@@ -945,7 +953,7 @@ class ProviderViewSet(BaseRLSViewSet):
get_object_or_404(Provider, pk=pk)
with transaction.atomic():
task = check_provider_connection_task.delay(
provider_id=pk, tenant_id=request.tenant_id
provider_id=pk, tenant_id=self.request.tenant_id
)
prowler_task = Task.objects.get(id=task.id)
serializer = TaskSerializer(prowler_task)
@@ -966,7 +974,7 @@ class ProviderViewSet(BaseRLSViewSet):
with transaction.atomic():
task = delete_provider_task.delay(
provider_id=pk, tenant_id=request.tenant_id
provider_id=pk, tenant_id=self.request.tenant_id
)
prowler_task = Task.objects.get(id=task.id)
serializer = TaskSerializer(prowler_task)
@@ -1045,10 +1053,11 @@ class ScanViewSet(BaseRLSViewSet):
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all scans
return Scan.objects.all()
# User lacks permission, filter providers based on provider groups associated with the role
return Scan.objects.filter(provider__in=get_providers(user_roles))
queryset = Scan.objects.filter(tenant_id=self.request.tenant_id)
else:
# User lacks permission, filter providers based on provider groups associated with the role
queryset = Scan.objects.filter(provider__in=get_providers(user_roles))
return queryset.select_related("provider", "task")
def get_serializer_class(self):
if self.action == "create":
@@ -1082,14 +1091,14 @@ class ScanViewSet(BaseRLSViewSet):
with transaction.atomic():
task = perform_scan_task.apply_async(
kwargs={
"tenant_id": request.tenant_id,
"tenant_id": self.request.tenant_id,
"scan_id": str(scan.id),
"provider_id": str(scan.provider_id),
# Disabled for now
# checks_to_execute=scan.scanner_args.get("checks_to_execute"),
},
link=perform_scan_summary_task.si(
tenant_id=request.tenant_id,
tenant_id=self.request.tenant_id,
scan_id=str(scan.id),
),
)
@@ -1145,7 +1154,7 @@ class TaskViewSet(BaseRLSViewSet):
return Task.objects.annotate(
name=F("task_runner_task__task_name"),
state=F("task_runner_task__status"),
)
).select_related("task_runner_task")
def destroy(self, request, *args, pk=None, **kwargs):
task = get_object_or_404(Task, pk=pk)
@@ -1206,17 +1215,20 @@ class ResourceViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
# the provider through the provider group)
required_permissions = []
def get_queryset(self):
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all scans
queryset = Resource.objects.all()
queryset = Resource.objects.filter(tenant_id=self.request.tenant_id)
else:
# User lacks permission, filter providers based on provider groups associated with the role
queryset = Resource.objects.filter(provider__in=get_providers(user_roles))
queryset = Resource.objects.filter(
tenant_id=self.request.tenant_id, provider__in=get_providers(user_roles)
)
search_value = self.request.query_params.get("filter[search]", None)
if search_value:
@@ -1289,7 +1301,8 @@ class FindingViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
# the provider through the provider group)
required_permissions = []
def get_serializer_class(self):
@@ -1302,7 +1315,7 @@ class FindingViewSet(BaseRLSViewSet):
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all scans
queryset = Finding.objects.all()
queryset = Finding.objects.filter(tenant_id=self.request.tenant_id)
else:
# User lacks permission, filter providers based on provider groups associated with the role
queryset = Finding.objects.filter(
@@ -1409,7 +1422,7 @@ class ProviderSecretViewSet(BaseRLSViewSet):
required_permissions = [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return ProviderSecret.objects.all()
return ProviderSecret.objects.filter(tenant_id=self.request.tenant_id)
def get_serializer_class(self):
if self.action == "create":
@@ -1468,7 +1481,7 @@ class InvitationViewSet(BaseRLSViewSet):
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Invitation.objects.all()
return Invitation.objects.filter(tenant_id=self.request.tenant_id)
def get_serializer_class(self):
if self.action == "create":
@@ -1515,7 +1528,7 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
http_method_names = ["post"]
def get_queryset(self):
return Invitation.objects.all()
return Invitation.objects.filter(tenant_id=self.request.tenant_id)
def get_serializer_class(self):
if hasattr(self, "response_serializer_class"):
@@ -1607,7 +1620,7 @@ class RoleViewSet(BaseRLSViewSet):
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Role.objects.all()
return Role.objects.filter(tenant_id=self.request.tenant_id)
def get_serializer_class(self):
if self.action == "create":
@@ -1628,7 +1641,8 @@ class RoleViewSet(BaseRLSViewSet):
create=extend_schema(
tags=["Role"],
summary="Create a new role-provider_groups relationship",
description="Add a new role-provider_groups relationship to the system by providing the required role-provider_groups details.",
description="Add a new role-provider_groups relationship to the system by providing the required "
"role-provider_groups details.",
responses={
204: OpenApiResponse(description="Relationship created successfully"),
400: OpenApiResponse(
@@ -1667,7 +1681,7 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Role.objects.all()
return Role.objects.filter(tenant_id=self.request.tenant_id)
def create(self, request, *args, **kwargs):
role = self.get_object()
@@ -1750,7 +1764,8 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
search_fields = ["compliance_id"]
ordering = ["compliance_id"]
ordering_fields = ["inserted_at", "compliance_id", "framework", "region"]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
# the provider through the provider group)
required_permissions = []
def get_queryset(self):
@@ -1761,20 +1776,28 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
if self.action == "retrieve":
if unlimited_visibility:
# User has unlimited visibility, return all compliance compliances
return ComplianceOverview.objects.all()
# User has unlimited visibility, return all compliance
return ComplianceOverview.objects.filter(
tenant_id=self.request.tenant_id
)
providers = get_providers(role)
return ComplianceOverview.objects.filter(scan__provider__in=providers)
return ComplianceOverview.objects.filter(
tenant_id=self.request.tenant_id, scan__provider__in=providers
)
if unlimited_visibility:
base_queryset = self.filter_queryset(ComplianceOverview.objects.all())
base_queryset = self.filter_queryset(
ComplianceOverview.objects.filter(tenant_id=self.request.tenant_id)
)
else:
providers = Provider.objects.filter(
provider_groups__in=role.provider_groups.all()
).distinct()
base_queryset = self.filter_queryset(
ComplianceOverview.objects.filter(scan__provider__in=providers)
ComplianceOverview.objects.filter(
tenant_id=self.request.tenant_id, scan__provider__in=providers
)
)
max_failed_ids = (
@@ -1853,7 +1876,8 @@ class OverviewViewSet(BaseRLSViewSet):
queryset = ComplianceOverview.objects.all()
http_method_names = ["get"]
ordering = ["-id"]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
# the provider through the provider group)
required_permissions = []
def get_queryset(self):
@@ -1862,8 +1886,10 @@ class OverviewViewSet(BaseRLSViewSet):
def _get_filtered_queryset(model):
if role.unlimited_visibility:
return model.objects.all()
return model.objects.filter(scan__provider__in=providers)
return model.objects.filter(tenant_id=self.request.tenant_id)
return model.objects.filter(
tenant_id=self.request.tenant_id, scan__provider__in=providers
)
if self.action == "providers":
return _get_filtered_queryset(Finding)
@@ -1902,17 +1928,22 @@ class OverviewViewSet(BaseRLSViewSet):
@action(detail=False, methods=["get"], url_name="providers")
def providers(self, request):
tenant_id = self.request.tenant_id
# Subquery to get the most recent finding for each uid
latest_finding_ids = (
Finding.objects.filter(
uid=OuterRef("uid"), scan__provider=OuterRef("scan__provider")
tenant_id=tenant_id,
uid=OuterRef("uid"),
scan__provider=OuterRef("scan__provider"),
)
.order_by("-id") # Most recent by id
.values("id")[:1]
)
# Filter findings to only include the most recent for each uid
recent_findings = Finding.objects.filter(id__in=Subquery(latest_finding_ids))
recent_findings = Finding.objects.filter(
tenant_id=tenant_id, id__in=Subquery(latest_finding_ids)
)
# Aggregate findings by provider
findings_aggregated = (
@@ -1929,8 +1960,10 @@ class OverviewViewSet(BaseRLSViewSet):
)
# Aggregate total resources by provider
resources_aggregated = Resource.objects.values("provider__provider").annotate(
total_resources=Count("id")
resources_aggregated = (
Resource.objects.filter(tenant_id=tenant_id)
.values("provider__provider")
.annotate(total_resources=Count("id"))
)
# Combine findings and resources data
@@ -1962,12 +1995,15 @@ class OverviewViewSet(BaseRLSViewSet):
@action(detail=False, methods=["get"], url_name="findings")
def findings(self, request):
tenant_id = self.request.tenant_id
queryset = self.get_queryset()
filtered_queryset = self.filter_queryset(queryset)
latest_scan_subquery = (
Scan.objects.filter(
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
tenant_id=tenant_id,
state=StateChoices.COMPLETED,
provider_id=OuterRef("scan__provider_id"),
)
.order_by("-id")
.values("id")[:1]
@@ -2004,12 +2040,15 @@ class OverviewViewSet(BaseRLSViewSet):
@action(detail=False, methods=["get"], url_name="findings_severity")
def findings_severity(self, request):
tenant_id = self.request.tenant_id
queryset = self.get_queryset()
filtered_queryset = self.filter_queryset(queryset)
latest_scan_subquery = (
Scan.objects.filter(
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
tenant_id=tenant_id,
state=StateChoices.COMPLETED,
provider_id=OuterRef("scan__provider_id"),
)
.order_by("-id")
.values("id")[:1]
@@ -2037,12 +2076,15 @@ class OverviewViewSet(BaseRLSViewSet):
@action(detail=False, methods=["get"], url_name="services")
def services(self, request):
tenant_id = self.request.tenant_id
queryset = self.get_queryset()
filtered_queryset = self.filter_queryset(queryset)
latest_scan_subquery = (
Scan.objects.filter(
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
tenant_id=tenant_id,
state=StateChoices.COMPLETED,
provider_id=OuterRef("scan__provider_id"),
)
.order_by("-id")
.values("id")[:1]

View File

@@ -88,16 +88,14 @@ def create_test_user(django_db_setup, django_db_blocker):
@pytest.fixture(scope="function")
def create_test_user_rbac(django_db_setup, django_db_blocker):
def create_test_user_rbac(django_db_setup, django_db_blocker, tenants_fixture):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing",
email="rbac@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
tenant = tenants_fixture[0]
Membership.objects.create(
user=user,
tenant=tenant,
@@ -123,16 +121,14 @@ def create_test_user_rbac(django_db_setup, django_db_blocker):
@pytest.fixture(scope="function")
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker):
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker, tenants_fixture):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing",
email="rbac_noroles@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
tenant = tenants_fixture[0]
Membership.objects.create(
user=user,
tenant=tenant,
@@ -180,10 +176,16 @@ def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
@pytest.fixture
def authenticated_client_rbac(create_test_user_rbac, tenants_fixture, client):
client.user = create_test_user_rbac
tenant_id = tenants_fixture[0].id
serializer = TokenSerializer(
data={"type": "tokens", "email": "rbac@rbac.com", "password": TEST_PASSWORD}
data={
"type": "tokens",
"email": "rbac@rbac.com",
"password": TEST_PASSWORD,
"tenant_id": tenant_id,
}
)
serializer.is_valid()
serializer.is_valid(raise_exception=True)
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@@ -303,7 +305,7 @@ def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
@pytest.fixture
def invitations_fixture(create_test_user, tenants_fixture):
user = create_test_user
*_, tenant = tenants_fixture
tenant = tenants_fixture[0]
valid_invitation = Invitation.objects.create(
email="testing@prowler.com",
state=Invitation.State.PENDING,