Compare commits

...

27 Commits

Author SHA1 Message Date
Pepe Fagoaga 9b0487eb9a chore(postgres): use pgbouncer 2025-06-09 13:08:57 +02:00
Pablo Lara 3a99909b75 chore: align Next.js version to 14.2.29 across Prowler and Cloud (#7962) 2025-06-06 13:54:42 +02:00
Pablo Lara 2ecd9ad2c5 docs: update changelog (#7960) 2025-06-06 13:17:38 +02:00
Alejandro Bailo 50dc396aa3 feat: scan id filter drowpdown (#7949)
Co-authored-by: Pablo Lara <larabjj@gmail.com>
2025-06-06 12:38:14 +02:00
Andoni Alonso acf333493a chore(api): reorder docker layers to speed up build times (#7957) 2025-06-06 10:42:14 +02:00
Pedro Martín bd6272f5a7 feat(docs): add information about tenants and read-only roles (#7956) 2025-06-06 10:14:33 +02:00
Pepe Fagoaga 8c95e1efaf chore: update API changelog for v5.7.3 (#7948) 2025-06-05 15:54:36 +02:00
Hugo Pereira Brito 845a0aa0d5 fix(changelog): add entries for password encryption in v5.7.3 (#7939)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
2025-06-05 14:23:12 +02:00
Hugo Pereira Brito 75a11be9e6 fix(docs): add final permission assignments example (#7943) 2025-06-05 18:07:43 +05:45
Hugo Pereira Brito a778d005b6 fix(docs): add mfa warning for users (#7924)
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2025-06-05 17:55:27 +05:45
Pedro Martín 1281f4ec5e chore(changelog): update following the correct format (#7908) 2025-06-05 17:52:36 +05:45
Víctor Fernández Poyatos 6332427e5e fix(compliance): add manual status to requirements (#7938) 2025-06-05 10:54:51 +02:00
Alejandro Bailo d89df83904 fix: Improve the perfomance removing regions heatmap (#7934) 2025-06-05 08:13:47 +02:00
Víctor Fernández Poyatos be420afebc feat(database): handle already closed connections (#7935) 2025-06-04 16:09:36 +02:00
Adrián Jesús Peña Rodríguez fb914a2c90 revert: remove get_with_retry (#7932) 2025-06-04 15:01:47 +02:00
Pablo Lara 4ac3cfc33d docs: update changelog (#7931) 2025-06-04 13:54:25 +02:00
Alejandro Bailo c74360ab63 fix: clear filters sync (#7928) 2025-06-04 13:32:52 +02:00
Alejandro Bailo 4dc4d82d42 feat: aws-well-architected compliance detailed view (#7925) 2025-06-04 12:26:27 +02:00
Víctor Fernández Poyatos 6e7a32cb51 revert(views): calling order to initial view method (#7921) 2025-06-03 16:38:00 +02:00
Alejandro Bailo 49e501c4be feat: CIS compliance detail view (#7913)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
2025-06-03 15:47:46 +02:00
Víctor Fernández Poyatos 9ee78fe65f fix(views): calling order to initial view method (#7918) 2025-06-03 13:34:44 +02:00
Víctor Fernández Poyatos 7a0549d39c fix(rls): Apply persistent RLS transactions (#7916) 2025-06-03 13:10:41 +02:00
Alejandro Bailo 3e8c86d880 feat: ISO compliance detail view (#7897)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
Co-authored-by: Pablo Lara <larabjj@gmail.com>
2025-06-03 09:20:52 +02:00
Pablo Lara e34c18757d fix: Fix named export for addCredentialsServiceAccountFormSchema (#7909) 2025-06-03 08:33:24 +02:00
Alejandro Bailo 5c1a47d108 feat: compliance detail view + ENS (#7853)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
2025-06-02 18:20:22 +02:00
Víctor Fernández Poyatos 59c51d5a4a feat(compliance): Rework compliance overviews (#7877) 2025-06-02 17:06:24 +02:00
Pedro Martín 66aa67f636 feat(changelog): update version with fixes (#7904)
Co-authored-by: Rubén De la Torre Vico <ruben@prowler.com>
2025-06-02 12:32:45 +02:00
110 changed files with 8405 additions and 1355 deletions
+1 -1
View File
@@ -16,7 +16,7 @@ AUTH_SECRET="N/c6mnaS5+SWq81+819OrzQZlmx1Vxtp/orjttJSmw8="
PROWLER_API_VERSION="stable"
# PostgreSQL settings
# If running Django and celery on host, use 'localhost', else use 'postgres-db'
POSTGRES_HOST=postgres-db
POSTGRES_HOST=postgres-db-proxy
POSTGRES_PORT=5432
POSTGRES_ADMIN_USER=prowler_admin
POSTGRES_ADMIN_PASSWORD=postgres
+15 -1
View File
@@ -6,13 +6,27 @@ All notable changes to the **Prowler API** are documented in this file.
### Added
- Support GCP Service Account key. [(#7824)](https://github.com/prowler-cloud/prowler/pull/7824)
- Added new `GET /compliance-overviews` endpoints to retrieve compliance metadata and specific requirements statuses [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877).
### Changed
- Reworked `GET /compliance-overviews` to return proper requirement metrics [(#7877)](https://github.com/prowler-cloud/prowler/pull/7877).
---
## [v1.8.3] (Prowler v5.7.3)
### Added
- Database backend to handle already closed connections [(#7935)](https://github.com/prowler-cloud/prowler/pull/7935).
### Changed
- Renamed field encrypted_password to password for M365 provider [(#7784)](https://github.com/prowler-cloud/prowler/pull/7784)
### Fixed
- Fixed transaction persistence with RLS operations [(#7916)](https://github.com/prowler-cloud/prowler/pull/7916).
- Reverted the change `get_with_retry` to use the original `get` method for retrieving tasks [(#7932)](https://github.com/prowler-cloud/prowler/pull/7932).
- Fixed the connection status verification before launching a scan [(#7831)](https://github.com/prowler-cloud/prowler/pull/7831)
---
## [v1.8.2] (Prowler v5.7.2)
@@ -28,7 +42,7 @@ All notable changes to the **Prowler API** are documented in this file.
## [v1.8.1] (Prowler v5.7.1)
### Fixed
- Added database index to improve performance on finding lookup. [(#7800)](https://github.com/prowler-cloud/prowler/pull/7800)
- Added database index to improve performance on finding lookup [(#7800)](https://github.com/prowler-cloud/prowler/pull/7800).
---
+3 -4
View File
@@ -37,18 +37,17 @@ COPY pyproject.toml ./
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir poetry
COPY src/backend/ ./backend/
ENV PATH="/home/prowler/.local/bin:$PATH"
# Add `--no-root` to avoid installing the current project as a package
RUN poetry install --no-root && \
rm -rf ~/.cache/pip
COPY docker-entrypoint.sh ./docker-entrypoint.sh
RUN poetry run python "$(poetry env info --path)/src/prowler/prowler/providers/m365/lib/powershell/m365_powershell.py"
COPY src/backend/ ./backend/
COPY docker-entrypoint.sh ./docker-entrypoint.sh
WORKDIR /home/prowler/backend
# Development image
+46 -20
View File
@@ -1,5 +1,4 @@
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
from rest_framework.filters import SearchFilter
@@ -47,11 +46,9 @@ class BaseViewSet(ModelViewSet):
class BaseRLSViewSet(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
# Ideally, this logic would be in the `.setup()` method but DRF view sets don't call it
# https://docs.djangoproject.com/en/5.1/ref/class-based-views/base/#django.views.generic.base.View.setup
if request.auth is None:
@@ -61,9 +58,19 @@ class BaseRLSViewSet(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
self.request.tenant_id = tenant_id
self._rls_cm = rls_transaction(tenant_id)
self._rls_cm.__enter__()
def finalize_response(self, request, response, *args, **kwargs):
response = super().finalize_response(request, response, *args, **kwargs)
if hasattr(self, "_rls_cm"):
self._rls_cm.__exit__(None, None, None)
del self._rls_cm
return response
def get_serializer_context(self):
context = super().get_serializer_context()
@@ -73,8 +80,7 @@ class BaseRLSViewSet(BaseViewSet):
class BaseTenantViewset(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
tenant = super().dispatch(request, *args, **kwargs)
tenant = super().dispatch(request, *args, **kwargs)
try:
# If the request is a POST, create the admin role
@@ -109,6 +115,8 @@ class BaseTenantViewset(BaseViewSet):
pass # Tenant might not exist, handle gracefully
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
if request.auth is None:
raise NotAuthenticated
@@ -117,19 +125,27 @@ class BaseTenantViewset(BaseViewSet):
raise NotAuthenticated("Tenant ID is not present in token")
user_id = str(request.user.id)
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
return super().initial(request, *args, **kwargs)
self._rls_cm = rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR)
self._rls_cm.__enter__()
def finalize_response(self, request, response, *args, **kwargs):
response = super().finalize_response(request, response, *args, **kwargs)
if hasattr(self, "_rls_cm"):
self._rls_cm.__exit__(None, None, None)
del self._rls_cm
return response
class BaseUserViewset(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
# TODO refactor after improving RLS on users
if request.stream is not None and request.stream.method == "POST":
return super().initial(request, *args, **kwargs)
return
if request.auth is None:
raise NotAuthenticated
@@ -137,6 +153,16 @@ class BaseUserViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
self.request.tenant_id = tenant_id
self._rls_cm = rls_transaction(tenant_id)
self._rls_cm.__enter__()
def finalize_response(self, request, response, *args, **kwargs):
response = super().finalize_response(request, response, *args, **kwargs)
if hasattr(self, "_rls_cm"):
self._rls_cm.__exit__(None, None, None)
del self._rls_cm
return response
+7 -7
View File
@@ -190,6 +190,8 @@ def generate_compliance_overview_template(prowler_compliance: dict):
total_checks = len(requirement.Checks)
checks_dict = {check: None for check in requirement.Checks}
req_status_val = "MANUAL" if total_checks == 0 else "PASS"
# Build requirement dictionary
requirement_dict = {
"name": requirement.Name or requirement.Id,
@@ -204,20 +206,18 @@ def generate_compliance_overview_template(prowler_compliance: dict):
"manual": 0,
"total": total_checks,
},
"status": "PASS",
"status": req_status_val,
}
# Update requirements status
if total_checks == 0:
# Update requirements status counts for the framework
if req_status_val == "MANUAL":
requirements_status["manual"] += 1
elif req_status_val == "PASS":
requirements_status["passed"] += 1
# Add requirement to compliance requirements
compliance_requirements[requirement.Id] = requirement_dict
# Calculate pending requirements
pending_requirements = total_requirements - requirements_status["manual"]
requirements_status["passed"] = pending_requirements
# Build compliance dictionary
compliance_dict = {
"framework": compliance_data.Framework,
+124 -11
View File
@@ -1,3 +1,4 @@
import re
import secrets
import uuid
from contextlib import contextmanager
@@ -152,6 +153,28 @@ def delete_related_daily_task(provider_id: str):
PeriodicTask.objects.filter(name=task_name).delete()
def create_objects_in_batches(
tenant_id: str, model, objects: list, batch_size: int = 500
):
"""
Bulk-create model instances in repeated, per-tenant RLS transactions.
All chunks execute in their own transaction, so no single transaction
grows too large.
Args:
tenant_id (str): UUID string of the tenant under which to set RLS.
model: Django model class whose `.objects.bulk_create()` will be called.
objects (list): List of model instances (unsaved) to bulk-create.
batch_size (int): Maximum number of objects per bulk_create call.
"""
total = len(objects)
for i in range(0, total, batch_size):
chunk = objects[i : i + batch_size]
with rls_transaction(value=tenant_id, parameter=POSTGRES_TENANT_VAR):
model.objects.bulk_create(chunk, batch_size)
# Postgres Enums
@@ -227,6 +250,72 @@ def register_enum(apps, schema_editor, enum_class): # noqa: F841
register_adapter(enum_class, enum_adapter)
def _should_create_index_on_partition(
partition_name: str, all_partitions: bool = False
) -> bool:
"""
Determine if we should create an index on this partition.
Args:
partition_name: The name of the partition (e.g., "findings_2025_aug", "findings_default")
all_partitions: If True, create on all partitions. If False, only current/future partitions.
Returns:
bool: True if index should be created on this partition, False otherwise.
"""
if all_partitions:
return True
# Extract date from partition name if it follows the pattern
# Partition names look like: findings_2025_aug, findings_2025_jul, etc.
date_pattern = r"(\d{4})_([a-z]{3})$"
match = re.search(date_pattern, partition_name)
if not match:
# If we can't parse the date, include it to be safe (e.g., default partition)
return True
try:
year_str, month_abbr = match.groups()
year = int(year_str)
# Map month abbreviations to numbers
month_map = {
"jan": 1,
"feb": 2,
"mar": 3,
"apr": 4,
"may": 5,
"jun": 6,
"jul": 7,
"aug": 8,
"sep": 9,
"oct": 10,
"nov": 11,
"dec": 12,
}
month = month_map.get(month_abbr.lower())
if month is None:
# Unknown month abbreviation, include it to be safe
return True
partition_date = datetime(year, month, 1, tzinfo=timezone.utc)
# Get current month start
now = datetime.now(timezone.utc)
current_month_start = now.replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
# Include current month and future partitions
return partition_date >= current_month_start
except (ValueError, TypeError):
# If date parsing fails, include it to be safe
return True
def create_index_on_partitions(
apps, # noqa: F841
schema_editor,
@@ -235,16 +324,39 @@ def create_index_on_partitions(
columns: str,
method: str = "BTREE",
where: str = "",
all_partitions: bool = True,
):
"""
Create an index on every existing partition of `parent_table`.
Create an index on existing partitions of `parent_table`.
Args:
parent_table: The name of the root table (e.g. "findings").
index_name: A short name for the index (will be prefixed per-partition).
columns: The parenthesized column list, e.g. "tenant_id, scan_id, status".
method: The index method—BTREE, GIN, etc. Defaults to BTREE.
where: Optional WHERE clause (without the leading "WHERE"), e.g. "status = 'FAIL'".
method: The index method—BTREE, GIN, etc. Defaults to BTREE.
where: Optional WHERE clause (without the leading "WHERE"), e.g. "status = 'FAIL'".
all_partitions: Whether to create indexes on all partitions or just current/future ones.
Defaults to False (current/future only) to avoid maintenance overhead
on old partitions where the index may not be needed.
Examples:
# Create index only on current and future partitions (recommended for new indexes)
create_index_on_partitions(
apps, schema_editor,
parent_table="findings",
index_name="new_performance_idx",
columns="tenant_id, status, severity",
all_partitions=False # Default behavior
)
# Create index on all partitions (use when migrating existing critical indexes)
create_index_on_partitions(
apps, schema_editor,
parent_table="findings",
index_name="critical_existing_idx",
columns="tenant_id, scan_id",
all_partitions=True
)
"""
with connection.cursor() as cursor:
cursor.execute(
@@ -259,13 +371,14 @@ def create_index_on_partitions(
where_sql = f" WHERE {where}" if where else ""
for partition in partitions:
idx_name = f"{partition.replace('.', '_')}_{index_name}"
sql = (
f"CREATE INDEX CONCURRENTLY IF NOT EXISTS {idx_name} "
f"ON {partition} USING {method} ({columns})"
f"{where_sql};"
)
schema_editor.execute(sql)
if _should_create_index_on_partition(partition, all_partitions):
idx_name = f"{partition.replace('.', '_')}_{index_name}"
sql = (
f"CREATE INDEX CONCURRENTLY IF NOT EXISTS {idx_name} "
f"ON {partition} USING {method} ({columns})"
f"{where_sql};"
)
schema_editor.execute(sql)
def drop_index_on_partitions(
@@ -279,7 +392,7 @@ def drop_index_on_partitions(
Args:
parent_table: The name of the root table (e.g. "findings").
index_name: The same short name used when creating them.
index_name: The same short name used when creating them.
"""
with connection.cursor() as cursor:
cursor.execute(
+34 -4
View File
@@ -3,7 +3,7 @@ from rest_framework import status
from rest_framework.exceptions import APIException
from rest_framework_json_api.exceptions import exception_handler
from rest_framework_json_api.serializers import ValidationError
from rest_framework_simplejwt.exceptions import TokenError, InvalidToken
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
class ModelValidationError(ValidationError):
@@ -32,6 +32,31 @@ class InvitationTokenExpiredException(APIException):
default_code = "token_expired"
# Task Management Exceptions (non-HTTP)
class TaskManagementError(Exception):
"""Base exception for task management errors."""
def __init__(self, task=None):
self.task = task
super().__init__()
class TaskFailedException(TaskManagementError):
"""Raised when a task has failed."""
class TaskNotFoundException(TaskManagementError):
"""Raised when a task is not found."""
class TaskInProgressException(TaskManagementError):
"""Raised when a task is running but there's no related Task object to return."""
def __init__(self, task_result=None):
self.task_result = task_result
super().__init__()
def custom_exception_handler(exc, context):
if isinstance(exc, django_validation_error):
if hasattr(exc, "error_dict"):
@@ -39,7 +64,12 @@ def custom_exception_handler(exc, context):
else:
exc = ValidationError(detail=exc.messages[0], code=exc.code)
elif isinstance(exc, (TokenError, InvalidToken)):
exc.detail["messages"] = [
message_item["message"] for message_item in exc.detail["messages"]
]
if (
hasattr(exc, "detail")
and isinstance(exc.detail, dict)
and "messages" in exc.detail
):
exc.detail["messages"] = [
message_item["message"] for message_item in exc.detail["messages"]
]
return exception_handler(exc, context)
+4 -5
View File
@@ -22,7 +22,7 @@ from api.db_utils import (
StatusEnumField,
)
from api.models import (
ComplianceOverview,
ComplianceRequirementOverview,
Finding,
Integration,
Invitation,
@@ -637,12 +637,11 @@ class RoleFilter(FilterSet):
class ComplianceOverviewFilter(FilterSet):
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
provider_type = ChoiceFilter(choices=Provider.ProviderChoices.choices)
provider_type__in = ChoiceInFilter(choices=Provider.ProviderChoices.choices)
scan_id = UUIDFilter(field_name="scan__id")
scan_id = UUIDFilter(field_name="scan_id")
region = CharFilter(field_name="region")
class Meta:
model = ComplianceOverview
model = ComplianceRequirementOverview
fields = {
"inserted_at": ["date", "gte", "lte"],
"compliance_id": ["exact", "icontains"],
@@ -0,0 +1,124 @@
# Generated by Django 5.1.8 on 2025-05-21 11:37
import uuid
import django.db.models.deletion
from django.db import migrations, models
import api.db_utils
import api.rls
from api.rls import RowLevelSecurityConstraint
class Migration(migrations.Migration):
dependencies = [
("api", "0026_provider_secret_gcp_service_account"),
]
operations = [
migrations.CreateModel(
name="ComplianceRequirementOverview",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("compliance_id", models.TextField(blank=False)),
("framework", models.TextField(blank=False)),
("version", models.TextField(blank=True)),
("description", models.TextField(blank=True)),
("region", models.TextField(blank=False)),
("requirement_id", models.TextField(blank=False)),
(
"requirement_status",
api.db_utils.StatusEnumField(
choices=[
("FAIL", "Fail"),
("PASS", "Pass"),
("MANUAL", "Manual"),
]
),
),
("passed_checks", models.IntegerField(default=0)),
("failed_checks", models.IntegerField(default=0)),
("total_checks", models.IntegerField(default=0)),
(
"scan",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="compliance_requirements_overviews",
related_query_name="compliance_requirements_overview",
to="api.scan",
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "compliance_requirements_overviews",
"abstract": False,
"indexes": [
models.Index(
fields=["tenant_id", "scan_id"], name="cro_tenant_scan_idx"
),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id"],
name="cro_scan_comp_idx",
),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id", "region"],
name="cro_scan_comp_reg_idx",
),
models.Index(
fields=[
"tenant_id",
"scan_id",
"compliance_id",
"requirement_id",
],
name="cro_scan_comp_req_idx",
),
models.Index(
fields=[
"tenant_id",
"scan_id",
"compliance_id",
"requirement_id",
"region",
],
name="cro_scan_comp_req_reg_idx",
),
],
"constraints": [
models.UniqueConstraint(
fields=(
"tenant_id",
"scan_id",
"compliance_id",
"requirement_id",
"region",
),
name="unique_tenant_compliance_requirement_overview",
)
],
},
),
migrations.AddConstraint(
model_name="ComplianceRequirementOverview",
constraint=RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_compliancerequirementoverview",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]
@@ -0,0 +1,29 @@
from functools import partial
from django.db import migrations
from api.db_utils import create_index_on_partitions, drop_index_on_partitions
class Migration(migrations.Migration):
atomic = False
dependencies = [
("api", "0027_compliance_requirement_overviews"),
]
operations = [
migrations.RunPython(
partial(
create_index_on_partitions,
parent_table="findings",
index_name="find_tenant_scan_check_idx",
columns="tenant_id, scan_id, check_id",
),
reverse_code=partial(
drop_index_on_partitions,
parent_table="findings",
index_name="find_tenant_scan_check_idx",
),
)
]
@@ -0,0 +1,17 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0028_findings_check_index_partitions"),
]
operations = [
migrations.AddIndex(
model_name="finding",
index=models.Index(
fields=["tenant_id", "scan_id", "check_id"],
name="find_tenant_scan_check_idx",
),
),
]
+76 -40
View File
@@ -1,9 +1,7 @@
import json
import re
import time
from uuid import UUID, uuid4
from config.env import env
from cryptography.fernet import Fernet
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
@@ -354,42 +352,6 @@ class ProviderGroupMembership(RowLevelSecurityProtectedModel):
resource_name = "provider_groups-provider"
class TaskManager(models.Manager):
def get_with_retry(
self,
id: str,
max_retries: int = None,
delay_seconds: float = None,
):
"""
Retry fetching a Task by ID in case it hasn't been created yet.
Args:
id (str): The Celery task ID (expected to match Task model PK).
max_retries (int, optional): Number of retry attempts. Defaults to env TASK_RETRY_ATTEMPTS or 5.
delay_seconds (float, optional): Delay between retries in seconds. Defaults to env TASK_RETRY_DELAY_SECONDS or 0.1.
Returns:
Task: The retrieved Task instance.
Raises:
Task.DoesNotExist: If the task is not found after all retries.
"""
max_retries = max_retries or env.int("TASK_RETRY_ATTEMPTS", default=5)
delay_seconds = delay_seconds or env.float(
"TASK_RETRY_DELAY_SECONDS", default=0.1
)
for _attempt in range(max_retries):
try:
return self.get(id=id)
except self.model.DoesNotExist:
time.sleep(delay_seconds)
raise self.model.DoesNotExist(
f"Task with ID {id} not found after {max_retries} retries."
)
class Task(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
@@ -402,8 +364,6 @@ class Task(RowLevelSecurityProtectedModel):
blank=True,
)
objects = TaskManager()
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "tasks"
@@ -802,6 +762,10 @@ class Finding(PostgresPartitionedModel, RowLevelSecurityProtectedModel):
GinIndex(fields=["resource_services"], name="gin_find_service_idx"),
GinIndex(fields=["resource_regions"], name="gin_find_region_idx"),
GinIndex(fields=["resource_types"], name="gin_find_rtype_idx"),
models.Index(
fields=["tenant_id", "scan_id", "check_id"],
name="find_tenant_scan_check_idx",
),
]
class JSONAPIMeta:
@@ -1183,6 +1147,78 @@ class ComplianceOverview(RowLevelSecurityProtectedModel):
resource_name = "compliance-overviews"
class ComplianceRequirementOverview(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
compliance_id = models.TextField(blank=False)
framework = models.TextField(blank=False)
version = models.TextField(blank=True)
description = models.TextField(blank=True)
region = models.TextField(blank=False)
requirement_id = models.TextField(blank=False)
requirement_status = StatusEnumField(choices=StatusChoices)
passed_checks = models.IntegerField(default=0)
failed_checks = models.IntegerField(default=0)
total_checks = models.IntegerField(default=0)
scan = models.ForeignKey(
Scan,
on_delete=models.CASCADE,
related_name="compliance_requirements_overviews",
related_query_name="compliance_requirements_overview",
)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "compliance_requirements_overviews"
constraints = [
models.UniqueConstraint(
fields=(
"tenant_id",
"scan_id",
"compliance_id",
"requirement_id",
"region",
),
name="unique_tenant_compliance_requirement_overview",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "DELETE"],
),
]
indexes = [
models.Index(fields=["tenant_id", "scan_id"], name="cro_tenant_scan_idx"),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id"],
name="cro_scan_comp_idx",
),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id", "region"],
name="cro_scan_comp_reg_idx",
),
models.Index(
fields=["tenant_id", "scan_id", "compliance_id", "requirement_id"],
name="cro_scan_comp_req_idx",
),
models.Index(
fields=[
"tenant_id",
"scan_id",
"compliance_id",
"requirement_id",
"region",
],
name="cro_scan_comp_req_reg_idx",
),
]
class JSONAPIMeta:
resource_name = "compliance-requirements-overviews"
class ScanSummary(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()
+1 -1
View File
@@ -1,4 +1,4 @@
from rest_framework_json_api.pagination import JsonApiPageNumberPagination
from drf_spectacular_jsonapi.schemas.pagination import JsonApiPageNumberPagination
class ComplianceOverviewPagination(JsonApiPageNumberPagination):
+268 -278
View File
@@ -10,9 +10,7 @@ paths:
/api/v1/compliance-overviews:
get:
operationId: compliance_overviews_list
description: Retrieve an overview of all the compliance in a given scan. If
no region filters are provided, the region with the most fails will be returned
by default.
description: Retrieve an overview of all the compliance in a given scan.
summary: List compliance overviews for a scan
parameters:
- in: query
@@ -22,15 +20,13 @@ paths:
items:
type: string
enum:
- inserted_at
- compliance_id
- id
- framework
- version
- requirements_status
- region
- provider_type
- scan
- url
- requirements_passed
- requirements_failed
- requirements_manual
- total_requirements
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
@@ -74,44 +70,6 @@ paths:
schema:
type: string
format: date-time
- in: query
name: filter[provider_type]
schema:
type: string
enum:
- aws
- azure
- gcp
- kubernetes
- m365
description: |-
* `aws` - AWS
* `azure` - Azure
* `gcp` - GCP
* `kubernetes` - Kubernetes
* `m365` - M365
- in: query
name: filter[provider_type__in]
schema:
type: array
items:
type: string
enum:
- aws
- azure
- gcp
- kubernetes
- m365
description: |-
Multiple values may be separated by commas.
* `aws` - AWS
* `azure` - Azure
* `gcp` - GCP
* `kubernetes` - Kubernetes
* `m365` - M365
explode: false
style: form
- in: query
name: filter[region]
schema:
@@ -171,14 +129,8 @@ paths:
items:
type: string
enum:
- inserted_at
- -inserted_at
- compliance_id
- -compliance_id
- framework
- -framework
- region
- -region
explode: false
tags:
- Compliance Overview
@@ -190,41 +142,43 @@ paths:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/PaginatedComplianceOverviewList'
description: ''
/api/v1/compliance-overviews/{id}:
description: Compliance overviews obtained successfully
'202':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/PaginatedTaskList'
description: The task is in progress
'500':
description: Compliance overviews generation task failed
/api/v1/compliance-overviews/attributes:
get:
operationId: compliance_overviews_retrieve
description: Fetch detailed information about a specific compliance overview
by its ID, including detailed requirement information and check's status.
summary: Retrieve data from a specific compliance overview
operationId: compliance_overviews_attributes_retrieve
description: Retrieve detailed attribute information for all requirements in
a specific compliance framework along with the associated check IDs for each
requirement.
summary: Get compliance requirement attributes
parameters:
- in: query
name: fields[compliance-overviews]
name: fields[compliance-requirements-attributes]
schema:
type: array
items:
type: string
enum:
- inserted_at
- compliance_id
- id
- framework
- version
- requirements_status
- region
- provider_type
- scan
- url
- description
- requirements
- attributes
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: path
name: id
- in: query
name: filter[compliance_id]
schema:
type: string
format: uuid
description: A UUID string identifying this compliance overview.
description: Compliance framework ID to get attributes for.
required: true
tags:
- Compliance Overview
@@ -235,8 +189,8 @@ paths:
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/ComplianceOverviewFullResponse'
description: ''
$ref: '#/components/schemas/PaginatedComplianceOverviewAttributesList'
description: Compliance attributes obtained successfully
/api/v1/compliance-overviews/metadata:
get:
operationId: compliance_overviews_metadata_retrieve
@@ -271,8 +225,142 @@ paths:
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/ComplianceOverviewMetadataResponse'
description: ''
$ref: '#/components/schemas/OpenApiResponseResponse'
description: Compliance overviews metadata obtained successfully
'202':
description: The task is in progress
'500':
description: Compliance overviews generation task failed
/api/v1/compliance-overviews/requirements:
get:
operationId: compliance_overviews_requirements_retrieve
description: Retrieve a detailed overview of compliance requirements in a given
scan, grouped by compliance framework. This endpoint provides requirement-level
details and aggregates status across regions.
summary: List compliance requirements overview for a scan
parameters:
- in: query
name: fields[compliance-requirements-details]
schema:
type: array
items:
type: string
enum:
- id
- framework
- version
- description
- status
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: query
name: filter[compliance_id]
schema:
type: string
description: Compliance ID.
required: true
- in: query
name: filter[compliance_id__icontains]
schema:
type: string
- in: query
name: filter[framework]
schema:
type: string
- in: query
name: filter[framework__icontains]
schema:
type: string
- in: query
name: filter[framework__iexact]
schema:
type: string
- in: query
name: filter[inserted_at]
schema:
type: string
format: date
- in: query
name: filter[inserted_at__date]
schema:
type: string
format: date
- in: query
name: filter[inserted_at__gte]
schema:
type: string
format: date-time
- in: query
name: filter[inserted_at__lte]
schema:
type: string
format: date-time
- in: query
name: filter[region]
schema:
type: string
- in: query
name: filter[region__icontains]
schema:
type: string
- in: query
name: filter[region__in]
schema:
type: array
items:
type: string
description: Multiple values may be separated by commas.
explode: false
style: form
- in: query
name: filter[scan_id]
schema:
type: string
format: uuid
description: Related scan ID.
required: true
- name: filter[search]
required: false
in: query
description: A search term.
schema:
type: string
- in: query
name: filter[version]
schema:
type: string
- in: query
name: filter[version__icontains]
schema:
type: string
- name: sort
required: false
in: query
description: '[list of fields to sort by](https://jsonapi.org/format/#fetching-sorting)'
schema:
type: array
items:
type: string
enum:
- compliance_id
- -compliance_id
explode: false
tags:
- Compliance Overview
security:
- jwtAuth: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/PaginatedComplianceOverviewDetailList'
description: Compliance requirement details obtained successfully
'202':
description: The task is in progress
'500':
description: Compliance overviews generation task failed
/api/v1/findings:
get:
operationId: findings_list
@@ -6839,80 +6927,37 @@ components:
properties:
type:
allOf:
- $ref: '#/components/schemas/Type7f7Enum'
- $ref: '#/components/schemas/ComplianceOverviewTypeEnum'
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
id:
type: string
format: uuid
id: {}
attributes:
type: object
properties:
inserted_at:
id:
type: string
format: date-time
readOnly: true
compliance_id:
type: string
maxLength: 100
framework:
type: string
maxLength: 100
version:
type: string
maxLength: 50
requirements_status:
type: object
properties:
passed:
type: integer
failed:
type: integer
manual:
type: integer
total:
type: integer
readOnly: true
region:
type: string
maxLength: 50
provider_type:
type: string
nullable: true
readOnly: true
requirements_passed:
type: integer
requirements_failed:
type: integer
requirements_manual:
type: integer
total_requirements:
type: integer
required:
- compliance_id
- id
- framework
relationships:
type: object
properties:
scan:
type: object
properties:
data:
type: object
properties:
id:
type: string
format: uuid
type:
type: string
enum:
- scans
title: Resource Type Name
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common
attributes and relationships.
required:
- id
- type
required:
- data
description: The identifier of the related object.
title: Resource Identifier
nullable: true
ComplianceOverviewFull:
- version
- requirements_passed
- requirements_failed
- requirements_manual
- total_requirements
ComplianceOverviewAttributes:
type: object
required:
- type
@@ -6921,134 +6966,78 @@ components:
properties:
type:
allOf:
- $ref: '#/components/schemas/Type7f7Enum'
- $ref: '#/components/schemas/ComplianceOverviewAttributesTypeEnum'
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
id:
type: string
format: uuid
id: {}
attributes:
type: object
properties:
inserted_at:
id:
type: string
format: date-time
readOnly: true
compliance_id:
type: string
maxLength: 100
framework:
type: string
maxLength: 100
version:
type: string
maxLength: 50
requirements_status:
type: object
properties:
passed:
type: integer
failed:
type: integer
manual:
type: integer
total:
type: integer
readOnly: true
region:
type: string
maxLength: 50
provider_type:
type: string
nullable: true
readOnly: true
description:
type: string
requirements:
type: object
properties:
requirement_id:
type: object
properties:
name:
type: string
checks:
type: object
properties:
check_name:
type: object
properties:
status:
type: string
enum:
- PASS
- FAIL
- null
description: Each key in the 'checks' object is a check name,
with values as 'PASS', 'FAIL', or null.
status:
type: string
enum:
- PASS
- FAIL
- MANUAL
attributes:
type: array
items:
type: object
description:
type: string
checks_status:
type: object
properties:
total:
type: integer
pass:
type: integer
fail:
type: integer
manual:
type: integer
readOnly: true
attributes: {}
required:
- compliance_id
- id
- framework
relationships:
- version
- description
- attributes
ComplianceOverviewAttributesTypeEnum:
type: string
enum:
- compliance-requirements-attributes
ComplianceOverviewDetail:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
allOf:
- $ref: '#/components/schemas/ComplianceOverviewDetailTypeEnum'
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
id: {}
attributes:
type: object
properties:
scan:
type: object
properties:
data:
type: object
properties:
id:
type: string
format: uuid
type:
type: string
enum:
- scans
title: Resource Type Name
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common
attributes and relationships.
required:
- id
- type
required:
- data
description: The identifier of the related object.
title: Resource Identifier
nullable: true
ComplianceOverviewFullResponse:
type: object
properties:
data:
$ref: '#/components/schemas/ComplianceOverviewFull'
required:
- data
id:
type: string
framework:
type: string
version:
type: string
description:
type: string
status:
enum:
- FAIL
- PASS
- MANUAL
type: string
description: |-
* `FAIL` - Fail
* `PASS` - Pass
* `MANUAL` - Manual
required:
- id
- framework
- version
- description
- status
ComplianceOverviewDetailTypeEnum:
type: string
enum:
- compliance-requirements-details
ComplianceOverviewMetadata:
type: object
required:
@@ -7072,17 +7061,14 @@ components:
type: string
required:
- regions
ComplianceOverviewMetadataResponse:
type: object
properties:
data:
$ref: '#/components/schemas/ComplianceOverviewMetadata'
required:
- data
ComplianceOverviewMetadataTypeEnum:
type: string
enum:
- compliance-overviews-metadata
ComplianceOverviewTypeEnum:
type: string
enum:
- compliance-overviews
Finding:
type: object
required:
@@ -8386,7 +8372,7 @@ components:
type: object
properties:
data:
$ref: '#/components/schemas/Membership'
$ref: '#/components/schemas/ComplianceOverviewMetadata'
required:
- data
OverviewFinding:
@@ -8601,29 +8587,33 @@ components:
type: string
enum:
- findings-severity-overview
PaginatedComplianceOverviewAttributesList:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/ComplianceOverviewAttributes'
required:
- data
PaginatedComplianceOverviewDetailList:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/ComplianceOverviewDetail'
required:
- data
PaginatedComplianceOverviewList:
type: object
required:
- count
- results
properties:
count:
type: integer
example: 123
next:
type: string
nullable: true
format: uri
example: http://api.example.org/accounts/?page[number]=4
previous:
type: string
nullable: true
format: uri
example: http://api.example.org/accounts/?page[number]=2
results:
data:
type: array
items:
$ref: '#/components/schemas/ComplianceOverview'
required:
- data
PaginatedFindingList:
type: object
properties:
@@ -11904,6 +11894,7 @@ components:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
@@ -11912,6 +11903,9 @@ components:
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
id:
type: string
format: uuid
attributes:
type: object
properties:
@@ -12326,10 +12320,6 @@ components:
type: string
enum:
- roles
Type7f7Enum:
type: string
enum:
- compliance-overviews
Type8cdEnum:
type: string
enum:
+6 -6
View File
@@ -1,12 +1,12 @@
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
from api.compliance import (
generate_compliance_overview_template,
generate_scan_compliance,
get_prowler_provider_checks,
get_prowler_provider_compliance,
load_prowler_compliance,
load_prowler_checks,
generate_scan_compliance,
generate_compliance_overview_template,
load_prowler_compliance,
)
from api.models import Provider
@@ -69,7 +69,7 @@ class TestCompliance:
load_prowler_compliance()
from api.compliance import PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE, PROWLER_CHECKS
from api.compliance import PROWLER_CHECKS, PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE
assert PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE == {
"template_key": "template_value"
@@ -268,7 +268,7 @@ class TestCompliance:
"manual": 0,
"total": 0,
},
"status": "PASS",
"status": "MANUAL",
},
},
"requirements_status": {
@@ -3,9 +3,13 @@ from enum import Enum
from unittest.mock import patch
import pytest
from django.conf import settings
from freezegun import freeze_time
from api.db_utils import (
_should_create_index_on_partition,
batch_delete,
create_objects_in_batches,
enum_to_choices,
generate_random_token,
one_week_from_now,
@@ -138,3 +142,88 @@ class TestBatchDelete:
)
assert Provider.objects.all().count() == 0
assert summary == {"api.Provider": create_test_providers}
class TestShouldCreateIndexOnPartition:
@freeze_time("2025-05-15 00:00:00Z")
@pytest.mark.parametrize(
"partition_name, all_partitions, expected",
[
("any_name", True, True),
("findings_default", True, True),
("findings_2022_jan", True, True),
("foo_bar", False, True),
("findings_2025_MAY", False, True),
("findings_2025_may", False, True),
("findings_2025_jun", False, True),
("findings_2025_apr", False, False),
("findings_2025_xyz", False, True),
],
)
def test_partition_inclusion_logic(self, partition_name, all_partitions, expected):
assert (
_should_create_index_on_partition(partition_name, all_partitions)
is expected
)
@freeze_time("2025-05-15 00:00:00Z")
def test_invalid_date_components(self):
# even if regex matches but int conversion fails, we fallback True
# (e.g. year too big, month number parse error)
bad_name = "findings_99999_jan"
assert _should_create_index_on_partition(bad_name, False) is True
bad_name2 = "findings_2025_abc"
# abc not in month_map → fallback True
assert _should_create_index_on_partition(bad_name2, False) is True
@pytest.mark.django_db
class TestCreateObjectsInBatches:
@pytest.fixture
def tenant(self, tenants_fixture):
return tenants_fixture[0]
def make_provider_instances(self, tenant, count):
"""
Return a list of `count` unsaved Provider instances for the given tenant.
"""
base_uid = 1000
return [
Provider(
tenant=tenant,
uid=str(base_uid + i),
provider=Provider.ProviderChoices.AWS,
)
for i in range(count)
]
def test_exact_multiple_of_batch(self, tenant):
total = 6
batch_size = 3
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs, batch_size=batch_size)
qs = Provider.objects.filter(tenant=tenant)
assert qs.count() == total
def test_non_multiple_of_batch(self, tenant):
total = 7
batch_size = 3
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs, batch_size=batch_size)
qs = Provider.objects.filter(tenant=tenant)
assert qs.count() == total
def test_batch_size_default(self, tenant):
default_size = settings.DJANGO_DELETION_BATCH_SIZE
total = default_size + 2
objs = self.make_provider_instances(tenant, total)
create_objects_in_batches(str(tenant.id), Provider, objs)
qs = Provider.objects.filter(tenant=tenant)
assert qs.count() == total
+379
View File
@@ -0,0 +1,379 @@
import json
from uuid import uuid4
import pytest
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.exceptions import (
TaskFailedException,
TaskInProgressException,
TaskNotFoundException,
)
from api.models import Task, User
from api.rls import Tenant
from api.v1.mixins import PaginateByPkMixin, TaskManagementMixin
@pytest.mark.django_db
class TestPaginateByPkMixin:
@pytest.fixture
def tenant(self):
return Tenant.objects.create(name="Test Tenant")
@pytest.fixture
def users(self, tenant):
# Create 5 users with proper email field
users = []
for i in range(5):
user = User.objects.create(email=f"user{i}@example.com", name=f"User {i}")
users.append(user)
return users
class DummyView(PaginateByPkMixin):
def __init__(self, page):
self._page = page
def paginate_queryset(self, qs):
return self._page
def get_serializer(self, queryset, many):
class S:
def __init__(self, data):
# serialize to list of ids
self.data = [obj.id for obj in data] if many else queryset.id
return S(queryset)
def get_paginated_response(self, data):
return Response({"results": data}, status=status.HTTP_200_OK)
def test_no_pagination(self, users):
base_qs = User.objects.all().order_by("id")
view = self.DummyView(page=None)
resp = view.paginate_by_pk(
request=None, base_queryset=base_qs, manager=User.objects
)
# since no pagination, should return all ids in order
expected = [u.id for u in base_qs]
assert isinstance(resp, Response)
assert resp.data == expected
def test_with_pagination(self, users):
base_qs = User.objects.all().order_by("id")
# simulate paging to first 2 ids
page = [base_qs[1].id, base_qs[3].id]
view = self.DummyView(page=page)
resp = view.paginate_by_pk(
request=None, base_queryset=base_qs, manager=User.objects
)
# should fetch only those two users, in the same order as page
assert resp.status_code == status.HTTP_200_OK
assert resp.data == {"results": page}
@pytest.mark.django_db
class TestTaskManagementMixin:
class DummyView(TaskManagementMixin):
pass
@pytest.fixture
def tenant(self):
return Tenant.objects.create(name="Test Tenant")
@pytest.fixture(autouse=True)
def cleanup(self):
Task.objects.all().delete()
TaskResult.objects.all().delete()
def test_no_task_and_no_taskresult_raises_not_found(self):
view = self.DummyView()
with pytest.raises(TaskNotFoundException):
view.check_task_status("task_xyz", {"foo": "bar"})
def test_no_task_and_no_taskresult_returns_none_when_not_raising(self):
view = self.DummyView()
result = view.check_task_status(
"task_xyz", {"foo": "bar"}, raise_on_not_found=False
)
assert result is None
def test_taskresult_pending_raises_in_progress(self):
task_kwargs = {"foo": "bar"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="task_xyz",
task_kwargs=json.dumps(task_kwargs),
status="PENDING",
)
view = self.DummyView()
with pytest.raises(TaskInProgressException) as excinfo:
view.check_task_status("task_xyz", task_kwargs, raise_on_not_found=False)
assert hasattr(excinfo.value, "task_result")
assert excinfo.value.task_result == tr
def test_taskresult_started_raises_in_progress(self):
task_kwargs = {"foo": "bar"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="task_xyz",
task_kwargs=json.dumps(task_kwargs),
status="STARTED",
)
view = self.DummyView()
with pytest.raises(TaskInProgressException) as excinfo:
view.check_task_status("task_xyz", task_kwargs, raise_on_not_found=False)
assert hasattr(excinfo.value, "task_result")
assert excinfo.value.task_result == tr
def test_taskresult_progress_raises_in_progress(self):
task_kwargs = {"foo": "bar"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="task_xyz",
task_kwargs=json.dumps(task_kwargs),
status="PROGRESS",
)
view = self.DummyView()
with pytest.raises(TaskInProgressException) as excinfo:
view.check_task_status("task_xyz", task_kwargs, raise_on_not_found=False)
assert hasattr(excinfo.value, "task_result")
assert excinfo.value.task_result == tr
def test_taskresult_failure_raises_failed(self):
task_kwargs = {"a": 1}
TaskResult.objects.create(
task_id=str(uuid4()),
task_name="task_fail",
task_kwargs=json.dumps(task_kwargs),
status="FAILURE",
)
view = self.DummyView()
with pytest.raises(TaskFailedException):
view.check_task_status("task_fail", task_kwargs, raise_on_not_found=False)
def test_taskresult_failure_returns_none_when_not_raising(self):
task_kwargs = {"a": 1}
TaskResult.objects.create(
task_id=str(uuid4()),
task_name="task_fail",
task_kwargs=json.dumps(task_kwargs),
status="FAILURE",
)
view = self.DummyView()
result = view.check_task_status(
"task_fail", task_kwargs, raise_on_failed=False, raise_on_not_found=False
)
assert result is None
def test_taskresult_success_returns_none(self):
task_kwargs = {"x": 2}
TaskResult.objects.create(
task_id=str(uuid4()),
task_name="task_ok",
task_kwargs=json.dumps(task_kwargs),
status="SUCCESS",
)
view = self.DummyView()
# should not raise, and returns None
assert (
view.check_task_status("task_ok", task_kwargs, raise_on_not_found=False)
is None
)
def test_taskresult_revoked_returns_none(self):
task_kwargs = {"x": 2}
TaskResult.objects.create(
task_id=str(uuid4()),
task_name="task_revoked",
task_kwargs=json.dumps(task_kwargs),
status="REVOKED",
)
view = self.DummyView()
# should not raise, and returns None
assert (
view.check_task_status(
"task_revoked", task_kwargs, raise_on_not_found=False
)
is None
)
def test_task_with_failed_status_raises_failed(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="FAILURE",
)
task = Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
with pytest.raises(TaskFailedException) as excinfo:
view.check_task_status("scan_task", task_kwargs)
# Check that the exception contains the expected task
assert hasattr(excinfo.value, "task")
assert excinfo.value.task == task
def test_task_with_cancelled_status_raises_failed(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="REVOKED",
)
task = Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
with pytest.raises(TaskFailedException) as excinfo:
view.check_task_status("scan_task", task_kwargs)
# Check that the exception contains the expected task
assert hasattr(excinfo.value, "task")
assert excinfo.value.task == task
def test_task_with_failed_status_returns_task_when_not_raising(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="FAILURE",
)
task = Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
result = view.check_task_status("scan_task", task_kwargs, raise_on_failed=False)
assert result == task
def test_task_with_completed_status_returns_none(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="SUCCESS",
)
Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
result = view.check_task_status("scan_task", task_kwargs)
assert result is None
def test_task_with_executing_status_returns_task(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="STARTED",
)
task = Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
result = view.check_task_status("scan_task", task_kwargs)
assert result is not None
assert result.pk == task.pk
def test_task_with_pending_status_returns_task(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="PENDING",
)
task = Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
result = view.check_task_status("scan_task", task_kwargs)
assert result is not None
assert result.pk == task.pk
def test_get_task_response_if_running_returns_none_for_completed_task(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="SUCCESS",
)
Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
result = view.get_task_response_if_running("scan_task", task_kwargs)
assert result is None
def test_get_task_response_if_running_returns_none_for_no_task(self):
view = self.DummyView()
result = view.get_task_response_if_running(
"nonexistent", {"foo": "bar"}, raise_on_not_found=False
)
assert result is None
def test_get_task_response_if_running_returns_202_for_executing_task(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="STARTED",
)
task = Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
result = view.get_task_response_if_running("scan_task", task_kwargs)
assert isinstance(result, Response)
assert result.status_code == status.HTTP_202_ACCEPTED
assert "Content-Location" in result.headers
# The response should contain the serialized task data
assert result.data is not None
assert "id" in result.data
assert str(result.data["id"]) == str(task.id)
def test_get_task_response_if_running_returns_none_for_available_task(self, tenant):
task_kwargs = {"provider_id": "test"}
tr = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs),
status="PENDING",
)
Task.objects.create(tenant=tenant, task_runner_task=tr)
view = self.DummyView()
result = view.get_task_response_if_running("scan_task", task_kwargs)
# PENDING maps to AVAILABLE, which is not EXECUTING, so should return None
assert result is None
def test_kwargs_filtering_works_correctly(self, tenant):
# Create tasks with different kwargs
task_kwargs_1 = {"provider_id": "test1", "scan_type": "full"}
task_kwargs_2 = {"provider_id": "test2", "scan_type": "quick"}
tr1 = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs_1),
status="STARTED",
)
tr2 = TaskResult.objects.create(
task_id=str(uuid4()),
task_name="scan_task",
task_kwargs=json.dumps(task_kwargs_2),
status="STARTED",
)
task1 = Task.objects.create(tenant=tenant, task_runner_task=tr1)
task2 = Task.objects.create(tenant=tenant, task_runner_task=tr2)
view = self.DummyView()
# Should find task1 when searching for its kwargs
result1 = view.check_task_status("scan_task", {"provider_id": "test1"})
assert result1 is not None
assert result1.pk == task1.pk
# Should find task2 when searching for its kwargs
result2 = view.check_task_status("scan_task", {"provider_id": "test2"})
assert result2 is not None
assert result2.pk == task2.pk
# Should not find anything when searching for non-existent kwargs
result3 = view.check_task_status(
"scan_task", {"provider_id": "test3"}, raise_on_not_found=False
)
assert result3 is None
+1 -36
View File
@@ -1,9 +1,6 @@
import uuid
from unittest import mock
import pytest
from api.models import Resource, ResourceTag, Task
from api.models import Resource, ResourceTag
@pytest.mark.django_db
@@ -123,35 +120,3 @@ class TestResourceModel:
# compliance={},
# )
# assert Finding.objects.filter(uid=long_uid).exists()
@pytest.mark.django_db
class TestTaskManager:
def test_get_with_retry_success(self):
task_id = uuid.uuid4()
call_counter = {"count": 0}
def side_effect(*args, **kwargs):
if call_counter["count"] < 2:
call_counter["count"] += 1
raise Task.DoesNotExist()
return Task(id=task_id)
with mock.patch.object(Task.objects, "get", side_effect=side_effect):
task = Task.objects.get_with_retry(
task_id, max_retries=5, delay_seconds=0.01
)
assert task.id == task_id
assert call_counter["count"] == 2
def test_get_with_retry_fail(self):
non_existent_id = uuid.uuid4()
with mock.patch.object(Task.objects, "get", side_effect=Task.DoesNotExist):
with pytest.raises(Task.DoesNotExist) as excinfo:
Task.objects.get_with_retry(
non_existent_id, max_retries=3, delay_seconds=0.01
)
assert str(non_existent_id) in str(excinfo.value)
+235 -178
View File
@@ -15,10 +15,10 @@ from django.conf import settings
from django.urls import reverse
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.compliance import get_compliance_frameworks
from api.models import (
ComplianceOverview,
Integration,
Invitation,
Membership,
@@ -35,6 +35,7 @@ from api.models import (
UserRoleRelationship,
)
from api.rls import Tenant
from api.v1.views import ComplianceOverviewViewSet
TODAY = str(datetime.today().date())
@@ -4761,210 +4762,266 @@ class TestComplianceOverviewViewSet:
assert len(response.json()["data"]) == 0
def test_compliance_overview_list(
self, authenticated_client, compliance_overviews_fixture
self, authenticated_client, compliance_requirements_overviews_fixture
):
# List compliance overviews with existing data
compliance_overview1, compliance_overview2 = compliance_overviews_fixture
scan_id = str(compliance_overview1.scan.id)
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": scan_id},
)
assert response.status_code == status.HTTP_200_OK
assert (
len(response.json()["data"]) == 1
) # Due to the custom get_queryset method, only one compliance_id
data = response.json()["data"]
assert len(data) == 2 # Two compliance frameworks
def test_compliance_overview_list_missing_scan_id(self, authenticated_client):
# Attempt to list compliance overviews without providing filter[scan_id]
response = authenticated_client.get(reverse("complianceoverview-list"))
# Check that we get aggregated data for each compliance framework
framework_ids = [item["id"] for item in data]
assert "aws_account_security_onboarding_aws" in framework_ids
assert "cis_1.4_aws" in framework_ids
# Check structure of response
for item in data:
assert "id" in item
assert "attributes" in item
attributes = item["attributes"]
assert "framework" in attributes
assert "version" in attributes
assert "requirements_passed" in attributes
assert "requirements_failed" in attributes
assert "requirements_manual" in attributes
assert "total_requirements" in attributes
def test_compliance_overview_metadata(
self, authenticated_client, compliance_requirements_overviews_fixture
):
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-metadata"),
{"filter[scan_id]": scan_id},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert "attributes" in data
assert "regions" in data["attributes"]
assert isinstance(data["attributes"]["regions"], list)
def test_compliance_overview_requirements(
self, authenticated_client, compliance_requirements_overviews_fixture
):
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
compliance_id = requirement_overview1.compliance_id
response = authenticated_client.get(
reverse("complianceoverview-requirements"),
{
"filter[scan_id]": scan_id,
"filter[compliance_id]": compliance_id,
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) > 0
# Check structure of requirements response
for item in data:
assert "id" in item
assert "attributes" in item
attributes = item["attributes"]
assert "framework" in attributes
assert "version" in attributes
assert "description" in attributes
assert "status" in attributes
def test_compliance_overview_requirements_manual(
self, authenticated_client, compliance_requirements_overviews_fixture
):
scan_id = str(compliance_requirements_overviews_fixture[0].scan.id)
# Compliance with a manual requirement
compliance_id = "aws_account_security_onboarding_aws"
response = authenticated_client.get(
reverse("complianceoverview-requirements"),
{
"filter[scan_id]": scan_id,
"filter[compliance_id]": compliance_id,
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert data[-1]["attributes"]["status"] == "MANUAL"
def test_compliance_overview_requirements_missing_scan_id(
self, authenticated_client
):
response = authenticated_client.get(
reverse("complianceoverview-requirements"),
{"filter[compliance_id]": "aws_account_security_onboarding_aws"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["source"]["pointer"] == "filter[scan_id]"
assert response.json()["errors"][0]["code"] == "required"
def test_compliance_overview_requirements_missing_compliance_id(
self, authenticated_client, compliance_requirements_overviews_fixture
):
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-requirements"),
{"filter[scan_id]": scan_id},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_compliance_overview_attributes(self, authenticated_client):
response = authenticated_client.get(
reverse("complianceoverview-attributes"),
{"filter[compliance_id]": "aws_account_security_onboarding_aws"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert len(data) > 0
# Check structure of attributes response
for item in data:
assert "id" in item
assert "attributes" in item
attributes = item["attributes"]
assert "framework" in attributes
assert "version" in attributes
assert "description" in attributes
assert "attributes" in attributes
assert "metadata" in attributes["attributes"]
assert "check_ids" in attributes["attributes"]
def test_compliance_overview_attributes_missing_compliance_id(
self, authenticated_client
):
response = authenticated_client.get(
reverse("complianceoverview-attributes"),
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_compliance_overview_task_management_integration(
self, authenticated_client, compliance_requirements_overviews_fixture
):
"""Test that task management mixin is properly integrated"""
from unittest.mock import patch
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
# Mock a running task
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
) as mock_task_response:
mock_response = Response(
{"detail": "Task is running"}, status=status.HTTP_202_ACCEPTED
)
mock_task_response.return_value = mock_response
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": scan_id},
)
assert response.status_code == status.HTTP_202_ACCEPTED
mock_task_response.assert_called_once()
def test_compliance_overview_task_failed_exception(
self, authenticated_client, compliance_requirements_overviews_fixture
):
"""Test handling of TaskFailedException"""
from unittest.mock import patch
from api.exceptions import TaskFailedException
requirement_overview1 = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview1.scan.id)
# Mock a failed task
with patch.object(
ComplianceOverviewViewSet, "get_task_response_if_running"
) as mock_task_response:
mock_task_response.side_effect = TaskFailedException("Task failed")
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": scan_id},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Task failed to generate compliance overview data" in str(
response.data
)
@pytest.mark.parametrize(
"filter_name, filter_value, expected_count",
"filter_name, filter_value_attr, expected_count_min",
[
("compliance_id", "aws_account_security_onboarding_aws", 1),
("compliance_id.icontains", "security_onboarding", 1),
("framework", "AWS-Account-Security-Onboarding", 1),
("framework.icontains", "security-onboarding", 1),
("version", "1.0", 1),
("version", "2.0", 0),
("version.icontains", "0", 1),
("region", "eu-west-1", 1),
("region.icontains", "west-1", 1),
("region.in", "eu-west-1,eu-west-2", 1),
("inserted_at.date", "2024-01-01", 0),
("inserted_at.date", TODAY, 1),
("inserted_at.gte", "2024-01-01", 1),
("scan_id", "scan.id", 1),
("compliance_id", "compliance_id", 1),
("framework", "framework", 1),
("version", "version", 1),
("region", "region", 1),
],
)
def test_compliance_overview_filters(
self,
authenticated_client,
compliance_overviews_fixture,
compliance_requirements_overviews_fixture,
filter_name,
filter_value,
expected_count,
filter_value_attr,
expected_count_min,
):
# Test filtering compliance overviews
compliance_overview1 = compliance_overviews_fixture[0]
scan_id = str(compliance_overview1.scan.id)
requirement_overview = compliance_requirements_overviews_fixture[0]
scan_id = str(requirement_overview.scan.id)
filter_value = requirement_overview
for attr in filter_value_attr.split("."):
filter_value = getattr(filter_value, attr)
filter_value = str(filter_value)
query_params = {
"filter[scan_id]": scan_id,
f"filter[{filter_name}]": filter_value,
}
if filter_name == "scan_id":
query_params = {"filter[scan_id]": filter_value}
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[scan_id]": scan_id,
f"filter[{filter_name}]": filter_value,
},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == expected_count
@pytest.mark.parametrize(
"filter_name",
["invalid_filter", "unknown_field"],
)
def test_compliance_overview_filters_invalid(
self, authenticated_client, compliance_overviews_fixture, filter_name
):
# Test handling of invalid filters
compliance_overview1 = compliance_overviews_fixture[0]
scan_id = str(compliance_overview1.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[scan_id]": scan_id,
f"filter[{filter_name}]": "some_value",
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@pytest.mark.parametrize(
"sort_field",
["inserted_at", "-inserted_at", "compliance_id", "-compliance_id"],
)
def test_compliance_overview_sort(
self, authenticated_client, compliance_overviews_fixture, sort_field
):
# Test sorting compliance overviews
compliance_overview1 = compliance_overviews_fixture[0]
scan_id = str(compliance_overview1.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[scan_id]": scan_id,
"sort": sort_field,
},
)
assert response.status_code == status.HTTP_200_OK
def test_compliance_overview_sort_invalid(
self, authenticated_client, compliance_overviews_fixture
):
# Test handling of invalid sort parameters
compliance_overview1 = compliance_overviews_fixture[0]
scan_id = str(compliance_overview1.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{
"filter[scan_id]": scan_id,
"sort": "invalid_field",
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["code"] == "invalid"
assert "invalid sort parameter" in response.json()["errors"][0]["detail"]
def test_compliance_overview_retrieve(
self, authenticated_client, compliance_overviews_fixture
):
# Retrieve a specific compliance overview
compliance_overview1 = compliance_overviews_fixture[0]
response = authenticated_client.get(
reverse(
"complianceoverview-detail",
kwargs={"pk": compliance_overview1.id},
),
)
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
assert data["id"] == str(compliance_overview1.id)
attributes = data["attributes"]
assert attributes["compliance_id"] == compliance_overview1.compliance_id
assert attributes["framework"] == compliance_overview1.framework
assert attributes["version"] == compliance_overview1.version
assert attributes["region"] == compliance_overview1.region
assert attributes["description"] == compliance_overview1.description
assert "requirements" in attributes
def test_compliance_overview_invalid_retrieve(self, authenticated_client):
# Attempt to retrieve a compliance overview with an invalid ID
response = authenticated_client.get(
reverse(
"complianceoverview-detail",
kwargs={"pk": "invalid-id"},
),
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_compliance_overview_list_queryset(
self, authenticated_client, compliance_overviews_fixture
):
compliance_overview1, compliance_overview2 = compliance_overviews_fixture
scan_id = str(compliance_overview1.scan.id)
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": scan_id},
)
# No filters, most fails should be returned
assert len(response.json()["data"]) == 1
assert response.json()["data"][0]["id"] == str(compliance_overview2.id)
compliance_overview1.requirements_failed = 5
compliance_overview1.save()
response = authenticated_client.get(
reverse("complianceoverview-list"),
{"filter[scan_id]": scan_id},
)
# No filters, now compliance_overview1 has more fails
assert len(response.json()["data"]) == 1
assert response.json()["data"][0]["id"] == str(compliance_overview1.id)
def test_compliance_overview_metadata(
self, authenticated_client, compliance_overviews_fixture
):
response = authenticated_client.get(
reverse("complianceoverview-metadata"),
{"filter[scan_id]": str(compliance_overviews_fixture[0].scan_id)},
)
data = response.json()
expected_regions = set(
ComplianceOverview.objects.all()
.values_list("region", flat=True)
.distinct("region")
query_params,
)
assert response.status_code == status.HTTP_200_OK
assert data["data"]["type"] == "compliance-overviews-metadata"
assert data["data"]["id"] is None
assert set(data["data"]["attributes"]["regions"]) == expected_regions
response_data = response.json()
def test_compliance_overview_metadata_missing_scan_id(self, authenticated_client):
# Attempt to list compliance overviews without providing filter[scan_id]
response = authenticated_client.get(reverse("complianceoverview-metadata"))
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"][0]["source"]["pointer"] == "filter[scan_id]"
assert response.json()["errors"][0]["code"] == "required"
assert len(response_data["data"]) >= expected_count_min
if response_data["data"]:
first_item = response_data["data"][0]
assert "id" in first_item
assert "type" in first_item
assert first_item["type"] == "compliance-overviews"
assert "attributes" in first_item
attributes = first_item["attributes"]
assert "framework" in attributes
assert "version" in attributes
assert "requirements_passed" in attributes
assert "requirements_failed" in attributes
assert "requirements_manual" in attributes
assert "total_requirements" in attributes
if filter_name == "compliance_id":
assert first_item["id"] == filter_value
elif filter_name == "framework":
assert attributes["framework"] == filter_value
elif filter_name == "version":
assert attributes["version"] == filter_value
@pytest.mark.django_db
+189
View File
@@ -1,5 +1,16 @@
from django.urls import reverse
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.exceptions import (
TaskFailedException,
TaskInProgressException,
TaskNotFoundException,
)
from api.models import StateChoices, Task
from api.v1.serializers import TaskSerializer
class PaginateByPkMixin:
"""
@@ -31,3 +42,181 @@ class PaginateByPkMixin:
serialized = self.get_serializer(queryset, many=True).data
return self.get_paginated_response(serialized)
class TaskManagementMixin:
"""
Mixin to manage task status checking.
This mixin provides functionality to check if a task with specific parameters
is running, completed, failed, or doesn't exist. It returns the task when running
and raises specific exceptions for failed/not found scenarios that can be handled
at the view level.
"""
def check_task_status(
self,
task_name: str,
task_kwargs: dict,
raise_on_failed: bool = True,
raise_on_not_found: bool = True,
) -> Task | None:
"""
Check the status of a task with given name and kwargs.
This method first checks for a related Task object, and if not found,
checks TaskResult directly. If a TaskResult is found and running but
there's no related Task, it raises TaskInProgressException.
Args:
task_name (str): The name of the task to check
task_kwargs (dict): The kwargs to match against the task
raise_on_failed (bool): Whether to raise exception if task failed
raise_on_not_found (bool): Whether to raise exception if task not found
Returns:
Task | None: The task instance if found (regardless of state), None if not found and raise_on_not_found=False
Raises:
TaskFailedException: If task failed and raise_on_failed=True
TaskNotFoundException: If task not found and raise_on_not_found=True
TaskInProgressException: If task is running but no related Task object exists
"""
# First, try to find a Task object with related TaskResult
try:
# Build the filter for task kwargs
task_filter = {
"task_runner_task__task_name": task_name,
}
# Add kwargs filters - we need to check if the task kwargs contain our parameters
for key, value in task_kwargs.items():
task_filter["task_runner_task__task_kwargs__contains"] = str(value)
task = (
Task.objects.filter(**task_filter)
.select_related("task_runner_task")
.order_by("-inserted_at")
.first()
)
if task:
# Get task state using the same logic as TaskSerializer
task_state_mapping = {
"PENDING": StateChoices.AVAILABLE,
"STARTED": StateChoices.EXECUTING,
"PROGRESS": StateChoices.EXECUTING,
"SUCCESS": StateChoices.COMPLETED,
"FAILURE": StateChoices.FAILED,
"REVOKED": StateChoices.CANCELLED,
}
celery_status = (
task.task_runner_task.status if task.task_runner_task else None
)
task_state = task_state_mapping.get(
celery_status or "", StateChoices.AVAILABLE
)
# Check task state and raise exceptions accordingly
if task_state in (StateChoices.FAILED, StateChoices.CANCELLED):
if raise_on_failed:
raise TaskFailedException(task=task)
return task
elif task_state == StateChoices.COMPLETED:
return None
return task
except Task.DoesNotExist:
pass
# If no Task found, check TaskResult directly
try:
# Build the filter for TaskResult
task_result_filter = {
"task_name": task_name,
}
# Add kwargs filters - check if the task kwargs contain our parameters
for key, value in task_kwargs.items():
task_result_filter["task_kwargs__contains"] = str(value)
task_result = (
TaskResult.objects.filter(**task_result_filter)
.order_by("-date_created")
.first()
)
if task_result:
# Check if the TaskResult indicates a running task
if task_result.status in ["PENDING", "STARTED", "PROGRESS"]:
# Task is running but no related Task object exists
raise TaskInProgressException(task_result=task_result)
elif task_result.status == "FAILURE":
if raise_on_failed:
raise TaskFailedException(task=None)
# For other statuses (SUCCESS, REVOKED), we don't have a Task to return,
# so we treat it as not found
except TaskResult.DoesNotExist:
pass
# No task found at all
if raise_on_not_found:
raise TaskNotFoundException()
return None
def get_task_response_if_running(
self,
task_name: str,
task_kwargs: dict,
raise_on_failed: bool = True,
raise_on_not_found: bool = True,
) -> Response | None:
"""
Get a 202 response with task details if the task is currently running.
This method is useful for endpoints that should return task status when
a background task is in progress, similar to the compliance overview endpoints.
Args:
task_name (str): The name of the task to check
task_kwargs (dict): The kwargs to match against the task
Returns:
Response | None: 202 response with task details if running, None otherwise
"""
task = self.check_task_status(
task_name=task_name,
task_kwargs=task_kwargs,
raise_on_failed=raise_on_failed,
raise_on_not_found=raise_on_not_found,
)
if not task:
return None
# Get task state
task_state_mapping = {
"PENDING": StateChoices.AVAILABLE,
"STARTED": StateChoices.EXECUTING,
"PROGRESS": StateChoices.EXECUTING,
"SUCCESS": StateChoices.COMPLETED,
"FAILURE": StateChoices.FAILED,
"REVOKED": StateChoices.CANCELLED,
}
celery_status = task.task_runner_task.status if task.task_runner_task else None
task_state = task_state_mapping.get(celery_status or "", StateChoices.AVAILABLE)
if task_state == StateChoices.EXECUTING:
self.response_serializer_class = TaskSerializer
serializer = TaskSerializer(task)
return Response(
data=serializer.data,
status=status.HTTP_202_ACCEPTED,
headers={
"Content-Location": reverse("task-detail", kwargs={"pk": task.id})
},
)
+43 -112
View File
@@ -14,7 +14,6 @@ from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.tokens import RefreshToken
from api.models import (
ComplianceOverview,
Finding,
Integration,
IntegrationProviderRelationship,
@@ -31,6 +30,7 @@ from api.models import (
RoleProviderGroupRelationship,
Scan,
StateChoices,
StatusChoices,
Task,
User,
UserRoleRelationship,
@@ -1679,130 +1679,61 @@ class RoleProviderGroupRelationshipSerializer(RLSSerializer, BaseWriteSerializer
# Compliance overview
class ComplianceOverviewSerializer(RLSSerializer):
class ComplianceOverviewSerializer(serializers.Serializer):
"""
Serializer for the ComplianceOverview model.
Serializer for compliance requirement status aggregated by compliance framework.
This serializer is used to format aggregated compliance framework data,
providing counts of passed, failed, and manual requirements along with
an overall global status for each framework.
"""
requirements_status = serializers.SerializerMethodField(
read_only=True, method_name="get_requirements_status"
)
provider_type = serializers.SerializerMethodField(read_only=True)
# Add ID field which will be used for resource identification
id = serializers.CharField()
framework = serializers.CharField()
version = serializers.CharField()
requirements_passed = serializers.IntegerField()
requirements_failed = serializers.IntegerField()
requirements_manual = serializers.IntegerField()
total_requirements = serializers.IntegerField()
class Meta:
model = ComplianceOverview
fields = [
"id",
"inserted_at",
"compliance_id",
"framework",
"version",
"requirements_status",
"region",
"provider_type",
"scan",
"url",
]
@extend_schema_field(
{
"type": "object",
"properties": {
"passed": {"type": "integer"},
"failed": {"type": "integer"},
"manual": {"type": "integer"},
"total": {"type": "integer"},
},
}
)
def get_requirements_status(self, obj):
return {
"passed": obj.requirements_passed,
"failed": obj.requirements_failed,
"manual": obj.requirements_manual,
"total": obj.total_requirements,
}
@extend_schema_field(serializers.CharField(allow_null=True))
def get_provider_type(self, obj):
"""
Retrieves the provider_type from scan.provider.provider_type.
"""
try:
return obj.scan.provider.provider
except AttributeError:
return None
class JSONAPIMeta:
resource_name = "compliance-overviews"
class ComplianceOverviewFullSerializer(ComplianceOverviewSerializer):
requirements = serializers.SerializerMethodField(read_only=True)
class ComplianceOverviewDetailSerializer(serializers.Serializer):
"""
Serializer for detailed compliance requirement information.
class Meta(ComplianceOverviewSerializer.Meta):
fields = ComplianceOverviewSerializer.Meta.fields + [
"description",
"requirements",
]
This serializer formats the aggregated requirement data, showing detailed status
and counts for each requirement across all regions.
"""
@extend_schema_field(
{
"type": "object",
"properties": {
"requirement_id": {
"type": "object",
"properties": {
"name": {"type": "string"},
"checks": {
"type": "object",
"properties": {
"check_name": {
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["PASS", "FAIL", None],
},
},
}
},
"description": "Each key in the 'checks' object is a check name, with values as "
"'PASS', 'FAIL', or null.",
},
"status": {
"type": "string",
"enum": ["PASS", "FAIL", "MANUAL"],
},
"attributes": {
"type": "array",
"items": {
"type": "object",
},
},
"description": {"type": "string"},
"checks_status": {
"type": "object",
"properties": {
"total": {"type": "integer"},
"pass": {"type": "integer"},
"fail": {"type": "integer"},
"manual": {"type": "integer"},
},
},
},
}
},
}
)
def get_requirements(self, obj):
"""
Returns the detailed structure of requirements.
"""
return obj.requirements
id = serializers.CharField()
framework = serializers.CharField()
version = serializers.CharField()
description = serializers.CharField()
status = serializers.ChoiceField(choices=StatusChoices.choices)
class JSONAPIMeta:
resource_name = "compliance-requirements-details"
class ComplianceOverviewAttributesSerializer(serializers.Serializer):
id = serializers.CharField()
framework = serializers.CharField()
version = serializers.CharField()
description = serializers.CharField()
attributes = serializers.JSONField()
class JSONAPIMeta:
resource_name = "compliance-requirements-attributes"
class ComplianceOverviewMetadataSerializer(serializers.Serializer):
regions = serializers.ListField(child=serializers.CharField(), allow_empty=True)
class Meta:
class JSONAPIMeta:
resource_name = "compliance-overviews-metadata"
+361 -53
View File
@@ -17,7 +17,7 @@ from django.conf import settings as django_settings
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.search import SearchQuery
from django.db import transaction
from django.db.models import Count, Exists, F, OuterRef, Prefetch, Q, Subquery, Sum
from django.db.models import Count, Exists, F, OuterRef, Prefetch, Q, Sum
from django.db.models.functions import Coalesce
from django.http import HttpResponse
from django.urls import reverse
@@ -26,10 +26,10 @@ from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_control
from django_celery_beat.models import PeriodicTask
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import (
OpenApiParameter,
OpenApiResponse,
OpenApiTypes,
extend_schema,
extend_schema_view,
)
@@ -58,8 +58,12 @@ from tasks.tasks import (
)
from api.base_views import BaseRLSViewSet, BaseTenantViewset, BaseUserViewset
from api.compliance import get_compliance_frameworks
from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
get_compliance_frameworks,
)
from api.db_router import MainRouter
from api.exceptions import TaskFailedException
from api.filters import (
ComplianceOverviewFilter,
FindingFilter,
@@ -81,6 +85,7 @@ from api.filters import (
)
from api.models import (
ComplianceOverview,
ComplianceRequirementOverview,
Finding,
Integration,
Invitation,
@@ -111,9 +116,10 @@ from api.utils import (
validate_invitation,
)
from api.uuid_utils import datetime_to_uuid7, uuid7_start
from api.v1.mixins import PaginateByPkMixin
from api.v1.mixins import PaginateByPkMixin, TaskManagementMixin
from api.v1.serializers import (
ComplianceOverviewFullSerializer,
ComplianceOverviewAttributesSerializer,
ComplianceOverviewDetailSerializer,
ComplianceOverviewMetadataSerializer,
ComplianceOverviewSerializer,
FindingDynamicFilterSerializer,
@@ -1086,7 +1092,7 @@ class ProviderViewSet(BaseRLSViewSet):
task = check_provider_connection_task.delay(
provider_id=pk, tenant_id=self.request.tenant_id
)
prowler_task = Task.objects.get_with_retry(id=task.id)
prowler_task = Task.objects.get(id=task.id)
serializer = TaskSerializer(prowler_task)
return Response(
data=serializer.data,
@@ -1109,7 +1115,7 @@ class ProviderViewSet(BaseRLSViewSet):
task = delete_provider_task.delay(
provider_id=pk, tenant_id=self.request.tenant_id
)
prowler_task = Task.objects.get_with_retry(id=task.id)
prowler_task = Task.objects.get(id=task.id)
serializer = TaskSerializer(prowler_task)
return Response(
data=serializer.data,
@@ -1489,7 +1495,7 @@ class ScanViewSet(BaseRLSViewSet):
},
)
prowler_task = Task.objects.get_with_retry(id=task.id)
prowler_task = Task.objects.get(id=task.id)
scan.task_id = task.id
scan.save(update_fields=["task_id"])
@@ -2391,8 +2397,7 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
list=extend_schema(
tags=["Compliance Overview"],
summary="List compliance overviews for a scan",
description="Retrieve an overview of all the compliance in a given scan. If no region filters are provided, the"
" region with the most fails will be returned by default.",
description="Retrieve an overview of all the compliance in a given scan.",
parameters=[
OpenApiParameter(
name="filter[scan_id]",
@@ -2402,12 +2407,18 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
description="Related scan ID.",
),
],
),
retrieve=extend_schema(
tags=["Compliance Overview"],
summary="Retrieve data from a specific compliance overview",
description="Fetch detailed information about a specific compliance overview by its ID, including detailed "
"requirement information and check's status.",
responses={
200: OpenApiResponse(
description="Compliance overviews obtained successfully",
response=ComplianceOverviewSerializer(many=True),
),
202: OpenApiResponse(
description="The task is in progress", response=TaskSerializer
),
500: OpenApiResponse(
description="Compliance overviews generation task failed"
),
},
),
metadata=extend_schema(
tags=["Compliance Overview"],
@@ -2423,19 +2434,84 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
description="Related scan ID.",
),
],
responses={
200: OpenApiResponse(
description="Compliance overviews metadata obtained successfully",
response=ComplianceOverviewMetadataSerializer,
),
202: OpenApiResponse(description="The task is in progress"),
500: OpenApiResponse(
description="Compliance overviews generation task failed"
),
},
),
requirements=extend_schema(
tags=["Compliance Overview"],
summary="List compliance requirements overview for a scan",
description="Retrieve a detailed overview of compliance requirements in a given scan, grouped by compliance "
"framework. This endpoint provides requirement-level details and aggregates status across regions.",
parameters=[
OpenApiParameter(
name="filter[scan_id]",
required=True,
type=OpenApiTypes.UUID,
location=OpenApiParameter.QUERY,
description="Related scan ID.",
),
OpenApiParameter(
name="filter[compliance_id]",
required=True,
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Compliance ID.",
),
],
responses={
200: OpenApiResponse(
description="Compliance requirement details obtained successfully",
response=ComplianceOverviewDetailSerializer(many=True),
),
202: OpenApiResponse(description="The task is in progress"),
500: OpenApiResponse(
description="Compliance overviews generation task failed"
),
},
filters=True,
),
attributes=extend_schema(
tags=["Compliance Overview"],
summary="Get compliance requirement attributes",
description="Retrieve detailed attribute information for all requirements in a specific compliance framework "
"along with the associated check IDs for each requirement.",
parameters=[
OpenApiParameter(
name="filter[compliance_id]",
required=True,
type=str,
location=OpenApiParameter.QUERY,
description="Compliance framework ID to get attributes for.",
),
],
responses={
200: OpenApiResponse(
description="Compliance attributes obtained successfully",
response=ComplianceOverviewAttributesSerializer(many=True),
),
},
),
)
@method_decorator(CACHE_DECORATOR, name="list")
@method_decorator(CACHE_DECORATOR, name="retrieve")
class ComplianceOverviewViewSet(BaseRLSViewSet):
@method_decorator(CACHE_DECORATOR, name="requirements")
@method_decorator(CACHE_DECORATOR, name="attributes")
class ComplianceOverviewViewSet(BaseRLSViewSet, TaskManagementMixin):
pagination_class = ComplianceOverviewPagination
queryset = ComplianceOverview.objects.all()
queryset = ComplianceRequirementOverview.objects.all()
serializer_class = ComplianceOverviewSerializer
filterset_class = ComplianceOverviewFilter
http_method_names = ["get"]
search_fields = ["compliance_id"]
ordering = ["compliance_id"]
ordering_fields = ["inserted_at", "compliance_id", "framework", "region"]
ordering_fields = ["compliance_id"]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
# the provider through the provider group)
required_permissions = []
@@ -2446,51 +2522,44 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
role, Permissions.UNLIMITED_VISIBILITY.value, False
)
if self.action == "retrieve":
if unlimited_visibility:
# User has unlimited visibility, return all compliance
return ComplianceOverview.objects.filter(
tenant_id=self.request.tenant_id
)
providers = get_providers(role)
return ComplianceOverview.objects.filter(
tenant_id=self.request.tenant_id, scan__provider__in=providers
)
if unlimited_visibility:
base_queryset = self.filter_queryset(
ComplianceOverview.objects.filter(tenant_id=self.request.tenant_id)
ComplianceRequirementOverview.objects.filter(
tenant_id=self.request.tenant_id
)
)
else:
providers = Provider.objects.filter(
provider_groups__in=role.provider_groups.all()
).distinct()
base_queryset = self.filter_queryset(
ComplianceOverview.objects.filter(
ComplianceRequirementOverview.objects.filter(
tenant_id=self.request.tenant_id, scan__provider__in=providers
)
)
max_failed_ids = (
base_queryset.filter(compliance_id=OuterRef("compliance_id"))
.order_by("-requirements_failed")
.values("id")[:1]
)
return base_queryset.filter(id__in=Subquery(max_failed_ids)).order_by(
"compliance_id"
)
return base_queryset
def get_serializer_class(self):
if self.action == "retrieve":
return ComplianceOverviewFullSerializer
if hasattr(self, "response_serializer_class"):
return self.response_serializer_class
elif self.action == "list":
return ComplianceOverviewSerializer
elif self.action == "metadata":
return ComplianceOverviewMetadataSerializer
elif self.action == "attributes":
return ComplianceOverviewAttributesSerializer
elif self.action == "requirements":
return ComplianceOverviewDetailSerializer
return super().get_serializer_class()
@extend_schema(exclude=True)
def retrieve(self, request, *args, **kwargs):
raise MethodNotAllowed(method="GET")
def list(self, request, *args, **kwargs):
if not request.query_params.get("filter[scan_id]"):
scan_id = request.query_params.get("filter[scan_id]")
if not scan_id:
raise ValidationError(
[
{
@@ -2501,7 +2570,82 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
}
]
)
return super().list(request, *args, **kwargs)
try:
if task := self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
):
return task
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
queryset = self.filter_queryset(self.filter_queryset(self.get_queryset()))
requirement_status_subquery = queryset.values(
"compliance_id", "requirement_id"
).annotate(
fail_count=Count("id", filter=Q(requirement_status="FAIL")),
pass_count=Count("id", filter=Q(requirement_status="PASS")),
total_count=Count("id"),
)
compliance_data = {}
framework_info = {}
for item in queryset.values("compliance_id", "framework", "version").distinct():
framework_info[item["compliance_id"]] = {
"framework": item["framework"],
"version": item["version"],
}
for item in requirement_status_subquery:
compliance_id = item["compliance_id"]
if item["fail_count"] > 0:
req_status = "FAIL"
elif item["pass_count"] == item["total_count"]:
req_status = "PASS"
else:
req_status = "MANUAL"
if compliance_id not in compliance_data:
compliance_data[compliance_id] = {
"total_requirements": 0,
"requirements_passed": 0,
"requirements_failed": 0,
"requirements_manual": 0,
}
compliance_data[compliance_id]["total_requirements"] += 1
if req_status == "PASS":
compliance_data[compliance_id]["requirements_passed"] += 1
elif req_status == "FAIL":
compliance_data[compliance_id]["requirements_failed"] += 1
else:
compliance_data[compliance_id]["requirements_manual"] += 1
response_data = []
for compliance_id, data in compliance_data.items():
framework = framework_info.get(compliance_id, {})
response_data.append(
{
"id": compliance_id,
"compliance_id": compliance_id,
"framework": framework.get("framework", ""),
"version": framework.get("version", ""),
"requirements_passed": data["requirements_passed"],
"requirements_failed": data["requirements_failed"],
"requirements_manual": data["requirements_manual"],
"total_requirements": data["total_requirements"],
}
)
serializer = self.get_serializer(response_data, many=True)
return Response(serializer.data)
@action(detail=False, methods=["get"], url_name="metadata")
def metadata(self, request):
@@ -2517,11 +2661,21 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
}
]
)
tenant_id = self.request.tenant_id
try:
if task := self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
):
return task
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
regions = list(
ComplianceOverview.objects.filter(tenant_id=tenant_id, scan_id=scan_id)
self.get_queryset()
.filter(scan_id=scan_id)
.values_list("region", flat=True)
.order_by("region")
.distinct()
@@ -2532,6 +2686,160 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
serializer.is_valid(raise_exception=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(detail=False, methods=["get"], url_name="requirements")
def requirements(self, request):
scan_id = request.query_params.get("filter[scan_id]")
compliance_id = request.query_params.get("filter[compliance_id]")
if not scan_id:
raise ValidationError(
[
{
"detail": "This query parameter is required.",
"status": 400,
"source": {"pointer": "filter[scan_id]"},
"code": "required",
}
]
)
if not compliance_id:
raise ValidationError(
[
{
"detail": "This query parameter is required.",
"status": 400,
"source": {"pointer": "filter[compliance_id]"},
"code": "required",
}
]
)
try:
if task := self.get_task_response_if_running(
task_name="scan-compliance-overviews",
task_kwargs={"tenant_id": self.request.tenant_id, "scan_id": scan_id},
raise_on_not_found=False,
):
return task
except TaskFailedException:
return Response(
{"detail": "Task failed to generate compliance overview data."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
filtered_queryset = self.filter_queryset(self.get_queryset())
all_requirements = (
filtered_queryset.values(
"requirement_id", "framework", "version", "description"
)
.distinct()
.annotate(
total_instances=Count("id"),
manual_count=Count("id", filter=Q(requirement_status="MANUAL")),
)
)
passed_instances = (
filtered_queryset.filter(requirement_status="PASS")
.values("requirement_id")
.annotate(pass_count=Count("id"))
)
passed_counts = {
item["requirement_id"]: item["pass_count"] for item in passed_instances
}
requirements_summary = []
for requirement in all_requirements:
requirement_id = requirement["requirement_id"]
total_instances = requirement["total_instances"]
passed_count = passed_counts.get(requirement_id, 0)
is_manual = requirement["manual_count"] == total_instances
if is_manual:
requirement_status = "MANUAL"
elif passed_count == total_instances:
requirement_status = "PASS"
else:
requirement_status = "FAIL"
requirements_summary.append(
{
"id": requirement_id,
"framework": requirement["framework"],
"version": requirement["version"],
"description": requirement["description"],
"status": requirement_status,
}
)
serializer = self.get_serializer(requirements_summary, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(detail=False, methods=["get"], url_name="attributes")
def attributes(self, request):
compliance_id = request.query_params.get("filter[compliance_id]")
if not compliance_id:
raise ValidationError(
[
{
"detail": "This query parameter is required.",
"status": 400,
"source": {"pointer": "filter[compliance_id]"},
"code": "required",
}
]
)
provider_type = None
try:
sample_requirement = (
self.get_queryset().filter(compliance_id=compliance_id).first()
)
if sample_requirement:
provider_type = sample_requirement.scan.provider.provider
except Exception:
pass
# If we couldn't determine from database, try each provider type
if not provider_type:
for pt in Provider.ProviderChoices.values:
if compliance_id in PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE.get(pt, {}):
provider_type = pt
break
if not provider_type:
raise NotFound(detail=f"Compliance framework '{compliance_id}' not found.")
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE.get(
provider_type, {}
)
compliance_framework = compliance_template.get(compliance_id)
if not compliance_framework:
raise NotFound(detail=f"Compliance framework '{compliance_id}' not found.")
attribute_data = []
for requirement_id, requirement in compliance_framework.get(
"requirements", {}
).items():
check_ids = list(requirement.get("checks", {}).keys())
metadata = requirement.get("attributes", [])
attribute_data.append(
{
"id": requirement_id,
"framework": compliance_framework.get("framework", ""),
"version": compliance_framework.get("version", ""),
"description": requirement.get("description", ""),
"attributes": {"metadata": metadata, "check_ids": check_ids},
}
)
serializer = self.get_serializer(attribute_data, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@extend_schema(tags=["Overview"])
@extend_schema_view(
@@ -2578,7 +2886,7 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
class OverviewViewSet(BaseRLSViewSet):
queryset = ComplianceOverview.objects.all()
http_method_names = ["get"]
ordering = ["-id"]
ordering = ["-inserted_at"]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of
# the provider through the provider group)
required_permissions = []
@@ -2823,7 +3131,7 @@ class ScheduleViewSet(BaseRLSViewSet):
with transaction.atomic():
task = schedule_provider_scan(provider_instance)
prowler_task = Task.objects.get_with_retry(id=task.id)
prowler_task = Task.objects.get(id=task.id)
self.response_serializer_class = TaskSerializer
output_serializer = self.get_serializer(prowler_task)
+2
View File
@@ -26,6 +26,7 @@ INSTALLED_APPS = [
"rest_framework",
"corsheaders",
"drf_spectacular",
"drf_spectacular_jsonapi",
"django_guid",
"rest_framework_json_api",
"django_celery_results",
@@ -127,6 +128,7 @@ DJANGO_GUID = {
}
DATABASE_ROUTERS = ["api.db_router.MainRouter"]
POSTGRES_EXTRA_DB_BACKEND_BASE = "database_backend"
# Password validation
+110
View File
@@ -15,6 +15,7 @@ from tasks.jobs.backfill import backfill_resource_scan_summaries
from api.db_utils import rls_transaction
from api.models import (
ComplianceOverview,
ComplianceRequirementOverview,
Finding,
Integration,
IntegrationProviderRelationship,
@@ -29,6 +30,7 @@ from api.models import (
Scan,
ScanSummary,
StateChoices,
StatusChoices,
Task,
User,
UserRoleRelationship,
@@ -777,6 +779,114 @@ def compliance_overviews_fixture(scans_fixture, tenants_fixture):
return compliance_overview1, compliance_overview2
@pytest.fixture
def compliance_requirements_overviews_fixture(scans_fixture, tenants_fixture):
"""Fixture for ComplianceRequirementOverview objects used by the new ComplianceOverviewViewSet."""
tenant = tenants_fixture[0]
scan1, scan2, scan3 = scans_fixture
# Create ComplianceRequirementOverview objects for scan1
requirement_overview1 = ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan1,
compliance_id="aws_account_security_onboarding_aws",
framework="AWS-Account-Security-Onboarding",
version="1.0",
description="Description for AWS Account Security Onboarding",
region="eu-west-1",
requirement_id="requirement1",
requirement_status=StatusChoices.PASS,
passed_checks=2,
failed_checks=0,
total_checks=2,
)
requirement_overview2 = ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan1,
compliance_id="aws_account_security_onboarding_aws",
framework="AWS-Account-Security-Onboarding",
version="1.0",
description="Description for AWS Account Security Onboarding",
region="eu-west-1",
requirement_id="requirement2",
requirement_status=StatusChoices.PASS,
passed_checks=2,
failed_checks=0,
total_checks=2,
)
requirement_overview3 = ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan1,
compliance_id="aws_account_security_onboarding_aws",
framework="AWS-Account-Security-Onboarding",
version="1.0",
description="Description for AWS Account Security Onboarding",
region="eu-west-2",
requirement_id="requirement1",
requirement_status=StatusChoices.PASS,
passed_checks=2,
failed_checks=0,
total_checks=2,
)
requirement_overview4 = ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan1,
compliance_id="aws_account_security_onboarding_aws",
framework="AWS-Account-Security-Onboarding",
version="1.0",
description="Description for AWS Account Security Onboarding",
region="eu-west-2",
requirement_id="requirement2",
requirement_status=StatusChoices.FAIL,
passed_checks=1,
failed_checks=1,
total_checks=2,
)
requirement_overview5 = ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan1,
compliance_id="aws_account_security_onboarding_aws",
framework="AWS-Account-Security-Onboarding",
version="1.0",
description="Description for AWS Account Security Onboarding (MANUAL)",
region="eu-west-2",
requirement_id="requirement3",
requirement_status=StatusChoices.MANUAL,
passed_checks=0,
failed_checks=0,
total_checks=0,
)
# Create a different compliance framework for testing
requirement_overview6 = ComplianceRequirementOverview.objects.create(
tenant=tenant,
scan=scan1,
compliance_id="cis_1.4_aws",
framework="CIS-1.4-AWS",
version="1.4",
description="CIS AWS Foundations Benchmark v1.4.0",
region="eu-west-1",
requirement_id="cis_requirement1",
requirement_status=StatusChoices.FAIL,
passed_checks=0,
failed_checks=3,
total_checks=3,
)
return (
requirement_overview1,
requirement_overview2,
requirement_overview3,
requirement_overview4,
requirement_overview5,
requirement_overview6,
)
def get_api_tokens(
api_client, user_email: str, user_password: str, tenant_id: str = None
) -> tuple[str, str]:
+15
View File
@@ -0,0 +1,15 @@
import django.db
from django.db.backends.postgresql.base import (
DatabaseWrapper as BuiltinPostgresDatabaseWrapper,
)
from psycopg2 import InterfaceError
class DatabaseWrapper(BuiltinPostgresDatabaseWrapper):
def create_cursor(self, name=None):
try:
return super().create_cursor(name=name)
except InterfaceError:
django.db.close_old_connections()
django.db.connection.connect()
return super().create_cursor(name=name)
+113 -70
View File
@@ -13,9 +13,9 @@ from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
generate_scan_compliance,
)
from api.db_utils import rls_transaction
from api.db_utils import create_objects_in_batches, rls_transaction
from api.models import (
ComplianceOverview,
ComplianceRequirementOverview,
Finding,
Provider,
Resource,
@@ -119,7 +119,6 @@ def perform_prowler_scan(
ValueError: If the provider cannot be connected.
"""
check_status_by_region = {}
exception = None
unique_resources = set()
scan_resource_cache: set[tuple[str, str, str, str]] = set()
@@ -293,16 +292,6 @@ def perform_prowler_scan(
)
finding_instance.add_resources([resource_instance])
# Update compliance data if applicable
if finding.status.value == "MUTED":
continue
region_dict = check_status_by_region.setdefault(finding.region, {})
current_status = region_dict.get(finding.check_id)
if current_status == "FAIL":
continue
region_dict[finding.check_id] = finding.status.value
# Update scan resource summaries
scan_resource_cache.add(
(
@@ -335,63 +324,6 @@ def perform_prowler_scan(
if exception is not None:
raise exception
try:
regions = prowler_provider.get_regions()
except AttributeError:
regions = set()
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
provider_instance.provider
]
compliance_overview_by_region = {
region: deepcopy(compliance_template) for region in regions
}
for region, check_status in check_status_by_region.items():
compliance_data = compliance_overview_by_region.setdefault(
region, deepcopy(compliance_template)
)
for check_name, status in check_status.items():
generate_scan_compliance(
compliance_data,
provider_instance.provider,
check_name,
status,
)
# Prepare compliance overview objects
compliance_overview_objects = []
for region, compliance_data in compliance_overview_by_region.items():
for compliance_id, compliance in compliance_data.items():
compliance_overview_objects.append(
ComplianceOverview(
tenant_id=tenant_id,
scan=scan_instance,
region=region,
compliance_id=compliance_id,
framework=compliance["framework"],
version=compliance["version"],
description=compliance["description"],
requirements=compliance["requirements"],
requirements_passed=compliance["requirements_status"]["passed"],
requirements_failed=compliance["requirements_status"]["failed"],
requirements_manual=compliance["requirements_status"]["manual"],
total_requirements=compliance["total_requirements"],
)
)
try:
with rls_transaction(tenant_id):
ComplianceOverview.objects.bulk_create(
compliance_overview_objects, batch_size=500
)
except Exception as overview_exception:
import sentry_sdk
sentry_sdk.capture_exception(overview_exception)
logger.error(
f"Error storing compliance overview for scan {scan_id}: {overview_exception}"
)
try:
resource_scan_summaries = [
ResourceScanSummary(
@@ -570,3 +502,114 @@ def aggregate_findings(tenant_id: str, scan_id: str):
for agg in aggregation
}
ScanSummary.objects.bulk_create(scan_aggregations, batch_size=3000)
def create_compliance_requirements(tenant_id: str, scan_id: str):
"""
Create detailed compliance requirement overview records for a scan.
This function processes the compliance data collected during a scan and creates
individual records for each compliance requirement in each region. These detailed
records provide a granular view of compliance status.
Args:
tenant_id (str): The ID of the tenant for which to create records.
scan_id (str): The ID of the scan for which to create records.
Returns:
dict: A dictionary containing the number of requirements created and the regions processed.
Raises:
ValidationError: If tenant_id is not a valid UUID.
"""
try:
with rls_transaction(tenant_id):
scan_instance = Scan.objects.get(pk=scan_id)
provider_instance = scan_instance.provider
prowler_provider = initialize_prowler_provider(provider_instance)
# Get check status data by region from findings
check_status_by_region = {}
with rls_transaction(tenant_id):
findings = Finding.objects.filter(scan_id=scan_id, muted=False)
for finding in findings:
# Get region from resources
for resource in finding.resources.all():
region = resource.region
region_dict = check_status_by_region.setdefault(region, {})
current_status = region_dict.get(finding.check_id)
if current_status == "FAIL":
continue
region_dict[finding.check_id] = finding.status
try:
# Try to get regions from provider
regions = prowler_provider.get_regions()
except (AttributeError, Exception):
# If not available, use regions from findings
regions = set(check_status_by_region.keys())
# Get compliance template for the provider
compliance_template = PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE[
provider_instance.provider
]
# Create compliance data by region
compliance_overview_by_region = {
region: deepcopy(compliance_template) for region in regions
}
# Apply check statuses to compliance data
for region, check_status in check_status_by_region.items():
compliance_data = compliance_overview_by_region.setdefault(
region, deepcopy(compliance_template)
)
for check_name, status in check_status.items():
generate_scan_compliance(
compliance_data,
provider_instance.provider,
check_name,
status,
)
# Prepare compliance requirement objects
compliance_requirement_objects = []
for region, compliance_data in compliance_overview_by_region.items():
for compliance_id, compliance in compliance_data.items():
# Create an overview record for each requirement within each compliance framework
for requirement_id, requirement in compliance["requirements"].items():
compliance_requirement_objects.append(
ComplianceRequirementOverview(
tenant_id=tenant_id,
scan=scan_instance,
region=region,
compliance_id=compliance_id,
framework=compliance["framework"],
version=compliance["version"],
requirement_id=requirement_id,
description=requirement["description"],
passed_checks=requirement["checks_status"]["pass"],
failed_checks=requirement["checks_status"]["fail"],
total_checks=requirement["checks_status"]["total"],
requirement_status=requirement["status"],
)
)
# Bulk create requirement records
create_objects_in_batches(
tenant_id, ComplianceRequirementOverview, compliance_requirement_objects
)
return {
"requirements_created": len(compliance_requirement_objects),
"regions_processed": list(regions),
"compliance_frameworks": (
list(compliance_overview_by_region.get(list(regions)[0], {}).keys())
if regions
else []
),
}
except Exception as e:
logger.error(f"Error creating compliance requirements for scan {scan_id}: {e}")
raise e
+25 -1
View File
@@ -17,7 +17,11 @@ from tasks.jobs.export import (
_generate_output_directory,
_upload_to_s3,
)
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan
from tasks.jobs.scan import (
aggregate_findings,
create_compliance_requirements,
perform_prowler_scan,
)
from tasks.utils import batched, get_next_execution_datetime
from api.compliance import get_compliance_frameworks
@@ -101,6 +105,7 @@ def perform_scan_task(
chain(
perform_scan_summary_task.si(tenant_id, scan_id),
create_compliance_requirements_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
@@ -211,6 +216,9 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
chain(
perform_scan_summary_task.si(tenant_id, scan_instance.id),
create_compliance_requirements_task.si(
tenant_id=tenant_id, scan_id=str(scan_instance.id)
),
generate_outputs.si(
scan_id=str(scan_instance.id), provider_id=provider_id, tenant_id=tenant_id
),
@@ -371,3 +379,19 @@ def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
scan_id (str): The scan identifier.
"""
return backfill_resource_scan_summaries(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(base=RLSTask, name="scan-compliance-overviews")
def create_compliance_requirements_task(tenant_id: str, scan_id: str):
"""
Creates detailed compliance requirement records for a scan.
This task processes the compliance data collected during a scan and creates
individual records for each compliance requirement in each region. These detailed
records provide a granular view of compliance status.
Args:
tenant_id (str): The tenant ID for which to create records.
scan_id (str): The ID of the scan for which to create records.
"""
return create_compliance_requirements(tenant_id=tenant_id, scan_id=scan_id)
+816 -8
View File
@@ -7,11 +7,13 @@ import pytest
from tasks.jobs.scan import (
_create_finding_delta,
_store_resources,
create_compliance_requirements,
perform_prowler_scan,
)
from tasks.utils import CustomEncoder
from api.models import (
ComplianceRequirementOverview,
Finding,
Provider,
Resource,
@@ -235,7 +237,7 @@ class TestPerformScan:
):
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
provider_instance.id = "provider123"
finding = MagicMock()
finding.resource_uid = "resource_uid_123"
@@ -250,15 +252,16 @@ class TestPerformScan:
resource_instance.region = finding.region
mock_get_or_create_resource.return_value = (resource_instance, True)
tag_instance = MagicMock()
mock_get_or_create_tag.return_value = (tag_instance, True)
resource, resource_uid_tuple = _store_resources(
finding, tenant_id, provider_instance
finding, str(tenant_id), provider_instance
)
mock_get_or_create_resource.assert_called_once_with(
tenant_id=tenant_id,
tenant_id=str(tenant_id),
provider=provider_instance,
uid=finding.resource_uid,
defaults={
@@ -305,11 +308,11 @@ class TestPerformScan:
mock_get_or_create_tag.return_value = (tag_instance, True)
resource, resource_uid_tuple = _store_resources(
finding, tenant_id, provider_instance
finding, str(tenant_id), provider_instance
)
mock_get_or_create_resource.assert_called_once_with(
tenant_id=tenant_id,
tenant_id=str(tenant_id),
provider=provider_instance,
uid=finding.resource_uid,
defaults={
@@ -363,14 +366,14 @@ class TestPerformScan:
]
resource, resource_uid_tuple = _store_resources(
finding, tenant_id, provider_instance
finding, str(tenant_id), provider_instance
)
mock_get_or_create_tag.assert_any_call(
tenant_id=tenant_id, key="tag1", value="value1"
tenant_id=str(tenant_id), key="tag1", value="value1"
)
mock_get_or_create_tag.assert_any_call(
tenant_id=tenant_id, key="tag2", value="value2"
tenant_id=str(tenant_id), key="tag2", value="value2"
)
resource_instance.upsert_or_delete_tags.assert_called_once()
tags_passed = resource_instance.upsert_or_delete_tags.call_args[1]["tags"]
@@ -382,3 +385,808 @@ class TestPerformScan:
# TODO Add tests for aggregations
@pytest.mark.django_db
class TestCreateComplianceRequirements:
def test_create_compliance_requirements_success(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
findings_fixture,
resources_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch("tasks.jobs.scan.generate_scan_compliance"),
patch("tasks.jobs.scan.create_objects_in_batches") as mock_create_objects,
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan.provider = provider
scan.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
mock_prowler_provider_instance = MagicMock()
mock_prowler_provider_instance.get_regions.return_value = [
"us-east-1",
"us-west-2",
]
mock_initialize_prowler_provider.return_value = (
mock_prowler_provider_instance
)
mock_compliance_template.__getitem__.return_value = {
"cis_1.4_aws": {
"framework": "CIS AWS Foundations Benchmark",
"version": "1.4.0",
"requirements": {
"1.1": {
"description": "Ensure root access key does not exist",
"checks_status": {
"pass": 0,
"fail": 0,
"manual": 0,
"total": 1,
},
"status": "PASS",
},
"1.2": {
"description": "Ensure MFA is enabled for root account",
"checks_status": {
"pass": 0,
"fail": 1,
"manual": 0,
"total": 1,
},
"status": "FAIL",
},
},
},
"aws_account_security_onboarding_aws": {
"framework": "AWS Account Security Onboarding",
"version": "1.0",
"requirements": {
"requirement1": {
"description": "Basic security requirement",
"checks_status": {
"pass": 1,
"fail": 0,
"manual": 0,
"total": 1,
},
"status": "PASS",
},
},
},
}
mock_findings_filter.return_value = []
result = create_compliance_requirements(tenant_id, scan_id)
assert "requirements_created" in result
assert "regions_processed" in result
assert "compliance_frameworks" in result
assert result["regions_processed"] == ["us-east-1", "us-west-2"]
assert result["requirements_created"] == 6
assert len(result["compliance_frameworks"]) == 2
mock_create_objects.assert_called_once()
call_args = mock_create_objects.call_args[0]
assert call_args[0] == tenant_id
assert call_args[1] == ComplianceRequirementOverview
assert len(call_args[2]) == 6
compliance_objects = call_args[2]
for obj in compliance_objects:
assert isinstance(obj, ComplianceRequirementOverview)
assert obj.tenant.id == tenant.id
assert obj.scan == scan
assert obj.region in ["us-east-1", "us-west-2"]
assert obj.compliance_id in [
"cis_1.4_aws",
"aws_account_security_onboarding_aws",
]
def test_create_compliance_requirements_with_findings(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch(
"tasks.jobs.scan.generate_scan_compliance"
) as mock_generate_compliance,
patch("tasks.jobs.scan.create_objects_in_batches"),
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan.provider = provider
scan.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
mock_finding1 = MagicMock()
mock_finding1.check_id = "check1"
mock_finding1.status = "PASS"
mock_resource1 = MagicMock()
mock_resource1.region = "us-east-1"
mock_finding1.resources.all.return_value = [mock_resource1]
mock_finding2 = MagicMock()
mock_finding2.check_id = "check2"
mock_finding2.status = "FAIL"
mock_resource2 = MagicMock()
mock_resource2.region = "us-west-2"
mock_finding2.resources.all.return_value = [mock_resource2]
mock_findings_filter.return_value = [mock_finding1, mock_finding2]
mock_prowler_provider_instance = MagicMock()
mock_prowler_provider_instance.get_regions.return_value = [
"us-east-1",
"us-west-2",
]
mock_initialize_prowler_provider.return_value = (
mock_prowler_provider_instance
)
mock_compliance_template.__getitem__.return_value = {
"test_compliance": {
"framework": "Test Framework",
"version": "1.0",
"requirements": {
"req_1": {
"description": "Test Requirement 1",
"checks": {"check_1": None},
"checks_status": {
"pass": 2,
"fail": 1,
"manual": 0,
"total": 3,
},
"status": "FAIL",
},
"req_2": {
"description": "Test Requirement 2",
"checks": {"check_2": None},
"checks_status": {
"pass": 2,
"fail": 0,
"manual": 0,
"total": 2,
},
"status": "PASS",
},
},
}
}
result = create_compliance_requirements(tenant_id, scan_id)
mock_findings_filter.assert_called_once_with(scan_id=scan_id, muted=False)
assert mock_generate_compliance.call_count == 2
assert result["requirements_created"] == 4
assert set(result["regions_processed"]) == {"us-east-1", "us-west-2"}
def test_create_compliance_requirements_no_provider_regions(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch("tasks.jobs.scan.generate_scan_compliance"),
patch("tasks.jobs.scan.create_objects_in_batches"),
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.KUBERNETES
provider.save()
scan.provider = provider
scan.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
mock_finding = MagicMock()
mock_finding.check_id = "check1"
mock_finding.status = "PASS"
mock_resource = MagicMock()
mock_resource.region = "default"
mock_finding.resources.all.return_value = [mock_resource]
mock_findings_filter.return_value = [mock_finding]
mock_prowler_provider_instance = MagicMock()
mock_prowler_provider_instance.get_regions.side_effect = AttributeError(
"No get_regions method"
)
mock_initialize_prowler_provider.return_value = (
mock_prowler_provider_instance
)
mock_compliance_template.__getitem__.return_value = {
"kubernetes_cis": {
"framework": "CIS Kubernetes Benchmark",
"version": "1.6.0",
"requirements": {
"1.1": {
"description": "Test requirement",
"checks_status": {
"pass": 0,
"fail": 0,
"manual": 0,
"total": 1,
},
"status": "PASS",
},
},
},
}
result = create_compliance_requirements(tenant_id, scan_id)
assert result["regions_processed"] == ["default"]
def test_create_compliance_requirements_empty_findings(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch(
"tasks.jobs.scan.generate_scan_compliance"
) as mock_generate_compliance,
patch("tasks.jobs.scan.create_objects_in_batches"),
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan.provider = provider
scan.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
mock_findings_filter.return_value = []
mock_prowler_provider_instance = MagicMock()
mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"]
mock_initialize_prowler_provider.return_value = (
mock_prowler_provider_instance
)
mock_compliance_template.__getitem__.return_value = {
"cis_1.4_aws": {
"framework": "CIS AWS Foundations Benchmark",
"version": "1.4.0",
"requirements": {
"1.1": {
"description": "Test requirement",
"checks_status": {
"pass": 0,
"fail": 0,
"manual": 0,
"total": 1,
},
"status": "PASS",
},
},
},
}
mock_findings_filter.return_value = []
result = create_compliance_requirements(tenant_id, scan_id)
assert result["regions_processed"] == ["us-east-1"]
assert result["requirements_created"] == 1
mock_generate_compliance.assert_not_called()
def test_create_compliance_requirements_error_handling(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan.provider = provider
scan.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
mock_initialize_prowler_provider.side_effect = Exception(
"Provider initialization failed"
)
with pytest.raises(Exception, match="Provider initialization failed"):
create_compliance_requirements(tenant_id, scan_id)
def test_create_compliance_requirements_muted_findings_excluded(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch("tasks.jobs.scan.generate_scan_compliance"),
patch("tasks.jobs.scan.create_objects_in_batches"),
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan.provider = provider
scan.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
mock_findings_filter.return_value = []
mock_prowler_provider_instance = MagicMock()
mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"]
mock_initialize_prowler_provider.return_value = (
mock_prowler_provider_instance
)
mock_compliance_template.__getitem__.return_value = {}
mock_findings_filter.return_value = []
create_compliance_requirements(tenant_id, scan_id)
mock_findings_filter.assert_called_once_with(scan_id=scan_id, muted=False)
def test_create_compliance_requirements_check_status_priority(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch(
"tasks.jobs.scan.generate_scan_compliance"
) as mock_generate_compliance,
patch("tasks.jobs.scan.create_objects_in_batches"),
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
provider = providers_fixture[0]
provider.provider = Provider.ProviderChoices.AWS
provider.save()
scan.provider = provider
scan.save()
tenant_id = str(tenant.id)
scan_id = str(scan.id)
mock_finding1 = MagicMock()
mock_finding1.check_id = "check1"
mock_finding1.status = "PASS"
mock_resource1 = MagicMock()
mock_resource1.region = "us-east-1"
mock_finding1.resources.all.return_value = [mock_resource1]
mock_finding2 = MagicMock()
mock_finding2.check_id = "check1"
mock_finding2.status = "FAIL"
mock_resource2 = MagicMock()
mock_resource2.region = "us-east-1"
mock_finding2.resources.all.return_value = [mock_resource2]
mock_findings_filter.return_value = [mock_finding1, mock_finding2]
mock_prowler_provider_instance = MagicMock()
mock_prowler_provider_instance.get_regions.return_value = ["us-east-1"]
mock_initialize_prowler_provider.return_value = (
mock_prowler_provider_instance
)
mock_compliance_template.__getitem__.return_value = {
"cis_1.4_aws": {
"framework": "CIS AWS Foundations Benchmark",
"version": "1.4.0",
"requirements": {
"1.1": {
"description": "Test requirement",
"checks_status": {
"pass": 0,
"fail": 0,
"manual": 0,
"total": 1,
},
"status": "PASS",
},
},
},
}
create_compliance_requirements(tenant_id, scan_id)
assert mock_generate_compliance.call_count == 1
def test_compliance_overview_aggregation_requirement_fail_priority(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch(
"tasks.jobs.scan.generate_scan_compliance"
) as mock_generate_compliance,
patch("tasks.jobs.scan.create_objects_in_batches") as mock_create_objects,
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
providers_fixture[0]
mock_findings_filter.return_value = []
mock_prowler_provider = MagicMock()
mock_prowler_provider.get_regions.return_value = [
"us-east-1",
"us-west-2",
"eu-west-1",
]
mock_initialize_prowler_provider.return_value = mock_prowler_provider
mock_compliance_template.__getitem__.return_value = {
"test_compliance": {
"framework": "Test Framework",
"version": "1.0",
"requirements": {
"req_1": {
"description": "Test Requirement 1",
"checks": {"check_1": None},
"checks_status": {
"pass": 2,
"fail": 1,
"manual": 0,
"total": 3,
},
"status": "FAIL",
}
},
}
}
mock_generate_compliance.return_value = {
"test_compliance": {
"framework": "Test Framework",
"version": "1.0",
"requirements": {
"req_1": {
"description": "Test Requirement 1",
"checks": {
"check_1": {
"us-east-1": {"status": "PASS"},
"us-west-2": {"status": "FAIL"},
"eu-west-1": {"status": "PASS"},
}
},
"checks_status": {
"pass": 2,
"fail": 1,
"manual": 0,
"total": 3,
},
"status": "FAIL",
}
},
}
}
created_objects = []
mock_create_objects.side_effect = (
lambda tenant_id, model, objs, batch_size=500: created_objects.extend(
objs
)
)
create_compliance_requirements(str(tenant.id), str(scan.id))
assert len(created_objects) == 3
assert all(obj.requirement_status == "FAIL" for obj in created_objects)
def test_compliance_overview_aggregation_requirement_pass_all_regions(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch(
"tasks.jobs.scan.generate_scan_compliance"
) as mock_generate_compliance,
patch("tasks.jobs.scan.create_objects_in_batches") as mock_create_objects,
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
providers_fixture[0]
mock_findings_filter.return_value = []
mock_prowler_provider = MagicMock()
mock_prowler_provider.get_regions.return_value = ["us-east-1", "us-west-2"]
mock_initialize_prowler_provider.return_value = mock_prowler_provider
mock_compliance_template.__getitem__.return_value = {
"test_compliance": {
"framework": "Test Framework",
"version": "1.0",
"requirements": {
"req_1": {
"description": "Test Requirement 1",
"checks": {"check_1": None},
"checks_status": {
"pass": 2,
"fail": 0,
"manual": 0,
"total": 2,
},
"status": "PASS",
}
},
}
}
mock_generate_compliance.return_value = {
"test_compliance": {
"framework": "Test Framework",
"version": "1.0",
"requirements": {
"req_1": {
"description": "Test Requirement 1",
"checks": {
"check_1": {
"us-east-1": {"status": "PASS"},
"us-west-2": {"status": "PASS"},
}
},
"checks_status": {
"pass": 2,
"fail": 0,
"manual": 0,
"total": 2,
},
"status": "PASS",
}
},
}
}
created_objects = []
mock_create_objects.side_effect = (
lambda tenant_id, model, objs, batch_size=500: created_objects.extend(
objs
)
)
create_compliance_requirements(str(tenant.id), str(scan.id))
assert len(created_objects) == 2
assert all(obj.requirement_status == "PASS" for obj in created_objects)
def test_compliance_overview_aggregation_multiple_requirements_mixed_status(
self,
tenants_fixture,
scans_fixture,
providers_fixture,
):
with (
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
patch(
"tasks.jobs.scan.PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE"
) as mock_compliance_template,
patch(
"tasks.jobs.scan.generate_scan_compliance"
) as mock_generate_compliance,
patch("tasks.jobs.scan.create_objects_in_batches") as mock_create_objects,
patch("api.models.Finding.objects.filter") as mock_findings_filter,
):
tenant = tenants_fixture[0]
scan = scans_fixture[0]
providers_fixture[0]
mock_findings_filter.return_value = []
mock_prowler_provider = MagicMock()
mock_prowler_provider.get_regions.return_value = ["us-east-1", "us-west-2"]
mock_initialize_prowler_provider.return_value = mock_prowler_provider
mock_compliance_template.__getitem__.return_value = {
"test_compliance": {
"framework": "Test Framework",
"version": "1.0",
"requirements": {
"req_1": {
"description": "Test Requirement 1",
"checks": {"check_1": None},
"checks_status": {
"pass": 2,
"fail": 0,
"manual": 0,
"total": 2,
},
"status": "PASS",
},
"req_2": {
"description": "Test Requirement 2",
"checks": {"check_2": None},
"checks_status": {
"pass": 1,
"fail": 1,
"manual": 0,
"total": 2,
},
"status": "FAIL",
},
},
}
}
mock_generate_compliance.return_value = {
"test_compliance": {
"framework": "Test Framework",
"version": "1.0",
"requirements": {
"req_1": {
"description": "Test Requirement 1",
"checks": {
"check_1": {
"us-east-1": {"status": "PASS"},
"us-west-2": {"status": "PASS"},
}
},
"checks_status": {
"pass": 2,
"fail": 0,
"manual": 0,
"total": 2,
},
"status": "PASS",
},
"req_2": {
"description": "Test Requirement 2",
"checks": {
"check_2": {
"us-east-1": {"status": "PASS"},
"us-west-2": {"status": "FAIL"},
}
},
"checks_status": {
"pass": 1,
"fail": 1,
"manual": 0,
"total": 2,
},
"status": "FAIL",
},
},
}
}
created_objects = []
mock_create_objects.side_effect = (
lambda tenant_id, model, objs, batch_size=500: created_objects.extend(
objs
)
)
create_compliance_requirements(str(tenant.id), str(scan.id))
assert len(created_objects) == 4
req_1_objects = [
obj for obj in created_objects if obj.requirement_id == "req_1"
]
req_2_objects = [
obj for obj in created_objects if obj.requirement_id == "req_2"
]
assert len(req_1_objects) == 2
assert len(req_2_objects) == 2
assert all(obj.requirement_status == "PASS" for obj in req_1_objects)
assert all(obj.requirement_status == "FAIL" for obj in req_2_objects)
+21 -2
View File
@@ -26,6 +26,25 @@ services:
ports:
- ${UI_PORT:-3000}:${UI_PORT:-3000}
postgres-proxy:
image: edoburu/pgbouncer:latest
hostname: "postgres-db-proxy"
environment:
- DB_HOST=postgres-db
- DB_PORT=5432
- DB_USER=${POSTGRES_ADMIN_USER}
- DB_PASSWORD=${POSTGRES_ADMIN_PASSWORD}
- ADMIN_USERS=prowler_admin
- AUTH_TYPE=scram-sha-256
env_file:
- path: ./.env
required: false
ports:
- "5432:5432"
depends_on:
postgres:
condition: service_healthy
postgres:
image: postgres:16.3-alpine3.20
hostname: "postgres-db"
@@ -38,8 +57,8 @@ services:
env_file:
- path: .env
required: false
ports:
- "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
# ports:
# - "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
healthcheck:
test: ["CMD-SHELL", "sh -c 'pg_isready -U ${POSTGRES_ADMIN_USER} -d ${POSTGRES_DB}'"]
interval: 5s
+3
View File
@@ -170,6 +170,9 @@ These two new environment variables are **required** to execute the PowerShell m
- `M365_USER` should be your Microsoft account email using the **assigned domain in the tenant**. This means it must look like `example@YourCompany.onmicrosoft.com` or `example@YourCompany.com`, but it must be the exact domain assigned to that user in the tenant.
???+ warning
The user must not be MFA capable. Microsoft does not allow MFA capable users to authenticate programmatically. See [Microsoft documentation](https://learn.microsoft.com/en-us/entra/identity-platform/scenario-desktop-acquire-token-username-password?tabs=dotnet) for more information.
???+ warning
Using a tenant domain other than the one assigned — even if it belongs to the same tenant — will cause Prowler to fail, as Microsoft authentication will not succeed.
@@ -138,6 +138,10 @@ Follow these steps to assign the permissions:
![Grant Admin Consent](./img/grant-admin-consent-delegated.png)
The final result of permission assignment should be this:
![Final Permission Assignment](./img/final-permissions-m365.png)
---
### Assign required roles to your user
Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

+6
View File
@@ -4,6 +4,9 @@ The **Prowler App** supports multiple users within a single tenant, enabling sea
[Roles](#roles) help you control user permissions, determining what actions each user can perform and the data they can access within Prowler. By default, each account includes an immutable **admin** role, ensuring that your account always retains administrative access.
???+ note
If the account is created without an invitation, a new tenant will be provisioned for it. However, if the account is created through an invitation, the user will join the inviters tenant.
## Membership
To get to User-Invitation Management we will focus on the Membership section.
@@ -156,6 +159,9 @@ Follow these steps to create a role for your account:
<img src="../img/rbac/role_create_1.png" alt="Role parameters" width="700"/>
???+ note
To assign read-only access, select only the `Unlimited Visibility` permission when creating the role. Then, go to the Users page and assign this role to the appropriate user.
#### Editing a Role
Follow these steps to edit a role on your account:
+14 -5
View File
@@ -2,7 +2,7 @@
All notable changes to the **Prowler SDK** are documented in this file.
## [5.8.0] (Prowler v5.8.0)
## [v5.8.0] (Prowler UNRELEASED)
### Added
- Add CIS 1.11 compliance framework for Kubernetes. [(#7790)](https://github.com/prowler-cloud/prowler/pull/7790)
@@ -23,12 +23,25 @@ All notable changes to the **Prowler SDK** are documented in this file.
- Add NIS 2 compliance framework for Azure. [(7857)](https://github.com/prowler-cloud/prowler/pull/7857)
- Add search bar in Dashboard Overview page. [(#7804)](https://github.com/prowler-cloud/prowler/pull/7804)
---
## [v5.7.3] (Prowler v5.7.3)
### Fixed
- Automatically encrypt password in Microsoft365 provider. [(#7784)](https://github.com/prowler-cloud/prowler/pull/7784)
- Remove last encrypted password appearances. [(#7825)](https://github.com/prowler-cloud/prowler/pull/7825)
---
## [v5.7.2] (Prowler v5.7.2)
### Fixed
- Fix `m365_powershell test_credentials` to use sanitized credentials. [(#7761)](https://github.com/prowler-cloud/prowler/pull/7761)
- Fix `admincenter_users_admins_reduced_license_footprint` check logic to pass when admin user has no license. [(#7779)](https://github.com/prowler-cloud/prowler/pull/7779)
- Fix `m365_powershell` to close the PowerShell sessions in msgraph services. [(#7816)](https://github.com/prowler-cloud/prowler/pull/7816)
- Fix `defender_ensure_notify_alerts_severity_is_high`check to accept high or lower severity. [(#7862)](https://github.com/prowler-cloud/prowler/pull/7862)
- Replace `Directory.Read.All` permission with `Domain.Read.All` which is more restrictive. [(#7888)](https://github.com/prowler-cloud/prowler/pull/7888)
- Split calls to list Azure Functions attributes. [(#7778)](https://github.com/prowler-cloud/prowler/pull/7778)
---
@@ -57,14 +70,12 @@ All notable changes to the **Prowler SDK** are documented in this file.
- Update and upgrade CIS for all the providers [(#7738)](https://github.com/prowler-cloud/prowler/pull/7738)
- Cover policies with conditions with SNS endpoint in `sns_topics_not_publicly_accessible`. [(#7750)](https://github.com/prowler-cloud/prowler/pull/7750)
- Change severity logic for `ec2_securitygroup_allow_ingress_from_internet_to_all_ports` check. [(#7764)](https://github.com/prowler-cloud/prowler/pull/7764)
- Automatically encrypt password in Microsoft365 provider. [(#7784)](https://github.com/prowler-cloud/prowler/pull/7784)
---
## [v5.6.0] (Prowler v5.6.0)
### Added
- Add SOC2 compliance framework to Azure. [(#7489)](https://github.com/prowler-cloud/prowler/pull/7489)
- Add check for unused Service Accounts in GCP. [(#7419)](https://github.com/prowler-cloud/prowler/pull/7419)
- Add Powershell to Microsoft365. [(#7331)](https://github.com/prowler-cloud/prowler/pull/7331)
@@ -114,7 +125,6 @@ All notable changes to the **Prowler SDK** are documented in this file.
- Add Microsoft User and User Credential auth to reports [(#7681)](https://github.com/prowler-cloud/prowler/pull/7681)
### Fixed
- Fix package name location in pyproject.toml while replicating for prowler-cloud. [(#7531)](https://github.com/prowler-cloud/prowler/pull/7531)
- Remove cache in PyPI release action. [(#7532)](https://github.com/prowler-cloud/prowler/pull/7532)
- Add the correct values for logger.info inside iam service. [(#7526)](https://github.com/prowler-cloud/prowler/pull/7526)
@@ -135,7 +145,6 @@ All notable changes to the **Prowler SDK** are documented in this file.
## [v5.5.1] (Prowler v5.5.1)
### Fixed
- Add default name to contacts in Azure Defender. [(#7483)](https://github.com/prowler-cloud/prowler/pull/7483)
- Handle projects without ID in GCP. [(#7496)](https://github.com/prowler-cloud/prowler/pull/7496)
- Restore packages location in PyProject. [(#7510)](https://github.com/prowler-cloud/prowler/pull/7510)
+2 -1
View File
@@ -22,7 +22,8 @@ module.exports = {
},
},
rules: {
"no-console": 1,
// console.error are allowed but no console.log
"no-console": ["error", { allow: ["error"] }],
eqeqeq: 2,
quotes: ["error", "double", "avoid-escape"],
"@typescript-eslint/no-explicit-any": "off",
+19
View File
@@ -4,16 +4,35 @@ All notable changes to the **Prowler UI** are documented in this file.
## [v1.8.0] (Prowler v5.8.0) Not released
### 🐞 Fixes
- Fix sync between filter buttons and URL when filters change. [(#7928)](https://github.com/prowler-cloud/prowler/pull/7928)
- Improve heatmap perfomance. [(#7934)](https://github.com/prowler-cloud/prowler/pull/7934)
### 🚀 Added
- New profile page with details about the user and their roles. [(#7780)](https://github.com/prowler-cloud/prowler/pull/7780)
- Improved `SnippetChip` component and show resource name in new findings table. [(#7813)](https://github.com/prowler-cloud/prowler/pull/7813)
- Possibility to edit the organization name. [(#7829)](https://github.com/prowler-cloud/prowler/pull/7829)
- Add GCP credential method (Account Service Key). [(#7872)](https://github.com/prowler-cloud/prowler/pull/7872)
- Add compliance detail view: ENS [(#7853)](https://github.com/prowler-cloud/prowler/pull/7853)
- Add compliance detail view: ISO [(#7897)](https://github.com/prowler-cloud/prowler/pull/7897)
- Add compliance detail view: CIS [(#7913)](https://github.com/prowler-cloud/prowler/pull/7913)
- Add compliance detail view: AWS Well-Architected Framework [(#7925)](https://github.com/prowler-cloud/prowler/pull/7925)
- Improve `Scan ID` filter by adding more context and enhancing the UI/UX. [(#7949)](https://github.com/prowler-cloud/prowler/pull/7949)
### 🔄 Changed
- Add `Provider UID` filter to scans page. [(#7820)](https://github.com/prowler-cloud/prowler/pull/7820)
- Aligned Next.js version to `v14.2.29` across Prowler and Cloud environments for consistency and improved maintainability. [(#7962)](https://github.com/prowler-cloud/prowler/pull/7962)
---
## [v1.7.3] (Prowler v5.7.3)
### 🐞 Fixes
- Fix encrypted password typo in `formSchemas`. [(#7828)](https://github.com/prowler-cloud/prowler/pull/7828)
---
+74 -1
View File
@@ -30,7 +30,6 @@ export const getCompliancesOverview = async ({
});
const data = await compliances.json();
const parsedData = parseStringify(data);
revalidatePath("/compliance");
return parsedData;
} catch (error) {
@@ -79,3 +78,77 @@ export const getComplianceOverviewMetadataInfo = async ({
return undefined;
}
};
export const getComplianceAttributes = async (complianceId: string) => {
const headers = await getAuthHeaders({ contentType: false });
try {
const url = new URL(`${apiBaseUrl}/compliance-overviews/attributes`);
url.searchParams.append("filter[compliance_id]", complianceId);
const response = await fetch(url.toString(), {
headers,
});
if (!response.ok) {
throw new Error(
`Failed to fetch compliance attributes: ${response.statusText}`,
);
}
const data = await response.json();
const parsedData = parseStringify(data);
return parsedData;
} catch (error) {
// eslint-disable-next-line no-console
console.error("Error fetching compliance attributes:", error);
return undefined;
}
// */
};
export const getComplianceRequirements = async ({
complianceId,
scanId,
region,
}: {
complianceId: string;
scanId: string;
region?: string | string[];
}) => {
const headers = await getAuthHeaders({ contentType: false });
try {
const url = new URL(`${apiBaseUrl}/compliance-overviews/requirements`);
url.searchParams.append("filter[compliance_id]", complianceId);
url.searchParams.append("filter[scan_id]", scanId);
if (region) {
const regionValue = Array.isArray(region) ? region.join(",") : region;
url.searchParams.append("filter[region__in]", regionValue);
//remove page param
}
url.searchParams.delete("page");
const response = await fetch(url.toString(), {
headers,
});
if (!response.ok) {
throw new Error(
`Failed to fetch compliance requirements: ${response.statusText}`,
);
}
const data = await response.json();
const parsedData = parseStringify(data);
return parsedData;
} catch (error) {
// eslint-disable-next-line no-console
console.error("Error fetching compliance requirements:", error);
return undefined;
}
// */
};
@@ -0,0 +1,269 @@
import { Spacer } from "@nextui-org/react";
import Image from "next/image";
import { Suspense } from "react";
import {
getComplianceAttributes,
getComplianceOverviewMetadataInfo,
getComplianceRequirements,
} from "@/actions/compliances";
import { getProvider } from "@/actions/providers";
import { getScans } from "@/actions/scans";
import {
BarChart,
BarChartSkeleton,
ClientAccordionWrapper,
ComplianceHeader,
HeatmapChart,
HeatmapChartSkeleton,
PieChart,
PieChartSkeleton,
SkeletonAccordion,
} from "@/components/compliance";
import { getComplianceIcon } from "@/components/icons/compliance/IconCompliance";
import { ContentLayout } from "@/components/ui";
import {
calculateCategoryHeatmapData,
getComplianceMapper,
} from "@/lib/compliance/commons";
import { ScanProps } from "@/types";
import { Framework, RequirementsTotals } from "@/types/compliance";
interface ComplianceDetailSearchParams {
complianceId: string;
version?: string;
scanId?: string;
"filter[region__in]"?: string;
"filter[cis_profile_level]"?: string;
}
const ComplianceIconSmall = ({
logoPath,
title,
}: {
logoPath: string;
title: string;
}) => {
return (
<div className="relative h-6 w-6 flex-shrink-0">
<Image
src={logoPath}
alt={`${title} logo`}
fill
className="h-10 w-10 min-w-10 rounded-md border-1 border-gray-300 bg-white object-contain p-[2px]"
/>
</div>
);
};
const ChartsWrapper = ({
children,
}: {
children: React.ReactNode;
logoPath?: string;
}) => {
return (
<div className="mb-8 flex w-full flex-col items-center justify-between lg:flex-row">
{children}
</div>
);
};
export default async function ComplianceDetail({
params,
searchParams,
}: {
params: { compliancetitle: string };
searchParams: ComplianceDetailSearchParams;
}) {
const { compliancetitle } = params;
const { complianceId, version, scanId } = searchParams;
const regionFilter = searchParams["filter[region__in]"];
const cisProfileFilter = searchParams["filter[cis_profile_level]"];
const logoPath = getComplianceIcon(compliancetitle);
// Create a key that includes region filter for Suspense
const searchParamsKey = JSON.stringify(searchParams || {});
const formattedTitle = compliancetitle.split("-").join(" ");
const pageTitle = version
? `Compliance Details: ${formattedTitle} - ${version}`
: `Compliance Details: ${formattedTitle}`;
// Fetch scans data
const scansData = await getScans({
filters: {
"filter[state]": "completed",
},
});
// Expand scans with provider information
const expandedScansData = scansData?.data?.length
? await Promise.all(
scansData.data.map(async (scan: ScanProps) => {
const providerId = scan.relationships?.provider?.data?.id;
if (!providerId) {
return { ...scan, providerInfo: null };
}
const formData = new FormData();
formData.append("id", providerId);
const providerData = await getProvider(formData);
return {
...scan,
providerInfo: providerData?.data
? {
provider: providerData.data.attributes.provider,
uid: providerData.data.attributes.uid,
alias: providerData.data.attributes.alias,
}
: null,
};
}),
)
: [];
const selectedScanId = scanId || expandedScansData[0]?.id || null;
// Fetch metadata info for regions
const metadataInfoData = await getComplianceOverviewMetadataInfo({
filters: {
"filter[scan_id]": selectedScanId,
},
});
const uniqueRegions = metadataInfoData?.data?.attributes?.regions || [];
return (
<ContentLayout
title={pageTitle}
icon={
logoPath ? (
<ComplianceIconSmall logoPath={logoPath} title={compliancetitle} />
) : (
"fluent-mdl2:compliance-audit"
)
}
>
<ComplianceHeader
scans={expandedScansData}
uniqueRegions={uniqueRegions}
showSearch={false}
framework={compliancetitle}
showProviders={false}
/>
<Suspense
key={searchParamsKey}
fallback={
<div className="space-y-8">
<ChartsWrapper logoPath={logoPath}>
<PieChartSkeleton />
<BarChartSkeleton />
<HeatmapChartSkeleton />
</ChartsWrapper>
<SkeletonAccordion />
</div>
}
>
<SSRComplianceContent
complianceId={complianceId}
scanId={selectedScanId}
region={regionFilter}
filter={cisProfileFilter}
logoPath={logoPath}
/>
</Suspense>
</ContentLayout>
);
}
const SSRComplianceContent = async ({
complianceId,
scanId,
region,
filter,
logoPath,
}: {
complianceId: string;
scanId: string;
region?: string;
filter?: string;
logoPath?: string;
}) => {
if (!scanId) {
return (
<div className="space-y-8">
<ChartsWrapper logoPath={logoPath}>
<PieChart pass={0} fail={0} manual={0} />
<BarChart sections={[]} />
<HeatmapChart categories={[]} />
</ChartsWrapper>
<ClientAccordionWrapper items={[]} defaultExpandedKeys={[]} />
</div>
);
}
// Get compliance data and attributes once
const [attributesData, requirementsData] = await Promise.all([
getComplianceAttributes(complianceId),
getComplianceRequirements({
complianceId,
scanId,
region,
}),
]);
// Determine framework from the first attribute item
const framework = attributesData?.data?.[0]?.attributes?.framework;
const mapper = getComplianceMapper(framework);
// Use the same data for both compliance view and heatmap
const data = mapper.mapComplianceData(
attributesData,
requirementsData,
filter,
);
// Calculate category heatmap data
const categoryHeatmapData = calculateCategoryHeatmapData(data);
const totalRequirements: RequirementsTotals = data.reduce(
(acc: RequirementsTotals, framework: Framework) => ({
pass: acc.pass + framework.pass,
fail: acc.fail + framework.fail,
manual: acc.manual + framework.manual,
}),
{ pass: 0, fail: 0, manual: 0 },
);
const accordionItems = mapper.toAccordionItems(data, scanId);
const topFailedSections = mapper.getTopFailedSections(data);
// Todo: rethink as every compliance has a different number of items
// const defaultKeys = accordionItems.slice(0, 2).map((item) => item.key);
const defaultKeys = [""];
return (
<div className="space-y-8">
<ChartsWrapper logoPath={logoPath}>
<PieChart
pass={totalRequirements.pass}
fail={totalRequirements.fail}
manual={totalRequirements.manual}
/>
<BarChart sections={topFailedSections} />
<HeatmapChart categories={categoryHeatmapData} />
</ChartsWrapper>
<Spacer className="h-1 w-full rounded-full bg-gray-200 dark:bg-gray-800" />
<ClientAccordionWrapper
items={accordionItems}
defaultExpandedKeys={defaultKeys}
/>
</div>
);
};
+21 -34
View File
@@ -1,6 +1,4 @@
export const dynamic = "force-dynamic";
import { Spacer } from "@nextui-org/react";
import { Suspense } from "react";
import { getCompliancesOverview } from "@/actions/compliances";
@@ -12,11 +10,10 @@ import {
ComplianceSkeletonGrid,
NoScansAvailable,
} from "@/components/compliance";
import { DataCompliance } from "@/components/compliance/data-compliance";
import { FilterControls } from "@/components/filters";
import { ComplianceHeader } from "@/components/compliance/compliance-header/compliance-header";
import { ContentLayout } from "@/components/ui";
import { DataTableFilterCustom } from "@/components/ui/table/data-table-filter-custom";
import { ComplianceOverviewData, ScanProps, SearchParamsProps } from "@/types";
import { ScanProps, SearchParamsProps } from "@/types";
import { ComplianceOverviewData } from "@/types/compliance";
export default async function Compliance({
searchParams,
@@ -84,21 +81,10 @@ export default async function Compliance({
<ContentLayout title="Compliance" icon="fluent-mdl2:compliance-audit">
{selectedScanId ? (
<>
<FilterControls search />
<Spacer y={8} />
<DataCompliance scans={expandedScansData} />
<Spacer y={8} />
<DataTableFilterCustom
filters={[
{
key: "region__in",
labelCheckboxGroup: "Regions",
values: uniqueRegions,
},
]}
defaultOpen={true}
<ComplianceHeader
scans={expandedScansData}
uniqueRegions={uniqueRegions}
/>
<Spacer y={12} />
<Suspense key={searchParamsKey} fallback={<ComplianceSkeletonGrid />}>
<SSRComplianceGrid searchParams={searchParams} />
</Suspense>
@@ -133,7 +119,11 @@ const SSRComplianceGrid = async ({
});
// Check if the response contains no data
if (!compliancesData || compliancesData?.data?.length === 0) {
if (
!compliancesData ||
!compliancesData.data ||
compliancesData.data.length === 0
) {
return (
<div className="flex h-full items-center">
<div className="text-sm text-default-500">
@@ -155,25 +145,22 @@ const SSRComplianceGrid = async ({
return (
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-3 2xl:grid-cols-4">
{compliancesData.data.map((compliance: ComplianceOverviewData) => {
const { attributes } = compliance;
const {
framework,
version,
requirements_status: { passed, total },
compliance_id,
} = attributes;
const { attributes, id } = compliance;
const { framework, version, requirements_passed, total_requirements } =
attributes;
return (
<ComplianceCard
key={compliance.id}
key={id}
title={framework}
version={version}
passingRequirements={passed}
totalRequirements={total}
prevPassingRequirements={passed}
prevTotalRequirements={total}
passingRequirements={requirements_passed}
totalRequirements={total_requirements}
prevPassingRequirements={requirements_passed}
prevTotalRequirements={total_requirements}
scanId={scanId}
complianceId={compliance_id}
complianceId={id}
id={id}
/>
);
})}
+13 -12
View File
@@ -19,6 +19,7 @@ import { ContentLayout } from "@/components/ui";
import { DataTable, DataTableFilterCustom } from "@/components/ui/table";
import {
createDict,
createScanDetailsMapping,
extractFiltersAndQuery,
extractSortAndKey,
hasDateOrScanFilter,
@@ -27,7 +28,8 @@ import {
createProviderDetailsMapping,
extractProviderUIDs,
} from "@/lib/provider-helpers";
import { FindingProps, ScanProps, SearchParamsProps } from "@/types/components";
import { ScanProps } from "@/types";
import { FindingProps, SearchParamsProps } from "@/types/components";
export default async function Findings({
searchParams,
@@ -47,7 +49,7 @@ export default async function Findings({
filters,
}),
getProviders({ pageSize: 50 }),
getScans({}),
getScans({ pageSize: 50 }),
]);
// Extract unique regions and services from the new endpoint
@@ -75,20 +77,17 @@ export default async function Findings({
});
// Extract scan UUIDs with "completed" state and more than one resource
const completedScans = scansData?.data
?.filter(
(scan: ScanProps) =>
scan.attributes.state === "completed" &&
scan.attributes.unique_resource_count > 1,
)
.map((scan: ScanProps) => ({
id: scan.id,
name: scan.attributes.name,
}));
const completedScans = scansData?.data?.filter(
(scan: ScanProps) =>
scan.attributes.state === "completed" &&
scan.attributes.unique_resource_count > 1,
);
const completedScanIds =
completedScans?.map((scan: ScanProps) => scan.id) || [];
const scanDetails = createScanDetailsMapping(completedScans, providersData);
return (
<ContentLayout title="Findings" icon="carbon:data-view-alt">
<FilterControls search date />
@@ -118,12 +117,14 @@ export default async function Findings({
key: "scan__in",
labelCheckboxGroup: "Scan ID",
values: completedScanIds,
valueLabelMapping: scanDetails,
index: 9,
},
]}
defaultOpen={true}
/>
<Spacer y={8} />
<Suspense key={searchParamsKey} fallback={<SkeletonTableFindings />}>
<SSRDataTable searchParams={searchParams} />
</Suspense>
+1 -1
View File
@@ -9,7 +9,7 @@ import { MembershipsCard } from "@/components/users/profile/memberships-card";
import { RolesCard } from "@/components/users/profile/roles-card";
import { SkeletonUserInfo } from "@/components/users/profile/skeleton-user-info";
import { isUserOwnerAndHasManageAccount } from "@/lib/permissions";
import { RoleDetail, TenantDetailData } from "@/types/users/users";
import { RoleDetail, TenantDetailData } from "@/types/users";
export default async function Profile() {
return (
@@ -0,0 +1,192 @@
"use client";
import { useSearchParams } from "next/navigation";
import { useEffect, useRef, useState } from "react";
import { getFindings } from "@/actions/findings/findings";
import {
ColumnFindings,
SkeletonTableFindings,
} from "@/components/findings/table";
import { Accordion } from "@/components/ui/accordion/Accordion";
import { DataTable } from "@/components/ui/table";
import { createDict } from "@/lib";
import { getComplianceMapper } from "@/lib/compliance/commons";
import { Requirement } from "@/types/compliance";
import { FindingProps, FindingsResponse } from "@/types/components";
interface ClientAccordionContentProps {
requirement: Requirement;
scanId: string;
framework: string;
disableFindings?: boolean;
}
export const ClientAccordionContent = ({
requirement,
framework,
scanId,
disableFindings = false,
}: ClientAccordionContentProps) => {
const [findings, setFindings] = useState<FindingsResponse | null>(null);
const [expandedFindings, setExpandedFindings] = useState<FindingProps[]>([]);
const searchParams = useSearchParams();
const pageNumber = searchParams.get("page") || "1";
const complianceId = searchParams.get("complianceId");
const defaultSort = "severity,status,-inserted_at";
const sort = searchParams.get("sort") || defaultSort;
const loadedPageRef = useRef<string | null>(null);
const loadedSortRef = useRef<string | null>(null);
const isExpandedRef = useRef(false);
const region = searchParams.get("filter[region__in]") || "";
useEffect(() => {
async function loadFindings() {
if (
!disableFindings &&
requirement.check_ids?.length > 0 &&
requirement.status !== "No findings" &&
(loadedPageRef.current !== pageNumber ||
loadedSortRef.current !== sort ||
!isExpandedRef.current)
) {
loadedPageRef.current = pageNumber;
loadedSortRef.current = sort;
isExpandedRef.current = true;
try {
const checkIds = requirement.check_ids;
const encodedSort = sort.replace(/^\+/, "");
const findingsData = await getFindings({
filters: {
"filter[check_id__in]": checkIds.join(","),
"filter[scan]": scanId,
...(region && { "filter[region__in]": region }),
},
page: parseInt(pageNumber, 10),
sort: encodedSort,
});
setFindings(findingsData);
if (findingsData?.data) {
// Create dictionaries for resources, scans, and providers
const resourceDict = createDict("resources", findingsData);
const scanDict = createDict("scans", findingsData);
const providerDict = createDict("providers", findingsData);
// Expand each finding with its corresponding resource, scan, and provider
const expandedData = findingsData.data.map(
(finding: FindingProps) => {
const scan = scanDict[finding.relationships?.scan?.data?.id];
const resource =
resourceDict[finding.relationships?.resources?.data?.[0]?.id];
const provider =
providerDict[scan?.relationships?.provider?.data?.id];
return {
...finding,
relationships: { scan, resource, provider },
};
},
);
setExpandedFindings(expandedData);
}
} catch (error) {
console.error("Error loading findings:", error);
}
}
}
loadFindings();
}, [requirement, scanId, pageNumber, sort, region, disableFindings]);
const renderDetails = () => {
if (!complianceId) {
return null;
}
const mapper = getComplianceMapper(framework);
const detailsComponent = mapper.getDetailsComponent(requirement);
return <div className="w-full">{detailsComponent}</div>;
};
if (disableFindings) {
return (
<div className="w-full">
{renderDetails()}
<p className="mt-2 text-sm font-medium text-gray-800">
This requirement has no checks; therefore, there are no findings.
</p>
</div>
);
}
const checks = requirement.check_ids || [];
const checksList = (
<div className="mb-2 flex items-center">
<span>{checks.join(", ")}</span>
</div>
);
const accordionChecksItems = [
{
key: "checks",
title: (
<div className="flex items-center gap-2">
<span className="text-primary">{checks.length}</span>
{checks.length > 1 ? <span>Checks</span> : <span>Check</span>}
</div>
),
content: checksList,
},
];
const renderFindingsTable = () => {
if (findings === null && requirement.status !== "MANUAL") {
return <SkeletonTableFindings />;
}
if (findings?.data?.length && findings.data.length > 0) {
return (
<div className="p-1">
<DataTable
// Remove the updated_at column as compliance is for the last scan
columns={ColumnFindings.filter(
(_, index) => index !== 4 && index !== 7,
)}
data={expandedFindings || []}
metadata={findings?.meta}
disableScroll={true}
/>
</div>
);
}
return (
<div className="text-sm font-medium text-gray-800">
There are no findings for this regions
</div>
);
};
return (
<div className="w-full">
{renderDetails()}
{checks.length > 0 && (
<div className="mb-2 mt-2">
<Accordion
items={accordionChecksItems}
variant="light"
defaultExpandedKeys={[""]}
className="rounded-lg bg-white dark:bg-prowler-blue-400"
/>
</div>
)}
{renderFindingsTable()}
</div>
);
};
@@ -0,0 +1,79 @@
"use client";
import { useState } from "react";
import { Accordion, AccordionItemProps } from "@/components/ui";
import { CustomButton } from "@/components/ui/custom";
export const ClientAccordionWrapper = ({
items,
defaultExpandedKeys,
}: {
items: AccordionItemProps[];
defaultExpandedKeys: string[];
}) => {
const [selectedKeys, setSelectedKeys] =
useState<string[]>(defaultExpandedKeys);
const [isExpanded, setIsExpanded] = useState(false);
// Function to get all keys except the last level (requirements)
const getAllKeysExceptLastLevel = (items: AccordionItemProps[]): string[] => {
const keys: string[] = [];
const traverse = (items: AccordionItemProps[], level: number = 0) => {
items.forEach((item) => {
// Add current item key if it's not the last level
if (item.items && item.items.length > 0) {
keys.push(item.key);
// Check if the children have their own children (not the last level)
const hasGrandChildren = item.items.some(
(child) => child.items && child.items.length > 0,
);
if (hasGrandChildren) {
traverse(item.items, level + 1);
}
}
});
};
traverse(items);
return keys;
};
const handleToggleExpand = () => {
if (isExpanded) {
setSelectedKeys(defaultExpandedKeys);
} else {
const allKeys = getAllKeysExceptLastLevel(items);
setSelectedKeys(allKeys);
}
setIsExpanded(!isExpanded);
};
const handleSelectionChange = (keys: string[]) => {
setSelectedKeys(keys);
};
return (
<div className="space-y-4">
<div className="flex justify-end">
<CustomButton
variant="flat"
size="sm"
onPress={handleToggleExpand}
ariaLabel={isExpanded ? "Collapse all" : "Expand all"}
>
{isExpanded ? "Collapse all" : "Expand all"}
</CustomButton>
</div>
<Accordion
items={items}
variant="light"
selectionMode="multiple"
defaultExpandedKeys={defaultExpandedKeys}
selectedKeys={selectedKeys}
onSelectionChange={handleSelectionChange}
/>
</div>
);
};
@@ -0,0 +1,21 @@
import { FindingStatus, StatusFindingBadge } from "@/components/ui/table";
interface ComplianceAccordionRequirementTitleProps {
type: string;
name: string;
status: FindingStatus;
}
export const ComplianceAccordionRequirementTitle = ({
name,
status,
}: ComplianceAccordionRequirementTitleProps) => {
return (
<div className="flex w-full items-center justify-between gap-2">
<div className="flex w-5/6 items-center gap-1">
<span>{name}</span>
</div>
<StatusFindingBadge status={status} />
</div>
);
};
@@ -0,0 +1,137 @@
import { Tooltip } from "@nextui-org/react";
interface ComplianceAccordionTitleProps {
label: string;
pass: number;
fail: number;
manual?: number;
isParentLevel?: boolean;
}
export const ComplianceAccordionTitle = ({
label,
pass,
fail,
manual = 0,
isParentLevel = false,
}: ComplianceAccordionTitleProps) => {
const total = pass + fail + manual;
const passPercentage = (pass / total) * 100;
const failPercentage = (fail / total) * 100;
const manualPercentage = (manual / total) * 100;
return (
<div className="flex flex-col items-start justify-between gap-1 md:flex-row md:items-center md:gap-2">
<div className="overflow-hidden md:min-w-0 md:flex-1">
<span
className="block max-w-[600px] overflow-hidden truncate text-ellipsis text-sm"
title={label}
>
{label.charAt(0).toUpperCase() + label.slice(1)}
</span>
</div>
<div className="mr-4 flex items-center gap-2">
<div className="hidden lg:block">
{total > 0 && isParentLevel && (
<span className="whitespace-nowrap text-xs font-medium text-gray-600">
Requirements:
</span>
)}
</div>
<div className="flex h-1.5 w-[200px] overflow-hidden rounded-full bg-gray-100 shadow-inner">
{total > 0 ? (
<div className="flex w-full">
{pass > 0 && (
<Tooltip
content={
<div className="px-1 py-0.5">
<div className="text-xs font-medium">Pass</div>
<div className="text-tiny text-default-400">
{pass} ({passPercentage.toFixed(1)}%)
</div>
</div>
}
size="sm"
placement="top"
delay={0}
closeDelay={0}
>
<div
className="h-full bg-[#3CEC6D] transition-all duration-200 hover:brightness-110"
style={{
width: `${passPercentage}%`,
marginRight: pass > 0 ? "2px" : "0",
}}
/>
</Tooltip>
)}
{fail > 0 && (
<Tooltip
content={
<div className="px-1 py-0.5">
<div className="text-xs font-medium">Fail</div>
<div className="text-tiny text-default-400">
{fail} ({failPercentage.toFixed(1)}%)
</div>
</div>
}
size="sm"
placement="top"
delay={0}
closeDelay={0}
>
<div
className="h-full bg-[#FB718F] transition-all duration-200 hover:brightness-110"
style={{
width: `${failPercentage}%`,
marginRight: manual > 0 ? "2px" : "0",
}}
/>
</Tooltip>
)}
{manual > 0 && (
<Tooltip
content={
<div className="px-1 py-0.5">
<div className="text-xs font-medium">Manual</div>
<div className="text-tiny text-default-400">
{manual} ({manualPercentage.toFixed(1)}%)
</div>
</div>
}
size="sm"
placement="top"
delay={0}
closeDelay={0}
>
<div
className="h-full bg-[#868994] transition-all duration-200 hover:brightness-110"
style={{ width: `${manualPercentage}%` }}
/>
</Tooltip>
)}
</div>
) : (
<div className="h-full w-full bg-gray-200" />
)}
</div>
<Tooltip
content={
<div className="px-1 py-0.5">
<div className="text-xs font-medium">Total requirements</div>
<div className="text-tiny text-default-400">{total}</div>
</div>
}
size="sm"
placement="top"
>
<div className="min-w-[32px] text-center text-xs font-medium text-default-600">
{total > 0 ? total : "—"}
</div>
</Tooltip>
</div>
</div>
);
};
+33 -2
View File
@@ -2,7 +2,7 @@
import { Card, CardBody, Progress } from "@nextui-org/react";
import Image from "next/image";
import { useSearchParams } from "next/navigation";
import { useRouter, useSearchParams } from "next/navigation";
import React, { useState } from "react";
import { DownloadIconButton, toast } from "@/components/ui";
@@ -19,6 +19,7 @@ interface ComplianceCardProps {
prevTotalRequirements: number;
scanId: string;
complianceId: string;
id: string;
}
export const ComplianceCard: React.FC<ComplianceCardProps> = ({
@@ -28,8 +29,10 @@ export const ComplianceCard: React.FC<ComplianceCardProps> = ({
totalRequirements,
scanId,
complianceId,
id,
}) => {
const searchParams = useSearchParams();
const router = useRouter();
const hasRegionFilter = searchParams.has("filter[region__in]");
const [isDownloading, setIsDownloading] = useState<boolean>(false);
@@ -68,6 +71,28 @@ export const ComplianceCard: React.FC<ComplianceCardProps> = ({
return "success";
};
const isPressable =
id.includes("ens") ||
id.includes("iso") ||
id.includes("cis_") ||
id.includes("pillar");
const navigateToDetail = () => {
// We will unlock this while developing the rest of complainces.
if (!isPressable) {
return;
}
const formattedTitleForUrl = encodeURIComponent(title);
const path = `/compliance/${formattedTitleForUrl}`;
const params = new URLSearchParams();
params.set("complianceId", id);
params.set("version", version);
params.set("scanId", scanId);
router.push(`${path}?${params.toString()}`);
};
const handleDownload = async () => {
setIsDownloading(true);
try {
@@ -78,7 +103,13 @@ export const ComplianceCard: React.FC<ComplianceCardProps> = ({
};
return (
<Card fullWidth isHoverable shadow="sm">
<Card
fullWidth
isHoverable
shadow="sm"
isPressable={isPressable}
onPress={navigateToDetail}
>
<CardBody className="flex flex-row items-center justify-between space-x-4 dark:bg-prowler-blue-800">
<div className="flex w-full items-center space-x-4">
<Image
@@ -0,0 +1,193 @@
"use client";
import { useTheme } from "next-themes";
import {
Bar,
BarChart as RechartsBarChart,
Legend,
ResponsiveContainer,
Tooltip,
XAxis,
YAxis,
} from "recharts";
import { translateType } from "@/lib/compliance/ens";
import { FailedSection } from "@/types/compliance";
interface FailedSectionsListProps {
sections: FailedSection[];
}
const title = (
<h3 className="whitespace-nowrap text-xs font-semibold uppercase tracking-wide">
Failed Sections (Top 5)
</h3>
);
export const BarChart = ({ sections }: FailedSectionsListProps) => {
const { theme } = useTheme();
const getTypeColor = (type: string) => {
switch (type.toLowerCase()) {
case "requisito":
return "#ff5356";
case "recomendacion":
return "#FDC53A"; // Increased contrast from #FDDD8A
case "refuerzo":
return "#7FB5FF"; // Increased contrast from #B5D7FF
default:
return "#ff5356";
}
};
const chartData = [...sections]
.sort((a, b) => b.total - a.total)
.slice(0, 5)
.map((section) => ({
name: section.name.charAt(0).toUpperCase() + section.name.slice(1),
...section.types,
}));
const allTypes = Array.from(
new Set(sections.flatMap((section) => Object.keys(section.types || {}))),
);
// Add empty bars to complete up to 5 bars for better distribution
while (chartData.length < 5) {
const emptyBar: any = { name: "" };
allTypes.forEach((type) => {
emptyBar[type] = 0;
});
chartData.push(emptyBar);
}
// Calculate the maximum value to ensure proper scaling
const maxValue = Math.max(
...chartData.map((item) =>
allTypes.reduce((sum, type) => sum + ((item as any)[type] || 0), 0),
),
);
// Set minimum domain to ensure bars are always visible
const domainMax = Math.max(maxValue, 1);
// Check if there are no failed sections
if (!sections || sections.length === 0) {
return (
<div className="flex w-[400px] flex-col items-center justify-between lg:w-[600px]">
{title}
<div className="flex h-[320px] w-full items-center justify-center">
<p className="text-sm text-gray-500">There are no failed sections</p>
</div>
</div>
);
}
return (
<div className="flex h-[320px] w-[400px] flex-col items-center justify-between lg:w-[400px]">
<div>{title}</div>
<div className="h-full w-full">
<ResponsiveContainer width="100%" height="100%">
<RechartsBarChart
data={chartData}
layout="vertical"
margin={{ top: 12, bottom: 0 }}
maxBarSize={32}
>
<XAxis
type="number"
fontSize={12}
axisLine={false}
tickLine={false}
allowDecimals={false}
hide={true}
domain={[0, domainMax]}
tick={{
fontSize: 12,
fill: theme === "dark" ? "#94a3b8" : "#374151",
}}
/>
<YAxis
type="category"
dataKey="name"
width={1}
tick={{
fontSize: 12,
fill: theme === "dark" ? "#94a3b8" : "#374151",
textAnchor: "start",
style: {
transform: "translateX(10px) translateY(-26px)",
},
width: 400,
}}
axisLine={false}
tickLine={false}
/>
<Tooltip
content={(props) => {
if (!props.active || !props.payload || !props.payload.length) {
return null;
}
const data = props.payload[0].payload;
if (!data.name || data.name === "") {
return null;
}
const hasValues = allTypes.some((type) => data[type] > 0);
if (!hasValues) {
return null;
}
return (
<div
style={{
backgroundColor: theme === "dark" ? "#1e293b" : "white",
border: `1px solid ${theme === "dark" ? "#475569" : "rgba(0, 0, 0, 0.1)"}`,
borderRadius: "6px",
boxShadow: "0px 4px 12px rgba(0, 0, 0, 0.15)",
fontSize: "12px",
padding: "8px 12px",
color: theme === "dark" ? "white" : "black",
}}
>
{props.payload.map((entry: any, index: number) => (
<div key={index} style={{ color: entry.color }}>
{translateType(entry.dataKey)}: {entry.value}
</div>
))}
</div>
);
}}
cursor={false}
/>
{allTypes.map((type, i) => (
<Bar
key={type}
dataKey={type}
stackId="a"
fill={getTypeColor(type)}
radius={i === allTypes.length - 1 ? [0, 4, 4, 0] : [0, 0, 0, 0]}
/>
))}
<Legend
formatter={(value) => translateType(value)}
wrapperStyle={{
fontSize: "10px",
display: "flex",
justifyContent: "center",
width: "100%",
paddingTop: "16px",
marginBottom: "16px",
}}
iconType="circle"
layout="horizontal"
verticalAlign="bottom"
/>
</RechartsBarChart>
</ResponsiveContainer>
</div>
</div>
);
};
@@ -0,0 +1,152 @@
"use client";
import { cn } from "@nextui-org/react";
import { useTheme } from "next-themes";
import { useState } from "react";
import { CategoryData } from "@/types/compliance";
interface HeatmapChartProps {
categories?: CategoryData[];
}
const getHeatmapColor = (percentage: number): string => {
if (percentage === 0) return "#10b981"; // Green for 0% failures
if (percentage <= 25) return "#eab308"; // Yellow
if (percentage <= 50) return "#f97316"; // Orange
if (percentage <= 100) return "#ef4444"; // Red
return "#ef4444";
};
const capitalizeFirstLetter = (text: string): string => {
const lowerText = text.toLowerCase();
const firstLetterIndex = lowerText.search(/[a-zA-Z]/);
if (firstLetterIndex === -1) return text; // No letters found
return (
lowerText.slice(0, firstLetterIndex) +
lowerText.charAt(firstLetterIndex).toUpperCase() +
lowerText.slice(firstLetterIndex + 1)
);
};
export const HeatmapChart = ({ categories = [] }: HeatmapChartProps) => {
const { theme } = useTheme();
const [hoveredItem, setHoveredItem] = useState<CategoryData | null>(null);
const [mousePosition, setMousePosition] = useState({ x: 0, y: 0 });
// Use categories data and prepare it
const heatmapData = categories
.filter((item) => item.totalRequirements > 0)
.sort((a, b) => b.failurePercentage - a.failurePercentage)
.slice(0, 9); // Exactly 9 items for 3x3 grid
// Check if there are no items with data
if (!categories.length || heatmapData.length === 0) {
return (
<div className="flex w-[400px] flex-col items-center justify-between lg:w-[400px]">
<h3 className="whitespace-nowrap text-xs font-semibold uppercase tracking-wide">
Sections Failure Rate
</h3>
<div className="flex h-[320px] w-full items-center justify-center">
<p className="text-sm text-gray-500">No category data available</p>
</div>
</div>
);
}
const handleMouseEnter = (item: CategoryData, event: React.MouseEvent) => {
setHoveredItem(item);
setMousePosition({ x: event.clientX, y: event.clientY });
};
const handleMouseMove = (event: React.MouseEvent) => {
setMousePosition({ x: event.clientX, y: event.clientY });
};
const handleMouseLeave = () => {
setHoveredItem(null);
};
return (
<div className="flex h-[320px] w-[400px] flex-col items-center justify-between lg:w-[400px]">
<div>
<h3 className="whitespace-nowrap text-xs font-semibold uppercase tracking-wide">
Sections Failure Rate
</h3>
</div>
<div className="h-full w-full p-2">
<div
className={cn(
"grid h-full w-full gap-1",
heatmapData.length < 3 ? "grid-cols-1" : "grid-cols-3",
)}
style={{
gridTemplateRows:
heatmapData.length < 3
? `repeat(${heatmapData.length}, ${heatmapData.length}fr)`
: `repeat(${Math.min(Math.ceil(heatmapData.length / 3), 3)}, 1fr)`,
}}
>
{heatmapData.map((item) => (
<div
key={item.name}
className="flex items-center justify-center rounded border p-1"
style={{
backgroundColor: getHeatmapColor(item.failurePercentage),
borderColor: theme === "dark" ? "#374151" : "#e5e7eb",
}}
onMouseEnter={(e) => handleMouseEnter(item, e)}
onMouseMove={handleMouseMove}
onMouseLeave={handleMouseLeave}
>
<div className="w-full px-1 text-center">
<div
className="truncate text-xs font-semibold"
style={{
color: theme === "dark" ? "#ffffff" : "#000000",
}}
title={capitalizeFirstLetter(item.name)}
>
{capitalizeFirstLetter(item.name)}
</div>
<div
className="text-xs"
style={{
color: theme === "dark" ? "#ffffff" : "#000000",
}}
>
{item.failurePercentage}%
</div>
</div>
</div>
))}
</div>
{/* Custom Tooltip */}
{hoveredItem && (
<div
className="pointer-events-none fixed z-50 rounded border px-3 py-2 text-xs shadow-lg"
style={{
left: mousePosition.x + 10,
top: mousePosition.y - 10,
backgroundColor: theme === "dark" ? "#1e293b" : "white",
borderColor: theme === "dark" ? "#475569" : "rgba(0, 0, 0, 0.1)",
color: theme === "dark" ? "white" : "black",
}}
>
<div className="mb-1 font-semibold">
{capitalizeFirstLetter(hoveredItem.name)}
</div>
<div>Failure Rate: {hoveredItem.failurePercentage}%</div>
<div>
Failed: {hoveredItem.failedRequirements}/
{hoveredItem.totalRequirements}
</div>
</div>
)}
</div>
</div>
);
};
@@ -0,0 +1,192 @@
"use client";
import { useTheme } from "next-themes";
import {
Cell,
Label,
Pie,
PieChart as RechartsPieChart,
Tooltip,
} from "recharts";
import { ChartConfig, ChartContainer } from "@/components/ui/chart/Chart";
interface PieChartProps {
pass: number;
fail: number;
manual: number;
}
const chartConfig = {
number: {
label: "Requirements",
},
pass: {
label: "Pass",
color: "hsl(var(--chart-success))",
},
fail: {
label: "Fail",
color: "hsl(var(--chart-fail))",
},
manual: {
label: "Manual",
color: "hsl(var(--chart-warning))",
},
} satisfies ChartConfig;
export const PieChart = ({ pass, fail, manual }: PieChartProps) => {
const { theme } = useTheme();
const chartData = [
{
name: "Pass",
value: pass,
fill: "#3CEC6D",
},
{
name: "Fail",
value: fail,
fill: "#FB718F",
},
{
name: "Manual",
value: manual,
fill: "#868994",
},
];
const totalRequirements = pass + fail + manual;
const emptyChartData = [
{
name: "Empty",
value: 1,
fill: "#64748b",
},
];
interface CustomTooltipProps {
active: boolean;
payload: {
payload: {
name: string;
value: number;
fill: string;
};
}[];
}
const CustomTooltip = ({ active, payload }: CustomTooltipProps) => {
if (active && payload && payload.length) {
const data = payload[0];
return (
<div
style={{
backgroundColor: theme === "dark" ? "#1e293b" : "white",
border: `1px solid ${theme === "dark" ? "#475569" : "rgba(0, 0, 0, 0.1)"}`,
borderRadius: "6px",
boxShadow: "0px 4px 12px rgba(0, 0, 0, 0.15)",
fontSize: "12px",
padding: "8px 12px",
color: theme === "dark" ? "white" : "black",
}}
>
<div style={{ display: "flex", alignItems: "center", gap: "8px" }}>
<div
style={{
width: "8px",
height: "8px",
borderRadius: "50%",
backgroundColor: data.payload.fill,
}}
/>
<span>
{data.payload.name}: {data.payload.value}
</span>
</div>
</div>
);
}
return null;
};
return (
<div className="flex h-[320px] flex-col items-center justify-between">
<h3 className="whitespace-nowrap text-xs font-semibold uppercase tracking-wide">
Requirements Status
</h3>
<ChartContainer
config={chartConfig}
className="aspect-square w-[200px] min-w-[200px]"
>
<RechartsPieChart>
<Tooltip
cursor={false}
content={<CustomTooltip active={false} payload={[]} />}
/>
<Pie
data={totalRequirements > 0 ? chartData : emptyChartData}
dataKey="value"
nameKey="name"
innerRadius={70}
outerRadius={100}
paddingAngle={2}
cornerRadius={4}
>
{(totalRequirements > 0 ? chartData : emptyChartData).map(
(entry, index) => (
<Cell key={`cell-${index}`} fill={entry.fill} />
),
)}
<Label
content={({ viewBox }) => {
if (viewBox && "cx" in viewBox && "cy" in viewBox) {
return (
<text
x={viewBox.cx}
y={viewBox.cy}
textAnchor="middle"
dominantBaseline="middle"
>
<tspan
x={viewBox.cx}
y={viewBox.cy}
className="fill-foreground text-xl font-bold"
>
{totalRequirements}
</tspan>
<tspan
x={viewBox.cx}
y={(viewBox.cy || 0) + 20}
className="fill-foreground text-xs"
>
Total
</tspan>
</text>
);
}
}}
/>
</Pie>
</RechartsPieChart>
</ChartContainer>
<div className="mt-2 grid grid-cols-3 gap-4">
<div className="flex flex-col items-center">
<div className="text-muted-foreground text-sm">Pass</div>
<div className="font-semibold text-system-success-medium">{pass}</div>
</div>
<div className="flex flex-col items-center">
<div className="text-muted-foreground text-sm">Fail</div>
<div className="font-semibold text-system-error-medium">{fail}</div>
</div>
<div className="flex flex-col items-center">
<div className="text-muted-foreground text-sm">Manual</div>
<div className="font-semibold text-prowler-grey-light">{manual}</div>
</div>
</div>
</div>
);
};
@@ -0,0 +1,91 @@
import Link from "next/link";
import { SeverityBadge } from "@/components/ui/table";
import { Requirement } from "@/types/compliance";
export const AWSWellArchitectedCustomDetails = ({
requirement,
}: {
requirement: Requirement;
}) => {
return (
<div className="space-y-4">
{requirement.description && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Description
</h4>
<p className="text-sm">{requirement.description}</p>
</div>
)}
{requirement.well_architected_name && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Best Practice
</h4>
<p className="text-sm">{requirement.well_architected_name}</p>
</div>
)}
{requirement.well_architected_question_id && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Question ID
</h4>
<p className="text-sm">{requirement.well_architected_question_id}</p>
</div>
)}
{requirement.well_architected_practice_id && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Practice ID
</h4>
<p className="text-sm">{requirement.well_architected_practice_id}</p>
</div>
)}
{requirement.level_of_risk && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Level of Risk
</h4>
<SeverityBadge
severity={
requirement.level_of_risk.toString().toLowerCase() as
| "low"
| "medium"
| "high"
}
/>
</div>
)}
{requirement.assessment_method && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Assessment Method
</h4>
<p className="text-sm">{requirement.assessment_method}</p>
</div>
)}
{requirement.implementation_guidance_url && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Implementation Guidance
</h4>
<Link
href={requirement.implementation_guidance_url as string}
target="_blank"
rel="noopener noreferrer"
className="break-all text-sm text-blue-600 underline hover:text-blue-800"
>
{requirement.implementation_guidance_url}
</Link>
</div>
)}
</div>
);
};
@@ -0,0 +1,151 @@
import Link from "next/link";
import ReactMarkdown from "react-markdown";
import { Requirement } from "@/types/compliance";
interface CISDetailsProps {
requirement: Requirement;
}
export const CISCustomDetails = ({ requirement }: CISDetailsProps) => {
const processReferences = (
references: string | number | string[] | undefined,
): string[] => {
if (typeof references !== "string") return [];
// Use regex to extract all URLs that start with https://
const urlRegex = /https:\/\/[^:]+/g;
const urls = references.match(urlRegex);
return urls || [];
};
return (
<div className="space-y-4">
{requirement.profile && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Profile Level
</h4>
<p className="text-sm">{requirement.profile}</p>
</div>
)}
{requirement.subsection && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
SubSection
</h4>
<p className="text-sm">{requirement.subsection}</p>
</div>
)}
{requirement.assessment_status && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Assessment Status
</h4>
<p className="text-sm">{requirement.assessment_status}</p>
</div>
)}
{requirement.description && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Description
</h4>
<p className="text-sm">{requirement.description}</p>
</div>
)}
{requirement.rationale_statement && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Rationale Statement
</h4>
<p className="text-sm">{requirement.rationale_statement}</p>
</div>
)}
{requirement.impact_statement && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Impact Statement
</h4>
<p className="text-sm">{requirement.impact_statement}</p>
</div>
)}
{requirement.remediation_procedure &&
typeof requirement.remediation_procedure === "string" && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Remediation Procedure
</h4>
{/* Prettier -> "plugins": ["prettier-plugin-tailwindcss"] is not ready yet to "prose": */}
{/* eslint-disable-next-line */}
<div className="prose prose-sm max-w-none dark:prose-invert">
<ReactMarkdown>{requirement.remediation_procedure}</ReactMarkdown>
</div>
</div>
)}
{requirement.audit_procedure &&
typeof requirement.audit_procedure === "string" && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Audit Procedure
</h4>
{/* eslint-disable-next-line */}
<div className="prose prose-sm max-w-none dark:prose-invert">
<ReactMarkdown>{requirement.audit_procedure}</ReactMarkdown>
</div>
</div>
)}
{requirement.additional_information && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Additional Information
</h4>
<p className="whitespace-pre-wrap text-sm">
{requirement.additional_information}
</p>
</div>
)}
{requirement.default_value && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
Default Value
</h4>
<p className="text-sm">{requirement.default_value}</p>
</div>
)}
{requirement.references && (
<div>
<h4 className="text-muted-foreground mb-1 text-sm font-medium">
References
</h4>
<div className="text-sm">
{processReferences(requirement.references).map(
(url: string, index: number) => (
<div key={index}>
<Link
href={url}
target="_blank"
rel="noopener noreferrer"
className="break-all text-blue-600 underline hover:text-blue-800"
>
{url}
</Link>
</div>
),
)}
</div>
</div>
)}
</div>
);
};
@@ -0,0 +1,47 @@
import { translateType } from "@/lib/compliance/ens";
import { Requirement } from "@/types/compliance";
export const ENSCustomDetails = ({
requirement,
}: {
requirement: Requirement;
}) => {
return (
<div className="mb-4">
<div className="mb-2 text-sm text-gray-600">
{requirement.description}
</div>
<div className="flex flex-col gap-2 text-sm">
<div className="flex items-center gap-2">
<span className="font-medium">Type:</span>
<span className="capitalize">
{translateType(requirement.type as string)}
</span>
</div>
<div className="flex items-center gap-2">
<span className="font-medium">Level:</span>
<span className="capitalize">{requirement.nivel}</span>
</div>
{requirement.dimensiones &&
Array.isArray(requirement.dimensiones) &&
requirement.dimensiones.length > 0 && (
<div className="flex items-center gap-2">
<span className="font-medium">Dimensions:</span>
<div className="flex flex-wrap gap-1">
{requirement.dimensiones.map(
(dimension: string, index: number) => (
<span
key={index}
className="rounded-full bg-gray-100 px-2 py-0.5 text-xs capitalize dark:bg-prowler-blue-400"
>
{dimension}
</span>
),
)}
</div>
</div>
)}
</div>
</div>
);
};
@@ -0,0 +1,23 @@
import { Requirement } from "@/types/compliance";
export const ISOCustomDetails = ({
requirement,
}: {
requirement: Requirement;
}) => {
return (
<div className="mb-4">
<div className="mb-2 text-sm text-gray-600">
{requirement.description}
</div>
<div className="flex flex-col gap-2 text-sm">
{requirement.objetive_name && (
<div className="flex items-center gap-2">
<span className="font-medium">Objective:</span>
<span>{requirement.objetive_name}</span>
</div>
)}
</div>
</div>
);
};
@@ -0,0 +1,71 @@
"use client";
import { Spacer } from "@nextui-org/react";
import { FilterControls } from "@/components/filters";
import { DataTableFilterCustom } from "@/components/ui/table/data-table-filter-custom";
import { DataCompliance } from "./data-compliance";
import { SelectScanComplianceDataProps } from "./scan-selector";
interface ComplianceHeaderProps {
scans: SelectScanComplianceDataProps["scans"];
uniqueRegions: string[];
showSearch?: boolean;
showRegionFilter?: boolean;
framework?: string; // Framework name to show specific filters
showProviders?: boolean;
}
export const ComplianceHeader = ({
scans,
uniqueRegions,
showSearch = true,
showRegionFilter = true,
framework,
showProviders = true,
}: ComplianceHeaderProps) => {
const frameworkFilters = [];
// Add CIS Profile Level filter if framework is CIS
if (framework === "CIS") {
frameworkFilters.push({
key: "cis_profile_level",
labelCheckboxGroup: "Level",
values: ["Level 1", "Level 2"],
index: 0, // Show first
showSelectAll: false, // No "Select All" option since Level 2 includes Level 1
defaultValues: ["Level 2"], // Default to Level 2 selected (which includes Level 1)
});
}
// Prepare region filters
const regionFilters = showRegionFilter
? [
{
key: "region__in",
labelCheckboxGroup: "Regions",
values: uniqueRegions,
index: 1, // Show after framework filters
defaultToSelectAll: true, // Default to all regions selected
},
]
: [];
const allFilters = [...frameworkFilters, ...regionFilters];
return (
<>
{showSearch && <FilterControls search />}
<Spacer y={8} />
{showProviders && <DataCompliance scans={scans} />}
{allFilters.length > 0 && (
<>
{showProviders && <Spacer y={8} />}
<DataTableFilterCustom filters={allFilters} defaultOpen={true} />
</>
)}
<Spacer y={12} />
</>
);
};
@@ -1,8 +1,9 @@
import { Divider } from "@nextui-org/react";
import { Divider, Tooltip } from "@nextui-org/react";
import React from "react";
import { DateWithTime, EntityInfoShort } from "@/components/ui/entities";
import { ProviderType } from "@/types";
interface ComplianceScanInfoProps {
scan: {
providerInfo: {
@@ -21,18 +22,25 @@ export const ComplianceScanInfo: React.FC<ComplianceScanInfoProps> = ({
scan,
}) => {
return (
<div className="flex w-fit items-center">
<div className="flex items-center gap-2">
<EntityInfoShort
cloudProvider={scan.providerInfo.provider}
entityAlias={scan.providerInfo.alias}
entityId={scan.providerInfo.uid}
hideCopyButton
snippetWidth="max-w-[100px]"
/>
<Divider orientation="vertical" className="mx-2 h-6" />
<div className="flex flex-col items-start">
<p className="text-xs text-default-500">
{scan.attributes.name || "- -"}
</p>
<Divider orientation="vertical" className="h-6" />
<div className="flex flex-col items-start whitespace-nowrap">
<Tooltip
content={scan.attributes.name || "- -"}
placement="top"
size="sm"
>
<p className="text-xs text-default-500">
{scan.attributes.name || "- -"}
</p>
</Tooltip>
<DateWithTime inline dateTime={scan.attributes.completed_at} />
</div>
</div>
@@ -3,8 +3,10 @@
import { useRouter, useSearchParams } from "next/navigation";
import { useEffect } from "react";
import { SelectScanComplianceData } from "@/components/compliance/data-compliance";
import { SelectScanComplianceDataProps } from "@/types";
import {
ScanSelector,
SelectScanComplianceDataProps,
} from "@/components/compliance/compliance-header/index";
interface DataComplianceProps {
scans: SelectScanComplianceDataProps["scans"];
}
@@ -33,8 +35,8 @@ export const DataCompliance = ({ scans }: DataComplianceProps) => {
return (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-1 items-center gap-x-4 gap-y-4 md:grid-cols-2 lg:grid-cols-3">
<SelectScanComplianceData
<div className="flex max-w-fit">
<ScanSelector
scans={scans}
selectedScanId={selectedScanId}
onSelectionChange={handleScanChange}
@@ -0,0 +1,2 @@
export * from "./data-compliance";
export * from "./scan-selector";
@@ -1,10 +1,22 @@
import { Select, SelectItem } from "@nextui-org/react";
import { SelectScanComplianceDataProps } from "@/types";
import { ProviderType, ScanProps } from "@/types";
import { ComplianceScanInfo } from "../compliance-scan-info";
import { ComplianceScanInfo } from "./compliance-scan-info";
export const SelectScanComplianceData = ({
export interface SelectScanComplianceDataProps {
scans: (ScanProps & {
providerInfo: {
provider: ProviderType;
uid: string;
alias: string;
};
})[];
selectedScanId: string;
onSelectionChange: (selectedKey: string) => void;
}
export const ScanSelector = ({
scans,
selectedScanId,
onSelectionChange,
@@ -14,7 +26,7 @@ export const SelectScanComplianceData = ({
aria-label="Select a Scan"
placeholder="Select a scan"
classNames={{
selectorIcon: "right-2",
trigger: "w-full min-w-[365px]",
}}
size="lg"
labelPlacement="outside"
@@ -1,2 +0,0 @@
export * from "./data-compliance";
export * from "./select-scan-compliance-data";
+19 -2
View File
@@ -1,4 +1,21 @@
export * from "./compliance-accordion/client-accordion-content";
export * from "./compliance-accordion/client-accordion-wrapper";
export * from "./compliance-accordion/compliance-accordion-requeriment-title";
export * from "./compliance-accordion/compliance-accordion-title";
export * from "./compliance-card";
export * from "./compliance-scan-info";
export * from "./compliance-skeleton-grid";
export * from "./compliance-charts/bar-chart";
export * from "./compliance-charts/heatmap-chart";
export * from "./compliance-charts/pie-chart";
export * from "./compliance-custom-details/cis-details";
export * from "./compliance-custom-details/ens-details";
export * from "./compliance-custom-details/iso-details";
export * from "./compliance-header/compliance-header";
export * from "./compliance-header/compliance-scan-info";
export * from "./compliance-header/data-compliance";
export * from "./compliance-header/scan-selector";
export * from "./no-scans-available";
export * from "./skeletons/bar-chart-skeleton";
export * from "./skeletons/compliance-accordion-skeleton";
export * from "./skeletons/compliance-grid-skeleton";
export * from "./skeletons/heatmap-chart-skeleton";
export * from "./skeletons/pie-chart-skeleton";
@@ -0,0 +1,53 @@
"use client";
import { Skeleton } from "@nextui-org/react";
export const BarChartSkeleton = () => {
return (
<div className="flex w-[400px] flex-col items-center justify-between lg:w-[600px]">
{/* Title skeleton */}
<Skeleton className="h-4 w-40 rounded-lg">
<div className="h-4 bg-default-200" />
</Skeleton>
{/* Chart area skeleton */}
<div className="ml-24 flex h-full flex-col justify-center space-y-2 p-4">
{/* Bar chart skeleton - 5 horizontal bars */}
{Array.from({ length: 5 }).map((_, index) => (
<div key={index} className="flex items-center space-x-4">
{/* Bar skeleton with varying widths */}
<Skeleton
className={`h-10 rounded-lg ${
index === 0
? "w-48"
: index === 1
? "w-40"
: index === 2
? "w-32"
: index === 3
? "w-24"
: "w-16"
}`}
>
<div className="h-6 bg-default-200" />
</Skeleton>
</div>
))}
{/* Legend skeleton */}
<div className="flex justify-center space-x-4 pt-2">
{Array.from({ length: 3 }).map((_, index) => (
<div key={index} className="flex items-center space-x-1">
<Skeleton className="h-3 w-3 rounded-full">
<div className="h-3 w-3 bg-default-200" />
</Skeleton>
<Skeleton className="h-3 w-16 rounded-lg">
<div className="h-3 bg-default-200" />
</Skeleton>
</div>
))}
</div>
</div>
</div>
);
};
@@ -0,0 +1,30 @@
import { Skeleton } from "@nextui-org/react";
import React from "react";
interface SkeletonAccordionProps {
itemCount?: number;
className?: string;
isCompact?: boolean;
}
export const SkeletonAccordion = ({
itemCount = 3,
className = "",
isCompact = false,
}: SkeletonAccordionProps) => {
const itemHeight = isCompact ? "h-10" : "h-14";
return (
<div
className={`w-full space-y-2 ${className} rounded-xl border border-gray-300 p-2 dark:border-gray-700`}
>
{[...Array(itemCount)].map((_, index) => (
<Skeleton key={index} className="rounded-lg">
<div className={`${itemHeight} bg-default-300`}></div>
</Skeleton>
))}
</div>
);
};
SkeletonAccordion.displayName = "SkeletonAccordion";
@@ -0,0 +1,28 @@
"use client";
import { Skeleton } from "@nextui-org/react";
export const HeatmapChartSkeleton = () => {
return (
<div className="flex h-[320px] w-[400px] flex-col items-center justify-between lg:w-[400px]">
{/* Title skeleton */}
<Skeleton className="h-4 w-36 rounded-lg">
<div className="h-4 bg-default-200" />
</Skeleton>
{/* Heatmap area skeleton - 3x3 grid like the real component */}
<div className="h-full w-full p-4">
<div className="grid h-full w-full grid-cols-3 gap-1">
{Array.from({ length: 9 }).map((_, index) => (
<Skeleton
key={index}
className="flex items-center justify-center rounded border"
>
<div className="h-full w-full bg-default-200" />
</Skeleton>
))}
</div>
</div>
</div>
);
};
@@ -0,0 +1,63 @@
"use client";
import { Skeleton } from "@nextui-org/react";
export const PieChartSkeleton = () => {
return (
<div className="flex h-[320px] flex-col items-center justify-between">
{/* Title skeleton */}
<Skeleton className="h-4 w-32 rounded-lg">
<div className="h-4 bg-default-200" />
</Skeleton>
{/* Pie chart skeleton */}
<div className="relative flex aspect-square w-[200px] min-w-[200px] items-center justify-center">
{/* Outer circle */}
<Skeleton className="absolute h-[200px] w-[200px] rounded-full">
<div className="h-[200px] w-[200px] bg-default-200" />
</Skeleton>
{/* Inner circle (donut hole) */}
<div className="absolute h-[140px] w-[140px] rounded-full bg-background"></div>
{/* Center text skeleton */}
<div className="absolute flex flex-col items-center">
<Skeleton className="h-6 w-8 rounded-lg">
<div className="h-6 bg-default-300" />
</Skeleton>
<Skeleton className="mt-1 h-3 w-6 rounded-lg">
<div className="h-3 bg-default-300" />
</Skeleton>
</div>
</div>
{/* Bottom stats skeleton */}
<div className="mt-2 grid grid-cols-3 gap-4">
<div className="flex flex-col items-center">
<Skeleton className="h-4 w-8 rounded-lg">
<div className="h-4 bg-default-200" />
</Skeleton>
<Skeleton className="mt-1 h-5 w-6 rounded-lg">
<div className="h-5 bg-default-200" />
</Skeleton>
</div>
<div className="flex flex-col items-center">
<Skeleton className="h-4 w-6 rounded-lg">
<div className="h-4 bg-default-200" />
</Skeleton>
<Skeleton className="mt-1 h-5 w-6 rounded-lg">
<div className="h-5 bg-default-200" />
</Skeleton>
</div>
<div className="flex flex-col items-center">
<Skeleton className="h-4 w-12 rounded-lg">
<div className="h-4 bg-default-200" />
</Skeleton>
<Skeleton className="mt-1 h-5 w-6 rounded-lg">
<div className="h-5 bg-default-200" />
</Skeleton>
</div>
</div>
</div>
);
};
@@ -1,65 +1,11 @@
import { Card, Skeleton } from "@nextui-org/react";
import React from "react";
import { SkeletonTable } from "../../ui/skeleton/skeleton";
export const SkeletonTableFindings = () => {
return (
<Card className="h-full w-full space-y-5 p-4" radius="sm">
{/* Table headers */}
<div className="hidden justify-between md:flex">
<Skeleton className="w-1/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-1/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-1/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
</div>
{/* Table body */}
<div className="space-y-3">
{[...Array(3)].map((_, index) => (
<div
key={index}
className="flex flex-col items-center justify-between space-x-0 md:flex-row md:space-x-4"
>
<Skeleton className="mb-2 w-full rounded-lg md:mb-0 md:w-1/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 w-full rounded-lg md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-1/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-1/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
</div>
))}
</div>
</Card>
<div className="bg-card rounded-xl border p-4 shadow-sm">
<SkeletonTable rows={4} columns={7} />
</div>
);
};
@@ -116,9 +116,9 @@ export const FindingsBySeverityChart = ({
>
<LabelList
position="insideRight"
offset={10}
offset={5}
className="fill-foreground font-bold"
fontSize={12}
fontSize={11}
/>
</Bar>
</BarChart>
@@ -146,9 +146,9 @@ export const FindingsByStatusChart: React.FC<FindingsByStatusChartProps> = ({
</PieChart>
</ChartContainer>
<div className="grid w-full grid-cols-2 justify-items-center gap-6">
<div className="flex flex-col gap-6">
<div className="flex flex-col gap-2">
<div className="flex items-center space-x-2 self-end">
<div className="flex items-center space-x-2">
<Link
href="/findings?filter[status]=PASS"
className="flex items-center space-x-2"
@@ -1,65 +1,11 @@
import { Card, Skeleton } from "@nextui-org/react";
import React from "react";
import { SkeletonTable } from "@/components/ui/skeleton/skeleton";
export const SkeletonTableNewFindings = () => {
return (
<Card className="h-full w-full space-y-5 p-4" radius="sm">
{/* Table headers */}
<div className="hidden justify-between md:flex">
<Skeleton className="w-1/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-2/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-1/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
<Skeleton className="w-1/12 rounded-lg">
<div className="h-8 bg-default-200"></div>
</Skeleton>
</div>
{/* Table body */}
<div className="space-y-3">
{[...Array(3)].map((_, index) => (
<div
key={index}
className="flex flex-col items-center justify-between space-x-0 md:flex-row md:space-x-4"
>
<Skeleton className="mb-2 w-full rounded-lg md:mb-0 md:w-1/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 w-full rounded-lg md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-2/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-1/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
<Skeleton className="mb-2 hidden rounded-lg sm:flex md:mb-0 md:w-1/12">
<div className="h-12 bg-default-300"></div>
</Skeleton>
</div>
))}
</div>
</Card>
<div className="bg-card rounded-xl border p-4 shadow-sm">
<SkeletonTable rows={3} columns={7} />
</div>
);
};
+62 -11
View File
@@ -6,7 +6,7 @@ import {
Selection,
} from "@nextui-org/react";
import { ChevronDown } from "lucide-react";
import React, { ReactNode, useCallback, useState } from "react";
import React, { ReactNode, useCallback, useMemo, useState } from "react";
import { cn } from "@/lib/utils";
@@ -24,17 +24,24 @@ export interface AccordionProps {
variant?: "light" | "shadow" | "bordered" | "splitted";
className?: string;
defaultExpandedKeys?: string[];
selectedKeys?: string[];
selectionMode?: "single" | "multiple";
isCompact?: boolean;
showDivider?: boolean;
onItemExpand?: (key: string) => void;
onSelectionChange?: (keys: string[]) => void;
}
const AccordionContent = ({
content,
items,
selectedKeys,
onSelectionChange,
}: {
content: ReactNode;
items?: AccordionItemProps[];
selectedKeys?: string[];
onSelectionChange?: (keys: string[]) => void;
}) => {
return (
<div className="text-sm text-gray-700 dark:text-gray-300">
@@ -46,6 +53,8 @@ const AccordionContent = ({
variant="light"
isCompact
selectionMode="multiple"
selectedKeys={selectedKeys}
onSelectionChange={onSelectionChange}
/>
</div>
)}
@@ -58,21 +67,58 @@ export const Accordion = ({
variant = "light",
className,
defaultExpandedKeys = [],
selectedKeys,
selectionMode = "single",
isCompact = false,
showDivider = true,
onItemExpand,
onSelectionChange,
}: AccordionProps) => {
const [expandedKeys, setExpandedKeys] = useState<Selection>(
// Determine if component is in controlled or uncontrolled mode
const isControlled = selectedKeys !== undefined;
const [internalExpandedKeys, setInternalExpandedKeys] = useState<Selection>(
new Set(defaultExpandedKeys),
);
const handleSelectionChange = useCallback((keys: Selection) => {
setExpandedKeys(keys);
}, []);
// Use selectedKeys if controlled, otherwise use internal state
const expandedKeys = useMemo(
() => (isControlled ? new Set(selectedKeys) : internalExpandedKeys),
[isControlled, selectedKeys, internalExpandedKeys],
);
const handleSelectionChange = useCallback(
(keys: Selection) => {
const keysArray = Array.from(keys as Set<string>);
// If controlled mode, call parent callback
if (isControlled && onSelectionChange) {
onSelectionChange(keysArray);
} else {
// If uncontrolled, update internal state
setInternalExpandedKeys(keys);
}
// Handle onItemExpand for backward compatibility
if (onItemExpand && keys !== expandedKeys) {
const currentKeys = Array.from(expandedKeys as Set<string>);
const newKeys = keysArray;
const newlyExpandedKeys = newKeys.filter(
(key) => !currentKeys.includes(key),
);
newlyExpandedKeys.forEach((key) => {
onItemExpand(key);
});
}
},
[expandedKeys, onItemExpand, isControlled, onSelectionChange],
);
return (
<NextUIAccordion
className={cn("w-full", className)}
className={cn("w-full !px-0", className)}
variant={variant}
selectionMode={selectionMode}
selectedKeys={expandedKeys}
@@ -91,15 +137,20 @@ export const Accordion = ({
isDisabled={item.isDisabled}
indicator={<ChevronDown className="text-gray-500" />}
classNames={{
base: index === 0 || index === 1 ? "my-2" : "my-1",
title: "text-sm font-medium",
base: index === 0 || index === 1 ? "my-1" : "my-1",
title: "text-sm",
subtitle: "text-xs text-gray-500",
trigger:
"p-2 rounded-lg data-[hover=true]:bg-gray-50 dark:data-[hover=true]:bg-gray-800/50",
content: "p-2",
"py-2 px-2 rounded-lg data-[hover=true]:bg-gray-50 dark:data-[hover=true]:bg-gray-800/50 w-full flex items-center",
content: "px-0 py-1",
}}
>
<AccordionContent content={item.content} items={item.items} />
<AccordionContent
content={item.content}
items={item.items}
selectedKeys={selectedKeys}
onSelectionChange={onSelectionChange}
/>
</AccordionItem>
))}
</NextUIAccordion>
@@ -70,6 +70,16 @@ interface HorizontalSplitBarProps {
* @default "text-gray-700"
*/
labelColor?: string;
/**
* Growth ratio multiplier (pixels per value unit)
* @default 1
*/
ratio?: number;
/**
* Show zero values in labels
* @default true
*/
showZero?: boolean;
}
/**
@@ -99,6 +109,8 @@ export const HorizontalSplitBar = ({
tooltipContentA,
tooltipContentB,
labelColor = "text-gray-700",
ratio = 1,
showZero = true,
}: HorizontalSplitBarProps) => {
// Reference to the container to measure its width
const containerRef = React.useRef<HTMLDivElement>(null);
@@ -150,8 +162,9 @@ export const HorizontalSplitBar = ({
const halfWidth = availableWidth / 2;
const separatorWidth = 1;
let rawWidthA = valA;
let rawWidthB = valB;
// Apply ratio multiplier to raw widths
let rawWidthA = valA * ratio;
let rawWidthB = valB * ratio;
// Determine if we need to scale to fit in available space
const maxSideWidth = halfWidth - separatorWidth / 2;
@@ -183,7 +196,7 @@ export const HorizontalSplitBar = ({
className={cn("text-xs font-medium", labelColor)}
aria-label={`${formattedValueA} ${tooltipContentA ? tooltipContentA : ""}`}
>
{valA > 0 ? formattedValueA : "0"}
{valA > 0 ? formattedValueA : showZero ? "0" : ""}
</div>
{/* Left bar */}
{valA > 0 && (
@@ -230,7 +243,7 @@ export const HorizontalSplitBar = ({
className={cn("text-xs font-medium", labelColor)}
aria-label={`${formattedValueB} ${tooltipContentB ? tooltipContentB : ""}`}
>
{valB > 0 ? formattedValueB : "0"}
{valB > 0 ? formattedValueB : showZero ? "0" : ""}
</div>
</div>
</div>
@@ -1,12 +1,13 @@
import { Suspense, use } from "react";
import { ReactNode, Suspense, use } from "react";
import { getUserInfo } from "@/actions/users/users";
import { Navbar } from "../nav-bar/navbar";
import { SkeletonContentLayout } from "./skeleton-content-layout";
interface ContentLayoutProps {
title: string;
icon: string;
icon: string | ReactNode;
children: React.ReactNode;
}
@@ -12,11 +12,23 @@ import {
} from "@nextui-org/react";
import { ChevronDown, X } from "lucide-react";
import { useSearchParams } from "next/navigation";
import React, { useCallback, useEffect, useMemo, useState } from "react";
import React, {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from "react";
import { CustomDropdownFilterProps } from "@/types";
import { EntityInfoShort } from "../entities";
import { ComplianceScanInfo } from "@/components/compliance/compliance-header/compliance-scan-info";
import { EntityInfoShort } from "@/components/ui/entities";
import { isScanEntity } from "@/lib/helper-filters";
import {
CustomDropdownFilterProps,
FilterEntity,
ProviderEntity,
ScanEntity,
} from "@/types";
export const CustomDropdownFilter = ({
filter,
@@ -25,6 +37,7 @@ export const CustomDropdownFilter = ({
const searchParams = useSearchParams();
const [groupSelected, setGroupSelected] = useState(new Set<string>());
const [isOpen, setIsOpen] = useState(false);
const hasUserInteracted = useRef(false);
const filterValues = useMemo(() => filter?.values || [], [filter?.values]);
const selectedValues = Array.from(groupSelected).filter(
@@ -38,28 +51,88 @@ export const CustomDropdownFilter = ({
return filterParam ? filterParam.split(",") : [];
}, [searchParams, filter?.key]);
// Sync URL state with component state
useEffect(() => {
if (activeFilterValue.length > 0) {
const newSelection = new Set(activeFilterValue);
if (newSelection.size === filterValues.length) {
// Helper function to handle URL filter values sync
const syncWithActiveFilters = useCallback(() => {
const newSelection = new Set(activeFilterValue);
if (
newSelection.size === filterValues.length &&
filter?.showSelectAll !== false
) {
newSelection.add("all");
}
setGroupSelected(newSelection);
}, [activeFilterValue, filterValues, filter?.showSelectAll]);
const resetComponentState = useCallback(() => {
setGroupSelected(new Set());
hasUserInteracted.current = false;
}, []);
const applyDefaultValues = useCallback(() => {
if (filter?.defaultToSelectAll && filterValues.length > 0) {
const newSelection = new Set(filterValues);
if (filter?.showSelectAll !== false) {
newSelection.add("all");
}
setGroupSelected(newSelection);
} else if (filter?.defaultValues && filter.defaultValues.length > 0) {
const validDefaultValues = filter.defaultValues.filter((value) =>
filterValues.includes(value),
);
const newSelection = new Set(validDefaultValues);
// Add "all" if all items are selected and showSelectAll is not false
if (
validDefaultValues.length === filterValues.length &&
filter?.showSelectAll !== false
) {
newSelection.add("all");
}
setGroupSelected(newSelection);
} else {
setGroupSelected(new Set());
}
}, [activeFilterValue, filterValues.length]);
}, [
filterValues,
filter?.defaultToSelectAll,
filter?.defaultValues,
filter?.showSelectAll,
]);
useEffect(() => {
const hasActiveFilters = activeFilterValue.length > 0;
const userHasInteracted = hasUserInteracted.current;
if (hasActiveFilters) {
// URL has filter values - sync component state with URL
syncWithActiveFilters();
} else if (userHasInteracted) {
// URL has no filters but user had interacted - reset component state
resetComponentState();
} else {
// URL has no filters and user hasn't interacted - apply defaults
applyDefaultValues();
}
}, [
activeFilterValue,
syncWithActiveFilters,
resetComponentState,
applyDefaultValues,
]);
const updateSelection = useCallback(
(newValues: string[]) => {
// Mark that user has interacted with the filter
hasUserInteracted.current = true;
const actualValues = newValues.filter((key) => key !== "all");
const newSelection = new Set(actualValues);
// Auto-add "all" if all items are selected
// Auto-add "all" if all items are selected and showSelectAll is not false
if (
actualValues.length === filterValues.length &&
filterValues.length > 0
filterValues.length > 0 &&
filter?.showSelectAll !== false
) {
newSelection.add("all");
}
@@ -69,7 +142,7 @@ export const CustomDropdownFilter = ({
// Notify parent with actual values (excluding "all")
onFilterChange?.(filter.key, actualValues);
},
[filterValues.length, onFilterChange, filter.key],
[filterValues.length, onFilterChange, filter.key, filter?.showSelectAll],
);
const onSelectionChange = useCallback(
@@ -111,10 +184,25 @@ export const CustomDropdownFilter = ({
const getDisplayLabel = useCallback(
(value: string) => {
const entity = filter.valueLabelMapping?.find((entry) => entry[value])?.[
value
];
return entity?.alias || entity?.uid || value;
const entity: FilterEntity | undefined = filter.valueLabelMapping?.find(
(entry) => entry[value],
)?.[value];
if (!entity) return value;
if (isScanEntity(entity as ScanEntity)) {
return (
(entity as ScanEntity).attributes?.name ||
(entity as ScanEntity).providerInfo?.alias ||
(entity as ScanEntity).providerInfo?.uid ||
value
);
} else {
return (
(entity as ProviderEntity).alias ||
(entity as ProviderEntity).uid ||
value
);
}
},
[filter.valueLabelMapping],
);
@@ -173,7 +261,7 @@ export const CustomDropdownFilter = ({
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
handleClearAll(e as any);
handleClearAll(e as unknown as React.MouseEvent);
}
}}
>
@@ -185,7 +273,7 @@ export const CustomDropdownFilter = ({
</div>
</Button>
</PopoverTrigger>
<PopoverContent className="w-80 dark:bg-prowler-blue-800">
<PopoverContent className="w-auto min-w-80 dark:bg-prowler-blue-800">
<div className="flex w-full flex-col gap-4 p-2">
<CheckboxGroup
color="default"
@@ -194,24 +282,29 @@ export const CustomDropdownFilter = ({
onValueChange={onSelectionChange}
className="font-bold"
>
<Checkbox
classNames={{
label: "text-small font-normal",
wrapper: "checkbox-update",
}}
value="all"
>
Select All
</Checkbox>
<Divider orientation="horizontal" className="mt-2" />
{filter?.showSelectAll !== false && (
<>
<Checkbox
classNames={{
label: "text-small font-normal",
wrapper: "checkbox-update",
}}
value="all"
>
Select All
</Checkbox>
<Divider orientation="horizontal" className="mt-2" />
</>
)}
<ScrollShadow
hideScrollBar
className="flex max-h-96 max-w-full flex-col gap-y-2 py-2"
>
{filterValues.map((value) => {
const entity = filter.valueLabelMapping?.find(
(entry) => entry[value],
)?.[value];
const entity: FilterEntity | undefined =
filter.valueLabelMapping?.find((entry) => entry[value])?.[
value
];
return (
<Checkbox
@@ -223,12 +316,18 @@ export const CustomDropdownFilter = ({
value={value}
>
{entity ? (
<EntityInfoShort
cloudProvider={entity.provider}
entityAlias={entity.alias ?? undefined}
entityId={entity.uid}
hideCopyButton
/>
isScanEntity(entity as ScanEntity) ? (
<ComplianceScanInfo scan={entity as ScanEntity} />
) : (
<EntityInfoShort
cloudProvider={(entity as ProviderEntity).provider}
entityAlias={
(entity as ProviderEntity).alias ?? undefined
}
entityId={(entity as ProviderEntity).uid}
hideCopyButton
/>
)
) : (
value
)}
+6 -2
View File
@@ -30,9 +30,13 @@ export const DateWithTime: React.FC<DateWithTimeProps> = ({
<div
className={`flex ${inline ? "flex-row items-center gap-2" : "flex-col"}`}
>
<span className="text-xs font-semibold">{formattedDate}</span>
<span className="whitespace-nowrap text-xs font-semibold">
{formattedDate}
</span>
{showTime && (
<span className="text-xs text-gray-500">{formattedTime}</span>
<span className="whitespace-nowrap text-xs text-gray-500">
{formattedTime}
</span>
)}
</div>
</div>
@@ -1,3 +1,4 @@
import { Tooltip } from "@nextui-org/react";
import React from "react";
import { IdIcon } from "@/components/icons";
@@ -11,6 +12,7 @@ interface EntityInfoProps {
entityAlias?: string;
entityId?: string;
hideCopyButton?: boolean;
snippetWidth?: string;
}
export const EntityInfoShort: React.FC<EntityInfoProps> = ({
@@ -20,12 +22,16 @@ export const EntityInfoShort: React.FC<EntityInfoProps> = ({
hideCopyButton = false,
}) => {
return (
<div className="flex w-full items-center justify-between space-x-2">
<div className="flex items-center gap-x-2">
<div className="flex items-center justify-start">
<div className="flex items-center justify-between gap-x-2">
<div className="flex-shrink-0">{getProviderLogo(cloudProvider)}</div>
<div className="flex flex-col">
<div className="flex max-w-[120px] flex-col">
{entityAlias && (
<span className="text-xs text-default-500">{entityAlias}</span>
<Tooltip content={entityAlias} placement="top" size="sm">
<span className="truncate text-ellipsis text-xs text-default-500">
{entityAlias}
</span>
</Tooltip>
)}
<SnippetChip
value={entityId ?? ""}
+9 -2
View File
@@ -23,6 +23,10 @@ export const SnippetChip = ({
return (
<Snippet
className={cn("h-6", className)}
classNames={{
content: "min-w-0 overflow-hidden",
pre: "min-w-0 overflow-hidden text-ellipsis whitespace-nowrap",
}}
color="default"
size="sm"
variant="flat"
@@ -34,10 +38,13 @@ export const SnippetChip = ({
codeString={value}
{...props}
>
<div className="flex items-center space-x-2" aria-label={ariaLabel}>
<div
className="flex min-w-0 items-center space-x-2"
aria-label={ariaLabel}
>
{icon}
<Tooltip content={value} placement="top" size="sm">
<span className="no-scrollbar max-w-24 overflow-hidden overflow-x-scroll text-ellipsis whitespace-nowrap text-xs">
<span className="min-w-0 flex-1 truncate text-xs">
{formatter ? formatter(value) : value}
</span>
</Tooltip>
+15 -7
View File
@@ -1,13 +1,15 @@
import { Icon } from "@iconify/react";
import { ReactNode } from "react";
import { ThemeSwitch } from "@/components/ThemeSwitch";
import { UserProfileProps } from "@/types";
import { SheetMenu } from "../sidebar/sheet-menu";
import { UserNav } from "../user-nav/user-nav";
interface NavbarProps {
title: string;
icon: string;
icon: string | ReactNode;
user: UserProfileProps;
}
@@ -17,12 +19,18 @@ export function Navbar({ title, icon, user }: NavbarProps) {
<div className="mx-4 flex h-14 items-center sm:mx-8">
<div className="flex items-center space-x-2">
<SheetMenu />
<Icon
className="text-default-500"
height={24}
icon={icon}
width={24}
/>
{typeof icon === "string" ? (
<Icon
className="text-default-500"
height={24}
icon={icon}
width={24}
/>
) : (
<div className="flex h-10 w-10 items-center justify-center [&>*]:h-full [&>*]:w-full">
{icon}
</div>
)}
<h1 className="text-sm font-bold text-default-700">{title}</h1>
</div>
<div className="flex flex-1 items-center justify-end gap-3">
+123
View File
@@ -0,0 +1,123 @@
import { cn } from "@/lib/utils";
interface SkeletonProps {
className?: string;
variant?: "default" | "card" | "table" | "text" | "circle" | "rectangular";
width?: string | number;
height?: string | number;
animate?: boolean;
}
export function Skeleton({
className,
variant = "default",
width,
height,
animate = true,
}: SkeletonProps) {
const variantClasses = {
default: "w-full h-4 rounded-lg",
card: "w-full h-40 rounded-xl",
table: "w-full h-60 rounded-lg",
text: "w-24 h-4 rounded-full",
circle: "rounded-full w-8 h-8",
rectangular: "rounded-md",
};
return (
<div
style={{
width: width
? typeof width === "number"
? `${width}px`
: width
: undefined,
height: height
? typeof height === "number"
? `${height}px`
: height
: undefined,
}}
className={cn(
"animate-pulse bg-gray-200 dark:bg-prowler-blue-800",
variantClasses[variant],
!animate && "animate-none",
className,
)}
/>
);
}
export function SkeletonTable({
rows = 5,
columns = 4,
className,
roundedCells = true,
}: {
rows?: number;
columns?: number;
className?: string;
roundedCells?: boolean;
}) {
return (
<div className={cn("w-full space-y-4", className)}>
{/* Header */}
<div className="flex items-center space-x-4 pb-4">
{Array.from({ length: columns }).map((_, index) => (
<Skeleton
key={`header-${index}`}
className={cn("h-8", roundedCells && "rounded-lg")}
width={`${100 / columns}%`}
variant={roundedCells ? "default" : "rectangular"}
/>
))}
</div>
{/* Rows */}
{Array.from({ length: rows }).map((_, rowIndex) => (
<div
key={`row-${rowIndex}`}
className="flex items-center space-x-4 py-3"
>
{Array.from({ length: columns }).map((_, colIndex) => (
<Skeleton
key={`cell-${rowIndex}-${colIndex}`}
className={cn("h-6", roundedCells && "rounded-lg")}
width={`${100 / columns}%`}
variant={roundedCells ? "default" : "rectangular"}
/>
))}
</div>
))}
</div>
);
}
export function SkeletonCard({ className }: { className?: string }) {
return (
<div className={cn("space-y-3", className)}>
<Skeleton variant="card" />
<Skeleton className="h-4 w-2/3" />
<Skeleton className="h-4 w-1/2" />
</div>
);
}
export function SkeletonText({
lines = 3,
className,
lastLineWidth = "w-1/2",
}: {
lines?: number;
className?: string;
lastLineWidth?: string;
}) {
return (
<div className={cn("space-y-2", className)}>
{Array.from({ length: lines - 1 }).map((_, index) => (
<Skeleton key={index} className="h-4 w-full" variant="text" />
))}
<Skeleton className={cn("h-4", lastLineWidth)} variant="text" />
</div>
);
}
@@ -55,7 +55,7 @@ export const DataTableFilterCustom = ({
size="md"
startContent={<CustomFilterIcon size={16} />}
onPress={() => setShowFilters(!showFilters)}
className="w-fit"
className="w-full max-w-fit"
>
<h3 className="text-small">
{showFilters ? "Hide Filters" : "Show Filters"}
+138 -70
View File
@@ -23,9 +23,19 @@ import {
interface DataTablePaginationProps {
metadata?: MetaDataProps;
disableScroll?: boolean;
}
export function DataTablePagination({ metadata }: DataTablePaginationProps) {
const baseLinkClass =
"relative block rounded border-0 bg-transparent px-3 py-1.5 text-gray-800 outline-none transition-all duration-300 hover:bg-gray-200 hover:text-gray-800 focus:shadow-none dark:text-prowler-theme-green";
const disabledLinkClass =
"text-gray-300 dark:text-gray-600 hover:bg-transparent hover:text-gray-300 dark:hover:text-gray-600 cursor-default pointer-events-none";
export function DataTablePagination({
metadata,
disableScroll = false,
}: DataTablePaginationProps) {
const pathname = usePathname();
const searchParams = useSearchParams();
const router = useRouter();
@@ -41,90 +51,148 @@ export function DataTablePagination({ metadata }: DataTablePaginationProps) {
const createPageUrl = (pageNumber: number | string) => {
const params = new URLSearchParams(searchParams);
if (pageNumber === "...") return `${pathname}?${params.toString()}`;
// Preserve all important parameters
const scanId = searchParams.get("scanId");
const id = searchParams.get("id");
const version = searchParams.get("version");
if (+pageNumber > totalPages) {
return `${pathname}?${params.toString()}`;
}
params.set("page", pageNumber.toString());
// Ensure that scanId, id and version are preserved
if (scanId) params.set("scanId", scanId);
if (id) params.set("id", id);
if (version) params.set("version", version);
return `${pathname}?${params.toString()}`;
};
const isFirstPage = currentPage === 1;
const isLastPage = currentPage === totalPages;
return (
<div className="flex w-full flex-col-reverse items-center justify-between gap-4 overflow-auto p-1 sm:flex-row sm:gap-8">
<div className="whitespace-nowrap text-sm font-medium">
{totalEntries} entries in Total.
<div className="whitespace-nowrap text-sm">
{totalEntries} entries in total
</div>
<div className="flex flex-col-reverse items-center gap-4 sm:flex-row sm:gap-6 lg:gap-8">
{/* Rows per page selector */}
<div className="flex items-center space-x-2">
<p className="whitespace-nowrap text-sm font-medium">Rows per page</p>
<Select
value={selectedPageSize}
onValueChange={(value) => {
setSelectedPageSize(value);
{totalEntries > 10 && (
<div className="flex flex-col-reverse items-center gap-4 sm:flex-row sm:gap-6 lg:gap-8">
{/* Rows per page selector */}
<div className="flex items-center space-x-2">
<p className="whitespace-nowrap text-sm font-medium">
Rows per page
</p>
<Select
value={selectedPageSize}
onValueChange={(value) => {
setSelectedPageSize(value);
const params = new URLSearchParams(searchParams);
params.set("pageSize", value);
params.set("page", "1");
const params = new URLSearchParams(searchParams);
// This pushes the URL without reloading the page
router.push(`${pathname}?${params.toString()}`);
}}
>
<SelectTrigger className="h-8 w-[4.5rem]">
<SelectValue />
</SelectTrigger>
<SelectContent side="top">
{itemsPerPageOptions.map((pageSize) => (
<SelectItem
key={pageSize}
value={`${pageSize}`}
className="cursor-pointer"
>
{pageSize}
</SelectItem>
))}
</SelectContent>
</Select>
// Preserve all important parameters
const scanId = searchParams.get("scanId");
const id = searchParams.get("id");
const version = searchParams.get("version");
params.set("pageSize", value);
params.set("page", "1");
// Ensure that scanId, id and version are preserved
if (scanId) params.set("scanId", scanId);
if (id) params.set("id", id);
if (version) params.set("version", version);
// This pushes the URL without reloading the page
if (disableScroll) {
const url = `${pathname}?${params.toString()}`;
router.push(url, { scroll: false });
} else {
router.push(`${pathname}?${params.toString()}`);
}
}}
>
<SelectTrigger className="h-8 w-[4.5rem]">
<SelectValue />
</SelectTrigger>
<SelectContent side="top">
{itemsPerPageOptions.map((pageSize) => (
<SelectItem
key={pageSize}
value={`${pageSize}`}
className="cursor-pointer"
>
{pageSize}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex items-center justify-center text-sm font-medium">
Page {currentPage} of {totalPages}
</div>
<div className="flex items-center space-x-2">
<Link
aria-label="Go to first page"
className={`${baseLinkClass} ${isFirstPage ? disabledLinkClass : ""}`}
href={
isFirstPage
? pathname + "?" + searchParams.toString()
: createPageUrl(1)
}
scroll={!disableScroll}
aria-disabled={isFirstPage}
onClick={(e) => isFirstPage && e.preventDefault()}
>
<DoubleArrowLeftIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to previous page"
className={`${baseLinkClass} ${isFirstPage ? disabledLinkClass : ""}`}
href={
isFirstPage
? pathname + "?" + searchParams.toString()
: createPageUrl(currentPage - 1)
}
scroll={!disableScroll}
aria-disabled={isFirstPage}
onClick={(e) => isFirstPage && e.preventDefault()}
>
<ChevronLeftIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to next page"
className={`${baseLinkClass} ${isLastPage ? disabledLinkClass : ""}`}
href={
isLastPage
? pathname + "?" + searchParams.toString()
: createPageUrl(currentPage + 1)
}
scroll={!disableScroll}
aria-disabled={isLastPage}
onClick={(e) => isLastPage && e.preventDefault()}
>
<ChevronRightIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to last page"
className={`${baseLinkClass} ${isLastPage ? disabledLinkClass : ""}`}
href={
isLastPage
? pathname + "?" + searchParams.toString()
: createPageUrl(totalPages)
}
scroll={!disableScroll}
aria-disabled={isLastPage}
onClick={(e) => isLastPage && e.preventDefault()}
>
<DoubleArrowRightIcon className="size-4" aria-hidden="true" />
</Link>
</div>
</div>
<div className="flex items-center justify-center text-sm font-medium">
Page {currentPage} of {totalPages}
</div>
<div className="flex items-center space-x-2">
<Link
aria-label="Go to first page"
className="page-link relative block rounded border-0 bg-transparent px-3 py-1.5 text-gray-800 outline-none transition-all duration-300 hover:bg-gray-200 hover:text-gray-800 focus:shadow-none dark:text-prowler-theme-green"
href={createPageUrl(1)}
aria-disabled="true"
>
<DoubleArrowLeftIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to previous page"
className="page-link relative block rounded border-0 bg-transparent px-3 py-1.5 text-gray-800 outline-none transition-all duration-300 hover:bg-gray-200 hover:text-gray-800 focus:shadow-none dark:text-prowler-theme-green"
href={createPageUrl(currentPage - 1)}
aria-disabled="true"
>
<ChevronLeftIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to next page"
className="page-link relative block rounded border-0 bg-transparent px-3 py-1.5 text-gray-800 outline-none transition-all duration-300 hover:bg-gray-200 hover:text-gray-800 focus:shadow-none dark:text-prowler-theme-green"
href={createPageUrl(currentPage + 1)}
>
<ChevronRightIcon className="size-4" aria-hidden="true" />
</Link>
<Link
aria-label="Go to last page"
className="page-link relative block rounded border-0 bg-transparent px-3 py-1.5 text-gray-800 outline-none transition-all duration-300 hover:bg-gray-200 hover:text-gray-800 focus:shadow-none dark:text-prowler-theme-green"
href={createPageUrl(totalPages)}
>
<DoubleArrowRightIcon className="size-4" aria-hidden="true" />
</Link>
</div>
</div>
)}
</div>
);
}
+6 -1
View File
@@ -29,12 +29,14 @@ interface DataTableProviderProps<TData, TValue> {
data: TData[];
metadata?: MetaDataProps;
customFilters?: FilterOption[];
disableScroll?: boolean;
}
export function DataTable<TData, TValue>({
columns,
data,
metadata,
disableScroll = false,
}: DataTableProviderProps<TData, TValue>) {
const [sorting, setSorting] = useState<SortingState>([]);
const [columnFilters, setColumnFilters] = useState<ColumnFiltersState>([]);
@@ -109,7 +111,10 @@ export function DataTable<TData, TValue>({
</div>
{metadata && (
<div className="flex w-full items-center space-x-2 py-4">
<DataTablePagination metadata={metadata} />
<DataTablePagination
metadata={metadata}
disableScroll={disableScroll}
/>
</div>
)}
</>
@@ -16,10 +16,12 @@ const statusColorMap: Record<
export const StatusFindingBadge = ({
status,
size = "sm",
value,
...props
}: {
status: FindingStatus;
size?: "sm" | "md" | "lg";
value?: string | number;
}) => {
const color = statusColorMap[status];
@@ -33,6 +35,7 @@ export const StatusFindingBadge = ({
>
<span className="text-xs font-light tracking-wide text-default-600">
{status.charAt(0).toUpperCase() + status.slice(1).toLowerCase()}
{value !== undefined && `: ${value}`}
</span>
</Chip>
);
@@ -5,7 +5,7 @@ import { useState } from "react";
import { CustomAlertModal, CustomButton } from "@/components/ui/custom";
import { DateWithTime, InfoField } from "@/components/ui/entities";
import { MembershipDetailData } from "@/types/users/users";
import { MembershipDetailData } from "@/types/users";
import { EditTenantForm } from "../forms";
@@ -1,6 +1,6 @@
import { Card, CardBody, CardHeader } from "@nextui-org/react";
import { MembershipDetailData, TenantDetailData } from "@/types/users/users";
import { MembershipDetailData, TenantDetailData } from "@/types/users";
import { MembershipItem } from "./membership-item";
+1 -1
View File
@@ -6,7 +6,7 @@ import { useState } from "react";
import { CustomButton } from "@/components/ui/custom/custom-button";
import { getRolePermissions } from "@/lib/permissions";
import { RoleData, RoleDetail } from "@/types/users/users";
import { RoleData, RoleDetail } from "@/types/users";
interface PermissionItemProps {
enabled: boolean;
+1 -1
View File
@@ -1,6 +1,6 @@
import { Card, CardBody, CardHeader } from "@nextui-org/react";
import { RoleData, RoleDetail } from "@/types/users/users";
import { RoleData, RoleDetail } from "@/types/users";
import { RoleItem } from "./role-item";
@@ -3,7 +3,7 @@
import { Card, CardBody, Divider } from "@nextui-org/react";
import { DateWithTime, InfoField, SnippetChip } from "@/components/ui/entities";
import { UserDataWithRoles } from "@/types/users/users";
import { UserDataWithRoles } from "@/types/users";
import { ProwlerShort } from "../../icons";
+9 -3
View File
@@ -16,8 +16,10 @@ export const useUrlFilters = () => {
(key: string, value: string | string[] | null) => {
const params = new URLSearchParams(searchParams.toString());
// Always reset page to 1 when a filter is applied
params.set("page", "1");
// Only reset page to 1 if page parameter already exists
if (params.has("page")) {
params.set("page", "1");
}
const filterKey = key.startsWith("filter[") ? key : `filter[${key}]`;
@@ -40,7 +42,11 @@ export const useUrlFilters = () => {
const filterKey = key.startsWith("filter[") ? key : `filter[${key}]`;
params.delete(filterKey);
params.set("page", "1");
// Only reset page to 1 if page parameter already exists
if (params.has("page")) {
params.set("page", "1");
}
router.push(`${pathname}?${params.toString()}`, { scroll: false });
},
+221
View File
@@ -0,0 +1,221 @@
import { ClientAccordionContent } from "@/components/compliance/compliance-accordion/client-accordion-content";
import { ComplianceAccordionRequirementTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-requeriment-title";
import { ComplianceAccordionTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-title";
import { AccordionItemProps } from "@/components/ui/accordion/Accordion";
import { FindingStatus } from "@/components/ui/table/status-finding-badge";
import {
AttributesData,
AWSWellArchitectedAttributesMetadata,
Framework,
Requirement,
RequirementItemData,
RequirementsData,
RequirementStatus,
} from "@/types/compliance";
export const mapComplianceData = (
attributesData: AttributesData,
requirementsData: RequirementsData,
): Framework[] => {
const attributes = attributesData?.data || [];
const requirements = requirementsData?.data || [];
// Create a map for quick lookup of requirements by id
const requirementsMap = new Map<string, RequirementItemData>();
requirements.forEach((req: RequirementItemData) => {
requirementsMap.set(req.id, req);
});
const frameworks: Framework[] = [];
// Process attributes and merge with requirements data
for (const attributeItem of attributes) {
const id = attributeItem.id;
const metadataArray = attributeItem.attributes?.attributes
?.metadata as unknown as AWSWellArchitectedAttributesMetadata[];
const attrs = metadataArray?.[0];
if (!attrs) continue;
// Get corresponding requirement data
const requirementData = requirementsMap.get(id);
if (!requirementData) continue;
const frameworkName = attributeItem.attributes.framework;
const sectionName = attrs.Section || "";
const subSectionName = attrs.SubSection || "";
const description = attributeItem.attributes.description;
const status = requirementData.attributes.status || "";
const checks = attributeItem.attributes.attributes.check_ids || [];
const requirementName = id;
if (!sectionName || !subSectionName) {
continue;
}
// Find or create framework
let framework = frameworks.find((f) => f.name === frameworkName);
if (!framework) {
framework = {
name: frameworkName,
pass: 0,
fail: 0,
manual: 0,
categories: [],
};
frameworks.push(framework);
}
// Find or create category (Section)
let category = framework.categories.find((c) => c.name === sectionName);
if (!category) {
category = {
name: sectionName,
pass: 0,
fail: 0,
manual: 0,
controls: [],
};
framework.categories.push(category);
}
// Find or create control (SubSection)
let control = category.controls.find((c) => c.label === subSectionName);
if (!control) {
control = {
label: subSectionName,
pass: 0,
fail: 0,
manual: 0,
requirements: [],
};
category.controls.push(control);
}
// Create requirement
const finalStatus: RequirementStatus = status as RequirementStatus;
const requirement: Requirement = {
name: requirementName,
description: description,
status: finalStatus,
check_ids: checks,
pass: finalStatus === "PASS" ? 1 : 0,
fail: finalStatus === "FAIL" ? 1 : 0,
manual: finalStatus === "MANUAL" ? 1 : 0,
well_architected_name: attrs.Name,
well_architected_question_id: attrs.WellArchitectedQuestionId,
well_architected_practice_id: attrs.WellArchitectedPracticeId,
level_of_risk: attrs.LevelOfRisk,
assessment_method: attrs.AssessmentMethod,
implementation_guidance_url: attrs.ImplementationGuidanceUrl,
};
control.requirements.push(requirement);
}
// Calculate counters
frameworks.forEach((framework) => {
framework.pass = 0;
framework.fail = 0;
framework.manual = 0;
framework.categories.forEach((category) => {
category.pass = 0;
category.fail = 0;
category.manual = 0;
category.controls.forEach((control) => {
control.pass = 0;
control.fail = 0;
control.manual = 0;
control.requirements.forEach((requirement) => {
if (requirement.status === "MANUAL") {
control.manual++;
} else if (requirement.status === "PASS") {
control.pass++;
} else if (requirement.status === "FAIL") {
control.fail++;
}
});
category.pass += control.pass;
category.fail += control.fail;
category.manual += control.manual;
});
framework.pass += category.pass;
framework.fail += category.fail;
framework.manual += category.manual;
});
});
return frameworks;
};
export const toAccordionItems = (
data: Framework[],
scanId: string | undefined,
): AccordionItemProps[] => {
return data.flatMap((framework) =>
framework.categories.map((category) => {
return {
key: `${framework.name}-${category.name}`,
title: (
<ComplianceAccordionTitle
label={category.name}
pass={category.pass}
fail={category.fail}
manual={category.manual}
isParentLevel={true}
/>
),
content: "",
items: category.controls.map((control, i: number) => {
return {
key: `${framework.name}-${category.name}-control-${i}`,
title: (
<ComplianceAccordionTitle
label={control.label}
pass={control.pass}
fail={control.fail}
manual={control.manual}
/>
),
content: "",
items: control.requirements.map((requirement, j: number) => {
const itemKey = `${framework.name}-${category.name}-control-${i}-req-${j}`;
return {
key: itemKey,
title: (
<ComplianceAccordionRequirementTitle
type=""
name={
(requirement.well_architected_name as string) ||
requirement.name
}
status={requirement.status as FindingStatus}
/>
),
content: (
<ClientAccordionContent
requirement={requirement}
scanId={scanId || ""}
framework={framework.name}
disableFindings={
requirement.check_ids.length === 0 &&
requirement.manual === 0
}
/>
),
items: [],
};
}),
isDisabled:
control.pass === 0 && control.fail === 0 && control.manual === 0,
};
}),
};
}),
);
};
+204
View File
@@ -0,0 +1,204 @@
import { ClientAccordionContent } from "@/components/compliance/compliance-accordion/client-accordion-content";
import { ComplianceAccordionRequirementTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-requeriment-title";
import { ComplianceAccordionTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-title";
import { AccordionItemProps } from "@/components/ui/accordion/Accordion";
import { FindingStatus } from "@/components/ui/table/status-finding-badge";
import {
AttributesData,
CISAttributesMetadata,
Framework,
Requirement,
RequirementItemData,
RequirementsData,
RequirementStatus,
} from "@/types/compliance";
export const mapComplianceData = (
attributesData: AttributesData,
requirementsData: RequirementsData,
filter?: string, // "Level 1" or "Level 2" or undefined (show all)
): Framework[] => {
const attributes = attributesData?.data || [];
const requirements = requirementsData?.data || [];
// Create a map for quick lookup of requirements by id
const requirementsMap = new Map<string, RequirementItemData>();
requirements.forEach((req: RequirementItemData) => {
requirementsMap.set(req.id, req);
});
const frameworks: Framework[] = [];
// Process attributes and merge with requirements data
for (const attributeItem of attributes) {
const id = attributeItem.id;
const metadataArray = attributeItem.attributes?.attributes
?.metadata as unknown as CISAttributesMetadata[];
const attrs = metadataArray?.[0];
if (!attrs) continue;
// Apply profile filter
if (filter === "Level 1" && attrs.Profile !== "Level 1") {
continue; // Skip Level 2 requirements when Level 1 is selected
}
// Get corresponding requirement data
const requirementData = requirementsMap.get(id);
if (!requirementData) continue;
const frameworkName = attributeItem.attributes.framework;
const sectionName = attrs.Section;
const description = attributeItem.attributes.description;
const status = requirementData.attributes.status || "";
const checks = attributeItem.attributes.attributes.check_ids || [];
const requirementName = id;
// Find or create framework
let framework = frameworks.find((f) => f.name === frameworkName);
if (!framework) {
framework = {
name: frameworkName,
pass: 0,
fail: 0,
manual: 0,
categories: [],
};
frameworks.push(framework);
}
const normalizedSectionName = sectionName.replace(/^(\d+)\s/, "$1. ");
let category = framework.categories.find(
(c) => c.name === normalizedSectionName,
);
if (!category) {
category = {
name: normalizedSectionName,
pass: 0,
fail: 0,
manual: 0,
controls: [],
};
framework.categories.push(category);
}
// Create a control for this requirement (each requirement is its own control)
const controlLabel = `${id} - ${description}`;
const control = {
label: controlLabel,
pass: 0,
fail: 0,
manual: 0,
requirements: [] as Requirement[],
};
// Create requirement
const finalStatus: RequirementStatus = status as RequirementStatus;
const requirement: Requirement = {
name: requirementName,
description: attrs.Description,
status: finalStatus,
check_ids: checks,
pass: finalStatus === "PASS" ? 1 : 0,
fail: finalStatus === "FAIL" ? 1 : 0,
manual: finalStatus === "MANUAL" ? 1 : 0,
profile: attrs.Profile,
subsection: attrs.SubSection || "",
assessment_status: attrs.AssessmentStatus,
rationale_statement: attrs.RationaleStatement,
impact_statement: attrs.ImpactStatement,
remediation_procedure: attrs.RemediationProcedure,
audit_procedure: attrs.AuditProcedure,
additional_information: attrs.AdditionalInformation,
default_value: attrs.DefaultValue || "",
references: attrs.References,
};
control.requirements.push(requirement);
// Update control counters
if (requirement.status === "MANUAL") {
control.manual++;
} else if (requirement.status === "PASS") {
control.pass++;
} else if (requirement.status === "FAIL") {
control.fail++;
}
category.controls.push(control);
}
// Calculate counters for categories and frameworks
frameworks.forEach((framework) => {
framework.pass = 0;
framework.fail = 0;
framework.manual = 0;
framework.categories.forEach((category) => {
category.pass = 0;
category.fail = 0;
category.manual = 0;
category.controls.forEach((control) => {
category.pass += control.pass;
category.fail += control.fail;
category.manual += control.manual;
});
framework.pass += category.pass;
framework.fail += category.fail;
framework.manual += category.manual;
});
});
return frameworks;
};
export const toAccordionItems = (
data: Framework[],
scanId: string | undefined,
): AccordionItemProps[] => {
return data.flatMap((framework) =>
framework.categories.map((category) => {
return {
key: `${framework.name}-${category.name}`,
title: (
<ComplianceAccordionTitle
label={category.name}
pass={category.pass}
fail={category.fail}
manual={category.manual}
isParentLevel={true}
/>
),
content: "",
items: category.controls.map((control, i: number) => {
const requirement = control.requirements[0]; // Each control has one requirement
const itemKey = `${framework.name}-${category.name}-control-${i}`;
return {
key: itemKey,
title: (
<ComplianceAccordionRequirementTitle
type=""
name={control.label}
status={requirement.status as FindingStatus}
/>
),
content: (
<ClientAccordionContent
requirement={requirement}
scanId={scanId || ""}
framework={framework.name}
disableFindings={
requirement.check_ids.length === 0 && requirement.manual === 0
}
/>
),
items: [],
};
}),
};
}),
);
};
+196
View File
@@ -0,0 +1,196 @@
import React from "react";
import { AWSWellArchitectedCustomDetails } from "@/components/compliance/compliance-custom-details/aws-well-architected-details";
import { CISCustomDetails } from "@/components/compliance/compliance-custom-details/cis-details";
import { ENSCustomDetails } from "@/components/compliance/compliance-custom-details/ens-details";
import { ISOCustomDetails } from "@/components/compliance/compliance-custom-details/iso-details";
import { AccordionItemProps } from "@/components/ui/accordion/Accordion";
import {
AttributesData,
CategoryData,
FailedSection,
Framework,
Requirement,
RequirementsData,
} from "@/types/compliance";
import {
mapComplianceData as mapAWSWellArchitectedComplianceData,
toAccordionItems as toAWSWellArchitectedAccordionItems,
} from "./aws-well-architected";
import {
mapComplianceData as mapCISComplianceData,
toAccordionItems as toCISAccordionItems,
} from "./cis";
import {
mapComplianceData as mapENSComplianceData,
toAccordionItems as toENSAccordionItems,
} from "./ens";
import {
mapComplianceData as mapISOComplianceData,
toAccordionItems as toISOAccordionItems,
} from "./iso";
export interface ComplianceMapper {
mapComplianceData: (
attributesData: AttributesData,
requirementsData: RequirementsData,
filter?: string,
) => Framework[];
toAccordionItems: (
data: Framework[],
scanId: string | undefined,
) => AccordionItemProps[];
getTopFailedSections: (mappedData: Framework[]) => FailedSection[];
getDetailsComponent: (requirement: Requirement) => React.ReactNode;
}
// Common function for getting top failed sections
export const getTopFailedSections = (
mappedData: Framework[],
): FailedSection[] => {
const failedSectionMap = new Map();
mappedData.forEach((framework) => {
framework.categories.forEach((category) => {
category.controls.forEach((control) => {
control.requirements.forEach((requirement) => {
if (requirement.status === "FAIL") {
const sectionName = category.name;
if (!failedSectionMap.has(sectionName)) {
failedSectionMap.set(sectionName, { total: 0, types: {} });
}
const sectionData = failedSectionMap.get(sectionName);
sectionData.total += 1;
const type = requirement.type || "Fails";
sectionData.types[type as string] =
(sectionData.types[type as string] || 0) + 1;
}
});
});
});
});
// Convert in descending order and slice top 5
return Array.from(failedSectionMap.entries())
.map(([name, data]) => ({ name, ...data }))
.sort((a, b) => b.total - a.total)
.slice(0, 5); // Top 5
};
// Registry of compliance mappers
const complianceMappers: Record<string, ComplianceMapper> = {
ENS: {
mapComplianceData: mapENSComplianceData,
toAccordionItems: toENSAccordionItems,
getTopFailedSections,
getDetailsComponent: (requirement: Requirement) =>
React.createElement(ENSCustomDetails, { requirement }),
},
ISO27001: {
mapComplianceData: mapISOComplianceData,
toAccordionItems: toISOAccordionItems,
getTopFailedSections,
getDetailsComponent: (requirement: Requirement) =>
React.createElement(ISOCustomDetails, { requirement }),
},
CIS: {
mapComplianceData: mapCISComplianceData,
toAccordionItems: toCISAccordionItems,
getTopFailedSections,
getDetailsComponent: (requirement: Requirement) =>
React.createElement(CISCustomDetails, { requirement }),
},
"AWS-Well-Architected-Framework-Security-Pillar": {
mapComplianceData: mapAWSWellArchitectedComplianceData,
toAccordionItems: toAWSWellArchitectedAccordionItems,
getTopFailedSections,
getDetailsComponent: (requirement: Requirement) =>
React.createElement(AWSWellArchitectedCustomDetails, { requirement }),
},
"AWS-Well-Architected-Framework-Reliability-Pillar": {
mapComplianceData: mapAWSWellArchitectedComplianceData,
toAccordionItems: toAWSWellArchitectedAccordionItems,
getTopFailedSections,
getDetailsComponent: (requirement: Requirement) =>
React.createElement(AWSWellArchitectedCustomDetails, { requirement }),
},
};
// Default mapper (fallback to ENS for backward compatibility)
const defaultMapper: ComplianceMapper = complianceMappers.ENS;
/**
* Get the appropriate compliance mapper based on the framework name
* @param framework - The framework name (e.g., "ENS", "ISO27001", "CIS")
* @returns ComplianceMapper object with specific functions for the framework
*/
export const getComplianceMapper = (framework?: string): ComplianceMapper => {
if (!framework) {
return defaultMapper;
}
return complianceMappers[framework] || defaultMapper;
};
export const calculateCategoryHeatmapData = (
complianceData: Framework[],
): CategoryData[] => {
if (!complianceData?.length) {
return [];
}
try {
const categoryMap = new Map<
string,
{ pass: number; fail: number; manual: number }
>();
// Aggregate data by category
complianceData.forEach((framework) => {
framework.categories.forEach((category) => {
const existing = categoryMap.get(category.name) || {
pass: 0,
fail: 0,
manual: 0,
};
categoryMap.set(category.name, {
pass: existing.pass + category.pass,
fail: existing.fail + category.fail,
manual: existing.manual + category.manual,
});
});
});
const categoryData: CategoryData[] = Array.from(categoryMap.entries()).map(
([name, stats]) => {
const totalRequirements = stats.pass + stats.fail + stats.manual;
const failurePercentage =
totalRequirements > 0
? Math.round((stats.fail / totalRequirements) * 100)
: 0;
return {
name,
failurePercentage,
totalRequirements,
failedRequirements: stats.fail,
};
},
);
const filteredData = categoryData
.filter((category) => category.totalRequirements > 0)
.sort((a, b) => b.failurePercentage - a.failurePercentage)
.slice(0, 9); // Show top 9 categories
return filteredData;
} catch (error) {
console.error("Error calculating category heatmap data:", error);
return [];
}
};
+249
View File
@@ -0,0 +1,249 @@
import { ClientAccordionContent } from "@/components/compliance/compliance-accordion/client-accordion-content";
import { ComplianceAccordionRequirementTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-requeriment-title";
import { ComplianceAccordionTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-title";
import { AccordionItemProps } from "@/components/ui/accordion/Accordion";
import { FindingStatus } from "@/components/ui/table/status-finding-badge";
import {
AttributesData,
ENSAttributesMetadata,
Framework,
Requirement,
RequirementItemData,
RequirementsData,
RequirementStatus,
} from "@/types/compliance";
export const translateType = (type: string) => {
if (!type) {
return "";
}
switch (type.toLowerCase()) {
case "requisito":
return "Requirement";
case "recomendacion":
return "Recommendation";
case "refuerzo":
return "Reinforcement";
case "medida":
return "Measure";
default:
return type;
}
};
export const mapComplianceData = (
attributesData: AttributesData,
requirementsData: RequirementsData,
): Framework[] => {
const attributes = attributesData?.data || [];
const requirements = requirementsData?.data || [];
// Create a map for quick lookup of requirements by id
const requirementsMap = new Map<string, RequirementItemData>();
requirements.forEach((req: RequirementItemData) => {
requirementsMap.set(req.id, req);
});
const frameworks: Framework[] = [];
// Process attributes and merge with requirements data
for (const attributeItem of attributes) {
const id = attributeItem.id;
const attrs = attributeItem.attributes?.attributes
?.metadata?.[0] as ENSAttributesMetadata;
if (!attrs) continue;
// Get corresponding requirement data
const requirementData = requirementsMap.get(id);
if (!requirementData) continue;
const frameworkName = attrs.Marco;
const categoryName = attrs.Categoria;
const groupControl = attrs.IdGrupoControl;
const type = attrs.Tipo;
const description = attributeItem.attributes.description;
const status = requirementData.attributes.status || "";
const controlDescription = attrs.DescripcionControl || "";
const checks = attributeItem.attributes.attributes.check_ids || [];
const isManual = attrs.ModoEjecucion === "manual";
const requirementName = id;
const groupControlLabel = `${groupControl} - ${description}`;
// Find or create framework
let framework = frameworks.find((f) => f.name === frameworkName);
if (!framework) {
framework = {
name: frameworkName,
pass: 0,
fail: 0,
manual: 0,
categories: [],
};
frameworks.push(framework);
}
// Find or create category
let category = framework.categories.find((c) => c.name === categoryName);
if (!category) {
category = {
name: categoryName,
pass: 0,
fail: 0,
manual: 0,
controls: [],
};
framework.categories.push(category);
}
// Find or create control
let control = category.controls.find((c) => c.label === groupControlLabel);
if (!control) {
control = {
label: groupControlLabel,
pass: 0,
fail: 0,
manual: 0,
requirements: [],
};
category.controls.push(control);
}
// Create requirement
const finalStatus: RequirementStatus = isManual
? "MANUAL"
: (status as RequirementStatus);
const requirement: Requirement = {
name: requirementName,
description: controlDescription,
status: finalStatus,
type,
check_ids: checks,
pass: finalStatus === "PASS" ? 1 : 0,
fail: finalStatus === "FAIL" ? 1 : 0,
manual: finalStatus === "MANUAL" ? 1 : 0,
nivel: attrs.Nivel || "",
dimensiones: attrs.Dimensiones || [],
};
control.requirements.push(requirement);
}
// Calculate counters
frameworks.forEach((framework) => {
framework.pass = 0;
framework.fail = 0;
framework.manual = 0;
framework.categories.forEach((category) => {
category.pass = 0;
category.fail = 0;
category.manual = 0;
category.controls.forEach((control) => {
control.pass = 0;
control.fail = 0;
control.manual = 0;
control.requirements.forEach((requirement) => {
if (requirement.status === "MANUAL") {
control.manual++;
} else if (requirement.status === "PASS") {
control.pass++;
} else if (requirement.status === "FAIL") {
control.fail++;
}
});
category.pass += control.pass;
category.fail += control.fail;
category.manual += control.manual;
});
framework.pass += category.pass;
framework.fail += category.fail;
framework.manual += category.manual;
});
});
return frameworks;
};
export const toAccordionItems = (
data: Framework[],
scanId: string | undefined,
): AccordionItemProps[] => {
return data.map((framework) => {
return {
key: framework.name,
title: (
<ComplianceAccordionTitle
label={framework.name}
pass={framework.pass}
fail={framework.fail}
manual={framework.manual}
isParentLevel={true}
/>
),
content: "",
items: framework.categories.map((category) => {
return {
key: `${framework.name}-${category.name}`,
title: (
<ComplianceAccordionTitle
label={category.name}
pass={category.pass}
fail={category.fail}
manual={category.manual}
/>
),
content: "",
items: category.controls.map((control, i: number) => {
return {
key: `${framework.name}-${category.name}-control-${i}`,
title: (
<ComplianceAccordionTitle
label={control.label}
pass={control.pass}
fail={control.fail}
manual={control.manual}
/>
),
content: "",
items: control.requirements.map((requirement, j: number) => {
const itemKey = `${framework.name}-${category.name}-control-${i}-req-${j}`;
return {
key: itemKey,
title: (
<ComplianceAccordionRequirementTitle
type={requirement.type as string}
name={requirement.name}
status={requirement.status as FindingStatus}
/>
),
content: (
<ClientAccordionContent
requirement={requirement}
scanId={scanId || ""}
framework={framework.name}
disableFindings={
requirement.check_ids.length === 0 &&
requirement.manual === 0
}
/>
),
};
}),
isDisabled:
control.pass === 0 &&
control.fail === 0 &&
control.manual === 0,
};
}),
};
}),
};
});
};
+212
View File
@@ -0,0 +1,212 @@
import { ClientAccordionContent } from "@/components/compliance/compliance-accordion/client-accordion-content";
import { ComplianceAccordionRequirementTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-requeriment-title";
import { ComplianceAccordionTitle } from "@/components/compliance/compliance-accordion/compliance-accordion-title";
import { AccordionItemProps } from "@/components/ui/accordion/Accordion";
import { FindingStatus } from "@/components/ui/table/status-finding-badge";
import {
AttributesData,
Framework,
ISO27001AttributesMetadata,
Requirement,
RequirementItemData,
RequirementsData,
RequirementStatus,
} from "@/types/compliance";
export const mapComplianceData = (
attributesData: AttributesData,
requirementsData: RequirementsData,
): Framework[] => {
const attributes = attributesData?.data || [];
const requirements = requirementsData?.data || [];
// Create a map for quick lookup of requirements by id
const requirementsMap = new Map<string, RequirementItemData>();
requirements.forEach((req: RequirementItemData) => {
requirementsMap.set(req.id, req);
});
const frameworks: Framework[] = [];
// Process attributes and merge with requirements data
for (const attributeItem of attributes) {
const id = attributeItem.id;
const metadataArray = attributeItem.attributes?.attributes
?.metadata as unknown as ISO27001AttributesMetadata[];
const attrs = metadataArray?.[0];
if (!attrs) continue;
// Get corresponding requirement data
const requirementData = requirementsMap.get(id);
if (!requirementData) continue;
const frameworkName = attributeItem.attributes.framework;
const categoryName = attrs.Category;
const controlLabel = `${attrs.Objetive_ID} - ${attrs.Objetive_Name}`;
const description = attributeItem.attributes.description;
const status = requirementData.attributes.status || "";
const checks = attributeItem.attributes.attributes.check_ids || [];
const requirementName = id;
const objetiveName = attrs.Objetive_Name;
const checkSummary = attrs.Check_Summary;
// Find or create framework
let framework = frameworks.find((f) => f.name === frameworkName);
if (!framework) {
framework = {
name: frameworkName,
pass: 0,
fail: 0,
manual: 0,
categories: [],
};
frameworks.push(framework);
}
// Find or create category
let category = framework.categories.find((c) => c.name === categoryName);
if (!category) {
category = {
name: categoryName,
pass: 0,
fail: 0,
manual: 0,
controls: [],
};
framework.categories.push(category);
}
// Find or create control
let control = category.controls.find((c) => c.label === controlLabel);
if (!control) {
control = {
label: controlLabel,
pass: 0,
fail: 0,
manual: 0,
requirements: [],
};
category.controls.push(control);
}
// Create requirement
const finalStatus: RequirementStatus = status as RequirementStatus;
const requirement: Requirement = {
name: requirementName,
description: description,
status: finalStatus,
check_ids: checks,
pass: finalStatus === "PASS" ? 1 : 0,
fail: finalStatus === "FAIL" ? 1 : 0,
manual: finalStatus === "MANUAL" ? 1 : 0,
objetive_name: objetiveName,
check_summary: checkSummary,
};
control.requirements.push(requirement);
}
// Calculate counters
frameworks.forEach((framework) => {
framework.pass = 0;
framework.fail = 0;
framework.manual = 0;
framework.categories.forEach((category) => {
category.pass = 0;
category.fail = 0;
category.manual = 0;
category.controls.forEach((control) => {
control.pass = 0;
control.fail = 0;
control.manual = 0;
control.requirements.forEach((requirement) => {
if (requirement.status === "MANUAL") {
control.manual++;
} else if (requirement.status === "PASS") {
control.pass++;
} else if (requirement.status === "FAIL") {
control.fail++;
}
});
category.pass += control.pass;
category.fail += control.fail;
category.manual += control.manual;
});
framework.pass += category.pass;
framework.fail += category.fail;
framework.manual += category.manual;
});
});
return frameworks;
};
export const toAccordionItems = (
data: Framework[],
scanId: string | undefined,
): AccordionItemProps[] => {
return data.flatMap((framework) =>
framework.categories.map((category) => {
return {
key: `${framework.name}-${category.name}`,
title: (
<ComplianceAccordionTitle
label={category.name}
pass={category.pass}
fail={category.fail}
manual={category.manual}
isParentLevel={true}
/>
),
content: "",
items: category.controls.map((control, i: number) => {
return {
key: `${framework.name}-${category.name}-control-${i}`,
title: (
<ComplianceAccordionTitle
label={control.label}
pass={control.pass}
fail={control.fail}
manual={control.manual}
/>
),
content: "",
items: control.requirements.map((requirement, j: number) => {
const itemKey = `${framework.name}-${category.name}-control-${i}-req-${j}`;
return {
key: itemKey,
title: (
<ComplianceAccordionRequirementTitle
type=""
name={requirement.name}
status={requirement.status as FindingStatus}
/>
),
content: (
<ClientAccordionContent
requirement={requirement}
scanId={scanId || ""}
framework={framework.name}
disableFindings={
requirement.check_ids.length === 0 &&
requirement.manual === 0
}
/>
),
items: [],
};
}),
isDisabled:
control.pass === 0 && control.fail === 0 && control.manual === 0,
};
}),
};
}),
);
};
+54
View File
@@ -1,3 +1,6 @@
import { ProviderProps, ProvidersApiResponse, ScanProps } from "@/types";
import { ScanEntity } from "@/types/scans";
/**
* Extracts normalized filters and search query from the URL search params.
* Used Server Side Rendering (SSR). There is a hook (useUrlFilters) for client side.
@@ -44,3 +47,54 @@ export const extractSortAndKey = (searchParams: Record<string, unknown>) => {
return { searchParamsKey, rawSort, encodedSort };
};
export const isScanEntity = (entity: ScanEntity) => {
return entity && entity.providerInfo && entity.attributes;
};
/**
* Creates a scan details mapping for filters from completed scans.
* Used to provide detailed information for scan filters in the UI.
*/
export const createScanDetailsMapping = (
completedScans: ScanProps[],
providersData?: ProvidersApiResponse,
) => {
if (!completedScans || completedScans.length === 0) {
return [];
}
const scanMappings = completedScans.map((scan: ScanProps) => {
// Get provider info from providerInfo if available, or find from providers data
let providerInfo = scan.providerInfo;
if (!providerInfo && scan.relationships?.provider?.data?.id) {
const provider = providersData?.data?.find(
(p: ProviderProps) => p.id === scan.relationships.provider.data.id,
);
if (provider) {
providerInfo = {
provider: provider.attributes.provider,
alias: provider.attributes.alias,
uid: provider.attributes.uid,
};
}
}
return {
[scan.id]: {
providerInfo: {
provider: providerInfo?.provider || "aws",
alias: providerInfo?.alias,
uid: providerInfo?.uid,
},
attributes: {
name: scan.attributes.name,
completed_at: scan.attributes.completed_at,
},
},
};
});
return scanMappings;
};
+1 -1
View File
@@ -1,4 +1,4 @@
import { RolePermissionAttributes } from "@/types/users/users";
import { RolePermissionAttributes } from "@/types/users";
export const isUserOwnerAndHasManageAccount = (
roles: any[],
+2 -2
View File
@@ -1,5 +1,5 @@
import {
ProviderAccountProps,
ProviderEntity,
ProviderProps,
ProvidersApiResponse,
} from "@/types/providers";
@@ -21,7 +21,7 @@ export const extractProviderUIDs = (
export const createProviderDetailsMapping = (
providerUIDs: string[],
providersData: ProvidersApiResponse,
): Array<{ [uid: string]: ProviderAccountProps }> => {
): Array<{ [uid: string]: ProviderEntity }> => {
if (!providersData?.data) return [];
return providerUIDs.map((uid) => {
+1263 -51
View File
File diff suppressed because it is too large Load Diff
+3 -1
View File
@@ -14,6 +14,7 @@
"@radix-ui/react-toast": "^1.2.4",
"@react-aria/ssr": "3.9.4",
"@react-aria/visually-hidden": "3.8.12",
"@tailwindcss/typography": "^0.5.16",
"@tanstack/react-table": "^8.19.3",
"add": "^2.0.6",
"alert": "^6.0.2",
@@ -28,13 +29,14 @@
"jose": "^5.9.3",
"jwt-decode": "^4.0.0",
"lucide-react": "^0.471.0",
"next": "^14.2.26",
"next": "14.2.29",
"next-auth": "^5.0.0-beta.25",
"next-themes": "^0.2.1",
"radix-ui": "^1.1.3",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-hook-form": "^7.52.2",
"react-markdown": "^10.1.0",
"recharts": "^2.15.2",
"server-only": "^0.0.1",
"shadcn-ui": "^0.2.3",
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

Some files were not shown because too many files have changed in this diff Show More