diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 1aaa3a65ab..f969bb7ffe 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -6,6 +6,8 @@ All notable changes to the **Prowler API** are documented in this file. ### 🚀 Added +- Provider group filters for API endpoints that support cloud provider filtering, including exact and `__in` variants [(#11573)](https://github.com/prowler-cloud/prowler/pull/11573) +- Provider filters for `GET /api/v1/compliance-overviews`, `/metadata`, and `/requirements`, using latest completed scans per matching provider [(#11587)](https://github.com/prowler-cloud/prowler/pull/11587) - Server-Sent Events (SSE) infrastructure for the API: a base viewset, a tenant-aware channel manager, and channel-name helpers backed by `django-eventstream` over Valkey Pub/Sub and served through the Gunicorn ASGI worker, so feature endpoints can stream events to clients over a single long-lived connection [(#11556)](https://github.com/prowler-cloud/prowler/pull/11556) ### 🔐 Security diff --git a/api/src/backend/api/filters.py b/api/src/backend/api/filters.py index c787d714b1..195982011d 100644 --- a/api/src/backend/api/filters.py +++ b/api/src/backend/api/filters.py @@ -102,7 +102,7 @@ class BaseProviderFilter(FilterSet): """ Abstract base filter for models with direct FK to Provider. - Provides standard provider_id and provider_type filters. + Provides standard provider_id, provider_type, and provider_groups filters. Subclasses must define Meta.model. """ @@ -116,6 +116,16 @@ class BaseProviderFilter(FilterSet): choices=Provider.ProviderChoices.choices, lookup_expr="in", ) + provider_groups = UUIDFilter( + field_name="provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) class Meta: abstract = True @@ -126,7 +136,7 @@ class BaseScanProviderFilter(FilterSet): """ Abstract base filter for models with FK to Scan (and Scan has FK to Provider). - Provides standard provider_id and provider_type filters via scan relationship. + Provides standard provider_id, provider_type, and provider_groups filters via scan relationship. Subclasses must define Meta.model. """ @@ -140,6 +150,16 @@ class BaseScanProviderFilter(FilterSet): choices=Provider.ProviderChoices.choices, lookup_expr="in", ) + provider_groups = UUIDFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) class Meta: abstract = True @@ -160,6 +180,16 @@ class CommonFindingFilters(FilterSet): provider_type__in = ChoiceInFilter( choices=Provider.ProviderChoices.choices, field_name="scan__provider__provider" ) + provider_groups = UUIDFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) provider_uid = CharFilter(field_name="scan__provider__uid", lookup_expr="exact") provider_uid__in = CharInFilter(field_name="scan__provider__uid", lookup_expr="in") provider_uid__icontains = CharFilter( @@ -370,6 +400,12 @@ class ProviderFilter(FilterSet): choices=Provider.ProviderChoices.choices, lookup_expr="in", ) + provider_groups = UUIDFilter( + field_name="provider_groups__id", lookup_expr="exact", distinct=True + ) + provider_groups__in = UUIDInFilter( + field_name="provider_groups__id", lookup_expr="in", distinct=True + ) class Meta: model = Provider @@ -395,6 +431,16 @@ class ProviderRelationshipFilterSet(FilterSet): provider_type__in = ChoiceInFilter( choices=Provider.ProviderChoices.choices, field_name="provider__provider" ) + provider_groups = UUIDFilter( + field_name="provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) provider_uid = CharFilter(field_name="provider__uid", lookup_expr="exact") provider_uid__in = CharInFilter(field_name="provider__uid", lookup_expr="in") provider_uid__icontains = CharFilter( @@ -1001,6 +1047,16 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet): field_name="provider__provider", choices=Provider.ProviderChoices.choices ) provider_type__in = CharInFilter(field_name="provider__provider", lookup_expr="in") + provider_groups = UUIDFilter( + field_name="provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) class Meta: model = FindingGroupDailySummary @@ -1101,6 +1157,16 @@ class LatestFindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet): field_name="provider__provider", choices=Provider.ProviderChoices.choices ) provider_type__in = CharInFilter(field_name="provider__provider", lookup_expr="in") + provider_groups = UUIDFilter( + field_name="provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) class Meta: model = FindingGroupDailySummary @@ -1280,12 +1346,19 @@ class RoleFilter(FilterSet): } -class ComplianceOverviewFilter(FilterSet): +class ComplianceOverviewFilter(BaseScanProviderFilter): + """ + Keep provider filters in the schema while runtime filtering resolves scans first. + + Compliance overview provider filters are applied to the latest completed scans + in the viewset, then this filterset handles the remaining compliance fields. + """ + inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date") - scan_id = UUIDFilter(field_name="scan_id", required=True) + scan_id = UUIDFilter(field_name="scan_id") region = CharFilter(field_name="region") - class Meta: + class Meta(BaseScanProviderFilter.Meta): model = ComplianceRequirementOverview fields = { "inserted_at": ["date", "gte", "lte"], @@ -1306,6 +1379,16 @@ class ScanSummaryFilter(FilterSet): provider_type__in = ChoiceInFilter( field_name="scan__provider__provider", choices=Provider.ProviderChoices.choices ) + provider_groups = UUIDFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) region = CharFilter(field_name="region") class Meta: @@ -1329,6 +1412,16 @@ class DailySeveritySummaryFilter(FilterSet): provider_type__in = ChoiceInFilter( field_name="provider__provider", choices=Provider.ProviderChoices.choices ) + provider_groups = UUIDFilter( + field_name="provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) date_from = DateFilter(method="filter_noop") date_to = DateFilter(method="filter_noop") @@ -1585,6 +1678,16 @@ class ThreatScoreSnapshotFilter(FilterSet): choices=Provider.ProviderChoices.choices, lookup_expr="in", ) + provider_groups = UUIDFilter( + field_name="provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) compliance_id = CharFilter(field_name="compliance_id", lookup_expr="exact") compliance_id__in = CharInFilter(field_name="compliance_id", lookup_expr="in") @@ -1628,6 +1731,16 @@ class ResourceGroupOverviewFilter(FilterSet): choices=Provider.ProviderChoices.choices, lookup_expr="in", ) + provider_groups = UUIDFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="exact", + distinct=True, + ) + provider_groups__in = UUIDInFilter( + field_name="scan__provider__provider_groups__id", + lookup_expr="in", + distinct=True, + ) resource_group = CharFilter(field_name="resource_group", lookup_expr="exact") resource_group__in = CharInFilter(field_name="resource_group", lookup_expr="in") diff --git a/api/src/backend/api/tests/test_views.py b/api/src/backend/api/tests/test_views.py index 67fb6c0663..35098c1907 100644 --- a/api/src/backend/api/tests/test_views.py +++ b/api/src/backend/api/tests/test_views.py @@ -1411,6 +1411,42 @@ class TestProviderViewSet: ) assert response.status_code == status.HTTP_400_BAD_REQUEST + def test_providers_filter_provider_groups( + self, + authenticated_client, + tenants_fixture, + providers_fixture, + provider_groups_fixture, + ): + tenant = tenants_fixture[0] + provider1, provider2, *_ = providers_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider2, provider_group=group2 + ) + + response = authenticated_client.get( + reverse("provider-list"), {"filter[provider_groups]": str(group1.id)} + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + assert [item["id"] for item in data] == [str(provider1.id)] + + response = authenticated_client.get( + reverse("provider-list"), + {"filter[provider_groups__in]": f"{group1.id},{group2.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + provider_ids = {item["id"] for item in response.json()["data"]} + assert provider_ids == {str(provider1.id), str(provider2.id)} + assert len(response.json()["data"]) == 2 + def test_providers_disable_pagination( self, authenticated_client, providers_fixture, tenants_fixture ): @@ -3715,6 +3751,41 @@ class TestScanViewSet: assert response.status_code == status.HTTP_200_OK assert len(response.json()["data"]) == expected_count + def test_scans_filter_provider_groups( + self, + authenticated_client, + tenants_fixture, + scans_fixture, + provider_groups_fixture, + ): + tenant = tenants_fixture[0] + scan1, scan2, *_ = scans_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=scan1.provider, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=scan1.provider, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=scan2.provider, provider_group=group2 + ) + + response = authenticated_client.get( + reverse("scan-list"), {"filter[provider_groups]": str(group1.id)} + ) + assert response.status_code == status.HTTP_200_OK + assert {item["id"] for item in response.json()["data"]} == {str(scan1.id)} + + response = authenticated_client.get( + reverse("scan-list"), + {"filter[provider_groups__in]": f"{group1.id},{group2.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + scan_ids = {item["id"] for item in response.json()["data"]} + assert scan_ids == {str(scan1.id), str(scan2.id), str(scans_fixture[2].id)} + assert len(response.json()["data"]) == 3 + @pytest.mark.parametrize( "filter_name", [ @@ -5996,6 +6067,49 @@ class TestResourceViewSet: assert response.status_code == status.HTTP_200_OK assert len(response.json()["data"]) == expected_count + def test_resource_filter_provider_groups( + self, + authenticated_client, + tenants_fixture, + resources_fixture, + provider_groups_fixture, + ): + tenant = tenants_fixture[0] + resource1, resource2, resource3, *_ = resources_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=resource1.provider, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=resource1.provider, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=resource3.provider, provider_group=group2 + ) + + response = authenticated_client.get( + reverse("resource-list"), + {"filter[updated_at]": TODAY, "filter[provider_groups]": str(group1.id)}, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["data"]) == 2 + assert {item["id"] for item in response.json()["data"]} == { + str(resource1.id), + str(resource2.id), + } + + response = authenticated_client.get( + reverse("resource-list"), + { + "filter[updated_at]": TODAY, + "filter[provider_groups__in]": f"{group1.id},{group2.id}", + }, + ) + assert response.status_code == status.HTTP_200_OK + resource_ids = {item["id"] for item in response.json()["data"]} + assert resource_ids == {str(resource1.id), str(resource2.id), str(resource3.id)} + assert len(response.json()["data"]) == 3 + def test_resource_filter_by_scan_id( self, authenticated_client, resources_fixture, scans_fixture ): @@ -7308,6 +7422,40 @@ class TestFindingViewSet: assert response.status_code == status.HTTP_200_OK assert len(response.json()["data"]) == 2 + def test_finding_filter_provider_groups( + self, + authenticated_client, + tenants_fixture, + findings_fixture, + provider_groups_fixture, + ): + tenant = tenants_fixture[0] + finding1, finding2, *_ = findings_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=finding1.scan.provider, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=finding1.scan.provider, provider_group=group2 + ) + + response = authenticated_client.get( + reverse("finding-list"), + {"filter[inserted_at]": TODAY, "filter[provider_groups]": str(group1.id)}, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["data"]) == 2 + + response = authenticated_client.get( + reverse("finding-list"), + { + "filter[inserted_at]": TODAY, + "filter[provider_groups__in]": f"{group1.id},{group2.id}", + }, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["data"]) == 2 + @pytest.mark.parametrize( "filter_name", ( @@ -9278,6 +9426,118 @@ class TestComplianceOverviewViewSet: with patch("api.v1.views.backfill_compliance_summaries_task.delay") as mock: yield mock + def _create_completed_scan(self, provider, name): + return Scan.objects.create( + name=name, + provider=provider, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant_id=provider.tenant_id, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + ) + + def _create_requirement( + self, + scan, + requirement_id, + status_choice, + region="eu-west-1", + compliance_id="cis_1.4_aws", + ): + passed = 1 if status_choice == StatusChoices.PASS else 0 + total = 1 if status_choice != StatusChoices.MANUAL else 0 + return ComplianceRequirementOverview.objects.create( + tenant_id=scan.tenant_id, + scan=scan, + compliance_id=compliance_id, + framework="CIS-1.4-AWS", + version="1.4", + description="CIS AWS Foundations Benchmark v1.4.0", + region=region, + requirement_id=requirement_id, + requirement_status=status_choice, + passed_checks=passed, + failed_checks=0 + if status_choice in (StatusChoices.PASS, StatusChoices.MANUAL) + else 1, + total_checks=total, + passed_findings=passed, + total_findings=total, + ) + + def _create_compliance_summary( + self, + scan, + *, + passed, + failed, + manual=0, + compliance_id="cis_1.4_aws", + ): + return ComplianceOverviewSummary.objects.create( + tenant_id=scan.tenant_id, + scan=scan, + compliance_id=compliance_id, + requirements_passed=passed, + requirements_failed=failed, + requirements_manual=manual, + total_requirements=passed + failed + manual, + ) + + def _overview_attrs_by_id(self, response): + assert response.status_code == status.HTTP_200_OK + return {item["id"]: item["attributes"] for item in response.json()["data"]} + + def _prepare_latest_compliance_data(self, providers_fixture): + provider1, provider2, provider3, *_ = providers_fixture + old_scan = self._create_completed_scan(provider1, "old aws compliance scan") + latest_scan1 = self._create_completed_scan( + provider1, "latest aws compliance scan 1" + ) + latest_scan2 = self._create_completed_scan( + provider2, "latest aws compliance scan 2" + ) + latest_gcp_scan = self._create_completed_scan( + provider3, "latest gcp compliance scan" + ) + + self._create_requirement(old_scan, "1.1", StatusChoices.FAIL) + self._create_requirement(old_scan, "1.2", StatusChoices.FAIL) + self._create_compliance_summary(old_scan, passed=0, failed=2) + + self._create_requirement( + latest_scan1, "1.1", StatusChoices.PASS, region="eu-west-1" + ) + self._create_requirement( + latest_scan1, "1.2", StatusChoices.PASS, region="eu-west-1" + ) + self._create_compliance_summary(latest_scan1, passed=2, failed=0) + + self._create_requirement( + latest_scan2, "1.1", StatusChoices.FAIL, region="us-east-1" + ) + self._create_requirement( + latest_scan2, "1.2", StatusChoices.PASS, region="us-east-1" + ) + self._create_compliance_summary(latest_scan2, passed=1, failed=1) + + self._create_requirement( + latest_gcp_scan, + "gcp-1.1", + StatusChoices.FAIL, + region="europe-west1", + compliance_id="cis_1.3_gcp", + ) + self._create_compliance_summary( + latest_gcp_scan, + passed=0, + failed=1, + compliance_id="cis_1.3_gcp", + ) + + return old_scan, latest_scan1, latest_scan2, latest_gcp_scan + def test_compliance_overview_list_none( self, authenticated_client, @@ -9425,6 +9685,283 @@ class TestComplianceOverviewViewSet: assert len(response.json()["data"]) >= 1 mock_backfill_task.assert_not_called() + def test_compliance_overview_provider_id_filter_uses_latest_scan( + self, + authenticated_client, + providers_fixture, + mock_backfill_task, + ): + _, latest_scan, *_ = self._prepare_latest_compliance_data(providers_fixture) + + response = authenticated_client.get( + reverse("complianceoverview-list"), + {"filter[provider_id]": str(latest_scan.provider_id)}, + ) + + attrs_by_id = self._overview_attrs_by_id(response) + assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 2 + assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 0 + assert "cis_1.3_gcp" not in attrs_by_id + mock_backfill_task.assert_not_called() + + def test_compliance_overview_provider_id_in_filter_aggregates_latest_scans( + self, + authenticated_client, + providers_fixture, + ): + _, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data( + providers_fixture + ) + + response = authenticated_client.get( + reverse("complianceoverview-list"), + { + "filter[provider_id__in]": ( + f"{latest_scan1.provider_id},{latest_scan2.provider_id}" + ) + }, + ) + + attrs_by_id = self._overview_attrs_by_id(response) + assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1 + assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1 + assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2 + assert "cis_1.3_gcp" not in attrs_by_id + + def test_compliance_overview_provider_type_filter_uses_latest_scans( + self, + authenticated_client, + providers_fixture, + ): + self._prepare_latest_compliance_data(providers_fixture) + + response = authenticated_client.get( + reverse("complianceoverview-list"), + {"filter[provider_type]": Provider.ProviderChoices.AWS.value}, + ) + + attrs_by_id = self._overview_attrs_by_id(response) + assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1 + assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1 + assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2 + assert "cis_1.3_gcp" not in attrs_by_id + + def test_compliance_overview_provider_groups_filters_use_latest_scans( + self, + authenticated_client, + providers_fixture, + provider_groups_fixture, + tenants_fixture, + ): + tenant = tenants_fixture[0] + provider1, provider2, *_ = providers_fixture + group1, group2, *_ = provider_groups_fixture + _, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data( + providers_fixture + ) + ProviderGroupMembership.objects.create( + tenant_id=tenant.id, + provider=provider1, + provider_group=group1, + ) + ProviderGroupMembership.objects.create( + tenant_id=tenant.id, + provider=provider2, + provider_group=group2, + ) + + response = authenticated_client.get( + reverse("complianceoverview-list"), + {"filter[provider_groups]": str(group1.id)}, + ) + + attrs_by_id = self._overview_attrs_by_id(response) + assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 2 + assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 0 + + response = authenticated_client.get( + reverse("complianceoverview-list"), + {"filter[provider_groups__in]": f"{group1.id},{group2.id}"}, + ) + + attrs_by_id = self._overview_attrs_by_id(response) + assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1 + assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1 + assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2 + + def _assert_latest_provider_scan_task_response( + self, + authenticated_client, + endpoint, + scan, + query_params=None, + ): + query_params = {**(query_params or {})} + if not any(key.startswith("filter[provider_") for key in query_params): + query_params = { + "filter[provider_id]": str(scan.provider_id), + **query_params, + } + + with patch.object( + ComplianceOverviewViewSet, "get_task_response_if_running" + ) as mock_task_response: + mock_task_response.return_value = Response( + {"detail": "Task is running"}, status=status.HTTP_202_ACCEPTED + ) + + response = authenticated_client.get(reverse(endpoint), query_params) + + assert response.status_code == status.HTTP_202_ACCEPTED + mock_task_response.assert_called_once() + _, kwargs = mock_task_response.call_args + assert kwargs["task_name"] == "scan-compliance-overviews" + assert str(kwargs["task_kwargs"]["tenant_id"]) == str(scan.tenant_id) + assert str(kwargs["task_kwargs"]["scan_id"]) == str(scan.id) + assert kwargs["raise_on_not_found"] is False + + def test_compliance_overview_provider_filter_returns_running_task_without_data( + self, + authenticated_client, + providers_fixture, + ): + scan = self._create_completed_scan( + providers_fixture[0], "latest scan without compliance data" + ) + + self._assert_latest_provider_scan_task_response( + authenticated_client, + "complianceoverview-list", + scan, + ) + + def test_compliance_overview_provider_filter_returns_running_task_for_partial_data( + self, + authenticated_client, + providers_fixture, + ): + provider_with_data, provider_without_data, *_ = providers_fixture + scan_with_data = self._create_completed_scan( + provider_with_data, "latest scan with compliance data" + ) + scan_without_data = self._create_completed_scan( + provider_without_data, "latest scan without partial compliance data" + ) + self._create_requirement(scan_with_data, "1.1", StatusChoices.PASS) + + self._assert_latest_provider_scan_task_response( + authenticated_client, + "complianceoverview-list", + scan_without_data, + { + "filter[provider_id__in]": ( + f"{provider_with_data.id},{provider_without_data.id}" + ) + }, + ) + + def test_compliance_overview_provider_filter_empty_response_uses_scan_data_presence( + self, + authenticated_client, + providers_fixture, + ): + scan = self._create_completed_scan( + providers_fixture[0], "latest scan with filtered compliance data" + ) + self._create_requirement(scan, "1.1", StatusChoices.PASS, region="eu-west-1") + + with patch.object( + ComplianceOverviewViewSet, "get_task_response_if_running" + ) as mock_task_response: + mock_task_response.return_value = Response( + {"detail": "Task is running"}, status=status.HTTP_202_ACCEPTED + ) + + response = authenticated_client.get( + reverse("complianceoverview-list"), + { + "filter[provider_id]": str(scan.provider_id), + "filter[region]": "us-east-1", + }, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.json()["data"] == [] + mock_task_response.assert_not_called() + + def test_compliance_overview_metadata_provider_filter_returns_running_task_without_data( + self, + authenticated_client, + providers_fixture, + ): + scan = self._create_completed_scan( + providers_fixture[0], "latest scan without compliance metadata" + ) + + self._assert_latest_provider_scan_task_response( + authenticated_client, + "complianceoverview-metadata", + scan, + ) + + def test_compliance_overview_requirements_provider_filter_returns_running_task_without_data( + self, + authenticated_client, + providers_fixture, + ): + scan = self._create_completed_scan( + providers_fixture[0], "latest scan without compliance requirements" + ) + + self._assert_latest_provider_scan_task_response( + authenticated_client, + "complianceoverview-requirements", + scan, + {"filter[compliance_id]": "cis_1.4_aws"}, + ) + + def test_compliance_overview_metadata_accepts_provider_filters( + self, + authenticated_client, + providers_fixture, + ): + _, latest_scan, *_ = self._prepare_latest_compliance_data(providers_fixture) + + response = authenticated_client.get( + reverse("complianceoverview-metadata"), + {"filter[provider_id]": str(latest_scan.provider_id)}, + ) + + assert response.status_code == status.HTTP_200_OK + regions = response.json()["data"]["attributes"]["regions"] + assert regions == ["eu-west-1"] + + def test_compliance_overview_requirements_accepts_provider_filters( + self, + authenticated_client, + providers_fixture, + ): + _, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data( + providers_fixture + ) + + response = authenticated_client.get( + reverse("complianceoverview-requirements"), + { + "filter[provider_id__in]": ( + f"{latest_scan1.provider_id},{latest_scan2.provider_id}" + ), + "filter[compliance_id]": "cis_1.4_aws", + }, + ) + + assert response.status_code == status.HTTP_200_OK + requirements_by_id = { + item["id"]: item["attributes"] for item in response.json()["data"] + } + assert requirements_by_id["1.1"]["status"] == "FAIL" + assert requirements_by_id["1.2"]["status"] == "PASS" + def test_compliance_overview_metadata( self, authenticated_client, compliance_requirements_overviews_fixture ): @@ -10031,6 +10568,40 @@ class TestOverviewViewSet: for entry in grouped_data: assert "findings" not in entry["attributes"] + def test_overview_providers_count_applies_limited_visibility( + self, + authenticated_client_no_permissions_rbac, + providers_fixture, + provider_groups_fixture, + tenants_fixture, + ): + tenant = tenants_fixture[0] + client = authenticated_client_no_permissions_rbac + allowed_provider = providers_fixture[2] + denied_provider = providers_fixture[4] + provider_group = provider_groups_fixture[0] + + ProviderGroupMembership.objects.create( + tenant_id=tenant.id, + provider_group=provider_group, + provider=allowed_provider, + ) + RoleProviderGroupRelationship.objects.create( + tenant_id=tenant.id, + role=client.user.roles.first(), + provider_group=provider_group, + ) + + response = client.get(reverse("overview-providers-count")) + + assert response.status_code == status.HTTP_200_OK + aggregated = { + entry["id"]: entry["attributes"]["count"] + for entry in response.json()["data"] + } + assert aggregated == {allowed_provider.provider: 1} + assert denied_provider.provider not in aggregated + def _create_scan(self, tenant, provider, name, started_at=None): scan_started = started_at or datetime.now(timezone.utc) - timedelta(hours=1) return Scan.objects.create( @@ -10578,6 +11149,87 @@ class TestOverviewViewSet: assert combined_attributes["muted"] == 3 assert combined_attributes["total"] == 14 + def test_overview_findings_provider_groups_filter( + self, + authenticated_client, + tenants_fixture, + providers_fixture, + provider_groups_fixture, + ): + tenant = tenants_fixture[0] + provider1, provider2, *_ = providers_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider2, provider_group=group2 + ) + + scan1 = Scan.objects.create( + name="scan-provider-group-one", + provider=provider1, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + scan2 = Scan.objects.create( + name="scan-provider-group-two", + provider=provider2, + trigger=Scan.TriggerChoices.MANUAL, + state=StateChoices.COMPLETED, + tenant=tenant, + ) + ScanSummary.objects.create( + tenant=tenant, + scan=scan1, + check_id="check-provider-group-one", + service="service-a", + severity="high", + region="region-a", + _pass=5, + fail=1, + muted=2, + total=8, + ) + ScanSummary.objects.create( + tenant=tenant, + scan=scan2, + check_id="check-provider-group-two", + service="service-b", + severity="medium", + region="region-b", + _pass=2, + fail=3, + muted=1, + total=6, + ) + + response = authenticated_client.get( + reverse("overview-findings"), + {"filter[provider_groups]": str(group1.id)}, + ) + assert response.status_code == status.HTTP_200_OK + attributes = response.json()["data"]["attributes"] + assert attributes["pass"] == 5 + assert attributes["fail"] == 1 + assert attributes["muted"] == 2 + assert attributes["total"] == 8 + + response = authenticated_client.get( + reverse("overview-findings"), + {"filter[provider_groups__in]": f"{group1.id},{group2.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + attributes = response.json()["data"]["attributes"] + assert attributes["pass"] == 7 + assert attributes["fail"] == 4 + assert attributes["muted"] == 3 + assert attributes["total"] == 14 + def test_overview_findings_severity_provider_id_in_filter( self, authenticated_client, tenants_fixture, providers_fixture ): @@ -11346,9 +11998,21 @@ class TestOverviewViewSet: @pytest.mark.parametrize( "filter_key,filter_value_fn,expected_total,expected_failed", [ - ("filter[provider_id]", lambda p1, _: str(p1.id), 10, 5), + ("filter[provider_id]", lambda p1, *_: str(p1.id), 10, 5), ("filter[provider_type]", lambda *_: "aws", 10, 5), ("filter[provider_type__in]", lambda *_: "aws,gcp", 30, 20), + ( + "filter[provider_groups]", + lambda p1, _, group1, __: str(group1.id), + 10, + 5, + ), + ( + "filter[provider_groups__in]", + lambda p1, _, group1, group2: f"{group1.id},{group2.id}", + 30, + 20, + ), ], ) def test_overview_categories_filters( @@ -11356,6 +12020,7 @@ class TestOverviewViewSet: authenticated_client, tenants_fixture, providers_fixture, + provider_groups_fixture, create_scan_category_summary, filter_key, filter_value_fn, @@ -11364,6 +12029,16 @@ class TestOverviewViewSet: ): tenant = tenants_fixture[0] provider1, _, gcp_provider, *_ = providers_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=gcp_provider, provider_group=group2 + ) scan1 = Scan.objects.create( name="categories-scan-1", @@ -11389,7 +12064,7 @@ class TestOverviewViewSet: response = authenticated_client.get( reverse("overview-categories"), - {filter_key: filter_value_fn(provider1, gcp_provider)}, + {filter_key: filter_value_fn(provider1, gcp_provider, group1, group2)}, ) assert response.status_code == status.HTTP_200_OK data = response.json()["data"] @@ -11563,10 +12238,22 @@ class TestOverviewViewSet: @pytest.mark.parametrize( "filter_key,filter_value_fn,expected_total,expected_failed", [ - ("filter[provider_id]", lambda p1, p2: str(p1.id), 10, 5), - ("filter[provider_id__in]", lambda p1, p2: f"{p1.id},{p2.id}", 25, 12), - ("filter[provider_type]", lambda p1, p2: "aws", 10, 5), - ("filter[provider_type__in]", lambda p1, p2: "aws,gcp", 25, 12), + ("filter[provider_id]", lambda p1, *_: str(p1.id), 10, 5), + ("filter[provider_id__in]", lambda p1, p2, *_: f"{p1.id},{p2.id}", 25, 12), + ("filter[provider_type]", lambda *_: "aws", 10, 5), + ("filter[provider_type__in]", lambda *_: "aws,gcp", 25, 12), + ( + "filter[provider_groups]", + lambda p1, p2, group1, group2: str(group1.id), + 10, + 5, + ), + ( + "filter[provider_groups__in]", + lambda p1, p2, group1, group2: f"{group1.id},{group2.id}", + 25, + 12, + ), ], ) def test_overview_groups_provider_filters( @@ -11574,6 +12261,7 @@ class TestOverviewViewSet: authenticated_client, tenants_fixture, providers_fixture, + provider_groups_fixture, create_scan_resource_group_summary, filter_key, filter_value_fn, @@ -11583,6 +12271,16 @@ class TestOverviewViewSet: tenant = tenants_fixture[0] provider1 = providers_fixture[0] # AWS gcp_provider = providers_fixture[2] # GCP + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=gcp_provider, provider_group=group2 + ) scan1 = Scan.objects.create( name="aws-rg-scan", @@ -11608,7 +12306,7 @@ class TestOverviewViewSet: response = authenticated_client.get( reverse("overview-resource-groups"), - {filter_key: filter_value_fn(provider1, gcp_provider)}, + {filter_key: filter_value_fn(provider1, gcp_provider, group1, group2)}, ) assert response.status_code == status.HTTP_200_OK data = response.json()["data"] @@ -11783,6 +12481,49 @@ class TestOverviewViewSet: data = response.json()["data"] assert len(data) >= 1 + def test_compliance_watchlist_provider_groups_filter( + self, + authenticated_client, + provider_compliance_scores_fixture, + providers_fixture, + provider_groups_fixture, + tenants_fixture, + ): + tenant = tenants_fixture[0] + provider1, provider2, *_ = providers_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider2, provider_group=group2 + ) + + response = authenticated_client.get( + reverse("overview-compliance-watchlist"), + {"filter[provider_groups]": str(group1.id)}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + by_id = {item["id"]: item["attributes"] for item in data} + assert by_id["aws_cis_2.0"]["requirements_passed"] == 1 + assert by_id["aws_cis_2.0"]["requirements_failed"] == 1 + assert by_id["aws_cis_2.0"]["requirements_manual"] == 1 + + response = authenticated_client.get( + reverse("overview-compliance-watchlist"), + {"filter[provider_groups__in]": f"{group1.id},{group2.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json()["data"] + by_id = {item["id"]: item["attributes"] for item in data} + assert by_id["aws_cis_2.0"]["requirements_passed"] == 0 + assert by_id["aws_cis_2.0"]["requirements_failed"] == 2 + assert by_id["aws_cis_2.0"]["requirements_manual"] == 1 + def test_compliance_watchlist_empty_result(self, authenticated_client): response = authenticated_client.get(reverse("overview-compliance-watchlist")) assert response.status_code == status.HTTP_200_OK @@ -17082,6 +17823,44 @@ class TestFindingGroupViewSet: # All fixture findings are from AWS provider assert len(response.json()["data"]) == 5 + def test_finding_groups_provider_groups_filter( + self, + authenticated_client, + tenants_fixture, + finding_groups_fixture, + providers_fixture, + provider_groups_fixture, + ): + tenant = tenants_fixture[0] + provider1, provider2, *_ = providers_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider2, provider_group=group2 + ) + + response = authenticated_client.get( + reverse("finding-group-list"), + {"filter[inserted_at]": TODAY, "filter[provider_groups]": str(group1.id)}, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["data"]) == 4 + + response = authenticated_client.get( + reverse("finding-group-list"), + { + "filter[inserted_at]": TODAY, + "filter[provider_groups__in]": f"{group1.id},{group2.id}", + }, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["data"]) == 5 + def test_finding_groups_check_id_filter( self, authenticated_client, finding_groups_fixture ): @@ -17992,6 +18771,41 @@ class TestFindingGroupViewSet: # All providers in fixture are AWS assert len(data) == 5 + def test_finding_groups_latest_provider_groups_filter( + self, + authenticated_client, + tenants_fixture, + finding_groups_fixture, + providers_fixture, + provider_groups_fixture, + ): + tenant = tenants_fixture[0] + provider1, provider2, *_ = providers_fixture + group1, group2, *_ = provider_groups_fixture + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group1 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider1, provider_group=group2 + ) + ProviderGroupMembership.objects.create( + tenant=tenant, provider=provider2, provider_group=group2 + ) + + response = authenticated_client.get( + reverse("finding-group-latest"), + {"filter[provider_groups]": str(group1.id)}, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["data"]) == 4 + + response = authenticated_client.get( + reverse("finding-group-latest"), + {"filter[provider_groups__in]": f"{group1.id},{group2.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["data"]) == 5 + def test_finding_groups_latest_check_id_filter( self, authenticated_client, finding_groups_fixture ): diff --git a/api/src/backend/api/v1/mixins.py b/api/src/backend/api/v1/mixins.py index e1a5d3470f..849c788366 100644 --- a/api/src/backend/api/v1/mixins.py +++ b/api/src/backend/api/v1/mixins.py @@ -1,6 +1,10 @@ +import uuid + +from django.http import QueryDict from django.urls import reverse from django_celery_results.models import TaskResult from rest_framework import status +from rest_framework.exceptions import ValidationError from rest_framework.response import Response from api.exceptions import ( @@ -8,7 +12,7 @@ from api.exceptions import ( TaskInProgressException, TaskNotFoundException, ) -from api.models import StateChoices, Task +from api.models import Provider, StateChoices, Task from api.v1.serializers import TaskSerializer @@ -74,6 +78,162 @@ class PaginateByPkMixin: return self.get_paginated_response(serialized) +class JsonApiFilterMixin: + """Shared helpers for manually applying django-filter to JSON:API params.""" + + jsonapi_filter_replace_dots = False + + def _normalize_jsonapi_params( + self, + query_params, + exclude_keys=None, + replace_dots=None, + ): + exclude_keys = exclude_keys or set() + if replace_dots is None: + replace_dots = self.jsonapi_filter_replace_dots + + normalized = QueryDict(mutable=True) + for key, values in query_params.lists(): + normalized_key = ( + key[7:-1] if key.startswith("filter[") and key.endswith("]") else key + ) + if replace_dots: + normalized_key = normalized_key.replace(".", "__") + if normalized_key not in exclude_keys: + normalized.setlist(normalized_key, values) + return normalized + + def _apply_filterset( + self, + queryset, + filterset_class, + exclude_keys=None, + replace_dots=None, + ): + normalized_params = self._normalize_jsonapi_params( + self.request.query_params, + exclude_keys=set(exclude_keys or []), + replace_dots=replace_dots, + ) + filterset = filterset_class(normalized_params, queryset=queryset) + if not filterset.is_valid(): + raise ValidationError(filterset.errors) + return filterset.qs + + +class ProviderFilterParamsMixin(JsonApiFilterMixin): + """Shared extraction of provider filters from JSON:API query params.""" + + PROVIDER_FILTER_KEYS = frozenset( + { + "provider_id", + "provider_id__in", + "provider_type", + "provider_type__in", + "provider_groups", + "provider_groups__in", + } + ) + PROVIDER_FILTER_DOT_ALIAS_KEYS = frozenset( + { + "provider_id.in", + "provider_type.in", + "provider_groups.in", + } + ) + PROVIDER_FILTER_QUERY_KEYS = PROVIDER_FILTER_KEYS | PROVIDER_FILTER_DOT_ALIAS_KEYS + + def _csv_filter_values(self, value): + return [item.strip() for item in value.split(",") if item.strip()] + + def _validate_uuid_filter_values(self, field_name, values): + try: + for value in values: + uuid.UUID(str(value)) + except (TypeError, ValueError, AttributeError): + raise ValidationError({field_name: ["Enter a valid UUID."]}) + + def _has_provider_filters(self, include_dot_aliases=False): + provider_filter_keys = ( + self.PROVIDER_FILTER_QUERY_KEYS + if include_dot_aliases + else self.PROVIDER_FILTER_KEYS + ) + return any( + self.request.query_params.get(f"filter[{key}]") + for key in provider_filter_keys + ) + + def _extract_provider_filters_from_params( + self, + *, + validate_uuids=False, + include_dot_aliases=False, + ): + params = self.request.query_params + filters = {} + valid_provider_types = { + choice[0] for choice in Provider.ProviderChoices.choices + } + + provider_id = params.get("filter[provider_id]") + if provider_id: + if validate_uuids: + self._validate_uuid_filter_values("provider_id", [provider_id]) + filters["provider_id"] = provider_id + + provider_id_in = params.get("filter[provider_id__in]") + if include_dot_aliases: + provider_id_in = provider_id_in or params.get("filter[provider_id.in]") + if provider_id_in: + values = self._csv_filter_values(provider_id_in) + if validate_uuids: + self._validate_uuid_filter_values("provider_id__in", values) + filters["provider_id__in"] = values + + provider_type = params.get("filter[provider_type]") + if provider_type: + if provider_type not in valid_provider_types: + raise ValidationError( + {"provider_type": f"Invalid choice: {provider_type}"} + ) + filters["provider__provider"] = provider_type + + provider_type_in = params.get("filter[provider_type__in]") + if include_dot_aliases: + provider_type_in = provider_type_in or params.get( + "filter[provider_type.in]" + ) + if provider_type_in: + values = self._csv_filter_values(provider_type_in) + invalid = [value for value in values if value not in valid_provider_types] + if invalid: + raise ValidationError( + {"provider_type__in": f"Invalid choices: {', '.join(invalid)}"} + ) + filters["provider__provider__in"] = values + + provider_groups = params.get("filter[provider_groups]") + if provider_groups: + if validate_uuids: + self._validate_uuid_filter_values("provider_groups", [provider_groups]) + filters["provider__provider_groups__id"] = provider_groups + + provider_groups_in = params.get("filter[provider_groups__in]") + if include_dot_aliases: + provider_groups_in = provider_groups_in or params.get( + "filter[provider_groups.in]" + ) + if provider_groups_in: + values = self._csv_filter_values(provider_groups_in) + if validate_uuids: + self._validate_uuid_filter_values("provider_groups__in", values) + filters["provider__provider_groups__id__in"] = values + + return filters + + class TaskManagementMixin: """ Mixin to manage task status checking. diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 7bcf9a9e69..7e9aae5b94 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -228,7 +228,13 @@ from api.utils import ( validate_invitation, ) from api.uuid_utils import datetime_to_uuid7, uuid7_start -from api.v1.mixins import DisablePaginationMixin, PaginateByPkMixin, TaskManagementMixin +from api.v1.mixins import ( + DisablePaginationMixin, + JsonApiFilterMixin, + PaginateByPkMixin, + ProviderFilterParamsMixin, + TaskManagementMixin, +) from api.v1.serializers import ( AttackPathsCartographySchemaSerializer, AttackPathsCustomQueryRunRequestSerializer, @@ -4556,15 +4562,19 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet): @extend_schema_view( list=extend_schema( tags=["Compliance Overview"], - summary="List compliance overviews for a scan", - description="Retrieve an overview of all the compliance in a given scan.", + summary="List compliance overviews", + description=( + "Retrieve compliance overview data for a scan. When provider filters " + "are provided, the endpoint uses the latest completed scan for each " + "matching provider." + ), parameters=[ OpenApiParameter( name="filter[scan_id]", - required=True, + required=False, type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, - description="Related scan ID.", + description="Related scan ID. Required unless a provider filter is provided.", ), ], responses={ @@ -4579,19 +4589,23 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet): description="Compliance overviews generation task failed" ), }, + filters=True, ), metadata=extend_schema( tags=["Compliance Overview"], summary="Retrieve metadata values from compliance overviews", - description="Fetch unique metadata values from a set of compliance overviews. This is useful for dynamic " - "filtering.", + description=( + "Fetch unique metadata values from compliance overviews. This is useful " + "for dynamic filtering. When provider filters are provided, metadata is " + "computed from the latest completed scan for each matching provider." + ), parameters=[ OpenApiParameter( name="filter[scan_id]", - required=True, + required=False, type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, - description="Related scan ID.", + description="Related scan ID. Required unless a provider filter is provided.", ), ], responses={ @@ -4606,19 +4620,24 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet): description="Compliance overviews generation task failed" ), }, + filters=True, ), requirements=extend_schema( tags=["Compliance Overview"], - summary="List compliance requirements overview for a scan", - description="Retrieve a detailed overview of compliance requirements in a given scan, grouped by compliance " - "framework. This endpoint provides requirement-level details and aggregates status across regions.", + summary="List compliance requirements overview", + description=( + "Retrieve a detailed overview of compliance requirements, grouped by " + "compliance framework. This endpoint provides requirement-level details " + "and aggregates status across regions. When provider filters are provided, " + "the endpoint uses the latest completed scan for each matching provider." + ), parameters=[ OpenApiParameter( name="filter[scan_id]", - required=True, + required=False, type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, - description="Related scan ID.", + description="Related scan ID. Required unless a provider filter is provided.", ), OpenApiParameter( name="filter[compliance_id]", @@ -4677,7 +4696,10 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet): @method_decorator(CACHE_DECORATOR, name="list") @method_decorator(CACHE_DECORATOR, name="requirements") @method_decorator(CACHE_DECORATOR, name="attributes") -class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): +class ComplianceOverviewViewSet( + ProviderFilterParamsMixin, BaseRLSViewSet, TaskManagementMixin +): + jsonapi_filter_replace_dots = True pagination_class = ComplianceOverviewPagination queryset = ComplianceRequirementOverview.objects.all() serializer_class = ComplianceOverviewSerializer @@ -4691,28 +4713,22 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): required_permissions = [] def get_queryset(self): + if getattr(self, "swagger_fake_view", False): + return ComplianceRequirementOverview.objects.none() + role = get_role(self.request.user, self.request.tenant_id) unlimited_visibility = getattr( role, Permissions.UNLIMITED_VISIBILITY.value, False ) - if unlimited_visibility: - base_queryset = self.filter_queryset( - ComplianceRequirementOverview.objects.filter( - tenant_id=self.request.tenant_id - ) - ) - else: - providers = Provider.objects.filter( - provider_groups__in=role.provider_groups.all() - ).distinct() - base_queryset = self.filter_queryset( - ComplianceRequirementOverview.objects.filter( - tenant_id=self.request.tenant_id, scan__provider__in=providers - ) - ) + base_queryset = ComplianceRequirementOverview.objects.filter( + tenant_id=self.request.tenant_id + ) - return base_queryset + if unlimited_visibility: + return base_queryset + + return base_queryset.filter(scan__provider__in=get_providers(role)) def get_serializer_class(self): if hasattr(self, "response_serializer_class"): @@ -4750,6 +4766,72 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): return summaries + def _validate_scan_selection(self, scan_id, has_provider_filters): + if scan_id and has_provider_filters: + raise ValidationError( + [ + { + "detail": "Use either filter[scan_id] or provider filters.", + "status": 400, + "source": {"pointer": "filter[scan_id]"}, + "code": "invalid", + } + ] + ) + + if scan_id: + self._validate_uuid_filter_values("scan_id", [scan_id]) + return + + if has_provider_filters: + return + + raise ValidationError( + [ + { + "detail": "This query parameter is required unless a provider filter is provided.", + "status": 400, + "source": {"pointer": "filter[scan_id]"}, + "code": "required", + } + ] + ) + + def _latest_scan_ids_for_provider_filters(self): + role = get_role(self.request.user, self.request.tenant_id) + scans = Scan.all_objects.filter( + tenant_id=self.request.tenant_id, + state=StateChoices.COMPLETED, + ) + + if not getattr(role, Permissions.UNLIMITED_VISIBILITY.value, False): + scans = scans.filter(provider__in=get_providers(role)) + + provider_filters = self._extract_provider_filters_from_params( + validate_uuids=True, + include_dot_aliases=True, + ) + if provider_filters: + scans = scans.filter(**provider_filters) + + return list( + scans.order_by("provider_id", "-inserted_at") + .distinct("provider_id") + .values_list("id", flat=True) + ) + + def _filtered_queryset_for_latest_provider_scans(self, latest_scan_ids=None): + if latest_scan_ids is None: + latest_scan_ids = self._latest_scan_ids_for_provider_filters() + queryset = self.get_queryset().filter(scan_id__in=latest_scan_ids) + # Provider filters stay on the filterset for OpenAPI docs, but runtime + # filtering happens on Scan first so compliance queries use scan IDs. + return self._apply_filterset( + queryset, + self.filterset_class, + exclude_keys=self.PROVIDER_FILTER_KEYS | {"scan_id"}, + ) + def _get_compliance_template(self, *, provider=None, scan_id=None): """Return the compliance template for the given provider or scan.""" if provider is None and scan_id is not None: @@ -4865,6 +4947,36 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) + def _task_response_for_latest_provider_scans(self, latest_scan_ids): + for scan_id in latest_scan_ids: + task_response = self._task_response_if_running(str(scan_id)) + if task_response: + return task_response + return None + + def _latest_provider_scan_ids_without_data(self, latest_scan_ids): + data_presence_queryset = self.get_queryset().filter(scan_id__in=latest_scan_ids) + scan_ids_with_data = { + str(scan_id) + for scan_id in data_presence_queryset.values_list( + "scan_id", flat=True + ).distinct() + } + return [ + scan_id + for scan_id in latest_scan_ids + if str(scan_id) not in scan_ids_with_data + ] + + def _task_response_for_latest_provider_scans_without_data( + self, + latest_scan_ids, + ): + scan_ids_to_check = self._latest_provider_scan_ids_without_data( + latest_scan_ids, + ) + return self._task_response_for_latest_provider_scans(scan_ids_to_check) + def _list_with_region_filter(self, scan_id, region_filter): """ Fall back to detailed ComplianceRequirementOverview query when region filter is applied. @@ -4905,8 +5017,25 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): return Response(data) + def _list_with_latest_provider_filters(self): + latest_scan_ids = self._latest_scan_ids_for_provider_filters() + queryset = self._filtered_queryset_for_latest_provider_scans(latest_scan_ids) + data = self._aggregate_compliance_overview(queryset) + task_response = self._task_response_for_latest_provider_scans_without_data( + latest_scan_ids, + ) + if task_response: + return task_response + + return Response(data) + def list(self, request, *args, **kwargs): scan_id = request.query_params.get("filter[scan_id]") + has_provider_filters = self._has_provider_filters(include_dot_aliases=True) + self._validate_scan_selection(scan_id, has_provider_filters) + + if has_provider_filters: + return self._list_with_latest_provider_filters() # Specific scan requested - use optimized summaries with region support region_filter = request.query_params.get( @@ -4952,33 +5081,34 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): @action(detail=False, methods=["get"], url_name="metadata") def metadata(self, request): scan_id = request.query_params.get("filter[scan_id]") - if not scan_id: - raise ValidationError( - [ - { - "detail": "This query parameter is required.", - "status": 400, - "source": {"pointer": "filter[scan_id]"}, - "code": "required", - } - ] + has_provider_filters = self._has_provider_filters(include_dot_aliases=True) + self._validate_scan_selection(scan_id, has_provider_filters) + + latest_scan_ids = None + if has_provider_filters: + latest_scan_ids = self._latest_scan_ids_for_provider_filters() + queryset = self._filtered_queryset_for_latest_provider_scans( + latest_scan_ids ) + else: + queryset = self._apply_filterset(self.get_queryset(), self.filterset_class) + regions = list( - self.get_queryset() - .filter(scan_id=scan_id) - .values_list("region", flat=True) - .order_by("region") - .distinct() + queryset.values_list("region", flat=True).order_by("region").distinct() ) result = {"regions": regions} - if regions: - serializer = self.get_serializer(data=result) - serializer.is_valid(raise_exception=True) - return Response(serializer.data, status=status.HTTP_200_OK) + task_response = None + if has_provider_filters: + task_response = self._task_response_for_latest_provider_scans_without_data( + latest_scan_ids, + ) + elif not regions: + task_response = self._task_response_if_running(scan_id) + if task_response: + return task_response - task_response = self._task_response_if_running(scan_id) - if task_response: + if has_provider_filters and task_response: return task_response serializer = self.get_serializer(data=result) @@ -4988,19 +5118,10 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): @action(detail=False, methods=["get"], url_name="requirements") def requirements(self, request): scan_id = request.query_params.get("filter[scan_id]") + has_provider_filters = self._has_provider_filters(include_dot_aliases=True) compliance_id = request.query_params.get("filter[compliance_id]") - if not scan_id: - raise ValidationError( - [ - { - "detail": "This query parameter is required.", - "status": 400, - "source": {"pointer": "filter[scan_id]"}, - "code": "required", - } - ] - ) + self._validate_scan_selection(scan_id, has_provider_filters) if not compliance_id: raise ValidationError( @@ -5013,7 +5134,16 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): } ] ) - filtered_queryset = self.filter_queryset(self.get_queryset()) + latest_scan_ids = None + if has_provider_filters: + latest_scan_ids = self._latest_scan_ids_for_provider_filters() + filtered_queryset = self._filtered_queryset_for_latest_provider_scans( + latest_scan_ids + ) + else: + filtered_queryset = self._apply_filterset( + self.get_queryset(), self.filterset_class + ) all_requirements = filtered_queryset.values( "requirement_id", @@ -5072,13 +5202,22 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): requirements_summary, many=True ) + task_response = None + if has_provider_filters: + task_response = self._task_response_for_latest_provider_scans_without_data( + latest_scan_ids, + ) + elif not requirements_summary: + task_response = self._task_response_if_running(scan_id) + if task_response: + return task_response + + if has_provider_filters and task_response: + return task_response + if requirements_summary: return Response(serializer.data, status=status.HTTP_200_OK) - task_response = self._task_response_if_running(scan_id) - if task_response: - return task_response - return Response(serializer.data, status=status.HTTP_200_OK) @action(detail=False, methods=["get"], url_name="attributes") @@ -5317,7 +5456,7 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin): ), ) @method_decorator(CACHE_DECORATOR, name="list") -class OverviewViewSet(BaseRLSViewSet): +class OverviewViewSet(ProviderFilterParamsMixin, BaseRLSViewSet): queryset = ScanSummary.objects.all() http_method_names = ["get"] ordering = ["-inserted_at"] @@ -5434,18 +5573,6 @@ class OverviewViewSet(BaseRLSViewSet): tenant_id=tenant_id, scan_id__in=latest_scan_ids ) - def _normalize_jsonapi_params(self, query_params, exclude_keys=None): - """Convert JSON:API filter params (filter[X]) to flat params (X).""" - exclude_keys = exclude_keys or set() - normalized = QueryDict(mutable=True) - for key, values in query_params.lists(): - normalized_key = ( - key[7:-1] if key.startswith("filter[") and key.endswith("]") else key - ) - if normalized_key not in exclude_keys: - normalized.setlist(normalized_key, values) - return normalized - def _ensure_allowed_providers(self): """Populate allowed providers for RBAC-aware queries once per request.""" if getattr(self, "_providers_initialized", False): @@ -5465,15 +5592,6 @@ class OverviewViewSet(BaseRLSViewSet): return queryset.filter(**provider_filter) return queryset - def _apply_filterset(self, queryset, filterset_class, exclude_keys=None): - normalized_params = self._normalize_jsonapi_params( - self.request.query_params, exclude_keys=set(exclude_keys or []) - ) - filterset = filterset_class(normalized_params, queryset=queryset) - if not filterset.is_valid(): - raise ValidationError(filterset.errors) - return filterset.qs - def _latest_scan_ids_for_allowed_providers(self, tenant_id, provider_filters=None): provider_filter = self._get_provider_filter() queryset = Scan.all_objects.filter( @@ -5487,40 +5605,6 @@ class OverviewViewSet(BaseRLSViewSet): .values_list("id", flat=True) ) - def _extract_provider_filters_from_params(self): - """Extract and validate provider filters from query params.""" - params = self.request.query_params - filters = {} - valid_provider_types = {c[0] for c in Provider.ProviderChoices.choices} - - provider_id = params.get("filter[provider_id]") - if provider_id: - filters["provider_id"] = provider_id - - provider_id_in = params.get("filter[provider_id__in]") - if provider_id_in: - filters["provider_id__in"] = provider_id_in.split(",") - - provider_type = params.get("filter[provider_type]") - if provider_type: - if provider_type not in valid_provider_types: - raise ValidationError( - {"provider_type": f"Invalid choice: {provider_type}"} - ) - filters["provider__provider"] = provider_type - - provider_type_in = params.get("filter[provider_type__in]") - if provider_type_in: - types = provider_type_in.split(",") - invalid = [t for t in types if t not in valid_provider_types] - if invalid: - raise ValidationError( - {"provider_type__in": f"Invalid choices: {', '.join(invalid)}"} - ) - filters["provider__provider__in"] = types - - return filters - @action(detail=False, methods=["get"], url_name="providers") def providers(self, request): tenant_id = self.request.tenant_id @@ -5591,15 +5675,11 @@ class OverviewViewSet(BaseRLSViewSet): tenant_id = self.request.tenant_id providers_qs = Provider.objects.filter(tenant_id=tenant_id) + self._ensure_allowed_providers() if hasattr(self, "allowed_providers"): - allowed_ids = list(self.allowed_providers.values_list("id", flat=True)) - if not allowed_ids: - overview = [] - return Response( - self.get_serializer(overview, many=True).data, - status=status.HTTP_200_OK, - ) - providers_qs = providers_qs.filter(id__in=allowed_ids) + providers_qs = providers_qs.filter( + id__in=self.allowed_providers.values("id") + ) overview = ( providers_qs.values("provider") @@ -5815,29 +5895,41 @@ class OverviewViewSet(BaseRLSViewSet): description="Retrieve a specific snapshot by ID. If not provided, returns latest snapshots.", ), OpenApiParameter( - name="provider_id", + name="filter[provider_id]", type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, description="Filter by specific provider ID", ), OpenApiParameter( - name="provider_id__in", + name="filter[provider_id__in]", type=OpenApiTypes.STR, location=OpenApiParameter.QUERY, description="Filter by multiple provider IDs (comma-separated UUIDs)", ), OpenApiParameter( - name="provider_type", + name="filter[provider_type]", type=OpenApiTypes.STR, location=OpenApiParameter.QUERY, description="Filter by provider type (aws, azure, gcp, etc.)", ), OpenApiParameter( - name="provider_type__in", + name="filter[provider_type__in]", type=OpenApiTypes.STR, location=OpenApiParameter.QUERY, description="Filter by multiple provider types (comma-separated)", ), + OpenApiParameter( + name="filter[provider_groups]", + type=OpenApiTypes.UUID, + location=OpenApiParameter.QUERY, + description="Filter by provider group ID", + ), + OpenApiParameter( + name="filter[provider_groups__in]", + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + description="Filter by multiple provider group IDs (comma-separated UUIDs)", + ), ], ) @action(detail=False, methods=["get"], url_name="threatscore") @@ -6179,6 +6271,8 @@ class OverviewViewSet(BaseRLSViewSet): "provider_id__in", "provider_type", "provider_type__in", + "provider_groups", + "provider_groups__in", } filtered_queryset = self._apply_filterset( base_queryset, CategoryOverviewFilter, exclude_keys=provider_filter_keys @@ -6248,6 +6342,8 @@ class OverviewViewSet(BaseRLSViewSet): "provider_id__in", "provider_type", "provider_type__in", + "provider_groups", + "provider_groups__in", } filtered_queryset = self._apply_filterset( base_queryset, @@ -7298,7 +7394,7 @@ SEVERITY_ORDER_REVERSE = {v: k for k, v in SEVERITY_ORDER.items()} ), retrieve=extend_schema(exclude=True), ) -class FindingGroupViewSet(BaseRLSViewSet): +class FindingGroupViewSet(JsonApiFilterMixin, BaseRLSViewSet): """ ViewSet for Finding Groups - aggregates findings by check_id. @@ -7314,6 +7410,7 @@ class FindingGroupViewSet(BaseRLSViewSet): queryset = FindingGroupDailySummary.objects.all() serializer_class = FindingGroupSerializer filterset_class = FindingGroupFilter + jsonapi_filter_replace_dots = True filter_backends = [ jsonapi_filters.QueryParameterValidationFilter, jsonapi_filters.OrderingFilter, @@ -7364,18 +7461,6 @@ class FindingGroupViewSet(BaseRLSViewSet): return queryset - def _normalize_jsonapi_params(self, query_params): - """Convert JSON:API filter params (filter[X]) to flat params (X).""" - normalized = QueryDict(mutable=True) - for key, values in query_params.lists(): - normalized_key = ( - key[7:-1] if key.startswith("filter[") and key.endswith("]") else key - ) - # Convert JSON:API dot notation to Django double underscore - normalized_key = normalized_key.replace(".", "__") - normalized.setlist(normalized_key, values) - return normalized - @extend_schema(exclude=True) def retrieve(self, request, *args, **kwargs): raise MethodNotAllowed(method="GET") @@ -8494,9 +8579,10 @@ class FindingGroupViewSet(BaseRLSViewSet): This endpoint returns finding groups without requiring date filters, automatically using the latest available data per check_id. - All other filters (provider_id, provider_type, check_id) are still supported. + Provider, provider group, check, and computed filters are still supported. """, tags=["Finding Groups"], + filters=True, ) @action(detail=False, methods=["get"], url_name="latest") def latest(self, request):