mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-03-26 05:48:03 +00:00
Compare commits
47 Commits
PRWLR-4669
...
PRWLR-4669
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d190eb020 | ||
|
|
0459a4d6f6 | ||
|
|
19df649554 | ||
|
|
737550eb05 | ||
|
|
68d7d9f998 | ||
|
|
fa400ded7d | ||
|
|
ec9455ff75 | ||
|
|
2183f31ff5 | ||
|
|
67257a4212 | ||
|
|
001fa60a11 | ||
|
|
3c9a8b3634 | ||
|
|
0ec3ed8be7 | ||
|
|
3ed0b8a464 | ||
|
|
fd610d44c0 | ||
|
|
b8cc4b4f0f | ||
|
|
396e51c27d | ||
|
|
36e61cb7a2 | ||
|
|
78c6484ddb | ||
|
|
3f1e90a5b3 | ||
|
|
e1bfec898f | ||
|
|
b5b816dac9 | ||
|
|
81f970f2d3 | ||
|
|
3d9cd177a2 | ||
|
|
57854f23b7 | ||
|
|
9d7499b74f | ||
|
|
c49fdc114a | ||
|
|
95fd9d6b5e | ||
|
|
6a5bc75252 | ||
|
|
858c04b0b0 | ||
|
|
2d6f20e84b | ||
|
|
b0a98b1a87 | ||
|
|
577530ac69 | ||
|
|
c1a8d47e5b | ||
|
|
e80704d6f0 | ||
|
|
010de4b415 | ||
|
|
0a2b8e4315 | ||
|
|
5b0b85c0f8 | ||
|
|
f7e8df618b | ||
|
|
d00d254c90 | ||
|
|
f9fbde6637 | ||
|
|
7b1a0474db | ||
|
|
da4f9b8e5f | ||
|
|
32f69d24b6 | ||
|
|
d032a61a9e | ||
|
|
07e0dc2ef5 | ||
|
|
9e175e8504 | ||
|
|
6b8a434cda |
21
.github/dependabot.yml
vendored
21
.github/dependabot.yml
vendored
@@ -36,6 +36,16 @@ updates:
|
||||
- "dependencies"
|
||||
- "npm"
|
||||
|
||||
- package-ecosystem: "docker"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: master
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "docker"
|
||||
|
||||
# v4.6
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/"
|
||||
@@ -59,6 +69,17 @@ updates:
|
||||
- "github_actions"
|
||||
- "v4"
|
||||
|
||||
- package-ecosystem: "docker"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: v4.6
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "docker"
|
||||
- "v4"
|
||||
|
||||
# v3
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/"
|
||||
|
||||
5
.github/labeler.yml
vendored
5
.github/labeler.yml
vendored
@@ -22,6 +22,11 @@ provider/kubernetes:
|
||||
- any-glob-to-any-file: "prowler/providers/kubernetes/**"
|
||||
- any-glob-to-any-file: "tests/providers/kubernetes/**"
|
||||
|
||||
provider/github:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: "prowler/providers/github/**"
|
||||
- any-glob-to-any-file: "tests/providers/github/**"
|
||||
|
||||
github_actions:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: ".github/workflows/*"
|
||||
|
||||
2
.github/workflows/ui-pull-request.yml
vendored
2
.github/workflows/ui-pull-request.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Setup Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v3
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
- name: Install dependencies
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.12-alpine AS build
|
||||
FROM python:3.12.8-alpine3.20 AS build
|
||||
|
||||
LABEL maintainer="https://github.com/prowler-cloud/api"
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ description = "Prowler's API (Django/DRF)"
|
||||
license = "Apache-2.0"
|
||||
name = "prowler-api"
|
||||
package-mode = false
|
||||
version = "1.0.0"
|
||||
version = "1.1.0"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
celery = {extras = ["pytest"], version = "^5.4.0"}
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
import uuid
|
||||
|
||||
from django.db import connection, transaction
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import transaction
|
||||
from rest_framework import permissions
|
||||
from rest_framework.exceptions import NotAuthenticated
|
||||
from rest_framework.filters import SearchFilter
|
||||
from rest_framework_json_api import filters
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
from rest_framework_json_api.views import ModelViewSet
|
||||
from rest_framework_simplejwt.authentication import JWTAuthentication
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
|
||||
from api.filters import CustomDjangoFilterBackend
|
||||
from api.models import Role, Tenant
|
||||
from api.rbac.permissions import HasPermissions
|
||||
|
||||
|
||||
class BaseViewSet(ModelViewSet):
|
||||
authentication_classes = [JWTAuthentication]
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
required_permissions = []
|
||||
permission_classes = [permissions.IsAuthenticated, HasPermissions]
|
||||
filter_backends = [
|
||||
filters.QueryParameterValidationFilter,
|
||||
filters.OrderingFilter,
|
||||
@@ -28,6 +31,17 @@ class BaseViewSet(ModelViewSet):
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
"""
|
||||
Sets required_permissions before permissions are checked.
|
||||
"""
|
||||
self.set_required_permissions()
|
||||
super().initial(request, *args, **kwargs)
|
||||
|
||||
def set_required_permissions(self):
|
||||
"""This is an abstract method that must be implemented by subclasses."""
|
||||
NotImplemented
|
||||
|
||||
def get_queryset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -47,13 +61,7 @@ class BaseRLSViewSet(BaseViewSet):
|
||||
if tenant_id is None:
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
try:
|
||||
uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise ValidationError("Tenant ID must be a valid UUID")
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
|
||||
with rls_transaction(tenant_id):
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
@@ -66,7 +74,39 @@ class BaseRLSViewSet(BaseViewSet):
|
||||
class BaseTenantViewset(BaseViewSet):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
with transaction.atomic():
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
tenant = super().dispatch(request, *args, **kwargs)
|
||||
|
||||
try:
|
||||
# If the request is a POST, create the admin role
|
||||
if request.method == "POST":
|
||||
isinstance(tenant, dict) and self._create_admin_role(tenant.data["id"])
|
||||
except Exception as e:
|
||||
self._handle_creation_error(e, tenant)
|
||||
raise
|
||||
|
||||
return tenant
|
||||
|
||||
def _create_admin_role(self, tenant_id):
|
||||
Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant_id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
|
||||
def _handle_creation_error(self, error, tenant):
|
||||
if tenant.data.get("id"):
|
||||
try:
|
||||
Tenant.objects.using(MainRouter.admin_db).filter(
|
||||
id=tenant.data["id"]
|
||||
).delete()
|
||||
except ObjectDoesNotExist:
|
||||
pass # Tenant might not exist, handle gracefully
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
if (
|
||||
@@ -75,8 +115,7 @@ class BaseTenantViewset(BaseViewSet):
|
||||
):
|
||||
user_id = str(request.user.id)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
|
||||
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
# TODO: DRY this when we have time
|
||||
@@ -87,13 +126,7 @@ class BaseTenantViewset(BaseViewSet):
|
||||
if tenant_id is None:
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
try:
|
||||
uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise ValidationError("Tenant ID must be a valid UUID")
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
|
||||
with rls_transaction(tenant_id):
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
@@ -114,12 +147,6 @@ class BaseUserViewset(BaseViewSet):
|
||||
if tenant_id is None:
|
||||
raise NotAuthenticated("Tenant ID is not present in token")
|
||||
|
||||
try:
|
||||
uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise ValidationError("Tenant ID must be a valid UUID")
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
|
||||
with rls_transaction(tenant_id):
|
||||
self.request.tenant_id = tenant_id
|
||||
return super().initial(request, *args, **kwargs)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import secrets
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
@@ -8,6 +9,7 @@ from django.core.paginator import Paginator
|
||||
from django.db import connection, models, transaction
|
||||
from psycopg2 import connect as psycopg2_connect
|
||||
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
|
||||
DB_PASSWORD = (
|
||||
@@ -23,6 +25,8 @@ TASK_RUNNER_DB_TABLE = "django_celery_results_taskresult"
|
||||
POSTGRES_TENANT_VAR = "api.tenant_id"
|
||||
POSTGRES_USER_VAR = "api.user_id"
|
||||
|
||||
SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def psycopg_connection(database_alias: str):
|
||||
@@ -44,10 +48,23 @@ def psycopg_connection(database_alias: str):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tenant_transaction(tenant_id: str):
|
||||
def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
|
||||
"""
|
||||
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
|
||||
if the value is a valid UUID.
|
||||
|
||||
Args:
|
||||
value (str): Database configuration parameter value.
|
||||
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
|
||||
"""
|
||||
with transaction.atomic():
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
|
||||
try:
|
||||
# just in case the value is an UUID object
|
||||
uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
raise ValidationError("Must be a valid UUID")
|
||||
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
|
||||
yield cursor
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from django.db import connection, transaction
|
||||
from rest_framework_json_api.serializers import ValidationError
|
||||
|
||||
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
|
||||
|
||||
|
||||
def set_tenant(func):
|
||||
@@ -31,7 +35,7 @@ def set_tenant(func):
|
||||
pass
|
||||
|
||||
# When calling the task
|
||||
some_task.delay(arg1, tenant_id="1234-abcd-5678")
|
||||
some_task.delay(arg1, tenant_id="8db7ca86-03cc-4d42-99f6-5e480baf6ab5")
|
||||
|
||||
# The tenant context will be set before the task logic executes.
|
||||
"""
|
||||
@@ -43,9 +47,12 @@ def set_tenant(func):
|
||||
tenant_id = kwargs.pop("tenant_id")
|
||||
except KeyError:
|
||||
raise KeyError("This task requires the tenant_id")
|
||||
|
||||
try:
|
||||
uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise ValidationError("Tenant ID must be a valid UUID")
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
|
||||
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -26,11 +26,13 @@ from api.models import (
|
||||
Finding,
|
||||
Invitation,
|
||||
Membership,
|
||||
PermissionChoices,
|
||||
Provider,
|
||||
ProviderGroup,
|
||||
ProviderSecret,
|
||||
Resource,
|
||||
ResourceTag,
|
||||
Role,
|
||||
Scan,
|
||||
ScanSummary,
|
||||
SeverityChoices,
|
||||
@@ -481,6 +483,26 @@ class UserFilter(FilterSet):
|
||||
}
|
||||
|
||||
|
||||
class RoleFilter(FilterSet):
|
||||
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
|
||||
updated_at = DateFilter(field_name="updated_at", lookup_expr="date")
|
||||
permission_state = ChoiceFilter(
|
||||
choices=PermissionChoices.choices, method="filter_permission_state"
|
||||
)
|
||||
|
||||
def filter_permission_state(self, queryset, name, value):
|
||||
return Role.filter_by_permission_state(queryset, value)
|
||||
|
||||
class Meta:
|
||||
model = Role
|
||||
fields = {
|
||||
"id": ["exact", "in"],
|
||||
"name": ["exact", "in"],
|
||||
"inserted_at": ["gte", "lte"],
|
||||
"updated_at": ["gte", "lte"],
|
||||
}
|
||||
|
||||
|
||||
class ComplianceOverviewFilter(FilterSet):
|
||||
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
|
||||
provider_type = ChoiceFilter(choices=Provider.ProviderChoices.choices)
|
||||
@@ -521,3 +543,25 @@ class ScanSummaryFilter(FilterSet):
|
||||
"inserted_at": ["date", "gte", "lte"],
|
||||
"region": ["exact", "icontains", "in"],
|
||||
}
|
||||
|
||||
|
||||
class ServiceOverviewFilter(ScanSummaryFilter):
|
||||
muted_findings = None
|
||||
|
||||
def is_valid(self):
|
||||
# Check if at least one of the inserted_at filters is present
|
||||
inserted_at_filters = [
|
||||
self.data.get("inserted_at"),
|
||||
self.data.get("inserted_at__gte"),
|
||||
self.data.get("inserted_at__lte"),
|
||||
]
|
||||
if not any(inserted_at_filters):
|
||||
raise ValidationError(
|
||||
{
|
||||
"inserted_at": [
|
||||
"At least one of filter[inserted_at], filter[inserted_at__gte], or "
|
||||
"filter[inserted_at__lte] is required."
|
||||
]
|
||||
}
|
||||
)
|
||||
return super().is_valid()
|
||||
|
||||
@@ -58,5 +58,96 @@
|
||||
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
|
||||
"inserted_at": "2024-11-13T11:55:41.237Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.role",
|
||||
"pk": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"name": "admin_test",
|
||||
"manage_users": true,
|
||||
"manage_account": true,
|
||||
"manage_billing": true,
|
||||
"manage_providers": true,
|
||||
"manage_integrations": true,
|
||||
"manage_scans": true,
|
||||
"unlimited_visibility": true,
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z",
|
||||
"updated_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.role",
|
||||
"pk": "845ff03a-87ef-42ba-9786-6577c70c4df0",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"name": "first_role",
|
||||
"manage_users": true,
|
||||
"manage_account": true,
|
||||
"manage_billing": true,
|
||||
"manage_providers": true,
|
||||
"manage_integrations": false,
|
||||
"manage_scans": false,
|
||||
"unlimited_visibility": true,
|
||||
"inserted_at": "2024-11-20T15:31:53.239Z",
|
||||
"updated_at": "2024-11-20T15:31:53.239Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.role",
|
||||
"pk": "902d726c-4bd5-413a-a2a4-f7b4754b6b20",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"name": "third_role",
|
||||
"manage_users": false,
|
||||
"manage_account": false,
|
||||
"manage_billing": false,
|
||||
"manage_providers": false,
|
||||
"manage_integrations": false,
|
||||
"manage_scans": true,
|
||||
"unlimited_visibility": false,
|
||||
"inserted_at": "2024-11-20T15:34:05.440Z",
|
||||
"updated_at": "2024-11-20T15:34:05.440Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.roleprovidergrouprelationship",
|
||||
"pk": "57fd024a-0a7f-49b4-a092-fa0979a07aaf",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"provider_group": "3fe28fb8-e545-424c-9b8f-69aff638f430",
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.roleprovidergrouprelationship",
|
||||
"pk": "a3cd0099-1c13-4df1-a5e5-ecdfec561b35",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"provider_group": "481769f5-db2b-447b-8b00-1dee18db90ec",
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.roleprovidergrouprelationship",
|
||||
"pk": "cfd84182-a058-40c2-af3c-0189b174940f",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
|
||||
"inserted_at": "2024-11-20T15:32:42.402Z"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "api.userrolerelationship",
|
||||
"pk": "92339663-e954-4fd8-98fb-8bfe15949975",
|
||||
"fields": {
|
||||
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
|
||||
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
|
||||
"user": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
|
||||
"inserted_at": "2024-11-20T15:36:14.302Z"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
246
api/src/backend/api/migrations/0003_rbac.py
Normal file
246
api/src/backend/api/migrations/0003_rbac.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Generated by Django 5.1.1 on 2024-12-05 12:29
|
||||
|
||||
import api.rls
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0002_token_migrations"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Role",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=255)),
|
||||
("manage_users", models.BooleanField(default=False)),
|
||||
("manage_account", models.BooleanField(default=False)),
|
||||
("manage_billing", models.BooleanField(default=False)),
|
||||
("manage_providers", models.BooleanField(default=False)),
|
||||
("manage_integrations", models.BooleanField(default=False)),
|
||||
("manage_scans", models.BooleanField(default=False)),
|
||||
("unlimited_visibility", models.BooleanField(default=False)),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "roles",
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="RoleProviderGroupRelationship",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "role_provider_group_relationship",
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="UserRoleRelationship",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "role_user_relationship",
|
||||
},
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
name="provider_group",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.providergroup"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
name="role",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.role"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="role",
|
||||
name="provider_groups",
|
||||
field=models.ManyToManyField(
|
||||
related_name="roles",
|
||||
through="api.RoleProviderGroupRelationship",
|
||||
to="api.providergroup",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="userrolerelationship",
|
||||
name="role",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.role"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="userrolerelationship",
|
||||
name="user",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="role",
|
||||
name="users",
|
||||
field=models.ManyToManyField(
|
||||
related_name="roles",
|
||||
through="api.UserRoleRelationship",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("role_id", "provider_group_id"),
|
||||
name="unique_role_provider_group_relationship",
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="roleprovidergrouprelationship",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_roleprovidergrouprelationship",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="userrolerelationship",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("role_id", "user_id"), name="unique_role_user_relationship"
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="userrolerelationship",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_userrolerelationship",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="role",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("tenant_id", "name"), name="unique_role_per_tenant"
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="role",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_role",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="InvitationRoleRelationship",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("inserted_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"invitation",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.invitation"
|
||||
),
|
||||
),
|
||||
(
|
||||
"role",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.role"
|
||||
),
|
||||
),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"db_table": "role_invitation_relationship",
|
||||
},
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="invitationrolerelationship",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("role_id", "invitation_id"),
|
||||
name="unique_role_invitation_relationship",
|
||||
),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="invitationrolerelationship",
|
||||
constraint=api.rls.RowLevelSecurityConstraint(
|
||||
"tenant_id",
|
||||
name="rls_on_invitationrolerelationship",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="role",
|
||||
name="invitations",
|
||||
field=models.ManyToManyField(
|
||||
related_name="roles",
|
||||
through="api.InvitationRoleRelationship",
|
||||
to="api.invitation",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
from django.db import migrations
|
||||
from api.db_router import MainRouter
|
||||
|
||||
|
||||
def create_admin_role(apps, schema_editor):
|
||||
Tenant = apps.get_model("api", "Tenant")
|
||||
Role = apps.get_model("api", "Role")
|
||||
User = apps.get_model("api", "User")
|
||||
UserRoleRelationship = apps.get_model("api", "UserRoleRelationship")
|
||||
|
||||
for tenant in Tenant.objects.using(MainRouter.admin_db).all():
|
||||
admin_role, _ = Role.objects.using(MainRouter.admin_db).get_or_create(
|
||||
name="admin",
|
||||
tenant=tenant,
|
||||
defaults={
|
||||
"manage_users": True,
|
||||
"manage_account": True,
|
||||
"manage_billing": True,
|
||||
"manage_providers": True,
|
||||
"manage_integrations": True,
|
||||
"manage_scans": True,
|
||||
"unlimited_visibility": True,
|
||||
},
|
||||
)
|
||||
users = User.objects.using(MainRouter.admin_db).filter(
|
||||
membership__tenant=tenant
|
||||
)
|
||||
for user in users:
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).get_or_create(
|
||||
user=user,
|
||||
role=admin_role,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("api", "0003_rbac"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(create_admin_role),
|
||||
]
|
||||
@@ -69,6 +69,21 @@ class StateChoices(models.TextChoices):
|
||||
CANCELLED = "cancelled", _("Cancelled")
|
||||
|
||||
|
||||
class PermissionChoices(models.TextChoices):
|
||||
"""
|
||||
Represents the different permission states that a role can have.
|
||||
|
||||
Attributes:
|
||||
UNLIMITED: Indicates that the role possesses all permissions.
|
||||
LIMITED: Indicates that the role has some permissions but not all.
|
||||
NONE: Indicates that the role does not have any permissions.
|
||||
"""
|
||||
|
||||
UNLIMITED = "unlimited", _("Unlimited permissions")
|
||||
LIMITED = "limited", _("Limited permissions")
|
||||
NONE = "none", _("No permissions")
|
||||
|
||||
|
||||
class ActiveProviderManager(models.Manager):
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().filter(self.active_provider_filter())
|
||||
@@ -298,19 +313,10 @@ class ProviderGroup(RowLevelSecurityProtectedModel):
|
||||
|
||||
|
||||
class ProviderGroupMembership(RowLevelSecurityProtectedModel):
|
||||
objects = ActiveProviderManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
provider = models.ForeignKey(
|
||||
Provider,
|
||||
on_delete=models.CASCADE,
|
||||
)
|
||||
provider_group = models.ForeignKey(
|
||||
ProviderGroup,
|
||||
on_delete=models.CASCADE,
|
||||
)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
|
||||
provider = models.ForeignKey(Provider, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "provider_group_memberships"
|
||||
@@ -327,7 +333,7 @@ class ProviderGroupMembership(RowLevelSecurityProtectedModel):
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "provider-group-memberships"
|
||||
resource_name = "provider_groups-provider"
|
||||
|
||||
|
||||
class Task(RowLevelSecurityProtectedModel):
|
||||
@@ -851,6 +857,150 @@ class Invitation(RowLevelSecurityProtectedModel):
|
||||
resource_name = "invitations"
|
||||
|
||||
|
||||
class Role(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
name = models.CharField(max_length=255)
|
||||
manage_users = models.BooleanField(default=False)
|
||||
manage_account = models.BooleanField(default=False)
|
||||
manage_billing = models.BooleanField(default=False)
|
||||
manage_providers = models.BooleanField(default=False)
|
||||
manage_integrations = models.BooleanField(default=False)
|
||||
manage_scans = models.BooleanField(default=False)
|
||||
unlimited_visibility = models.BooleanField(default=False)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
updated_at = models.DateTimeField(auto_now=True, editable=False)
|
||||
provider_groups = models.ManyToManyField(
|
||||
ProviderGroup, through="RoleProviderGroupRelationship", related_name="roles"
|
||||
)
|
||||
users = models.ManyToManyField(
|
||||
User, through="UserRoleRelationship", related_name="roles"
|
||||
)
|
||||
invitations = models.ManyToManyField(
|
||||
Invitation, through="InvitationRoleRelationship", related_name="roles"
|
||||
)
|
||||
|
||||
# Filter permission_state
|
||||
PERMISSION_FIELDS = [
|
||||
"manage_users",
|
||||
"manage_account",
|
||||
"manage_billing",
|
||||
"manage_providers",
|
||||
"manage_integrations",
|
||||
"manage_scans",
|
||||
]
|
||||
|
||||
@property
|
||||
def permission_state(self):
|
||||
values = [getattr(self, field) for field in self.PERMISSION_FIELDS]
|
||||
if all(values):
|
||||
return PermissionChoices.UNLIMITED
|
||||
elif not any(values):
|
||||
return PermissionChoices.NONE
|
||||
else:
|
||||
return PermissionChoices.LIMITED
|
||||
|
||||
@classmethod
|
||||
def filter_by_permission_state(cls, queryset, value):
|
||||
q_all_true = Q(**{field: True for field in cls.PERMISSION_FIELDS})
|
||||
q_all_false = Q(**{field: False for field in cls.PERMISSION_FIELDS})
|
||||
|
||||
if value == PermissionChoices.UNLIMITED:
|
||||
return queryset.filter(q_all_true)
|
||||
elif value == PermissionChoices.NONE:
|
||||
return queryset.filter(q_all_false)
|
||||
else:
|
||||
return queryset.exclude(q_all_true | q_all_false)
|
||||
|
||||
class Meta:
|
||||
db_table = "roles"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["tenant_id", "name"],
|
||||
name="unique_role_per_tenant",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "roles"
|
||||
|
||||
|
||||
class RoleProviderGroupRelationship(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
role = models.ForeignKey(Role, on_delete=models.CASCADE)
|
||||
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "role_provider_group_relationship"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["role_id", "provider_group_id"],
|
||||
name="unique_role_provider_group_relationship",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "role-provider_groups"
|
||||
|
||||
|
||||
class UserRoleRelationship(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
role = models.ForeignKey(Role, on_delete=models.CASCADE)
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "role_user_relationship"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["role_id", "user_id"],
|
||||
name="unique_role_user_relationship",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "user-roles"
|
||||
|
||||
|
||||
class InvitationRoleRelationship(RowLevelSecurityProtectedModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
role = models.ForeignKey(Role, on_delete=models.CASCADE)
|
||||
invitation = models.ForeignKey(Invitation, on_delete=models.CASCADE)
|
||||
inserted_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "role_invitation_relationship"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["role_id", "invitation_id"],
|
||||
name="unique_role_invitation_relationship",
|
||||
),
|
||||
RowLevelSecurityConstraint(
|
||||
field="tenant_id",
|
||||
name="rls_on_%(class)s",
|
||||
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
|
||||
),
|
||||
]
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "invitation-roles"
|
||||
|
||||
|
||||
class ComplianceOverview(RowLevelSecurityProtectedModel):
|
||||
objects = ActiveProviderManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
70
api/src/backend/api/rbac/permissions.py
Normal file
70
api/src/backend/api/rbac/permissions.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from enum import Enum
|
||||
from rest_framework.permissions import BasePermission
|
||||
from api.models import Provider, Role, User
|
||||
from api.db_router import MainRouter
|
||||
from typing import Optional
|
||||
from django.db.models import QuerySet
|
||||
|
||||
|
||||
class Permissions(Enum):
|
||||
MANAGE_USERS = "manage_users"
|
||||
MANAGE_ACCOUNT = "manage_account"
|
||||
MANAGE_BILLING = "manage_billing"
|
||||
MANAGE_PROVIDERS = "manage_providers"
|
||||
MANAGE_INTEGRATIONS = "manage_integrations"
|
||||
MANAGE_SCANS = "manage_scans"
|
||||
UNLIMITED_VISIBILITY = "unlimited_visibility"
|
||||
|
||||
|
||||
class HasPermissions(BasePermission):
|
||||
"""
|
||||
Custom permission to check if the user's role has the required permissions.
|
||||
The required permissions should be specified in the view as a list in `required_permissions`.
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
required_permissions = getattr(view, "required_permissions", [])
|
||||
if not required_permissions:
|
||||
return True
|
||||
|
||||
user_roles = (
|
||||
User.objects.using(MainRouter.admin_db).get(id=request.user.id).roles.all()
|
||||
)
|
||||
if not user_roles:
|
||||
return False
|
||||
|
||||
for perm in required_permissions:
|
||||
if not getattr(user_roles[0], perm.value, False):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_role(user: User) -> Optional[Role]:
|
||||
"""
|
||||
Retrieve the first role assigned to the given user.
|
||||
|
||||
Returns:
|
||||
The user's first Role instance if the user has any roles, otherwise None.
|
||||
"""
|
||||
return user.roles.first()
|
||||
|
||||
|
||||
def get_providers(role: Role) -> QuerySet[Provider]:
|
||||
"""
|
||||
Return a distinct queryset of Providers accessible by the given role.
|
||||
|
||||
If the role has no associated provider groups, an empty queryset is returned.
|
||||
|
||||
Args:
|
||||
role: A Role instance.
|
||||
|
||||
Returns:
|
||||
A QuerySet of Provider objects filtered by the role's provider groups.
|
||||
If the role has no provider groups, returns an empty queryset.
|
||||
"""
|
||||
provider_groups = role.provider_groups.all()
|
||||
if not provider_groups.exists():
|
||||
return Provider.objects.none()
|
||||
|
||||
return Provider.objects.filter(provider_groups__in=provider_groups).distinct()
|
||||
@@ -2,7 +2,7 @@ from contextlib import nullcontext
|
||||
|
||||
from rest_framework_json_api.renderers import JSONRenderer
|
||||
|
||||
from api.db_utils import tenant_transaction
|
||||
from api.db_utils import rls_transaction
|
||||
|
||||
|
||||
class APIJSONRenderer(JSONRenderer):
|
||||
@@ -13,9 +13,9 @@ class APIJSONRenderer(JSONRenderer):
|
||||
tenant_id = getattr(request, "tenant_id", None) if request else None
|
||||
include_param_present = "include" in request.query_params if request else False
|
||||
|
||||
# Use tenant_transaction if needed for included resources, otherwise do nothing
|
||||
# Use rls_transaction if needed for included resources, otherwise do nothing
|
||||
context_manager = (
|
||||
tenant_transaction(tenant_id)
|
||||
rls_transaction(tenant_id)
|
||||
if tenant_id and include_param_present
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,10 @@
|
||||
import pytest
|
||||
from django.urls import reverse
|
||||
from unittest.mock import patch
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header
|
||||
|
||||
|
||||
@patch("api.v1.views.MainRouter.admin_db", new="default")
|
||||
@pytest.mark.django_db
|
||||
def test_basic_authentication():
|
||||
client = APIClient()
|
||||
|
||||
@@ -13,6 +13,7 @@ def test_check_resources_between_different_tenants(
|
||||
enforce_test_user_db_connection,
|
||||
authenticated_api_client,
|
||||
tenants_fixture,
|
||||
set_user_admin_roles_fixture,
|
||||
):
|
||||
client = authenticated_api_client
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ from django.db.utils import ConnectionRouter
|
||||
from api.db_router import MainRouter
|
||||
from api.rls import Tenant
|
||||
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="admin")
|
||||
class TestMainDatabaseRouter:
|
||||
@pytest.fixture(scope="module")
|
||||
def router(self):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from unittest.mock import patch, call
|
||||
import uuid
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
|
||||
from api.decorators import set_tenant
|
||||
|
||||
|
||||
@@ -15,12 +17,12 @@ class TestSetTenantDecorator:
|
||||
def random_func(arg):
|
||||
return arg
|
||||
|
||||
tenant_id = "1234-abcd-5678"
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
result = random_func("test_arg", tenant_id=tenant_id)
|
||||
|
||||
assert (
|
||||
call(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
|
||||
call(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
|
||||
in mock_cursor.execute.mock_calls
|
||||
)
|
||||
assert result == "test_arg"
|
||||
|
||||
306
api/src/backend/api/tests/test_rbac.py
Normal file
306
api/src/backend/api/tests/test_rbac.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import pytest
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
from unittest.mock import patch, ANY, Mock
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestUserViewSet:
|
||||
def test_list_users_with_all_permissions(self, authenticated_client_rbac):
|
||||
response = authenticated_client_rbac.get(reverse("user-list"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert isinstance(response.json()["data"], list)
|
||||
|
||||
def test_list_users_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(reverse("user-list"))
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_retrieve_user_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
response = authenticated_client_rbac.get(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert (
|
||||
response.json()["data"]["attributes"]["email"]
|
||||
== create_test_user_rbac.email
|
||||
)
|
||||
|
||||
def test_retrieve_user_with_no_roles(
|
||||
self, authenticated_client_rbac_noroles, create_test_user_rbac_no_roles
|
||||
):
|
||||
response = authenticated_client_rbac_noroles.get(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac_no_roles.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_retrieve_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_create_user_with_all_permissions(self, authenticated_client_rbac):
|
||||
valid_user_payload = {
|
||||
"name": "test",
|
||||
"password": "newpassword123",
|
||||
"email": "new_user@test.com",
|
||||
}
|
||||
response = authenticated_client_rbac.post(
|
||||
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
|
||||
|
||||
def test_create_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
valid_user_payload = {
|
||||
"name": "test",
|
||||
"password": "newpassword123",
|
||||
"email": "new_user@test.com",
|
||||
}
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
|
||||
|
||||
def test_partial_update_user_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
updated_data = {
|
||||
"data": {
|
||||
"type": "users",
|
||||
"id": str(create_test_user_rbac.id),
|
||||
"attributes": {"name": "Updated Name"},
|
||||
},
|
||||
}
|
||||
response = authenticated_client_rbac.patch(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id}),
|
||||
data=updated_data,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["name"] == "Updated Name"
|
||||
|
||||
def test_partial_update_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
updated_data = {
|
||||
"data": {
|
||||
"type": "users",
|
||||
"attributes": {"name": "Updated Name"},
|
||||
}
|
||||
}
|
||||
response = authenticated_client_no_permissions_rbac.patch(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user.id}),
|
||||
data=updated_data,
|
||||
format="vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_delete_user_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
response = authenticated_client_rbac.delete(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
def test_delete_user_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.delete(
|
||||
reverse("user-detail", kwargs={"pk": create_test_user.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_me_with_all_permissions(
|
||||
self, authenticated_client_rbac, create_test_user_rbac
|
||||
):
|
||||
response = authenticated_client_rbac.get(reverse("user-me"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert (
|
||||
response.json()["data"]["attributes"]["email"]
|
||||
== create_test_user_rbac.email
|
||||
)
|
||||
|
||||
def test_me_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, create_test_user
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(reverse("user-me"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["email"] == "rbac_limited@rbac.com"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProviderViewSet:
|
||||
def test_list_providers_with_all_permissions(
|
||||
self, authenticated_client_rbac, providers_fixture
|
||||
):
|
||||
response = authenticated_client_rbac.get(reverse("provider-list"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == len(providers_fixture)
|
||||
|
||||
def test_list_providers_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
response = authenticated_client_no_permissions_rbac.get(
|
||||
reverse("provider-list")
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()["data"]) == 0
|
||||
|
||||
def test_retrieve_provider_with_all_permissions(
|
||||
self, authenticated_client_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_rbac.get(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["alias"] == provider.alias
|
||||
|
||||
def test_retrieve_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_no_permissions_rbac.get(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_create_provider_with_all_permissions(self, authenticated_client_rbac):
|
||||
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
|
||||
response = authenticated_client_rbac.post(
|
||||
reverse("provider-list"), data=payload, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["data"]["attributes"]["alias"] == "new_alias"
|
||||
|
||||
def test_create_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac
|
||||
):
|
||||
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse("provider-list"), data=payload, format="json"
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_partial_update_provider_with_all_permissions(
|
||||
self, authenticated_client_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
payload = {
|
||||
"data": {
|
||||
"type": "providers",
|
||||
"id": provider.id,
|
||||
"attributes": {"alias": "updated_alias"},
|
||||
},
|
||||
}
|
||||
response = authenticated_client_rbac.patch(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id}),
|
||||
data=payload,
|
||||
content_type="application/vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["data"]["attributes"]["alias"] == "updated_alias"
|
||||
|
||||
def test_partial_update_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
update_payload = {
|
||||
"data": {
|
||||
"type": "providers",
|
||||
"attributes": {"alias": "updated_alias"},
|
||||
}
|
||||
}
|
||||
response = authenticated_client_no_permissions_rbac.patch(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id}),
|
||||
data=update_payload,
|
||||
format="vnd.api+json",
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@patch("api.v1.views.Task.objects.get")
|
||||
@patch("api.v1.views.delete_provider_task.delay")
|
||||
def test_delete_provider_with_all_permissions(
|
||||
self,
|
||||
mock_delete_task,
|
||||
mock_task_get,
|
||||
authenticated_client_rbac,
|
||||
providers_fixture,
|
||||
tasks_fixture,
|
||||
):
|
||||
prowler_task = tasks_fixture[0]
|
||||
task_mock = Mock()
|
||||
task_mock.id = prowler_task.id
|
||||
mock_delete_task.return_value = task_mock
|
||||
mock_task_get.return_value = prowler_task
|
||||
|
||||
provider1, *_ = providers_fixture
|
||||
response = authenticated_client_rbac.delete(
|
||||
reverse("provider-detail", kwargs={"pk": provider1.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
mock_delete_task.assert_called_once_with(
|
||||
provider_id=str(provider1.id), tenant_id=ANY
|
||||
)
|
||||
assert "Content-Location" in response.headers
|
||||
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
|
||||
|
||||
def test_delete_provider_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_no_permissions_rbac.delete(
|
||||
reverse("provider-detail", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@patch("api.v1.views.Task.objects.get")
|
||||
@patch("api.v1.views.check_provider_connection_task.delay")
|
||||
def test_connection_with_all_permissions(
|
||||
self,
|
||||
mock_provider_connection,
|
||||
mock_task_get,
|
||||
authenticated_client_rbac,
|
||||
providers_fixture,
|
||||
tasks_fixture,
|
||||
):
|
||||
prowler_task = tasks_fixture[0]
|
||||
task_mock = Mock()
|
||||
task_mock.id = prowler_task.id
|
||||
task_mock.status = "PENDING"
|
||||
mock_provider_connection.return_value = task_mock
|
||||
mock_task_get.return_value = prowler_task
|
||||
|
||||
provider1, *_ = providers_fixture
|
||||
assert provider1.connected is None
|
||||
assert provider1.connection_last_checked_at is None
|
||||
|
||||
response = authenticated_client_rbac.post(
|
||||
reverse("provider-connection", kwargs={"pk": provider1.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
mock_provider_connection.assert_called_once_with(
|
||||
provider_id=str(provider1.id), tenant_id=ANY
|
||||
)
|
||||
assert "Content-Location" in response.headers
|
||||
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
|
||||
|
||||
def test_connection_with_no_permissions(
|
||||
self, authenticated_client_no_permissions_rbac, providers_fixture
|
||||
):
|
||||
provider = providers_fixture[0]
|
||||
response = authenticated_client_no_permissions_rbac.post(
|
||||
reverse("provider-connection", kwargs={"pk": provider.id})
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,6 +17,7 @@ from api.models import (
|
||||
ComplianceOverview,
|
||||
Finding,
|
||||
Invitation,
|
||||
InvitationRoleRelationship,
|
||||
Membership,
|
||||
Provider,
|
||||
ProviderGroup,
|
||||
@@ -24,10 +25,13 @@ from api.models import (
|
||||
ProviderSecret,
|
||||
Resource,
|
||||
ResourceTag,
|
||||
Role,
|
||||
RoleProviderGroupRelationship,
|
||||
Scan,
|
||||
StateChoices,
|
||||
Task,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from api.rls import Tenant
|
||||
|
||||
@@ -176,10 +180,26 @@ class UserSerializer(BaseSerializerV1):
|
||||
"""
|
||||
|
||||
memberships = serializers.ResourceRelatedField(many=True, read_only=True)
|
||||
roles = serializers.ResourceRelatedField(many=True, read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ["id", "name", "email", "company_name", "date_joined", "memberships"]
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"email",
|
||||
"company_name",
|
||||
"date_joined",
|
||||
"memberships",
|
||||
"roles",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"roles": {"read_only": True},
|
||||
}
|
||||
|
||||
included_serializers = {
|
||||
"roles": "api.v1.serializers.RoleSerializer",
|
||||
}
|
||||
|
||||
|
||||
class UserCreateSerializer(BaseWriteSerializer):
|
||||
@@ -215,10 +235,13 @@ class UserCreateSerializer(BaseWriteSerializer):
|
||||
|
||||
class UserUpdateSerializer(BaseWriteSerializer):
|
||||
password = serializers.CharField(write_only=True, required=False)
|
||||
roles = serializers.ResourceRelatedField(
|
||||
queryset=Role.objects.all(), many=True, required=False
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ["id", "name", "password", "email", "company_name"]
|
||||
fields = ["id", "name", "password", "email", "company_name", "roles"]
|
||||
extra_kwargs = {
|
||||
"id": {"read_only": True},
|
||||
}
|
||||
@@ -235,6 +258,73 @@ class UserUpdateSerializer(BaseWriteSerializer):
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
|
||||
class RoleResourceIdentifierSerializer(serializers.Serializer):
|
||||
resource_type = serializers.CharField(source="type")
|
||||
id = serializers.UUIDField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "role-identifier"
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""
|
||||
Ensure 'type' is used in the output instead of 'resource_type'.
|
||||
"""
|
||||
representation = super().to_representation(instance)
|
||||
representation["type"] = representation.pop("resource_type", None)
|
||||
return representation
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Map 'type' back to 'resource_type' during input.
|
||||
"""
|
||||
data["resource_type"] = data.pop("type", None)
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
class UserRoleRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for modifying user memberships
|
||||
"""
|
||||
|
||||
roles = serializers.ListField(
|
||||
child=RoleResourceIdentifierSerializer(),
|
||||
help_text="List of resource identifier objects representing roles.",
|
||||
)
|
||||
|
||||
def create(self, validated_data):
|
||||
role_ids = [item["id"] for item in validated_data["roles"]]
|
||||
roles = Role.objects.filter(id__in=role_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
new_relationships = [
|
||||
UserRoleRelationship(
|
||||
user=self.context.get("user"), role=r, tenant_id=tenant_id
|
||||
)
|
||||
for r in roles
|
||||
]
|
||||
UserRoleRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return self.context.get("user")
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
role_ids = [item["id"] for item in validated_data["roles"]]
|
||||
roles = Role.objects.filter(id__in=role_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
instance.roles.clear()
|
||||
new_relationships = [
|
||||
UserRoleRelationship(user=instance, role=r, tenant_id=tenant_id)
|
||||
for r in roles
|
||||
]
|
||||
UserRoleRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
model = UserRoleRelationship
|
||||
fields = ["id", "roles"]
|
||||
|
||||
|
||||
# Tasks
|
||||
class TaskBase(serializers.ModelSerializer):
|
||||
state_mapping = {
|
||||
@@ -358,89 +448,200 @@ class MembershipSerializer(serializers.ModelSerializer):
|
||||
|
||||
# Provider Groups
|
||||
class ProviderGroupSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
providers = serializers.ResourceRelatedField(many=True, read_only=True)
|
||||
providers = serializers.ResourceRelatedField(
|
||||
queryset=Provider.objects.all(), many=True, required=False
|
||||
)
|
||||
roles = serializers.ResourceRelatedField(
|
||||
queryset=Role.objects.all(), many=True, required=False
|
||||
)
|
||||
|
||||
def validate(self, attrs):
|
||||
tenant = self.context["tenant_id"]
|
||||
name = attrs.get("name", self.instance.name if self.instance else None)
|
||||
|
||||
# Exclude the current instance when checking for uniqueness during updates
|
||||
queryset = ProviderGroup.objects.filter(tenant=tenant, name=name)
|
||||
if self.instance:
|
||||
queryset = queryset.exclude(pk=self.instance.pk)
|
||||
|
||||
if queryset.exists():
|
||||
if ProviderGroup.objects.filter(name=attrs.get("name")).exists():
|
||||
raise serializers.ValidationError(
|
||||
{
|
||||
"name": "A provider group with this name already exists for this tenant."
|
||||
}
|
||||
{"name": "A provider group with this name already exists."}
|
||||
)
|
||||
|
||||
return super().validate(attrs)
|
||||
|
||||
class Meta:
|
||||
model = ProviderGroup
|
||||
fields = ["id", "name", "inserted_at", "updated_at", "providers", "url"]
|
||||
read_only_fields = ["id", "inserted_at", "updated_at"]
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
"providers",
|
||||
"roles",
|
||||
"url",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"id": {"read_only": True},
|
||||
"inserted_at": {"read_only": True},
|
||||
"updated_at": {"read_only": True},
|
||||
"roles": {"read_only": True},
|
||||
"url": {"read_only": True},
|
||||
}
|
||||
|
||||
|
||||
class ProviderGroupIncludedSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
class ProviderGroupIncludedSerializer(ProviderGroupSerializer):
|
||||
class Meta:
|
||||
model = ProviderGroup
|
||||
fields = ["id", "name"]
|
||||
|
||||
|
||||
class ProviderGroupUpdateSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for updating the ProviderGroup model.
|
||||
Only allows "name" field to be updated.
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
model = ProviderGroup
|
||||
fields = ["id", "name"]
|
||||
|
||||
|
||||
class ProviderGroupMembershipUpdateSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for modifying provider group memberships
|
||||
"""
|
||||
|
||||
provider_ids = serializers.ListField(
|
||||
child=serializers.UUIDField(),
|
||||
help_text="List of provider UUIDs to add to the group",
|
||||
class ProviderGroupCreateSerializer(ProviderGroupSerializer):
|
||||
providers = serializers.ResourceRelatedField(
|
||||
queryset=Provider.objects.all(), many=True, required=False
|
||||
)
|
||||
roles = serializers.ResourceRelatedField(
|
||||
queryset=Role.objects.all(), many=True, required=False
|
||||
)
|
||||
|
||||
def validate(self, attrs):
|
||||
tenant_id = self.context["tenant_id"]
|
||||
provider_ids = attrs.get("provider_ids", [])
|
||||
class Meta:
|
||||
model = ProviderGroup
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
"providers",
|
||||
"roles",
|
||||
"url",
|
||||
]
|
||||
|
||||
existing_provider_ids = set(
|
||||
Provider.objects.filter(
|
||||
id__in=provider_ids, tenant_id=tenant_id
|
||||
).values_list("id", flat=True)
|
||||
def create(self, validated_data):
|
||||
providers = validated_data.pop("providers", [])
|
||||
roles = validated_data.pop("roles", [])
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
provider_group = ProviderGroup.objects.create(
|
||||
tenant_id=tenant_id, **validated_data
|
||||
)
|
||||
provided_provider_ids = set(provider_ids)
|
||||
|
||||
missing_provider_ids = provided_provider_ids - existing_provider_ids
|
||||
|
||||
if missing_provider_ids:
|
||||
raise serializers.ValidationError(
|
||||
{
|
||||
"provider_ids": f"The following provider IDs do not exist: {', '.join(str(id) for id in missing_provider_ids)}"
|
||||
}
|
||||
through_model_instances = [
|
||||
ProviderGroupMembership(
|
||||
provider_group=provider_group,
|
||||
provider=provider,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider in providers
|
||||
]
|
||||
ProviderGroupMembership.objects.bulk_create(through_model_instances)
|
||||
|
||||
return super().validate(attrs)
|
||||
through_model_instances = [
|
||||
RoleProviderGroupRelationship(
|
||||
provider_group=provider_group,
|
||||
role=role,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for role in roles
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
return provider_group
|
||||
|
||||
|
||||
class ProviderGroupUpdateSerializer(ProviderGroupSerializer):
|
||||
def update(self, instance, validated_data):
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
if "providers" in validated_data:
|
||||
providers = validated_data.pop("providers")
|
||||
instance.providers.clear()
|
||||
through_model_instances = [
|
||||
ProviderGroupMembership(
|
||||
provider_group=instance,
|
||||
provider=provider,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider in providers
|
||||
]
|
||||
ProviderGroupMembership.objects.bulk_create(through_model_instances)
|
||||
|
||||
if "roles" in validated_data:
|
||||
roles = validated_data.pop("roles")
|
||||
instance.roles.clear()
|
||||
through_model_instances = [
|
||||
RoleProviderGroupRelationship(
|
||||
provider_group=instance,
|
||||
role=role,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for role in roles
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
|
||||
class ProviderResourceIdentifierSerializer(serializers.Serializer):
|
||||
resource_type = serializers.CharField(source="type")
|
||||
id = serializers.UUIDField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "provider-identifier"
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""
|
||||
Ensure 'type' is used in the output instead of 'resource_type'.
|
||||
"""
|
||||
representation = super().to_representation(instance)
|
||||
representation["type"] = representation.pop("resource_type", None)
|
||||
return representation
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Map 'type' back to 'resource_type' during input.
|
||||
"""
|
||||
data["resource_type"] = data.pop("type", None)
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
class ProviderGroupMembershipSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for modifying provider_group memberships
|
||||
"""
|
||||
|
||||
providers = serializers.ListField(
|
||||
child=ProviderResourceIdentifierSerializer(),
|
||||
help_text="List of resource identifier objects representing providers.",
|
||||
)
|
||||
|
||||
def create(self, validated_data):
|
||||
provider_ids = [item["id"] for item in validated_data["providers"]]
|
||||
providers = Provider.objects.filter(id__in=provider_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
new_relationships = [
|
||||
ProviderGroupMembership(
|
||||
provider_group=self.context.get("provider_group"),
|
||||
provider=p,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
ProviderGroupMembership.objects.bulk_create(new_relationships)
|
||||
|
||||
return self.context.get("provider_group")
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
provider_ids = [item["id"] for item in validated_data["providers"]]
|
||||
providers = Provider.objects.filter(id__in=provider_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
instance.providers.clear()
|
||||
new_relationships = [
|
||||
ProviderGroupMembership(
|
||||
provider_group=instance, provider=p, tenant_id=tenant_id
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
ProviderGroupMembership.objects.bulk_create(new_relationships)
|
||||
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
model = ProviderGroupMembership
|
||||
fields = ["id", "provider_ids"]
|
||||
fields = ["id", "providers"]
|
||||
|
||||
|
||||
# Providers
|
||||
@@ -1034,6 +1235,8 @@ class InvitationSerializer(RLSSerializer):
|
||||
Serializer for the Invitation model.
|
||||
"""
|
||||
|
||||
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
|
||||
|
||||
class Meta:
|
||||
model = Invitation
|
||||
fields = [
|
||||
@@ -1043,6 +1246,7 @@ class InvitationSerializer(RLSSerializer):
|
||||
"email",
|
||||
"state",
|
||||
"token",
|
||||
"roles",
|
||||
"expires_at",
|
||||
"inviter",
|
||||
"url",
|
||||
@@ -1050,6 +1254,8 @@ class InvitationSerializer(RLSSerializer):
|
||||
|
||||
|
||||
class InvitationBaseWriteSerializer(BaseWriteSerializer):
|
||||
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
|
||||
|
||||
def validate_email(self, value):
|
||||
user = User.objects.filter(email=value).first()
|
||||
tenant_id = self.context["tenant_id"]
|
||||
@@ -1086,31 +1292,54 @@ class InvitationCreateSerializer(InvitationBaseWriteSerializer, RLSSerializer):
|
||||
|
||||
class Meta:
|
||||
model = Invitation
|
||||
fields = ["email", "expires_at", "state", "token", "inviter"]
|
||||
fields = ["email", "expires_at", "state", "token", "inviter", "roles"]
|
||||
extra_kwargs = {
|
||||
"token": {"read_only": True},
|
||||
"state": {"read_only": True},
|
||||
"inviter": {"read_only": True},
|
||||
"expires_at": {"required": False},
|
||||
"roles": {"required": False},
|
||||
}
|
||||
|
||||
def create(self, validated_data):
|
||||
inviter = self.context.get("request").user
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
validated_data["inviter"] = inviter
|
||||
return super().create(validated_data)
|
||||
roles = validated_data.pop("roles", [])
|
||||
invitation = super().create(validated_data)
|
||||
for role in roles:
|
||||
InvitationRoleRelationship.objects.create(
|
||||
role=role, invitation=invitation, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
|
||||
class InvitationUpdateSerializer(InvitationBaseWriteSerializer):
|
||||
class Meta:
|
||||
model = Invitation
|
||||
fields = ["id", "email", "expires_at", "state", "token"]
|
||||
fields = ["id", "email", "expires_at", "state", "token", "roles"]
|
||||
extra_kwargs = {
|
||||
"token": {"read_only": True},
|
||||
"state": {"read_only": True},
|
||||
"expires_at": {"required": False},
|
||||
"email": {"required": False},
|
||||
"roles": {"required": False},
|
||||
}
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
invitation = super().update(instance, validated_data)
|
||||
if "roles" in validated_data:
|
||||
roles = validated_data.pop("roles")
|
||||
instance.roles.clear()
|
||||
for role in roles:
|
||||
InvitationRoleRelationship.objects.create(
|
||||
role=role, invitation=invitation, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
|
||||
class InvitationAcceptSerializer(RLSSerializer):
|
||||
"""Serializer for accepting an invitation."""
|
||||
@@ -1122,6 +1351,205 @@ class InvitationAcceptSerializer(RLSSerializer):
|
||||
fields = ["invitation_token"]
|
||||
|
||||
|
||||
# Roles
|
||||
|
||||
|
||||
class RoleSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
permission_state = serializers.SerializerMethodField()
|
||||
users = serializers.ResourceRelatedField(
|
||||
queryset=User.objects.all(), many=True, required=False
|
||||
)
|
||||
provider_groups = serializers.ResourceRelatedField(
|
||||
queryset=ProviderGroup.objects.all(), many=True, required=False
|
||||
)
|
||||
|
||||
def get_permission_state(self, obj) -> str:
|
||||
return obj.permission_state
|
||||
|
||||
def validate(self, attrs):
|
||||
if Role.objects.filter(name=attrs.get("name")).exists():
|
||||
raise serializers.ValidationError(
|
||||
{"name": "A role with this name already exists."}
|
||||
)
|
||||
|
||||
if attrs.get("manage_providers"):
|
||||
attrs["unlimited_visibility"] = True
|
||||
|
||||
# Prevent updates to the admin role
|
||||
if getattr(self.instance, "name", None) == "admin":
|
||||
raise serializers.ValidationError(
|
||||
{"name": "The admin role cannot be updated."}
|
||||
)
|
||||
|
||||
return super().validate(attrs)
|
||||
|
||||
class Meta:
|
||||
model = Role
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"manage_users",
|
||||
"manage_account",
|
||||
"manage_billing",
|
||||
"manage_providers",
|
||||
"manage_integrations",
|
||||
"manage_scans",
|
||||
"permission_state",
|
||||
"unlimited_visibility",
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
"provider_groups",
|
||||
"users",
|
||||
"invitations",
|
||||
"url",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"id": {"read_only": True},
|
||||
"inserted_at": {"read_only": True},
|
||||
"updated_at": {"read_only": True},
|
||||
"url": {"read_only": True},
|
||||
}
|
||||
|
||||
|
||||
class RoleCreateSerializer(RoleSerializer):
|
||||
provider_groups = serializers.ResourceRelatedField(
|
||||
many=True, queryset=ProviderGroup.objects.all(), required=False
|
||||
)
|
||||
users = serializers.ResourceRelatedField(
|
||||
many=True, queryset=User.objects.all(), required=False
|
||||
)
|
||||
|
||||
def create(self, validated_data):
|
||||
provider_groups = validated_data.pop("provider_groups", [])
|
||||
users = validated_data.pop("users", [])
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
role = Role.objects.create(tenant_id=tenant_id, **validated_data)
|
||||
|
||||
through_model_instances = [
|
||||
RoleProviderGroupRelationship(
|
||||
role=role,
|
||||
provider_group=provider_group,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider_group in provider_groups
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
through_model_instances = [
|
||||
UserRoleRelationship(
|
||||
role=role,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
UserRoleRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
return role
|
||||
|
||||
|
||||
class RoleUpdateSerializer(RoleSerializer):
|
||||
def update(self, instance, validated_data):
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
if "provider_groups" in validated_data:
|
||||
provider_groups = validated_data.pop("provider_groups")
|
||||
instance.provider_groups.clear()
|
||||
through_model_instances = [
|
||||
RoleProviderGroupRelationship(
|
||||
role=instance,
|
||||
provider_group=provider_group,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider_group in provider_groups
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
if "users" in validated_data:
|
||||
users = validated_data.pop("users")
|
||||
instance.users.clear()
|
||||
through_model_instances = [
|
||||
UserRoleRelationship(
|
||||
role=instance,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
UserRoleRelationship.objects.bulk_create(through_model_instances)
|
||||
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
|
||||
class ProviderGroupResourceIdentifierSerializer(serializers.Serializer):
|
||||
resource_type = serializers.CharField(source="type")
|
||||
id = serializers.UUIDField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "provider-group-identifier"
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""
|
||||
Ensure 'type' is used in the output instead of 'resource_type'.
|
||||
"""
|
||||
representation = super().to_representation(instance)
|
||||
representation["type"] = representation.pop("resource_type", None)
|
||||
return representation
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""
|
||||
Map 'type' back to 'resource_type' during input.
|
||||
"""
|
||||
data["resource_type"] = data.pop("type", None)
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
class RoleProviderGroupRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
|
||||
"""
|
||||
Serializer for modifying role memberships
|
||||
"""
|
||||
|
||||
provider_groups = serializers.ListField(
|
||||
child=ProviderGroupResourceIdentifierSerializer(),
|
||||
help_text="List of resource identifier objects representing provider groups.",
|
||||
)
|
||||
|
||||
def create(self, validated_data):
|
||||
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
|
||||
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
new_relationships = [
|
||||
RoleProviderGroupRelationship(
|
||||
role=self.context.get("role"), provider_group=pg, tenant_id=tenant_id
|
||||
)
|
||||
for pg in provider_groups
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return self.context.get("role")
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
|
||||
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
|
||||
tenant_id = self.context.get("tenant_id")
|
||||
|
||||
instance.provider_groups.clear()
|
||||
new_relationships = [
|
||||
RoleProviderGroupRelationship(
|
||||
role=instance, provider_group=pg, tenant_id=tenant_id
|
||||
)
|
||||
for pg in provider_groups
|
||||
]
|
||||
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
|
||||
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
model = RoleProviderGroupRelationship
|
||||
fields = ["id", "provider_groups"]
|
||||
|
||||
|
||||
# Compliance overview
|
||||
|
||||
|
||||
@@ -1334,6 +1762,24 @@ class OverviewSeveritySerializer(serializers.Serializer):
|
||||
return {"version": "v1"}
|
||||
|
||||
|
||||
class OverviewServiceSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(source="service")
|
||||
total = serializers.IntegerField()
|
||||
_pass = serializers.IntegerField()
|
||||
fail = serializers.IntegerField()
|
||||
muted = serializers.IntegerField()
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "services-overview"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields["pass"] = self.fields.pop("_pass")
|
||||
|
||||
def get_root_meta(self, _resource, _many):
|
||||
return {"version": "v1"}
|
||||
|
||||
|
||||
# Schedules
|
||||
|
||||
|
||||
|
||||
@@ -3,16 +3,20 @@ from drf_spectacular.views import SpectacularRedocView
|
||||
from rest_framework_nested import routers
|
||||
|
||||
from api.v1.views import (
|
||||
ComplianceOverviewViewSet,
|
||||
CustomTokenObtainView,
|
||||
CustomTokenRefreshView,
|
||||
FindingViewSet,
|
||||
InvitationAcceptViewSet,
|
||||
InvitationViewSet,
|
||||
MembershipViewSet,
|
||||
OverviewViewSet,
|
||||
ProviderGroupViewSet,
|
||||
ProviderGroupProvidersRelationshipView,
|
||||
ProviderSecretViewSet,
|
||||
InvitationViewSet,
|
||||
InvitationAcceptViewSet,
|
||||
RoleViewSet,
|
||||
RoleProviderGroupRelationshipView,
|
||||
UserRoleRelationshipView,
|
||||
OverviewViewSet,
|
||||
ComplianceOverviewViewSet,
|
||||
ProviderViewSet,
|
||||
ResourceViewSet,
|
||||
ScanViewSet,
|
||||
@@ -29,11 +33,12 @@ router = routers.DefaultRouter(trailing_slash=False)
|
||||
router.register(r"users", UserViewSet, basename="user")
|
||||
router.register(r"tenants", TenantViewSet, basename="tenant")
|
||||
router.register(r"providers", ProviderViewSet, basename="provider")
|
||||
router.register(r"provider_groups", ProviderGroupViewSet, basename="providergroup")
|
||||
router.register(r"provider-groups", ProviderGroupViewSet, basename="providergroup")
|
||||
router.register(r"scans", ScanViewSet, basename="scan")
|
||||
router.register(r"tasks", TaskViewSet, basename="task")
|
||||
router.register(r"resources", ResourceViewSet, basename="resource")
|
||||
router.register(r"findings", FindingViewSet, basename="finding")
|
||||
router.register(r"roles", RoleViewSet, basename="role")
|
||||
router.register(
|
||||
r"compliance-overviews", ComplianceOverviewViewSet, basename="complianceoverview"
|
||||
)
|
||||
@@ -80,6 +85,27 @@ urlpatterns = [
|
||||
InvitationAcceptViewSet.as_view({"post": "accept"}),
|
||||
name="invitation-accept",
|
||||
),
|
||||
path(
|
||||
"roles/<uuid:pk>/relationships/provider_groups",
|
||||
RoleProviderGroupRelationshipView.as_view(
|
||||
{"post": "create", "patch": "partial_update", "delete": "destroy"}
|
||||
),
|
||||
name="role-provider-groups-relationship",
|
||||
),
|
||||
path(
|
||||
"users/<uuid:pk>/relationships/roles",
|
||||
UserRoleRelationshipView.as_view(
|
||||
{"post": "create", "patch": "partial_update", "delete": "destroy"}
|
||||
),
|
||||
name="user-roles-relationship",
|
||||
),
|
||||
path(
|
||||
"provider-groups/<uuid:pk>/relationships/providers",
|
||||
ProviderGroupProvidersRelationshipView.as_view(
|
||||
{"post": "create", "patch": "partial_update", "delete": "destroy"}
|
||||
),
|
||||
name="provider_group-providers-relationship",
|
||||
),
|
||||
path("", include(router.urls)),
|
||||
path("", include(tenants_router.urls)),
|
||||
path("", include(users_router.urls)),
|
||||
|
||||
@@ -16,6 +16,7 @@ from drf_spectacular.utils import (
|
||||
extend_schema_view,
|
||||
)
|
||||
from drf_spectacular.views import SpectacularAPIView
|
||||
from drf_spectacular_jsonapi.schemas.openapi import JsonApiAutoSchema
|
||||
from rest_framework import permissions, status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import (
|
||||
@@ -25,7 +26,8 @@ from rest_framework.exceptions import (
|
||||
ValidationError,
|
||||
)
|
||||
from rest_framework.generics import GenericAPIView, get_object_or_404
|
||||
from rest_framework_json_api.views import Response
|
||||
from rest_framework.permissions import SAFE_METHODS
|
||||
from rest_framework_json_api.views import RelationshipView, Response
|
||||
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
|
||||
from tasks.beat import schedule_provider_scan
|
||||
from tasks.tasks import (
|
||||
@@ -47,8 +49,10 @@ from api.filters import (
|
||||
ProviderGroupFilter,
|
||||
ProviderSecretFilter,
|
||||
ResourceFilter,
|
||||
RoleFilter,
|
||||
ScanFilter,
|
||||
ScanSummaryFilter,
|
||||
ServiceOverviewFilter,
|
||||
TaskFilter,
|
||||
TenantFilter,
|
||||
UserFilter,
|
||||
@@ -63,6 +67,8 @@ from api.models import (
|
||||
ProviderGroupMembership,
|
||||
ProviderSecret,
|
||||
Resource,
|
||||
Role,
|
||||
RoleProviderGroupRelationship,
|
||||
Scan,
|
||||
ScanSummary,
|
||||
SeverityChoices,
|
||||
@@ -70,8 +76,10 @@ from api.models import (
|
||||
StatusChoices,
|
||||
Task,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from api.pagination import ComplianceOverviewPagination
|
||||
from api.rbac.permissions import Permissions, get_providers, get_role
|
||||
from api.rls import Tenant
|
||||
from api.utils import validate_invitation
|
||||
from api.uuid_utils import datetime_to_uuid7
|
||||
@@ -87,10 +95,12 @@ from api.v1.serializers import (
|
||||
MembershipSerializer,
|
||||
OverviewFindingSerializer,
|
||||
OverviewProviderSerializer,
|
||||
OverviewServiceSerializer,
|
||||
OverviewSeveritySerializer,
|
||||
ProviderCreateSerializer,
|
||||
ProviderGroupMembershipUpdateSerializer,
|
||||
ProviderGroupMembershipSerializer,
|
||||
ProviderGroupSerializer,
|
||||
ProviderGroupCreateSerializer,
|
||||
ProviderGroupUpdateSerializer,
|
||||
ProviderSecretCreateSerializer,
|
||||
ProviderSecretSerializer,
|
||||
@@ -98,6 +108,10 @@ from api.v1.serializers import (
|
||||
ProviderSerializer,
|
||||
ProviderUpdateSerializer,
|
||||
ResourceSerializer,
|
||||
RoleCreateSerializer,
|
||||
RoleProviderGroupRelationshipSerializer,
|
||||
RoleSerializer,
|
||||
RoleUpdateSerializer,
|
||||
ScanCreateSerializer,
|
||||
ScanSerializer,
|
||||
ScanUpdateSerializer,
|
||||
@@ -107,6 +121,7 @@ from api.v1.serializers import (
|
||||
TokenRefreshSerializer,
|
||||
TokenSerializer,
|
||||
UserCreateSerializer,
|
||||
UserRoleRelationshipSerializer,
|
||||
UserSerializer,
|
||||
UserUpdateSerializer,
|
||||
)
|
||||
@@ -117,6 +132,11 @@ CACHE_DECORATOR = cache_control(
|
||||
)
|
||||
|
||||
|
||||
class RelationshipViewSchema(JsonApiAutoSchema):
|
||||
def _resolve_path_parameters(self, _path_variables):
|
||||
return []
|
||||
|
||||
|
||||
@extend_schema(
|
||||
tags=["Token"],
|
||||
summary="Obtain a token",
|
||||
@@ -172,7 +192,7 @@ class SchemaView(SpectacularAPIView):
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
spectacular_settings.TITLE = "Prowler API"
|
||||
spectacular_settings.VERSION = "1.0.1"
|
||||
spectacular_settings.VERSION = "1.1.0"
|
||||
spectacular_settings.DESCRIPTION = (
|
||||
"Prowler API specification.\n\nThis file is auto-generated."
|
||||
)
|
||||
@@ -271,6 +291,19 @@ class UserViewSet(BaseUserViewset):
|
||||
filterset_class = UserFilter
|
||||
ordering = ["-date_joined"]
|
||||
ordering_fields = ["name", "email", "company_name", "date_joined", "is_active"]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_USERS]
|
||||
|
||||
def set_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.action == "me":
|
||||
# No permissions required for me request
|
||||
self.required_permissions = []
|
||||
else:
|
||||
# Require permission for the rest of the requests
|
||||
self.required_permissions = [Permissions.MANAGE_USERS]
|
||||
|
||||
def get_queryset(self):
|
||||
# If called during schema generation, return an empty queryset
|
||||
@@ -347,11 +380,125 @@ class UserViewSet(BaseUserViewset):
|
||||
user=user, tenant=tenant, role=role
|
||||
)
|
||||
if invitation:
|
||||
user_role = []
|
||||
for role in invitation.roles.all():
|
||||
user_role.append(
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user, role=role, tenant=invitation.tenant
|
||||
)
|
||||
)
|
||||
invitation.state = Invitation.State.ACCEPTED
|
||||
invitation.save(using=MainRouter.admin_db)
|
||||
else:
|
||||
role = Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
return Response(data=UserSerializer(user).data, status=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
tags=["User"],
|
||||
summary="Create a new user-roles relationship",
|
||||
description="Add a new user-roles relationship to the system by providing the required user-roles details.",
|
||||
responses={
|
||||
204: OpenApiResponse(description="Relationship created successfully"),
|
||||
400: OpenApiResponse(
|
||||
description="Bad request (e.g., relationship already exists)"
|
||||
),
|
||||
},
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
tags=["User"],
|
||||
summary="Partially update a user-roles relationship",
|
||||
description="Update the user-roles relationship information without affecting other fields.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship updated successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
tags=["User"],
|
||||
summary="Delete a user-roles relationship",
|
||||
description="Remove the user-roles relationship from the system by their ID.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship deleted successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
queryset = User.objects.all()
|
||||
serializer_class = UserRoleRelationshipSerializer
|
||||
resource_name = "roles"
|
||||
http_method_names = ["post", "patch", "delete"]
|
||||
schema = RelationshipViewSchema()
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_USERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return User.objects.all()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
user = self.get_object()
|
||||
|
||||
role_ids = [item["id"] for item in request.data]
|
||||
existing_relationships = UserRoleRelationship.objects.filter(
|
||||
user=user, role_id__in=role_ids
|
||||
)
|
||||
|
||||
if existing_relationships.exists():
|
||||
return Response(
|
||||
{"detail": "One or more roles are already associated with the user."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(
|
||||
data={"roles": request.data},
|
||||
context={
|
||||
"user": user,
|
||||
"tenant_id": self.request.tenant_id,
|
||||
"request": request,
|
||||
},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
user = self.get_object()
|
||||
serializer = self.get_serializer(
|
||||
instance=user,
|
||||
data={"roles": request.data},
|
||||
context={"tenant_id": self.request.tenant_id, "request": request},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
user = self.get_object()
|
||||
user.roles.clear()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
tags=["Tenant"],
|
||||
@@ -389,6 +536,8 @@ class TenantViewSet(BaseTenantViewset):
|
||||
search_fields = ["name"]
|
||||
ordering = ["-inserted_at"]
|
||||
ordering_fields = ["name", "inserted_at", "updated_at"]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Tenant.objects.all()
|
||||
@@ -446,6 +595,8 @@ class MembershipViewSet(BaseTenantViewset):
|
||||
"role",
|
||||
"date_joined",
|
||||
]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
user = self.request.user
|
||||
@@ -479,6 +630,8 @@ class TenantMembersViewSet(BaseTenantViewset):
|
||||
http_method_names = ["get", "delete"]
|
||||
serializer_class = MembershipSerializer
|
||||
queryset = Membership.objects.none()
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
tenant = self.get_tenant()
|
||||
@@ -562,66 +715,128 @@ class ProviderGroupViewSet(BaseRLSViewSet):
|
||||
queryset = ProviderGroup.objects.all()
|
||||
serializer_class = ProviderGroupSerializer
|
||||
filterset_class = ProviderGroupFilter
|
||||
http_method_names = ["get", "post", "patch", "put", "delete"]
|
||||
http_method_names = ["get", "post", "patch", "delete"]
|
||||
ordering = ["inserted_at"]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def set_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.request.method in SAFE_METHODS:
|
||||
# No permissions required for GET requests
|
||||
self.required_permissions = []
|
||||
else:
|
||||
# Require permission for non-GET requests
|
||||
self.required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return ProviderGroup.objects.prefetch_related("providers")
|
||||
user_roles = get_role(self.request.user)
|
||||
# Check if any of the user's roles have UNLIMITED_VISIBILITY
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all provider groups
|
||||
return ProviderGroup.objects.prefetch_related("providers")
|
||||
|
||||
# Collect provider groups associated with the user's roles
|
||||
return user_roles.provider_groups.all()
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "partial_update":
|
||||
if self.action == "create":
|
||||
return ProviderGroupCreateSerializer
|
||||
elif self.action == "partial_update":
|
||||
return ProviderGroupUpdateSerializer
|
||||
elif self.action == "providers":
|
||||
if hasattr(self, "response_serializer_class"):
|
||||
return self.response_serializer_class
|
||||
return ProviderGroupMembershipUpdateSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
@extend_schema(
|
||||
tags=["Provider Group"],
|
||||
summary="Add providers to a provider group",
|
||||
description="Add one or more providers to an existing provider group.",
|
||||
request=ProviderGroupMembershipUpdateSerializer,
|
||||
responses={200: OpenApiResponse(response=ProviderGroupSerializer)},
|
||||
)
|
||||
@action(detail=True, methods=["put"], url_name="providers")
|
||||
def providers(self, request, pk=None):
|
||||
|
||||
@extend_schema(tags=["Provider Group"])
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
summary="Create a new provider_group-providers relationship",
|
||||
description="Add a new provider_group-providers relationship to the system by providing the required provider_group-providers details.",
|
||||
responses={
|
||||
204: OpenApiResponse(description="Relationship created successfully"),
|
||||
400: OpenApiResponse(
|
||||
description="Bad request (e.g., relationship already exists)"
|
||||
),
|
||||
},
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
summary="Partially update a provider_group-providers relationship",
|
||||
description="Update the provider_group-providers relationship information without affecting other fields.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship updated successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
summary="Delete a provider_group-providers relationship",
|
||||
description="Remove the provider_group-providers relationship from the system by their ID.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship deleted successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
queryset = ProviderGroup.objects.all()
|
||||
serializer_class = ProviderGroupMembershipSerializer
|
||||
resource_name = "providers"
|
||||
http_method_names = ["post", "patch", "delete"]
|
||||
schema = RelationshipViewSchema()
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return ProviderGroup.objects.all()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
provider_group = self.get_object()
|
||||
|
||||
# Validate input data
|
||||
serializer = self.get_serializer_class()(
|
||||
data=request.data,
|
||||
context=self.get_serializer_context(),
|
||||
provider_ids = [item["id"] for item in request.data]
|
||||
existing_relationships = ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group, provider_id__in=provider_ids
|
||||
)
|
||||
|
||||
if existing_relationships.exists():
|
||||
return Response(
|
||||
{
|
||||
"detail": "One or more providers are already associated with the provider_group."
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(
|
||||
data={"providers": request.data},
|
||||
context={
|
||||
"provider_group": provider_group,
|
||||
"tenant_id": self.request.tenant_id,
|
||||
"request": request,
|
||||
},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
provider_ids = serializer.validated_data["provider_ids"]
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# Update memberships
|
||||
ProviderGroupMembership.objects.filter(
|
||||
provider_group=provider_group, tenant_id=request.tenant_id
|
||||
).delete()
|
||||
|
||||
provider_group_memberships = [
|
||||
ProviderGroupMembership(
|
||||
tenant_id=self.request.tenant_id,
|
||||
provider_group=provider_group,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for provider_id in provider_ids
|
||||
]
|
||||
|
||||
ProviderGroupMembership.objects.bulk_create(
|
||||
provider_group_memberships, ignore_conflicts=True
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
provider_group = self.get_object()
|
||||
serializer = self.get_serializer(
|
||||
instance=provider_group,
|
||||
data={"providers": request.data},
|
||||
context={"tenant_id": self.request.tenant_id, "request": request},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# Return the updated provider group with providers
|
||||
provider_group.refresh_from_db()
|
||||
self.response_serializer_class = ProviderGroupSerializer
|
||||
response_serializer = ProviderGroupSerializer(
|
||||
provider_group, context=self.get_serializer_context()
|
||||
)
|
||||
return Response(data=response_serializer.data, status=status.HTTP_200_OK)
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
provider_group = self.get_object()
|
||||
provider_group.providers.clear()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
@@ -671,9 +886,28 @@ class ProviderViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def set_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.request.method in SAFE_METHODS:
|
||||
# No permissions required for GET requests
|
||||
self.required_permissions = []
|
||||
else:
|
||||
# Require permission for non-GET requests
|
||||
self.required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return Provider.objects.all()
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all providers
|
||||
return Provider.objects.all()
|
||||
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
return get_providers(user_roles)
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -793,9 +1027,28 @@ class ScanViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_SCANS]
|
||||
|
||||
def set_required_permissions(self):
|
||||
"""
|
||||
Returns the required permissions based on the request method.
|
||||
"""
|
||||
if self.request.method in SAFE_METHODS:
|
||||
# No permissions required for GET requests
|
||||
self.required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
else:
|
||||
# Require permission for non-GET requests
|
||||
self.required_permissions = [Permissions.MANAGE_SCANS]
|
||||
|
||||
def get_queryset(self):
|
||||
return Scan.objects.all()
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all scans
|
||||
return Scan.objects.all()
|
||||
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
return Scan.objects.filter(provider__in=get_providers(user_roles))
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
@@ -885,10 +1138,13 @@ class TaskViewSet(BaseRLSViewSet):
|
||||
search_fields = ["name"]
|
||||
ordering = ["-inserted_at"]
|
||||
ordering_fields = ["inserted_at", "completed_at", "name", "state"]
|
||||
# RBAC required permissions
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
return Task.objects.annotate(
|
||||
name=F("task_runner_task__task_name"), state=F("task_runner_task__status")
|
||||
name=F("task_runner_task__task_name"),
|
||||
state=F("task_runner_task__status"),
|
||||
)
|
||||
|
||||
def destroy(self, request, *args, pk=None, **kwargs):
|
||||
@@ -950,11 +1206,19 @@ class ResourceViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = Resource.objects.all()
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all scans
|
||||
queryset = Resource.objects.all()
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
queryset = Resource.objects.filter(provider__in=get_providers(user_roles))
|
||||
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
if search_value:
|
||||
# Django's ORM will build a LEFT JOIN and OUTER JOIN on the "through" table, resulting in duplicates
|
||||
# The duplicates then require a `distinct` query
|
||||
@@ -1025,11 +1289,8 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
def inserted_at_to_uuidv7(self, inserted_at):
|
||||
if inserted_at is None:
|
||||
return None
|
||||
return datetime_to_uuid7(inserted_at)
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "findings_services_regions":
|
||||
@@ -1038,9 +1299,17 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
return super().get_serializer_class()
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = Finding.objects.all()
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
user_roles = get_role(self.request.user)
|
||||
if user_roles.unlimited_visibility:
|
||||
# User has unlimited visibility, return all scans
|
||||
queryset = Finding.objects.all()
|
||||
else:
|
||||
# User lacks permission, filter providers based on provider groups associated with the role
|
||||
queryset = Finding.objects.filter(
|
||||
scan__provider__in=get_providers(user_roles)
|
||||
)
|
||||
|
||||
search_value = self.request.query_params.get("filter[search]", None)
|
||||
if search_value:
|
||||
# Django's ORM will build a LEFT JOIN and OUTER JOIN on any "through" tables, resulting in duplicates
|
||||
# The duplicates then require a `distinct` query
|
||||
@@ -1068,6 +1337,11 @@ class FindingViewSet(BaseRLSViewSet):
|
||||
|
||||
return queryset
|
||||
|
||||
def inserted_at_to_uuidv7(self, inserted_at):
|
||||
if inserted_at is None:
|
||||
return None
|
||||
return datetime_to_uuid7(inserted_at)
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="findings_services_regions")
|
||||
def findings_services_regions(self, request):
|
||||
queryset = self.get_queryset()
|
||||
@@ -1131,6 +1405,8 @@ class ProviderSecretViewSet(BaseRLSViewSet):
|
||||
"inserted_at",
|
||||
"updated_at",
|
||||
]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_PROVIDERS]
|
||||
|
||||
def get_queryset(self):
|
||||
return ProviderSecret.objects.all()
|
||||
@@ -1188,6 +1464,8 @@ class InvitationViewSet(BaseRLSViewSet):
|
||||
"state",
|
||||
"inviter",
|
||||
]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Invitation.objects.all()
|
||||
@@ -1275,6 +1553,13 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
|
||||
user=user,
|
||||
tenant=invitation.tenant,
|
||||
)
|
||||
user_role = []
|
||||
for role in invitation.roles.all():
|
||||
user_role.append(
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user, role=role, tenant=invitation.tenant
|
||||
)
|
||||
)
|
||||
invitation.state = Invitation.State.ACCEPTED
|
||||
invitation.save(using=MainRouter.admin_db)
|
||||
|
||||
@@ -1283,6 +1568,154 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
|
||||
return Response(data=membership_serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
@extend_schema(tags=["Role"])
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="List all roles",
|
||||
description="Retrieve a list of all roles with options for filtering by various criteria.",
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Retrieve data from a role",
|
||||
description="Fetch detailed information about a specific role by their ID.",
|
||||
),
|
||||
create=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Create a new role",
|
||||
description="Add a new role to the system by providing the required role details.",
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Partially update a role",
|
||||
description="Update certain fields of an existing role's information without affecting other fields.",
|
||||
responses={200: RoleSerializer},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Delete a role",
|
||||
description="Remove a role from the system by their ID.",
|
||||
),
|
||||
)
|
||||
class RoleViewSet(BaseRLSViewSet):
|
||||
queryset = Role.objects.all()
|
||||
serializer_class = RoleSerializer
|
||||
filterset_class = RoleFilter
|
||||
http_method_names = ["get", "post", "patch", "delete"]
|
||||
ordering = ["inserted_at"]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Role.objects.all()
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "create":
|
||||
return RoleCreateSerializer
|
||||
elif self.action == "partial_update":
|
||||
return RoleUpdateSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
user_role = get_role(request.user)
|
||||
# If the user is the owner of the role, the manage_account field is not editable
|
||||
if user_role and kwargs["pk"] == str(user_role.id):
|
||||
request.data["manage_account"] = str(user_role.manage_account).lower()
|
||||
return super().partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Create a new role-provider_groups relationship",
|
||||
description="Add a new role-provider_groups relationship to the system by providing the required role-provider_groups details.",
|
||||
responses={
|
||||
204: OpenApiResponse(description="Relationship created successfully"),
|
||||
400: OpenApiResponse(
|
||||
description="Bad request (e.g., relationship already exists)"
|
||||
),
|
||||
},
|
||||
),
|
||||
partial_update=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Partially update a role-provider_groups relationship",
|
||||
description="Update the role-provider_groups relationship information without affecting other fields.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship updated successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
destroy=extend_schema(
|
||||
tags=["Role"],
|
||||
summary="Delete a role-provider_groups relationship",
|
||||
description="Remove the role-provider_groups relationship from the system by their ID.",
|
||||
responses={
|
||||
204: OpenApiResponse(
|
||||
response=None, description="Relationship deleted successfully"
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
|
||||
queryset = Role.objects.all()
|
||||
serializer_class = RoleProviderGroupRelationshipSerializer
|
||||
resource_name = "provider_groups"
|
||||
http_method_names = ["post", "patch", "delete"]
|
||||
schema = RelationshipViewSchema()
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_ACCOUNT]
|
||||
|
||||
def get_queryset(self):
|
||||
return Role.objects.all()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
|
||||
provider_group_ids = [item["id"] for item in request.data]
|
||||
existing_relationships = RoleProviderGroupRelationship.objects.filter(
|
||||
role=role, provider_group_id__in=provider_group_ids
|
||||
)
|
||||
|
||||
if existing_relationships.exists():
|
||||
return Response(
|
||||
{
|
||||
"detail": "One or more provider groups are already associated with the role."
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
serializer = self.get_serializer(
|
||||
data={"provider_groups": request.data},
|
||||
context={
|
||||
"role": role,
|
||||
"tenant_id": self.request.tenant_id,
|
||||
"request": request,
|
||||
},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
serializer = self.get_serializer(
|
||||
instance=role,
|
||||
data={"provider_groups": request.data},
|
||||
context={"tenant_id": self.request.tenant_id, "request": request},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
role.provider_groups.clear()
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
tags=["Compliance Overview"],
|
||||
@@ -1317,12 +1750,32 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
|
||||
search_fields = ["compliance_id"]
|
||||
ordering = ["compliance_id"]
|
||||
ordering_fields = ["inserted_at", "compliance_id", "framework", "region"]
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
if self.action == "retrieve":
|
||||
return ComplianceOverview.objects.all()
|
||||
role = get_role(self.request.user)
|
||||
unlimited_visibility = getattr(
|
||||
role, Permissions.UNLIMITED_VISIBILITY.value, False
|
||||
)
|
||||
|
||||
base_queryset = self.filter_queryset(ComplianceOverview.objects.all())
|
||||
if self.action == "retrieve":
|
||||
if unlimited_visibility:
|
||||
# User has unlimited visibility, return all compliance compliances
|
||||
return ComplianceOverview.objects.all()
|
||||
|
||||
providers = get_providers(role)
|
||||
return ComplianceOverview.objects.filter(scan__provider__in=providers)
|
||||
|
||||
if unlimited_visibility:
|
||||
base_queryset = self.filter_queryset(ComplianceOverview.objects.all())
|
||||
else:
|
||||
providers = Provider.objects.filter(
|
||||
provider_groups__in=role.provider_groups.all()
|
||||
).distinct()
|
||||
base_queryset = self.filter_queryset(
|
||||
ComplianceOverview.objects.filter(scan__provider__in=providers)
|
||||
)
|
||||
|
||||
max_failed_ids = (
|
||||
base_queryset.filter(compliance_id=OuterRef("compliance_id"))
|
||||
@@ -1330,12 +1783,10 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
|
||||
.values("id")[:1]
|
||||
)
|
||||
|
||||
queryset = base_queryset.filter(id__in=Subquery(max_failed_ids)).order_by(
|
||||
return base_queryset.filter(id__in=Subquery(max_failed_ids)).order_by(
|
||||
"compliance_id"
|
||||
)
|
||||
|
||||
return queryset
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self.action == "retrieve":
|
||||
return ComplianceOverviewFullSerializer
|
||||
@@ -1387,20 +1838,37 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
|
||||
),
|
||||
filters=True,
|
||||
),
|
||||
services=extend_schema(
|
||||
summary="Get findings data by service",
|
||||
description=(
|
||||
"Retrieve an aggregated summary of findings grouped by service. The response includes the total count "
|
||||
"of findings for each service, as long as there are at least one finding for that service. At least "
|
||||
"one of the `inserted_at` filters must be provided."
|
||||
),
|
||||
filters=True,
|
||||
),
|
||||
)
|
||||
@method_decorator(CACHE_DECORATOR, name="list")
|
||||
class OverviewViewSet(BaseRLSViewSet):
|
||||
queryset = ComplianceOverview.objects.all()
|
||||
http_method_names = ["get"]
|
||||
ordering = ["-id"]
|
||||
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
|
||||
required_permissions = []
|
||||
|
||||
def get_queryset(self):
|
||||
role = get_role(self.request.user)
|
||||
providers = get_providers(role)
|
||||
|
||||
def _get_filtered_queryset(model):
|
||||
if role.unlimited_visibility:
|
||||
return model.objects.all()
|
||||
return model.objects.filter(scan__provider__in=providers)
|
||||
|
||||
if self.action == "providers":
|
||||
return Finding.objects.all()
|
||||
elif self.action == "findings":
|
||||
return ScanSummary.objects.all()
|
||||
elif self.action == "findings_severity":
|
||||
return ScanSummary.objects.all()
|
||||
return _get_filtered_queryset(Finding)
|
||||
elif self.action in ("findings", "findings_severity", "services"):
|
||||
return _get_filtered_queryset(ScanSummary)
|
||||
else:
|
||||
return super().get_queryset()
|
||||
|
||||
@@ -1411,6 +1879,8 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
return OverviewFindingSerializer
|
||||
elif self.action == "findings_severity":
|
||||
return OverviewSeveritySerializer
|
||||
elif self.action == "services":
|
||||
return OverviewServiceSerializer
|
||||
return super().get_serializer_class()
|
||||
|
||||
def get_filterset_class(self):
|
||||
@@ -1418,6 +1888,8 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
return None
|
||||
elif self.action in ["findings", "findings_severity"]:
|
||||
return ScanSummaryFilter
|
||||
elif self.action == "services":
|
||||
return ServiceOverviewFilter
|
||||
return None
|
||||
|
||||
@extend_schema(exclude=True)
|
||||
@@ -1563,6 +2035,38 @@ class OverviewViewSet(BaseRLSViewSet):
|
||||
serializer = OverviewSeveritySerializer(severity_data)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@action(detail=False, methods=["get"], url_name="services")
|
||||
def services(self, request):
|
||||
queryset = self.get_queryset()
|
||||
filtered_queryset = self.filter_queryset(queryset)
|
||||
|
||||
latest_scan_subquery = (
|
||||
Scan.objects.filter(
|
||||
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
|
||||
)
|
||||
.order_by("-id")
|
||||
.values("id")[:1]
|
||||
)
|
||||
|
||||
annotated_queryset = filtered_queryset.annotate(
|
||||
latest_scan_id=Subquery(latest_scan_subquery)
|
||||
)
|
||||
|
||||
filtered_queryset = annotated_queryset.filter(scan_id=F("latest_scan_id"))
|
||||
|
||||
services_data = (
|
||||
filtered_queryset.values("service")
|
||||
.annotate(_pass=Sum("_pass"))
|
||||
.annotate(fail=Sum("fail"))
|
||||
.annotate(muted=Sum("muted"))
|
||||
.annotate(total=Sum("total"))
|
||||
.order_by("service")
|
||||
)
|
||||
|
||||
serializer = OverviewServiceSerializer(services_data, many=True)
|
||||
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
@extend_schema(tags=["Schedule"])
|
||||
@extend_schema_view(
|
||||
@@ -1578,6 +2082,8 @@ class ScheduleViewSet(BaseRLSViewSet):
|
||||
# TODO: change to Schedule when implemented
|
||||
queryset = Task.objects.none()
|
||||
http_method_names = ["post"]
|
||||
# RBAC required permissions
|
||||
required_permissions = [Permissions.MANAGE_SCANS]
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset()
|
||||
|
||||
@@ -35,10 +35,10 @@ class RLSTask(Task):
|
||||
**options,
|
||||
)
|
||||
task_result_instance = TaskResult.objects.get(task_id=result.task_id)
|
||||
from api.db_utils import tenant_transaction
|
||||
from api.db_utils import rls_transaction
|
||||
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
APITask.objects.create(
|
||||
id=task_result_instance.task_id,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -10,8 +10,8 @@ DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"NAME": "prowler_db_test",
|
||||
"USER": env("POSTGRES_USER", default="prowler"),
|
||||
"PASSWORD": env("POSTGRES_PASSWORD", default="S3cret"),
|
||||
"USER": env("POSTGRES_USER", default="prowler_admin"),
|
||||
"PASSWORD": env("POSTGRES_PASSWORD", default="postgres"),
|
||||
"HOST": env("POSTGRES_HOST", default="localhost"),
|
||||
"PORT": env("POSTGRES_PORT", default="5432"),
|
||||
},
|
||||
|
||||
@@ -1,35 +1,39 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from django.db import connections as django_connections, connection as django_connection
|
||||
from django.db import connection as django_connection
|
||||
from django.db import connections as django_connections
|
||||
from django.urls import reverse
|
||||
from django_celery_results.models import TaskResult
|
||||
from prowler.lib.check.models import Severity
|
||||
from prowler.lib.outputs.finding import Status
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
ComplianceOverview,
|
||||
Finding,
|
||||
)
|
||||
from api.models import (
|
||||
User,
|
||||
Invitation,
|
||||
Membership,
|
||||
Provider,
|
||||
ProviderGroup,
|
||||
ProviderSecret,
|
||||
Resource,
|
||||
ResourceTag,
|
||||
Role,
|
||||
Scan,
|
||||
ScanSummary,
|
||||
StateChoices,
|
||||
Task,
|
||||
Membership,
|
||||
ProviderSecret,
|
||||
Invitation,
|
||||
ComplianceOverview,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from api.rls import Tenant
|
||||
from api.v1.serializers import TokenSerializer
|
||||
from prowler.lib.check.models import Severity
|
||||
from prowler.lib.outputs.finding import Status
|
||||
|
||||
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
|
||||
NO_TENANT_HTTP_STATUS = status.HTTP_401_UNAUTHORIZED
|
||||
@@ -83,8 +87,148 @@ def create_test_user(django_db_setup, django_db_blocker):
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac(django_db_setup, django_db_blocker):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing",
|
||||
email="rbac@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
role=Membership.RoleChoices.OWNER,
|
||||
)
|
||||
Role.objects.create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=Role.objects.get(name="admin"),
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing",
|
||||
email="rbac_noroles@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
role=Membership.RoleChoices.OWNER,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
|
||||
with django_db_blocker.unblock():
|
||||
user = User.objects.create_user(
|
||||
name="testing_limited",
|
||||
email="rbac_limited@rbac.com",
|
||||
password=TEST_PASSWORD,
|
||||
)
|
||||
tenant = Tenant.objects.create(
|
||||
name="Tenant Test",
|
||||
)
|
||||
Membership.objects.create(
|
||||
user=user,
|
||||
tenant=tenant,
|
||||
role=Membership.RoleChoices.OWNER,
|
||||
)
|
||||
Role.objects.create(
|
||||
name="limited",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=False,
|
||||
manage_account=False,
|
||||
manage_billing=False,
|
||||
manage_providers=False,
|
||||
manage_integrations=False,
|
||||
manage_scans=False,
|
||||
unlimited_visibility=False,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=Role.objects.get(name="limited"),
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client(create_test_user, tenants_fixture, client):
|
||||
def authenticated_client_rbac(create_test_user_rbac, tenants_fixture, client):
|
||||
client.user = create_test_user_rbac
|
||||
serializer = TokenSerializer(
|
||||
data={"type": "tokens", "email": "rbac@rbac.com", "password": TEST_PASSWORD}
|
||||
)
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client_rbac_noroles(
|
||||
create_test_user_rbac_no_roles, tenants_fixture, client
|
||||
):
|
||||
client.user = create_test_user_rbac_no_roles
|
||||
serializer = TokenSerializer(
|
||||
data={
|
||||
"type": "tokens",
|
||||
"email": "rbac_noroles@rbac.com",
|
||||
"password": TEST_PASSWORD,
|
||||
}
|
||||
)
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client_no_permissions_rbac(
|
||||
create_test_user_rbac_limited, tenants_fixture, client
|
||||
):
|
||||
client.user = create_test_user_rbac_limited
|
||||
serializer = TokenSerializer(
|
||||
data={
|
||||
"type": "tokens",
|
||||
"email": "rbac_limited@rbac.com",
|
||||
"password": TEST_PASSWORD,
|
||||
}
|
||||
)
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authenticated_client(
|
||||
create_test_user, tenants_fixture, set_user_admin_roles_fixture, client
|
||||
):
|
||||
client.user = create_test_user
|
||||
serializer = TokenSerializer(
|
||||
data={"type": "tokens", "email": TEST_USER, "password": TEST_PASSWORD}
|
||||
@@ -104,6 +248,7 @@ def authenticated_api_client(create_test_user, tenants_fixture):
|
||||
serializer.is_valid()
|
||||
access_token = serializer.validated_data["access"]
|
||||
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@@ -128,9 +273,33 @@ def tenants_fixture(create_test_user):
|
||||
tenant3 = Tenant.objects.create(
|
||||
name="Tenant Three",
|
||||
)
|
||||
|
||||
return tenant1, tenant2, tenant3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
|
||||
user = create_test_user
|
||||
for tenant in tenants_fixture[:2]:
|
||||
with rls_transaction(str(tenant.id)):
|
||||
role = Role.objects.create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invitations_fixture(create_test_user, tenants_fixture):
|
||||
user = create_test_user
|
||||
@@ -153,6 +322,20 @@ def invitations_fixture(create_test_user, tenants_fixture):
|
||||
return valid_invitation, expired_invitation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def users_fixture(django_user_model):
|
||||
user1 = User.objects.create_user(
|
||||
name="user1", email="test_unit0@prowler.com", password="S3cret"
|
||||
)
|
||||
user2 = User.objects.create_user(
|
||||
name="user2", email="test_unit1@prowler.com", password="S3cret"
|
||||
)
|
||||
user3 = User.objects.create_user(
|
||||
name="user3", email="test_unit2@prowler.com", password="S3cret"
|
||||
)
|
||||
return user1, user2, user3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def providers_fixture(tenants_fixture):
|
||||
tenant, *_ = tenants_fixture
|
||||
@@ -210,6 +393,57 @@ def provider_groups_fixture(tenants_fixture):
|
||||
return pgroup1, pgroup2, pgroup3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roles_fixture(tenants_fixture):
|
||||
tenant, *_ = tenants_fixture
|
||||
role1 = Role.objects.create(
|
||||
name="Role One",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=False,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=False,
|
||||
)
|
||||
role2 = Role.objects.create(
|
||||
name="Role Two",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=False,
|
||||
manage_account=False,
|
||||
manage_billing=False,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
role3 = Role.objects.create(
|
||||
name="Role Three",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
role4 = Role.objects.create(
|
||||
name="Role Four",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=False,
|
||||
manage_account=False,
|
||||
manage_billing=False,
|
||||
manage_providers=False,
|
||||
manage_integrations=False,
|
||||
manage_scans=False,
|
||||
unlimited_visibility=False,
|
||||
)
|
||||
|
||||
return role1, role2, role3, role4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_secret_fixture(providers_fixture):
|
||||
return tuple(
|
||||
@@ -537,10 +771,107 @@ def get_api_tokens(
|
||||
data=json_body,
|
||||
format="vnd.api+json",
|
||||
)
|
||||
return response.json()["data"]["attributes"]["access"], response.json()["data"][
|
||||
"attributes"
|
||||
]["refresh"]
|
||||
return (
|
||||
response.json()["data"]["attributes"]["access"],
|
||||
response.json()["data"]["attributes"]["refresh"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scan_summaries_fixture(tenants_fixture, providers_fixture):
|
||||
tenant = tenants_fixture[0]
|
||||
provider = providers_fixture[0]
|
||||
scan = Scan.objects.create(
|
||||
name="overview scan",
|
||||
provider=provider,
|
||||
trigger=Scan.TriggerChoices.MANUAL,
|
||||
state=StateChoices.COMPLETED,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
ScanSummary.objects.create(
|
||||
tenant=tenant,
|
||||
check_id="check1",
|
||||
service="service1",
|
||||
severity="high",
|
||||
region="region1",
|
||||
_pass=1,
|
||||
fail=0,
|
||||
muted=0,
|
||||
total=1,
|
||||
new=1,
|
||||
changed=0,
|
||||
unchanged=0,
|
||||
fail_new=0,
|
||||
fail_changed=0,
|
||||
pass_new=1,
|
||||
pass_changed=0,
|
||||
muted_new=0,
|
||||
muted_changed=0,
|
||||
scan=scan,
|
||||
)
|
||||
|
||||
ScanSummary.objects.create(
|
||||
tenant=tenant,
|
||||
check_id="check1",
|
||||
service="service1",
|
||||
severity="high",
|
||||
region="region2",
|
||||
_pass=0,
|
||||
fail=1,
|
||||
muted=1,
|
||||
total=2,
|
||||
new=2,
|
||||
changed=0,
|
||||
unchanged=0,
|
||||
fail_new=1,
|
||||
fail_changed=0,
|
||||
pass_new=0,
|
||||
pass_changed=0,
|
||||
muted_new=1,
|
||||
muted_changed=0,
|
||||
scan=scan,
|
||||
)
|
||||
|
||||
ScanSummary.objects.create(
|
||||
tenant=tenant,
|
||||
check_id="check2",
|
||||
service="service2",
|
||||
severity="critical",
|
||||
region="region1",
|
||||
_pass=1,
|
||||
fail=0,
|
||||
muted=0,
|
||||
total=1,
|
||||
new=1,
|
||||
changed=0,
|
||||
unchanged=0,
|
||||
fail_new=0,
|
||||
fail_changed=0,
|
||||
pass_new=1,
|
||||
pass_changed=0,
|
||||
muted_new=0,
|
||||
muted_changed=0,
|
||||
scan=scan,
|
||||
)
|
||||
|
||||
|
||||
def get_authorization_header(access_token: str) -> dict:
|
||||
return {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
"""Ensure test_rbac.py is executed first."""
|
||||
items.sort(key=lambda item: 0 if "test_rbac.py" in item.nodeid else 1)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# Apply the mock before the test session starts. This is necessary to avoid admin error when running the
|
||||
# 0004_rbac_missing_admin_roles migration
|
||||
patch("api.db_router.MainRouter.admin_db", new="default").start()
|
||||
|
||||
|
||||
def pytest_unconfigure(config):
|
||||
# Stop all patches after the test session ends. This is necessary to avoid admin error when running the
|
||||
# 0004_rbac_missing_admin_roles migration
|
||||
patch.stopall()
|
||||
|
||||
@@ -2,7 +2,7 @@ from celery.utils.log import get_task_logger
|
||||
from django.db import transaction
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.db_utils import batch_delete, tenant_transaction
|
||||
from api.db_utils import batch_delete, rls_transaction
|
||||
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
@@ -66,7 +66,7 @@ def delete_tenant(pk: str):
|
||||
deletion_summary = {}
|
||||
|
||||
for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
|
||||
with tenant_transaction(pk):
|
||||
with rls_transaction(pk):
|
||||
summary = delete_provider(provider.id)
|
||||
deletion_summary.update(summary)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from api.compliance import (
|
||||
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
|
||||
generate_scan_compliance,
|
||||
)
|
||||
from api.db_utils import tenant_transaction
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import (
|
||||
ComplianceOverview,
|
||||
Finding,
|
||||
@@ -69,7 +69,7 @@ def _store_resources(
|
||||
- tuple[str, str]: A tuple containing the resource UID and region.
|
||||
|
||||
"""
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
resource_instance, created = Resource.objects.get_or_create(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_instance,
|
||||
@@ -86,7 +86,7 @@ def _store_resources(
|
||||
resource_instance.service = finding.service_name
|
||||
resource_instance.type = finding.resource_type
|
||||
resource_instance.save()
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
tags = [
|
||||
ResourceTag.objects.get_or_create(
|
||||
tenant_id=tenant_id, key=key, value=value
|
||||
@@ -122,7 +122,7 @@ def perform_prowler_scan(
|
||||
unique_resources = set()
|
||||
start_time = time.time()
|
||||
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
provider_instance = Provider.objects.get(pk=provider_id)
|
||||
scan_instance = Scan.objects.get(pk=scan_id)
|
||||
scan_instance.state = StateChoices.EXECUTING
|
||||
@@ -130,7 +130,7 @@ def perform_prowler_scan(
|
||||
scan_instance.save()
|
||||
|
||||
try:
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
try:
|
||||
prowler_provider = initialize_prowler_provider(provider_instance)
|
||||
provider_instance.connected = True
|
||||
@@ -156,7 +156,7 @@ def perform_prowler_scan(
|
||||
for finding in findings:
|
||||
for attempt in range(CELERY_DEADLOCK_ATTEMPTS):
|
||||
try:
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
# Process resource
|
||||
resource_uid = finding.resource_uid
|
||||
if resource_uid not in resource_cache:
|
||||
@@ -188,7 +188,7 @@ def perform_prowler_scan(
|
||||
resource_instance.type = finding.resource_type
|
||||
updated_fields.append("type")
|
||||
if updated_fields:
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
resource_instance.save(update_fields=updated_fields)
|
||||
except (OperationalError, IntegrityError) as db_err:
|
||||
if attempt < CELERY_DEADLOCK_ATTEMPTS - 1:
|
||||
@@ -203,7 +203,7 @@ def perform_prowler_scan(
|
||||
|
||||
# Update tags
|
||||
tags = []
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
for key, value in finding.resource_tags.items():
|
||||
tag_key = (key, value)
|
||||
if tag_key not in tag_cache:
|
||||
@@ -219,7 +219,7 @@ def perform_prowler_scan(
|
||||
unique_resources.add((resource_instance.uid, resource_instance.region))
|
||||
|
||||
# Process finding
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
finding_uid = finding.uid
|
||||
if finding_uid not in last_status_cache:
|
||||
most_recent_finding = (
|
||||
@@ -267,7 +267,7 @@ def perform_prowler_scan(
|
||||
region_dict[finding.check_id] = finding.status.value
|
||||
|
||||
# Update scan progress
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
scan_instance.progress = progress
|
||||
scan_instance.save()
|
||||
|
||||
@@ -279,7 +279,7 @@ def perform_prowler_scan(
|
||||
scan_instance.state = StateChoices.FAILED
|
||||
|
||||
finally:
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
scan_instance.duration = time.time() - start_time
|
||||
scan_instance.completed_at = datetime.now(tz=timezone.utc)
|
||||
scan_instance.unique_resource_count = len(unique_resources)
|
||||
@@ -330,7 +330,7 @@ def perform_prowler_scan(
|
||||
total_requirements=compliance["total_requirements"],
|
||||
)
|
||||
)
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
ComplianceOverview.objects.bulk_create(compliance_overview_objects)
|
||||
|
||||
if exception is not None:
|
||||
@@ -368,7 +368,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
- muted_new: Muted findings with a delta of 'new'.
|
||||
- muted_changed: Muted findings with a delta of 'changed'.
|
||||
"""
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
findings = Finding.objects.filter(scan_id=scan_id)
|
||||
|
||||
aggregation = findings.values(
|
||||
@@ -464,7 +464,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
|
||||
),
|
||||
)
|
||||
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
scan_aggregations = {
|
||||
ScanSummary(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -7,7 +7,7 @@ from tasks.jobs.connection import check_provider_connection
|
||||
from tasks.jobs.deletion import delete_provider, delete_tenant
|
||||
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan
|
||||
|
||||
from api.db_utils import tenant_transaction
|
||||
from api.db_utils import rls_transaction
|
||||
from api.decorators import set_tenant
|
||||
from api.models import Provider, Scan
|
||||
|
||||
@@ -99,7 +99,7 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
|
||||
"""
|
||||
task_id = self.request.id
|
||||
|
||||
with tenant_transaction(tenant_id):
|
||||
with rls_transaction(tenant_id):
|
||||
provider_instance = Provider.objects.get(pk=provider_id)
|
||||
periodic_task_instance = PeriodicTask.objects.get(
|
||||
name=f"scan-perform-scheduled-{provider_id}"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from tasks.jobs.deletion import delete_provider, delete_tenant
|
||||
@@ -24,7 +22,6 @@ class TestDeleteProvider:
|
||||
delete_provider(non_existent_pk)
|
||||
|
||||
|
||||
@patch("api.db_router.MainRouter.admin_db", new="default")
|
||||
@pytest.mark.django_db
|
||||
class TestDeleteTenant:
|
||||
def test_delete_tenant_success(self, tenants_fixture, providers_fixture):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -26,7 +27,7 @@ class TestPerformScan:
|
||||
providers_fixture,
|
||||
):
|
||||
with (
|
||||
patch("api.db_utils.tenant_transaction"),
|
||||
patch("api.db_utils.rls_transaction"),
|
||||
patch(
|
||||
"tasks.jobs.scan.initialize_prowler_provider"
|
||||
) as mock_initialize_prowler_provider,
|
||||
@@ -165,10 +166,10 @@ class TestPerformScan:
|
||||
"tasks.jobs.scan.initialize_prowler_provider",
|
||||
side_effect=Exception("Connection error"),
|
||||
)
|
||||
@patch("api.db_utils.tenant_transaction")
|
||||
@patch("api.db_utils.rls_transaction")
|
||||
def test_perform_prowler_scan_no_connection(
|
||||
self,
|
||||
mock_tenant_transaction,
|
||||
mock_rls_transaction,
|
||||
mock_initialize_prowler_provider,
|
||||
mock_prowler_scan_class,
|
||||
tenants_fixture,
|
||||
@@ -205,14 +206,14 @@ class TestPerformScan:
|
||||
|
||||
@patch("api.models.ResourceTag.objects.get_or_create")
|
||||
@patch("api.models.Resource.objects.get_or_create")
|
||||
@patch("api.db_utils.tenant_transaction")
|
||||
@patch("api.db_utils.rls_transaction")
|
||||
def test_store_resources_new_resource(
|
||||
self,
|
||||
mock_tenant_transaction,
|
||||
mock_rls_transaction,
|
||||
mock_get_or_create_resource,
|
||||
mock_get_or_create_tag,
|
||||
):
|
||||
tenant_id = "tenant123"
|
||||
tenant_id = uuid.uuid4()
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.id = "provider456"
|
||||
|
||||
@@ -253,14 +254,14 @@ class TestPerformScan:
|
||||
|
||||
@patch("api.models.ResourceTag.objects.get_or_create")
|
||||
@patch("api.models.Resource.objects.get_or_create")
|
||||
@patch("api.db_utils.tenant_transaction")
|
||||
@patch("api.db_utils.rls_transaction")
|
||||
def test_store_resources_existing_resource(
|
||||
self,
|
||||
mock_tenant_transaction,
|
||||
mock_rls_transaction,
|
||||
mock_get_or_create_resource,
|
||||
mock_get_or_create_tag,
|
||||
):
|
||||
tenant_id = "tenant123"
|
||||
tenant_id = uuid.uuid4()
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.id = "provider456"
|
||||
|
||||
@@ -310,14 +311,14 @@ class TestPerformScan:
|
||||
|
||||
@patch("api.models.ResourceTag.objects.get_or_create")
|
||||
@patch("api.models.Resource.objects.get_or_create")
|
||||
@patch("api.db_utils.tenant_transaction")
|
||||
@patch("api.db_utils.rls_transaction")
|
||||
def test_store_resources_with_tags(
|
||||
self,
|
||||
mock_tenant_transaction,
|
||||
mock_rls_transaction,
|
||||
mock_get_or_create_resource,
|
||||
mock_get_or_create_tag,
|
||||
):
|
||||
tenant_id = "tenant123"
|
||||
tenant_id = uuid.uuid4()
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.id = "provider456"
|
||||
|
||||
|
||||
@@ -73,6 +73,8 @@ To use each one you need to pass the proper flag to the execution. Prowler for A
|
||||
- **Subscription scope permissions**: Required to launch the checks against your resources, mandatory to launch the tool. It is required to add the following RBAC builtin roles per subscription to the entity that is going to be assumed by the tool:
|
||||
- `Reader`
|
||||
- `ProwlerRole` (custom role defined in [prowler-azure-custom-role](https://github.com/prowler-cloud/prowler/blob/master/permissions/prowler-azure-custom-role.json))
|
||||
???+ note
|
||||
Please, notice that the field `assignableScopes` in the JSON custom role file must be changed to be the subscription or management group where the role is going to be assigned. The valid formats for the field are `/subscriptions/<subscription-id>` or `/providers/Microsoft.Management/managementGroups/<management-group-id>`.
|
||||
|
||||
To assign the permissions, follow the instructions in the [Microsoft Entra ID permissions](../tutorials/azure/create-prowler-service-principal.md#assigning-the-proper-permissions) section and the [Azure subscriptions permissions](../tutorials/azure/subscriptions.md#assigning-proper-permissions) section, respectively.
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ Mutelist:
|
||||
Resources:
|
||||
- "user-1" # Will mute user-1 in check iam_user_hardware_mfa_enabled
|
||||
- "user-2" # Will mute user-2 in check iam_user_hardware_mfa_enabled
|
||||
Description: "Findings related with the check iam_user_hardware_mfa_enabled will be muted for us-east-1 region and user-1, user-2 resources"
|
||||
"ec2_*":
|
||||
Regions:
|
||||
- "*"
|
||||
@@ -140,6 +141,9 @@ Mutelist:
|
||||
| `resource` | The resource identifier. Use `*` to apply the mutelist to all resources. | `ANDed` |
|
||||
| `tag` | The tag value. | `ORed` |
|
||||
|
||||
### Description
|
||||
|
||||
This field can be used to add information or some hints for the Mutelist rule.
|
||||
|
||||
## How to Use the Mutelist
|
||||
|
||||
@@ -171,6 +175,7 @@ If you want to mute failed findings only in specific regions, create a file with
|
||||
- "ap-southeast-2"
|
||||
Resources:
|
||||
- "*"
|
||||
Description: "Description related with the muted findings for the check"
|
||||
|
||||
### Default Mutelist
|
||||
For the AWS Provider, Prowler is executed with a default AWS Mutelist with the AWS Resources that should be muted such as all resources created by AWS Control Tower when setting up a landing zone that can be found in [AWS Documentation](https://docs.aws.amazon.com/controltower/latest/userguide/shared-account-resources.html).
|
||||
|
||||
@@ -95,6 +95,7 @@ Resources:
|
||||
- 'servicecatalog:List*'
|
||||
- 'ssm:GetDocument'
|
||||
- 'ssm-incidents:List*'
|
||||
- 'states:ListTagsForResource'
|
||||
- 'support:Describe*'
|
||||
- 'tag:GetTagKeys'
|
||||
- 'wellarchitected:List*'
|
||||
|
||||
@@ -45,6 +45,7 @@
|
||||
"servicecatalog:List*",
|
||||
"ssm:GetDocument",
|
||||
"ssm-incidents:List*",
|
||||
"states:ListTagsForResource",
|
||||
"support:Describe*",
|
||||
"tag:GetTagKeys",
|
||||
"wellarchitected:List*"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"roleName": "ProwlerRole",
|
||||
"description": "Role used for checks that require read-only access to Azure resources and are not covered by the Reader role.",
|
||||
"assignableScopes": [
|
||||
"/"
|
||||
"/{'subscriptions', 'providers/Microsoft.Management/managementGroups'}/{Your Subscription or Management Group ID}"
|
||||
],
|
||||
"permissions": [
|
||||
{
|
||||
|
||||
24
poetry.lock
generated
24
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "about-time"
|
||||
@@ -775,17 +775,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.35.78"
|
||||
version = "1.35.81"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "boto3-1.35.78-py3-none-any.whl", hash = "sha256:5ef7166fe5060637b92af8dc152cd7acecf96b3fc9c5456706a886cadb534391"},
|
||||
{file = "boto3-1.35.78.tar.gz", hash = "sha256:fc8001519c8842e766ad3793bde3fbd0bb39e821a582fc12cf67876b8f3cf7f1"},
|
||||
{file = "boto3-1.35.81-py3-none-any.whl", hash = "sha256:742941b2424c0223d2d94a08c3485462fa7c58d816b62ca80f08e555243acee1"},
|
||||
{file = "boto3-1.35.81.tar.gz", hash = "sha256:d2e95fa06f095b8e0c545dd678c6269d253809b2997c30f5ce8a956c410b4e86"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.35.78,<1.36.0"
|
||||
botocore = ">=1.35.81,<1.36.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.10.0,<0.11.0"
|
||||
|
||||
@@ -794,13 +794,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.35.79"
|
||||
version = "1.35.81"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "botocore-1.35.79-py3-none-any.whl", hash = "sha256:e6b10bb9a357e3f5ca2e60f6dd15a85d311b9a476eb21b3c0c2a3b364a2897c8"},
|
||||
{file = "botocore-1.35.79.tar.gz", hash = "sha256:245bfdda1b1508539ddd1819c67a8a2cc81780adf0715d3de418d64c4247f346"},
|
||||
{file = "botocore-1.35.81-py3-none-any.whl", hash = "sha256:a7b13bbd959bf2d6f38f681676aab408be01974c46802ab997617b51399239f7"},
|
||||
{file = "botocore-1.35.81.tar.gz", hash = "sha256:564c2478e50179e0b766e6a87e5e0cdd35e1bc37eb375c1cf15511f5dd13600d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2583,13 +2583,13 @@ dev = ["click", "codecov", "mkdocs-gen-files", "mkdocs-git-authors-plugin", "mkd
|
||||
|
||||
[[package]]
|
||||
name = "mkdocs-material"
|
||||
version = "9.5.48"
|
||||
version = "9.5.49"
|
||||
description = "Documentation that simply works"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "mkdocs_material-9.5.48-py3-none-any.whl", hash = "sha256:b695c998f4b939ce748adbc0d3bff73fa886a670ece948cf27818fa115dc16f8"},
|
||||
{file = "mkdocs_material-9.5.48.tar.gz", hash = "sha256:a582531e8b34f4c7ed38c29d5c44763053832cf2a32f7409567e0c74749a47db"},
|
||||
{file = "mkdocs_material-9.5.49-py3-none-any.whl", hash = "sha256:c3c2d8176b18198435d3a3e119011922f3e11424074645c24019c2dcf08a360e"},
|
||||
{file = "mkdocs_material-9.5.49.tar.gz", hash = "sha256:3671bb282b4f53a1c72e08adbe04d2481a98f85fed392530051f80ff94a9621d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5199,4 +5199,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "e00da6013a01923ac8e79017e7fdb221e09a3dcf581ad8d74e39550be64cc2f3"
|
||||
content-hash = "aff89cb4b5b66d79efcf4e1f16fdc76e4a1019e4196cd8961299d479b99cf7fa"
|
||||
|
||||
@@ -10,6 +10,7 @@ Mutelist:
|
||||
- "*"
|
||||
Resources:
|
||||
- "aws-controltower-NotificationForwarder"
|
||||
Description: "Checks from AWS lambda functions muted by default"
|
||||
"cloudformation_stack*":
|
||||
Regions:
|
||||
- "*"
|
||||
|
||||
@@ -14,6 +14,7 @@ Mutelist:
|
||||
Resources:
|
||||
- "user-1" # Will ignore user-1 in check iam_user_hardware_mfa_enabled
|
||||
- "user-2" # Will ignore user-2 in check iam_user_hardware_mfa_enabled
|
||||
Description: "Check iam_user_hardware_mfa_enabled muted for region us-east-1 and resources user-1, user-2"
|
||||
"ec2_*":
|
||||
Regions:
|
||||
- "*"
|
||||
|
||||
@@ -15,6 +15,7 @@ Mutelist:
|
||||
Resources:
|
||||
- "sqlserver1" # Will ignore sqlserver1 in check sqlserver_tde_encryption_enabled located in westeurope
|
||||
- "sqlserver2" # Will ignore sqlserver2 in check sqlserver_tde_encryption_enabled located in westeurope
|
||||
Description: "Findings related with the check sqlserver_tde_encryption_enabled is muted for westeurope region and sqlserver1, sqlserver2 resources"
|
||||
"defender_*":
|
||||
Regions:
|
||||
- "*"
|
||||
|
||||
@@ -15,6 +15,7 @@ Mutelist:
|
||||
Resources:
|
||||
- "instance1" # Will ignore instance1 in check compute_instance_public_ip located in europe-southwest1
|
||||
- "instance2" # Will ignore instance2 in check compute_instance_public_ip located in europe-southwest1
|
||||
Description: "Findings related with the check compute_instance_public_ip will be muted for europe-southwest1 region and instance1, instance2 resources"
|
||||
"iam_*":
|
||||
Regions:
|
||||
- "*"
|
||||
|
||||
@@ -15,6 +15,7 @@ Mutelist:
|
||||
Resources:
|
||||
- "prowler-pod1" # Will ignore prowler-pod1 in check core_minimize_allowPrivilegeEscalation_containers located in namespace1
|
||||
- "prowler-pod2" # Will ignore prowler-pod2 in check core_minimize_allowPrivilegeEscalation_containers located in namespace1
|
||||
Description: "Findings related with the check core_minimize_allowPrivilegeEscalation_containers will be muted for namespace1 region and prowler-pod1, prowler-pod2 resources"
|
||||
"kubelet_*":
|
||||
Regions:
|
||||
- "*"
|
||||
|
||||
@@ -15,6 +15,7 @@ mutelist_schema = Schema(
|
||||
Optional("Resources"): list,
|
||||
Optional("Tags"): list,
|
||||
},
|
||||
Optional("Description"): str,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,6 +106,7 @@ class Mutelist(ABC):
|
||||
- 'i-123456789'
|
||||
Tags:
|
||||
- 'Name=AdminInstance | Environment=Prod'
|
||||
Description: 'Field to describe why the findings associated with these values are muted'
|
||||
```
|
||||
The check `ec2_instance_detailed_monitoring_enabled` will be muted for all accounts and regions and for the resource_id 'i-123456789' with at least one of the tags 'Name=AdminInstance' or 'Environment=Prod'.
|
||||
|
||||
|
||||
@@ -908,6 +908,7 @@
|
||||
"ap-southeast-2",
|
||||
"ap-southeast-3",
|
||||
"ap-southeast-4",
|
||||
"ap-southeast-5",
|
||||
"ca-central-1",
|
||||
"ca-west-1",
|
||||
"eu-central-1",
|
||||
@@ -1260,6 +1261,15 @@
|
||||
"aws-us-gov": []
|
||||
}
|
||||
},
|
||||
"bcm-pricing-calculator": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
"us-east-1"
|
||||
],
|
||||
"aws-cn": [],
|
||||
"aws-us-gov": []
|
||||
}
|
||||
},
|
||||
"bedrock": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
@@ -7345,6 +7355,31 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"networkflowmonitor": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
"ap-northeast-1",
|
||||
"ap-northeast-2",
|
||||
"ap-northeast-3",
|
||||
"ap-south-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"ca-central-1",
|
||||
"eu-central-1",
|
||||
"eu-north-1",
|
||||
"eu-west-1",
|
||||
"eu-west-2",
|
||||
"eu-west-3",
|
||||
"sa-east-1",
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-west-1",
|
||||
"us-west-2"
|
||||
],
|
||||
"aws-cn": [],
|
||||
"aws-us-gov": []
|
||||
}
|
||||
},
|
||||
"networkmanager": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
@@ -7442,6 +7477,15 @@
|
||||
"aws-us-gov": []
|
||||
}
|
||||
},
|
||||
"notificationscontacts": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
"us-east-1"
|
||||
],
|
||||
"aws-cn": [],
|
||||
"aws-us-gov": []
|
||||
}
|
||||
},
|
||||
"oam": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
@@ -7486,6 +7530,23 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"observabilityadmin": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
"ap-northeast-1",
|
||||
"ap-southeast-1",
|
||||
"ap-southeast-2",
|
||||
"eu-central-1",
|
||||
"eu-north-1",
|
||||
"eu-west-1",
|
||||
"us-east-1",
|
||||
"us-east-2",
|
||||
"us-west-2"
|
||||
],
|
||||
"aws-cn": [],
|
||||
"aws-us-gov": []
|
||||
}
|
||||
},
|
||||
"omics": {
|
||||
"regions": {
|
||||
"aws": [
|
||||
@@ -11100,10 +11161,12 @@
|
||||
"ap-southeast-3",
|
||||
"ap-southeast-4",
|
||||
"ca-central-1",
|
||||
"ca-west-1",
|
||||
"eu-central-1",
|
||||
"eu-central-2",
|
||||
"eu-north-1",
|
||||
"eu-south-1",
|
||||
"eu-south-2",
|
||||
"eu-west-1",
|
||||
"eu-west-2",
|
||||
"eu-west-3",
|
||||
|
||||
@@ -8,14 +8,14 @@ class backup_recovery_point_encrypted(Check):
|
||||
for recovery_point in backup_client.recovery_points:
|
||||
report = Check_Report_AWS(self.metadata())
|
||||
report.region = recovery_point.backup_vault_region
|
||||
report.resource_id = recovery_point.backup_vault_name
|
||||
report.resource_id = recovery_point.id
|
||||
report.resource_arn = recovery_point.arn
|
||||
report.resource_tags = recovery_point.tags
|
||||
report.status = "FAIL"
|
||||
report.status_extended = f"Backup Recovery Point {recovery_point.arn} for Backup Vault {recovery_point.backup_vault_name} is not encrypted at rest."
|
||||
report.status_extended = f"Backup Recovery Point {recovery_point.id} for Backup Vault {recovery_point.backup_vault_name} is not encrypted at rest."
|
||||
if recovery_point.encrypted:
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Backup Recovery Point {recovery_point.arn} for Backup Vault {recovery_point.backup_vault_name} is encrypted at rest."
|
||||
report.status_extended = f"Backup Recovery Point {recovery_point.id} for Backup Vault {recovery_point.backup_vault_name} is encrypted at rest."
|
||||
|
||||
findings.append(report)
|
||||
|
||||
|
||||
@@ -195,6 +195,7 @@ class Backup(AWSService):
|
||||
self.recovery_points.append(
|
||||
RecoveryPoint(
|
||||
arn=arn,
|
||||
id=arn.split(":")[-1],
|
||||
backup_vault_name=backup_vault.name,
|
||||
encrypted=recovery_point.get(
|
||||
"IsEncrypted", False
|
||||
@@ -246,6 +247,7 @@ class BackupReportPlan(BaseModel):
|
||||
|
||||
class RecoveryPoint(BaseModel):
|
||||
arn: str
|
||||
id: str
|
||||
backup_vault_name: str
|
||||
encrypted: bool
|
||||
backup_vault_region: str
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.cloudtrail.cloudtrail_client import (
|
||||
cloudtrail_client,
|
||||
)
|
||||
from prowler.providers.aws.services.s3.s3_client import s3_client
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Modify the CloudTrail's associated S3 bucket's public access settings to ensure the bucket is not publicly accessible.
|
||||
Specifically, this fixer configures the S3 bucket's public access block settings to block all public access.
|
||||
Requires the s3:PutBucketPublicAccessBlock permissions.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "s3:PutBucketPublicAccessBlock",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The CloudTrail name.
|
||||
region (str): AWS region where the CloudTrail and S3 bucket exist.
|
||||
Returns:
|
||||
bool: True if the operation is successful (policy and ACL updated), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = s3_client.regional_clients[region]
|
||||
for trail in cloudtrail_client.trails.values():
|
||||
if trail.name == resource_id:
|
||||
trail_bucket = trail.s3_bucket
|
||||
|
||||
regional_client.put_public_access_block(
|
||||
Bucket=trail_bucket,
|
||||
PublicAccessBlockConfiguration={
|
||||
"BlockPublicAcls": True,
|
||||
"IgnorePublicAcls": True,
|
||||
"BlockPublicPolicy": True,
|
||||
"RestrictPublicBuckets": True,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -48,7 +48,7 @@ class cloudtrail_multi_region_enabled_logging_management_events(Check):
|
||||
report.resource_id = trail.name
|
||||
report.resource_arn = trail.arn
|
||||
report.resource_tags = trail.tags
|
||||
report.region = trail.home_region
|
||||
report.region = region
|
||||
report.status = "PASS"
|
||||
if trail.is_multiregion:
|
||||
report.status_extended = f"Trail {trail.name} from home region {trail.home_region} is multi-region, is logging and have management events enabled."
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Modify the attributes of an EC2 AMI to remove public access.
|
||||
Specifically, this fixer removes the 'all' value from the 'LaunchPermission' attribute
|
||||
to prevent the AMI from being publicly accessible.
|
||||
Requires the ec2:ModifyImageAttribute permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:ModifyImageAttribute",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The ID of the EC2 AMI to make private.
|
||||
region (str): AWS region where the AMI exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (the AMI is no longer publicly accessible), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
regional_client.modify_image_attribute(
|
||||
ImageId=resource_id,
|
||||
LaunchPermission={"Remove": [{"Group": "all"}]},
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Cassandra ports (7000, 7001, 7199, 9042, 9160) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies Cassandra ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [7000, 7001, 7199, 9042, 9160]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Elasticsearch and Kibana ports (9200, 9300, 5601)
|
||||
from any address (0.0.0.0/0) for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies those ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [9200, 9300, 5601]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing FTP ports (20, 21) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies FTP ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [20, 21]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Kafka ports (9092) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies Kafka ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [9092]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Kerberos ports (88, 464, 749, 750) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies Kerberos ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [88, 464, 749, 750]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing LDAP ports (389, 636) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies LDAP ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [389, 636]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Memcached ports (11211) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies Memcached ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [11211]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing MongoDB ports (27017, 27018) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies MongoDB ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [27017, 27018]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing MySQL ports (3306) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies MySQL ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [3306]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Oracle ports (1521, 2483, 2484) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies Oracle ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [1521, 2483, 2484]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing PostgreSQL ports (5432) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies PostgreSQL ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [5432]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing RDP ports (3389) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies RDP ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [3389]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Redis ports (6379) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies Redis ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [6379]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing SQLServer ports (1433, 1434) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies SQLServer ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [1433, 1434]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing SSH ports (22) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies SSH ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [22]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing Telnet ports (23) from any address (0.0.0.0/0)
|
||||
for the EC2 instance's security groups.
|
||||
This fixer will only be triggered if the check identifies Telnet ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The EC2 instance ID.
|
||||
region (str): The AWS region where the EC2 instance exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = [23]
|
||||
for instance in ec2_client.instances:
|
||||
if instance.id == resource_id:
|
||||
for sg in ec2_client.security_groups.values():
|
||||
if sg.id in instance.security_groups:
|
||||
for ingress_rule in sg.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=sg.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,52 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
|
||||
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Revokes any ingress rule allowing high risk ports (25, 110, 135, 143, 445, 3000, 4333, 5000, 5500, 8080, 8088)
|
||||
from any address (0.0.0.0/0) for the security groups.
|
||||
This fixer will only be triggered if the check identifies high risk ports open to the Internet.
|
||||
Requires the ec2:RevokeSecurityGroupIngress permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "ec2:RevokeSecurityGroupIngress",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The Security Group ID.
|
||||
region (str): The AWS region where the Security Group exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (ingress rule revoked), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = ec2_client.regional_clients[region]
|
||||
check_ports = ec2_client.audit_config.get(
|
||||
"ec2_high_risk_ports",
|
||||
[25, 110, 135, 143, 445, 3000, 4333, 5000, 5500, 8080, 8088],
|
||||
)
|
||||
for security_group in ec2_client.security_groups.values():
|
||||
if security_group.id == resource_id:
|
||||
for ingress_rule in security_group.ingress_rules:
|
||||
if check_security_group(
|
||||
ingress_rule, "tcp", check_ports, any_address=True
|
||||
):
|
||||
regional_client.revoke_security_group_ingress(
|
||||
GroupId=security_group.id,
|
||||
IpPermissions=[ingress_rule],
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -29,7 +29,9 @@ class route53_dangling_ip_subdomain_takeover(Check):
|
||||
# Check if record is an IP Address
|
||||
if validate_ip_address(record):
|
||||
report = Check_Report_AWS(self.metadata())
|
||||
report.resource_id = f"{record_set.hosted_zone_id}/{record}"
|
||||
report.resource_id = (
|
||||
f"{record_set.hosted_zone_id}/{record_set.name}/{record}"
|
||||
)
|
||||
report.resource_arn = route53_client.hosted_zones[
|
||||
record_set.hosted_zone_id
|
||||
].arn
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.providers.aws.services.s3.s3_client import s3_client
|
||||
|
||||
|
||||
def fixer(resource_id: str, region: str) -> bool:
|
||||
"""
|
||||
Modify the S3 bucket's policy to remove public access.
|
||||
Specifically, this fixer delete the policy of the public bucket.
|
||||
Requires the s3:DeleteBucketPolicy permission.
|
||||
Permissions:
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "s3:DeleteBucketPolicy",
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
Args:
|
||||
resource_id (str): The S3 bucket name.
|
||||
region (str): AWS region where the S3 bucket exists.
|
||||
Returns:
|
||||
bool: True if the operation is successful (policy updated), False otherwise.
|
||||
"""
|
||||
try:
|
||||
regional_client = s3_client.regional_clients[region]
|
||||
|
||||
regional_client.delete_bucket_policy(Bucket=resource_id)
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,6 @@
|
||||
from prowler.providers.aws.services.stepfunctions.stepfunctions_service import (
|
||||
StepFunctions,
|
||||
)
|
||||
from prowler.providers.common.provider import Provider
|
||||
|
||||
stepfunctions_client = StepFunctions(Provider.get_global_provider())
|
||||
@@ -0,0 +1,320 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
|
||||
from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
|
||||
class StateMachineStatus(str, Enum):
|
||||
"""Enumeration of possible State Machine statuses."""
|
||||
|
||||
ACTIVE = "ACTIVE"
|
||||
DELETING = "DELETING"
|
||||
|
||||
|
||||
class StateMachineType(str, Enum):
|
||||
"""Enumeration of possible State Machine types."""
|
||||
|
||||
STANDARD = "STANDARD"
|
||||
EXPRESS = "EXPRESS"
|
||||
|
||||
|
||||
class LoggingLevel(str, Enum):
|
||||
"""Enumeration of possible logging levels."""
|
||||
|
||||
ALL = "ALL"
|
||||
ERROR = "ERROR"
|
||||
FATAL = "FATAL"
|
||||
OFF = "OFF"
|
||||
|
||||
|
||||
class EncryptionType(str, Enum):
|
||||
"""Enumeration of possible encryption types."""
|
||||
|
||||
AWS_OWNED_KEY = "AWS_OWNED_KEY"
|
||||
CUSTOMER_MANAGED_KMS_KEY = "CUSTOMER_MANAGED_KMS_KEY"
|
||||
|
||||
|
||||
class CloudWatchLogsLogGroup(BaseModel):
|
||||
"""
|
||||
Represents a CloudWatch Logs Log Group configuration for a State Machine.
|
||||
|
||||
Attributes:
|
||||
log_group_arn (str): The ARN of the CloudWatch Logs Log Group.
|
||||
"""
|
||||
|
||||
log_group_arn: str
|
||||
|
||||
|
||||
class LoggingDestination(BaseModel):
|
||||
"""
|
||||
Represents a logging destination for a State Machine.
|
||||
|
||||
Attributes:
|
||||
cloud_watch_logs_log_group (CloudWatchLogsLogGroup): The CloudWatch Logs Log Group configuration.
|
||||
"""
|
||||
|
||||
cloud_watch_logs_log_group: CloudWatchLogsLogGroup
|
||||
|
||||
|
||||
class LoggingConfiguration(BaseModel):
|
||||
"""
|
||||
Represents the logging configuration for a State Machine.
|
||||
|
||||
Attributes:
|
||||
level (LoggingLevel): The logging level.
|
||||
include_execution_data (bool): Whether to include execution data in the logs.
|
||||
destinations (List[LoggingDestination]): List of logging destinations.
|
||||
"""
|
||||
|
||||
level: LoggingLevel
|
||||
include_execution_data: bool
|
||||
destinations: List[LoggingDestination]
|
||||
|
||||
|
||||
class TracingConfiguration(BaseModel):
|
||||
"""
|
||||
Represents the tracing configuration for a State Machine.
|
||||
|
||||
Attributes:
|
||||
enabled (bool): Whether X-Ray tracing is enabled.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EncryptionConfiguration(BaseModel):
|
||||
"""
|
||||
Represents the encryption configuration for a State Machine.
|
||||
|
||||
Attributes:
|
||||
kms_key_id (Optional[str]): The KMS key ID used for encryption.
|
||||
kms_data_key_reuse_period_seconds (Optional[int]): The time in seconds that a KMS data key can be reused.
|
||||
type (EncryptionType): The type of encryption used.
|
||||
"""
|
||||
|
||||
kms_key_id: Optional[str]
|
||||
kms_data_key_reuse_period_seconds: Optional[int]
|
||||
type: EncryptionType
|
||||
|
||||
|
||||
class StateMachine(BaseModel):
|
||||
"""
|
||||
Represents an AWS Step Functions State Machine.
|
||||
|
||||
Attributes:
|
||||
id (str): The unique identifier of the state machine.
|
||||
arn (str): The ARN of the state machine.
|
||||
name (Optional[str]): The name of the state machine.
|
||||
status (StateMachineStatus): The current status of the state machine.
|
||||
definition (str): The Amazon States Language definition of the state machine.
|
||||
role_arn (str): The ARN of the IAM role used by the state machine.
|
||||
type (StateMachineType): The type of the state machine (STANDARD or EXPRESS).
|
||||
creation_date (datetime): The creation date and time of the state machine.
|
||||
region (str): The region where the state machine is.
|
||||
logging_configuration (Optional[LoggingConfiguration]): The logging configuration of the state machine.
|
||||
tracing_configuration (Optional[TracingConfiguration]): The tracing configuration of the state machine.
|
||||
label (Optional[str]): The label associated with the state machine.
|
||||
revision_id (Optional[str]): The revision ID of the state machine.
|
||||
description (Optional[str]): A description of the state machine.
|
||||
encryption_configuration (Optional[EncryptionConfiguration]): The encryption configuration of the state machine.
|
||||
tags (List[Dict]): A list of tags associated with the state machine.
|
||||
"""
|
||||
|
||||
id: str
|
||||
arn: str
|
||||
name: Optional[str] = None
|
||||
status: StateMachineStatus
|
||||
definition: Optional[str] = None
|
||||
role_arn: Optional[str] = None
|
||||
type: StateMachineType
|
||||
creation_date: datetime
|
||||
region: str
|
||||
logging_configuration: Optional[LoggingConfiguration] = None
|
||||
tracing_configuration: Optional[TracingConfiguration] = None
|
||||
label: Optional[str] = None
|
||||
revision_id: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
encryption_configuration: Optional[EncryptionConfiguration] = None
|
||||
tags: List[Dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StepFunctions(AWSService):
|
||||
"""
|
||||
AWS Step Functions service class to manage state machines.
|
||||
|
||||
This class provides methods to list state machines, describe their details,
|
||||
and list their associated tags across different AWS regions.
|
||||
"""
|
||||
|
||||
def __init__(self, provider):
|
||||
"""
|
||||
Initialize the StepFunctions service.
|
||||
|
||||
Args:
|
||||
provider: The AWS provider instance containing regional clients and audit configurations.
|
||||
"""
|
||||
super().__init__(__class__.__name__, provider)
|
||||
self.state_machines: Dict[str, StateMachine] = {}
|
||||
self.__threading_call__(self._list_state_machines)
|
||||
self.__threading_call__(
|
||||
self._describe_state_machine, self.state_machines.values()
|
||||
)
|
||||
self.__threading_call__(
|
||||
self._list_state_machine_tags, self.state_machines.values()
|
||||
)
|
||||
|
||||
def _list_state_machines(self, regional_client) -> None:
|
||||
"""
|
||||
List AWS Step Functions state machines in the specified region and populate the state_machines dictionary.
|
||||
|
||||
This function retrieves all state machines using pagination, filters them based on audit_resources if provided,
|
||||
and creates StateMachine instances to store their basic information.
|
||||
|
||||
Args:
|
||||
regional_client: The regional AWS Step Functions client used to interact with the AWS API.
|
||||
"""
|
||||
logger.info("StepFunctions - Listing state machines...")
|
||||
try:
|
||||
list_state_machines_paginator = regional_client.get_paginator(
|
||||
"list_state_machines"
|
||||
)
|
||||
|
||||
for page in list_state_machines_paginator.paginate():
|
||||
for state_machine_data in page.get("stateMachines", []):
|
||||
try:
|
||||
arn = state_machine_data.get("stateMachineArn")
|
||||
state_machine_id = (
|
||||
arn.split(":")[-1].split("/")[-1] if arn else None
|
||||
)
|
||||
if not self.audit_resources or is_resource_filtered(
|
||||
arn, self.audit_resources
|
||||
):
|
||||
state_machine = StateMachine(
|
||||
id=state_machine_id,
|
||||
arn=arn,
|
||||
name=state_machine_data.get("name"),
|
||||
type=StateMachineType(
|
||||
state_machine_data.get("type", "STANDARD")
|
||||
),
|
||||
creation_date=state_machine_data.get("creationDate"),
|
||||
region=regional_client.region,
|
||||
status=StateMachineStatus.ACTIVE,
|
||||
)
|
||||
|
||||
self.state_machines[arn] = state_machine
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _describe_state_machine(self, state_machine: StateMachine) -> None:
|
||||
"""
|
||||
Describe an AWS Step Functions state machine and update its details.
|
||||
|
||||
Args:
|
||||
state_machine (StateMachine): The StateMachine instance to describe and update.
|
||||
"""
|
||||
logger.info(
|
||||
f"StepFunctions - Describing state machine with ID {state_machine.id} ..."
|
||||
)
|
||||
try:
|
||||
regional_client = self.regional_clients[state_machine.region]
|
||||
response = regional_client.describe_state_machine(
|
||||
stateMachineArn=state_machine.arn
|
||||
)
|
||||
|
||||
state_machine.status = StateMachineStatus(response.get("status"))
|
||||
state_machine.definition = response.get("definition")
|
||||
state_machine.role_arn = response.get("roleArn")
|
||||
state_machine.label = response.get("label")
|
||||
state_machine.revision_id = response.get("revisionId")
|
||||
state_machine.description = response.get("description")
|
||||
|
||||
logging_config = response.get("loggingConfiguration")
|
||||
if logging_config:
|
||||
state_machine.logging_configuration = LoggingConfiguration(
|
||||
level=LoggingLevel(logging_config.get("level")),
|
||||
include_execution_data=logging_config.get("includeExecutionData"),
|
||||
destinations=[
|
||||
LoggingDestination(
|
||||
cloud_watch_logs_log_group=CloudWatchLogsLogGroup(
|
||||
log_group_arn=dest["cloudWatchLogsLogGroup"][
|
||||
"logGroupArn"
|
||||
]
|
||||
)
|
||||
)
|
||||
for dest in logging_config.get("destinations", [])
|
||||
],
|
||||
)
|
||||
|
||||
tracing_config = response.get("tracingConfiguration")
|
||||
if tracing_config:
|
||||
state_machine.tracing_configuration = TracingConfiguration(
|
||||
enabled=tracing_config.get("enabled")
|
||||
)
|
||||
|
||||
encryption_config = response.get("encryptionConfiguration")
|
||||
if encryption_config:
|
||||
state_machine.encryption_configuration = EncryptionConfiguration(
|
||||
kms_key_id=encryption_config.get("kmsKeyId"),
|
||||
kms_data_key_reuse_period_seconds=encryption_config.get(
|
||||
"kmsDataKeyReusePeriodSeconds"
|
||||
),
|
||||
type=EncryptionType(encryption_config.get("type")),
|
||||
)
|
||||
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "ResourceNotFoundException":
|
||||
logger.warning(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _list_state_machine_tags(self, state_machine: StateMachine) -> None:
|
||||
"""
|
||||
List tags for an AWS Step Functions state machine and update the StateMachine instance.
|
||||
|
||||
Args:
|
||||
state_machine (StateMachine): The StateMachine instance to list and update tags for.
|
||||
"""
|
||||
logger.info(
|
||||
f"StepFunctions - Listing tags for state machine with ID {state_machine.id} ..."
|
||||
)
|
||||
try:
|
||||
regional_client = self.regional_clients[state_machine.region]
|
||||
|
||||
response = regional_client.list_tags_for_resource(
|
||||
resourceArn=state_machine.arn
|
||||
)
|
||||
|
||||
state_machine.tags = response.get("tags", [])
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "ResourceNotFoundException":
|
||||
logger.warning(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"Provider": "aws",
|
||||
"CheckID": "stepfunctions_statemachine_logging_enabled",
|
||||
"CheckTitle": "Step Functions state machines should have logging enabled",
|
||||
"CheckType": [
|
||||
"Software and Configuration Checks/AWS Security Best Practices"
|
||||
],
|
||||
"ServiceName": "stepfunctions",
|
||||
"SubServiceName": "",
|
||||
"ResourceIdTemplate": "arn:aws:states:{region}:{account-id}:stateMachine/{stateMachine-id}",
|
||||
"Severity": "medium",
|
||||
"ResourceType": "AwsStepFunctionStateMachine",
|
||||
"Description": "This control checks if AWS Step Functions state machines have logging enabled. The control fails if the state machine doesn't have the loggingConfiguration property defined.",
|
||||
"Risk": "Without logging enabled, important operational data may be lost, making it difficult to troubleshoot issues, monitor performance, and ensure compliance with auditing requirements.",
|
||||
"RelatedUrl": "https://docs.aws.amazon.com/step-functions/latest/dg/logging.html",
|
||||
"Remediation": {
|
||||
"Code": {
|
||||
"CLI": "aws stepfunctions update-state-machine --state-machine-arn <state-machine-arn> --logging-configuration file://logging-config.json",
|
||||
"NativeIaC": "",
|
||||
"Other": "https://docs.aws.amazon.com/securityhub/latest/userguide/stepfunctions-controls.html#stepfunctions-1",
|
||||
"Terraform": "https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sfn_state_machine#logging_configuration"
|
||||
},
|
||||
"Recommendation": {
|
||||
"Text": "Configure logging for your Step Functions state machines to ensure that operational data is captured and available for debugging, monitoring, and auditing purposes.",
|
||||
"Url": "https://docs.aws.amazon.com/step-functions/latest/dg/logging.html"
|
||||
}
|
||||
},
|
||||
"Categories": [
|
||||
"logging"
|
||||
],
|
||||
"DependsOn": [],
|
||||
"RelatedTo": [],
|
||||
"Notes": ""
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import List
|
||||
|
||||
from prowler.lib.check.models import Check, Check_Report_AWS
|
||||
from prowler.providers.aws.services.stepfunctions.stepfunctions_client import (
|
||||
stepfunctions_client,
|
||||
)
|
||||
from prowler.providers.aws.services.stepfunctions.stepfunctions_service import (
|
||||
LoggingLevel,
|
||||
)
|
||||
|
||||
|
||||
class stepfunctions_statemachine_logging_enabled(Check):
|
||||
"""
|
||||
Check if AWS Step Functions state machines have logging enabled.
|
||||
|
||||
This class verifies whether each AWS Step Functions state machine has logging enabled by checking
|
||||
for the presence of a loggingConfiguration property in the state machine's configuration.
|
||||
"""
|
||||
|
||||
def execute(self) -> List[Check_Report_AWS]:
|
||||
"""
|
||||
Execute the Step Functions state machines logging enabled check.
|
||||
|
||||
Iterates over all Step Functions state machines and generates a report indicating whether
|
||||
each state machine has logging enabled.
|
||||
|
||||
Returns:
|
||||
List[Check_Report_AWS]: A list of report objects with the results of the check.
|
||||
"""
|
||||
findings = []
|
||||
for state_machine in stepfunctions_client.state_machines.values():
|
||||
report = Check_Report_AWS(self.metadata())
|
||||
report.region = state_machine.region
|
||||
report.resource_id = state_machine.id
|
||||
report.resource_arn = state_machine.arn
|
||||
report.resource_tags = state_machine.tags
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Step Functions state machine {state_machine.name} has logging enabled."
|
||||
|
||||
if state_machine.logging_configuration.level == LoggingLevel.OFF:
|
||||
report.status = "FAIL"
|
||||
report.status_extended = f"Step Functions state machine {state_machine.name} does not have logging enabled."
|
||||
findings.append(report)
|
||||
|
||||
return findings
|
||||
@@ -30,9 +30,9 @@ class GCPBaseException(ProwlerException):
|
||||
"message": "Error testing connection to GCP",
|
||||
"remediation": "Check the connection and ensure it is properly set up.",
|
||||
},
|
||||
(3006, "GCPLoadCredentialsFromDictError"): {
|
||||
"message": "Error loading credentials from dictionary",
|
||||
"remediation": "Check the credentials and ensure they are properly set up. client_id, client_secret and refresh_token are required.",
|
||||
(3006, "GCPLoadADCFromDictError"): {
|
||||
"message": "Error loading Application Default Credentials from dictionary",
|
||||
"remediation": "Check the dictionary and ensure a valid Application Default Credentials are present with client_id, client_secret and refresh_token keys.",
|
||||
},
|
||||
(3007, "GCPStaticCredentialsError"): {
|
||||
"message": "Error loading static credentials",
|
||||
@@ -46,6 +46,10 @@ class GCPBaseException(ProwlerException):
|
||||
"message": "Cloud Asset API not used",
|
||||
"remediation": "Enable the Cloud Asset API for the project.",
|
||||
},
|
||||
(3010, "GCPLoadServiceAccountKeyFromDictError"): {
|
||||
"message": "Error loading Service Account Private Key credentials from dictionary",
|
||||
"remediation": "Check the dictionary and ensure it contains a Service Account Private Key.",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, code, file=None, original_exception=None, message=None):
|
||||
@@ -111,7 +115,7 @@ class GCPTestConnectionError(GCPBaseException):
|
||||
)
|
||||
|
||||
|
||||
class GCPLoadCredentialsFromDictError(GCPCredentialsError):
|
||||
class GCPLoadADCFromDictError(GCPCredentialsError):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
3006, file=file, original_exception=original_exception, message=message
|
||||
@@ -137,3 +141,10 @@ class GCPCloudAssetAPINotUsedError(GCPBaseException):
|
||||
super().__init__(
|
||||
3009, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
|
||||
class GCPLoadServiceAccountKeyFromDictError(GCPCredentialsError):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
3010, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
@@ -25,7 +25,8 @@ from prowler.providers.gcp.exceptions.exceptions import (
|
||||
GCPGetProjectError,
|
||||
GCPHTTPError,
|
||||
GCPInvalidProviderIdError,
|
||||
GCPLoadCredentialsFromDictError,
|
||||
GCPLoadADCFromDictError,
|
||||
GCPLoadServiceAccountKeyFromDictError,
|
||||
GCPNoAccesibleProjectsError,
|
||||
GCPSetUpSessionError,
|
||||
GCPStaticCredentialsError,
|
||||
@@ -57,7 +58,6 @@ class GcpProvider(Provider):
|
||||
- get_projects -> Get the projects accessible by the provided credentials
|
||||
- update_projects_with_organizations -> Update the projects with organizations
|
||||
- is_project_matching -> Check if the input project matches the project to match
|
||||
- validate_static_arguments -> Validate the static arguments
|
||||
- validate_project_id -> Validate the provider ID
|
||||
"""
|
||||
|
||||
@@ -87,6 +87,7 @@ class GcpProvider(Provider):
|
||||
client_id: str = None,
|
||||
client_secret: str = None,
|
||||
refresh_token: str = None,
|
||||
service_account_key: dict = None,
|
||||
):
|
||||
"""
|
||||
GCP Provider constructor
|
||||
@@ -106,11 +107,12 @@ class GcpProvider(Provider):
|
||||
client_id: str
|
||||
client_secret: str
|
||||
refresh_token: str
|
||||
service_account_key: dict
|
||||
|
||||
Raises:
|
||||
GCPNoAccesibleProjectsError if no project IDs can be accessed via Google Credentials
|
||||
GCPSetUpSessionError if an error occurs during the setup session
|
||||
GCPLoadCredentialsFromDictError if an error occurs during the loading credentials from dict
|
||||
GCPLoadADCFromDictError if an error occurs during the loading credentials from dict
|
||||
GCPGetProjectError if an error occurs during the get project
|
||||
|
||||
Returns:
|
||||
@@ -130,6 +132,10 @@ class GcpProvider(Provider):
|
||||
... client_secret="client_secret",
|
||||
... refresh_token="refresh_token"
|
||||
... )
|
||||
- Using the service account key:
|
||||
>>> GcpProvider(
|
||||
... service_account_key={"service_account_key": "service_account_key"}
|
||||
... )
|
||||
- Using a credentials file:
|
||||
>>> GcpProvider(
|
||||
... credentials_file="credentials_file"
|
||||
@@ -167,7 +173,10 @@ class GcpProvider(Provider):
|
||||
)
|
||||
|
||||
self._session, self._default_project_id = self.setup_session(
|
||||
credentials_file, self._impersonated_service_account, gcp_credentials
|
||||
credentials_file=credentials_file,
|
||||
service_account=self._impersonated_service_account,
|
||||
gcp_credentials=gcp_credentials,
|
||||
service_account_key=service_account_key,
|
||||
)
|
||||
|
||||
self._project_ids = []
|
||||
@@ -312,25 +321,39 @@ class GcpProvider(Provider):
|
||||
|
||||
@staticmethod
|
||||
def setup_session(
|
||||
credentials_file: str, service_account: str, gcp_credentials: dict = None
|
||||
credentials_file: str,
|
||||
service_account: str,
|
||||
gcp_credentials: dict = None,
|
||||
service_account_key: dict = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Setup the GCP session with the provided credentials file or service account to impersonate
|
||||
|
||||
Args:
|
||||
credentials_file: str
|
||||
service_account: str
|
||||
credentials_file: str -> The credentials file path used to authenticate
|
||||
service_account: dict -> The service account to impersonate
|
||||
gcp_credentials: dict -> The GCP credentials following the format:
|
||||
{
|
||||
"client_id": str,
|
||||
"client_secret": str,
|
||||
"refresh_token": str,
|
||||
"type": str
|
||||
}
|
||||
service_account_key: dict -> The service account key, used to authenticate
|
||||
|
||||
Returns:
|
||||
Credentials object and default project ID
|
||||
|
||||
Raises:
|
||||
GCPLoadCredentialsFromDictError if an error occurs during the loading credentials from dict
|
||||
GCPLoadADCFromDictError if an error occurs during the loading credentials from dict
|
||||
GCPLoadServiceAccountKeyFromDictError if an error occurs during the loading credentials from the service account key
|
||||
GCPSetUpSessionError if an error occurs during the setup session
|
||||
|
||||
Usage:
|
||||
>>> GcpProvider.setup_session(credentials_file, service_account)
|
||||
>>> GcpProvider.setup_session(service_account, gcp_credentials)
|
||||
>>> GcpProvider.setup_session(service_account, service_account_key)
|
||||
>>> GcpProvider.setup_session(credentials_file, service_account, gcp_credentials)
|
||||
|
||||
"""
|
||||
try:
|
||||
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
@@ -347,7 +370,24 @@ class GcpProvider(Provider):
|
||||
logger.critical(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
raise GCPLoadCredentialsFromDictError(
|
||||
raise GCPLoadADCFromDictError(
|
||||
file=__file__, original_exception=error
|
||||
)
|
||||
|
||||
if service_account_key:
|
||||
logger.info(
|
||||
"GCP provider: Setting credentials from service account key..."
|
||||
)
|
||||
try:
|
||||
credentials, default_project_id = load_credentials_from_dict(
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
return credentials, default_project_id
|
||||
except Exception as error:
|
||||
logger.critical(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
raise GCPLoadServiceAccountKeyFromDictError(
|
||||
file=__file__, original_exception=error
|
||||
)
|
||||
|
||||
@@ -388,10 +428,11 @@ class GcpProvider(Provider):
|
||||
credentials_file: str = None,
|
||||
service_account: str = None,
|
||||
raise_on_exception: bool = True,
|
||||
provider_id: Optional[str] = None,
|
||||
client_id: str = None,
|
||||
client_secret: str = None,
|
||||
refresh_token: str = None,
|
||||
provider_id: Optional[str] = None,
|
||||
service_account_key: dict = None,
|
||||
) -> Connection:
|
||||
"""
|
||||
Test the connection to GCP with the provided credentials file or service account to impersonate.
|
||||
@@ -403,16 +444,18 @@ class GcpProvider(Provider):
|
||||
credentials_file: str
|
||||
service_account: str
|
||||
raise_on_exception: bool
|
||||
provider_id: Optional[str] -> The provider ID, for GCP it is the project ID
|
||||
client_id: str
|
||||
client_secret: str
|
||||
refresh_token: str
|
||||
provider_id: Optional[str] -> The provider ID, for GCP it is the project ID
|
||||
service_account_key: dict
|
||||
|
||||
Returns:
|
||||
Connection object with is_connected set to True if the connection is successful, or error set to the exception if the connection fails
|
||||
|
||||
Raises:
|
||||
GCPLoadCredentialsFromDictError if an error occurs during the loading credentials from dict
|
||||
GCPLoadADCFromDictError if an error occurs during the loading credentials from dict
|
||||
GCPLoadServiceAccountKeyFromDictError if an error occurs during the loading credentials from dict
|
||||
GCPSetUpSessionError if an error occurs during the setup session
|
||||
GCPCloudResourceManagerAPINotUsedError if the Cloud Resource Manager API has not been used before or it is disabled
|
||||
GCPInvalidProviderIdError if the provider ID does not match with the expected project_id
|
||||
@@ -425,10 +468,6 @@ class GcpProvider(Provider):
|
||||
... client_secret="client_secret",
|
||||
... refresh_token="refresh_token"
|
||||
... )
|
||||
- Using a Service Account credentials file path:
|
||||
>>> GcpProvider.test_connection(
|
||||
... credentials_file="credentials_file"
|
||||
... )
|
||||
- Using ADC credentials with a Service Account to impersonate:
|
||||
>>> GcpProvider.test_connection(
|
||||
... client_id="client_id",
|
||||
@@ -436,6 +475,14 @@ class GcpProvider(Provider):
|
||||
... refresh_token="refresh_token",
|
||||
... service_account="service_account"
|
||||
... )
|
||||
- Using service account key:
|
||||
>>> GcpProvider.test_connection(
|
||||
... service_account_key={"service_account_key": "service_account_key"}
|
||||
... )
|
||||
- Using a Service Account credentials file path:
|
||||
>>> GcpProvider.test_connection(
|
||||
... credentials_file="credentials_file"
|
||||
... )
|
||||
"""
|
||||
try:
|
||||
# Set the GCP credentials using the provided client_id, client_secret and refresh_token from ADC
|
||||
@@ -444,9 +491,11 @@ class GcpProvider(Provider):
|
||||
gcp_credentials = GcpProvider.validate_static_arguments(
|
||||
client_id, client_secret, refresh_token
|
||||
)
|
||||
|
||||
session, project_id = GcpProvider.setup_session(
|
||||
credentials_file, service_account, gcp_credentials
|
||||
credentials_file=credentials_file,
|
||||
service_account=service_account,
|
||||
gcp_credentials=gcp_credentials,
|
||||
service_account_key=service_account_key,
|
||||
)
|
||||
if provider_id and project_id != provider_id:
|
||||
# Logic to check if the provider ID matches the project ID
|
||||
@@ -460,7 +509,7 @@ class GcpProvider(Provider):
|
||||
return Connection(is_connected=True)
|
||||
|
||||
# Errors from setup_session
|
||||
except GCPLoadCredentialsFromDictError as load_credentials_error:
|
||||
except GCPLoadServiceAccountKeyFromDictError as load_credentials_error:
|
||||
logger.critical(
|
||||
f"{load_credentials_error.__class__.__name__}[{load_credentials_error.__traceback__.tb_lineno}]: {load_credentials_error}"
|
||||
)
|
||||
@@ -741,18 +790,14 @@ class GcpProvider(Provider):
|
||||
) -> dict:
|
||||
"""
|
||||
Validate the static arguments client_id, client_secret and refresh_token of ADC credentials
|
||||
|
||||
Args:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
refresh_token: str
|
||||
|
||||
Returns:
|
||||
dict
|
||||
|
||||
Raises:
|
||||
GCPStaticCredentialsError if any of the static arguments is missing from the ADC credentials
|
||||
|
||||
Usage:
|
||||
>>> GcpProvider.validate_static_arguments(client_id, client_secret, refresh_token)
|
||||
"""
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_block_project_wide_ssh_keys_disabled(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "FAIL"
|
||||
report.status_extended = f"The VM Instance {instance.name} is making use of common/shared project-wide SSH key(s)."
|
||||
if instance.metadata.get("items"):
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_confidential_computing_enabled(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = (
|
||||
f"VM Instance {instance.name} has Confidential Computing enabled."
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_default_service_account_in_use(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"The default service account is not configured to be used with VM Instance {instance.name}."
|
||||
if (
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_default_service_account_in_use_with_full_api_access(Check
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"The VM Instance {instance.name} is not configured to use the default service account with full access to all cloud APIs."
|
||||
for service_account in instance.service_accounts:
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_encryption_with_csek_enabled(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "FAIL"
|
||||
report.status_extended = f"The VM Instance {instance.name} has the following unencrypted disks: '{', '.join([i[0] for i in instance.disks_encryption if not i[1]])}'."
|
||||
if all([i[1] for i in instance.disks_encryption]):
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_ip_forwarding_is_enabled(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = (
|
||||
f"The IP Forwarding of VM Instance {instance.name} is not enabled."
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_public_ip(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = (
|
||||
f"VM Instance {instance.name} does not have a public IP."
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_serial_ports_in_use(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"VM Instance {instance.name} has Enable Connecting to Serial Ports off."
|
||||
if instance.metadata.get("items"):
|
||||
|
||||
@@ -10,7 +10,7 @@ class compute_instance_shielded_vm_enabled(Check):
|
||||
report.project_id = instance.project_id
|
||||
report.resource_id = instance.id
|
||||
report.resource_name = instance.name
|
||||
report.location = instance.zone
|
||||
report.location = instance.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"VM Instance {instance.name} has vTPM or Integrity Monitoring set to on."
|
||||
if (
|
||||
|
||||
@@ -18,7 +18,7 @@ class compute_network_default_in_use(Check):
|
||||
report.project_id = project
|
||||
report.resource_id = "default"
|
||||
report.resource_name = "default"
|
||||
report.location = "global"
|
||||
report.location = compute_client.region
|
||||
|
||||
if project in projects_with_default_network:
|
||||
report.status = "FAIL"
|
||||
|
||||
@@ -9,7 +9,7 @@ class compute_project_os_login_enabled(Check):
|
||||
report = Check_Report_GCP(self.metadata())
|
||||
report.project_id = project.id
|
||||
report.resource_id = project.id
|
||||
report.location = "global"
|
||||
report.location = compute_client.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Project {project.id} has OS Login enabled."
|
||||
if not project.enable_oslogin:
|
||||
|
||||
@@ -101,6 +101,7 @@ class Compute(GCPService):
|
||||
name=instance["name"],
|
||||
id=instance["id"],
|
||||
zone=zone,
|
||||
region=zone.rsplit("-", 1)[0],
|
||||
public_ip=public_ip,
|
||||
metadata=instance.get("metadata", {}),
|
||||
shielded_enabled_vtpm=instance.get(
|
||||
@@ -306,6 +307,7 @@ class Instance(BaseModel):
|
||||
name: str
|
||||
id: str
|
||||
zone: str
|
||||
region: str
|
||||
public_ip: bool
|
||||
project_id: str
|
||||
metadata: dict
|
||||
|
||||
@@ -10,6 +10,7 @@ class dataproc_encrypted_with_cmks_disabled(Check):
|
||||
report.project_id = cluster.project_id
|
||||
report.resource_id = cluster.id
|
||||
report.resource_name = cluster.name
|
||||
report.location = dataproc_client.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Dataproc cluster {cluster.name} is encrypted with customer managed encryption keys."
|
||||
if cluster.encryption_config.get("gcePdKmsKeyName") is None:
|
||||
|
||||
@@ -10,6 +10,7 @@ class dns_dnssec_disabled(Check):
|
||||
report.project_id = managed_zone.project_id
|
||||
report.resource_id = managed_zone.id
|
||||
report.resource_name = managed_zone.name
|
||||
report.location = dns_client.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = (
|
||||
f"Cloud DNS {managed_zone.name} has DNSSEC enabled."
|
||||
|
||||
@@ -10,6 +10,7 @@ class dns_rsasha1_in_use_to_key_sign_in_dnssec(Check):
|
||||
report.project_id = managed_zone.project_id
|
||||
report.resource_id = managed_zone.id
|
||||
report.resource_name = managed_zone.name
|
||||
report.location = dns_client.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Cloud DNS {managed_zone.name} is not using RSASHA1 algorithm as key signing."
|
||||
if any(
|
||||
|
||||
@@ -10,6 +10,7 @@ class dns_rsasha1_in_use_to_zone_sign_in_dnssec(Check):
|
||||
report.project_id = managed_zone.project_id
|
||||
report.resource_id = managed_zone.id
|
||||
report.resource_name = managed_zone.name
|
||||
report.location = dns_client.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"Cloud DNS {managed_zone.name} is not using RSASHA1 algorithm as zone signing."
|
||||
if any(
|
||||
|
||||
@@ -10,7 +10,7 @@ class gke_cluster_no_default_service_account(Check):
|
||||
report.project_id = cluster.project_id
|
||||
report.resource_id = cluster.id
|
||||
report.resource_name = cluster.name
|
||||
report.location = cluster.location
|
||||
report.location = cluster.region
|
||||
report.status = "PASS"
|
||||
report.status_extended = f"GKE cluster {cluster.name} is not using the Compute Engine default service account."
|
||||
if not cluster.node_pools and cluster.service_account == "default":
|
||||
|
||||
@@ -60,6 +60,7 @@ class GKE(GCPService):
|
||||
name=cluster["name"],
|
||||
id=cluster["id"],
|
||||
location=cluster["location"],
|
||||
region=cluster["location"].rsplit("-", 1)[0],
|
||||
service_account=cluster["nodeConfig"]["serviceAccount"],
|
||||
node_pools=node_pools,
|
||||
project_id=location.project_id,
|
||||
@@ -85,6 +86,7 @@ class NodePool(BaseModel):
|
||||
class Cluster(BaseModel):
|
||||
name: str
|
||||
id: str
|
||||
region: str
|
||||
location: str
|
||||
service_account: str
|
||||
node_pools: list[NodePool]
|
||||
|
||||
@@ -48,8 +48,8 @@ azure-mgmt-storage = "21.2.1"
|
||||
azure-mgmt-subscription = "3.1.1"
|
||||
azure-mgmt-web = "7.3.1"
|
||||
azure-storage-blob = "12.24.0"
|
||||
boto3 = "1.35.78"
|
||||
botocore = "1.35.79"
|
||||
boto3 = "1.35.81"
|
||||
botocore = "1.35.81"
|
||||
colorama = "0.4.6"
|
||||
cryptography = "43.0.1"
|
||||
dash = "2.18.2"
|
||||
@@ -100,7 +100,7 @@ optional = true
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
mkdocs = "1.6.1"
|
||||
mkdocs-git-revision-date-localized-plugin = "1.3.0"
|
||||
mkdocs-material = "9.5.48"
|
||||
mkdocs-material = "9.5.49"
|
||||
mkdocs-material-extensions = "1.3.1"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
@@ -94,12 +94,15 @@ class Test_backup_recovery_point_encrypted:
|
||||
|
||||
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
|
||||
|
||||
with mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
), mock.patch(
|
||||
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
|
||||
new=Backup(aws_provider),
|
||||
with (
|
||||
mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
),
|
||||
mock.patch(
|
||||
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
|
||||
new=Backup(aws_provider),
|
||||
),
|
||||
):
|
||||
# Test Check
|
||||
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
|
||||
@@ -124,12 +127,15 @@ class Test_backup_recovery_point_encrypted:
|
||||
|
||||
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
|
||||
|
||||
with mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
), mock.patch(
|
||||
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
|
||||
new=Backup(aws_provider),
|
||||
with (
|
||||
mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
),
|
||||
mock.patch(
|
||||
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
|
||||
new=Backup(aws_provider),
|
||||
),
|
||||
):
|
||||
# Test Check
|
||||
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
|
||||
@@ -142,9 +148,9 @@ class Test_backup_recovery_point_encrypted:
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "FAIL"
|
||||
assert result[0].status_extended == (
|
||||
"Backup Recovery Point arn:aws:backup:eu-west-1:123456789012:recovery-point:1 for Backup Vault Test Vault is not encrypted at rest."
|
||||
"Backup Recovery Point 1 for Backup Vault Test Vault is not encrypted at rest."
|
||||
)
|
||||
assert result[0].resource_id == "Test Vault"
|
||||
assert result[0].resource_id == "1"
|
||||
assert (
|
||||
result[0].resource_arn
|
||||
== "arn:aws:backup:eu-west-1:123456789012:recovery-point:1"
|
||||
@@ -165,12 +171,15 @@ class Test_backup_recovery_point_encrypted:
|
||||
|
||||
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
|
||||
|
||||
with mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
), mock.patch(
|
||||
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
|
||||
new=Backup(aws_provider),
|
||||
with (
|
||||
mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
),
|
||||
mock.patch(
|
||||
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
|
||||
new=Backup(aws_provider),
|
||||
),
|
||||
):
|
||||
# Test Check
|
||||
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
|
||||
@@ -183,9 +192,9 @@ class Test_backup_recovery_point_encrypted:
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "PASS"
|
||||
assert result[0].status_extended == (
|
||||
"Backup Recovery Point arn:aws:backup:eu-west-1:123456789012:recovery-point:1 for Backup Vault Test Vault is encrypted at rest."
|
||||
"Backup Recovery Point 1 for Backup Vault Test Vault is encrypted at rest."
|
||||
)
|
||||
assert result[0].resource_id == "Test Vault"
|
||||
assert result[0].resource_id == "1"
|
||||
assert (
|
||||
result[0].resource_arn
|
||||
== "arn:aws:backup:eu-west-1:123456789012:recovery-point:1"
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
from unittest import mock
|
||||
|
||||
import botocore
|
||||
import botocore.client
|
||||
from boto3 import client
|
||||
from moto import mock_aws
|
||||
|
||||
from tests.providers.aws.utils import (
|
||||
AWS_REGION_EU_WEST_1,
|
||||
AWS_REGION_US_EAST_1,
|
||||
set_mocked_aws_provider,
|
||||
)
|
||||
|
||||
mock_make_api_call = botocore.client.BaseClient._make_api_call
|
||||
|
||||
|
||||
def mock_make_api_call_error(self, operation_name, kwarg):
|
||||
if operation_name == "PutPublicAccessBlock":
|
||||
raise botocore.exceptions.ClientError(
|
||||
{
|
||||
"Error": {
|
||||
"Code": "InvalidPermission.NotFound",
|
||||
"Message": "The specified rule does not exist in this security group.",
|
||||
}
|
||||
},
|
||||
operation_name,
|
||||
)
|
||||
return mock_make_api_call(self, operation_name, kwarg)
|
||||
|
||||
|
||||
class Test_cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer:
|
||||
@mock_aws
|
||||
def test_trail_bucket_public_acl(self):
|
||||
aws_provider = set_mocked_aws_provider(
|
||||
[AWS_REGION_US_EAST_1, AWS_REGION_EU_WEST_1]
|
||||
)
|
||||
s3_client = client("s3", region_name=AWS_REGION_US_EAST_1)
|
||||
bucket_name_us = "bucket_test_us"
|
||||
s3_client.create_bucket(Bucket=bucket_name_us)
|
||||
s3_client.put_bucket_acl(
|
||||
AccessControlPolicy={
|
||||
"Grants": [
|
||||
{
|
||||
"Grantee": {
|
||||
"DisplayName": "test",
|
||||
"EmailAddress": "",
|
||||
"ID": "test_ID",
|
||||
"Type": "Group",
|
||||
"URI": "http://acs.amazonaws.com/groups/global/AllUsers",
|
||||
},
|
||||
"Permission": "READ",
|
||||
},
|
||||
],
|
||||
"Owner": {"DisplayName": "test", "ID": "test_id"},
|
||||
},
|
||||
Bucket=bucket_name_us,
|
||||
)
|
||||
|
||||
trail_name_us = "trail_test_us"
|
||||
cloudtrail_client = client("cloudtrail", region_name=AWS_REGION_US_EAST_1)
|
||||
cloudtrail_client.create_trail(
|
||||
Name=trail_name_us, S3BucketName=bucket_name_us, IsMultiRegionTrail=False
|
||||
)
|
||||
|
||||
from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
|
||||
Cloudtrail,
|
||||
)
|
||||
from prowler.providers.aws.services.s3.s3_service import S3
|
||||
|
||||
with mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
), mock.patch(
|
||||
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.cloudtrail_client",
|
||||
new=Cloudtrail(aws_provider),
|
||||
), mock.patch(
|
||||
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.s3_client",
|
||||
new=S3(aws_provider),
|
||||
):
|
||||
# Test Check
|
||||
from prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer import (
|
||||
fixer,
|
||||
)
|
||||
|
||||
assert fixer(trail_name_us, AWS_REGION_US_EAST_1)
|
||||
|
||||
@mock_aws
|
||||
def test_trail_bucket_public_acl_error(self):
|
||||
with mock.patch(
|
||||
"botocore.client.BaseClient._make_api_call", new=mock_make_api_call_error
|
||||
):
|
||||
aws_provider = set_mocked_aws_provider(
|
||||
[AWS_REGION_US_EAST_1, AWS_REGION_EU_WEST_1]
|
||||
)
|
||||
s3_client = client("s3", region_name=AWS_REGION_US_EAST_1)
|
||||
bucket_name_us = "bucket_test_us"
|
||||
s3_client.create_bucket(Bucket=bucket_name_us)
|
||||
s3_client.put_bucket_acl(
|
||||
AccessControlPolicy={
|
||||
"Grants": [
|
||||
{
|
||||
"Grantee": {
|
||||
"DisplayName": "test",
|
||||
"EmailAddress": "",
|
||||
"ID": "test_ID",
|
||||
"Type": "Group",
|
||||
"URI": "http://acs.amazonaws.com/groups/global/AllUsers",
|
||||
},
|
||||
"Permission": "READ",
|
||||
},
|
||||
],
|
||||
"Owner": {"DisplayName": "test", "ID": "test_id"},
|
||||
},
|
||||
Bucket=bucket_name_us,
|
||||
)
|
||||
|
||||
trail_name_us = "trail_test_us"
|
||||
cloudtrail_client = client("cloudtrail", region_name=AWS_REGION_US_EAST_1)
|
||||
cloudtrail_client.create_trail(
|
||||
Name=trail_name_us,
|
||||
S3BucketName=bucket_name_us,
|
||||
IsMultiRegionTrail=False,
|
||||
)
|
||||
|
||||
from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
|
||||
Cloudtrail,
|
||||
)
|
||||
from prowler.providers.aws.services.s3.s3_service import S3
|
||||
|
||||
with mock.patch(
|
||||
"prowler.providers.common.provider.Provider.get_global_provider",
|
||||
return_value=aws_provider,
|
||||
), mock.patch(
|
||||
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.cloudtrail_client",
|
||||
new=Cloudtrail(aws_provider),
|
||||
), mock.patch(
|
||||
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.s3_client",
|
||||
new=S3(aws_provider),
|
||||
):
|
||||
# Test Check
|
||||
from prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer import (
|
||||
fixer,
|
||||
)
|
||||
|
||||
assert not fixer(trail_name_us, AWS_REGION_US_EAST_1)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user