feat(api): add provider group filters (#11573)

This commit is contained in:
Adrián Peña
2026-06-16 14:18:34 +02:00
committed by GitHub
parent 181197177c
commit e4d5ca11b3
5 changed files with 1338 additions and 163 deletions
+2
View File
@@ -6,6 +6,8 @@ All notable changes to the **Prowler API** are documented in this file.
### 🚀 Added
- Provider group filters for API endpoints that support cloud provider filtering, including exact and `__in` variants [(#11573)](https://github.com/prowler-cloud/prowler/pull/11573)
- Provider filters for `GET /api/v1/compliance-overviews`, `/metadata`, and `/requirements`, using latest completed scans per matching provider [(#11587)](https://github.com/prowler-cloud/prowler/pull/11587)
- Server-Sent Events (SSE) infrastructure for the API: a base viewset, a tenant-aware channel manager, and channel-name helpers backed by `django-eventstream` over Valkey Pub/Sub and served through the Gunicorn ASGI worker, so feature endpoints can stream events to clients over a single long-lived connection [(#11556)](https://github.com/prowler-cloud/prowler/pull/11556)
### 🔐 Security
+118 -5
View File
@@ -102,7 +102,7 @@ class BaseProviderFilter(FilterSet):
"""
Abstract base filter for models with direct FK to Provider.
Provides standard provider_id and provider_type filters.
Provides standard provider_id, provider_type, and provider_groups filters.
Subclasses must define Meta.model.
"""
@@ -116,6 +116,16 @@ class BaseProviderFilter(FilterSet):
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
provider_groups = UUIDFilter(
field_name="provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
class Meta:
abstract = True
@@ -126,7 +136,7 @@ class BaseScanProviderFilter(FilterSet):
"""
Abstract base filter for models with FK to Scan (and Scan has FK to Provider).
Provides standard provider_id and provider_type filters via scan relationship.
Provides standard provider_id, provider_type, and provider_groups filters via scan relationship.
Subclasses must define Meta.model.
"""
@@ -140,6 +150,16 @@ class BaseScanProviderFilter(FilterSet):
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
provider_groups = UUIDFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
class Meta:
abstract = True
@@ -160,6 +180,16 @@ class CommonFindingFilters(FilterSet):
provider_type__in = ChoiceInFilter(
choices=Provider.ProviderChoices.choices, field_name="scan__provider__provider"
)
provider_groups = UUIDFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
provider_uid = CharFilter(field_name="scan__provider__uid", lookup_expr="exact")
provider_uid__in = CharInFilter(field_name="scan__provider__uid", lookup_expr="in")
provider_uid__icontains = CharFilter(
@@ -370,6 +400,12 @@ class ProviderFilter(FilterSet):
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
provider_groups = UUIDFilter(
field_name="provider_groups__id", lookup_expr="exact", distinct=True
)
provider_groups__in = UUIDInFilter(
field_name="provider_groups__id", lookup_expr="in", distinct=True
)
class Meta:
model = Provider
@@ -395,6 +431,16 @@ class ProviderRelationshipFilterSet(FilterSet):
provider_type__in = ChoiceInFilter(
choices=Provider.ProviderChoices.choices, field_name="provider__provider"
)
provider_groups = UUIDFilter(
field_name="provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
provider_uid = CharFilter(field_name="provider__uid", lookup_expr="exact")
provider_uid__in = CharInFilter(field_name="provider__uid", lookup_expr="in")
provider_uid__icontains = CharFilter(
@@ -1001,6 +1047,16 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
field_name="provider__provider", choices=Provider.ProviderChoices.choices
)
provider_type__in = CharInFilter(field_name="provider__provider", lookup_expr="in")
provider_groups = UUIDFilter(
field_name="provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
class Meta:
model = FindingGroupDailySummary
@@ -1101,6 +1157,16 @@ class LatestFindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
field_name="provider__provider", choices=Provider.ProviderChoices.choices
)
provider_type__in = CharInFilter(field_name="provider__provider", lookup_expr="in")
provider_groups = UUIDFilter(
field_name="provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
class Meta:
model = FindingGroupDailySummary
@@ -1280,12 +1346,19 @@ class RoleFilter(FilterSet):
}
class ComplianceOverviewFilter(FilterSet):
class ComplianceOverviewFilter(BaseScanProviderFilter):
"""
Keep provider filters in the schema while runtime filtering resolves scans first.
Compliance overview provider filters are applied to the latest completed scans
in the viewset, then this filterset handles the remaining compliance fields.
"""
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
scan_id = UUIDFilter(field_name="scan_id", required=True)
scan_id = UUIDFilter(field_name="scan_id")
region = CharFilter(field_name="region")
class Meta:
class Meta(BaseScanProviderFilter.Meta):
model = ComplianceRequirementOverview
fields = {
"inserted_at": ["date", "gte", "lte"],
@@ -1306,6 +1379,16 @@ class ScanSummaryFilter(FilterSet):
provider_type__in = ChoiceInFilter(
field_name="scan__provider__provider", choices=Provider.ProviderChoices.choices
)
provider_groups = UUIDFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
region = CharFilter(field_name="region")
class Meta:
@@ -1329,6 +1412,16 @@ class DailySeveritySummaryFilter(FilterSet):
provider_type__in = ChoiceInFilter(
field_name="provider__provider", choices=Provider.ProviderChoices.choices
)
provider_groups = UUIDFilter(
field_name="provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
date_from = DateFilter(method="filter_noop")
date_to = DateFilter(method="filter_noop")
@@ -1585,6 +1678,16 @@ class ThreatScoreSnapshotFilter(FilterSet):
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
provider_groups = UUIDFilter(
field_name="provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
compliance_id = CharFilter(field_name="compliance_id", lookup_expr="exact")
compliance_id__in = CharInFilter(field_name="compliance_id", lookup_expr="in")
@@ -1628,6 +1731,16 @@ class ResourceGroupOverviewFilter(FilterSet):
choices=Provider.ProviderChoices.choices,
lookup_expr="in",
)
provider_groups = UUIDFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="exact",
distinct=True,
)
provider_groups__in = UUIDInFilter(
field_name="scan__provider__provider_groups__id",
lookup_expr="in",
distinct=True,
)
resource_group = CharFilter(field_name="resource_group", lookup_expr="exact")
resource_group__in = CharInFilter(field_name="resource_group", lookup_expr="in")
+821 -7
View File
@@ -1411,6 +1411,42 @@ class TestProviderViewSet:
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_providers_filter_provider_groups(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
provider_groups_fixture,
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider2, provider_group=group2
)
response = authenticated_client.get(
reverse("provider-list"), {"filter[provider_groups]": str(group1.id)}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert [item["id"] for item in data] == [str(provider1.id)]
response = authenticated_client.get(
reverse("provider-list"),
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
)
assert response.status_code == status.HTTP_200_OK
provider_ids = {item["id"] for item in response.json()["data"]}
assert provider_ids == {str(provider1.id), str(provider2.id)}
assert len(response.json()["data"]) == 2
def test_providers_disable_pagination(
self, authenticated_client, providers_fixture, tenants_fixture
):
@@ -3715,6 +3751,41 @@ class TestScanViewSet:
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == expected_count
def test_scans_filter_provider_groups(
self,
authenticated_client,
tenants_fixture,
scans_fixture,
provider_groups_fixture,
):
tenant = tenants_fixture[0]
scan1, scan2, *_ = scans_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=scan1.provider, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=scan1.provider, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=scan2.provider, provider_group=group2
)
response = authenticated_client.get(
reverse("scan-list"), {"filter[provider_groups]": str(group1.id)}
)
assert response.status_code == status.HTTP_200_OK
assert {item["id"] for item in response.json()["data"]} == {str(scan1.id)}
response = authenticated_client.get(
reverse("scan-list"),
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
)
assert response.status_code == status.HTTP_200_OK
scan_ids = {item["id"] for item in response.json()["data"]}
assert scan_ids == {str(scan1.id), str(scan2.id), str(scans_fixture[2].id)}
assert len(response.json()["data"]) == 3
@pytest.mark.parametrize(
"filter_name",
[
@@ -5996,6 +6067,49 @@ class TestResourceViewSet:
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == expected_count
def test_resource_filter_provider_groups(
self,
authenticated_client,
tenants_fixture,
resources_fixture,
provider_groups_fixture,
):
tenant = tenants_fixture[0]
resource1, resource2, resource3, *_ = resources_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=resource1.provider, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=resource1.provider, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=resource3.provider, provider_group=group2
)
response = authenticated_client.get(
reverse("resource-list"),
{"filter[updated_at]": TODAY, "filter[provider_groups]": str(group1.id)},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 2
assert {item["id"] for item in response.json()["data"]} == {
str(resource1.id),
str(resource2.id),
}
response = authenticated_client.get(
reverse("resource-list"),
{
"filter[updated_at]": TODAY,
"filter[provider_groups__in]": f"{group1.id},{group2.id}",
},
)
assert response.status_code == status.HTTP_200_OK
resource_ids = {item["id"] for item in response.json()["data"]}
assert resource_ids == {str(resource1.id), str(resource2.id), str(resource3.id)}
assert len(response.json()["data"]) == 3
def test_resource_filter_by_scan_id(
self, authenticated_client, resources_fixture, scans_fixture
):
@@ -7308,6 +7422,40 @@ class TestFindingViewSet:
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 2
def test_finding_filter_provider_groups(
self,
authenticated_client,
tenants_fixture,
findings_fixture,
provider_groups_fixture,
):
tenant = tenants_fixture[0]
finding1, finding2, *_ = findings_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=finding1.scan.provider, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=finding1.scan.provider, provider_group=group2
)
response = authenticated_client.get(
reverse("finding-list"),
{"filter[inserted_at]": TODAY, "filter[provider_groups]": str(group1.id)},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 2
response = authenticated_client.get(
reverse("finding-list"),
{
"filter[inserted_at]": TODAY,
"filter[provider_groups__in]": f"{group1.id},{group2.id}",
},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 2
@pytest.mark.parametrize(
"filter_name",
(
@@ -9278,6 +9426,118 @@ class TestComplianceOverviewViewSet:
with patch("api.v1.views.backfill_compliance_summaries_task.delay") as mock:
yield mock
def _create_completed_scan(self, provider, name):
return Scan.objects.create(
name=name,
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant_id=provider.tenant_id,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
def _create_requirement(
self,
scan,
requirement_id,
status_choice,
region="eu-west-1",
compliance_id="cis_1.4_aws",
):
passed = 1 if status_choice == StatusChoices.PASS else 0
total = 1 if status_choice != StatusChoices.MANUAL else 0
return ComplianceRequirementOverview.objects.create(
tenant_id=scan.tenant_id,
scan=scan,
compliance_id=compliance_id,
framework="CIS-1.4-AWS",
version="1.4",
description="CIS AWS Foundations Benchmark v1.4.0",
region=region,
requirement_id=requirement_id,
requirement_status=status_choice,
passed_checks=passed,
failed_checks=0
if status_choice in (StatusChoices.PASS, StatusChoices.MANUAL)
else 1,
total_checks=total,
passed_findings=passed,
total_findings=total,
)
def _create_compliance_summary(
self,
scan,
*,
passed,
failed,
manual=0,
compliance_id="cis_1.4_aws",
):
return ComplianceOverviewSummary.objects.create(
tenant_id=scan.tenant_id,
scan=scan,
compliance_id=compliance_id,
requirements_passed=passed,
requirements_failed=failed,
requirements_manual=manual,
total_requirements=passed + failed + manual,
)
def _overview_attrs_by_id(self, response):
assert response.status_code == status.HTTP_200_OK
return {item["id"]: item["attributes"] for item in response.json()["data"]}
def _prepare_latest_compliance_data(self, providers_fixture):
provider1, provider2, provider3, *_ = providers_fixture
old_scan = self._create_completed_scan(provider1, "old aws compliance scan")
latest_scan1 = self._create_completed_scan(
provider1, "latest aws compliance scan 1"
)
latest_scan2 = self._create_completed_scan(
provider2, "latest aws compliance scan 2"
)
latest_gcp_scan = self._create_completed_scan(
provider3, "latest gcp compliance scan"
)
self._create_requirement(old_scan, "1.1", StatusChoices.FAIL)
self._create_requirement(old_scan, "1.2", StatusChoices.FAIL)
self._create_compliance_summary(old_scan, passed=0, failed=2)
self._create_requirement(
latest_scan1, "1.1", StatusChoices.PASS, region="eu-west-1"
)
self._create_requirement(
latest_scan1, "1.2", StatusChoices.PASS, region="eu-west-1"
)
self._create_compliance_summary(latest_scan1, passed=2, failed=0)
self._create_requirement(
latest_scan2, "1.1", StatusChoices.FAIL, region="us-east-1"
)
self._create_requirement(
latest_scan2, "1.2", StatusChoices.PASS, region="us-east-1"
)
self._create_compliance_summary(latest_scan2, passed=1, failed=1)
self._create_requirement(
latest_gcp_scan,
"gcp-1.1",
StatusChoices.FAIL,
region="europe-west1",
compliance_id="cis_1.3_gcp",
)
self._create_compliance_summary(
latest_gcp_scan,
passed=0,
failed=1,
compliance_id="cis_1.3_gcp",
)
return old_scan, latest_scan1, latest_scan2, latest_gcp_scan
def test_compliance_overview_list_none(
self,
authenticated_client,
@@ -9425,6 +9685,283 @@ class TestComplianceOverviewViewSet:
assert len(response.json()["data"]) >= 1
mock_backfill_task.assert_not_called()
def test_compliance_overview_provider_id_filter_uses_latest_scan(
self,
authenticated_client,
providers_fixture,
mock_backfill_task,
):
_, latest_scan, *_ = self._prepare_latest_compliance_data(providers_fixture)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[provider_id]": str(latest_scan.provider_id)},
)
attrs_by_id = self._overview_attrs_by_id(response)
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 2
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 0
assert "cis_1.3_gcp" not in attrs_by_id
mock_backfill_task.assert_not_called()
def test_compliance_overview_provider_id_in_filter_aggregates_latest_scans(
self,
authenticated_client,
providers_fixture,
):
_, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data(
providers_fixture
)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[provider_id__in]": (
f"{latest_scan1.provider_id},{latest_scan2.provider_id}"
)
},
)
attrs_by_id = self._overview_attrs_by_id(response)
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1
assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2
assert "cis_1.3_gcp" not in attrs_by_id
def test_compliance_overview_provider_type_filter_uses_latest_scans(
self,
authenticated_client,
providers_fixture,
):
self._prepare_latest_compliance_data(providers_fixture)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[provider_type]": Provider.ProviderChoices.AWS.value},
)
attrs_by_id = self._overview_attrs_by_id(response)
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1
assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2
assert "cis_1.3_gcp" not in attrs_by_id
def test_compliance_overview_provider_groups_filters_use_latest_scans(
self,
authenticated_client,
providers_fixture,
provider_groups_fixture,
tenants_fixture,
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
group1, group2, *_ = provider_groups_fixture
_, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data(
providers_fixture
)
ProviderGroupMembership.objects.create(
tenant_id=tenant.id,
provider=provider1,
provider_group=group1,
)
ProviderGroupMembership.objects.create(
tenant_id=tenant.id,
provider=provider2,
provider_group=group2,
)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[provider_groups]": str(group1.id)},
)
attrs_by_id = self._overview_attrs_by_id(response)
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 2
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 0
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
)
attrs_by_id = self._overview_attrs_by_id(response)
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1
assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2
def _assert_latest_provider_scan_task_response(
self,
authenticated_client,
endpoint,
scan,
query_params=None,
):
query_params = {**(query_params or {})}
if not any(key.startswith("filter[provider_") for key in query_params):
query_params = {
"filter[provider_id]": str(scan.provider_id),
**query_params,
}
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
) as mock_task_response:
mock_task_response.return_value = Response(
{"detail": "Task is running"}, status=status.HTTP_202_ACCEPTED
)
response = authenticated_client.get(reverse(endpoint), query_params)
assert response.status_code == status.HTTP_202_ACCEPTED
mock_task_response.assert_called_once()
_, kwargs = mock_task_response.call_args
assert kwargs["task_name"] == "scan-compliance-overviews"
assert str(kwargs["task_kwargs"]["tenant_id"]) == str(scan.tenant_id)
assert str(kwargs["task_kwargs"]["scan_id"]) == str(scan.id)
assert kwargs["raise_on_not_found"] is False
def test_compliance_overview_provider_filter_returns_running_task_without_data(
self,
authenticated_client,
providers_fixture,
):
scan = self._create_completed_scan(
providers_fixture[0], "latest scan without compliance data"
)
self._assert_latest_provider_scan_task_response(
authenticated_client,
"complianceoverview-list",
scan,
)
def test_compliance_overview_provider_filter_returns_running_task_for_partial_data(
self,
authenticated_client,
providers_fixture,
):
provider_with_data, provider_without_data, *_ = providers_fixture
scan_with_data = self._create_completed_scan(
provider_with_data, "latest scan with compliance data"
)
scan_without_data = self._create_completed_scan(
provider_without_data, "latest scan without partial compliance data"
)
self._create_requirement(scan_with_data, "1.1", StatusChoices.PASS)
self._assert_latest_provider_scan_task_response(
authenticated_client,
"complianceoverview-list",
scan_without_data,
{
"filter[provider_id__in]": (
f"{provider_with_data.id},{provider_without_data.id}"
)
},
)
def test_compliance_overview_provider_filter_empty_response_uses_scan_data_presence(
self,
authenticated_client,
providers_fixture,
):
scan = self._create_completed_scan(
providers_fixture[0], "latest scan with filtered compliance data"
)
self._create_requirement(scan, "1.1", StatusChoices.PASS, region="eu-west-1")
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
) as mock_task_response:
mock_task_response.return_value = Response(
{"detail": "Task is running"}, status=status.HTTP_202_ACCEPTED
)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[provider_id]": str(scan.provider_id),
"filter[region]": "us-east-1",
},
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"] == []
mock_task_response.assert_not_called()
def test_compliance_overview_metadata_provider_filter_returns_running_task_without_data(
self,
authenticated_client,
providers_fixture,
):
scan = self._create_completed_scan(
providers_fixture[0], "latest scan without compliance metadata"
)
self._assert_latest_provider_scan_task_response(
authenticated_client,
"complianceoverview-metadata",
scan,
)
def test_compliance_overview_requirements_provider_filter_returns_running_task_without_data(
self,
authenticated_client,
providers_fixture,
):
scan = self._create_completed_scan(
providers_fixture[0], "latest scan without compliance requirements"
)
self._assert_latest_provider_scan_task_response(
authenticated_client,
"complianceoverview-requirements",
scan,
{"filter[compliance_id]": "cis_1.4_aws"},
)
def test_compliance_overview_metadata_accepts_provider_filters(
self,
authenticated_client,
providers_fixture,
):
_, latest_scan, *_ = self._prepare_latest_compliance_data(providers_fixture)
response = authenticated_client.get(
reverse("complianceoverview-metadata"),
{"filter[provider_id]": str(latest_scan.provider_id)},
)
assert response.status_code == status.HTTP_200_OK
regions = response.json()["data"]["attributes"]["regions"]
assert regions == ["eu-west-1"]
def test_compliance_overview_requirements_accepts_provider_filters(
self,
authenticated_client,
providers_fixture,
):
_, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data(
providers_fixture
)
response = authenticated_client.get(
reverse("complianceoverview-requirements"),
{
"filter[provider_id__in]": (
f"{latest_scan1.provider_id},{latest_scan2.provider_id}"
),
"filter[compliance_id]": "cis_1.4_aws",
},
)
assert response.status_code == status.HTTP_200_OK
requirements_by_id = {
item["id"]: item["attributes"] for item in response.json()["data"]
}
assert requirements_by_id["1.1"]["status"] == "FAIL"
assert requirements_by_id["1.2"]["status"] == "PASS"
def test_compliance_overview_metadata(
self, authenticated_client, compliance_requirements_overviews_fixture
):
@@ -10031,6 +10568,40 @@ class TestOverviewViewSet:
for entry in grouped_data:
assert "findings" not in entry["attributes"]
def test_overview_providers_count_applies_limited_visibility(
self,
authenticated_client_no_permissions_rbac,
providers_fixture,
provider_groups_fixture,
tenants_fixture,
):
tenant = tenants_fixture[0]
client = authenticated_client_no_permissions_rbac
allowed_provider = providers_fixture[2]
denied_provider = providers_fixture[4]
provider_group = provider_groups_fixture[0]
ProviderGroupMembership.objects.create(
tenant_id=tenant.id,
provider_group=provider_group,
provider=allowed_provider,
)
RoleProviderGroupRelationship.objects.create(
tenant_id=tenant.id,
role=client.user.roles.first(),
provider_group=provider_group,
)
response = client.get(reverse("overview-providers-count"))
assert response.status_code == status.HTTP_200_OK
aggregated = {
entry["id"]: entry["attributes"]["count"]
for entry in response.json()["data"]
}
assert aggregated == {allowed_provider.provider: 1}
assert denied_provider.provider not in aggregated
def _create_scan(self, tenant, provider, name, started_at=None):
scan_started = started_at or datetime.now(timezone.utc) - timedelta(hours=1)
return Scan.objects.create(
@@ -10578,6 +11149,87 @@ class TestOverviewViewSet:
assert combined_attributes["muted"] == 3
assert combined_attributes["total"] == 14
def test_overview_findings_provider_groups_filter(
self,
authenticated_client,
tenants_fixture,
providers_fixture,
provider_groups_fixture,
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider2, provider_group=group2
)
scan1 = Scan.objects.create(
name="scan-provider-group-one",
provider=provider1,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
scan2 = Scan.objects.create(
name="scan-provider-group-two",
provider=provider2,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
ScanSummary.objects.create(
tenant=tenant,
scan=scan1,
check_id="check-provider-group-one",
service="service-a",
severity="high",
region="region-a",
_pass=5,
fail=1,
muted=2,
total=8,
)
ScanSummary.objects.create(
tenant=tenant,
scan=scan2,
check_id="check-provider-group-two",
service="service-b",
severity="medium",
region="region-b",
_pass=2,
fail=3,
muted=1,
total=6,
)
response = authenticated_client.get(
reverse("overview-findings"),
{"filter[provider_groups]": str(group1.id)},
)
assert response.status_code == status.HTTP_200_OK
attributes = response.json()["data"]["attributes"]
assert attributes["pass"] == 5
assert attributes["fail"] == 1
assert attributes["muted"] == 2
assert attributes["total"] == 8
response = authenticated_client.get(
reverse("overview-findings"),
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
)
assert response.status_code == status.HTTP_200_OK
attributes = response.json()["data"]["attributes"]
assert attributes["pass"] == 7
assert attributes["fail"] == 4
assert attributes["muted"] == 3
assert attributes["total"] == 14
def test_overview_findings_severity_provider_id_in_filter(
self, authenticated_client, tenants_fixture, providers_fixture
):
@@ -11346,9 +11998,21 @@ class TestOverviewViewSet:
@pytest.mark.parametrize(
"filter_key,filter_value_fn,expected_total,expected_failed",
[
("filter[provider_id]", lambda p1, _: str(p1.id), 10, 5),
("filter[provider_id]", lambda p1, *_: str(p1.id), 10, 5),
("filter[provider_type]", lambda *_: "aws", 10, 5),
("filter[provider_type__in]", lambda *_: "aws,gcp", 30, 20),
(
"filter[provider_groups]",
lambda p1, _, group1, __: str(group1.id),
10,
5,
),
(
"filter[provider_groups__in]",
lambda p1, _, group1, group2: f"{group1.id},{group2.id}",
30,
20,
),
],
)
def test_overview_categories_filters(
@@ -11356,6 +12020,7 @@ class TestOverviewViewSet:
authenticated_client,
tenants_fixture,
providers_fixture,
provider_groups_fixture,
create_scan_category_summary,
filter_key,
filter_value_fn,
@@ -11364,6 +12029,16 @@ class TestOverviewViewSet:
):
tenant = tenants_fixture[0]
provider1, _, gcp_provider, *_ = providers_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=gcp_provider, provider_group=group2
)
scan1 = Scan.objects.create(
name="categories-scan-1",
@@ -11389,7 +12064,7 @@ class TestOverviewViewSet:
response = authenticated_client.get(
reverse("overview-categories"),
{filter_key: filter_value_fn(provider1, gcp_provider)},
{filter_key: filter_value_fn(provider1, gcp_provider, group1, group2)},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
@@ -11563,10 +12238,22 @@ class TestOverviewViewSet:
@pytest.mark.parametrize(
"filter_key,filter_value_fn,expected_total,expected_failed",
[
("filter[provider_id]", lambda p1, p2: str(p1.id), 10, 5),
("filter[provider_id__in]", lambda p1, p2: f"{p1.id},{p2.id}", 25, 12),
("filter[provider_type]", lambda p1, p2: "aws", 10, 5),
("filter[provider_type__in]", lambda p1, p2: "aws,gcp", 25, 12),
("filter[provider_id]", lambda p1, *_: str(p1.id), 10, 5),
("filter[provider_id__in]", lambda p1, p2, *_: f"{p1.id},{p2.id}", 25, 12),
("filter[provider_type]", lambda *_: "aws", 10, 5),
("filter[provider_type__in]", lambda *_: "aws,gcp", 25, 12),
(
"filter[provider_groups]",
lambda p1, p2, group1, group2: str(group1.id),
10,
5,
),
(
"filter[provider_groups__in]",
lambda p1, p2, group1, group2: f"{group1.id},{group2.id}",
25,
12,
),
],
)
def test_overview_groups_provider_filters(
@@ -11574,6 +12261,7 @@ class TestOverviewViewSet:
authenticated_client,
tenants_fixture,
providers_fixture,
provider_groups_fixture,
create_scan_resource_group_summary,
filter_key,
filter_value_fn,
@@ -11583,6 +12271,16 @@ class TestOverviewViewSet:
tenant = tenants_fixture[0]
provider1 = providers_fixture[0] # AWS
gcp_provider = providers_fixture[2] # GCP
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=gcp_provider, provider_group=group2
)
scan1 = Scan.objects.create(
name="aws-rg-scan",
@@ -11608,7 +12306,7 @@ class TestOverviewViewSet:
response = authenticated_client.get(
reverse("overview-resource-groups"),
{filter_key: filter_value_fn(provider1, gcp_provider)},
{filter_key: filter_value_fn(provider1, gcp_provider, group1, group2)},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
@@ -11783,6 +12481,49 @@ class TestOverviewViewSet:
data = response.json()["data"]
assert len(data) >= 1
def test_compliance_watchlist_provider_groups_filter(
self,
authenticated_client,
provider_compliance_scores_fixture,
providers_fixture,
provider_groups_fixture,
tenants_fixture,
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider2, provider_group=group2
)
response = authenticated_client.get(
reverse("overview-compliance-watchlist"),
{"filter[provider_groups]": str(group1.id)},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
by_id = {item["id"]: item["attributes"] for item in data}
assert by_id["aws_cis_2.0"]["requirements_passed"] == 1
assert by_id["aws_cis_2.0"]["requirements_failed"] == 1
assert by_id["aws_cis_2.0"]["requirements_manual"] == 1
response = authenticated_client.get(
reverse("overview-compliance-watchlist"),
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
by_id = {item["id"]: item["attributes"] for item in data}
assert by_id["aws_cis_2.0"]["requirements_passed"] == 0
assert by_id["aws_cis_2.0"]["requirements_failed"] == 2
assert by_id["aws_cis_2.0"]["requirements_manual"] == 1
def test_compliance_watchlist_empty_result(self, authenticated_client):
response = authenticated_client.get(reverse("overview-compliance-watchlist"))
assert response.status_code == status.HTTP_200_OK
@@ -17082,6 +17823,44 @@ class TestFindingGroupViewSet:
# All fixture findings are from AWS provider
assert len(response.json()["data"]) == 5
def test_finding_groups_provider_groups_filter(
self,
authenticated_client,
tenants_fixture,
finding_groups_fixture,
providers_fixture,
provider_groups_fixture,
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider2, provider_group=group2
)
response = authenticated_client.get(
reverse("finding-group-list"),
{"filter[inserted_at]": TODAY, "filter[provider_groups]": str(group1.id)},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 4
response = authenticated_client.get(
reverse("finding-group-list"),
{
"filter[inserted_at]": TODAY,
"filter[provider_groups__in]": f"{group1.id},{group2.id}",
},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 5
def test_finding_groups_check_id_filter(
self, authenticated_client, finding_groups_fixture
):
@@ -17992,6 +18771,41 @@ class TestFindingGroupViewSet:
# All providers in fixture are AWS
assert len(data) == 5
def test_finding_groups_latest_provider_groups_filter(
self,
authenticated_client,
tenants_fixture,
finding_groups_fixture,
providers_fixture,
provider_groups_fixture,
):
tenant = tenants_fixture[0]
provider1, provider2, *_ = providers_fixture
group1, group2, *_ = provider_groups_fixture
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group1
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider1, provider_group=group2
)
ProviderGroupMembership.objects.create(
tenant=tenant, provider=provider2, provider_group=group2
)
response = authenticated_client.get(
reverse("finding-group-latest"),
{"filter[provider_groups]": str(group1.id)},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 4
response = authenticated_client.get(
reverse("finding-group-latest"),
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 5
def test_finding_groups_latest_check_id_filter(
self, authenticated_client, finding_groups_fixture
):
+161 -1
View File
@@ -1,6 +1,10 @@
import uuid
from django.http import QueryDict
from django.urls import reverse
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.exceptions import ValidationError
from rest_framework.response import Response
from api.exceptions import (
@@ -8,7 +12,7 @@ from api.exceptions import (
TaskInProgressException,
TaskNotFoundException,
)
from api.models import StateChoices, Task
from api.models import Provider, StateChoices, Task
from api.v1.serializers import TaskSerializer
@@ -74,6 +78,162 @@ class PaginateByPkMixin:
return self.get_paginated_response(serialized)
class JsonApiFilterMixin:
"""Shared helpers for manually applying django-filter to JSON:API params."""
jsonapi_filter_replace_dots = False
def _normalize_jsonapi_params(
self,
query_params,
exclude_keys=None,
replace_dots=None,
):
exclude_keys = exclude_keys or set()
if replace_dots is None:
replace_dots = self.jsonapi_filter_replace_dots
normalized = QueryDict(mutable=True)
for key, values in query_params.lists():
normalized_key = (
key[7:-1] if key.startswith("filter[") and key.endswith("]") else key
)
if replace_dots:
normalized_key = normalized_key.replace(".", "__")
if normalized_key not in exclude_keys:
normalized.setlist(normalized_key, values)
return normalized
def _apply_filterset(
self,
queryset,
filterset_class,
exclude_keys=None,
replace_dots=None,
):
normalized_params = self._normalize_jsonapi_params(
self.request.query_params,
exclude_keys=set(exclude_keys or []),
replace_dots=replace_dots,
)
filterset = filterset_class(normalized_params, queryset=queryset)
if not filterset.is_valid():
raise ValidationError(filterset.errors)
return filterset.qs
class ProviderFilterParamsMixin(JsonApiFilterMixin):
"""Shared extraction of provider filters from JSON:API query params."""
PROVIDER_FILTER_KEYS = frozenset(
{
"provider_id",
"provider_id__in",
"provider_type",
"provider_type__in",
"provider_groups",
"provider_groups__in",
}
)
PROVIDER_FILTER_DOT_ALIAS_KEYS = frozenset(
{
"provider_id.in",
"provider_type.in",
"provider_groups.in",
}
)
PROVIDER_FILTER_QUERY_KEYS = PROVIDER_FILTER_KEYS | PROVIDER_FILTER_DOT_ALIAS_KEYS
def _csv_filter_values(self, value):
return [item.strip() for item in value.split(",") if item.strip()]
def _validate_uuid_filter_values(self, field_name, values):
try:
for value in values:
uuid.UUID(str(value))
except (TypeError, ValueError, AttributeError):
raise ValidationError({field_name: ["Enter a valid UUID."]})
def _has_provider_filters(self, include_dot_aliases=False):
provider_filter_keys = (
self.PROVIDER_FILTER_QUERY_KEYS
if include_dot_aliases
else self.PROVIDER_FILTER_KEYS
)
return any(
self.request.query_params.get(f"filter[{key}]")
for key in provider_filter_keys
)
def _extract_provider_filters_from_params(
self,
*,
validate_uuids=False,
include_dot_aliases=False,
):
params = self.request.query_params
filters = {}
valid_provider_types = {
choice[0] for choice in Provider.ProviderChoices.choices
}
provider_id = params.get("filter[provider_id]")
if provider_id:
if validate_uuids:
self._validate_uuid_filter_values("provider_id", [provider_id])
filters["provider_id"] = provider_id
provider_id_in = params.get("filter[provider_id__in]")
if include_dot_aliases:
provider_id_in = provider_id_in or params.get("filter[provider_id.in]")
if provider_id_in:
values = self._csv_filter_values(provider_id_in)
if validate_uuids:
self._validate_uuid_filter_values("provider_id__in", values)
filters["provider_id__in"] = values
provider_type = params.get("filter[provider_type]")
if provider_type:
if provider_type not in valid_provider_types:
raise ValidationError(
{"provider_type": f"Invalid choice: {provider_type}"}
)
filters["provider__provider"] = provider_type
provider_type_in = params.get("filter[provider_type__in]")
if include_dot_aliases:
provider_type_in = provider_type_in or params.get(
"filter[provider_type.in]"
)
if provider_type_in:
values = self._csv_filter_values(provider_type_in)
invalid = [value for value in values if value not in valid_provider_types]
if invalid:
raise ValidationError(
{"provider_type__in": f"Invalid choices: {', '.join(invalid)}"}
)
filters["provider__provider__in"] = values
provider_groups = params.get("filter[provider_groups]")
if provider_groups:
if validate_uuids:
self._validate_uuid_filter_values("provider_groups", [provider_groups])
filters["provider__provider_groups__id"] = provider_groups
provider_groups_in = params.get("filter[provider_groups__in]")
if include_dot_aliases:
provider_groups_in = provider_groups_in or params.get(
"filter[provider_groups.in]"
)
if provider_groups_in:
values = self._csv_filter_values(provider_groups_in)
if validate_uuids:
self._validate_uuid_filter_values("provider_groups__in", values)
filters["provider__provider_groups__id__in"] = values
return filters
class TaskManagementMixin:
"""
Mixin to manage task status checking.
+236 -150
View File
@@ -228,7 +228,13 @@ from api.utils import (
validate_invitation,
)
from api.uuid_utils import datetime_to_uuid7, uuid7_start
from api.v1.mixins import DisablePaginationMixin, PaginateByPkMixin, TaskManagementMixin
from api.v1.mixins import (
DisablePaginationMixin,
JsonApiFilterMixin,
PaginateByPkMixin,
ProviderFilterParamsMixin,
TaskManagementMixin,
)
from api.v1.serializers import (
AttackPathsCartographySchemaSerializer,
AttackPathsCustomQueryRunRequestSerializer,
@@ -4556,15 +4562,19 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
@extend_schema_view(
list=extend_schema(
tags=["Compliance Overview"],
summary="List compliance overviews for a scan",
description="Retrieve an overview of all the compliance in a given scan.",
summary="List compliance overviews",
description=(
"Retrieve compliance overview data for a scan. When provider filters "
"are provided, the endpoint uses the latest completed scan for each "
"matching provider."
),
parameters=[
OpenApiParameter(
name="filter[scan_id]",
required=True,
required=False,
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Related scan ID.",
description="Related scan ID. Required unless a provider filter is provided.",
),
],
responses={
@@ -4579,19 +4589,23 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
description="Compliance overviews generation task failed"
),
},
filters=True,
),
metadata=extend_schema(
tags=["Compliance Overview"],
summary="Retrieve metadata values from compliance overviews",
description="Fetch unique metadata values from a set of compliance overviews. This is useful for dynamic "
"filtering.",
description=(
"Fetch unique metadata values from compliance overviews. This is useful "
"for dynamic filtering. When provider filters are provided, metadata is "
"computed from the latest completed scan for each matching provider."
),
parameters=[
OpenApiParameter(
name="filter[scan_id]",
required=True,
required=False,
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Related scan ID.",
description="Related scan ID. Required unless a provider filter is provided.",
),
],
responses={
@@ -4606,19 +4620,24 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
description="Compliance overviews generation task failed"
),
},
filters=True,
),
requirements=extend_schema(
tags=["Compliance Overview"],
summary="List compliance requirements overview for a scan",
description="Retrieve a detailed overview of compliance requirements in a given scan, grouped by compliance "
"framework. This endpoint provides requirement-level details and aggregates status across regions.",
summary="List compliance requirements overview",
description=(
"Retrieve a detailed overview of compliance requirements, grouped by "
"compliance framework. This endpoint provides requirement-level details "
"and aggregates status across regions. When provider filters are provided, "
"the endpoint uses the latest completed scan for each matching provider."
),
parameters=[
OpenApiParameter(
name="filter[scan_id]",
required=True,
required=False,
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Related scan ID.",
description="Related scan ID. Required unless a provider filter is provided.",
),
OpenApiParameter(
name="filter[compliance_id]",
@@ -4677,7 +4696,10 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
@method_decorator(CACHE_DECORATOR, name="list")
@method_decorator(CACHE_DECORATOR, name="requirements")
@method_decorator(CACHE_DECORATOR, name="attributes")
class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
class ComplianceOverviewViewSet(
ProviderFilterParamsMixin, BaseRLSViewSet, TaskManagementMixin
):
jsonapi_filter_replace_dots = True
pagination_class = ComplianceOverviewPagination
queryset = ComplianceRequirementOverview.objects.all()
serializer_class = ComplianceOverviewSerializer
@@ -4691,28 +4713,22 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
required_permissions = []
def get_queryset(self):
if getattr(self, "swagger_fake_view", False):
return ComplianceRequirementOverview.objects.none()
role = get_role(self.request.user, self.request.tenant_id)
unlimited_visibility = getattr(
role, Permissions.UNLIMITED_VISIBILITY.value, False
)
if unlimited_visibility:
base_queryset = self.filter_queryset(
ComplianceRequirementOverview.objects.filter(
tenant_id=self.request.tenant_id
)
)
else:
providers = Provider.objects.filter(
provider_groups__in=role.provider_groups.all()
).distinct()
base_queryset = self.filter_queryset(
ComplianceRequirementOverview.objects.filter(
tenant_id=self.request.tenant_id, scan__provider__in=providers
)
)
base_queryset = ComplianceRequirementOverview.objects.filter(
tenant_id=self.request.tenant_id
)
return base_queryset
if unlimited_visibility:
return base_queryset
return base_queryset.filter(scan__provider__in=get_providers(role))
def get_serializer_class(self):
if hasattr(self, "response_serializer_class"):
@@ -4750,6 +4766,72 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
return summaries
def _validate_scan_selection(self, scan_id, has_provider_filters):
if scan_id and has_provider_filters:
raise ValidationError(
[
{
"detail": "Use either filter[scan_id] or provider filters.",
"status": 400,
"source": {"pointer": "filter[scan_id]"},
"code": "invalid",
}
]
)
if scan_id:
self._validate_uuid_filter_values("scan_id", [scan_id])
return
if has_provider_filters:
return
raise ValidationError(
[
{
"detail": "This query parameter is required unless a provider filter is provided.",
"status": 400,
"source": {"pointer": "filter[scan_id]"},
"code": "required",
}
]
)
def _latest_scan_ids_for_provider_filters(self):
role = get_role(self.request.user, self.request.tenant_id)
scans = Scan.all_objects.filter(
tenant_id=self.request.tenant_id,
state=StateChoices.COMPLETED,
)
if not getattr(role, Permissions.UNLIMITED_VISIBILITY.value, False):
scans = scans.filter(provider__in=get_providers(role))
provider_filters = self._extract_provider_filters_from_params(
validate_uuids=True,
include_dot_aliases=True,
)
if provider_filters:
scans = scans.filter(**provider_filters)
return list(
scans.order_by("provider_id", "-inserted_at")
.distinct("provider_id")
.values_list("id", flat=True)
)
def _filtered_queryset_for_latest_provider_scans(self, latest_scan_ids=None):
if latest_scan_ids is None:
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
queryset = self.get_queryset().filter(scan_id__in=latest_scan_ids)
# Provider filters stay on the filterset for OpenAPI docs, but runtime
# filtering happens on Scan first so compliance queries use scan IDs.
return self._apply_filterset(
queryset,
self.filterset_class,
exclude_keys=self.PROVIDER_FILTER_KEYS | {"scan_id"},
)
def _get_compliance_template(self, *, provider=None, scan_id=None):
"""Return the compliance template for the given provider or scan."""
if provider is None and scan_id is not None:
@@ -4865,6 +4947,36 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _task_response_for_latest_provider_scans(self, latest_scan_ids):
for scan_id in latest_scan_ids:
task_response = self._task_response_if_running(str(scan_id))
if task_response:
return task_response
return None
def _latest_provider_scan_ids_without_data(self, latest_scan_ids):
data_presence_queryset = self.get_queryset().filter(scan_id__in=latest_scan_ids)
scan_ids_with_data = {
str(scan_id)
for scan_id in data_presence_queryset.values_list(
"scan_id", flat=True
).distinct()
}
return [
scan_id
for scan_id in latest_scan_ids
if str(scan_id) not in scan_ids_with_data
]
def _task_response_for_latest_provider_scans_without_data(
self,
latest_scan_ids,
):
scan_ids_to_check = self._latest_provider_scan_ids_without_data(
latest_scan_ids,
)
return self._task_response_for_latest_provider_scans(scan_ids_to_check)
def _list_with_region_filter(self, scan_id, region_filter):
"""
Fall back to detailed ComplianceRequirementOverview query when region filter is applied.
@@ -4905,8 +5017,25 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
return Response(data)
def _list_with_latest_provider_filters(self):
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
queryset = self._filtered_queryset_for_latest_provider_scans(latest_scan_ids)
data = self._aggregate_compliance_overview(queryset)
task_response = self._task_response_for_latest_provider_scans_without_data(
latest_scan_ids,
)
if task_response:
return task_response
return Response(data)
def list(self, request, *args, **kwargs):
scan_id = request.query_params.get("filter[scan_id]")
has_provider_filters = self._has_provider_filters(include_dot_aliases=True)
self._validate_scan_selection(scan_id, has_provider_filters)
if has_provider_filters:
return self._list_with_latest_provider_filters()
# Specific scan requested - use optimized summaries with region support
region_filter = request.query_params.get(
@@ -4952,33 +5081,34 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
@action(detail=False, methods=["get"], url_name="metadata")
def metadata(self, request):
scan_id = request.query_params.get("filter[scan_id]")
if not scan_id:
raise ValidationError(
[
{
"detail": "This query parameter is required.",
"status": 400,
"source": {"pointer": "filter[scan_id]"},
"code": "required",
}
]
has_provider_filters = self._has_provider_filters(include_dot_aliases=True)
self._validate_scan_selection(scan_id, has_provider_filters)
latest_scan_ids = None
if has_provider_filters:
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
queryset = self._filtered_queryset_for_latest_provider_scans(
latest_scan_ids
)
else:
queryset = self._apply_filterset(self.get_queryset(), self.filterset_class)
regions = list(
self.get_queryset()
.filter(scan_id=scan_id)
.values_list("region", flat=True)
.order_by("region")
.distinct()
queryset.values_list("region", flat=True).order_by("region").distinct()
)
result = {"regions": regions}
if regions:
serializer = self.get_serializer(data=result)
serializer.is_valid(raise_exception=True)
return Response(serializer.data, status=status.HTTP_200_OK)
task_response = None
if has_provider_filters:
task_response = self._task_response_for_latest_provider_scans_without_data(
latest_scan_ids,
)
elif not regions:
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
task_response = self._task_response_if_running(scan_id)
if task_response:
if has_provider_filters and task_response:
return task_response
serializer = self.get_serializer(data=result)
@@ -4988,19 +5118,10 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
@action(detail=False, methods=["get"], url_name="requirements")
def requirements(self, request):
scan_id = request.query_params.get("filter[scan_id]")
has_provider_filters = self._has_provider_filters(include_dot_aliases=True)
compliance_id = request.query_params.get("filter[compliance_id]")
if not scan_id:
raise ValidationError(
[
{
"detail": "This query parameter is required.",
"status": 400,
"source": {"pointer": "filter[scan_id]"},
"code": "required",
}
]
)
self._validate_scan_selection(scan_id, has_provider_filters)
if not compliance_id:
raise ValidationError(
@@ -5013,7 +5134,16 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
}
]
)
filtered_queryset = self.filter_queryset(self.get_queryset())
latest_scan_ids = None
if has_provider_filters:
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
filtered_queryset = self._filtered_queryset_for_latest_provider_scans(
latest_scan_ids
)
else:
filtered_queryset = self._apply_filterset(
self.get_queryset(), self.filterset_class
)
all_requirements = filtered_queryset.values(
"requirement_id",
@@ -5072,13 +5202,22 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
requirements_summary, many=True
)
task_response = None
if has_provider_filters:
task_response = self._task_response_for_latest_provider_scans_without_data(
latest_scan_ids,
)
elif not requirements_summary:
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
if has_provider_filters and task_response:
return task_response
if requirements_summary:
return Response(serializer.data, status=status.HTTP_200_OK)
task_response = self._task_response_if_running(scan_id)
if task_response:
return task_response
return Response(serializer.data, status=status.HTTP_200_OK)
@action(detail=False, methods=["get"], url_name="attributes")
@@ -5317,7 +5456,7 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
),
)
@method_decorator(CACHE_DECORATOR, name="list")
class OverviewViewSet(BaseRLSViewSet):
class OverviewViewSet(ProviderFilterParamsMixin, BaseRLSViewSet):
queryset = ScanSummary.objects.all()
http_method_names = ["get"]
ordering = ["-inserted_at"]
@@ -5434,18 +5573,6 @@ class OverviewViewSet(BaseRLSViewSet):
tenant_id=tenant_id, scan_id__in=latest_scan_ids
)
def _normalize_jsonapi_params(self, query_params, exclude_keys=None):
"""Convert JSON:API filter params (filter[X]) to flat params (X)."""
exclude_keys = exclude_keys or set()
normalized = QueryDict(mutable=True)
for key, values in query_params.lists():
normalized_key = (
key[7:-1] if key.startswith("filter[") and key.endswith("]") else key
)
if normalized_key not in exclude_keys:
normalized.setlist(normalized_key, values)
return normalized
def _ensure_allowed_providers(self):
"""Populate allowed providers for RBAC-aware queries once per request."""
if getattr(self, "_providers_initialized", False):
@@ -5465,15 +5592,6 @@ class OverviewViewSet(BaseRLSViewSet):
return queryset.filter(**provider_filter)
return queryset
def _apply_filterset(self, queryset, filterset_class, exclude_keys=None):
normalized_params = self._normalize_jsonapi_params(
self.request.query_params, exclude_keys=set(exclude_keys or [])
)
filterset = filterset_class(normalized_params, queryset=queryset)
if not filterset.is_valid():
raise ValidationError(filterset.errors)
return filterset.qs
def _latest_scan_ids_for_allowed_providers(self, tenant_id, provider_filters=None):
provider_filter = self._get_provider_filter()
queryset = Scan.all_objects.filter(
@@ -5487,40 +5605,6 @@ class OverviewViewSet(BaseRLSViewSet):
.values_list("id", flat=True)
)
def _extract_provider_filters_from_params(self):
"""Extract and validate provider filters from query params."""
params = self.request.query_params
filters = {}
valid_provider_types = {c[0] for c in Provider.ProviderChoices.choices}
provider_id = params.get("filter[provider_id]")
if provider_id:
filters["provider_id"] = provider_id
provider_id_in = params.get("filter[provider_id__in]")
if provider_id_in:
filters["provider_id__in"] = provider_id_in.split(",")
provider_type = params.get("filter[provider_type]")
if provider_type:
if provider_type not in valid_provider_types:
raise ValidationError(
{"provider_type": f"Invalid choice: {provider_type}"}
)
filters["provider__provider"] = provider_type
provider_type_in = params.get("filter[provider_type__in]")
if provider_type_in:
types = provider_type_in.split(",")
invalid = [t for t in types if t not in valid_provider_types]
if invalid:
raise ValidationError(
{"provider_type__in": f"Invalid choices: {', '.join(invalid)}"}
)
filters["provider__provider__in"] = types
return filters
@action(detail=False, methods=["get"], url_name="providers")
def providers(self, request):
tenant_id = self.request.tenant_id
@@ -5591,15 +5675,11 @@ class OverviewViewSet(BaseRLSViewSet):
tenant_id = self.request.tenant_id
providers_qs = Provider.objects.filter(tenant_id=tenant_id)
self._ensure_allowed_providers()
if hasattr(self, "allowed_providers"):
allowed_ids = list(self.allowed_providers.values_list("id", flat=True))
if not allowed_ids:
overview = []
return Response(
self.get_serializer(overview, many=True).data,
status=status.HTTP_200_OK,
)
providers_qs = providers_qs.filter(id__in=allowed_ids)
providers_qs = providers_qs.filter(
id__in=self.allowed_providers.values("id")
)
overview = (
providers_qs.values("provider")
@@ -5815,29 +5895,41 @@ class OverviewViewSet(BaseRLSViewSet):
description="Retrieve a specific snapshot by ID. If not provided, returns latest snapshots.",
),
OpenApiParameter(
name="provider_id",
name="filter[provider_id]",
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Filter by specific provider ID",
),
OpenApiParameter(
name="provider_id__in",
name="filter[provider_id__in]",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by multiple provider IDs (comma-separated UUIDs)",
),
OpenApiParameter(
name="provider_type",
name="filter[provider_type]",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by provider type (aws, azure, gcp, etc.)",
),
OpenApiParameter(
name="provider_type__in",
name="filter[provider_type__in]",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by multiple provider types (comma-separated)",
),
OpenApiParameter(
name="filter[provider_groups]",
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Filter by provider group ID",
),
OpenApiParameter(
name="filter[provider_groups__in]",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Filter by multiple provider group IDs (comma-separated UUIDs)",
),
],
)
@action(detail=False, methods=["get"], url_name="threatscore")
@@ -6179,6 +6271,8 @@ class OverviewViewSet(BaseRLSViewSet):
"provider_id__in",
"provider_type",
"provider_type__in",
"provider_groups",
"provider_groups__in",
}
filtered_queryset = self._apply_filterset(
base_queryset, CategoryOverviewFilter, exclude_keys=provider_filter_keys
@@ -6248,6 +6342,8 @@ class OverviewViewSet(BaseRLSViewSet):
"provider_id__in",
"provider_type",
"provider_type__in",
"provider_groups",
"provider_groups__in",
}
filtered_queryset = self._apply_filterset(
base_queryset,
@@ -7298,7 +7394,7 @@ SEVERITY_ORDER_REVERSE = {v: k for k, v in SEVERITY_ORDER.items()}
),
retrieve=extend_schema(exclude=True),
)
class FindingGroupViewSet(BaseRLSViewSet):
class FindingGroupViewSet(JsonApiFilterMixin, BaseRLSViewSet):
"""
ViewSet for Finding Groups - aggregates findings by check_id.
@@ -7314,6 +7410,7 @@ class FindingGroupViewSet(BaseRLSViewSet):
queryset = FindingGroupDailySummary.objects.all()
serializer_class = FindingGroupSerializer
filterset_class = FindingGroupFilter
jsonapi_filter_replace_dots = True
filter_backends = [
jsonapi_filters.QueryParameterValidationFilter,
jsonapi_filters.OrderingFilter,
@@ -7364,18 +7461,6 @@ class FindingGroupViewSet(BaseRLSViewSet):
return queryset
def _normalize_jsonapi_params(self, query_params):
"""Convert JSON:API filter params (filter[X]) to flat params (X)."""
normalized = QueryDict(mutable=True)
for key, values in query_params.lists():
normalized_key = (
key[7:-1] if key.startswith("filter[") and key.endswith("]") else key
)
# Convert JSON:API dot notation to Django double underscore
normalized_key = normalized_key.replace(".", "__")
normalized.setlist(normalized_key, values)
return normalized
@extend_schema(exclude=True)
def retrieve(self, request, *args, **kwargs):
raise MethodNotAllowed(method="GET")
@@ -8494,9 +8579,10 @@ class FindingGroupViewSet(BaseRLSViewSet):
This endpoint returns finding groups without requiring date filters,
automatically using the latest available data per check_id.
All other filters (provider_id, provider_type, check_id) are still supported.
Provider, provider group, check, and computed filters are still supported.
""",
tags=["Finding Groups"],
filters=True,
)
@action(detail=False, methods=["get"], url_name="latest")
def latest(self, request):