mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-06-17 13:03:14 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 88f6848913 | |||
| 2de298fb7b | |||
| 11f0845a91 | |||
| 42d99a17a6 | |||
| 832f10b7f6 | |||
| d133ad18a4 | |||
| 3539940a26 | |||
| 1192d94648 | |||
| a578f4af34 | |||
| d6528b674e | |||
| 75decbbedf | |||
| 4a14559a5f | |||
| c6f8620a0d | |||
| ca4889b43e | |||
| 057d061c7e |
@@ -145,7 +145,7 @@ SENTRY_RELEASE=local
|
||||
NEXT_PUBLIC_SENTRY_ENVIRONMENT=${SENTRY_ENVIRONMENT}
|
||||
|
||||
#### Prowler release version ####
|
||||
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.30.0
|
||||
NEXT_PUBLIC_PROWLER_RELEASE_VERSION=v5.30.2
|
||||
|
||||
# Social login credentials
|
||||
SOCIAL_GOOGLE_OAUTH_CALLBACK_URL="${AUTH_URL}/api/auth/callback/google"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
name: 'OSV-Scanner'
|
||||
description: 'Install osv-scanner and scan a lockfile, failing on HIGH/CRITICAL/UNKNOWN severity findings. Posts/updates a PR comment with findings on pull_request events (requires pull-requests: write).'
|
||||
description: 'Install osv-scanner and scan a lockfile, failing on CRITICAL severity findings. Posts/updates a PR comment with findings on pull_request events (requires pull-requests: write).'
|
||||
author: 'Prowler'
|
||||
|
||||
inputs:
|
||||
@@ -7,9 +7,9 @@ inputs:
|
||||
description: 'Path to the lockfile to scan, relative to the repository root (e.g. uv.lock, api/uv.lock, ui/pnpm-lock.yaml).'
|
||||
required: true
|
||||
severity-levels:
|
||||
description: 'Comma-separated severity levels that fail the scan. Default: HIGH,CRITICAL,UNKNOWN.'
|
||||
description: 'Comma-separated severity levels that fail the scan. Default: CRITICAL.'
|
||||
required: false
|
||||
default: 'HIGH,CRITICAL,UNKNOWN'
|
||||
default: 'CRITICAL'
|
||||
version:
|
||||
description: 'osv-scanner release tag to install. When overriding, you MUST also override binary-sha256.'
|
||||
required: false
|
||||
|
||||
@@ -12,9 +12,6 @@ on:
|
||||
branches:
|
||||
- 'master'
|
||||
- 'v5.*'
|
||||
paths:
|
||||
- 'api/**'
|
||||
- '.github/workflows/api-container-checks.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
@@ -16,13 +16,6 @@ on:
|
||||
branches:
|
||||
- "master"
|
||||
- "v5.*"
|
||||
paths:
|
||||
- 'api/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
- '.github/workflows/api-security.yml'
|
||||
- '.github/actions/setup-python-uv/**'
|
||||
- '.github/actions/osv-scanner/**'
|
||||
- '.github/scripts/osv-scan.sh'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
@@ -12,9 +12,6 @@ on:
|
||||
branches:
|
||||
- 'master'
|
||||
- 'v5.*'
|
||||
paths:
|
||||
- 'mcp_server/**'
|
||||
- '.github/workflows/mcp-container-checks.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
@@ -15,12 +15,6 @@ on:
|
||||
branches:
|
||||
- 'master'
|
||||
- 'v5.*'
|
||||
paths:
|
||||
- 'mcp_server/pyproject.toml'
|
||||
- 'mcp_server/uv.lock'
|
||||
- '.github/workflows/mcp-security.yml'
|
||||
- '.github/actions/osv-scanner/**'
|
||||
- '.github/scripts/osv-scan.sh'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
@@ -15,12 +15,6 @@ on:
|
||||
branches:
|
||||
- 'master'
|
||||
- 'v5.*'
|
||||
paths:
|
||||
- 'prowler/**'
|
||||
- 'Dockerfile*'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
- '.github/workflows/sdk-container-checks.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
@@ -19,16 +19,6 @@ on:
|
||||
branches:
|
||||
- 'master'
|
||||
- 'v5.*'
|
||||
paths:
|
||||
- 'prowler/**'
|
||||
- 'tests/**'
|
||||
- 'pyproject.toml'
|
||||
- 'uv.lock'
|
||||
- '.github/workflows/sdk-tests.yml'
|
||||
- '.github/workflows/sdk-security.yml'
|
||||
- '.github/actions/setup-python-uv/**'
|
||||
- '.github/actions/osv-scanner/**'
|
||||
- '.github/scripts/osv-scan.sh'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
@@ -12,9 +12,6 @@ on:
|
||||
branches:
|
||||
- 'master'
|
||||
- 'v5.*'
|
||||
paths:
|
||||
- 'ui/**'
|
||||
- '.github/workflows/ui-container-checks.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
@@ -15,12 +15,6 @@ on:
|
||||
branches:
|
||||
- 'master'
|
||||
- 'v5.*'
|
||||
paths:
|
||||
- 'ui/package.json'
|
||||
- 'ui/pnpm-lock.yaml'
|
||||
- '.github/workflows/ui-security.yml'
|
||||
- '.github/actions/osv-scanner/**'
|
||||
- '.github/scripts/osv-scan.sh'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
+14
-4
@@ -2,11 +2,21 @@
|
||||
|
||||
All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
## [1.32.0] (Prowler UNRELEASED)
|
||||
## [1.31.2] (Prowler v5.30.2)
|
||||
|
||||
### 🚀 Added
|
||||
### 🔄 Changed
|
||||
|
||||
- Server-Sent Events (SSE) infrastructure for the API: a base viewset, a tenant-aware channel manager, and channel-name helpers backed by `django-eventstream` over Valkey Pub/Sub and served through the Gunicorn ASGI worker, so feature endpoints can stream events to clients over a single long-lived connection [(#11556)](https://github.com/prowler-cloud/prowler/pull/11556)
|
||||
- `scan-compliance-overviews` task now streams the findings aggregation and the requirement-row writes so it runs faster and its peak memory no longer grows with the number of regions and frameworks [(#11591)](https://github.com/prowler-cloud/prowler/pull/11591)
|
||||
|
||||
---
|
||||
|
||||
## [1.31.1] (Prowler v5.30.1)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- `compliance-overviews/attributes` now resolves the provider from the scan, so multi-provider universal frameworks (e.g. CSA CCM) return the check IDs of the scan's provider and Azure/GCP requirement details show their findings instead of appearing empty [(#11546)](https://github.com/prowler-cloud/prowler/pull/11546)
|
||||
- Attack Paths: `drop_subgraph` now deletes relationships first and then nodes in batches, using less memory on Neo4j when clearing a dense provider graph [(#11557)](https://github.com/prowler-cloud/prowler/pull/11557)
|
||||
- OCI scans now use API key credentials with the configured region instead of falling back to `/home/prowler/.oci/config` [(#11558)](https://github.com/prowler-cloud/prowler/pull/11558)
|
||||
|
||||
---
|
||||
|
||||
@@ -27,7 +37,7 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
- Workers now shut down gracefully on deploy or restart, finishing or re-queueing in-flight tasks instead of being force-killed and leaving them stuck [(#11416)](https://github.com/prowler-cloud/prowler/pull/11416)
|
||||
- Resource `name` is now stored and refreshed on every scan, so resources no longer keep an empty name [(#11476)](https://github.com/prowler-cloud/prowler/pull/11476)
|
||||
- Compliance catalog now warms in background during startup. `compliance-overviews/attributes` returns `503` while warming, so the first request after a deploy no longer trips the API timeout [(#4554)](https://github.com/prowler-cloud/prowler-cloud/pull/4554)
|
||||
- Compliance catalog now warms in background during startup. `compliance-overviews/attributes` returns `503` while warming, so the first request after a deploy no longer trips the API timeout [(#11530)](https://github.com/prowler-cloud/prowler/pull/11530)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
|
||||
@@ -21,19 +21,13 @@ apply_fixtures() {
|
||||
}
|
||||
|
||||
start_dev_server() {
|
||||
echo "Starting the development server (Gunicorn ASGI, debug + reload)..."
|
||||
# Same server/worker as prod (config.asgi via the native `asgi` worker), so
|
||||
# SSE streams run on the event loop exactly as they do in production. DEBUG is
|
||||
# on so guniconf's `reload = DEBUG` hot-reloads edited code (and flips
|
||||
# `preload_app` off so reload actually takes).
|
||||
export DJANGO_DEBUG="${DJANGO_DEBUG:-True}"
|
||||
export DJANGO_BIND_ADDRESS="${DJANGO_BIND_ADDRESS:-0.0.0.0}"
|
||||
exec uv run gunicorn -c config/guniconf.py config.asgi:application
|
||||
echo "Starting the development server..."
|
||||
exec uv run python manage.py runserver 0.0.0.0:"${DJANGO_PORT:-8080}"
|
||||
}
|
||||
|
||||
start_prod_server() {
|
||||
echo "Starting the Gunicorn server..."
|
||||
exec uv run gunicorn -c config/guniconf.py config.asgi:application
|
||||
exec uv run gunicorn -c config/guniconf.py config.wsgi:application
|
||||
}
|
||||
|
||||
resolve_worker_hostname() {
|
||||
|
||||
+4
-6
@@ -41,10 +41,9 @@ dependencies = [
|
||||
"drf-spectacular==0.27.2",
|
||||
"drf-spectacular-jsonapi==0.5.1",
|
||||
"defusedxml==0.7.1",
|
||||
"django-eventstream==5.3.3",
|
||||
"gunicorn==26.0.0",
|
||||
"gunicorn==23.0.0",
|
||||
"lxml==6.1.0",
|
||||
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
|
||||
"prowler @ git+https://github.com/prowler-cloud/prowler.git@v5.30",
|
||||
"psycopg2-binary==2.9.9",
|
||||
"pytest-celery[redis] (==1.3.0)",
|
||||
"sentry-sdk[django] (==2.56.0)",
|
||||
@@ -69,7 +68,7 @@ name = "prowler-api"
|
||||
package-mode = false
|
||||
# Needed for the SDK compatibility
|
||||
requires-python = ">=3.11,<3.13"
|
||||
version = "1.31.0"
|
||||
version = "1.31.2"
|
||||
|
||||
[tool.uv]
|
||||
# Transitive pins matching master to avoid silent drift; bump deliberately.
|
||||
@@ -210,7 +209,6 @@ constraint-dependencies = [
|
||||
"django-celery-results==2.6.0",
|
||||
"django-cors-headers==4.4.0",
|
||||
"django-environ==0.11.2",
|
||||
"django-eventstream==5.3.3",
|
||||
"django-filter==24.3",
|
||||
"django-guid==3.5.0",
|
||||
"django-postgres-extra==2.0.9",
|
||||
@@ -255,7 +253,7 @@ constraint-dependencies = [
|
||||
"grpc-google-iam-v1==0.14.3",
|
||||
"grpcio==1.76.0",
|
||||
"grpcio-status==1.76.0",
|
||||
"gunicorn==26.0.0",
|
||||
"gunicorn==23.0.0",
|
||||
"h11==0.16.0",
|
||||
"h2==4.3.0",
|
||||
"hpack==4.1.0",
|
||||
|
||||
@@ -175,7 +175,8 @@ def drop_subgraph(database: str, provider_id: str) -> int:
|
||||
"""
|
||||
Delete all nodes for a provider from the tenant database.
|
||||
|
||||
Uses batched deletion to avoid memory issues with large graphs.
|
||||
Deletes relationships then nodes in batches (not `DETACH DELETE`) so a dense
|
||||
provider's graph cannot exceed Neo4j's transaction memory limit.
|
||||
Silently returns 0 if the database doesn't exist.
|
||||
"""
|
||||
provider_label = get_provider_label(provider_id)
|
||||
@@ -183,13 +184,28 @@ def drop_subgraph(database: str, provider_id: str) -> int:
|
||||
|
||||
try:
|
||||
with get_session(database) as session:
|
||||
# Phase 1: delete relationships incident to provider nodes in batches.
|
||||
deleted_count = 1
|
||||
while deleted_count > 0:
|
||||
result = session.run(
|
||||
f"""
|
||||
MATCH (:`{provider_label}`)-[r]-()
|
||||
WITH DISTINCT r LIMIT $batch_size
|
||||
DELETE r
|
||||
RETURN COUNT(r) AS deleted_rels_count
|
||||
""",
|
||||
{"batch_size": BATCH_SIZE},
|
||||
)
|
||||
deleted_count = result.single().get("deleted_rels_count", 0)
|
||||
|
||||
# Phase 2: delete the now relationship-free nodes in batches.
|
||||
deleted_count = 1
|
||||
while deleted_count > 0:
|
||||
result = session.run(
|
||||
f"""
|
||||
MATCH (n:{PROVIDER_RESOURCE_LABEL}:`{provider_label}`)
|
||||
WITH n LIMIT $batch_size
|
||||
DETACH DELETE n
|
||||
DELETE n
|
||||
RETURN COUNT(n) AS deleted_nodes_count
|
||||
""",
|
||||
{"batch_size": BATCH_SIZE},
|
||||
|
||||
@@ -93,31 +93,3 @@ class CombinedJWTOrAPIKeyAuthentication(BaseAuthentication):
|
||||
|
||||
# Default fallback
|
||||
return self.jwt_auth.authenticate(request)
|
||||
|
||||
|
||||
class SSEAuthentication(CombinedJWTOrAPIKeyAuthentication):
|
||||
"""JWT/API-Key auth that also accepts `?access_token=<jwt>`.
|
||||
|
||||
Browser `EventSource` is the only widely available SSE client API
|
||||
and it cannot set the `Authorization` header (its constructor takes
|
||||
only a URL and `withCredentials`). To keep browser SSE clients on
|
||||
the same auth stack as the rest of the API, SSE endpoints additionally
|
||||
accept a JWT via the `?access_token=<jwt>` query parameter — the
|
||||
standard parameter name defined in RFC 6750 Section 2.3 for bearer tokens.
|
||||
"""
|
||||
|
||||
def authenticate(self, request: Request):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header:
|
||||
return super().authenticate(request)
|
||||
|
||||
raw_token = request.query_params.get("access_token")
|
||||
if not raw_token:
|
||||
# No header and no query token — let the default path raise
|
||||
# the canonical AuthenticationFailed via the parent class.
|
||||
return super().authenticate(request)
|
||||
|
||||
jwt_auth = JWTAuthentication()
|
||||
validated_token = jwt_auth.get_validated_token(raw_token)
|
||||
user = jwt_auth.get_user(validated_token)
|
||||
return user, validated_token
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: Prowler API
|
||||
version: 1.31.0
|
||||
version: 1.31.2
|
||||
description: |-
|
||||
Prowler API specification.
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Platform Server-Sent Events (SSE) infrastructure.
|
||||
|
||||
Wires `django-eventstream` into the API: a base viewset features
|
||||
subclass to expose an SSE endpoint
|
||||
(:class:`api.sse.base_views.BaseSSEViewSet`), the channel manager that
|
||||
enforces the tenant gate (:class:`api.sse.channelmanager.SSEChannelManager`),
|
||||
and the channel-name helpers (:func:`api.sse.utils.make_channel_name`).
|
||||
"""
|
||||
|
||||
from api.sse.utils import make_channel_name
|
||||
from api.sse.base_views import BaseSSEViewSet
|
||||
|
||||
__all__ = ["BaseSSEViewSet", "make_channel_name"]
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Base view class for SSE endpoints."""
|
||||
|
||||
from api.authentication import SSEAuthentication
|
||||
from api.base_views import BaseRLSViewSet
|
||||
from django_eventstream.renderers import SSEEventRenderer
|
||||
from django_eventstream.views import events
|
||||
|
||||
|
||||
class BaseSSEViewSet(BaseRLSViewSet):
|
||||
"""Base class for platform SSE endpoints.
|
||||
|
||||
Subclasses override method `get_channels` to declare the channel
|
||||
names the connection should subscribe to — the same way a regular
|
||||
DRF viewset overrides method `get_queryset`. The channel manager
|
||||
reads the result from `request.sse_channels`; there is no other
|
||||
coupling between platform and feature.
|
||||
"""
|
||||
|
||||
authentication_classes = [SSEAuthentication]
|
||||
# Pin the SSE renderer so content negotiation accepts the browser's
|
||||
# `Accept: text/event-stream`.
|
||||
renderer_classes = [SSEEventRenderer]
|
||||
|
||||
def get_channels(self) -> set[str]:
|
||||
"""Return the channels this connection subscribes to.
|
||||
|
||||
Implementations MUST raise the relevant DRF exceptions
|
||||
(`NotAuthenticated`, `PermissionDenied`, `NotFound`) when
|
||||
authorization fails. Returning an empty set would surface as
|
||||
django-eventstream's "No channels specified" which masks the
|
||||
real cause.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_queryset(self):
|
||||
# Most SSE viewsets only need `get_channels` and never call
|
||||
# `get_queryset` (the SSE list path bypasses serialization
|
||||
# entirely). Subclasses that perform their own queryset lookup
|
||||
# inside `get_channels` should override; the default raises
|
||||
# the same error a missing override on a ModelViewSet would.
|
||||
raise NotImplementedError
|
||||
|
||||
def list(self, request, *_args, **kwargs):
|
||||
"""Resolve channels under the regular DRF stack and stream."""
|
||||
request.sse_channels = self.get_channels()
|
||||
return events(request, **kwargs)
|
||||
@@ -1,33 +0,0 @@
|
||||
"""Channel manager that wires `django-eventstream` to platform SSE views."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from django_eventstream.channelmanager import DefaultChannelManager
|
||||
|
||||
from api.sse.utils import tenant_id_from_channel
|
||||
|
||||
|
||||
class SSEChannelManager(DefaultChannelManager):
|
||||
"""Connect `django-eventstream` to the platform's SSE viewsets."""
|
||||
|
||||
def get_channels_for_request(self, request, _view_kwargs):
|
||||
"""Return the channels the viewset already computed on the request."""
|
||||
return getattr(request, "sse_channels", set())
|
||||
|
||||
def can_read_channel(self, user, channel):
|
||||
"""Re-verify tenant membership once the stream is established.
|
||||
|
||||
The channel name embeds the tenant id; cross-tenant subscription
|
||||
is rejected here even if the URL-level check ever has a bug.
|
||||
Resource-level visibility was already enforced at connect.
|
||||
"""
|
||||
if user is None or not user.is_authenticated:
|
||||
return False
|
||||
tenant_id = tenant_id_from_channel(channel)
|
||||
if tenant_id is None:
|
||||
return False
|
||||
return user.is_member_of_tenant(tenant_id)
|
||||
|
||||
def is_channel_reliable(self, channel):
|
||||
"""Clients refetch canonical state from REST on reconnect"""
|
||||
return False
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Channel-name convention shared by SSE publishers, consumers, and the
|
||||
channel manager. The format is `<prefix>:<tenant_id>:<resource_id>`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
CHANNEL_SEPARATOR = ":"
|
||||
|
||||
|
||||
def make_channel_name(
|
||||
prefix: str,
|
||||
tenant_id: str | uuid.UUID,
|
||||
resource_id: str | uuid.UUID,
|
||||
) -> str:
|
||||
"""Build the canonical channel name for a resource.
|
||||
|
||||
Args:
|
||||
prefix: Feature-owned prefix (e.g. `"lighthouse-session"`).
|
||||
tenant_id: Tenant the resource belongs to.
|
||||
resource_id: Resource identifier within the tenant.
|
||||
"""
|
||||
return CHANNEL_SEPARATOR.join((prefix, str(tenant_id), str(resource_id)))
|
||||
|
||||
|
||||
def tenant_id_from_channel(channel: str) -> uuid.UUID | None:
|
||||
"""Return the tenant UUID embedded in *channel*, or `None` if
|
||||
*channel* does not follow the platform convention.
|
||||
|
||||
A `None` result MUST be treated by callers as "not authorized" or
|
||||
a malformed channel cannot be safely read.
|
||||
"""
|
||||
segments = channel.split(CHANNEL_SEPARATOR)
|
||||
if len(segments) < 3:
|
||||
return None
|
||||
try:
|
||||
return uuid.UUID(segments[1])
|
||||
except ValueError:
|
||||
return None
|
||||
@@ -542,3 +542,84 @@ class TestHasProviderData:
|
||||
):
|
||||
with pytest.raises(db_module.GraphDatabaseQueryException):
|
||||
db_module.has_provider_data("db-tenant-abc", "provider-123")
|
||||
|
||||
|
||||
class TestDropSubgraph:
|
||||
"""Test drop_subgraph two-phase batched deletion of a provider's graph."""
|
||||
|
||||
@staticmethod
|
||||
def _result(count):
|
||||
result = MagicMock()
|
||||
result.single.return_value.get.return_value = count
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _session_ctx(session):
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = session
|
||||
ctx.__exit__.return_value = False
|
||||
return ctx
|
||||
|
||||
def test_deletes_relationships_then_nodes_in_batches(self):
|
||||
session = MagicMock()
|
||||
# Phase 1 (relationships): one full batch then empty.
|
||||
# Phase 2 (nodes): one full batch then empty.
|
||||
session.run.side_effect = [
|
||||
self._result(1000),
|
||||
self._result(0),
|
||||
self._result(1000),
|
||||
self._result(0),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=self._session_ctx(session),
|
||||
):
|
||||
deleted = db_module.drop_subgraph("db-tenant-abc", "provider-123")
|
||||
|
||||
# Only phase-2 node counts contribute to the return value.
|
||||
assert deleted == 1000
|
||||
assert session.run.call_count == 4
|
||||
|
||||
queries = [call.args[0] for call in session.run.call_args_list]
|
||||
|
||||
# Regression guard: the memory blow-up was caused by DETACH DELETE.
|
||||
assert all("DETACH DELETE" not in query for query in queries)
|
||||
|
||||
rel_queries = [query for query in queries if "DELETE r" in query]
|
||||
node_queries = [query for query in queries if "DELETE n" in query]
|
||||
assert rel_queries and node_queries
|
||||
# DISTINCT avoids double-counting relationships matched from both ends.
|
||||
assert all("DISTINCT r" in query for query in rel_queries)
|
||||
|
||||
# Relationships must be fully drained before nodes are deleted.
|
||||
first_node = next(i for i, q in enumerate(queries) if "DELETE n" in q)
|
||||
last_rel = max(i for i, q in enumerate(queries) if "DELETE r" in q)
|
||||
assert last_rel < first_node
|
||||
|
||||
def test_returns_zero_when_database_not_found(self):
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
|
||||
message="Database does not exist",
|
||||
code="Neo.ClientError.Database.DatabaseNotFound",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
assert db_module.drop_subgraph("db-tenant-gone", "provider-123") == 0
|
||||
|
||||
def test_raises_on_other_errors(self):
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.side_effect = db_module.GraphDatabaseQueryException(
|
||||
message="Connection refused",
|
||||
code="Neo.TransientError.General.UnknownError",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.attack_paths.database.get_session",
|
||||
return_value=session_ctx,
|
||||
):
|
||||
with pytest.raises(db_module.GraphDatabaseQueryException):
|
||||
db_module.drop_subgraph("db-tenant-abc", "provider-123")
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from django.test import RequestFactory
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
from api.authentication import SSEAuthentication, TenantAPIKeyAuthentication
|
||||
from api.authentication import TenantAPIKeyAuthentication
|
||||
from api.db_router import MainRouter
|
||||
from api.models import TenantAPIKey
|
||||
|
||||
@@ -382,64 +382,3 @@ class TestTenantAPIKeyAuthentication:
|
||||
auth_backend.authenticate(request)
|
||||
|
||||
assert str(exc_info.value.detail) == "API Key has already expired."
|
||||
|
||||
|
||||
class TestSSEAuthentication:
|
||||
"""`SSEAuthentication` adds an `?access_token=<jwt>` fallback for
|
||||
browser `EventSource` clients while keeping the standard
|
||||
`Authorization` header as the authoritative source."""
|
||||
|
||||
def test_header_present_delegates_to_super(self):
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Bearer header-token"}
|
||||
with patch.object(
|
||||
SSEAuthentication.__bases__[0], "authenticate", return_value=("user", "tok")
|
||||
) as super_auth:
|
||||
result = SSEAuthentication().authenticate(request)
|
||||
super_auth.assert_called_once_with(request)
|
||||
assert result == ("user", "tok")
|
||||
|
||||
def test_no_header_no_query_token_delegates_to_super(self):
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.query_params = {}
|
||||
with patch.object(
|
||||
SSEAuthentication.__bases__[0], "authenticate", return_value=None
|
||||
) as super_auth:
|
||||
result = SSEAuthentication().authenticate(request)
|
||||
super_auth.assert_called_once_with(request)
|
||||
assert result is None
|
||||
|
||||
def test_query_token_used_only_as_fallback(self):
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.query_params = {"access_token": "query-jwt"}
|
||||
|
||||
jwt_instance = MagicMock()
|
||||
jwt_instance.get_validated_token.return_value = "validated"
|
||||
jwt_instance.get_user.return_value = "query-user"
|
||||
|
||||
with patch("api.authentication.JWTAuthentication", return_value=jwt_instance):
|
||||
user, token = SSEAuthentication().authenticate(request)
|
||||
|
||||
jwt_instance.get_validated_token.assert_called_once_with("query-jwt")
|
||||
assert user == "query-user"
|
||||
assert token == "validated"
|
||||
|
||||
def test_query_token_invalid_raises_authentication_error(self):
|
||||
"""An invalid JWT in `?access_token` must propagate as an auth error,
|
||||
not be swallowed or treated as unauthenticated."""
|
||||
from rest_framework_simplejwt.exceptions import InvalidToken
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.query_params = {"access_token": "bad-token"}
|
||||
|
||||
jwt_instance = MagicMock()
|
||||
jwt_instance.get_validated_token.side_effect = InvalidToken("Token is invalid")
|
||||
|
||||
with patch("api.authentication.JWTAuthentication", return_value=jwt_instance):
|
||||
with pytest.raises(InvalidToken):
|
||||
SSEAuthentication().authenticate(request)
|
||||
|
||||
jwt_instance.get_validated_token.assert_called_once_with("bad-token")
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
"""Tests for the platform SSE infrastructure (``api.sse``).
|
||||
|
||||
Cover the two security-critical platform pieces — the channel-name
|
||||
convention (:mod:`api.sse.utils`) and the tenant gate enforced by
|
||||
:class:`api.sse.channelmanager.SSEChannelManager`. The SSE authentication
|
||||
class lives in :mod:`api.authentication` with the rest of the auth stack,
|
||||
so its tests live in ``test_authentication.py``. Per-feature SSE endpoints
|
||||
add their own tests on top of these.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.sse.channelmanager import SSEChannelManager
|
||||
from api.sse.utils import make_channel_name, tenant_id_from_channel
|
||||
|
||||
|
||||
class TestMakeChannel:
|
||||
def test_round_trips_tenant_id(self):
|
||||
tenant_id = uuid.uuid4()
|
||||
channel = make_channel_name("lighthouse-session", tenant_id, uuid.uuid4())
|
||||
assert tenant_id_from_channel(channel) == tenant_id
|
||||
|
||||
def test_accepts_str_arguments(self):
|
||||
tenant_id = uuid.uuid4()
|
||||
channel = make_channel_name("lighthouse-session", str(tenant_id), "resource-1")
|
||||
assert channel == f"lighthouse-session:{tenant_id}:resource-1"
|
||||
|
||||
def test_prefix_with_hyphen_is_not_split(self):
|
||||
# Prefixes contain hyphens but never colons, so the tenant id is
|
||||
# always the second colon-separated segment.
|
||||
tenant_id = uuid.uuid4()
|
||||
channel = make_channel_name("a-long-hyphenated-prefix", tenant_id, "res")
|
||||
assert tenant_id_from_channel(channel) == tenant_id
|
||||
|
||||
|
||||
class TestTenantIdFromChannel:
|
||||
def test_returns_none_for_too_few_segments(self):
|
||||
assert tenant_id_from_channel("prefix:only") is None
|
||||
assert tenant_id_from_channel("garbage") is None
|
||||
|
||||
def test_returns_none_for_non_uuid_tenant_segment(self):
|
||||
assert tenant_id_from_channel("prefix:not-a-uuid:resource") is None
|
||||
|
||||
def test_parses_valid_channel(self):
|
||||
tenant_id = uuid.uuid4()
|
||||
assert tenant_id_from_channel(f"prefix:{tenant_id}:resource") == tenant_id
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestSSEChannelManager:
|
||||
def test_member_can_read_own_tenant_channel(
|
||||
self, create_test_user, tenants_fixture
|
||||
):
|
||||
tenant = tenants_fixture[0]
|
||||
channel = make_channel_name("lighthouse-session", tenant.id, uuid.uuid4())
|
||||
assert SSEChannelManager().can_read_channel(create_test_user, channel)
|
||||
|
||||
def test_non_member_cannot_read_other_tenant_channel(
|
||||
self, create_test_user, tenants_fixture
|
||||
):
|
||||
# create_test_user is a member of tenant1 and tenant2 but not tenant3.
|
||||
foreign_tenant = tenants_fixture[2]
|
||||
channel = make_channel_name(
|
||||
"lighthouse-session", foreign_tenant.id, uuid.uuid4()
|
||||
)
|
||||
assert not SSEChannelManager().can_read_channel(create_test_user, channel)
|
||||
|
||||
def test_anonymous_user_is_rejected(self, tenants_fixture):
|
||||
channel = make_channel_name(
|
||||
"lighthouse-session", tenants_fixture[0].id, uuid.uuid4()
|
||||
)
|
||||
assert not SSEChannelManager().can_read_channel(None, channel)
|
||||
|
||||
anon = MagicMock(is_authenticated=False)
|
||||
assert not SSEChannelManager().can_read_channel(anon, channel)
|
||||
|
||||
def test_malformed_channel_is_rejected(self, create_test_user, tenants_fixture):
|
||||
assert not SSEChannelManager().can_read_channel(create_test_user, "garbage")
|
||||
|
||||
def test_get_channels_for_request_reads_stashed_set(self):
|
||||
request = MagicMock()
|
||||
request.sse_channels = {"prefix:tenant:resource"}
|
||||
manager = SSEChannelManager()
|
||||
assert manager.get_channels_for_request(request, {}) == {
|
||||
"prefix:tenant:resource"
|
||||
}
|
||||
|
||||
def test_get_channels_for_request_defaults_to_empty(self):
|
||||
# A request that never went through BaseSSEViewSet.list has no
|
||||
# sse_channels attribute; the manager must not raise.
|
||||
request = object()
|
||||
assert SSEChannelManager().get_channels_for_request(request, {}) == set()
|
||||
|
||||
def test_channel_is_not_reliable(self):
|
||||
# v1 ships without server-side replay storage.
|
||||
assert (
|
||||
SSEChannelManager().is_channel_reliable("prefix:tenant:resource") is False
|
||||
)
|
||||
@@ -357,6 +357,30 @@ class TestGetProwlerProviderKwargs:
|
||||
expected_result = {**secret_dict, **expected_extra_kwargs}
|
||||
assert result == expected_result
|
||||
|
||||
def test_get_prowler_provider_kwargs_oraclecloud_converts_region_string_to_set(
|
||||
self,
|
||||
):
|
||||
secret_dict = {
|
||||
"user": "ocid1.user.oc1..fake",
|
||||
"fingerprint": "00:11:22:33:44:55:66:77",
|
||||
"key_content": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----",
|
||||
"tenancy": "ocid1.tenancy.oc1..fake",
|
||||
"region": "us-ashburn-1",
|
||||
"pass_phrase": "fake-passphrase",
|
||||
}
|
||||
secret_mock = MagicMock()
|
||||
secret_mock.secret = secret_dict
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider = Provider.ProviderChoices.ORACLECLOUD.value
|
||||
provider.secret = secret_mock
|
||||
provider.uid = "ocid1.tenancy.oc1..fake"
|
||||
|
||||
result = get_prowler_provider_kwargs(provider)
|
||||
|
||||
expected_result = {**secret_dict, "region": {"us-ashburn-1"}}
|
||||
assert result == expected_result
|
||||
|
||||
def test_get_prowler_provider_kwargs_with_mutelist(self):
|
||||
provider_uid = "provider_uid"
|
||||
secret_dict = {"key": "value"}
|
||||
|
||||
@@ -9570,6 +9570,188 @@ class TestComplianceOverviewViewSet:
|
||||
assert "Category" in first_attr
|
||||
assert "AWSService" in first_attr
|
||||
|
||||
def test_compliance_overview_attributes_resolves_provider_from_scan(
|
||||
self, authenticated_client, tenants_fixture, providers_fixture
|
||||
):
|
||||
# csa_ccm_4.0 is a multi-provider universal framework: a single
|
||||
# compliance_id whose requirements expose different checks per provider.
|
||||
# Passing a scan must return the check IDs for that scan's provider,
|
||||
# otherwise the endpoint defaults to the first provider that declares the
|
||||
# framework and azure/gcp requirements end up with check IDs that match
|
||||
# no findings.
|
||||
tenant = tenants_fixture[0]
|
||||
gcp_provider = providers_fixture[2]
|
||||
azure_provider = providers_fixture[4]
|
||||
assert gcp_provider.provider == Provider.ProviderChoices.GCP.value
|
||||
assert azure_provider.provider == Provider.ProviderChoices.AZURE.value
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
gcp_scan = Scan.objects.create(
|
||||
name="gcp scan",
|
||||
provider=gcp_provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
azure_scan = Scan.objects.create(
|
||||
name="azure scan",
|
||||
provider=azure_provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
|
||||
def request_attributes(scan_id=None):
|
||||
params = {"filter[compliance_id]": "csa_ccm_4.0"}
|
||||
if scan_id is not None:
|
||||
params["filter[scan_id]"] = str(scan_id)
|
||||
return authenticated_client.get(
|
||||
reverse("complianceoverview-attributes"), params
|
||||
)
|
||||
|
||||
def collect_check_ids(scan_id=None):
|
||||
response = request_attributes(scan_id)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
check_ids = set()
|
||||
for item in response.json()["data"]:
|
||||
check_ids.update(item["attributes"]["attributes"]["check_ids"])
|
||||
return check_ids
|
||||
|
||||
gcp_check_ids = collect_check_ids(gcp_scan.id)
|
||||
azure_check_ids = collect_check_ids(azure_scan.id)
|
||||
|
||||
# Each scan resolves to its own provider's checks, and they differ.
|
||||
assert gcp_check_ids
|
||||
assert azure_check_ids
|
||||
assert gcp_check_ids != azure_check_ids
|
||||
|
||||
# The returned check IDs belong to the SDK's per-provider definition.
|
||||
from api.compliance import get_prowler_provider_compliance
|
||||
|
||||
def expected_check_ids(provider_type):
|
||||
framework = get_prowler_provider_compliance(provider_type)["csa_ccm_4.0"]
|
||||
expected = set()
|
||||
for requirement in framework.requirements:
|
||||
expected.update(requirement.checks.get(provider_type, []))
|
||||
return expected
|
||||
|
||||
assert gcp_check_ids <= expected_check_ids(Provider.ProviderChoices.GCP.value)
|
||||
assert azure_check_ids <= expected_check_ids(
|
||||
Provider.ProviderChoices.AZURE.value
|
||||
)
|
||||
|
||||
# An explicit scan_id is authoritative: a non-existent scan must fail
|
||||
# closed with 404 instead of silently falling back to another provider.
|
||||
missing_response = request_attributes("00000000-0000-0000-0000-000000000000")
|
||||
assert missing_response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
# A malformed scan_id is rejected with 404 as well.
|
||||
malformed_response = request_attributes("not-a-uuid")
|
||||
assert malformed_response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
# An empty value (filter[scan_id]=) must not fall back to the legacy
|
||||
# provider picker: the explicit (if blank) selector fails closed.
|
||||
empty_response = request_attributes("")
|
||||
assert empty_response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
# A scan belonging to another tenant is not visible (RLS), so it must
|
||||
# return 404 rather than leaking the fallback provider's check IDs.
|
||||
other_tenant = Tenant.objects.create(name="Other Compliance Tenant")
|
||||
foreign_provider = Provider.objects.create(
|
||||
provider="gcp",
|
||||
uid="foreign-gcp-test",
|
||||
alias="foreign_gcp",
|
||||
tenant_id=other_tenant.id,
|
||||
)
|
||||
foreign_scan = Scan.objects.create(
|
||||
name="foreign scan",
|
||||
provider=foreign_provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=other_tenant.id,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
foreign_response = request_attributes(foreign_scan.id)
|
||||
assert foreign_response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_compliance_overview_attributes_scan_scoped_by_provider_group(
|
||||
self,
|
||||
authenticated_client_no_permissions_rbac,
|
||||
providers_fixture,
|
||||
):
|
||||
# A user with limited visibility (no UNLIMITED_VISIBILITY) must only be
|
||||
# able to resolve scans for providers in its provider groups. Tenant RLS
|
||||
# alone is not enough here: both scans belong to the same tenant, so the
|
||||
# endpoint has to scope the scan lookup by provider group, otherwise a
|
||||
# restricted user could read another provider's compliance metadata.
|
||||
client = authenticated_client_no_permissions_rbac
|
||||
limited_user = client.user
|
||||
membership = Membership.objects.filter(user=limited_user).first()
|
||||
tenant = membership.tenant
|
||||
|
||||
allowed_provider = providers_fixture[2]
|
||||
denied_provider = providers_fixture[4]
|
||||
assert allowed_provider.provider == Provider.ProviderChoices.GCP.value
|
||||
assert denied_provider.provider == Provider.ProviderChoices.AZURE.value
|
||||
|
||||
provider_group = ProviderGroup.objects.create(
|
||||
name="limited-compliance-group",
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
ProviderGroupMembership.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
provider_group=provider_group,
|
||||
provider=allowed_provider,
|
||||
)
|
||||
RoleProviderGroupRelationship.objects.create(
|
||||
tenant_id=tenant.id,
|
||||
role=limited_user.roles.first(),
|
||||
provider_group=provider_group,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
allowed_scan = Scan.objects.create(
|
||||
name="allowed scan",
|
||||
provider=allowed_provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
denied_scan = Scan.objects.create(
|
||||
name="denied scan",
|
||||
provider=denied_provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant_id=tenant.id,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
|
||||
def request_attributes(scan_id):
|
||||
return client.get(
|
||||
reverse("complianceoverview-attributes"),
|
||||
{
|
||||
"filter[compliance_id]": "csa_ccm_4.0",
|
||||
"filter[scan_id]": str(scan_id),
|
||||
},
|
||||
)
|
||||
|
||||
# The scan in the user's provider group resolves normally.
|
||||
assert request_attributes(allowed_scan.id).status_code == status.HTTP_200_OK
|
||||
|
||||
# The scan outside the user's provider group is invisible, so it fails
|
||||
# closed with 404 instead of leaking the other provider's check IDs.
|
||||
assert (
|
||||
request_attributes(denied_scan.id).status_code == status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
def test_compliance_overview_attributes_missing_compliance_id(
|
||||
self, authenticated_client
|
||||
):
|
||||
|
||||
@@ -243,6 +243,12 @@ def get_prowler_provider_kwargs(
|
||||
**prowler_provider_kwargs,
|
||||
"filter_accounts": [provider.uid],
|
||||
}
|
||||
elif provider.provider == Provider.ProviderChoices.ORACLECLOUD.value:
|
||||
if isinstance(prowler_provider_kwargs.get("region"), str):
|
||||
prowler_provider_kwargs = {
|
||||
**prowler_provider_kwargs,
|
||||
"region": {prowler_provider_kwargs["region"]},
|
||||
}
|
||||
elif provider.provider == Provider.ProviderChoices.OPENSTACK.value:
|
||||
# clouds_yaml_content, clouds_yaml_cloud and provider_id are validated
|
||||
# in the provider itself, so it's not needed here.
|
||||
|
||||
@@ -30,6 +30,7 @@ from dj_rest_auth.registration.views import SocialLoginView
|
||||
from django.conf import settings as django_settings
|
||||
from django.contrib.postgres.aggregates import ArrayAgg, BoolAnd, StringAgg
|
||||
from django.contrib.postgres.search import SearchQuery
|
||||
from django.core.exceptions import ValidationError as DjangoValidationError
|
||||
from django.db import transaction
|
||||
from django.db.models import (
|
||||
BooleanField,
|
||||
@@ -4644,6 +4645,16 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Compliance framework ID to get attributes for.",
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="filter[scan_id]",
|
||||
required=False,
|
||||
type=OpenApiTypes.UUID,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Scan ID used to resolve the provider for "
|
||||
"multi-provider universal frameworks (e.g. CSA CCM), so "
|
||||
"the returned check IDs match the scan's provider. When omitted, "
|
||||
"the first provider that declares the framework is used.",
|
||||
),
|
||||
],
|
||||
responses={
|
||||
200: OpenApiResponse(
|
||||
@@ -5084,7 +5095,51 @@ class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
|
||||
|
||||
provider_type = None
|
||||
|
||||
# If we couldn't determine from database, try each provider type
|
||||
# When a scan is provided, resolve the provider from it. Multi-provider
|
||||
# universal frameworks (e.g. CSA CCM) share a single compliance_id
|
||||
# across providers but expose different checks per provider, so the
|
||||
# metadata (and therefore the check IDs the UI uses to fetch findings)
|
||||
# must be returned for the scan's provider. Without this, the endpoint
|
||||
# falls back to the first provider that declares the framework and
|
||||
# returns its check IDs, leaving azure/gcp/... requirements with no
|
||||
# matching findings.
|
||||
scan_id = request.query_params.get("filter[scan_id]")
|
||||
if "filter[scan_id]" in request.query_params:
|
||||
# An explicit scan_id is authoritative: fail closed instead of
|
||||
# falling back to another provider. Otherwise an invalid, empty
|
||||
# (filter[scan_id]=) or inaccessible scan would silently return the
|
||||
# first provider's check IDs, recreating the multi-provider mismatch
|
||||
# this endpoint fixes.
|
||||
if not scan_id:
|
||||
raise NotFound(detail=f"Scan '{scan_id}' not found.")
|
||||
|
||||
# Tenant isolation is already enforced by Postgres RLS on the
|
||||
# connection (see BaseRLSViewSet). Scope the lookup by provider
|
||||
# group as well so a user with limited visibility can't resolve
|
||||
# another provider's scan and read its compliance metadata, mirroring
|
||||
# the RBAC scoping get_queryset() applies to the rest of the ViewSet.
|
||||
role = get_role(request.user, request.tenant_id)
|
||||
if getattr(role, Permissions.UNLIMITED_VISIBILITY.value, False):
|
||||
scan_queryset = Scan.objects.filter(tenant_id=request.tenant_id)
|
||||
else:
|
||||
scan_queryset = Scan.objects.filter(provider__in=get_providers(role))
|
||||
|
||||
try:
|
||||
scan = scan_queryset.select_related("provider").get(id=scan_id)
|
||||
except (Scan.DoesNotExist, DjangoValidationError, ValueError):
|
||||
raise NotFound(detail=f"Scan '{scan_id}' not found.")
|
||||
|
||||
provider_type = scan.provider.provider
|
||||
if compliance_id not in get_compliance_frameworks(provider_type):
|
||||
raise NotFound(
|
||||
detail=(
|
||||
f"Compliance framework '{compliance_id}' is not "
|
||||
f"available for scan '{scan_id}'."
|
||||
)
|
||||
)
|
||||
|
||||
# Fall back to the first provider that declares the framework. Keeps the
|
||||
# endpoint working for provider-agnostic callers that omit the scan.
|
||||
if not provider_type:
|
||||
for pt in Provider.ProviderChoices.values:
|
||||
if compliance_id in get_compliance_frameworks(pt):
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import timedelta
|
||||
from config.custom_logging import LOGGING # noqa
|
||||
from config.env import BASE_DIR, env # noqa
|
||||
from config.settings.celery import * # noqa
|
||||
from config.settings.eventstream import * # noqa
|
||||
from config.settings.partitions import * # noqa
|
||||
from config.settings.sentry import * # noqa
|
||||
from config.settings.social_login import * # noqa
|
||||
@@ -45,7 +44,6 @@ INSTALLED_APPS = [
|
||||
"dj_rest_auth.registration",
|
||||
"rest_framework.authtoken",
|
||||
"drf_simple_apikey",
|
||||
"django_eventstream",
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
@@ -138,7 +136,6 @@ SPECTACULAR_SETTINGS = {
|
||||
}
|
||||
|
||||
WSGI_APPLICATION = "config.wsgi.application"
|
||||
ASGI_APPLICATION = "config.asgi.application"
|
||||
|
||||
DJANGO_GUID = {
|
||||
"GUID_HEADER_NAME": "Transaction-ID",
|
||||
|
||||
@@ -25,15 +25,6 @@ bind = f"{BIND_ADDRESS}:{PORT}"
|
||||
workers = env.int("DJANGO_WORKERS", default=multiprocessing.cpu_count() * 2 + 1)
|
||||
reload = DEBUG
|
||||
|
||||
# Native ASGI worker (gunicorn 24+). Required so SSE endpoints can keep the
|
||||
# event loop alive while waiting for events.
|
||||
worker_class = env("DJANGO_WORKER_CLASS", default="asgi")
|
||||
|
||||
# Preload the application before forking workers in production: the app is
|
||||
# imported once in the master and workers fork from it. In development, disable
|
||||
# preload so the server restarts on code changes.
|
||||
preload_app = not DEBUG
|
||||
|
||||
# Logging
|
||||
logconfig_dict = DJANGO_LOGGERS
|
||||
gunicorn_logger = logging.getLogger(BackendLogger.GUNICORN)
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""Server-Sent Events (SSE) configuration.
|
||||
|
||||
Wires django-eventstream into the platform: Valkey Pub/Sub backend on a
|
||||
dedicated DB (separate from the Celery broker), the platform channel
|
||||
manager, and headers that match the existing CORS allowlist.
|
||||
"""
|
||||
|
||||
from config.env import env
|
||||
from config.settings.celery import (
|
||||
VALKEY_HOST,
|
||||
VALKEY_PASSWORD,
|
||||
VALKEY_PORT,
|
||||
VALKEY_SCHEME,
|
||||
VALKEY_USERNAME,
|
||||
)
|
||||
|
||||
# Dedicated Valkey DB for the SSE Pub/Sub bus. Kept distinct from the
|
||||
# Celery broker DB so a noisy broker can't shoulder out streaming
|
||||
# traffic on the same keyspace.
|
||||
EVENTSTREAM_VALKEY_DB = env.int("EVENTSTREAM_VALKEY_DB", default=2)
|
||||
|
||||
EVENTSTREAM_REDIS: dict = {
|
||||
"host": VALKEY_HOST,
|
||||
"port": int(VALKEY_PORT),
|
||||
"db": EVENTSTREAM_VALKEY_DB,
|
||||
}
|
||||
if VALKEY_PASSWORD:
|
||||
EVENTSTREAM_REDIS["password"] = VALKEY_PASSWORD
|
||||
if VALKEY_USERNAME:
|
||||
EVENTSTREAM_REDIS["username"] = VALKEY_USERNAME
|
||||
if VALKEY_SCHEME == "rediss":
|
||||
EVENTSTREAM_REDIS["ssl"] = True
|
||||
|
||||
# Platform channel manager — performs the per-feature authorization and
|
||||
# rewrites the placeholder channel from the URL into the canonical
|
||||
# tenant-scoped channel name. See ``api.sse.channelmanager``.
|
||||
EVENTSTREAM_CHANNELMANAGER_CLASS = "api.sse.channelmanager.SSEChannelManager"
|
||||
|
||||
# Headers a browser EventSource may legitimately send. Keep tight; the
|
||||
# stream itself reads no body, so no permissive defaults.
|
||||
EVENTSTREAM_ALLOW_HEADERS = "Cache-Control, Last-Event-ID"
|
||||
+161
-137
@@ -5,6 +5,7 @@ import re
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -22,7 +23,6 @@ from django.db.models import (
|
||||
Max,
|
||||
Min,
|
||||
OuterRef,
|
||||
Prefetch,
|
||||
Q,
|
||||
Sum,
|
||||
When,
|
||||
@@ -357,68 +357,71 @@ def _copy_compliance_requirement_rows(
|
||||
|
||||
|
||||
def _persist_compliance_requirement_rows(
|
||||
tenant_id: str, rows: list[dict[str, Any]], batch_size: int = 10000
|
||||
) -> None:
|
||||
tenant_id: str, rows: Iterable[dict[str, Any]], batch_size: int = 10000
|
||||
) -> int:
|
||||
"""Persist compliance requirement rows using batched COPY with ORM fallback.
|
||||
|
||||
Splits large row sets into batches to reduce lock duration and improve concurrency.
|
||||
``rows`` is consumed lazily in batches, so peak memory stays at ~``batch_size``
|
||||
rows instead of the full set. A batch that fails COPY falls back to an ORM
|
||||
``bulk_create`` of just that batch.
|
||||
|
||||
Args:
|
||||
tenant_id: Target tenant UUID.
|
||||
rows: Precomputed row dictionaries that reflect the compliance
|
||||
overview state for a scan.
|
||||
rows: Iterable of row dictionaries reflecting the compliance overview
|
||||
state for a scan.
|
||||
batch_size: Number of rows per COPY batch (default: 10000).
|
||||
|
||||
Returns:
|
||||
int: total number of rows persisted.
|
||||
"""
|
||||
if not rows:
|
||||
return
|
||||
|
||||
total_rows = len(rows)
|
||||
total_batches = (total_rows + batch_size - 1) // batch_size
|
||||
|
||||
try:
|
||||
# Process rows in batches to reduce lock duration
|
||||
for batch_num in range(total_batches):
|
||||
start_idx = batch_num * batch_size
|
||||
end_idx = min(start_idx + batch_size, total_rows)
|
||||
batch = rows[start_idx:end_idx]
|
||||
total_rows = 0
|
||||
batch_num = 0
|
||||
|
||||
for batch, _is_last in batched(rows, batch_size):
|
||||
if not batch:
|
||||
continue
|
||||
batch_num += 1
|
||||
try:
|
||||
_copy_compliance_requirement_rows(tenant_id, batch)
|
||||
except Exception as error:
|
||||
logger.exception(
|
||||
f"COPY bulk insert for compliance requirements batch {batch_num} "
|
||||
"failed; falling back to ORM bulk_create for this batch",
|
||||
exc_info=error,
|
||||
)
|
||||
fallback_objects = [
|
||||
ComplianceRequirementOverview(
|
||||
id=row["id"],
|
||||
tenant_id=row["tenant_id"],
|
||||
inserted_at=row["inserted_at"],
|
||||
compliance_id=row["compliance_id"],
|
||||
framework=row["framework"],
|
||||
version=row["version"],
|
||||
description=row["description"],
|
||||
region=row["region"],
|
||||
requirement_id=row["requirement_id"],
|
||||
requirement_status=row["requirement_status"],
|
||||
passed_checks=row["passed_checks"],
|
||||
failed_checks=row["failed_checks"],
|
||||
total_checks=row["total_checks"],
|
||||
passed_findings=row.get("passed_findings", 0),
|
||||
total_findings=row.get("total_findings", 0),
|
||||
scan_id=row["scan_id"],
|
||||
)
|
||||
for row in batch
|
||||
]
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceRequirementOverview.objects.bulk_create(
|
||||
fallback_objects, batch_size=500
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compliance COPY batch {batch_num + 1}/{total_batches}: "
|
||||
f"inserted {len(batch)} rows ({start_idx + len(batch)}/{total_rows} total)"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(
|
||||
"COPY bulk insert for compliance requirements failed; falling back to ORM bulk_create",
|
||||
exc_info=error,
|
||||
total_rows += len(batch)
|
||||
logger.info(
|
||||
f"Compliance COPY batch {batch_num}: inserted {len(batch)} rows "
|
||||
f"({total_rows} total)"
|
||||
)
|
||||
# Fallback: use ORM bulk_create for all remaining rows
|
||||
fallback_objects = [
|
||||
ComplianceRequirementOverview(
|
||||
id=row["id"],
|
||||
tenant_id=row["tenant_id"],
|
||||
inserted_at=row["inserted_at"],
|
||||
compliance_id=row["compliance_id"],
|
||||
framework=row["framework"],
|
||||
version=row["version"],
|
||||
description=row["description"],
|
||||
region=row["region"],
|
||||
requirement_id=row["requirement_id"],
|
||||
requirement_status=row["requirement_status"],
|
||||
passed_checks=row["passed_checks"],
|
||||
failed_checks=row["failed_checks"],
|
||||
total_checks=row["total_checks"],
|
||||
passed_findings=row.get("passed_findings", 0),
|
||||
total_findings=row.get("total_findings", 0),
|
||||
scan_id=row["scan_id"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceRequirementOverview.objects.bulk_create(
|
||||
fallback_objects, batch_size=500
|
||||
)
|
||||
|
||||
return total_rows
|
||||
|
||||
|
||||
def _create_compliance_summaries(
|
||||
@@ -1445,9 +1448,13 @@ def _aggregate_findings_by_region(
|
||||
tenant_id: str, scan_id: str, modeled_threatscore_compliance_id: str
|
||||
) -> tuple[dict, dict]:
|
||||
"""
|
||||
Aggregate findings by region using optimized ORM queries.
|
||||
Aggregate findings by region using streaming, column-scoped ORM reads.
|
||||
|
||||
Replaces nested Python loops with efficient queries and aggregation.
|
||||
Reads only the consumed columns as tuples via ``values_list`` and streams
|
||||
them with ``.iterator()``, using the denormalized ``resource_regions`` array
|
||||
instead of ``prefetch_related("resources")``. ``resource_regions`` mirrors the
|
||||
regions of a finding's related resources, so it yields the same per-region
|
||||
tally without joining the resource table.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
@@ -1459,12 +1466,12 @@ def _aggregate_findings_by_region(
|
||||
- check_status_by_region: {region: {check_id: status}}
|
||||
- findings_count_by_compliance: {region: {normalized_id: {requirement_id: {total, pass}}}}
|
||||
"""
|
||||
check_status_by_region = {}
|
||||
findings_count_by_compliance = {}
|
||||
check_status_by_region: dict = {}
|
||||
findings_count_by_compliance: dict = {}
|
||||
|
||||
normalized_id = re.sub(r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower())
|
||||
|
||||
with rls_transaction(tenant_id, using=READ_REPLICA_ALIAS):
|
||||
# Fetch only PASS/FAIL findings (optimized query reduces data transfer)
|
||||
# Other statuses are not needed for check_status or ThreatScore calculation
|
||||
findings = (
|
||||
Finding.all_objects.filter(
|
||||
tenant_id=tenant_id,
|
||||
@@ -1472,42 +1479,28 @@ def _aggregate_findings_by_region(
|
||||
muted=False,
|
||||
status__in=["PASS", "FAIL"],
|
||||
)
|
||||
.only("id", "check_id", "status", "compliance")
|
||||
.prefetch_related(
|
||||
Prefetch(
|
||||
"resources",
|
||||
queryset=Resource.objects.only("id", "region"),
|
||||
to_attr="small_resources",
|
||||
)
|
||||
.values_list("check_id", "status", "resource_regions", "compliance")
|
||||
.iterator(chunk_size=DJANGO_FINDINGS_BATCH_SIZE)
|
||||
)
|
||||
|
||||
for check_id, status, resource_regions, compliance in findings:
|
||||
threatscore_requirements = (compliance or {}).get(
|
||||
modeled_threatscore_compliance_id
|
||||
)
|
||||
)
|
||||
|
||||
# Process findings in a single pass (more efficient than original nested loops)
|
||||
normalized_id = re.sub(
|
||||
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
|
||||
)
|
||||
|
||||
for finding in findings:
|
||||
status = finding.status
|
||||
|
||||
for resource in finding.small_resources:
|
||||
region = resource.region
|
||||
|
||||
# Aggregate check status by region
|
||||
current_status = check_status_by_region.setdefault(region, {})
|
||||
for region in resource_regions or ():
|
||||
# Priority: FAIL > any other status
|
||||
if current_status.get(finding.check_id) != "FAIL":
|
||||
current_status[finding.check_id] = status
|
||||
current_status = check_status_by_region.setdefault(region, {})
|
||||
if current_status.get(check_id) != "FAIL":
|
||||
current_status[check_id] = status
|
||||
|
||||
# Aggregate ThreatScore compliance counts
|
||||
if modeled_threatscore_compliance_id in (finding.compliance or {}):
|
||||
if threatscore_requirements:
|
||||
compliance_key = findings_count_by_compliance.setdefault(
|
||||
region, {}
|
||||
).setdefault(normalized_id, {})
|
||||
|
||||
for requirement_id in finding.compliance[
|
||||
modeled_threatscore_compliance_id
|
||||
]:
|
||||
for requirement_id in threatscore_requirements:
|
||||
requirement_stats = compliance_key.setdefault(
|
||||
requirement_id, {"total": 0, "pass": 0}
|
||||
)
|
||||
@@ -1554,8 +1547,8 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
(compliance_id, requirement_id)
|
||||
)
|
||||
|
||||
compliance_requirement_rows: list[dict[str, Any]] = []
|
||||
regions = []
|
||||
requirements_created = 0
|
||||
requirement_statuses = defaultdict(
|
||||
lambda: {"fail_count": 0, "pass_count": 0, "total_count": 0}
|
||||
)
|
||||
@@ -1595,44 +1588,93 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
else:
|
||||
requirement_stats["failed_checks"] += 1
|
||||
|
||||
# Prepare compliance requirement rows and compute summaries in single pass
|
||||
utc_datetime_now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Pre-compute shared strings (optimization: reduces string conversions)
|
||||
tenant_id_str = str(tenant_id)
|
||||
scan_id_str = str(scan_instance.id)
|
||||
|
||||
for region in regions:
|
||||
region_stats = region_requirement_stats.get(region, {})
|
||||
for compliance_id, compliance in compliance_template.items():
|
||||
modeled_compliance_id = _normalized_compliance_key(
|
||||
compliance["framework"], compliance["version"]
|
||||
# Per-framework constants that don't depend on the region.
|
||||
compliance_plan = []
|
||||
for compliance_id, compliance in compliance_template.items():
|
||||
modeled_compliance_id = _normalized_compliance_key(
|
||||
compliance["framework"], compliance["version"]
|
||||
)
|
||||
framework = compliance["framework"]
|
||||
version = compliance["version"] or ""
|
||||
requirements = [
|
||||
(
|
||||
requirement_id,
|
||||
requirement.get("description") or "",
|
||||
len(requirement["checks"]),
|
||||
)
|
||||
compliance_stats = region_stats.get(compliance_id, {})
|
||||
# Create an overview record for each requirement within each compliance framework
|
||||
for requirement_id, requirement in compliance[
|
||||
"requirements"
|
||||
].items():
|
||||
stats = compliance_stats.get(requirement_id)
|
||||
passed_checks = stats["passed_checks"] if stats else 0
|
||||
failed_checks = stats["failed_checks"] if stats else 0
|
||||
total_checks = len(requirement["checks"])
|
||||
if total_checks == 0:
|
||||
requirement_status = "MANUAL"
|
||||
elif failed_checks > 0:
|
||||
requirement_status = "FAIL"
|
||||
else:
|
||||
requirement_status = "PASS"
|
||||
].items()
|
||||
]
|
||||
compliance_plan.append(
|
||||
(
|
||||
compliance_id,
|
||||
framework,
|
||||
version,
|
||||
modeled_compliance_id,
|
||||
requirements,
|
||||
)
|
||||
)
|
||||
|
||||
compliance_requirement_rows.append(
|
||||
{
|
||||
# Yield rows lazily (consumed batch-by-batch by COPY) so peak memory
|
||||
# stays bounded; tally requirement_statuses in the same pass.
|
||||
def _iter_compliance_requirement_rows():
|
||||
for region in regions:
|
||||
region_stats = region_requirement_stats.get(region, {})
|
||||
region_findings = findings_count_by_compliance.get(region, {})
|
||||
for (
|
||||
compliance_id,
|
||||
framework,
|
||||
version,
|
||||
modeled_compliance_id,
|
||||
requirements,
|
||||
) in compliance_plan:
|
||||
compliance_stats = region_stats.get(compliance_id, {})
|
||||
compliance_findings = region_findings.get(
|
||||
modeled_compliance_id, {}
|
||||
)
|
||||
for requirement_id, description, total_checks in requirements:
|
||||
stats = compliance_stats.get(requirement_id)
|
||||
if stats:
|
||||
passed_checks = stats["passed_checks"]
|
||||
failed_checks = stats["failed_checks"]
|
||||
else:
|
||||
passed_checks = 0
|
||||
failed_checks = 0
|
||||
if total_checks == 0:
|
||||
requirement_status = "MANUAL"
|
||||
elif failed_checks > 0:
|
||||
requirement_status = "FAIL"
|
||||
else:
|
||||
requirement_status = "PASS"
|
||||
|
||||
finding_counts = compliance_findings.get(requirement_id)
|
||||
if finding_counts:
|
||||
passed_findings = finding_counts.get("pass", 0)
|
||||
total_findings = finding_counts.get("total", 0)
|
||||
else:
|
||||
passed_findings = 0
|
||||
total_findings = 0
|
||||
|
||||
key = (compliance_id, requirement_id)
|
||||
requirement_statuses[key]["total_count"] += 1
|
||||
if requirement_status == "FAIL":
|
||||
requirement_statuses[key]["fail_count"] += 1
|
||||
elif requirement_status == "PASS":
|
||||
requirement_statuses[key]["pass_count"] += 1
|
||||
|
||||
yield {
|
||||
"id": uuid.uuid4(),
|
||||
"tenant_id": tenant_id_str,
|
||||
"inserted_at": utc_datetime_now,
|
||||
"compliance_id": compliance_id,
|
||||
"framework": compliance["framework"],
|
||||
"version": compliance["version"] or "",
|
||||
"description": requirement.get("description") or "",
|
||||
"framework": framework,
|
||||
"version": version,
|
||||
"description": description,
|
||||
"region": region,
|
||||
"requirement_id": requirement_id,
|
||||
"requirement_status": requirement_status,
|
||||
@@ -1640,41 +1682,23 @@ def create_compliance_requirements(tenant_id: str, scan_id: str):
|
||||
"failed_checks": failed_checks,
|
||||
"total_checks": total_checks,
|
||||
"scan_id": scan_id_str,
|
||||
"passed_findings": findings_count_by_compliance.get(
|
||||
region, {}
|
||||
)
|
||||
.get(modeled_compliance_id, {})
|
||||
.get(requirement_id, {})
|
||||
.get("pass", 0),
|
||||
"total_findings": findings_count_by_compliance.get(
|
||||
region, {}
|
||||
)
|
||||
.get(modeled_compliance_id, {})
|
||||
.get(requirement_id, {})
|
||||
.get("total", 0),
|
||||
"passed_findings": passed_findings,
|
||||
"total_findings": total_findings,
|
||||
}
|
||||
)
|
||||
|
||||
# Update summary tracking (single-pass optimization)
|
||||
key = (compliance_id, requirement_id)
|
||||
requirement_statuses[key]["total_count"] += 1
|
||||
if requirement_status == "FAIL":
|
||||
requirement_statuses[key]["fail_count"] += 1
|
||||
elif requirement_status == "PASS":
|
||||
requirement_statuses[key]["pass_count"] += 1
|
||||
|
||||
# Idempotent re-run: COPY can't ON CONFLICT, so clear this scan's rows first.
|
||||
# Idempotent re-run: clear this scan's rows before re-inserting.
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceRequirementOverview.objects.filter(scan_id=scan_id).delete()
|
||||
|
||||
# Bulk create requirement records using PostgreSQL COPY
|
||||
_persist_compliance_requirement_rows(tenant_id, compliance_requirement_rows)
|
||||
requirements_created = _persist_compliance_requirement_rows(
|
||||
tenant_id, _iter_compliance_requirement_rows()
|
||||
)
|
||||
|
||||
# Create pre-aggregated summaries for fast compliance overview lookups
|
||||
_create_compliance_summaries(tenant_id, scan_id, requirement_statuses)
|
||||
|
||||
return {
|
||||
"requirements_created": len(compliance_requirement_rows),
|
||||
"requirements_created": requirements_created,
|
||||
"regions_processed": list(regions),
|
||||
"compliance_frameworks": (
|
||||
list(compliance_template.keys()) if regions else []
|
||||
|
||||
@@ -3674,19 +3674,19 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# Mock findings with resources
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "FAIL"
|
||||
mock_finding1.compliance = {modeled_threatscore_compliance_id: ["req1", "req2"]}
|
||||
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
# (check_id, status, resource_regions, compliance) tuples
|
||||
finding_rows = [
|
||||
(
|
||||
"check1",
|
||||
"FAIL",
|
||||
["us-east-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1", "req2"]},
|
||||
)
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3700,6 +3700,12 @@ class TestAggregateFindingsByRegion:
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify structure of check_status_by_region
|
||||
assert isinstance(check_status_by_region, dict)
|
||||
assert "us-east-1" in check_status_by_region
|
||||
@@ -3719,27 +3725,15 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# First finding with PASS status
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "PASS"
|
||||
mock_finding1.compliance = {}
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
|
||||
# Second finding with FAIL status for same check/region
|
||||
mock_finding2 = MagicMock()
|
||||
mock_finding2.check_id = "check1"
|
||||
mock_finding2.status = "FAIL"
|
||||
mock_finding2.compliance = {}
|
||||
mock_resource2 = MagicMock()
|
||||
mock_resource2.region = "us-east-1"
|
||||
mock_finding2.small_resources = [mock_resource2]
|
||||
# Same check/region: PASS first, then FAIL — FAIL must win
|
||||
finding_rows = [
|
||||
("check1", "PASS", ["us-east-1"], {}),
|
||||
("check1", "FAIL", ["us-east-1"], {}),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3751,6 +3745,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# FAIL should override PASS
|
||||
assert check_status_by_region["us-east-1"]["check1"] == "FAIL"
|
||||
|
||||
@@ -3765,8 +3765,8 @@ class TestAggregateFindingsByRegion:
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = []
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = []
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3778,6 +3778,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify filter was called with muted=False
|
||||
mock_findings_filter.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
@@ -3796,27 +3802,25 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# Finding with PASS status
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "PASS"
|
||||
mock_finding1.compliance = {modeled_threatscore_compliance_id: ["req1"]}
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
|
||||
# Finding with FAIL status
|
||||
mock_finding2 = MagicMock()
|
||||
mock_finding2.check_id = "check2"
|
||||
mock_finding2.status = "FAIL"
|
||||
mock_finding2.compliance = {modeled_threatscore_compliance_id: ["req1"]}
|
||||
mock_resource2 = MagicMock()
|
||||
mock_resource2.region = "us-east-1"
|
||||
mock_finding2.small_resources = [mock_resource2]
|
||||
# PASS and FAIL findings mapped to the same ThreatScore requirement
|
||||
finding_rows = [
|
||||
(
|
||||
"check1",
|
||||
"PASS",
|
||||
["us-east-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1"]},
|
||||
),
|
||||
(
|
||||
"check2",
|
||||
"FAIL",
|
||||
["us-east-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1"]},
|
||||
),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3828,6 +3832,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify compliance counts
|
||||
normalized_id = re.sub(
|
||||
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
|
||||
@@ -3850,27 +3860,15 @@ class TestAggregateFindingsByRegion:
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
# Finding in us-east-1
|
||||
mock_finding1 = MagicMock()
|
||||
mock_finding1.check_id = "check1"
|
||||
mock_finding1.status = "FAIL"
|
||||
mock_finding1.compliance = {}
|
||||
mock_resource1 = MagicMock()
|
||||
mock_resource1.region = "us-east-1"
|
||||
mock_finding1.small_resources = [mock_resource1]
|
||||
|
||||
# Finding in us-west-2
|
||||
mock_finding2 = MagicMock()
|
||||
mock_finding2.check_id = "check1"
|
||||
mock_finding2.status = "PASS"
|
||||
mock_finding2.compliance = {}
|
||||
mock_resource2 = MagicMock()
|
||||
mock_resource2.region = "us-west-2"
|
||||
mock_finding2.small_resources = [mock_resource2]
|
||||
# One finding per region
|
||||
finding_rows = [
|
||||
("check1", "FAIL", ["us-east-1"], {}),
|
||||
("check1", "PASS", ["us-west-2"], {}),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = [mock_finding1, mock_finding2]
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3882,6 +3880,12 @@ class TestAggregateFindingsByRegion:
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
# Verify both regions are present with correct statuses
|
||||
assert "us-east-1" in check_status_by_region
|
||||
assert "us-west-2" in check_status_by_region
|
||||
@@ -3890,17 +3894,26 @@ class TestAggregateFindingsByRegion:
|
||||
|
||||
@patch("tasks.jobs.scan.Finding.all_objects.filter")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
def test_aggregate_findings_by_region_empty_findings(
|
||||
def test_aggregate_findings_by_region_multi_region_finding(
|
||||
self, mock_rls_transaction, mock_findings_filter
|
||||
):
|
||||
"""Test with no findings - should return empty dicts."""
|
||||
"""A finding with multiple resource_regions is tallied in every region."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
finding_rows = [
|
||||
(
|
||||
"check1",
|
||||
"FAIL",
|
||||
["us-east-1", "eu-west-1"],
|
||||
{modeled_threatscore_compliance_id: ["req1"]},
|
||||
)
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.only.return_value = mock_queryset
|
||||
mock_queryset.prefetch_related.return_value = []
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
@@ -3914,6 +3927,92 @@ class TestAggregateFindingsByRegion:
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
normalized_id = re.sub(
|
||||
r"[^a-z0-9]", "", modeled_threatscore_compliance_id.lower()
|
||||
)
|
||||
for region in ("us-east-1", "eu-west-1"):
|
||||
assert check_status_by_region[region]["check1"] == "FAIL"
|
||||
req_stats = findings_count_by_compliance[region][normalized_id]["req1"]
|
||||
assert req_stats == {"total": 1, "pass": 0}
|
||||
|
||||
@patch("tasks.jobs.scan.Finding.all_objects.filter")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
def test_aggregate_findings_by_region_skips_empty_regions(
|
||||
self, mock_rls_transaction, mock_findings_filter
|
||||
):
|
||||
"""A finding with no denormalized regions contributes nothing."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
finding_rows = [
|
||||
("check1", "FAIL", [], {modeled_threatscore_compliance_id: ["req1"]}),
|
||||
("check2", "PASS", None, {}),
|
||||
]
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = finding_rows
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, findings_count_by_compliance = (
|
||||
_aggregate_findings_by_region(
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
assert check_status_by_region == {}
|
||||
assert findings_count_by_compliance == {}
|
||||
|
||||
@patch("tasks.jobs.scan.Finding.all_objects.filter")
|
||||
@patch("tasks.jobs.scan.rls_transaction")
|
||||
def test_aggregate_findings_by_region_empty_findings(
|
||||
self, mock_rls_transaction, mock_findings_filter
|
||||
):
|
||||
"""Test with no findings - should return empty dicts."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
scan_id = str(uuid.uuid4())
|
||||
modeled_threatscore_compliance_id = "ProwlerThreatScore-1.0"
|
||||
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.values_list.return_value = mock_queryset
|
||||
mock_queryset.iterator.return_value = []
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = None
|
||||
ctx.__exit__.return_value = False
|
||||
mock_rls_transaction.return_value = ctx
|
||||
mock_findings_filter.return_value = mock_queryset
|
||||
|
||||
check_status_by_region, findings_count_by_compliance = (
|
||||
_aggregate_findings_by_region(
|
||||
tenant_id, scan_id, modeled_threatscore_compliance_id
|
||||
)
|
||||
)
|
||||
|
||||
# Streaming query contract: column-scoped values_list + iterator
|
||||
mock_queryset.values_list.assert_called_once_with(
|
||||
"check_id", "status", "resource_regions", "compliance"
|
||||
)
|
||||
mock_queryset.iterator.assert_called_once()
|
||||
|
||||
assert check_status_by_region == {}
|
||||
assert findings_count_by_compliance == {}
|
||||
|
||||
|
||||
Generated
+8
-58
@@ -146,7 +146,6 @@ constraints = [
|
||||
{ name = "django-celery-results", specifier = "==2.6.0" },
|
||||
{ name = "django-cors-headers", specifier = "==4.4.0" },
|
||||
{ name = "django-environ", specifier = "==0.11.2" },
|
||||
{ name = "django-eventstream", specifier = "==5.3.3" },
|
||||
{ name = "django-filter", specifier = "==24.3" },
|
||||
{ name = "django-guid", specifier = "==3.5.0" },
|
||||
{ name = "django-postgres-extra", specifier = "==2.0.9" },
|
||||
@@ -191,7 +190,7 @@ constraints = [
|
||||
{ name = "grpc-google-iam-v1", specifier = "==0.14.3" },
|
||||
{ name = "grpcio", specifier = "==1.76.0" },
|
||||
{ name = "grpcio-status", specifier = "==1.76.0" },
|
||||
{ name = "gunicorn", specifier = "==26.0.0" },
|
||||
{ name = "gunicorn", specifier = "==23.0.0" },
|
||||
{ name = "h11", specifier = "==0.16.0" },
|
||||
{ name = "h2", specifier = "==4.3.0" },
|
||||
{ name = "hpack", specifier = "==4.1.0" },
|
||||
@@ -2363,19 +2362,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/f1/468b49cccba3b42dda571063a14c668bb0b53a1d5712426d18e36663bd53/django_environ-0.11.2-py2.py3-none-any.whl", hash = "sha256:0ff95ab4344bfeff693836aa978e6840abef2e2f1145adff7735892711590c05", size = 19141, upload-time = "2023-09-01T21:02:59.88Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "django-eventstream"
|
||||
version = "5.3.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "django" },
|
||||
{ name = "django-grip" },
|
||||
{ name = "gripcontrol" },
|
||||
{ name = "pyjwt", extra = ["crypto"] },
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/49/ec6cbc24f3f30465370df7096cfea9722bad2b0c1f35a7ff5d45fb96cff6/django_eventstream-5.3.3.tar.gz", hash = "sha256:6880b03298eebf18c1b736b972fb862eaf631dfbb79f8b27496418a3495d08dc", size = 47622, upload-time = "2025-10-23T00:22:40.291Z" }
|
||||
|
||||
[[package]]
|
||||
name = "django-filter"
|
||||
version = "24.3"
|
||||
@@ -2388,19 +2374,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/09/b1/92f1c30b47c1ebf510c35a2ccad9448f73437e5891bbd2b4febe357cc3de/django_filter-24.3-py3-none-any.whl", hash = "sha256:c4852822928ce17fb699bcfccd644b3574f1a2d80aeb2b4ff4f16b02dd49dc64", size = 95011, upload-time = "2024-08-02T13:27:55.616Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "django-grip"
|
||||
version = "3.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "django" },
|
||||
{ name = "gripcontrol" },
|
||||
{ name = "pubcontrol" },
|
||||
{ name = "six" },
|
||||
{ name = "werkzeug" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cb/d0/2c7b04fa864073cd8cb324f8674672162282d97540d56732cbd3a9ae5bca/django-grip-3.5.2.tar.gz", hash = "sha256:1ee1601492cd110256bd03e4a68797a9fbefa27c15f5a838bf245df97db0450c", size = 7626, upload-time = "2025-03-24T18:53:58.677Z" }
|
||||
|
||||
[[package]]
|
||||
name = "django-guid"
|
||||
version = "3.5.0"
|
||||
@@ -3012,17 +2985,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/ab/717c58343cf02c5265b531384b248787e04d8160b8afe53d9eec053d7b44/greenlet-3.3.1-cp312-cp312-win_arm64.whl", hash = "sha256:bfb2d1763d777de5ee495c85309460f6fd8146e50ec9d0ae0183dbf6f0a829d1", size = 226403, upload-time = "2026-01-23T15:31:39.372Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gripcontrol"
|
||||
version = "4.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pubcontrol" },
|
||||
{ name = "pyjwt", extra = ["crypto"] },
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4f/51/1cbf88384fbe97a1454fb0adddcdca8cb90ceb99c3250274c334db844f4f/gripcontrol-4.4.0.tar.gz", hash = "sha256:44ee6fe244a02870aa4e5bc810138ccaf5070dce5eb149b8ee9e27b960a95c2d", size = 12526, upload-time = "2026-05-14T21:19:28.49Z" }
|
||||
|
||||
[[package]]
|
||||
name = "grpc-google-iam-v1"
|
||||
version = "0.14.3"
|
||||
@@ -3084,14 +3046,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gunicorn"
|
||||
version = "26.0.0"
|
||||
version = "23.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6d/b7/a4a3f632f823e432ce6bc65f62961b7980c898c77f075a2f7118cb3846fe/gunicorn-26.0.0.tar.gz", hash = "sha256:ca9346f85e3a4aeeb64d491045c16b9a35647abd37ea15efe53080eb8b090baf", size = 727286, upload-time = "2026-05-05T06:38:25.529Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/34/72/9614c465dc206155d93eff0ca20d42e1e35afc533971379482de953521a4/gunicorn-23.0.0.tar.gz", hash = "sha256:f014447a0101dc57e294f6c18ca6b40227a4c90e9bdb586042628030cba004ec", size = 375031, upload-time = "2024-08-10T20:25:27.378Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/40/9c2384fc2be4ad25dd4a49decd5ad9ea5a3639814c11bd40ab77cb9f0a14/gunicorn-26.0.0-py3-none-any.whl", hash = "sha256:40233d26a5f0d1872916188c276e21641155111c2853f0c2cd55260aec0d24fc", size = 212009, upload-time = "2026-05-05T06:38:23.007Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/7d/6dac2a6e1eba33ee43f318edbed4ff29151a49b5d37f080aad1e6469bca4/gunicorn-23.0.0-py3-none-any.whl", hash = "sha256:ec400d38950de4dfd418cff8328b2c8faed0edb0d517d3394e457c317908ca4d", size = 85029, upload-time = "2024-08-10T20:25:24.996Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4454,7 +4416,7 @@ wheels = [
|
||||
[[package]]
|
||||
name = "prowler"
|
||||
version = "5.30.0"
|
||||
source = { git = "https://github.com/prowler-cloud/prowler.git?rev=master#f1d741214a60df17158c3fdc97804fd1fde64f3a" }
|
||||
source = { git = "https://github.com/prowler-cloud/prowler.git?rev=v5.30#f1d741214a60df17158c3fdc97804fd1fde64f3a" }
|
||||
dependencies = [
|
||||
{ name = "alibabacloud-actiontrail20200706" },
|
||||
{ name = "alibabacloud-credentials" },
|
||||
@@ -4542,7 +4504,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "prowler-api"
|
||||
version = "1.31.0"
|
||||
version = "1.31.2"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "cartography" },
|
||||
@@ -4555,7 +4517,6 @@ dependencies = [
|
||||
{ name = "django-celery-results" },
|
||||
{ name = "django-cors-headers" },
|
||||
{ name = "django-environ" },
|
||||
{ name = "django-eventstream" },
|
||||
{ name = "django-filter" },
|
||||
{ name = "django-guid" },
|
||||
{ name = "django-postgres-extra" },
|
||||
@@ -4620,7 +4581,6 @@ requires-dist = [
|
||||
{ name = "django-celery-results", specifier = "==2.6.0" },
|
||||
{ name = "django-cors-headers", specifier = "==4.4.0" },
|
||||
{ name = "django-environ", specifier = "==0.11.2" },
|
||||
{ name = "django-eventstream", specifier = "==5.3.3" },
|
||||
{ name = "django-filter", specifier = "==24.3" },
|
||||
{ name = "django-guid", specifier = "==3.5.0" },
|
||||
{ name = "django-postgres-extra", specifier = "==2.0.9" },
|
||||
@@ -4633,14 +4593,14 @@ requires-dist = [
|
||||
{ name = "drf-spectacular-jsonapi", specifier = "==0.5.1" },
|
||||
{ name = "fonttools", specifier = "==4.62.1" },
|
||||
{ name = "gevent", specifier = "==25.9.1" },
|
||||
{ name = "gunicorn", specifier = "==26.0.0" },
|
||||
{ name = "gunicorn", specifier = "==23.0.0" },
|
||||
{ name = "h2", specifier = "==4.3.0" },
|
||||
{ name = "lxml", specifier = "==6.1.0" },
|
||||
{ name = "markdown", specifier = "==3.10.2" },
|
||||
{ name = "matplotlib", specifier = "==3.10.8" },
|
||||
{ name = "neo4j", specifier = "==6.1.0" },
|
||||
{ name = "openai", specifier = "==1.109.1" },
|
||||
{ name = "prowler", git = "https://github.com/prowler-cloud/prowler.git?rev=master" },
|
||||
{ name = "prowler", git = "https://github.com/prowler-cloud/prowler.git?rev=v5.30" },
|
||||
{ name = "psycopg2-binary", specifier = "==2.9.9" },
|
||||
{ name = "pytest-celery", extras = ["redis"], specifier = "==1.3.0" },
|
||||
{ name = "reportlab", specifier = "==4.4.10" },
|
||||
@@ -4721,16 +4681,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/08/9c66c269b0d417a0af9fb969535f0371b8c538633535a7a6a5ca3f9231e2/psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab", size = 1163864, upload-time = "2023-10-28T09:37:28.155Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pubcontrol"
|
||||
version = "3.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyjwt", extra = ["crypto"] },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/25/6a/02202a247214a6ffd5148ab1b16aca1c334b40dca211bca0442c8b7c7447/pubcontrol-3.5.0.tar.gz", hash = "sha256:a5ec6b3f53edfd005675518e5e4cc23b34122776835ae7c6dbd1db173d1ff0cb", size = 18199, upload-time = "2023-07-05T19:11:40.477Z" }
|
||||
|
||||
[[package]]
|
||||
name = "py-deviceid"
|
||||
version = "0.1.1"
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
---
|
||||
title: 'Server-Sent Events (SSE)'
|
||||
---
|
||||
|
||||
This guide explains how to add a **Server-Sent Events (SSE)** endpoint to the Prowler API. SSE lets the backend push a one-way stream of events to a client over a single long-lived HTTP connection — ideal for live progress, token-by-token LLM output, or any "the server has news for you" use case where the client should not poll.
|
||||
|
||||
<Info>
|
||||
The platform ships the SSE **infrastructure** (`api.sse`) and wiring. No feature endpoint streams over SSE out of the box — this guide shows how to build one on top of the shared base.
|
||||
</Info>
|
||||
|
||||
## When to use SSE
|
||||
|
||||
| Need | Use |
|
||||
|------|-----|
|
||||
| Server pushes incremental updates, client only reads | **SSE** |
|
||||
| Bidirectional, low-latency messaging (chat both ways, games) | WebSocket |
|
||||
| Client asks, server answers once | Plain REST |
|
||||
|
||||
SSE is the right tool when the **client only consumes**: scan progress, long-running job checkpoints, streamed LLM tokens, cross-client resource-sync notifications. It rides on plain HTTP, reconnects automatically in the browser via the native [`EventSource`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API, and needs no extra protocol.
|
||||
|
||||
## How it works
|
||||
|
||||
SSE is wired through [`django-eventstream`](https://github.com/fanout/django_eventstream) and a small platform layer in `api/src/backend/api/sse/`:
|
||||
|
||||
| Piece | File | Responsibility |
|
||||
|-------|------|----------------|
|
||||
| `BaseSSEViewSet` | `api/sse/base_views.py` | Base DRF viewset a feature subclasses. The feature implements `get_channels`; the base handles auth, the tenant transaction, and delegates streaming to `django-eventstream`. |
|
||||
| `SSEChannelManager` | `api/sse/channelmanager.py` | Registered in `settings.EVENTSTREAM_CHANNELMANAGER_CLASS`. Reads the channel set off the request and enforces the platform-wide tenant gate. |
|
||||
| `SSEAuthentication` | `api/authentication.py` | Same JWT/API-key stack as the rest of the API, plus an `?access_token=<jwt>` fallback for browser `EventSource` clients. Lives with the other authentication classes, not in the `sse` package. |
|
||||
| `make_channel_name` / `tenant_id_from_channel` | `api/sse/utils.py` | Single source of truth for the channel-name format, so publishers and the channel manager agree byte-for-byte. |
|
||||
| Settings | `config/settings/eventstream.py` | Valkey Pub/Sub backend (dedicated DB), channel manager, allowed headers. |
|
||||
|
||||
### Transport: the server runs on ASGI
|
||||
|
||||
SSE connections are long-lived. Holding one open per synchronous worker would exhaust the worker pool, so the API runs under Gunicorn's native **`asgi` worker** (`config.asgi:application`). Streams are parked on the event loop while ordinary CRUD endpoints keep their synchronous execution (Django runs sync views in a thread-sensitive executor under ASGI). This is configured in `config/guniconf.py` and used by both the dev and production entrypoints — no separate server process is needed.
|
||||
|
||||
### The data flow
|
||||
|
||||
```
|
||||
publisher (Celery task / view) subscriber (browser, CLI)
|
||||
│ │
|
||||
│ send_event(channel, "scan.progress", …) │ GET …/event-stream
|
||||
▼ ▼
|
||||
Valkey Pub/Sub ◄────────────────────► BaseSSEViewSet.list
|
||||
(EVENTSTREAM_VALKEY_DB) → get_channels() (RLS-scoped)
|
||||
→ SSEChannelManager (tenant gate)
|
||||
→ StreamingHttpResponse (text/event-stream)
|
||||
```
|
||||
|
||||
A publisher anywhere in the system (most often a Celery task) calls `send_event(channel, event_type, payload)`. `django-eventstream` fans it out over Valkey Pub/Sub to every connection subscribed to that channel.
|
||||
|
||||
## Adding an SSE endpoint to your feature
|
||||
|
||||
The example below streams progress for a long-running **scan**. Adapt the resource, prefix, and event names to your feature.
|
||||
|
||||
<Steps>
|
||||
|
||||
<Step title="Pick a channel prefix">
|
||||
|
||||
Channels follow the format `<prefix>:<tenant_id>:<resource_id>`, built only through `make_channel_name`. The prefix is owned by your feature and may contain hyphens but **never colons** (the parser splits on `:`).
|
||||
|
||||
```python
|
||||
CHANNEL_PREFIX = "scan-progress"
|
||||
```
|
||||
|
||||
The tenant id is baked into every channel name. That is what lets the platform enforce cross-tenant isolation without knowing anything about your feature.
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Subclass BaseSSEViewSet">
|
||||
|
||||
Create the viewset for the SSE sub-resource. The only required method is `get_channels`; it runs inside the tenant transaction set up by the base class, so any database lookup inside it is automatically RLS-scoped.
|
||||
|
||||
```python
|
||||
# scans/event_streams.py
|
||||
from api.sse import BaseSSEViewSet, make_channel_name
|
||||
from django.shortcuts import get_object_or_404
|
||||
from scans.models import Scan
|
||||
|
||||
CHANNEL_PREFIX = "scan-progress"
|
||||
|
||||
|
||||
class ScanEventStreamViewSet(BaseSSEViewSet):
|
||||
def get_queryset(self):
|
||||
# RLS already scopes to the tenant; narrow further as needed
|
||||
# (e.g. only scans the requesting user may see).
|
||||
return Scan.objects.filter(tenant_id=self.request.tenant_id)
|
||||
|
||||
def get_channels(self) -> set[str]:
|
||||
scan = get_object_or_404(self.get_queryset(), pk=self.kwargs["scan_pk"])
|
||||
return {make_channel_name(CHANNEL_PREFIX, scan.tenant_id, scan.id)}
|
||||
```
|
||||
|
||||
<Warning>
|
||||
`get_channels` **must raise** the relevant DRF exception (`NotFound`, `PermissionDenied`, `NotAuthenticated`) when authorization fails — `get_object_or_404` does this for you. Returning an empty set surfaces as django-eventstream's confusing "No channels specified" error instead of the real cause.
|
||||
</Warning>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Wire the URL as a sub-resource">
|
||||
|
||||
Mount the endpoint as an `event-stream` sub-resource. Keep it **outside the DRF router**, which would force the URL into a list/detail convention. Route the `get` method to the viewset's `list` action.
|
||||
|
||||
```python
|
||||
# scans/urls.py
|
||||
path(
|
||||
"scans/<uuid:scan_pk>/event-stream",
|
||||
ScanEventStreamViewSet.as_view({"get": "list"}),
|
||||
name="scan-event-stream",
|
||||
),
|
||||
```
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Define your event vocabulary">
|
||||
|
||||
A feature owns its event types in `<app>/<domain>/events.py`: one `publish_<event>` function per event type, each body a **single** `send_event` call so the wire-level string lives in exactly one place.
|
||||
|
||||
```python
|
||||
# scans/events.py
|
||||
from django_eventstream import send_event
|
||||
|
||||
|
||||
def publish_progress(channel: str, checked: int, total: int) -> None:
|
||||
send_event(channel, "scan.progress", {"checked": checked, "total": total})
|
||||
|
||||
|
||||
def publish_end(channel: str, scan_id: str) -> None:
|
||||
# Terminal event carries the canonical id so reconnecting clients
|
||||
# can refetch the persisted resource over REST.
|
||||
send_event(channel, "scan.end", {"scan_id": scan_id})
|
||||
|
||||
|
||||
def publish_error(channel: str, code: str, detail: str) -> None:
|
||||
send_event(channel, "scan.error", {"code": code, "detail": detail})
|
||||
```
|
||||
|
||||
There is no platform-side enum, registry, or dispatch table — **the naming convention is the contract** (see below).
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Publish from the producer">
|
||||
|
||||
Wherever the work happens — usually a Celery task — build the channel the same way and publish:
|
||||
|
||||
```python
|
||||
from api.sse import make_channel_name
|
||||
from scans.events import publish_progress, publish_end
|
||||
|
||||
channel = make_channel_name("scan-progress", scan.tenant_id, scan.id)
|
||||
publish_progress(channel, checked=42, total=100)
|
||||
...
|
||||
publish_end(channel, scan_id=str(scan.id))
|
||||
```
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## Event naming convention
|
||||
|
||||
Every event uses an event type of the form **`<resource>.<verb>`** (lowercased, dot-separated). The verb comes from this platform-wide vocabulary — if you need a verb that is not listed, document the addition in this guide so the catalog stays discoverable.
|
||||
|
||||
| Verb | When to use |
|
||||
|------|-------------|
|
||||
| `delta` | An incremental piece of a stream the client concatenates (LLM text tokens, audio chunks). Standard term across OpenAI / Anthropic / LiteLLM / Vercel AI SDK. |
|
||||
| `start` | Begin marker for a compound operation (e.g. a tool call whose execution will be reported by a matching `end`). |
|
||||
| `end` | Terminal marker. Carries the canonical resource id so reconnecting clients can refetch persisted state via REST. |
|
||||
| `progress` | Periodic checkpoint with quantifiable completion, e.g. `{"checked": 42, "total": 100}`. |
|
||||
| `created` / `updated` / `deleted` | Resource-lifecycle events for cross-client sync streams. |
|
||||
| `error` | Terminal failure. Carries a stable `code` for client switching and a human-readable `detail`. |
|
||||
|
||||
<Note>
|
||||
Payloads are **flat JSON**. The wire-level `event:` field already names the event type, so do **not** wrap the payload in `{"type": ..., "data": ...}`. Include the canonical resource UUID on terminal events so reconnecting clients can reconcile via REST.
|
||||
</Note>
|
||||
|
||||
## Authentication
|
||||
|
||||
SSE endpoints use the same authentication stack as the rest of the API. Non-browser clients (CLI, programmatic) send the standard `Authorization` header — JWT or API key.
|
||||
|
||||
Browser `EventSource` is the only widely available SSE client API and it **cannot set custom headers**. For that case only, the endpoint accepts a JWT via the `?access_token=<jwt>` query parameter. The header always wins when present — a header is intentional, while a query parameter can leak into referers and logs, so it is consulted only as a fallback.
|
||||
|
||||
```javascript
|
||||
// Browser
|
||||
const es = new EventSource(
|
||||
`/api/v1/scans/${scanId}/event-stream?access_token=${jwt}`
|
||||
);
|
||||
```
|
||||
|
||||
```bash
|
||||
# CLI / programmatic — header, exactly like every other endpoint
|
||||
curl -N -H "Authorization: Bearer $JWT" \
|
||||
https://<host>/api/v1/scans/$SCAN_ID/event-stream
|
||||
```
|
||||
|
||||
## Tenant isolation & security model
|
||||
|
||||
Authorization is enforced at two layers:
|
||||
|
||||
1. **At connect**, `get_channels` runs under the regular DRF stack inside the tenant transaction (`rls_transaction`). Resource lookups are RLS-scoped, so a user cannot even resolve a channel for a resource they cannot see. Narrow the queryset further (e.g. `created_by=request.user`) when a resource is per-user within a tenant.
|
||||
2. **After connect**, `SSEChannelManager.can_read_channel` re-verifies tenant membership by parsing the tenant id embedded in the channel name. Cross-tenant subscription is rejected even if a URL-level check ever has a bug. A malformed channel name is treated as "not authorized".
|
||||
|
||||
Because the tenant id lives inside the channel name, this gate works for any feature without the platform knowing anything about it.
|
||||
|
||||
## Reconnect & state recovery
|
||||
|
||||
The platform deliberately ships **without server-side replay** (`is_channel_reliable` returns `False`). When a client reconnects, it does **not** receive missed events. Instead:
|
||||
|
||||
- Terminal events (`*.end`) carry the canonical resource **UUID**.
|
||||
- On reconnect, the client refetches the authoritative state from the normal REST endpoint using that id.
|
||||
|
||||
Design your event payloads accordingly: deltas are ephemeral and concatenated in-flight; the durable truth always lives behind a REST resource.
|
||||
|
||||
## Local development
|
||||
|
||||
- The dev and production entrypoints both launch Gunicorn with the `asgi` worker (`config.asgi:application`). In dev, `DJANGO_DEBUG=True` enables hot reload; `preload_app` is automatically disabled under debug so edited code is picked up.
|
||||
- SSE uses a **dedicated Valkey database** (`EVENTSTREAM_VALKEY_DB`, default `2`) kept separate from the Celery broker so a noisy broker cannot crowd out streaming traffic. It reuses the same `VALKEY_*` connection settings as the rest of the platform.
|
||||
|
||||
| Env var | Default | Purpose |
|
||||
|---------|---------|---------|
|
||||
| `EVENTSTREAM_VALKEY_DB` | `2` | Valkey DB index for the SSE Pub/Sub bus |
|
||||
| `DJANGO_WORKER_CLASS` | `asgi` | Gunicorn worker class |
|
||||
|
||||
Test the stream end to end with `curl -N` (disable buffering) and an auth header:
|
||||
|
||||
```bash
|
||||
curl -N -H "Authorization: Bearer $JWT" \
|
||||
http://localhost:8080/api/v1/scans/$SCAN_ID/event-stream
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The platform basis is covered by `api/tests/test_sse.py` (channel parsing, the tenant gate, and auth precedence). For a feature endpoint, test:
|
||||
|
||||
- `get_channels` returns the expected channel for an authorized resource and raises `NotFound`/`PermissionDenied` otherwise.
|
||||
- Each `publish_<event>` helper emits the correct event type and flat payload (mock `send_event`).
|
||||
- The producer builds the channel with `make_channel_name` using the resource's own `tenant_id`.
|
||||
+1
-2
@@ -395,8 +395,7 @@
|
||||
"developer-guide/lighthouse-architecture",
|
||||
"developer-guide/mcp-server",
|
||||
"developer-guide/ai-skills",
|
||||
"developer-guide/prowler-studio",
|
||||
"developer-guide/server-sent-events"
|
||||
"developer-guide/prowler-studio"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -138,6 +138,10 @@ To keep permissions focused:
|
||||
|
||||
4. Continue through the wizard and finish. No principals need to be granted access in step 3 unless you want other identities to impersonate this account.
|
||||
|
||||
<Note>
|
||||
To use this service account with `--organization-id`, additionally grant `roles/cloudasset.viewer` at the organization node and enable the Cloud Asset API in the service account's host project. See [Scanning a Specific GCP Organization](./organization). Without these, organization-wide scans silently fall back to listing only the projects accessible to the service account.
|
||||
</Note>
|
||||
|
||||
### Step 3: Generate a JSON Key
|
||||
|
||||
1. Open the newly created service account, move to the **Keys** tab, and choose **Add key > Create new key**.
|
||||
|
||||
@@ -11,8 +11,19 @@ prowler gcp --organization-id organization-id
|
||||
```
|
||||
|
||||
<Warning>
|
||||
Ensure the credentials used have one of the following roles at the organization level:
|
||||
Cloud Asset Viewer (`roles/cloudasset.viewer`), or Cloud Asset Owner (`roles/cloudasset.owner`).
|
||||
Ensure the credentials used have one of the following roles bound **at the organization node** (not at a project): Cloud Asset Viewer (`roles/cloudasset.viewer`) or Cloud Asset Owner (`roles/cloudasset.owner`). The role must be bound directly on the organization so the Cloud Asset API can enumerate projects across the whole hierarchy.
|
||||
|
||||
```bash
|
||||
gcloud organizations add-iam-policy-binding <organization-id> \
|
||||
--member="serviceAccount:<service-account-email>" \
|
||||
--role="roles/cloudasset.viewer"
|
||||
```
|
||||
|
||||
The Cloud Asset API (`cloudasset.googleapis.com`) must also be enabled in the project that owns the credentials (the service account's host project, or the quota project for user credentials):
|
||||
|
||||
```bash
|
||||
gcloud services enable cloudasset.googleapis.com --project <credentials-project-id>
|
||||
```
|
||||
|
||||
</Warning>
|
||||
<Note>
|
||||
|
||||
@@ -2,6 +2,16 @@
|
||||
|
||||
All notable changes to the **Prowler SDK** are documented in this file.
|
||||
|
||||
## [5.30.2] (Prowler v5.30.2)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- GCP `logging_log_metric_filter_and_alert_*` checks now credit org-level aggregated sinks filtered to the Admin Activity audit stream [(#11575)](https://github.com/prowler-cloud/prowler/pull/11575)
|
||||
- A broken built-in provider no longer aborts the CLI when a different provider was invoked [(#11618)](https://github.com/prowler-cloud/prowler/pull/11618)
|
||||
- GCP organization scans with `--organization-id` no longer silently fall back to the credentials' host project when the Cloud Asset API call fails [(#11280)](https://github.com/prowler-cloud/prowler/pull/11280)
|
||||
|
||||
---
|
||||
|
||||
## [5.30.0] (Prowler v5.30.0)
|
||||
|
||||
### 🚀 Added
|
||||
|
||||
@@ -49,7 +49,7 @@ class _MutableTimestamp:
|
||||
|
||||
timestamp = _MutableTimestamp(datetime.today())
|
||||
timestamp_utc = _MutableTimestamp(datetime.now(timezone.utc))
|
||||
prowler_version = "5.30.0"
|
||||
prowler_version = "5.30.2"
|
||||
html_logo_url = "https://github.com/prowler-cloud/prowler/"
|
||||
square_logo_img = "https://raw.githubusercontent.com/prowler-cloud/prowler/dc7d2d5aeb92fdf12e8604f42ef6472cd3e8e889/docs/img/prowler-logo-black.png"
|
||||
aws_logo = "https://user-images.githubusercontent.com/38561120/235953920-3e3fba08-0795-41dc-b480-9bea57db9f2e.png"
|
||||
|
||||
@@ -15,6 +15,8 @@ from prowler.lib.check.models import Severity
|
||||
from prowler.lib.cli.redact import warn_sensitive_argument_values
|
||||
from prowler.lib.outputs.common import Status
|
||||
from prowler.providers.common.arguments import (
|
||||
PROVIDER_ALIASES,
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
validate_asff_usage,
|
||||
validate_provider_arguments,
|
||||
@@ -166,13 +168,13 @@ Detailed documentation at https://docs.prowler.com
|
||||
if sys.argv[1].startswith("-"):
|
||||
sys.argv = self.__set_default_provider__(sys.argv)
|
||||
|
||||
# Provider aliases mapping
|
||||
# Microsoft 365
|
||||
elif sys.argv[1] == "microsoft365":
|
||||
sys.argv[1] = "m365"
|
||||
# Oracle Cloud Infrastructure
|
||||
elif sys.argv[1] == "oci":
|
||||
sys.argv[1] = "oraclecloud"
|
||||
# Provider aliases mapping (single source: arguments.PROVIDER_ALIASES)
|
||||
elif sys.argv[1] in PROVIDER_ALIASES:
|
||||
sys.argv[1] = PROVIDER_ALIASES[sys.argv[1]]
|
||||
|
||||
# Selective fail-loud here (post argv-normalisation, pre parse_args)
|
||||
# so the invoked-provider check stays correct under parse(args=...).
|
||||
enforce_invoked_provider_loaded(self)
|
||||
|
||||
# Warn about sensitive flags passed with explicit values
|
||||
# Snapshot argv before parse_args() which may exit on errors
|
||||
|
||||
@@ -10,16 +10,43 @@ provider_arguments_lib_path = "lib.arguments.arguments"
|
||||
validate_provider_arguments_function = "validate_arguments"
|
||||
init_provider_arguments_function = "init_parser"
|
||||
|
||||
# Kept in sync with parser.py's argv normalisation; both consumers import this.
|
||||
PROVIDER_ALIASES = {
|
||||
"microsoft365": "m365",
|
||||
"oci": "oraclecloud",
|
||||
}
|
||||
|
||||
|
||||
def _invoked_provider_from_argv(available_providers: Sequence[str]) -> Optional[str]:
|
||||
"""Return the provider name the user invoked, or None.
|
||||
|
||||
Mirrors `ProwlerArgumentParser.parse()` resolution: only inspects
|
||||
`sys.argv[1]`. Scanning the whole argv would misclassify
|
||||
`prowler --output-directory stackit` as `stackit`.
|
||||
"""
|
||||
available = set(available_providers)
|
||||
if len(sys.argv) < 2:
|
||||
return "aws" if "aws" in available else None
|
||||
first = sys.argv[1]
|
||||
if first in ("-h", "--help", "-v", "--version"):
|
||||
return None
|
||||
if first.startswith("-"):
|
||||
return "aws" if "aws" in available else None
|
||||
normalized = PROVIDER_ALIASES.get(first, first)
|
||||
return normalized if normalized in available else None
|
||||
|
||||
|
||||
def init_providers_parser(self):
|
||||
"""init_providers_parser calls the provider init_parser function to load all the arguments and flags. Receives a ProwlerArgumentParser object"""
|
||||
# We need to call the arguments parser for each provider
|
||||
"""Build the subparser of each available provider.
|
||||
|
||||
Built-in load failures are captured silently on
|
||||
`self._builtin_load_failures`; the warn/exit decision is deferred to
|
||||
`enforce_invoked_provider_loaded()` because `parse(args=...)` can
|
||||
override `sys.argv` after this function ran.
|
||||
"""
|
||||
self._builtin_load_failures = {}
|
||||
providers = Provider.get_available_providers()
|
||||
for provider in providers:
|
||||
# Discriminate built-in vs external upfront via find_spec, so an
|
||||
# ImportError from a transitive dependency missing inside a built-in
|
||||
# arguments module surfaces clearly instead of being silently
|
||||
# re-routed to the entry-point path (which only has external providers).
|
||||
if Provider.is_builtin(provider):
|
||||
try:
|
||||
getattr(
|
||||
@@ -28,21 +55,9 @@ def init_providers_parser(self):
|
||||
),
|
||||
init_provider_arguments_function,
|
||||
)(self)
|
||||
except ImportError as e:
|
||||
logger.critical(
|
||||
f"Failed to load arguments for built-in provider '{provider}'. "
|
||||
f"Missing dependency: {e}. "
|
||||
f"Ensure all required dependencies are installed."
|
||||
)
|
||||
logger.debug("Full traceback:", exc_info=True)
|
||||
sys.exit(1)
|
||||
except Exception as error:
|
||||
logger.critical(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
sys.exit(1)
|
||||
self._builtin_load_failures[provider] = error
|
||||
else:
|
||||
# External provider — init_parser classmethod via entry point
|
||||
cls = Provider._load_ep_provider(provider)
|
||||
if cls and hasattr(cls, "init_parser"):
|
||||
try:
|
||||
@@ -53,6 +68,51 @@ def init_providers_parser(self):
|
||||
)
|
||||
|
||||
|
||||
def enforce_invoked_provider_loaded(self):
|
||||
"""Apply selective fail-loud over the failures captured at init time.
|
||||
|
||||
Called by `ProwlerArgumentParser.parse()` AFTER argv normalisation so
|
||||
the invoked provider matches what argparse will dispatch to — including
|
||||
the case where `parse(args=...)` overrode the ambient `sys.argv`.
|
||||
|
||||
Invoked + failed → critical + `sys.exit(1)`. Others → warning.
|
||||
"""
|
||||
failures = getattr(self, "_builtin_load_failures", {})
|
||||
if not failures:
|
||||
return
|
||||
invoked = _invoked_provider_from_argv(Provider.get_available_providers())
|
||||
for provider, error in failures.items():
|
||||
if provider == invoked:
|
||||
continue
|
||||
if isinstance(error, ImportError):
|
||||
logger.warning(
|
||||
f"Skipping built-in provider '{provider}' due to missing "
|
||||
f"dependency: {error}. It will be unavailable in this "
|
||||
f"invocation, but the CLI continues because you invoked a "
|
||||
f"different provider."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping built-in provider '{provider}': "
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
if invoked is None or invoked not in failures:
|
||||
return
|
||||
error = failures[invoked]
|
||||
if isinstance(error, ImportError):
|
||||
logger.critical(
|
||||
f"Failed to load arguments for built-in provider '{invoked}'. "
|
||||
f"Missing dependency: {error}. "
|
||||
f"Ensure all required dependencies are installed."
|
||||
)
|
||||
logger.debug("Full traceback:", exc_info=True)
|
||||
else:
|
||||
logger.critical(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def validate_provider_arguments(arguments: Namespace) -> tuple[bool, str]:
|
||||
"""validate_provider_arguments returns {True, "} if the provider arguments passed are valid and can be used together"""
|
||||
try:
|
||||
|
||||
@@ -34,11 +34,17 @@ class GCPBaseException(ProwlerException):
|
||||
"message": "Error loading Service Account Private Key credentials from dictionary",
|
||||
"remediation": "Check the dictionary and ensure it contains a Service Account Private Key.",
|
||||
},
|
||||
(3011, "GCPGetOrganizationProjectsError"): {
|
||||
"message": "Error retrieving projects under the organization via the Cloud Asset API",
|
||||
"remediation": "Ensure the Cloud Asset API is enabled in the credentials' project and that the principal has 'roles/cloudasset.viewer' bound at the organization level. See https://cloud.google.com/asset-inventory/docs/access-control.",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, code, file=None, original_exception=None, message=None):
|
||||
provider = "GCP"
|
||||
error_info = self.GCP_ERROR_CODES.get((code, self.__class__.__name__))
|
||||
# Copy the catalog entry so a custom message does not mutate the
|
||||
# class-level GCP_ERROR_CODES shared across exception instances.
|
||||
error_info = dict(self.GCP_ERROR_CODES.get((code, self.__class__.__name__)))
|
||||
if message:
|
||||
error_info["message"] = message
|
||||
super().__init__(
|
||||
@@ -104,3 +110,10 @@ class GCPLoadServiceAccountKeyFromDictError(GCPCredentialsError):
|
||||
super().__init__(
|
||||
3010, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
|
||||
class GCPGetOrganizationProjectsError(GCPBaseException):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
3011, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
@@ -21,6 +21,8 @@ from prowler.providers.common.models import Audit_Metadata, Connection
|
||||
from prowler.providers.common.provider import Provider
|
||||
from prowler.providers.gcp.config import DEFAULT_RETRY_ATTEMPTS
|
||||
from prowler.providers.gcp.exceptions.exceptions import (
|
||||
GCPBaseException,
|
||||
GCPGetOrganizationProjectsError,
|
||||
GCPInvalidProviderIdError,
|
||||
GCPLoadADCFromDictError,
|
||||
GCPLoadServiceAccountKeyFromDictError,
|
||||
@@ -621,10 +623,7 @@ class GcpProvider(Provider):
|
||||
credentials_file: str
|
||||
|
||||
Returns:
|
||||
dict[str, GCPProject]
|
||||
|
||||
Usage:
|
||||
>>> GcpProvider.get_projects(credentials=credentials, organization_id=organization_id)
|
||||
dict of project_id and GCPProject object
|
||||
"""
|
||||
projects = {}
|
||||
try:
|
||||
@@ -632,7 +631,10 @@ class GcpProvider(Provider):
|
||||
try:
|
||||
# Initialize Cloud Asset Inventory API for recursive project retrieval
|
||||
asset_service = discovery.build(
|
||||
"cloudasset", "v1", credentials=credentials
|
||||
"cloudasset",
|
||||
"v1",
|
||||
credentials=credentials,
|
||||
num_retries=DEFAULT_RETRY_ATTEMPTS,
|
||||
)
|
||||
# Set the scope to the specified organization and filter for projects
|
||||
scope = f"organizations/{organization_id}"
|
||||
@@ -643,7 +645,7 @@ class GcpProvider(Provider):
|
||||
)
|
||||
|
||||
while request is not None:
|
||||
response = request.execute()
|
||||
response = request.execute(num_retries=DEFAULT_RETRY_ATTEMPTS)
|
||||
|
||||
for asset in response.get("assets", []):
|
||||
# Extract labels and other project details
|
||||
@@ -688,13 +690,25 @@ class GcpProvider(Provider):
|
||||
)
|
||||
except HttpError as http_error:
|
||||
if "Cloud Asset API has not been used" in str(http_error):
|
||||
logger.error(
|
||||
f"Projects cannot be retrieved from the Organization since Cloud Asset API has not been used before or it is disabled [{http_error.__traceback__.tb_lineno}]. Enable it by visiting https://console.developers.google.com/apis/api/cloudasset.googleapis.com/ then retry."
|
||||
message = (
|
||||
"Projects cannot be retrieved from the Organization since the Cloud Asset API "
|
||||
"has not been used before or it is disabled. Enable it by visiting "
|
||||
"https://console.developers.google.com/apis/api/cloudasset.googleapis.com/ then retry."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"{http_error.__class__.__name__}[{http_error.__traceback__.tb_lineno}]: {http_error}"
|
||||
message = (
|
||||
f"Cloud Asset API call failed while listing projects under organization "
|
||||
f"'{organization_id}': {http_error}. Ensure the credentials' principal has "
|
||||
"'roles/cloudasset.viewer' bound at the organization level."
|
||||
)
|
||||
logger.critical(
|
||||
f"{http_error.__class__.__name__}[{http_error.__traceback__.tb_lineno}]: {message}"
|
||||
)
|
||||
raise GCPGetOrganizationProjectsError(
|
||||
file=__file__,
|
||||
original_exception=http_error,
|
||||
message=message,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Initialize Cloud Resource Manager API for simple project listing
|
||||
@@ -781,8 +795,10 @@ class GcpProvider(Provider):
|
||||
labels={},
|
||||
lifecycle_state="ACTIVE",
|
||||
)
|
||||
# If no projects were able to be accessed via API, add them manually from the credentials file
|
||||
elif credentials_file:
|
||||
# If no projects were able to be accessed via API, add them manually from the credentials file.
|
||||
# Skip this fallback when an organization scan was explicitly requested: silently
|
||||
# downgrading scope to the service account's home project hides permission errors.
|
||||
elif credentials_file and not organization_id:
|
||||
with open(credentials_file, "r", encoding="utf-8") as file:
|
||||
project_id = json.load(file)["project_id"]
|
||||
# Handle empty or null project names
|
||||
@@ -798,6 +814,8 @@ class GcpProvider(Provider):
|
||||
labels={},
|
||||
lifecycle_state="ACTIVE",
|
||||
)
|
||||
except GCPBaseException as gcp_error:
|
||||
raise gcp_error
|
||||
except Exception as error:
|
||||
logger.critical(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import re
|
||||
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.gcp.config import DEFAULT_RETRY_ATTEMPTS
|
||||
from prowler.providers.gcp.gcp_provider import GcpProvider
|
||||
from prowler.providers.gcp.lib.service.service import GCPService
|
||||
from prowler.providers.gcp.services.monitoring.monitoring_service import Monitoring
|
||||
|
||||
|
||||
class Logging(GCPService):
|
||||
@@ -121,9 +124,86 @@ class Metric(BaseModel):
|
||||
bucket_name: str = ""
|
||||
|
||||
|
||||
# A positive selector of the Admin Activity stream: a ``logName`` predicate
|
||||
# (``:`` has-substring or ``=`` equals) or a ``log_id()`` call. Written verbose
|
||||
# so each fragment stays legible; ``(?![a-z_])`` keeps a longer stream name
|
||||
# (``.../activity_v2``) from impersonating Admin Activity.
|
||||
_ACTIVITY_SELECTOR = re.compile(
|
||||
r"""
|
||||
(?: logName \s* [:=] \s* | log_id \s* \( \s* ) # logName: / logName= / log_id(
|
||||
["']? [^"'\s)]* # optional quote, then path prefix
|
||||
cloudaudit\.googleapis\.com/activity (?![a-z_]) # the Admin Activity stream itself
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
# The same selector for *any* Cloud Audit stream (activity, data_access,
|
||||
# system_event, policy, access_transparency, …). Used to strip the OR-combined
|
||||
# audit clauses so we can prove nothing restrictive is left over.
|
||||
_CLOUDAUDIT_SELECTOR = re.compile(
|
||||
r"""
|
||||
(?: logName \s* [:=] \s* | log_id \s* \( \s* ) # logName: / logName= / log_id(
|
||||
["']? [^"'\s)]* # optional quote, then path prefix
|
||||
cloudaudit\.googleapis\.com/[a-z_]+ # any cloudaudit stream
|
||||
["']? \s* \)? # optional closing quote / paren
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
# Operators that exclude or narrow coverage. Any of these means we cannot prove
|
||||
# the sink delivers the *whole* Admin Activity stream, so it is not credited.
|
||||
_NEGATION_OR_RESTRICTION = re.compile(
|
||||
r"""
|
||||
\bNOT\b # NOT exclusion
|
||||
| \bAND\b # AND conjunction (restriction)
|
||||
| != | !: # "!=" / "!:" inequality
|
||||
| (?:^|[\s(]) -\s* [A-Za-z_] # leading "-" exclusion operator
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def _sink_delivers_activity_logs(sink_filter: str) -> bool:
|
||||
"""True only when a sink's filter *provably* exports the full Admin Activity
|
||||
audit stream (or everything).
|
||||
|
||||
Crediting flips a child project to PASS on a CIS security control, so the
|
||||
match is deliberately conservative: a false FAIL is safe, a false PASS is
|
||||
not. A non-``"all"`` filter is credited only when
|
||||
|
||||
1. it positively selects the Admin Activity stream
|
||||
(``logName:.../activity``, ``logName="...activity"`` or
|
||||
``log_id("...activity")``);
|
||||
2. it carries no operator that excludes or narrows the stream — ``NOT`` /
|
||||
``-`` / ``!=`` (negation) or ``AND`` (restriction); and
|
||||
3. nothing but ``OR``-combined Cloud Audit selectors remains once those are
|
||||
stripped — an ``OR`` only widens coverage, but any leftover predicate
|
||||
(``severity>=ERROR``, ``resource.type=...``) could narrow it.
|
||||
|
||||
Sink filters encode the stream URL-encoded (``...%2Factivity``) or as a path
|
||||
— normalize before matching.
|
||||
"""
|
||||
if not sink_filter or sink_filter.strip().lower() == "all":
|
||||
return True
|
||||
normalized = sink_filter.replace("%2F", "/").replace("%2f", "/")
|
||||
# 1. The Admin Activity stream must be positively selected.
|
||||
if not _ACTIVITY_SELECTOR.search(normalized):
|
||||
return False
|
||||
# 2. No operator may exclude or narrow that coverage.
|
||||
if _NEGATION_OR_RESTRICTION.search(normalized):
|
||||
return False
|
||||
# 3. Only OR-combined audit selectors may remain — strip them and the OR
|
||||
# glue; anything left is a predicate we cannot prove is full-coverage.
|
||||
remainder = _CLOUDAUDIT_SELECTOR.sub(" ", normalized)
|
||||
remainder = re.sub(r"\bOR\b|[()\s]", " ", remainder, flags=re.IGNORECASE)
|
||||
return remainder.strip() == ""
|
||||
|
||||
|
||||
def get_projects_covered_by_aggregated_metric(
|
||||
logging_client, monitoring_client, metric_filter
|
||||
):
|
||||
logging_client: Logging,
|
||||
monitoring_client: Monitoring,
|
||||
metric_filter: str,
|
||||
) -> dict[str, str]:
|
||||
"""Return {project_id: metric_name} for scanned projects whose logs are routed,
|
||||
via an organization-level sink with includeChildren=True, to a bucket that holds
|
||||
a bucket-scoped log metric matching ``metric_filter`` that has an alert policy.
|
||||
@@ -133,6 +213,10 @@ def get_projects_covered_by_aggregated_metric(
|
||||
every child project's logs into one bucket, where a single bucket-scoped metric
|
||||
+ alert covers them all. Without crediting that, those child projects are falsely
|
||||
failed. Mirrors the org-sink handling already in ``logging_sink_created`` (#11355).
|
||||
|
||||
A sink is credited when it exports everything (``filter == "all"``) or when its
|
||||
filter carries the Admin Activity audit stream — the only stream the CIS metric
|
||||
filters can match (see ``_sink_delivers_activity_logs``).
|
||||
"""
|
||||
# Buckets that hold a matching, alerted, bucket-scoped metric -> metric name.
|
||||
bucket_to_metric = {}
|
||||
@@ -155,7 +239,7 @@ def get_projects_covered_by_aggregated_metric(
|
||||
for sink in logging_client.sinks:
|
||||
if not getattr(sink, "include_children", False):
|
||||
continue
|
||||
if getattr(sink, "filter", "all") != "all":
|
||||
if not _sink_delivers_activity_logs(getattr(sink, "filter", "all")):
|
||||
continue
|
||||
for bucket, metric_name in bucket_to_metric.items():
|
||||
# sink.destination e.g. "logging.googleapis.com/projects/.../buckets/X";
|
||||
|
||||
+1
-1
@@ -124,7 +124,7 @@ maintainers = [{name = "Prowler Engineering", email = "engineering@prowler.com"}
|
||||
name = "prowler"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.13"
|
||||
version = "5.30.0"
|
||||
version = "5.30.2"
|
||||
|
||||
[project.scripts]
|
||||
prowler = "prowler.__main__:prowler"
|
||||
|
||||
+297
-11
@@ -417,17 +417,19 @@ class TestIsBuiltinProvider:
|
||||
|
||||
|
||||
class TestInitProvidersParserBuiltinDependencyFailure:
|
||||
"""Tests the critical behavior fix: when a built-in provider's arguments
|
||||
module exists but its imports fail (e.g. boto3 not installed), we must
|
||||
fail loudly with a clear message — not silently fall through to entry
|
||||
points as if the provider were external."""
|
||||
"""Selective fail-loud: init captures failures silently, enforce emits
|
||||
warning for non-invoked and exits for the invoked broken provider."""
|
||||
|
||||
@patch("sys.argv", ["prowler", "aws"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_builtin_with_missing_transitive_dep_fails_loudly(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
from prowler.providers.common.arguments import init_providers_parser
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
mock_import.side_effect = ImportError("No module named 'boto3'")
|
||||
@@ -435,14 +437,14 @@ class TestInitProvidersParserBuiltinDependencyFailure:
|
||||
parser = MagicMock()
|
||||
parser._providers = ["aws"]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws"],
|
||||
),
|
||||
pytest.raises(SystemExit),
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
assert "aws" in parser._builtin_load_failures
|
||||
with pytest.raises(SystemExit):
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.Provider._load_ep_provider")
|
||||
@@ -466,6 +468,290 @@ class TestInitProvidersParserBuiltinDependencyFailure:
|
||||
|
||||
ext_cls.init_parser.assert_called_once_with(parser)
|
||||
|
||||
@patch("sys.argv", ["prowler", "aws"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_unrelated_builtin_failure_does_not_abort_when_other_provider_invoked(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
"""Broken stackit + invoked aws → warning, no abort."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
aws_module = MagicMock()
|
||||
|
||||
def import_side_effect(module_path):
|
||||
if "stackit" in module_path:
|
||||
raise ImportError("No module named 'stackit.objectstorage'")
|
||||
return aws_module
|
||||
|
||||
mock_import.side_effect = import_side_effect
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws", "stackit"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
assert "stackit" in parser._builtin_load_failures
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
aws_module.init_parser.assert_called_once_with(parser)
|
||||
|
||||
@patch("sys.argv", ["prowler", "-h"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_no_provider_invoked_failure_does_not_abort(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
"""`prowler -h` + broken built-in → warning, help still renders."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
mock_import.side_effect = ImportError("No module named 'stackit.objectstorage'")
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["stackit"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
@patch("sys.argv", ["prowler", "microsoft365"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_invoked_microsoft365_alias_still_triggers_fail_loud(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
"""Alias `microsoft365 → m365` must be normalised before matching."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
mock_import.side_effect = ImportError("No module named 'msgraph'")
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["m365"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
with pytest.raises(SystemExit):
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
@patch("sys.argv", ["prowler", "oci"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_invoked_oci_alias_still_triggers_fail_loud(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
"""Alias `oci → oraclecloud` must be normalised before matching."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
mock_import.side_effect = ImportError("No module named 'oci'")
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["oraclecloud"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
with pytest.raises(SystemExit):
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
@patch("sys.argv", ["prowler", "--output-directory", "stackit"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_flag_value_matching_provider_name_not_treated_as_invoked(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
"""Flag-first invocation → invoked is 'aws' (default), not the flag's value."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
aws_module = MagicMock()
|
||||
|
||||
def import_side_effect(module_path):
|
||||
if "stackit" in module_path:
|
||||
raise ImportError("No module named 'stackit.objectstorage'")
|
||||
return aws_module
|
||||
|
||||
mock_import.side_effect = import_side_effect
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws", "stackit"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
aws_module.init_parser.assert_called_once_with(parser)
|
||||
|
||||
@patch("sys.argv", ["prowler", "aws"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_invoked_builtin_non_import_error_fails_loudly(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
"""Non-ImportError in invoked provider → still fail-loud."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
mock_import.side_effect = RuntimeError("Unexpected error in aws init_parser")
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
with pytest.raises(SystemExit):
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
@patch("sys.argv", ["prowler", "aws"])
|
||||
@patch("prowler.providers.common.arguments.Provider.is_builtin")
|
||||
@patch("prowler.providers.common.arguments.import_module")
|
||||
def test_unrelated_builtin_non_import_error_does_not_abort(
|
||||
self, mock_import, mock_is_builtin
|
||||
):
|
||||
"""Non-ImportError in unrelated provider → warning, no abort."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
mock_is_builtin.return_value = True
|
||||
aws_module = MagicMock()
|
||||
|
||||
def import_side_effect(module_path):
|
||||
if "stackit" in module_path:
|
||||
raise RuntimeError("Unexpected error in stackit init_parser")
|
||||
return aws_module
|
||||
|
||||
mock_import.side_effect = import_side_effect
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws", "stackit"],
|
||||
):
|
||||
init_providers_parser(parser)
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
aws_module.init_parser.assert_called_once_with(parser)
|
||||
|
||||
|
||||
class TestParseArgsOverrideAlignment:
|
||||
"""Regression: `parse(args=...)` overrides sys.argv AFTER __init__ ran;
|
||||
the selective fail-loud must read argv at enforce time, not init time."""
|
||||
|
||||
def test_enforce_reads_current_sys_argv_not_init_time_sys_argv(self):
|
||||
"""Init with argv=['prowler','-h'] (no provider) captures stackit
|
||||
failure silently. Enforce with argv=['prowler','stackit'] must
|
||||
fail-loud — proving alignment under parse(args=...)."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
def import_side_effect(path):
|
||||
if "stackit" in path:
|
||||
raise ImportError("No module named 'stackit.objectstorage'")
|
||||
return MagicMock()
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"prowler.providers.common.arguments.Provider.is_builtin",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws", "stackit"],
|
||||
),
|
||||
patch(
|
||||
"prowler.providers.common.arguments.import_module",
|
||||
side_effect=import_side_effect,
|
||||
),
|
||||
):
|
||||
# Phase 1: __init__ with ambient argv = ['prowler', '-h']
|
||||
with patch("sys.argv", ["prowler", "-h"]):
|
||||
init_providers_parser(parser)
|
||||
# Failure captured silently — no SystemExit during init
|
||||
assert "stackit" in parser._builtin_load_failures
|
||||
|
||||
# Phase 2: parse(args=...) overrode sys.argv → stackit invoked
|
||||
with patch("sys.argv", ["prowler", "stackit"]):
|
||||
with pytest.raises(SystemExit):
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
def test_enforce_reads_current_sys_argv_for_no_invocation(self):
|
||||
"""Inverse: init's argv invokes stackit, but parse(args=['prowler',
|
||||
'-h']) overrides. Enforce must NOT fail-loud."""
|
||||
from prowler.providers.common.arguments import (
|
||||
enforce_invoked_provider_loaded,
|
||||
init_providers_parser,
|
||||
)
|
||||
|
||||
def import_side_effect(path):
|
||||
if "stackit" in path:
|
||||
raise ImportError("No module named 'stackit.objectstorage'")
|
||||
return MagicMock()
|
||||
|
||||
parser = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"prowler.providers.common.arguments.Provider.is_builtin",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"prowler.providers.common.arguments.Provider.get_available_providers",
|
||||
return_value=["aws", "stackit"],
|
||||
),
|
||||
patch(
|
||||
"prowler.providers.common.arguments.import_module",
|
||||
side_effect=import_side_effect,
|
||||
),
|
||||
):
|
||||
# Phase 1: __init__ with ambient argv pretending stackit invoked
|
||||
with patch("sys.argv", ["prowler", "stackit"]):
|
||||
init_providers_parser(parser)
|
||||
assert "stackit" in parser._builtin_load_failures
|
||||
|
||||
# Phase 2: parse(args=['prowler', '-h']) overrode sys.argv →
|
||||
# no provider invoked anymore → enforce must NOT exit
|
||||
with patch("sys.argv", ["prowler", "-h"]):
|
||||
enforce_invoked_provider_loaded(parser)
|
||||
|
||||
|
||||
class TestInitGlobalProviderBuiltinDependencyFailure:
|
||||
"""Same contract as TestInitProvidersParserBuiltinDependencyFailure but
|
||||
|
||||
@@ -13,6 +13,7 @@ from prowler.config.config import (
|
||||
)
|
||||
from prowler.providers.common.models import Connection
|
||||
from prowler.providers.gcp.exceptions.exceptions import (
|
||||
GCPGetOrganizationProjectsError,
|
||||
GCPInvalidProviderIdError,
|
||||
GCPNoAccesibleProjectsError,
|
||||
GCPTestConnectionError,
|
||||
@@ -1077,3 +1078,66 @@ class TestGCPProvider:
|
||||
|
||||
assert gcp_provider.skip_api_check is True
|
||||
mocked_is_api_active.assert_not_called()
|
||||
|
||||
def test_get_projects_organization_id_permission_denied_raises(self):
|
||||
"""When --organization-id is set and the Cloud Asset API returns a 403,
|
||||
get_projects must raise GCPGetOrganizationProjectsError instead of
|
||||
silently falling back to the service account's home project.
|
||||
|
||||
Regression test for https://github.com/prowler-cloud/prowler/issues/11250.
|
||||
"""
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
forbidden_response = MagicMock(status=403, reason="Forbidden")
|
||||
http_error = HttpError(
|
||||
resp=forbidden_response,
|
||||
content=b'{"error": {"code": 403, "message": "Permission denied on resource organization"}}',
|
||||
uri="https://cloudasset.googleapis.com/v1/organizations/123:listAssets",
|
||||
)
|
||||
|
||||
asset_service = MagicMock()
|
||||
asset_service.assets.return_value.list.return_value.execute.side_effect = (
|
||||
http_error
|
||||
)
|
||||
|
||||
with patch(
|
||||
"prowler.providers.gcp.gcp_provider.discovery.build",
|
||||
return_value=asset_service,
|
||||
):
|
||||
with pytest.raises(GCPGetOrganizationProjectsError):
|
||||
GcpProvider.get_projects(
|
||||
credentials=MagicMock(),
|
||||
organization_id="test-organization-id",
|
||||
credentials_file="test_credentials_file",
|
||||
)
|
||||
|
||||
def test_get_projects_organization_id_cloud_asset_api_disabled_raises(self):
|
||||
"""When --organization-id is set and the Cloud Asset API is disabled,
|
||||
get_projects must raise GCPGetOrganizationProjectsError with the
|
||||
enable-API remediation rather than swallowing the error."""
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
disabled_response = MagicMock(status=403, reason="Forbidden")
|
||||
http_error = HttpError(
|
||||
resp=disabled_response,
|
||||
content=b'{"error": {"message": "Cloud Asset API has not been used in project 123 before or it is disabled."}}',
|
||||
uri="https://cloudasset.googleapis.com/v1/organizations/123:listAssets",
|
||||
)
|
||||
|
||||
asset_service = MagicMock()
|
||||
asset_service.assets.return_value.list.return_value.execute.side_effect = (
|
||||
http_error
|
||||
)
|
||||
|
||||
with patch(
|
||||
"prowler.providers.gcp.gcp_provider.discovery.build",
|
||||
return_value=asset_service,
|
||||
):
|
||||
with pytest.raises(GCPGetOrganizationProjectsError) as exc_info:
|
||||
GcpProvider.get_projects(
|
||||
credentials=MagicMock(),
|
||||
organization_id="test-organization-id",
|
||||
credentials_file="test_credentials_file",
|
||||
)
|
||||
|
||||
assert "Cloud Asset API" in str(exc_info.value)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prowler.providers.gcp.services.logging.logging_service import Logging
|
||||
from tests.providers.gcp.gcp_fixtures import (
|
||||
GCP_PROJECT_ID,
|
||||
@@ -291,6 +293,93 @@ class TestGetProjectsCoveredByAggregatedMetric:
|
||||
)
|
||||
assert self._run(logging_client, monitoring_client) == {}
|
||||
|
||||
def test_not_covered_when_sink_filter_omits_activity_stream(self):
|
||||
"""A sink that routes cloudaudit streams but NOT Admin Activity (here,
|
||||
data_access only) does not deliver the entries the CIS metric filters
|
||||
match, so it must not be credited — right service, wrong stream."""
|
||||
logging_client, monitoring_client = self._clients(
|
||||
sink_filter="logName: /logs/cloudaudit.googleapis.com%2Fdata_access"
|
||||
)
|
||||
assert self._run(logging_client, monitoring_client) == {}
|
||||
|
||||
def test_covered_when_sink_filter_carries_activity_stream_encoded(self):
|
||||
"""A sink filtered to the cloudaudit streams (URL-encoded logName form,
|
||||
as returned by the Logging API) delivers every Admin Activity entry the
|
||||
CIS metric filters can match, so it must be credited."""
|
||||
logging_client, monitoring_client = self._clients(
|
||||
sink_filter=(
|
||||
"logName: /logs/cloudaudit.googleapis.com%2Factivity OR "
|
||||
"logName: /logs/cloudaudit.googleapis.com%2Fdata_access"
|
||||
)
|
||||
)
|
||||
assert self._run(logging_client, monitoring_client) == {
|
||||
GCP_PROJECT_ID: "central-metric"
|
||||
}
|
||||
|
||||
def test_covered_when_sink_filter_carries_activity_stream_plain(self):
|
||||
logging_client, monitoring_client = self._clients(
|
||||
sink_filter='logName="projects/p/logs/cloudaudit.googleapis.com/activity"'
|
||||
)
|
||||
assert self._run(logging_client, monitoring_client) == {
|
||||
GCP_PROJECT_ID: "central-metric"
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sink_filter",
|
||||
[
|
||||
# --- Negation: the stream is named but excluded. ---
|
||||
'NOT logName:"projects/p/logs/cloudaudit.googleapis.com%2Factivity"',
|
||||
'-logName:"projects/p/logs/cloudaudit.googleapis.com%2Factivity"',
|
||||
'NOT log_id("cloudaudit.googleapis.com/activity")',
|
||||
# "!=" inequality (and its spaced form) excludes the stream.
|
||||
'logName!="projects/p/logs/cloudaudit.googleapis.com%2Factivity"',
|
||||
'logName != "projects/p/logs/cloudaudit.googleapis.com/activity"',
|
||||
# Activity negated inside a compound filter.
|
||||
'resource.type="gce_instance" AND '
|
||||
'NOT logName:"projects/p/logs/cloudaudit.googleapis.com%2Factivity"',
|
||||
# --- Restriction: the stream is named but AND-narrowed, so only a
|
||||
# subset of Admin Activity entries reaches the bucket. ---
|
||||
'logName:"projects/p/logs/cloudaudit.googleapis.com%2Factivity" '
|
||||
'AND resource.type="gce_instance"',
|
||||
'log_id("cloudaudit.googleapis.com/activity") '
|
||||
'AND resource.type="gce_instance"',
|
||||
'logName="projects/p/logs/cloudaudit.googleapis.com/activity" '
|
||||
"AND severity>=ERROR",
|
||||
'logName:"projects/p/logs/cloudaudit.googleapis.com%2Factivity" '
|
||||
'AND protoPayload.methodName="SetIamPolicy"',
|
||||
# --- OR-ed with a non-audit predicate: fail closed, since we credit
|
||||
# only unions of provable Cloud Audit stream selectors. ---
|
||||
'logName:"projects/p/logs/cloudaudit.googleapis.com%2Factivity" '
|
||||
"OR severity>=ERROR",
|
||||
],
|
||||
)
|
||||
def test_not_covered_when_sink_filter_negated_or_restrictive(self, sink_filter):
|
||||
"""A filter that names the Admin Activity stream but negates, narrows, or
|
||||
mixes in an unprovable predicate is not credited — we credit only filters
|
||||
we can prove deliver every Admin Activity entry the CIS metrics match."""
|
||||
logging_client, monitoring_client = self._clients(sink_filter=sink_filter)
|
||||
assert self._run(logging_client, monitoring_client) == {}
|
||||
|
||||
def test_covered_when_activity_logname_has_hyphenated_path(self):
|
||||
"""A hyphen in the project path must not be mistaken for the ``-`` (NOT)
|
||||
negation operator — the activity stream is still delivered."""
|
||||
logging_client, monitoring_client = self._clients(
|
||||
sink_filter='logName="projects/my-project/logs/cloudaudit.googleapis.com/activity"'
|
||||
)
|
||||
assert self._run(logging_client, monitoring_client) == {
|
||||
GCP_PROJECT_ID: "central-metric"
|
||||
}
|
||||
|
||||
def test_covered_when_sink_filter_uses_log_id_selector(self):
|
||||
"""The ``log_id()`` form is an equivalent positive full-coverage selector
|
||||
of the Admin Activity stream and is credited like the ``logName`` form."""
|
||||
logging_client, monitoring_client = self._clients(
|
||||
sink_filter='log_id("cloudaudit.googleapis.com/activity")'
|
||||
)
|
||||
assert self._run(logging_client, monitoring_client) == {
|
||||
GCP_PROJECT_ID: "central-metric"
|
||||
}
|
||||
|
||||
def test_not_covered_when_sink_destination_bucket_differs(self):
|
||||
logging_client, monitoring_client = self._clients(
|
||||
sink_destination="logging.googleapis.com/projects/x/locations/eu/buckets/other"
|
||||
|
||||
+10
-1
@@ -2,6 +2,15 @@
|
||||
|
||||
All notable changes to the **Prowler UI** are documented in this file.
|
||||
|
||||
## [1.30.1] (Prowler v5.30.1)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- Threat Map no longer shows an empty map for accounts that only have Okta or Google Workspace scans [(#11542)](https://github.com/prowler-cloud/prowler/pull/11542)
|
||||
- Compliance attributes requests now pass the selected scan, so multi-provider universal frameworks (e.g. CSA CCM) load the check IDs of the scan's provider and Azure/GCP requirement details show their findings instead of appearing empty [(#11546)](https://github.com/prowler-cloud/prowler/pull/11546)
|
||||
|
||||
---
|
||||
|
||||
## [1.30.0] (Prowler v5.30.0)
|
||||
|
||||
### 🚀 Added
|
||||
@@ -12,7 +21,7 @@ All notable changes to the **Prowler UI** are documented in this file.
|
||||
### 🔄 Changed
|
||||
|
||||
- Renamed "Customer Support" to "Support Desk" in the side menu, showing it only in Prowler Cloud/Enterprise, while "Community Support" now shows only in Prowler OSS [(#11508)](https://github.com/prowler-cloud/prowler/pull/11508)
|
||||
- Compliance detail page now shows a "still loading" retry state while the API warms its compliance catalog, instead of rendering an empty page [(#4554)](https://github.com/prowler-cloud/prowler-cloud/pull/4554)
|
||||
- Compliance detail page now shows a "still loading" retry state while the API warms its compliance catalog, instead of rendering an empty page [(#11530)](https://github.com/prowler-cloud/prowler/pull/11530)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
|
||||
@@ -73,12 +73,21 @@ export const getComplianceOverviewMetadataInfo = async ({
|
||||
}
|
||||
};
|
||||
|
||||
export const getComplianceAttributes = async (complianceId: string) => {
|
||||
export const getComplianceAttributes = async (
|
||||
complianceId: string,
|
||||
scanId?: string,
|
||||
) => {
|
||||
const headers = await getAuthHeaders({ contentType: false });
|
||||
|
||||
try {
|
||||
const url = new URL(`${apiBaseUrl}/compliance-overviews/attributes`);
|
||||
url.searchParams.append("filter[compliance_id]", complianceId);
|
||||
// Pass the scan so multi-provider universal frameworks (e.g. CSA CCM)
|
||||
// resolve the check IDs for the scan's provider instead of defaulting to
|
||||
// the first provider that declares the framework.
|
||||
if (scanId) {
|
||||
url.searchParams.append("filter[scan_id]", scanId);
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
headers,
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { adaptRegionsOverviewToThreatMap } from "./threat-map.adapter";
|
||||
import type { RegionsOverviewResponse } from "./types";
|
||||
|
||||
function buildRegionsResponse(
|
||||
rows: Array<{ providerType: string; region: string }>,
|
||||
): RegionsOverviewResponse {
|
||||
return {
|
||||
data: rows.map(({ providerType, region }, index) => ({
|
||||
type: "regions-overview",
|
||||
id: `region-${index}`,
|
||||
attributes: {
|
||||
provider_type: providerType,
|
||||
region,
|
||||
total: 10,
|
||||
fail: 4,
|
||||
muted: 0,
|
||||
pass: 6,
|
||||
},
|
||||
})),
|
||||
meta: { version: "v1" },
|
||||
};
|
||||
}
|
||||
|
||||
describe("adaptRegionsOverviewToThreatMap", () => {
|
||||
it("maps okta regions to a global location", () => {
|
||||
const response = buildRegionsResponse([
|
||||
{ providerType: "okta", region: "global" },
|
||||
]);
|
||||
|
||||
const result = adaptRegionsOverviewToThreatMap(response);
|
||||
|
||||
expect(result.locations).toHaveLength(1);
|
||||
expect(result.locations[0]).toMatchObject({
|
||||
providerType: "okta",
|
||||
region: "global",
|
||||
name: "Okta - Global",
|
||||
totalFindings: 10,
|
||||
failFindings: 4,
|
||||
});
|
||||
expect(result.regions).toEqual(["global"]);
|
||||
});
|
||||
|
||||
it("maps googleworkspace regions to a global location", () => {
|
||||
const response = buildRegionsResponse([
|
||||
{ providerType: "googleworkspace", region: "global" },
|
||||
]);
|
||||
|
||||
const result = adaptRegionsOverviewToThreatMap(response);
|
||||
|
||||
expect(result.locations).toHaveLength(1);
|
||||
expect(result.locations[0]).toMatchObject({
|
||||
providerType: "googleworkspace",
|
||||
region: "global",
|
||||
name: "Google Workspace - Global",
|
||||
totalFindings: 10,
|
||||
failFindings: 4,
|
||||
});
|
||||
expect(result.regions).toEqual(["global"]);
|
||||
});
|
||||
});
|
||||
@@ -261,6 +261,19 @@ const ALIBABACLOUD_COORDINATES: Record<string, { lat: number; lng: number }> = {
|
||||
global: { lat: 30.3, lng: 120.2 }, // Global fallback (Hangzhou HQ)
|
||||
};
|
||||
|
||||
// Okta is a SaaS identity platform without user-facing regions
|
||||
const OKTA_COORDINATES: Record<string, { lat: number; lng: number }> = {
|
||||
global: { lat: 37.8, lng: -122.4 }, // Global fallback (San Francisco HQ)
|
||||
};
|
||||
|
||||
// Google Workspace is a SaaS suite without user-facing regions
|
||||
const GOOGLEWORKSPACE_COORDINATES: Record<
|
||||
string,
|
||||
{ lat: number; lng: number }
|
||||
> = {
|
||||
global: { lat: 37.4, lng: -122.1 }, // Global fallback (Mountain View HQ)
|
||||
};
|
||||
|
||||
const PROVIDER_COORDINATES: Record<
|
||||
string,
|
||||
Record<string, { lat: number; lng: number }>
|
||||
@@ -277,6 +290,8 @@ const PROVIDER_COORDINATES: Record<
|
||||
oraclecloud: ORACLECLOUD_COORDINATES,
|
||||
mongodbatlas: MONGODBATLAS_COORDINATES,
|
||||
alibabacloud: ALIBABACLOUD_COORDINATES,
|
||||
okta: OKTA_COORDINATES,
|
||||
googleworkspace: GOOGLEWORKSPACE_COORDINATES,
|
||||
};
|
||||
|
||||
// Returns [lng, lat] format for D3/GeoJSON compatibility
|
||||
|
||||
@@ -87,7 +87,7 @@ export default async function ComplianceDetail({
|
||||
"filter[scan_id]": selectedScanId ?? undefined,
|
||||
},
|
||||
}),
|
||||
getComplianceAttributes(complianceId),
|
||||
getComplianceAttributes(complianceId, selectedScanId ?? undefined),
|
||||
selectedScanId
|
||||
? getScan(selectedScanId, { include: "provider" })
|
||||
: Promise.resolve(null),
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
import { ThreatMap } from "./threat-map";
|
||||
import type { ThreatMapData } from "./threat-map.types";
|
||||
|
||||
vi.mock("next/navigation", () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
useSearchParams: () => new URLSearchParams(),
|
||||
}));
|
||||
|
||||
vi.mock("./horizontal-bar-chart", () => ({
|
||||
HorizontalBarChart: () => <div data-testid="bar-chart" />,
|
||||
}));
|
||||
|
||||
function buildLocation(providerType: string, region: string) {
|
||||
return {
|
||||
id: `${providerType}-${region}`,
|
||||
name: `${providerType} - ${region}`,
|
||||
region,
|
||||
regionCode: region,
|
||||
providerType,
|
||||
coordinates: [-122.4, 37.8] as [number, number],
|
||||
totalFindings: 10,
|
||||
failFindings: 4,
|
||||
riskLevel: "high" as const,
|
||||
severityData: [
|
||||
{ name: "Fail", value: 4, percentage: 40 },
|
||||
{ name: "Pass", value: 6, percentage: 60 },
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
describe("ThreatMap region selector", () => {
|
||||
it("auto-selects the region when it is the only one available", () => {
|
||||
const data: ThreatMapData = {
|
||||
locations: [
|
||||
buildLocation("okta", "global"),
|
||||
buildLocation("googleworkspace", "global"),
|
||||
],
|
||||
regions: ["global"],
|
||||
};
|
||||
|
||||
render(<ThreatMap data={data} />);
|
||||
|
||||
const select = screen.getByRole("combobox", {
|
||||
name: "Filter threat map by region",
|
||||
});
|
||||
expect(select).toHaveValue("global");
|
||||
expect(screen.getByText("Global Regions")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByText("Select a location on the map to view details"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("keeps All Regions as default when there are multiple regions", () => {
|
||||
const data: ThreatMapData = {
|
||||
locations: [
|
||||
buildLocation("aws", "us-east-1"),
|
||||
buildLocation("okta", "global"),
|
||||
],
|
||||
regions: ["global", "us-east-1"],
|
||||
};
|
||||
|
||||
render(<ThreatMap data={data} />);
|
||||
|
||||
const select = screen.getByRole("combobox", {
|
||||
name: "Filter threat map by region",
|
||||
});
|
||||
expect(select).toHaveValue("All Regions");
|
||||
expect(
|
||||
screen.getByRole("option", { name: "All Regions" }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("shows the global option capitalized while keeping its filter value", () => {
|
||||
const data: ThreatMapData = {
|
||||
locations: [
|
||||
buildLocation("aws", "us-east-1"),
|
||||
buildLocation("okta", "global"),
|
||||
],
|
||||
regions: ["global", "us-east-1"],
|
||||
};
|
||||
|
||||
render(<ThreatMap data={data} />);
|
||||
|
||||
const globalOption = screen.getByRole("option", { name: "Global" });
|
||||
expect(globalOption).toHaveValue("global");
|
||||
expect(
|
||||
screen.getByRole("option", { name: "us-east-1" }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -124,7 +124,11 @@ export function ThreatMap({
|
||||
x: number;
|
||||
y: number;
|
||||
} | null>(null);
|
||||
const [selectedRegion, setSelectedRegion] = useState("All Regions");
|
||||
// With a single region "All Regions" adds nothing, so it starts selected
|
||||
const hasSingleRegion = data.regions.length === 1;
|
||||
const [selectedRegion, setSelectedRegion] = useState(
|
||||
hasSingleRegion ? data.regions[0] : "All Regions",
|
||||
);
|
||||
const [worldData, setWorldData] = useState<FeatureCollection | null>(null);
|
||||
const [isLoadingMap, setIsLoadingMap] = useState(true);
|
||||
const [dimensions, setDimensions] = useState<{
|
||||
@@ -424,10 +428,12 @@ export function ThreatMap({
|
||||
onChange={(e) => setSelectedRegion(e.target.value)}
|
||||
className="border-border-neutral-primary bg-bg-neutral-secondary text-text-neutral-primary appearance-none rounded-lg border px-4 py-2 pr-10 text-sm focus:outline-none focus-visible:ring-2 focus-visible:ring-offset-2"
|
||||
>
|
||||
<option value="All Regions">All Regions</option>
|
||||
{!hasSingleRegion && (
|
||||
<option value="All Regions">All Regions</option>
|
||||
)}
|
||||
{sortedRegions.map((region) => (
|
||||
<option key={region} value={region}>
|
||||
{region}
|
||||
{region.toLowerCase() === "global" ? "Global" : region}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
@@ -467,7 +473,7 @@ export function ThreatMap({
|
||||
<div className="border-border-neutral-primary bg-bg-neutral-secondary absolute bottom-4 left-4 flex items-center gap-2 rounded-full border px-3 py-1.5">
|
||||
<div
|
||||
aria-hidden="true"
|
||||
className="bg-data-critical h-3 w-3 rounded"
|
||||
className="bg-bg-data-critical h-3 w-3 rounded"
|
||||
/>
|
||||
<span className="text-text-neutral-primary text-sm font-medium">
|
||||
{locationCount} Locations
|
||||
|
||||
Reference in New Issue
Block a user