feat(api): add check title search for finding groups (#10377)

This commit is contained in:
Adrián Peña
2026-03-19 16:48:26 +01:00
committed by GitHub
parent cece2cb87e
commit 2fe92cfce3
6 changed files with 65 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ All notable changes to the **Prowler API** are documented in this file.
### 🚀 Added ### 🚀 Added
- `CORS_ALLOWED_ORIGINS` configurable via environment variable [(#10355)](https://github.com/prowler-cloud/prowler/pull/10355) - `CORS_ALLOWED_ORIGINS` configurable via environment variable [(#10355)](https://github.com/prowler-cloud/prowler/pull/10355)
- Finding groups support `check_title` substring filtering [(#10377)](https://github.com/prowler-cloud/prowler/pull/10377)
- Attack Paths: Tenant and provider related labels to the nodes so they can be easily filtered on custom queries [(#10308)](https://github.com/prowler-cloud/prowler/pull/10308) - Attack Paths: Tenant and provider related labels to the nodes so they can be easily filtered on custom queries [(#10308)](https://github.com/prowler-cloud/prowler/pull/10308)
### 🔄 Changed ### 🔄 Changed

View File

@@ -926,6 +926,9 @@ class FindingGroupSummaryFilter(FilterSet):
check_id = CharFilter(field_name="check_id", lookup_expr="exact") check_id = CharFilter(field_name="check_id", lookup_expr="exact")
check_id__in = CharInFilter(field_name="check_id", lookup_expr="in") check_id__in = CharInFilter(field_name="check_id", lookup_expr="in")
check_id__icontains = CharFilter(field_name="check_id", lookup_expr="icontains") check_id__icontains = CharFilter(field_name="check_id", lookup_expr="icontains")
check_title__icontains = CharFilter(
field_name="check_title", lookup_expr="icontains"
)
# Provider filters # Provider filters
provider_id = UUIDFilter(field_name="provider_id", lookup_expr="exact") provider_id = UUIDFilter(field_name="provider_id", lookup_expr="exact")
@@ -1025,6 +1028,9 @@ class LatestFindingGroupSummaryFilter(FilterSet):
check_id = CharFilter(field_name="check_id", lookup_expr="exact") check_id = CharFilter(field_name="check_id", lookup_expr="exact")
check_id__in = CharInFilter(field_name="check_id", lookup_expr="in") check_id__in = CharInFilter(field_name="check_id", lookup_expr="in")
check_id__icontains = CharFilter(field_name="check_id", lookup_expr="icontains") check_id__icontains = CharFilter(field_name="check_id", lookup_expr="icontains")
check_title__icontains = CharFilter(
field_name="check_title", lookup_expr="icontains"
)
# Provider filters # Provider filters
provider_id = UUIDFilter(field_name="provider_id", lookup_expr="exact") provider_id = UUIDFilter(field_name="provider_id", lookup_expr="exact")

View File

@@ -301,7 +301,7 @@ class TestTokenSwitchTenant:
assert invalid_tenant_response.status_code == 400 assert invalid_tenant_response.status_code == 400
assert invalid_tenant_response.json()["errors"][0]["code"] == "invalid" assert invalid_tenant_response.json()["errors"][0]["code"] == "invalid"
assert invalid_tenant_response.json()["errors"][0]["detail"] == ( assert invalid_tenant_response.json()["errors"][0]["detail"] == (
"Tenant does not exist or user is not a " "member." "Tenant does not exist or user is not a member."
) )
@@ -912,10 +912,9 @@ class TestAPIKeyLifecycle:
auth_response = client.get(reverse("provider-list"), headers=api_key_headers) auth_response = client.get(reverse("provider-list"), headers=api_key_headers)
# Must return 401 Unauthorized, not 500 Internal Server Error # Must return 401 Unauthorized, not 500 Internal Server Error
assert auth_response.status_code == 401, ( assert (
f"Expected 401 but got {auth_response.status_code}: " auth_response.status_code == 401
f"{auth_response.json()}" ), f"Expected 401 but got {auth_response.status_code}: {auth_response.json()}"
)
# Verify error message is present # Verify error message is present
response_json = auth_response.json() response_json = auth_response.json()

View File

@@ -10,11 +10,11 @@ from django.conf import settings
import api import api
import api.apps as api_apps_module import api.apps as api_apps_module
from api.apps import ( from api.apps import (
ApiConfig,
PRIVATE_KEY_FILE, PRIVATE_KEY_FILE,
PUBLIC_KEY_FILE, PUBLIC_KEY_FILE,
SIGNING_KEY_ENV, SIGNING_KEY_ENV,
VERIFYING_KEY_ENV, VERIFYING_KEY_ENV,
ApiConfig,
) )
@@ -187,9 +187,10 @@ def test_ready_initializes_driver_for_api_process(monkeypatch):
_set_argv(monkeypatch, ["gunicorn"]) _set_argv(monkeypatch, ["gunicorn"])
_set_testing(monkeypatch, False) _set_testing(monkeypatch, False)
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch( with (
"api.attack_paths.database.init_driver" patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None),
) as init_driver: patch("api.attack_paths.database.init_driver") as init_driver,
):
config.ready() config.ready()
init_driver.assert_called_once() init_driver.assert_called_once()
@@ -200,9 +201,10 @@ def test_ready_skips_driver_for_celery(monkeypatch):
_set_argv(monkeypatch, ["celery", "-A", "api"]) _set_argv(monkeypatch, ["celery", "-A", "api"])
_set_testing(monkeypatch, False) _set_testing(monkeypatch, False)
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch( with (
"api.attack_paths.database.init_driver" patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None),
) as init_driver: patch("api.attack_paths.database.init_driver") as init_driver,
):
config.ready() config.ready()
init_driver.assert_not_called() init_driver.assert_not_called()
@@ -213,9 +215,10 @@ def test_ready_skips_driver_for_manage_py_skip_command(monkeypatch):
_set_argv(monkeypatch, ["manage.py", "migrate"]) _set_argv(monkeypatch, ["manage.py", "migrate"])
_set_testing(monkeypatch, False) _set_testing(monkeypatch, False)
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch( with (
"api.attack_paths.database.init_driver" patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None),
) as init_driver: patch("api.attack_paths.database.init_driver") as init_driver,
):
config.ready() config.ready()
init_driver.assert_not_called() init_driver.assert_not_called()
@@ -226,9 +229,10 @@ def test_ready_skips_driver_when_testing(monkeypatch):
_set_argv(monkeypatch, ["gunicorn"]) _set_argv(monkeypatch, ["gunicorn"])
_set_testing(monkeypatch, True) _set_testing(monkeypatch, True)
with patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None), patch( with (
"api.attack_paths.database.init_driver" patch.object(ApiConfig, "_ensure_crypto_keys", return_value=None),
) as init_driver: patch("api.attack_paths.database.init_driver") as init_driver,
):
config.ready() config.ready()
init_driver.assert_not_called() init_driver.assert_not_called()

View File

@@ -15526,6 +15526,22 @@ class TestFindingGroupViewSet:
assert len(response.json()["data"]) == 1 assert len(response.json()["data"]) == 1
assert "bucket" in response.json()["data"][0]["id"].lower() assert "bucket" in response.json()["data"][0]["id"].lower()
def test_finding_groups_check_title_icontains(
self, authenticated_client, finding_groups_fixture
):
"""Test searching check titles with icontains."""
response = authenticated_client.get(
reverse("finding-group-list"),
{
"filter[inserted_at]": TODAY,
"filter[check_title.icontains]": "public access",
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) == 1
assert data[0]["id"] == "s3_bucket_public_access"
def test_resources_not_found(self, authenticated_client): def test_resources_not_found(self, authenticated_client):
"""Test 404 returned for nonexistent check_id.""" """Test 404 returned for nonexistent check_id."""
response = authenticated_client.get( response = authenticated_client.get(

View File

@@ -4,7 +4,6 @@ import json
import logging import logging
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@@ -12,7 +11,6 @@ from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
from urllib.parse import urljoin from urllib.parse import urljoin
import sentry_sdk import sentry_sdk
from allauth.socialaccount.models import SocialAccount, SocialApp from allauth.socialaccount.models import SocialAccount, SocialApp
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
@@ -76,6 +74,7 @@ from rest_framework.exceptions import (
) )
from rest_framework.generics import GenericAPIView, get_object_or_404 from rest_framework.generics import GenericAPIView, get_object_or_404
from rest_framework.permissions import SAFE_METHODS from rest_framework.permissions import SAFE_METHODS
from rest_framework_json_api import filters as jsonapi_filters
from rest_framework_json_api.views import RelationshipView, Response from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan from tasks.beat import schedule_provider_scan
@@ -100,7 +99,6 @@ from api.attack_paths import database as graph_database
from api.attack_paths import get_queries_for_provider, get_query_by_id from api.attack_paths import get_queries_for_provider, get_query_by_id
from api.attack_paths import views_helpers as attack_paths_views_helpers from api.attack_paths import views_helpers as attack_paths_views_helpers
from api.base_views import BaseRLSViewSet, BaseTenantViewset, BaseUserViewset from api.base_views import BaseRLSViewSet, BaseTenantViewset, BaseUserViewset
from api.renderers import APIJSONRenderer, PlainTextRenderer
from api.compliance import ( from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE, PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
get_compliance_frameworks, get_compliance_frameworks,
@@ -199,6 +197,7 @@ from api.models import (
) )
from api.pagination import ComplianceOverviewPagination from api.pagination import ComplianceOverviewPagination
from api.rbac.permissions import Permissions, get_providers, get_role from api.rbac.permissions import Permissions, get_providers, get_role
from api.renderers import APIJSONRenderer, PlainTextRenderer
from api.rls import Tenant from api.rls import Tenant
from api.utils import ( from api.utils import (
CustomOAuth2Client, CustomOAuth2Client,
@@ -6777,13 +6776,29 @@ class FindingGroupViewSet(BaseRLSViewSet):
queryset = FindingGroupDailySummary.objects.all() queryset = FindingGroupDailySummary.objects.all()
serializer_class = FindingGroupSerializer serializer_class = FindingGroupSerializer
filterset_class = FindingGroupSummaryFilter filterset_class = FindingGroupSummaryFilter
filter_backends = [
jsonapi_filters.QueryParameterValidationFilter,
jsonapi_filters.OrderingFilter,
CustomDjangoFilterBackend,
]
http_method_names = ["get"] http_method_names = ["get"]
required_permissions = [] required_permissions = []
def get_filterset_class(self): def get_filterset_class(self):
"""Return appropriate filter based on action.""" """Return the filterset class used for schema generation and the list action.
Note: The resources and latest_resources actions do not use this method
at runtime. They manually instantiate FindingGroupFilter /
LatestFindingGroupFilter against a Finding queryset (see
_get_finding_queryset). The class returned here for those actions only
affects the OpenAPI schema generated by drf-spectacular.
"""
if self.action == "latest": if self.action == "latest":
return LatestFindingGroupSummaryFilter return LatestFindingGroupSummaryFilter
if self.action == "resources":
return FindingGroupFilter
if self.action == "latest_resources":
return LatestFindingGroupFilter
return FindingGroupSummaryFilter return FindingGroupSummaryFilter
def get_queryset(self): def get_queryset(self):
@@ -7237,6 +7252,7 @@ class FindingGroupViewSet(BaseRLSViewSet):
and timing information including how long they have been failing. and timing information including how long they have been failing.
""", """,
tags=["Finding Groups"], tags=["Finding Groups"],
filters=True,
) )
@action(detail=True, methods=["get"], url_path="resources") @action(detail=True, methods=["get"], url_path="resources")
def resources(self, request, pk=None): def resources(self, request, pk=None):
@@ -7311,6 +7327,7 @@ class FindingGroupViewSet(BaseRLSViewSet):
and timing information. No date filters required. and timing information. No date filters required.
""", """,
tags=["Finding Groups"], tags=["Finding Groups"],
filters=True,
) )
@action( @action(
detail=False, detail=False,