mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
feat(api): add provider group filters (#11573)
This commit is contained in:
@@ -6,6 +6,8 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
### 🚀 Added
|
||||
|
||||
- Provider group filters for API endpoints that support cloud provider filtering, including exact and `__in` variants [(#11573)](https://github.com/prowler-cloud/prowler/pull/11573)
|
||||
- Provider filters for `GET /api/v1/compliance-overviews`, `/metadata`, and `/requirements`, using latest completed scans per matching provider [(#11587)](https://github.com/prowler-cloud/prowler/pull/11587)
|
||||
- Server-Sent Events (SSE) infrastructure for the API: a base viewset, a tenant-aware channel manager, and channel-name helpers backed by `django-eventstream` over Valkey Pub/Sub and served through the Gunicorn ASGI worker, so feature endpoints can stream events to clients over a single long-lived connection [(#11556)](https://github.com/prowler-cloud/prowler/pull/11556)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
@@ -102,7 +102,7 @@ class BaseProviderFilter(FilterSet):
|
||||
"""
|
||||
Abstract base filter for models with direct FK to Provider.
|
||||
|
||||
Provides standard provider_id and provider_type filters.
|
||||
Provides standard provider_id, provider_type, and provider_groups filters.
|
||||
Subclasses must define Meta.model.
|
||||
"""
|
||||
|
||||
@@ -116,6 +116,16 @@ class BaseProviderFilter(FilterSet):
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
@@ -126,7 +136,7 @@ class BaseScanProviderFilter(FilterSet):
|
||||
"""
|
||||
Abstract base filter for models with FK to Scan (and Scan has FK to Provider).
|
||||
|
||||
Provides standard provider_id and provider_type filters via scan relationship.
|
||||
Provides standard provider_id, provider_type, and provider_groups filters via scan relationship.
|
||||
Subclasses must define Meta.model.
|
||||
"""
|
||||
|
||||
@@ -140,6 +150,16 @@ class BaseScanProviderFilter(FilterSet):
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
@@ -160,6 +180,16 @@ class CommonFindingFilters(FilterSet):
|
||||
provider_type__in = ChoiceInFilter(
|
||||
choices=Provider.ProviderChoices.choices, field_name="scan__provider__provider"
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
provider_uid = CharFilter(field_name="scan__provider__uid", lookup_expr="exact")
|
||||
provider_uid__in = CharInFilter(field_name="scan__provider__uid", lookup_expr="in")
|
||||
provider_uid__icontains = CharFilter(
|
||||
@@ -370,6 +400,12 @@ class ProviderFilter(FilterSet):
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="provider_groups__id", lookup_expr="exact", distinct=True
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="provider_groups__id", lookup_expr="in", distinct=True
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = Provider
|
||||
@@ -395,6 +431,16 @@ class ProviderRelationshipFilterSet(FilterSet):
|
||||
provider_type__in = ChoiceInFilter(
|
||||
choices=Provider.ProviderChoices.choices, field_name="provider__provider"
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
provider_uid = CharFilter(field_name="provider__uid", lookup_expr="exact")
|
||||
provider_uid__in = CharInFilter(field_name="provider__uid", lookup_expr="in")
|
||||
provider_uid__icontains = CharFilter(
|
||||
@@ -1001,6 +1047,16 @@ class FindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
|
||||
field_name="provider__provider", choices=Provider.ProviderChoices.choices
|
||||
)
|
||||
provider_type__in = CharInFilter(field_name="provider__provider", lookup_expr="in")
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = FindingGroupDailySummary
|
||||
@@ -1101,6 +1157,16 @@ class LatestFindingGroupSummaryFilter(_CheckTitleToCheckIdMixin, FilterSet):
|
||||
field_name="provider__provider", choices=Provider.ProviderChoices.choices
|
||||
)
|
||||
provider_type__in = CharInFilter(field_name="provider__provider", lookup_expr="in")
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = FindingGroupDailySummary
|
||||
@@ -1280,12 +1346,19 @@ class RoleFilter(FilterSet):
|
||||
}
|
||||
|
||||
|
||||
class ComplianceOverviewFilter(FilterSet):
|
||||
class ComplianceOverviewFilter(BaseScanProviderFilter):
|
||||
"""
|
||||
Keep provider filters in the schema while runtime filtering resolves scans first.
|
||||
|
||||
Compliance overview provider filters are applied to the latest completed scans
|
||||
in the viewset, then this filterset handles the remaining compliance fields.
|
||||
"""
|
||||
|
||||
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
|
||||
scan_id = UUIDFilter(field_name="scan_id", required=True)
|
||||
scan_id = UUIDFilter(field_name="scan_id")
|
||||
region = CharFilter(field_name="region")
|
||||
|
||||
class Meta:
|
||||
class Meta(BaseScanProviderFilter.Meta):
|
||||
model = ComplianceRequirementOverview
|
||||
fields = {
|
||||
"inserted_at": ["date", "gte", "lte"],
|
||||
@@ -1306,6 +1379,16 @@ class ScanSummaryFilter(FilterSet):
|
||||
provider_type__in = ChoiceInFilter(
|
||||
field_name="scan__provider__provider", choices=Provider.ProviderChoices.choices
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
region = CharFilter(field_name="region")
|
||||
|
||||
class Meta:
|
||||
@@ -1329,6 +1412,16 @@ class DailySeveritySummaryFilter(FilterSet):
|
||||
provider_type__in = ChoiceInFilter(
|
||||
field_name="provider__provider", choices=Provider.ProviderChoices.choices
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
date_from = DateFilter(method="filter_noop")
|
||||
date_to = DateFilter(method="filter_noop")
|
||||
|
||||
@@ -1585,6 +1678,16 @@ class ThreatScoreSnapshotFilter(FilterSet):
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
compliance_id = CharFilter(field_name="compliance_id", lookup_expr="exact")
|
||||
compliance_id__in = CharInFilter(field_name="compliance_id", lookup_expr="in")
|
||||
|
||||
@@ -1628,6 +1731,16 @@ class ResourceGroupOverviewFilter(FilterSet):
|
||||
choices=Provider.ProviderChoices.choices,
|
||||
lookup_expr="in",
|
||||
)
|
||||
provider_groups = UUIDFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="exact",
|
||||
distinct=True,
|
||||
)
|
||||
provider_groups__in = UUIDInFilter(
|
||||
field_name="scan__provider__provider_groups__id",
|
||||
lookup_expr="in",
|
||||
distinct=True,
|
||||
)
|
||||
resource_group = CharFilter(field_name="resource_group", lookup_expr="exact")
|
||||
resource_group__in = CharInFilter(field_name="resource_group", lookup_expr="in")
|
||||
|
||||
|
||||
@@ -1411,6 +1411,42 @@ class TestProviderViewSet:
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_providers_filter_provider_groups(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider2, provider_group=group2
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("provider-list"), {"filter[provider_groups]": str(group1.id)}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
assert [item["id"] for item in data] == [str(provider1.id)]
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("provider-list"),
|
||||
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
provider_ids = {item["id"] for item in response.json()["data"]}
|
||||
assert provider_ids == {str(provider1.id), str(provider2.id)}
|
||||
assert len(response.json()["data"]) == 2
|
||||
|
||||
def test_providers_disable_pagination(
|
||||
self, authenticated_client, providers_fixture, tenants_fixture
|
||||
):
|
||||
@@ -3715,6 +3751,41 @@ class TestScanViewSet:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == expected_count
|
||||
|
||||
def test_scans_filter_provider_groups(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
scans_fixture,
|
||||
provider_groups_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
scan1, scan2, *_ = scans_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=scan1.provider, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=scan1.provider, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=scan2.provider, provider_group=group2
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("scan-list"), {"filter[provider_groups]": str(group1.id)}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert {item["id"] for item in response.json()["data"]} == {str(scan1.id)}
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("scan-list"),
|
||||
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
scan_ids = {item["id"] for item in response.json()["data"]}
|
||||
assert scan_ids == {str(scan1.id), str(scan2.id), str(scans_fixture[2].id)}
|
||||
assert len(response.json()["data"]) == 3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_name",
|
||||
[
|
||||
@@ -5996,6 +6067,49 @@ class TestResourceViewSet:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == expected_count
|
||||
|
||||
def test_resource_filter_provider_groups(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
resources_fixture,
|
||||
provider_groups_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
resource1, resource2, resource3, *_ = resources_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=resource1.provider, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=resource1.provider, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=resource3.provider, provider_group=group2
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("resource-list"),
|
||||
{"filter[updated_at]": TODAY, "filter[provider_groups]": str(group1.id)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 2
|
||||
assert {item["id"] for item in response.json()["data"]} == {
|
||||
str(resource1.id),
|
||||
str(resource2.id),
|
||||
}
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("resource-list"),
|
||||
{
|
||||
"filter[updated_at]": TODAY,
|
||||
"filter[provider_groups__in]": f"{group1.id},{group2.id}",
|
||||
},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
resource_ids = {item["id"] for item in response.json()["data"]}
|
||||
assert resource_ids == {str(resource1.id), str(resource2.id), str(resource3.id)}
|
||||
assert len(response.json()["data"]) == 3
|
||||
|
||||
def test_resource_filter_by_scan_id(
|
||||
self, authenticated_client, resources_fixture, scans_fixture
|
||||
):
|
||||
@@ -7308,6 +7422,40 @@ class TestFindingViewSet:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 2
|
||||
|
||||
def test_finding_filter_provider_groups(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
findings_fixture,
|
||||
provider_groups_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
finding1, finding2, *_ = findings_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=finding1.scan.provider, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=finding1.scan.provider, provider_group=group2
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("finding-list"),
|
||||
{"filter[inserted_at]": TODAY, "filter[provider_groups]": str(group1.id)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 2
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("finding-list"),
|
||||
{
|
||||
"filter[inserted_at]": TODAY,
|
||||
"filter[provider_groups__in]": f"{group1.id},{group2.id}",
|
||||
},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_name",
|
||||
(
|
||||
@@ -9278,6 +9426,118 @@ class TestComplianceOverviewViewSet:
|
||||
with patch("api.v1.views.backfill_compliance_summaries_task.delay") as mock:
|
||||
yield mock
|
||||
|
||||
def _create_completed_scan(self, provider, name):
|
||||
return Scan.objects.create(
|
||||
name=name,
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=provider.tenant_id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def _create_requirement(
|
||||
self,
|
||||
scan,
|
||||
requirement_id,
|
||||
status_choice,
|
||||
region="eu-west-1",
|
||||
compliance_id="cis_1.4_aws",
|
||||
):
|
||||
passed = 1 if status_choice == StatusChoices.PASS else 0
|
||||
total = 1 if status_choice != StatusChoices.MANUAL else 0
|
||||
return ComplianceRequirementOverview.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
scan=scan,
|
||||
compliance_id=compliance_id,
|
||||
framework="CIS-1.4-AWS",
|
||||
version="1.4",
|
||||
description="CIS AWS Foundations Benchmark v1.4.0",
|
||||
region=region,
|
||||
requirement_id=requirement_id,
|
||||
requirement_status=status_choice,
|
||||
passed_checks=passed,
|
||||
failed_checks=0
|
||||
if status_choice in (StatusChoices.PASS, StatusChoices.MANUAL)
|
||||
else 1,
|
||||
total_checks=total,
|
||||
passed_findings=passed,
|
||||
total_findings=total,
|
||||
)
|
||||
|
||||
def _create_compliance_summary(
|
||||
self,
|
||||
scan,
|
||||
*,
|
||||
passed,
|
||||
failed,
|
||||
manual=0,
|
||||
compliance_id="cis_1.4_aws",
|
||||
):
|
||||
return ComplianceOverviewSummary.objects.create(
|
||||
tenant_id=scan.tenant_id,
|
||||
scan=scan,
|
||||
compliance_id=compliance_id,
|
||||
requirements_passed=passed,
|
||||
requirements_failed=failed,
|
||||
requirements_manual=manual,
|
||||
total_requirements=passed + failed + manual,
|
||||
)
|
||||
|
||||
def _overview_attrs_by_id(self, response):
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
return {item["id"]: item["attributes"] for item in response.json()["data"]}
|
||||
|
||||
def _prepare_latest_compliance_data(self, providers_fixture):
|
||||
provider1, provider2, provider3, *_ = providers_fixture
|
||||
old_scan = self._create_completed_scan(provider1, "old aws compliance scan")
|
||||
latest_scan1 = self._create_completed_scan(
|
||||
provider1, "latest aws compliance scan 1"
|
||||
)
|
||||
latest_scan2 = self._create_completed_scan(
|
||||
provider2, "latest aws compliance scan 2"
|
||||
)
|
||||
latest_gcp_scan = self._create_completed_scan(
|
||||
provider3, "latest gcp compliance scan"
|
||||
)
|
||||
|
||||
self._create_requirement(old_scan, "1.1", StatusChoices.FAIL)
|
||||
self._create_requirement(old_scan, "1.2", StatusChoices.FAIL)
|
||||
self._create_compliance_summary(old_scan, passed=0, failed=2)
|
||||
|
||||
self._create_requirement(
|
||||
latest_scan1, "1.1", StatusChoices.PASS, region="eu-west-1"
|
||||
)
|
||||
self._create_requirement(
|
||||
latest_scan1, "1.2", StatusChoices.PASS, region="eu-west-1"
|
||||
)
|
||||
self._create_compliance_summary(latest_scan1, passed=2, failed=0)
|
||||
|
||||
self._create_requirement(
|
||||
latest_scan2, "1.1", StatusChoices.FAIL, region="us-east-1"
|
||||
)
|
||||
self._create_requirement(
|
||||
latest_scan2, "1.2", StatusChoices.PASS, region="us-east-1"
|
||||
)
|
||||
self._create_compliance_summary(latest_scan2, passed=1, failed=1)
|
||||
|
||||
self._create_requirement(
|
||||
latest_gcp_scan,
|
||||
"gcp-1.1",
|
||||
StatusChoices.FAIL,
|
||||
region="europe-west1",
|
||||
compliance_id="cis_1.3_gcp",
|
||||
)
|
||||
self._create_compliance_summary(
|
||||
latest_gcp_scan,
|
||||
passed=0,
|
||||
failed=1,
|
||||
compliance_id="cis_1.3_gcp",
|
||||
)
|
||||
|
||||
return old_scan, latest_scan1, latest_scan2, latest_gcp_scan
|
||||
|
||||
def test_compliance_overview_list_none(
|
||||
self,
|
||||
authenticated_client,
|
||||
@@ -9425,6 +9685,283 @@ class TestComplianceOverviewViewSet:
|
||||
assert len(response.json()["data"]) >= 1
|
||||
mock_backfill_task.assert_not_called()
|
||||
|
||||
def test_compliance_overview_provider_id_filter_uses_latest_scan(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
mock_backfill_task,
|
||||
):
|
||||
_, latest_scan, *_ = self._prepare_latest_compliance_data(providers_fixture)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{"filter[provider_id]": str(latest_scan.provider_id)},
|
||||
)
|
||||
|
||||
attrs_by_id = self._overview_attrs_by_id(response)
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 2
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 0
|
||||
assert "cis_1.3_gcp" not in attrs_by_id
|
||||
mock_backfill_task.assert_not_called()
|
||||
|
||||
def test_compliance_overview_provider_id_in_filter_aggregates_latest_scans(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
_, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data(
|
||||
providers_fixture
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{
|
||||
"filter[provider_id__in]": (
|
||||
f"{latest_scan1.provider_id},{latest_scan2.provider_id}"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
attrs_by_id = self._overview_attrs_by_id(response)
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1
|
||||
assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2
|
||||
assert "cis_1.3_gcp" not in attrs_by_id
|
||||
|
||||
def test_compliance_overview_provider_type_filter_uses_latest_scans(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
self._prepare_latest_compliance_data(providers_fixture)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{"filter[provider_type]": Provider.ProviderChoices.AWS.value},
|
||||
)
|
||||
|
||||
attrs_by_id = self._overview_attrs_by_id(response)
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1
|
||||
assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2
|
||||
assert "cis_1.3_gcp" not in attrs_by_id
|
||||
|
||||
def test_compliance_overview_provider_groups_filters_use_latest_scans(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
tenants_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
_, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data(
|
||||
providers_fixture
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider1,
|
||||
provider_group=group1,
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider=provider2,
|
||||
provider_group=group2,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{"filter[provider_groups]": str(group1.id)},
|
||||
)
|
||||
|
||||
attrs_by_id = self._overview_attrs_by_id(response)
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 2
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 0
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
|
||||
)
|
||||
|
||||
attrs_by_id = self._overview_attrs_by_id(response)
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_passed"] == 1
|
||||
assert attrs_by_id["cis_1.4_aws"]["requirements_failed"] == 1
|
||||
assert attrs_by_id["cis_1.4_aws"]["total_requirements"] == 2
|
||||
|
||||
def _assert_latest_provider_scan_task_response(
|
||||
self,
|
||||
authenticated_client,
|
||||
endpoint,
|
||||
scan,
|
||||
query_params=None,
|
||||
):
|
||||
query_params = {**(query_params or {})}
|
||||
if not any(key.startswith("filter[provider_") for key in query_params):
|
||||
query_params = {
|
||||
"filter[provider_id]": str(scan.provider_id),
|
||||
**query_params,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
ComplianceOverviewViewSet, "get_task_response_if_running"
|
||||
) as mock_task_response:
|
||||
mock_task_response.return_value = Response(
|
||||
{"detail": "Task is running"}, status=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
|
||||
response = authenticated_client.get(reverse(endpoint), query_params)
|
||||
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
mock_task_response.assert_called_once()
|
||||
_, kwargs = mock_task_response.call_args
|
||||
assert kwargs["task_name"] == "scan-compliance-overviews"
|
||||
assert str(kwargs["task_kwargs"]["tenant_id"]) == str(scan.tenant_id)
|
||||
assert str(kwargs["task_kwargs"]["scan_id"]) == str(scan.id)
|
||||
assert kwargs["raise_on_not_found"] is False
|
||||
|
||||
def test_compliance_overview_provider_filter_returns_running_task_without_data(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
scan = self._create_completed_scan(
|
||||
providers_fixture[0], "latest scan without compliance data"
|
||||
)
|
||||
|
||||
self._assert_latest_provider_scan_task_response(
|
||||
authenticated_client,
|
||||
"complianceoverview-list",
|
||||
scan,
|
||||
)
|
||||
|
||||
def test_compliance_overview_provider_filter_returns_running_task_for_partial_data(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
provider_with_data, provider_without_data, *_ = providers_fixture
|
||||
scan_with_data = self._create_completed_scan(
|
||||
provider_with_data, "latest scan with compliance data"
|
||||
)
|
||||
scan_without_data = self._create_completed_scan(
|
||||
provider_without_data, "latest scan without partial compliance data"
|
||||
)
|
||||
self._create_requirement(scan_with_data, "1.1", StatusChoices.PASS)
|
||||
|
||||
self._assert_latest_provider_scan_task_response(
|
||||
authenticated_client,
|
||||
"complianceoverview-list",
|
||||
scan_without_data,
|
||||
{
|
||||
"filter[provider_id__in]": (
|
||||
f"{provider_with_data.id},{provider_without_data.id}"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def test_compliance_overview_provider_filter_empty_response_uses_scan_data_presence(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
scan = self._create_completed_scan(
|
||||
providers_fixture[0], "latest scan with filtered compliance data"
|
||||
)
|
||||
self._create_requirement(scan, "1.1", StatusChoices.PASS, region="eu-west-1")
|
||||
|
||||
with patch.object(
|
||||
ComplianceOverviewViewSet, "get_task_response_if_running"
|
||||
) as mock_task_response:
|
||||
mock_task_response.return_value = Response(
|
||||
{"detail": "Task is running"}, status=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-list"),
|
||||
{
|
||||
"filter[provider_id]": str(scan.provider_id),
|
||||
"filter[region]": "us-east-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"] == []
|
||||
mock_task_response.assert_not_called()
|
||||
|
||||
def test_compliance_overview_metadata_provider_filter_returns_running_task_without_data(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
scan = self._create_completed_scan(
|
||||
providers_fixture[0], "latest scan without compliance metadata"
|
||||
)
|
||||
|
||||
self._assert_latest_provider_scan_task_response(
|
||||
authenticated_client,
|
||||
"complianceoverview-metadata",
|
||||
scan,
|
||||
)
|
||||
|
||||
def test_compliance_overview_requirements_provider_filter_returns_running_task_without_data(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
scan = self._create_completed_scan(
|
||||
providers_fixture[0], "latest scan without compliance requirements"
|
||||
)
|
||||
|
||||
self._assert_latest_provider_scan_task_response(
|
||||
authenticated_client,
|
||||
"complianceoverview-requirements",
|
||||
scan,
|
||||
{"filter[compliance_id]": "cis_1.4_aws"},
|
||||
)
|
||||
|
||||
def test_compliance_overview_metadata_accepts_provider_filters(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
_, latest_scan, *_ = self._prepare_latest_compliance_data(providers_fixture)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-metadata"),
|
||||
{"filter[provider_id]": str(latest_scan.provider_id)},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
regions = response.json()["data"]["attributes"]["regions"]
|
||||
assert regions == ["eu-west-1"]
|
||||
|
||||
def test_compliance_overview_requirements_accepts_provider_filters(
|
||||
self,
|
||||
authenticated_client,
|
||||
providers_fixture,
|
||||
):
|
||||
_, latest_scan1, latest_scan2, *_ = self._prepare_latest_compliance_data(
|
||||
providers_fixture
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("complianceoverview-requirements"),
|
||||
{
|
||||
"filter[provider_id__in]": (
|
||||
f"{latest_scan1.provider_id},{latest_scan2.provider_id}"
|
||||
),
|
||||
"filter[compliance_id]": "cis_1.4_aws",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
requirements_by_id = {
|
||||
item["id"]: item["attributes"] for item in response.json()["data"]
|
||||
}
|
||||
assert requirements_by_id["1.1"]["status"] == "FAIL"
|
||||
assert requirements_by_id["1.2"]["status"] == "PASS"
|
||||
|
||||
def test_compliance_overview_metadata(
|
||||
self, authenticated_client, compliance_requirements_overviews_fixture
|
||||
):
|
||||
@@ -10031,6 +10568,40 @@ class TestOverviewViewSet:
|
||||
for entry in grouped_data:
|
||||
assert "findings" not in entry["attributes"]
|
||||
|
||||
def test_overview_providers_count_applies_limited_visibility(
|
||||
self,
|
||||
authenticated_client_no_permissions_rbac,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
tenants_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
client = authenticated_client_no_permissions_rbac
|
||||
allowed_provider = providers_fixture[2]
|
||||
denied_provider = providers_fixture[4]
|
||||
provider_group = provider_groups_fixture[0]
|
||||
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider_group=provider_group,
|
||||
provider=allowed_provider,
|
||||
)
|
||||
RoleProviderGroupRelationship.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
role=client.user.roles.first(),
|
||||
provider_group=provider_group,
|
||||
)
|
||||
|
||||
response = client.get(reverse("overview-providers-count"))
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
aggregated = {
|
||||
entry["id"]: entry["attributes"]["count"]
|
||||
for entry in response.json()["data"]
|
||||
}
|
||||
assert aggregated == {allowed_provider.provider: 1}
|
||||
assert denied_provider.provider not in aggregated
|
||||
|
||||
def _create_scan(self, tenant, provider, name, started_at=None):
|
||||
scan_started = started_at or datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
return Scan.objects.create(
|
||||
@@ -10578,6 +11149,87 @@ class TestOverviewViewSet:
|
||||
assert combined_attributes["muted"] == 3
|
||||
assert combined_attributes["total"] == 14
|
||||
|
||||
def test_overview_findings_provider_groups_filter(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider2, provider_group=group2
|
||||
)
|
||||
|
||||
scan1 = Scan.objects.create(
|
||||
name="scan-provider-group-one",
|
||||
provider=provider1,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant=tenant,
|
||||
)
|
||||
scan2 = Scan.objects.create(
|
||||
name="scan-provider-group-two",
|
||||
provider=provider2,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant=tenant,
|
||||
)
|
||||
ScanSummary.objects.create(
|
||||
tenant=tenant,
|
||||
scan=scan1,
|
||||
check_id="check-provider-group-one",
|
||||
service="service-a",
|
||||
severity="high",
|
||||
region="region-a",
|
||||
_pass=5,
|
||||
fail=1,
|
||||
muted=2,
|
||||
total=8,
|
||||
)
|
||||
ScanSummary.objects.create(
|
||||
tenant=tenant,
|
||||
scan=scan2,
|
||||
check_id="check-provider-group-two",
|
||||
service="service-b",
|
||||
severity="medium",
|
||||
region="region-b",
|
||||
_pass=2,
|
||||
fail=3,
|
||||
muted=1,
|
||||
total=6,
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-findings"),
|
||||
{"filter[provider_groups]": str(group1.id)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
attributes = response.json()["data"]["attributes"]
|
||||
assert attributes["pass"] == 5
|
||||
assert attributes["fail"] == 1
|
||||
assert attributes["muted"] == 2
|
||||
assert attributes["total"] == 8
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-findings"),
|
||||
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
attributes = response.json()["data"]["attributes"]
|
||||
assert attributes["pass"] == 7
|
||||
assert attributes["fail"] == 4
|
||||
assert attributes["muted"] == 3
|
||||
assert attributes["total"] == 14
|
||||
|
||||
def test_overview_findings_severity_provider_id_in_filter(
|
||||
self, authenticated_client, tenants_fixture, providers_fixture
|
||||
):
|
||||
@@ -11346,9 +11998,21 @@ class TestOverviewViewSet:
|
||||
@pytest.mark.parametrize(
|
||||
"filter_key,filter_value_fn,expected_total,expected_failed",
|
||||
[
|
||||
("filter[provider_id]", lambda p1, _: str(p1.id), 10, 5),
|
||||
("filter[provider_id]", lambda p1, *_: str(p1.id), 10, 5),
|
||||
("filter[provider_type]", lambda *_: "aws", 10, 5),
|
||||
("filter[provider_type__in]", lambda *_: "aws,gcp", 30, 20),
|
||||
(
|
||||
"filter[provider_groups]",
|
||||
lambda p1, _, group1, __: str(group1.id),
|
||||
10,
|
||||
5,
|
||||
),
|
||||
(
|
||||
"filter[provider_groups__in]",
|
||||
lambda p1, _, group1, group2: f"{group1.id},{group2.id}",
|
||||
30,
|
||||
20,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_overview_categories_filters(
|
||||
@@ -11356,6 +12020,7 @@ class TestOverviewViewSet:
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
create_scan_category_summary,
|
||||
filter_key,
|
||||
filter_value_fn,
|
||||
@@ -11364,6 +12029,16 @@ class TestOverviewViewSet:
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, _, gcp_provider, *_ = providers_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=gcp_provider, provider_group=group2
|
||||
)
|
||||
|
||||
scan1 = Scan.objects.create(
|
||||
name="categories-scan-1",
|
||||
@@ -11389,7 +12064,7 @@ class TestOverviewViewSet:
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-categories"),
|
||||
{filter_key: filter_value_fn(provider1, gcp_provider)},
|
||||
{filter_key: filter_value_fn(provider1, gcp_provider, group1, group2)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
@@ -11563,10 +12238,22 @@ class TestOverviewViewSet:
|
||||
@pytest.mark.parametrize(
|
||||
"filter_key,filter_value_fn,expected_total,expected_failed",
|
||||
[
|
||||
("filter[provider_id]", lambda p1, p2: str(p1.id), 10, 5),
|
||||
("filter[provider_id__in]", lambda p1, p2: f"{p1.id},{p2.id}", 25, 12),
|
||||
("filter[provider_type]", lambda p1, p2: "aws", 10, 5),
|
||||
("filter[provider_type__in]", lambda p1, p2: "aws,gcp", 25, 12),
|
||||
("filter[provider_id]", lambda p1, *_: str(p1.id), 10, 5),
|
||||
("filter[provider_id__in]", lambda p1, p2, *_: f"{p1.id},{p2.id}", 25, 12),
|
||||
("filter[provider_type]", lambda *_: "aws", 10, 5),
|
||||
("filter[provider_type__in]", lambda *_: "aws,gcp", 25, 12),
|
||||
(
|
||||
"filter[provider_groups]",
|
||||
lambda p1, p2, group1, group2: str(group1.id),
|
||||
10,
|
||||
5,
|
||||
),
|
||||
(
|
||||
"filter[provider_groups__in]",
|
||||
lambda p1, p2, group1, group2: f"{group1.id},{group2.id}",
|
||||
25,
|
||||
12,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_overview_groups_provider_filters(
|
||||
@@ -11574,6 +12261,7 @@ class TestOverviewViewSet:
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
create_scan_resource_group_summary,
|
||||
filter_key,
|
||||
filter_value_fn,
|
||||
@@ -11583,6 +12271,16 @@ class TestOverviewViewSet:
|
||||
tenant = tenants_fixture[0]
|
||||
provider1 = providers_fixture[0] # AWS
|
||||
gcp_provider = providers_fixture[2] # GCP
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=gcp_provider, provider_group=group2
|
||||
)
|
||||
|
||||
scan1 = Scan.objects.create(
|
||||
name="aws-rg-scan",
|
||||
@@ -11608,7 +12306,7 @@ class TestOverviewViewSet:
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-resource-groups"),
|
||||
{filter_key: filter_value_fn(provider1, gcp_provider)},
|
||||
{filter_key: filter_value_fn(provider1, gcp_provider, group1, group2)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
@@ -11783,6 +12481,49 @@ class TestOverviewViewSet:
|
||||
data = response.json()["data"]
|
||||
assert len(data) >= 1
|
||||
|
||||
def test_compliance_watchlist_provider_groups_filter(
|
||||
self,
|
||||
authenticated_client,
|
||||
provider_compliance_scores_fixture,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
tenants_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider2, provider_group=group2
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-compliance-watchlist"),
|
||||
{"filter[provider_groups]": str(group1.id)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
by_id = {item["id"]: item["attributes"] for item in data}
|
||||
assert by_id["aws_cis_2.0"]["requirements_passed"] == 1
|
||||
assert by_id["aws_cis_2.0"]["requirements_failed"] == 1
|
||||
assert by_id["aws_cis_2.0"]["requirements_manual"] == 1
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("overview-compliance-watchlist"),
|
||||
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()["data"]
|
||||
by_id = {item["id"]: item["attributes"] for item in data}
|
||||
assert by_id["aws_cis_2.0"]["requirements_passed"] == 0
|
||||
assert by_id["aws_cis_2.0"]["requirements_failed"] == 2
|
||||
assert by_id["aws_cis_2.0"]["requirements_manual"] == 1
|
||||
|
||||
def test_compliance_watchlist_empty_result(self, authenticated_client):
|
||||
response = authenticated_client.get(reverse("overview-compliance-watchlist"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -17082,6 +17823,44 @@ class TestFindingGroupViewSet:
|
||||
# All fixture findings are from AWS provider
|
||||
assert len(response.json()["data"]) == 5
|
||||
|
||||
def test_finding_groups_provider_groups_filter(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
finding_groups_fixture,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider2, provider_group=group2
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("finding-group-list"),
|
||||
{"filter[inserted_at]": TODAY, "filter[provider_groups]": str(group1.id)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 4
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("finding-group-list"),
|
||||
{
|
||||
"filter[inserted_at]": TODAY,
|
||||
"filter[provider_groups__in]": f"{group1.id},{group2.id}",
|
||||
},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 5
|
||||
|
||||
def test_finding_groups_check_id_filter(
|
||||
self, authenticated_client, finding_groups_fixture
|
||||
):
|
||||
@@ -17992,6 +18771,41 @@ class TestFindingGroupViewSet:
|
||||
# All providers in fixture are AWS
|
||||
assert len(data) == 5
|
||||
|
||||
def test_finding_groups_latest_provider_groups_filter(
|
||||
self,
|
||||
authenticated_client,
|
||||
tenants_fixture,
|
||||
finding_groups_fixture,
|
||||
providers_fixture,
|
||||
provider_groups_fixture,
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
provider1, provider2, *_ = providers_fixture
|
||||
group1, group2, *_ = provider_groups_fixture
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group1
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider1, provider_group=group2
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant=tenant, provider=provider2, provider_group=group2
|
||||
)
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("finding-group-latest"),
|
||||
{"filter[provider_groups]": str(group1.id)},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 4
|
||||
|
||||
response = authenticated_client.get(
|
||||
reverse("finding-group-latest"),
|
||||
{"filter[provider_groups__in]": f"{group1.id},{group2.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 5
|
||||
|
||||
def test_finding_groups_latest_check_id_filter(
|
||||
self, authenticated_client, finding_groups_fixture
|
||||
):
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import uuid
|
||||
|
||||
from django.http import QueryDict
|
||||
from django.urls import reverse
|
||||
from django_celery_results.models import TaskResult
|
||||
from rest_framework import status
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.response import Response
|
||||
|
||||
from api.exceptions import (
|
||||
@@ -8,7 +12,7 @@ from api.exceptions import (
|
||||
TaskInProgressException,
|
||||
TaskNotFoundException,
|
||||
)
|
||||
from api.models import StateChoices, Task
|
||||
from api.models import Provider, StateChoices, Task
|
||||
from api.v1.serializers import TaskSerializer
|
||||
|
||||
|
||||
@@ -74,6 +78,162 @@ class PaginateByPkMixin:
|
||||
return self.get_paginated_response(serialized)
|
||||
|
||||
|
||||
class JsonApiFilterMixin:
|
||||
"""Shared helpers for manually applying django-filter to JSON:API params."""
|
||||
|
||||
jsonapi_filter_replace_dots = False
|
||||
|
||||
def _normalize_jsonapi_params(
|
||||
self,
|
||||
query_params,
|
||||
exclude_keys=None,
|
||||
replace_dots=None,
|
||||
):
|
||||
exclude_keys = exclude_keys or set()
|
||||
if replace_dots is None:
|
||||
replace_dots = self.jsonapi_filter_replace_dots
|
||||
|
||||
normalized = QueryDict(mutable=True)
|
||||
for key, values in query_params.lists():
|
||||
normalized_key = (
|
||||
key[7:-1] if key.startswith("filter[") and key.endswith("]") else key
|
||||
)
|
||||
if replace_dots:
|
||||
normalized_key = normalized_key.replace(".", "__")
|
||||
if normalized_key not in exclude_keys:
|
||||
normalized.setlist(normalized_key, values)
|
||||
return normalized
|
||||
|
||||
def _apply_filterset(
|
||||
self,
|
||||
queryset,
|
||||
filterset_class,
|
||||
exclude_keys=None,
|
||||
replace_dots=None,
|
||||
):
|
||||
normalized_params = self._normalize_jsonapi_params(
|
||||
self.request.query_params,
|
||||
exclude_keys=set(exclude_keys or []),
|
||||
replace_dots=replace_dots,
|
||||
)
|
||||
filterset = filterset_class(normalized_params, queryset=queryset)
|
||||
if not filterset.is_valid():
|
||||
raise ValidationError(filterset.errors)
|
||||
return filterset.qs
|
||||
|
||||
|
||||
class ProviderFilterParamsMixin(JsonApiFilterMixin):
|
||||
"""Shared extraction of provider filters from JSON:API query params."""
|
||||
|
||||
PROVIDER_FILTER_KEYS = frozenset(
|
||||
{
|
||||
"provider_id",
|
||||
"provider_id__in",
|
||||
"provider_type",
|
||||
"provider_type__in",
|
||||
"provider_groups",
|
||||
"provider_groups__in",
|
||||
}
|
||||
)
|
||||
PROVIDER_FILTER_DOT_ALIAS_KEYS = frozenset(
|
||||
{
|
||||
"provider_id.in",
|
||||
"provider_type.in",
|
||||
"provider_groups.in",
|
||||
}
|
||||
)
|
||||
PROVIDER_FILTER_QUERY_KEYS = PROVIDER_FILTER_KEYS | PROVIDER_FILTER_DOT_ALIAS_KEYS
|
||||
|
||||
def _csv_filter_values(self, value):
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
def _validate_uuid_filter_values(self, field_name, values):
|
||||
try:
|
||||
for value in values:
|
||||
uuid.UUID(str(value))
|
||||
except (TypeError, ValueError, AttributeError):
|
||||
raise ValidationError({field_name: ["Enter a valid UUID."]})
|
||||
|
||||
def _has_provider_filters(self, include_dot_aliases=False):
|
||||
provider_filter_keys = (
|
||||
self.PROVIDER_FILTER_QUERY_KEYS
|
||||
if include_dot_aliases
|
||||
else self.PROVIDER_FILTER_KEYS
|
||||
)
|
||||
return any(
|
||||
self.request.query_params.get(f"filter[{key}]")
|
||||
for key in provider_filter_keys
|
||||
)
|
||||
|
||||
def _extract_provider_filters_from_params(
|
||||
self,
|
||||
*,
|
||||
validate_uuids=False,
|
||||
include_dot_aliases=False,
|
||||
):
|
||||
params = self.request.query_params
|
||||
filters = {}
|
||||
valid_provider_types = {
|
||||
choice[0] for choice in Provider.ProviderChoices.choices
|
||||
}
|
||||
|
||||
provider_id = params.get("filter[provider_id]")
|
||||
if provider_id:
|
||||
if validate_uuids:
|
||||
self._validate_uuid_filter_values("provider_id", [provider_id])
|
||||
filters["provider_id"] = provider_id
|
||||
|
||||
provider_id_in = params.get("filter[provider_id__in]")
|
||||
if include_dot_aliases:
|
||||
provider_id_in = provider_id_in or params.get("filter[provider_id.in]")
|
||||
if provider_id_in:
|
||||
values = self._csv_filter_values(provider_id_in)
|
||||
if validate_uuids:
|
||||
self._validate_uuid_filter_values("provider_id__in", values)
|
||||
filters["provider_id__in"] = values
|
||||
|
||||
provider_type = params.get("filter[provider_type]")
|
||||
if provider_type:
|
||||
if provider_type not in valid_provider_types:
|
||||
raise ValidationError(
|
||||
{"provider_type": f"Invalid choice: {provider_type}"}
|
||||
)
|
||||
filters["provider__provider"] = provider_type
|
||||
|
||||
provider_type_in = params.get("filter[provider_type__in]")
|
||||
if include_dot_aliases:
|
||||
provider_type_in = provider_type_in or params.get(
|
||||
"filter[provider_type.in]"
|
||||
)
|
||||
if provider_type_in:
|
||||
values = self._csv_filter_values(provider_type_in)
|
||||
invalid = [value for value in values if value not in valid_provider_types]
|
||||
if invalid:
|
||||
raise ValidationError(
|
||||
{"provider_type__in": f"Invalid choices: {', '.join(invalid)}"}
|
||||
)
|
||||
filters["provider__provider__in"] = values
|
||||
|
||||
provider_groups = params.get("filter[provider_groups]")
|
||||
if provider_groups:
|
||||
if validate_uuids:
|
||||
self._validate_uuid_filter_values("provider_groups", [provider_groups])
|
||||
filters["provider__provider_groups__id"] = provider_groups
|
||||
|
||||
provider_groups_in = params.get("filter[provider_groups__in]")
|
||||
if include_dot_aliases:
|
||||
provider_groups_in = provider_groups_in or params.get(
|
||||
"filter[provider_groups.in]"
|
||||
)
|
||||
if provider_groups_in:
|
||||
values = self._csv_filter_values(provider_groups_in)
|
||||
if validate_uuids:
|
||||
self._validate_uuid_filter_values("provider_groups__in", values)
|
||||
filters["provider__provider_groups__id__in"] = values
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
class TaskManagementMixin:
|
||||
"""
|
||||
Mixin to manage task status checking.
|
||||
|
||||
+236
-150
@@ -228,7 +228,13 @@ from api.utils import (
|
||||
validate_invitation,
|
||||
)
|
||||
from api.uuid_utils import datetime_to_uuid7, uuid7_start
|
||||
from api.v1.mixins import DisablePaginationMixin, PaginateByPkMixin, TaskManagementMixin
|
||||
from api.v1.mixins import (
|
||||
DisablePaginationMixin,
|
||||
JsonApiFilterMixin,
|
||||
PaginateByPkMixin,
|
||||
ProviderFilterParamsMixin,
|
||||
TaskManagementMixin,
|
||||
)
|
||||
from api.v1.serializers import (
|
||||
AttackPathsCartographySchemaSerializer,
|
||||
AttackPathsCustomQueryRunRequestSerializer,
|
||||
@@ -4556,15 +4562,19 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
tags=["Compliance Overview"],
|
||||
summary="List compliance overviews for a scan",
|
||||
description="Retrieve an overview of all the compliance in a given scan.",
|
||||
summary="List compliance overviews",
|
||||
description=(
|
||||
"Retrieve compliance overview data for a scan. When provider filters "
|
||||
"are provided, the endpoint uses the latest completed scan for each "
|
||||
"matching provider."
|
||||
),
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="filter[scan_id]",
|
||||
required=True,
|
||||
required=False,
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Related scan ID.",
|
||||
description="Related scan ID. Required unless a provider filter is provided.",
|
||||
),
|
||||
],
|
||||
responses={
|
||||
@@ -4579,19 +4589,23 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
description="Compliance overviews generation task failed"
|
||||
),
|
||||
},
|
||||
filters=True,
|
||||
),
|
||||
metadata=extend_schema(
|
||||
tags=["Compliance Overview"],
|
||||
summary="Retrieve metadata values from compliance overviews",
|
||||
description="Fetch unique metadata values from a set of compliance overviews. This is useful for dynamic "
|
||||
"filtering.",
|
||||
description=(
|
||||
"Fetch unique metadata values from compliance overviews. This is useful "
|
||||
"for dynamic filtering. When provider filters are provided, metadata is "
|
||||
"computed from the latest completed scan for each matching provider."
|
||||
),
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="filter[scan_id]",
|
||||
required=True,
|
||||
required=False,
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Related scan ID.",
|
||||
description="Related scan ID. Required unless a provider filter is provided.",
|
||||
),
|
||||
],
|
||||
responses={
|
||||
@@ -4606,19 +4620,24 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
description="Compliance overviews generation task failed"
|
||||
),
|
||||
},
|
||||
filters=True,
|
||||
),
|
||||
requirements=extend_schema(
|
||||
tags=["Compliance Overview"],
|
||||
summary="List compliance requirements overview for a scan",
|
||||
description="Retrieve a detailed overview of compliance requirements in a given scan, grouped by compliance "
|
||||
"framework. This endpoint provides requirement-level details and aggregates status across regions.",
|
||||
summary="List compliance requirements overview",
|
||||
description=(
|
||||
"Retrieve a detailed overview of compliance requirements, grouped by "
|
||||
"compliance framework. This endpoint provides requirement-level details "
|
||||
"and aggregates status across regions. When provider filters are provided, "
|
||||
"the endpoint uses the latest completed scan for each matching provider."
|
||||
),
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="filter[scan_id]",
|
||||
required=True,
|
||||
required=False,
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Related scan ID.",
|
||||
description="Related scan ID. Required unless a provider filter is provided.",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="filter[compliance_id]",
|
||||
@@ -4677,7 +4696,10 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
@method_decorator(CACHE_DECORATOR, name="list")
|
||||
@method_decorator(CACHE_DECORATOR, name="requirements")
|
||||
@method_decorator(CACHE_DECORATOR, name="attributes")
|
||||
class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
class ComplianceOverviewViewSet(
|
||||
ProviderFilterParamsMixin, BaseRLSViewSet, TaskManagementMixin
|
||||
):
|
||||
jsonapi_filter_replace_dots = True
|
||||
pagination_class = ComplianceOverviewPagination
|
||||
queryset = ComplianceRequirementOverview.objects.all()
|
||||
serializer_class = ComplianceOverviewSerializer
|
||||
@@ -4691,28 +4713,22 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
if getattr(self, "swagger_fake_view", False):
|
||||
return ComplianceRequirementOverview.objects.none()
|
||||
|
||||
role = get_role(self.request.user, self.request.tenant_id)
|
||||
unlimited_visibility = getattr(
|
||||
role, Permissions.UNLIMITED_VISIBILITY.value, False
|
||||
)
|
||||
|
||||
if unlimited_visibility:
|
||||
base_queryset = self.filter_queryset(
|
||||
ComplianceRequirementOverview.objects.filter(
|
||||
tenant_id=self.request.tenant_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=role.provider_groups.all()
|
||||
).distinct()
|
||||
base_queryset = self.filter_queryset(
|
||||
ComplianceRequirementOverview.objects.filter(
|
||||
tenant_id=self.request.tenant_id, scan__provider__in=providers
|
||||
)
|
||||
)
|
||||
base_queryset = ComplianceRequirementOverview.objects.filter(
|
||||
tenant_id=self.request.tenant_id
|
||||
)
|
||||
|
||||
return base_queryset
|
||||
if unlimited_visibility:
|
||||
return base_queryset
|
||||
|
||||
return base_queryset.filter(scan__provider__in=get_providers(role))
|
||||
|
||||
def get_serializer_class(self):
|
||||
if hasattr(self, "response_serializer_class"):
|
||||
@@ -4750,6 +4766,72 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
|
||||
return summaries
|
||||
|
||||
def _validate_scan_selection(self, scan_id, has_provider_filters):
|
||||
if scan_id and has_provider_filters:
|
||||
raise ValidationError(
|
||||
[
|
||||
{
|
||||
"detail": "Use either filter[scan_id] or provider filters.",
|
||||
"status": 400,
|
||||
"source": {"pointer": "filter[scan_id]"},
|
||||
"code": "invalid",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
if scan_id:
|
||||
self._validate_uuid_filter_values("scan_id", [scan_id])
|
||||
return
|
||||
|
||||
if has_provider_filters:
|
||||
return
|
||||
|
||||
raise ValidationError(
|
||||
[
|
||||
{
|
||||
"detail": "This query parameter is required unless a provider filter is provided.",
|
||||
"status": 400,
|
||||
"source": {"pointer": "filter[scan_id]"},
|
||||
"code": "required",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def _latest_scan_ids_for_provider_filters(self):
|
||||
role = get_role(self.request.user, self.request.tenant_id)
|
||||
scans = Scan.all_objects.filter(
|
||||
tenant_id=self.request.tenant_id,
|
||||
state=StateChoices.COMPLETED,
|
||||
)
|
||||
|
||||
if not getattr(role, Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
scans = scans.filter(provider__in=get_providers(role))
|
||||
|
||||
provider_filters = self._extract_provider_filters_from_params(
|
||||
validate_uuids=True,
|
||||
include_dot_aliases=True,
|
||||
)
|
||||
if provider_filters:
|
||||
scans = scans.filter(**provider_filters)
|
||||
|
||||
return list(
|
||||
scans.order_by("provider_id", "-inserted_at")
|
||||
.distinct("provider_id")
|
||||
.values_list("id", flat=True)
|
||||
)
|
||||
|
||||
def _filtered_queryset_for_latest_provider_scans(self, latest_scan_ids=None):
|
||||
if latest_scan_ids is None:
|
||||
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
|
||||
queryset = self.get_queryset().filter(scan_id__in=latest_scan_ids)
|
||||
# Provider filters stay on the filterset for OpenAPI docs, but runtime
|
||||
# filtering happens on Scan first so compliance queries use scan IDs.
|
||||
return self._apply_filterset(
|
||||
queryset,
|
||||
self.filterset_class,
|
||||
exclude_keys=self.PROVIDER_FILTER_KEYS | {"scan_id"},
|
||||
)
|
||||
|
||||
def _get_compliance_template(self, *, provider=None, scan_id=None):
|
||||
"""Return the compliance template for the given provider or scan."""
|
||||
if provider is None and scan_id is not None:
|
||||
@@ -4865,6 +4947,36 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def _task_response_for_latest_provider_scans(self, latest_scan_ids):
|
||||
for scan_id in latest_scan_ids:
|
||||
task_response = self._task_response_if_running(str(scan_id))
|
||||
if task_response:
|
||||
return task_response
|
||||
return None
|
||||
|
||||
def _latest_provider_scan_ids_without_data(self, latest_scan_ids):
|
||||
data_presence_queryset = self.get_queryset().filter(scan_id__in=latest_scan_ids)
|
||||
scan_ids_with_data = {
|
||||
str(scan_id)
|
||||
for scan_id in data_presence_queryset.values_list(
|
||||
"scan_id", flat=True
|
||||
).distinct()
|
||||
}
|
||||
return [
|
||||
scan_id
|
||||
for scan_id in latest_scan_ids
|
||||
if str(scan_id) not in scan_ids_with_data
|
||||
]
|
||||
|
||||
def _task_response_for_latest_provider_scans_without_data(
|
||||
self,
|
||||
latest_scan_ids,
|
||||
):
|
||||
scan_ids_to_check = self._latest_provider_scan_ids_without_data(
|
||||
latest_scan_ids,
|
||||
)
|
||||
return self._task_response_for_latest_provider_scans(scan_ids_to_check)
|
||||
|
||||
def _list_with_region_filter(self, scan_id, region_filter):
|
||||
"""
|
||||
Fall back to detailed ComplianceRequirementOverview query when region filter is applied.
|
||||
@@ -4905,8 +5017,25 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
|
||||
return Response(data)
|
||||
|
||||
def _list_with_latest_provider_filters(self):
|
||||
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
|
||||
queryset = self._filtered_queryset_for_latest_provider_scans(latest_scan_ids)
|
||||
data = self._aggregate_compliance_overview(queryset)
|
||||
task_response = self._task_response_for_latest_provider_scans_without_data(
|
||||
latest_scan_ids,
|
||||
)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
return Response(data)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
scan_id = request.query_params.get("filter[scan_id]")
|
||||
has_provider_filters = self._has_provider_filters(include_dot_aliases=True)
|
||||
self._validate_scan_selection(scan_id, has_provider_filters)
|
||||
|
||||
if has_provider_filters:
|
||||
return self._list_with_latest_provider_filters()
|
||||
|
||||
# Specific scan requested - use optimized summaries with region support
|
||||
region_filter = request.query_params.get(
|
||||
@@ -4952,33 +5081,34 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
@action(detail=False, methods=["get"], url_name="metadata")
|
||||
def metadata(self, request):
|
||||
scan_id = request.query_params.get("filter[scan_id]")
|
||||
if not scan_id:
|
||||
raise ValidationError(
|
||||
[
|
||||
{
|
||||
"detail": "This query parameter is required.",
|
||||
"status": 400,
|
||||
"source": {"pointer": "filter[scan_id]"},
|
||||
"code": "required",
|
||||
}
|
||||
]
|
||||
has_provider_filters = self._has_provider_filters(include_dot_aliases=True)
|
||||
self._validate_scan_selection(scan_id, has_provider_filters)
|
||||
|
||||
latest_scan_ids = None
|
||||
if has_provider_filters:
|
||||
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
|
||||
queryset = self._filtered_queryset_for_latest_provider_scans(
|
||||
latest_scan_ids
|
||||
)
|
||||
else:
|
||||
queryset = self._apply_filterset(self.get_queryset(), self.filterset_class)
|
||||
|
||||
regions = list(
|
||||
self.get_queryset()
|
||||
.filter(scan_id=scan_id)
|
||||
.values_list("region", flat=True)
|
||||
.order_by("region")
|
||||
.distinct()
|
||||
queryset.values_list("region", flat=True).order_by("region").distinct()
|
||||
)
|
||||
result = {"regions": regions}
|
||||
|
||||
if regions:
|
||||
serializer = self.get_serializer(data=result)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
task_response = None
|
||||
if has_provider_filters:
|
||||
task_response = self._task_response_for_latest_provider_scans_without_data(
|
||||
latest_scan_ids,
|
||||
)
|
||||
elif not regions:
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
if has_provider_filters and task_response:
|
||||
return task_response
|
||||
|
||||
serializer = self.get_serializer(data=result)
|
||||
@@ -4988,19 +5118,10 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
@action(detail=False, methods=["get"], url_name="requirements")
|
||||
def requirements(self, request):
|
||||
scan_id = request.query_params.get("filter[scan_id]")
|
||||
has_provider_filters = self._has_provider_filters(include_dot_aliases=True)
|
||||
compliance_id = request.query_params.get("filter[compliance_id]")
|
||||
|
||||
if not scan_id:
|
||||
raise ValidationError(
|
||||
[
|
||||
{
|
||||
"detail": "This query parameter is required.",
|
||||
"status": 400,
|
||||
"source": {"pointer": "filter[scan_id]"},
|
||||
"code": "required",
|
||||
}
|
||||
]
|
||||
)
|
||||
self._validate_scan_selection(scan_id, has_provider_filters)
|
||||
|
||||
if not compliance_id:
|
||||
raise ValidationError(
|
||||
@@ -5013,7 +5134,16 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
}
|
||||
]
|
||||
)
|
||||
filtered_queryset = self.filter_queryset(self.get_queryset())
|
||||
latest_scan_ids = None
|
||||
if has_provider_filters:
|
||||
latest_scan_ids = self._latest_scan_ids_for_provider_filters()
|
||||
filtered_queryset = self._filtered_queryset_for_latest_provider_scans(
|
||||
latest_scan_ids
|
||||
)
|
||||
else:
|
||||
filtered_queryset = self._apply_filterset(
|
||||
self.get_queryset(), self.filterset_class
|
||||
)
|
||||
|
||||
all_requirements = filtered_queryset.values(
|
||||
"requirement_id",
|
||||
@@ -5072,13 +5202,22 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
requirements_summary, many=True
|
||||
)
|
||||
|
||||
task_response = None
|
||||
if has_provider_filters:
|
||||
task_response = self._task_response_for_latest_provider_scans_without_data(
|
||||
latest_scan_ids,
|
||||
)
|
||||
elif not requirements_summary:
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
if has_provider_filters and task_response:
|
||||
return task_response
|
||||
|
||||
if requirements_summary:
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
task_response = self._task_response_if_running(scan_id)
|
||||
if task_response:
|
||||
return task_response
|
||||
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="attributes")
|
||||
@@ -5317,7 +5456,7 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
),
|
||||
)
|
||||
@method_decorator(CACHE_DECORATOR, name="list")
|
||||
class OverviewViewSet(BaseRLSViewSet):
|
||||
class OverviewViewSet(ProviderFilterParamsMixin, BaseRLSViewSet):
|
||||
queryset = ScanSummary.objects.all()
|
||||
http_method_names = ["get"]
|
||||
ordering = ["-inserted_at"]
|
||||
@@ -5434,18 +5573,6 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
tenant_id=tenant_id, scan_id__in=latest_scan_ids
|
||||
)
|
||||
|
||||
def _normalize_jsonapi_params(self, query_params, exclude_keys=None):
|
||||
"""Convert JSON:API filter params (filter[X]) to flat params (X)."""
|
||||
exclude_keys = exclude_keys or set()
|
||||
normalized = QueryDict(mutable=True)
|
||||
for key, values in query_params.lists():
|
||||
normalized_key = (
|
||||
key[7:-1] if key.startswith("filter[") and key.endswith("]") else key
|
||||
)
|
||||
if normalized_key not in exclude_keys:
|
||||
normalized.setlist(normalized_key, values)
|
||||
return normalized
|
||||
|
||||
def _ensure_allowed_providers(self):
|
||||
"""Populate allowed providers for RBAC-aware queries once per request."""
|
||||
if getattr(self, "_providers_initialized", False):
|
||||
@@ -5465,15 +5592,6 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
return queryset.filter(**provider_filter)
|
||||
return queryset
|
||||
|
||||
def _apply_filterset(self, queryset, filterset_class, exclude_keys=None):
|
||||
normalized_params = self._normalize_jsonapi_params(
|
||||
self.request.query_params, exclude_keys=set(exclude_keys or [])
|
||||
)
|
||||
filterset = filterset_class(normalized_params, queryset=queryset)
|
||||
if not filterset.is_valid():
|
||||
raise ValidationError(filterset.errors)
|
||||
return filterset.qs
|
||||
|
||||
def _latest_scan_ids_for_allowed_providers(self, tenant_id, provider_filters=None):
|
||||
provider_filter = self._get_provider_filter()
|
||||
queryset = Scan.all_objects.filter(
|
||||
@@ -5487,40 +5605,6 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
.values_list("id", flat=True)
|
||||
)
|
||||
|
||||
def _extract_provider_filters_from_params(self):
|
||||
"""Extract and validate provider filters from query params."""
|
||||
params = self.request.query_params
|
||||
filters = {}
|
||||
valid_provider_types = {c[0] for c in Provider.ProviderChoices.choices}
|
||||
|
||||
provider_id = params.get("filter[provider_id]")
|
||||
if provider_id:
|
||||
filters["provider_id"] = provider_id
|
||||
|
||||
provider_id_in = params.get("filter[provider_id__in]")
|
||||
if provider_id_in:
|
||||
filters["provider_id__in"] = provider_id_in.split(",")
|
||||
|
||||
provider_type = params.get("filter[provider_type]")
|
||||
if provider_type:
|
||||
if provider_type not in valid_provider_types:
|
||||
raise ValidationError(
|
||||
{"provider_type": f"Invalid choice: {provider_type}"}
|
||||
)
|
||||
filters["provider__provider"] = provider_type
|
||||
|
||||
provider_type_in = params.get("filter[provider_type__in]")
|
||||
if provider_type_in:
|
||||
types = provider_type_in.split(",")
|
||||
invalid = [t for t in types if t not in valid_provider_types]
|
||||
if invalid:
|
||||
raise ValidationError(
|
||||
{"provider_type__in": f"Invalid choices: {', '.join(invalid)}"}
|
||||
)
|
||||
filters["provider__provider__in"] = types
|
||||
|
||||
return filters
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="providers")
|
||||
def providers(self, request):
|
||||
tenant_id = self.request.tenant_id
|
||||
@@ -5591,15 +5675,11 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
tenant_id = self.request.tenant_id
|
||||
providers_qs = Provider.objects.filter(tenant_id=tenant_id)
|
||||
|
||||
self._ensure_allowed_providers()
|
||||
if hasattr(self, "allowed_providers"):
|
||||
allowed_ids = list(self.allowed_providers.values_list("id", flat=True))
|
||||
if not allowed_ids:
|
||||
overview = []
|
||||
return Response(
|
||||
self.get_serializer(overview, many=True).data,
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
providers_qs = providers_qs.filter(id__in=allowed_ids)
|
||||
providers_qs = providers_qs.filter(
|
||||
id__in=self.allowed_providers.values("id")
|
||||
)
|
||||
|
||||
overview = (
|
||||
providers_qs.values("provider")
|
||||
@@ -5815,29 +5895,41 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
description="Retrieve a specific snapshot by ID. If not provided, returns latest snapshots.",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_id",
|
||||
name="filter[provider_id]",
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by specific provider ID",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_id__in",
|
||||
name="filter[provider_id__in]",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by multiple provider IDs (comma-separated UUIDs)",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_type",
|
||||
name="filter[provider_type]",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by provider type (aws, azure, gcp, etc.)",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="provider_type__in",
|
||||
name="filter[provider_type__in]",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by multiple provider types (comma-separated)",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="filter[provider_groups]",
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by provider group ID",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="filter[provider_groups__in]",
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter by multiple provider group IDs (comma-separated UUIDs)",
|
||||
),
|
||||
],
|
||||
)
|
||||
@action(detail=False, methods=["get"], url_name="threatscore")
|
||||
@@ -6179,6 +6271,8 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
"provider_id__in",
|
||||
"provider_type",
|
||||
"provider_type__in",
|
||||
"provider_groups",
|
||||
"provider_groups__in",
|
||||
}
|
||||
filtered_queryset = self._apply_filterset(
|
||||
base_queryset, CategoryOverviewFilter, exclude_keys=provider_filter_keys
|
||||
@@ -6248,6 +6342,8 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
"provider_id__in",
|
||||
"provider_type",
|
||||
"provider_type__in",
|
||||
"provider_groups",
|
||||
"provider_groups__in",
|
||||
}
|
||||
filtered_queryset = self._apply_filterset(
|
||||
base_queryset,
|
||||
@@ -7298,7 +7394,7 @@ SEVERITY_ORDER_REVERSE = {v: k for k, v in SEVERITY_ORDER.items()}
|
||||
),
|
||||
retrieve=extend_schema(exclude=True),
|
||||
)
|
||||
class FindingGroupViewSet(BaseRLSViewSet):
|
||||
class FindingGroupViewSet(JsonApiFilterMixin, BaseRLSViewSet):
|
||||
"""
|
||||
ViewSet for Finding Groups - aggregates findings by check_id.
|
||||
|
||||
@@ -7314,6 +7410,7 @@ class FindingGroupViewSet(BaseRLSViewSet):
|
||||
queryset = FindingGroupDailySummary.objects.all()
|
||||
serializer_class = FindingGroupSerializer
|
||||
filterset_class = FindingGroupFilter
|
||||
jsonapi_filter_replace_dots = True
|
||||
filter_backends = [
|
||||
jsonapi_filters.QueryParameterValidationFilter,
|
||||
jsonapi_filters.OrderingFilter,
|
||||
@@ -7364,18 +7461,6 @@ class FindingGroupViewSet(BaseRLSViewSet):
|
||||
|
||||
return queryset
|
||||
|
||||
def _normalize_jsonapi_params(self, query_params):
|
||||
"""Convert JSON:API filter params (filter[X]) to flat params (X)."""
|
||||
normalized = QueryDict(mutable=True)
|
||||
for key, values in query_params.lists():
|
||||
normalized_key = (
|
||||
key[7:-1] if key.startswith("filter[") and key.endswith("]") else key
|
||||
)
|
||||
# Convert JSON:API dot notation to Django double underscore
|
||||
normalized_key = normalized_key.replace(".", "__")
|
||||
normalized.setlist(normalized_key, values)
|
||||
return normalized
|
||||
|
||||
@extend_schema(exclude=True)
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
raise MethodNotAllowed(method="GET")
|
||||
@@ -8494,9 +8579,10 @@ class FindingGroupViewSet(BaseRLSViewSet):
|
||||
|
||||
This endpoint returns finding groups without requiring date filters,
|
||||
automatically using the latest available data per check_id.
|
||||
All other filters (provider_id, provider_type, check_id) are still supported.
|
||||
Provider, provider group, check, and computed filters are still supported.
|
||||
""",
|
||||
tags=["Finding Groups"],
|
||||
filters=True,
|
||||
)
|
||||
@action(detail=False, methods=["get"], url_name="latest")
|
||||
def latest(self, request):
|
||||
|
||||
Reference in New Issue
Block a user