Compare commits

...

47 Commits

Author SHA1 Message Date
Pablo Lara
2d190eb020 chore: add manage group actions 2024-12-18 17:38:09 +01:00
Pablo Lara
0459a4d6f6 Merge branch 'PRWLR-5824-Update-resource-name-and-enable-relationships-on-role-and-provider_group-create-update' into PRWLR-4669-Roles-Page-UI-with-API-changes 2024-12-18 12:10:45 +01:00
Pablo Lara
19df649554 chore: add manage group actions 2024-12-18 11:15:37 +01:00
Adrián Jesús Peña Rodríguez
737550eb05 ref(rbac): update spec 2024-12-18 11:08:19 +01:00
Adrián Jesús Peña Rodríguez
68d7d9f998 ref(rbac): enable relationship creation when objects is created 2024-12-18 11:05:29 +01:00
Adrián Jesús Peña Rodríguez
fa400ded7d ref(rbac): improve rbac implementation for views (#6226) 2024-12-17 18:11:48 +01:00
dependabot[bot]
ec9455ff75 chore(deps): bump boto3 from 1.35.80 to 1.35.81 (#6218)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-17 11:32:30 -05:00
Daniel Barranquero
2183f31ff5 feat(ec2): add new fixers for internet exposed ports (#6223) 2024-12-17 10:04:00 -05:00
Prowler Bot
67257a4212 chore(regions_update): Changes in regions for AWS services (#6222)
Co-authored-by: MrCloudSec <38561120+MrCloudSec@users.noreply.github.com>
2024-12-17 10:00:52 -05:00
Pedro Martín
001fa60a11 feat(mutelist): add description field (#6221)
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2024-12-17 15:13:55 +01:00
Pablo Lara
3c9a8b3634 chore: add manage group component 2024-12-17 11:13:37 +01:00
Víctor Fernández Poyatos
0ec3ed8be7 feat(services): Add GET /overviews/services to API (#6029) 2024-12-17 08:47:44 +01:00
dependabot[bot]
3ed0b8a464 chore(deps-dev): bump mkdocs-material from 9.5.48 to 9.5.49 (#6217)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-17 08:42:55 +01:00
Pedro Martín
fd610d44c0 refactor(gcp): use always <client>.region for checks (#6206) 2024-12-16 18:21:42 -05:00
Adrián Jesús Peña Rodríguez
b8cc4b4f0f feat(stepfunctions): add stepfunctions service and check stepfunctions_statemachine_logging_enabled (#5466)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
Co-authored-by: Rubén De la Torre Vico <rubendltv22@gmail.com>
2024-12-16 11:34:02 -05:00
Pedro Martín
396e51c27d feat(gcp): add service account credentials (#6165) 2024-12-16 10:11:32 -05:00
Daniel Barranquero
36e61cb7a2 feat(ec2): add new fixer ec2_ami_public_fixer (#6177) 2024-12-16 10:09:14 -05:00
Daniel Barranquero
78c6484ddb feat(cloudtrail): add new fixer cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer (#6174) 2024-12-16 10:05:34 -05:00
Daniel Barranquero
3f1e90a5b3 feat(s3): add new fixer s3_bucket_policy_public_write_access_fixer (#6173) 2024-12-16 10:01:38 -05:00
dependabot[bot]
e1bfec898f chore(deps): bump botocore from 1.35.80 to 1.35.81 (#6199)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-16 09:57:03 -05:00
dependabot[bot]
b5b816dac9 chore(deps): bump boto3 from 1.35.79 to 1.35.80 (#6198)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-16 07:51:44 -05:00
Pablo Lara
81f970f2d3 chore: refactor updateInvite action form 2024-12-16 13:34:16 +01:00
Pablo Lara
3d9cd177a2 chore: add role column to the table users 2024-12-16 12:59:40 +01:00
Pepe Fagoaga
57854f23b7 chore(rls): rename tenant_transaction to rls_transaction (#6202) 2024-12-16 12:27:55 +01:00
Rubén De la Torre Vico
9d7499b74f fix(azure): custom Prowler Role for Azure assignableScopes (#6149) 2024-12-16 08:34:17 +01:00
Pablo Lara
c49fdc114a chore: report an error related to RBAC API side 2024-12-15 11:43:46 +01:00
Pablo Lara
95fd9d6b5e WIP: add change role to the user's invitations 2024-12-15 10:52:41 +01:00
Pablo Lara
6a5bc75252 chore: add change role to the user's invitations 2024-12-15 10:52:33 +01:00
Pablo Lara
858c04b0b0 chore: fix error with exports 2024-12-15 10:52:09 +01:00
Pablo Lara
2d6f20e84b feat: add role when invite an user 2024-12-15 10:51:58 +01:00
Pablo Lara
b0a98b1a87 feat: add permission column to roles table 2024-12-15 10:51:49 +01:00
Pablo Lara
577530ac69 chore: add and edit roles is working now 2024-12-15 10:51:10 +01:00
Pablo Lara
c1a8d47e5b feat: edit role feature 2024-12-15 10:50:39 +01:00
Pablo Lara
e80704d6f0 feat: add new role feature 2024-12-15 10:50:32 +01:00
Pablo Lara
010de4b415 feat: add roles page 2024-12-15 10:50:24 +01:00
Pablo Lara
0a2b8e4315 chore: add roles item to the sidebar 2024-12-15 10:50:16 +01:00
dependabot[bot]
5b0b85c0f8 chore(deps): bump actions/setup-node from 3 to 4 (#5893)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-13 14:57:27 +01:00
Pedro Martín
f7e8df618b chore(labeler): add provider github (#6194) 2024-12-13 09:43:49 -04:00
Adrián Jesús Peña Rodríguez
d00d254c90 feat(api): RBAC system (#6114) 2024-12-13 14:14:40 +01:00
dependabot[bot]
f9fbde6637 chore(deps): bump botocore from 1.35.79 to 1.35.80 (#6172)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-13 13:20:40 +01:00
Sergio Garcia
7b1a0474db fix(aws): set unique resource IDs (#6152) 2024-12-13 13:00:38 +01:00
Pepe Fagoaga
da4f9b8e5f fix(RLS): enforce config security (#6066) 2024-12-13 12:55:09 +01:00
Pepe Fagoaga
32f69d24b6 fix: dependabot syntax (#6181) 2024-12-13 12:20:43 +01:00
Pepe Fagoaga
d032a61a9e chore(dependabot): Add docker (#6180) 2024-12-13 12:13:53 +01:00
dependabot[bot]
07e0dc2ef5 chore(deps): bump cross-spawn from 7.0.3 to 7.0.6 in /ui (#5881)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2024-12-13 08:25:57 +01:00
dependabot[bot]
9e175e8504 chore(deps): bump nanoid from 3.3.7 to 3.3.8 in /ui (#6110)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-13 07:59:50 +01:00
dependabot[bot]
6b8a434cda chore(deps): bump boto3 from 1.35.78 to 1.35.79 (#6171)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-13 07:58:58 +01:00
190 changed files with 16839 additions and 612 deletions

View File

@@ -36,6 +36,16 @@ updates:
- "dependencies"
- "npm"
- package-ecosystem: "docker"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 10
target-branch: master
labels:
- "dependencies"
- "docker"
# v4.6
- package-ecosystem: "pip"
directory: "/"
@@ -59,6 +69,17 @@ updates:
- "github_actions"
- "v4"
- package-ecosystem: "docker"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 10
target-branch: v4.6
labels:
- "dependencies"
- "docker"
- "v4"
# v3
- package-ecosystem: "pip"
directory: "/"

5
.github/labeler.yml vendored
View File

@@ -22,6 +22,11 @@ provider/kubernetes:
- any-glob-to-any-file: "prowler/providers/kubernetes/**"
- any-glob-to-any-file: "tests/providers/kubernetes/**"
provider/github:
- changed-files:
- any-glob-to-any-file: "prowler/providers/github/**"
- any-glob-to-any-file: "tests/providers/github/**"
github_actions:
- changed-files:
- any-glob-to-any-file: ".github/workflows/*"

View File

@@ -20,7 +20,7 @@ jobs:
with:
persist-credentials: false
- name: Setup Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v3
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}
- name: Install dependencies

View File

@@ -1,4 +1,4 @@
FROM python:3.12-alpine AS build
FROM python:3.12.8-alpine3.20 AS build
LABEL maintainer="https://github.com/prowler-cloud/api"

View File

@@ -8,7 +8,7 @@ description = "Prowler's API (Django/DRF)"
license = "Apache-2.0"
name = "prowler-api"
package-mode = false
version = "1.0.0"
version = "1.1.0"
[tool.poetry.dependencies]
celery = {extras = ["pytest"], version = "^5.4.0"}

View File

@@ -1,20 +1,23 @@
import uuid
from django.db import connection, transaction
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
from rest_framework.filters import SearchFilter
from rest_framework_json_api import filters
from rest_framework_json_api.serializers import ValidationError
from rest_framework_json_api.views import ModelViewSet
from rest_framework_simplejwt.authentication import JWTAuthentication
from api.db_router import MainRouter
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, Tenant
from api.rbac.permissions import HasPermissions
class BaseViewSet(ModelViewSet):
authentication_classes = [JWTAuthentication]
permission_classes = [permissions.IsAuthenticated]
required_permissions = []
permission_classes = [permissions.IsAuthenticated, HasPermissions]
filter_backends = [
filters.QueryParameterValidationFilter,
filters.OrderingFilter,
@@ -28,6 +31,17 @@ class BaseViewSet(ModelViewSet):
ordering_fields = "__all__"
ordering = ["id"]
def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.set_required_permissions()
super().initial(request, *args, **kwargs)
def set_required_permissions(self):
"""This is an abstract method that must be implemented by subclasses."""
NotImplemented
def get_queryset(self):
raise NotImplementedError
@@ -47,13 +61,7 @@ class BaseRLSViewSet(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
@@ -66,7 +74,39 @@ class BaseRLSViewSet(BaseViewSet):
class BaseTenantViewset(BaseViewSet):
def dispatch(self, request, *args, **kwargs):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
tenant = super().dispatch(request, *args, **kwargs)
try:
# If the request is a POST, create the admin role
if request.method == "POST":
isinstance(tenant, dict) and self._create_admin_role(tenant.data["id"])
except Exception as e:
self._handle_creation_error(e, tenant)
raise
return tenant
def _create_admin_role(self, tenant_id):
Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant_id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
def _handle_creation_error(self, error, tenant):
if tenant.data.get("id"):
try:
Tenant.objects.using(MainRouter.admin_db).filter(
id=tenant.data["id"]
).delete()
except ObjectDoesNotExist:
pass # Tenant might not exist, handle gracefully
def initial(self, request, *args, **kwargs):
if (
@@ -75,8 +115,7 @@ class BaseTenantViewset(BaseViewSet):
):
user_id = str(request.user.id)
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
return super().initial(request, *args, **kwargs)
# TODO: DRY this when we have time
@@ -87,13 +126,7 @@ class BaseTenantViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
@@ -114,12 +147,6 @@ class BaseUserViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

View File

@@ -1,4 +1,5 @@
import secrets
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
@@ -8,6 +9,7 @@ from django.core.paginator import Paginator
from django.db import connection, models, transaction
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
@@ -23,6 +25,8 @@ TASK_RUNNER_DB_TABLE = "django_celery_results_taskresult"
POSTGRES_TENANT_VAR = "api.tenant_id"
POSTGRES_USER_VAR = "api.user_id"
SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"
@contextmanager
def psycopg_connection(database_alias: str):
@@ -44,10 +48,23 @@ def psycopg_connection(database_alias: str):
@contextmanager
def tenant_transaction(tenant_id: str):
def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
"""
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
if the value is a valid UUID.
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
"""
with transaction.atomic():
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
try:
# just in case the value is an UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor

View File

@@ -1,6 +1,10 @@
import uuid
from functools import wraps
from django.db import connection, transaction
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
def set_tenant(func):
@@ -31,7 +35,7 @@ def set_tenant(func):
pass
# When calling the task
some_task.delay(arg1, tenant_id="1234-abcd-5678")
some_task.delay(arg1, tenant_id="8db7ca86-03cc-4d42-99f6-5e480baf6ab5")
# The tenant context will be set before the task logic executes.
"""
@@ -43,9 +47,12 @@ def set_tenant(func):
tenant_id = kwargs.pop("tenant_id")
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
return func(*args, **kwargs)

View File

@@ -26,11 +26,13 @@ from api.models import (
Finding,
Invitation,
Membership,
PermissionChoices,
Provider,
ProviderGroup,
ProviderSecret,
Resource,
ResourceTag,
Role,
Scan,
ScanSummary,
SeverityChoices,
@@ -481,6 +483,26 @@ class UserFilter(FilterSet):
}
class RoleFilter(FilterSet):
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
updated_at = DateFilter(field_name="updated_at", lookup_expr="date")
permission_state = ChoiceFilter(
choices=PermissionChoices.choices, method="filter_permission_state"
)
def filter_permission_state(self, queryset, name, value):
return Role.filter_by_permission_state(queryset, value)
class Meta:
model = Role
fields = {
"id": ["exact", "in"],
"name": ["exact", "in"],
"inserted_at": ["gte", "lte"],
"updated_at": ["gte", "lte"],
}
class ComplianceOverviewFilter(FilterSet):
inserted_at = DateFilter(field_name="inserted_at", lookup_expr="date")
provider_type = ChoiceFilter(choices=Provider.ProviderChoices.choices)
@@ -521,3 +543,25 @@ class ScanSummaryFilter(FilterSet):
"inserted_at": ["date", "gte", "lte"],
"region": ["exact", "icontains", "in"],
}
class ServiceOverviewFilter(ScanSummaryFilter):
muted_findings = None
def is_valid(self):
# Check if at least one of the inserted_at filters is present
inserted_at_filters = [
self.data.get("inserted_at"),
self.data.get("inserted_at__gte"),
self.data.get("inserted_at__lte"),
]
if not any(inserted_at_filters):
raise ValidationError(
{
"inserted_at": [
"At least one of filter[inserted_at], filter[inserted_at__gte], or "
"filter[inserted_at__lte] is required."
]
}
)
return super().is_valid()

View File

@@ -58,5 +58,96 @@
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
"inserted_at": "2024-11-13T11:55:41.237Z"
}
},
{
"model": "api.role",
"pk": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"name": "admin_test",
"manage_users": true,
"manage_account": true,
"manage_billing": true,
"manage_providers": true,
"manage_integrations": true,
"manage_scans": true,
"unlimited_visibility": true,
"inserted_at": "2024-11-20T15:32:42.402Z",
"updated_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.role",
"pk": "845ff03a-87ef-42ba-9786-6577c70c4df0",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"name": "first_role",
"manage_users": true,
"manage_account": true,
"manage_billing": true,
"manage_providers": true,
"manage_integrations": false,
"manage_scans": false,
"unlimited_visibility": true,
"inserted_at": "2024-11-20T15:31:53.239Z",
"updated_at": "2024-11-20T15:31:53.239Z"
}
},
{
"model": "api.role",
"pk": "902d726c-4bd5-413a-a2a4-f7b4754b6b20",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"name": "third_role",
"manage_users": false,
"manage_account": false,
"manage_billing": false,
"manage_providers": false,
"manage_integrations": false,
"manage_scans": true,
"unlimited_visibility": false,
"inserted_at": "2024-11-20T15:34:05.440Z",
"updated_at": "2024-11-20T15:34:05.440Z"
}
},
{
"model": "api.roleprovidergrouprelationship",
"pk": "57fd024a-0a7f-49b4-a092-fa0979a07aaf",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"provider_group": "3fe28fb8-e545-424c-9b8f-69aff638f430",
"inserted_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.roleprovidergrouprelationship",
"pk": "a3cd0099-1c13-4df1-a5e5-ecdfec561b35",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"provider_group": "481769f5-db2b-447b-8b00-1dee18db90ec",
"inserted_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.roleprovidergrouprelationship",
"pk": "cfd84182-a058-40c2-af3c-0189b174940f",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"provider_group": "525e91e7-f3f3-4254-bbc3-27ce1ade86b1",
"inserted_at": "2024-11-20T15:32:42.402Z"
}
},
{
"model": "api.userrolerelationship",
"pk": "92339663-e954-4fd8-98fb-8bfe15949975",
"fields": {
"tenant": "12646005-9067-4d2a-a098-8bb378604362",
"role": "3f01e759-bdf9-4a99-8888-1ab805b79f93",
"user": "8b38e2eb-6689-4f1e-a4ba-95b275130200",
"inserted_at": "2024-11-20T15:36:14.302Z"
}
}
]

View File

@@ -0,0 +1,246 @@
# Generated by Django 5.1.1 on 2024-12-05 12:29
import api.rls
import django.db.models.deletion
import uuid
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("api", "0002_token_migrations"),
]
operations = [
migrations.CreateModel(
name="Role",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("name", models.CharField(max_length=255)),
("manage_users", models.BooleanField(default=False)),
("manage_account", models.BooleanField(default=False)),
("manage_billing", models.BooleanField(default=False)),
("manage_providers", models.BooleanField(default=False)),
("manage_integrations", models.BooleanField(default=False)),
("manage_scans", models.BooleanField(default=False)),
("unlimited_visibility", models.BooleanField(default=False)),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "roles",
},
),
migrations.CreateModel(
name="RoleProviderGroupRelationship",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "role_provider_group_relationship",
},
),
migrations.CreateModel(
name="UserRoleRelationship",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "role_user_relationship",
},
),
migrations.AddField(
model_name="roleprovidergrouprelationship",
name="provider_group",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.providergroup"
),
),
migrations.AddField(
model_name="roleprovidergrouprelationship",
name="role",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.role"
),
),
migrations.AddField(
model_name="role",
name="provider_groups",
field=models.ManyToManyField(
related_name="roles",
through="api.RoleProviderGroupRelationship",
to="api.providergroup",
),
),
migrations.AddField(
model_name="userrolerelationship",
name="role",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.role"
),
),
migrations.AddField(
model_name="userrolerelationship",
name="user",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
),
migrations.AddField(
model_name="role",
name="users",
field=models.ManyToManyField(
related_name="roles",
through="api.UserRoleRelationship",
to=settings.AUTH_USER_MODEL,
),
),
migrations.AddConstraint(
model_name="roleprovidergrouprelationship",
constraint=models.UniqueConstraint(
fields=("role_id", "provider_group_id"),
name="unique_role_provider_group_relationship",
),
),
migrations.AddConstraint(
model_name="roleprovidergrouprelationship",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_roleprovidergrouprelationship",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddConstraint(
model_name="userrolerelationship",
constraint=models.UniqueConstraint(
fields=("role_id", "user_id"), name="unique_role_user_relationship"
),
),
migrations.AddConstraint(
model_name="userrolerelationship",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_userrolerelationship",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddConstraint(
model_name="role",
constraint=models.UniqueConstraint(
fields=("tenant_id", "name"), name="unique_role_per_tenant"
),
),
migrations.AddConstraint(
model_name="role",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_role",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.CreateModel(
name="InvitationRoleRelationship",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
(
"invitation",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.invitation"
),
),
(
"role",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.role"
),
),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "role_invitation_relationship",
},
),
migrations.AddConstraint(
model_name="invitationrolerelationship",
constraint=models.UniqueConstraint(
fields=("role_id", "invitation_id"),
name="unique_role_invitation_relationship",
),
),
migrations.AddConstraint(
model_name="invitationrolerelationship",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_invitationrolerelationship",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddField(
model_name="role",
name="invitations",
field=models.ManyToManyField(
related_name="roles",
through="api.InvitationRoleRelationship",
to="api.invitation",
),
),
]

View File

@@ -0,0 +1,43 @@
from django.db import migrations
from api.db_router import MainRouter
def create_admin_role(apps, schema_editor):
Tenant = apps.get_model("api", "Tenant")
Role = apps.get_model("api", "Role")
User = apps.get_model("api", "User")
UserRoleRelationship = apps.get_model("api", "UserRoleRelationship")
for tenant in Tenant.objects.using(MainRouter.admin_db).all():
admin_role, _ = Role.objects.using(MainRouter.admin_db).get_or_create(
name="admin",
tenant=tenant,
defaults={
"manage_users": True,
"manage_account": True,
"manage_billing": True,
"manage_providers": True,
"manage_integrations": True,
"manage_scans": True,
"unlimited_visibility": True,
},
)
users = User.objects.using(MainRouter.admin_db).filter(
membership__tenant=tenant
)
for user in users:
UserRoleRelationship.objects.using(MainRouter.admin_db).get_or_create(
user=user,
role=admin_role,
tenant=tenant,
)
class Migration(migrations.Migration):
dependencies = [
("api", "0003_rbac"),
]
operations = [
migrations.RunPython(create_admin_role),
]

View File

@@ -69,6 +69,21 @@ class StateChoices(models.TextChoices):
CANCELLED = "cancelled", _("Cancelled")
class PermissionChoices(models.TextChoices):
"""
Represents the different permission states that a role can have.
Attributes:
UNLIMITED: Indicates that the role possesses all permissions.
LIMITED: Indicates that the role has some permissions but not all.
NONE: Indicates that the role does not have any permissions.
"""
UNLIMITED = "unlimited", _("Unlimited permissions")
LIMITED = "limited", _("Limited permissions")
NONE = "none", _("No permissions")
class ActiveProviderManager(models.Manager):
def get_queryset(self):
return super().get_queryset().filter(self.active_provider_filter())
@@ -298,19 +313,10 @@ class ProviderGroup(RowLevelSecurityProtectedModel):
class ProviderGroupMembership(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
provider = models.ForeignKey(
Provider,
on_delete=models.CASCADE,
)
provider_group = models.ForeignKey(
ProviderGroup,
on_delete=models.CASCADE,
)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
provider = models.ForeignKey(Provider, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "provider_group_memberships"
@@ -327,7 +333,7 @@ class ProviderGroupMembership(RowLevelSecurityProtectedModel):
]
class JSONAPIMeta:
resource_name = "provider-group-memberships"
resource_name = "provider_groups-provider"
class Task(RowLevelSecurityProtectedModel):
@@ -851,6 +857,150 @@ class Invitation(RowLevelSecurityProtectedModel):
resource_name = "invitations"
class Role(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
name = models.CharField(max_length=255)
manage_users = models.BooleanField(default=False)
manage_account = models.BooleanField(default=False)
manage_billing = models.BooleanField(default=False)
manage_providers = models.BooleanField(default=False)
manage_integrations = models.BooleanField(default=False)
manage_scans = models.BooleanField(default=False)
unlimited_visibility = models.BooleanField(default=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
provider_groups = models.ManyToManyField(
ProviderGroup, through="RoleProviderGroupRelationship", related_name="roles"
)
users = models.ManyToManyField(
User, through="UserRoleRelationship", related_name="roles"
)
invitations = models.ManyToManyField(
Invitation, through="InvitationRoleRelationship", related_name="roles"
)
# Filter permission_state
PERMISSION_FIELDS = [
"manage_users",
"manage_account",
"manage_billing",
"manage_providers",
"manage_integrations",
"manage_scans",
]
@property
def permission_state(self):
values = [getattr(self, field) for field in self.PERMISSION_FIELDS]
if all(values):
return PermissionChoices.UNLIMITED
elif not any(values):
return PermissionChoices.NONE
else:
return PermissionChoices.LIMITED
@classmethod
def filter_by_permission_state(cls, queryset, value):
q_all_true = Q(**{field: True for field in cls.PERMISSION_FIELDS})
q_all_false = Q(**{field: False for field in cls.PERMISSION_FIELDS})
if value == PermissionChoices.UNLIMITED:
return queryset.filter(q_all_true)
elif value == PermissionChoices.NONE:
return queryset.filter(q_all_false)
else:
return queryset.exclude(q_all_true | q_all_false)
class Meta:
db_table = "roles"
constraints = [
models.UniqueConstraint(
fields=["tenant_id", "name"],
name="unique_role_per_tenant",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "roles"
class RoleProviderGroupRelationship(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
role = models.ForeignKey(Role, on_delete=models.CASCADE)
provider_group = models.ForeignKey(ProviderGroup, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "role_provider_group_relationship"
constraints = [
models.UniqueConstraint(
fields=["role_id", "provider_group_id"],
name="unique_role_provider_group_relationship",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "role-provider_groups"
class UserRoleRelationship(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
role = models.ForeignKey(Role, on_delete=models.CASCADE)
user = models.ForeignKey(User, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "role_user_relationship"
constraints = [
models.UniqueConstraint(
fields=["role_id", "user_id"],
name="unique_role_user_relationship",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "user-roles"
class InvitationRoleRelationship(RowLevelSecurityProtectedModel):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
role = models.ForeignKey(Role, on_delete=models.CASCADE)
invitation = models.ForeignKey(Invitation, on_delete=models.CASCADE)
inserted_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "role_invitation_relationship"
constraints = [
models.UniqueConstraint(
fields=["role_id", "invitation_id"],
name="unique_role_invitation_relationship",
),
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
class JSONAPIMeta:
resource_name = "invitation-roles"
class ComplianceOverview(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

View File

@@ -0,0 +1,70 @@
from enum import Enum
from rest_framework.permissions import BasePermission
from api.models import Provider, Role, User
from api.db_router import MainRouter
from typing import Optional
from django.db.models import QuerySet
class Permissions(Enum):
MANAGE_USERS = "manage_users"
MANAGE_ACCOUNT = "manage_account"
MANAGE_BILLING = "manage_billing"
MANAGE_PROVIDERS = "manage_providers"
MANAGE_INTEGRATIONS = "manage_integrations"
MANAGE_SCANS = "manage_scans"
UNLIMITED_VISIBILITY = "unlimited_visibility"
class HasPermissions(BasePermission):
"""
Custom permission to check if the user's role has the required permissions.
The required permissions should be specified in the view as a list in `required_permissions`.
"""
def has_permission(self, request, view):
required_permissions = getattr(view, "required_permissions", [])
if not required_permissions:
return True
user_roles = (
User.objects.using(MainRouter.admin_db).get(id=request.user.id).roles.all()
)
if not user_roles:
return False
for perm in required_permissions:
if not getattr(user_roles[0], perm.value, False):
return False
return True
def get_role(user: User) -> Optional[Role]:
"""
Retrieve the first role assigned to the given user.
Returns:
The user's first Role instance if the user has any roles, otherwise None.
"""
return user.roles.first()
def get_providers(role: Role) -> QuerySet[Provider]:
"""
Return a distinct queryset of Providers accessible by the given role.
If the role has no associated provider groups, an empty queryset is returned.
Args:
role: A Role instance.
Returns:
A QuerySet of Provider objects filtered by the role's provider groups.
If the role has no provider groups, returns an empty queryset.
"""
provider_groups = role.provider_groups.all()
if not provider_groups.exists():
return Provider.objects.none()
return Provider.objects.filter(provider_groups__in=provider_groups).distinct()

View File

@@ -2,7 +2,7 @@ from contextlib import nullcontext
from rest_framework_json_api.renderers import JSONRenderer
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
class APIJSONRenderer(JSONRenderer):
@@ -13,9 +13,9 @@ class APIJSONRenderer(JSONRenderer):
tenant_id = getattr(request, "tenant_id", None) if request else None
include_param_present = "include" in request.query_params if request else False
# Use tenant_transaction if needed for included resources, otherwise do nothing
# Use rls_transaction if needed for included resources, otherwise do nothing
context_manager = (
tenant_transaction(tenant_id)
rls_transaction(tenant_id)
if tenant_id and include_param_present
else nullcontext()
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,10 @@
import pytest
from django.urls import reverse
from unittest.mock import patch
from rest_framework.test import APIClient
from conftest import TEST_PASSWORD, get_api_tokens, get_authorization_header
@patch("api.v1.views.MainRouter.admin_db", new="default")
@pytest.mark.django_db
def test_basic_authentication():
client = APIClient()

View File

@@ -13,6 +13,7 @@ def test_check_resources_between_different_tenants(
enforce_test_user_db_connection,
authenticated_api_client,
tenants_fixture,
set_user_admin_roles_fixture,
):
client = authenticated_api_client

View File

@@ -6,8 +6,10 @@ from django.db.utils import ConnectionRouter
from api.db_router import MainRouter
from api.rls import Tenant
from config.django.base import DATABASE_ROUTERS as PROD_DATABASE_ROUTERS
from unittest.mock import patch
@patch("api.db_router.MainRouter.admin_db", new="admin")
class TestMainDatabaseRouter:
@pytest.fixture(scope="module")
def router(self):

View File

@@ -1,7 +1,9 @@
from unittest.mock import patch, call
import uuid
from unittest.mock import call, patch
import pytest
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.decorators import set_tenant
@@ -15,12 +17,12 @@ class TestSetTenantDecorator:
def random_func(arg):
return arg
tenant_id = "1234-abcd-5678"
tenant_id = str(uuid.uuid4())
result = random_func("test_arg", tenant_id=tenant_id)
assert (
call(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
call(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
in mock_cursor.execute.mock_calls
)
assert result == "test_arg"

View File

@@ -0,0 +1,306 @@
import pytest
from django.urls import reverse
from rest_framework import status
from unittest.mock import patch, ANY, Mock
@pytest.mark.django_db
class TestUserViewSet:
def test_list_users_with_all_permissions(self, authenticated_client_rbac):
response = authenticated_client_rbac.get(reverse("user-list"))
assert response.status_code == status.HTTP_200_OK
assert isinstance(response.json()["data"], list)
def test_list_users_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
response = authenticated_client_no_permissions_rbac.get(reverse("user-list"))
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_retrieve_user_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
response = authenticated_client_rbac.get(
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
)
assert response.status_code == status.HTTP_200_OK
assert (
response.json()["data"]["attributes"]["email"]
== create_test_user_rbac.email
)
def test_retrieve_user_with_no_roles(
self, authenticated_client_rbac_noroles, create_test_user_rbac_no_roles
):
response = authenticated_client_rbac_noroles.get(
reverse("user-detail", kwargs={"pk": create_test_user_rbac_no_roles.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_retrieve_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
response = authenticated_client_no_permissions_rbac.get(
reverse("user-detail", kwargs={"pk": create_test_user.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_create_user_with_all_permissions(self, authenticated_client_rbac):
valid_user_payload = {
"name": "test",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_rbac.post(
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
def test_create_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
valid_user_payload = {
"name": "test",
"password": "newpassword123",
"email": "new_user@test.com",
}
response = authenticated_client_no_permissions_rbac.post(
reverse("user-list"), data=valid_user_payload, format="vnd.api+json"
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["data"]["attributes"]["email"] == "new_user@test.com"
def test_partial_update_user_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
updated_data = {
"data": {
"type": "users",
"id": str(create_test_user_rbac.id),
"attributes": {"name": "Updated Name"},
},
}
response = authenticated_client_rbac.patch(
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id}),
data=updated_data,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["name"] == "Updated Name"
def test_partial_update_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
updated_data = {
"data": {
"type": "users",
"attributes": {"name": "Updated Name"},
}
}
response = authenticated_client_no_permissions_rbac.patch(
reverse("user-detail", kwargs={"pk": create_test_user.id}),
data=updated_data,
format="vnd.api+json",
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_user_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
response = authenticated_client_rbac.delete(
reverse("user-detail", kwargs={"pk": create_test_user_rbac.id})
)
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_delete_user_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
response = authenticated_client_no_permissions_rbac.delete(
reverse("user-detail", kwargs={"pk": create_test_user.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_me_with_all_permissions(
self, authenticated_client_rbac, create_test_user_rbac
):
response = authenticated_client_rbac.get(reverse("user-me"))
assert response.status_code == status.HTTP_200_OK
assert (
response.json()["data"]["attributes"]["email"]
== create_test_user_rbac.email
)
def test_me_with_no_permissions(
self, authenticated_client_no_permissions_rbac, create_test_user
):
response = authenticated_client_no_permissions_rbac.get(reverse("user-me"))
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["email"] == "rbac_limited@rbac.com"
@pytest.mark.django_db
class TestProviderViewSet:
def test_list_providers_with_all_permissions(
self, authenticated_client_rbac, providers_fixture
):
response = authenticated_client_rbac.get(reverse("provider-list"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == len(providers_fixture)
def test_list_providers_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
response = authenticated_client_no_permissions_rbac.get(
reverse("provider-list")
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["data"]) == 0
def test_retrieve_provider_with_all_permissions(
self, authenticated_client_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_rbac.get(
reverse("provider-detail", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["alias"] == provider.alias
def test_retrieve_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_no_permissions_rbac.get(
reverse("provider-detail", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_create_provider_with_all_permissions(self, authenticated_client_rbac):
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
response = authenticated_client_rbac.post(
reverse("provider-list"), data=payload, format="json"
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["data"]["attributes"]["alias"] == "new_alias"
def test_create_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac
):
payload = {"provider": "aws", "uid": "111111111111", "alias": "new_alias"}
response = authenticated_client_no_permissions_rbac.post(
reverse("provider-list"), data=payload, format="json"
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_partial_update_provider_with_all_permissions(
self, authenticated_client_rbac, providers_fixture
):
provider = providers_fixture[0]
payload = {
"data": {
"type": "providers",
"id": provider.id,
"attributes": {"alias": "updated_alias"},
},
}
response = authenticated_client_rbac.patch(
reverse("provider-detail", kwargs={"pk": provider.id}),
data=payload,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["data"]["attributes"]["alias"] == "updated_alias"
def test_partial_update_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
update_payload = {
"data": {
"type": "providers",
"attributes": {"alias": "updated_alias"},
}
}
response = authenticated_client_no_permissions_rbac.patch(
reverse("provider-detail", kwargs={"pk": provider.id}),
data=update_payload,
format="vnd.api+json",
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.delete_provider_task.delay")
def test_delete_provider_with_all_permissions(
self,
mock_delete_task,
mock_task_get,
authenticated_client_rbac,
providers_fixture,
tasks_fixture,
):
prowler_task = tasks_fixture[0]
task_mock = Mock()
task_mock.id = prowler_task.id
mock_delete_task.return_value = task_mock
mock_task_get.return_value = prowler_task
provider1, *_ = providers_fixture
response = authenticated_client_rbac.delete(
reverse("provider-detail", kwargs={"pk": provider1.id})
)
assert response.status_code == status.HTTP_202_ACCEPTED
mock_delete_task.assert_called_once_with(
provider_id=str(provider1.id), tenant_id=ANY
)
assert "Content-Location" in response.headers
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
def test_delete_provider_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_no_permissions_rbac.delete(
reverse("provider-detail", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@patch("api.v1.views.Task.objects.get")
@patch("api.v1.views.check_provider_connection_task.delay")
def test_connection_with_all_permissions(
self,
mock_provider_connection,
mock_task_get,
authenticated_client_rbac,
providers_fixture,
tasks_fixture,
):
prowler_task = tasks_fixture[0]
task_mock = Mock()
task_mock.id = prowler_task.id
task_mock.status = "PENDING"
mock_provider_connection.return_value = task_mock
mock_task_get.return_value = prowler_task
provider1, *_ = providers_fixture
assert provider1.connected is None
assert provider1.connection_last_checked_at is None
response = authenticated_client_rbac.post(
reverse("provider-connection", kwargs={"pk": provider1.id})
)
assert response.status_code == status.HTTP_202_ACCEPTED
mock_provider_connection.assert_called_once_with(
provider_id=str(provider1.id), tenant_id=ANY
)
assert "Content-Location" in response.headers
assert response.headers["Content-Location"] == f"/api/v1/tasks/{task_mock.id}"
def test_connection_with_no_permissions(
self, authenticated_client_no_permissions_rbac, providers_fixture
):
provider = providers_fixture[0]
response = authenticated_client_no_permissions_rbac.post(
reverse("provider-connection", kwargs={"pk": provider.id})
)
assert response.status_code == status.HTTP_403_FORBIDDEN

File diff suppressed because it is too large Load Diff

View File

@@ -17,6 +17,7 @@ from api.models import (
ComplianceOverview,
Finding,
Invitation,
InvitationRoleRelationship,
Membership,
Provider,
ProviderGroup,
@@ -24,10 +25,13 @@ from api.models import (
ProviderSecret,
Resource,
ResourceTag,
Role,
RoleProviderGroupRelationship,
Scan,
StateChoices,
Task,
User,
UserRoleRelationship,
)
from api.rls import Tenant
@@ -176,10 +180,26 @@ class UserSerializer(BaseSerializerV1):
"""
memberships = serializers.ResourceRelatedField(many=True, read_only=True)
roles = serializers.ResourceRelatedField(many=True, read_only=True)
class Meta:
model = User
fields = ["id", "name", "email", "company_name", "date_joined", "memberships"]
fields = [
"id",
"name",
"email",
"company_name",
"date_joined",
"memberships",
"roles",
]
extra_kwargs = {
"roles": {"read_only": True},
}
included_serializers = {
"roles": "api.v1.serializers.RoleSerializer",
}
class UserCreateSerializer(BaseWriteSerializer):
@@ -215,10 +235,13 @@ class UserCreateSerializer(BaseWriteSerializer):
class UserUpdateSerializer(BaseWriteSerializer):
password = serializers.CharField(write_only=True, required=False)
roles = serializers.ResourceRelatedField(
queryset=Role.objects.all(), many=True, required=False
)
class Meta:
model = User
fields = ["id", "name", "password", "email", "company_name"]
fields = ["id", "name", "password", "email", "company_name", "roles"]
extra_kwargs = {
"id": {"read_only": True},
}
@@ -235,6 +258,73 @@ class UserUpdateSerializer(BaseWriteSerializer):
return super().update(instance, validated_data)
class RoleResourceIdentifierSerializer(serializers.Serializer):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
class JSONAPIMeta:
resource_name = "role-identifier"
def to_representation(self, instance):
"""
Ensure 'type' is used in the output instead of 'resource_type'.
"""
representation = super().to_representation(instance)
representation["type"] = representation.pop("resource_type", None)
return representation
def to_internal_value(self, data):
"""
Map 'type' back to 'resource_type' during input.
"""
data["resource_type"] = data.pop("type", None)
return super().to_internal_value(data)
class UserRoleRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for modifying user memberships
"""
roles = serializers.ListField(
child=RoleResourceIdentifierSerializer(),
help_text="List of resource identifier objects representing roles.",
)
def create(self, validated_data):
role_ids = [item["id"] for item in validated_data["roles"]]
roles = Role.objects.filter(id__in=role_ids)
tenant_id = self.context.get("tenant_id")
new_relationships = [
UserRoleRelationship(
user=self.context.get("user"), role=r, tenant_id=tenant_id
)
for r in roles
]
UserRoleRelationship.objects.bulk_create(new_relationships)
return self.context.get("user")
def update(self, instance, validated_data):
role_ids = [item["id"] for item in validated_data["roles"]]
roles = Role.objects.filter(id__in=role_ids)
tenant_id = self.context.get("tenant_id")
instance.roles.clear()
new_relationships = [
UserRoleRelationship(user=instance, role=r, tenant_id=tenant_id)
for r in roles
]
UserRoleRelationship.objects.bulk_create(new_relationships)
return instance
class Meta:
model = UserRoleRelationship
fields = ["id", "roles"]
# Tasks
class TaskBase(serializers.ModelSerializer):
state_mapping = {
@@ -358,89 +448,200 @@ class MembershipSerializer(serializers.ModelSerializer):
# Provider Groups
class ProviderGroupSerializer(RLSSerializer, BaseWriteSerializer):
providers = serializers.ResourceRelatedField(many=True, read_only=True)
providers = serializers.ResourceRelatedField(
queryset=Provider.objects.all(), many=True, required=False
)
roles = serializers.ResourceRelatedField(
queryset=Role.objects.all(), many=True, required=False
)
def validate(self, attrs):
tenant = self.context["tenant_id"]
name = attrs.get("name", self.instance.name if self.instance else None)
# Exclude the current instance when checking for uniqueness during updates
queryset = ProviderGroup.objects.filter(tenant=tenant, name=name)
if self.instance:
queryset = queryset.exclude(pk=self.instance.pk)
if queryset.exists():
if ProviderGroup.objects.filter(name=attrs.get("name")).exists():
raise serializers.ValidationError(
{
"name": "A provider group with this name already exists for this tenant."
}
{"name": "A provider group with this name already exists."}
)
return super().validate(attrs)
class Meta:
model = ProviderGroup
fields = ["id", "name", "inserted_at", "updated_at", "providers", "url"]
read_only_fields = ["id", "inserted_at", "updated_at"]
fields = [
"id",
"name",
"inserted_at",
"updated_at",
"providers",
"roles",
"url",
]
extra_kwargs = {
"id": {"read_only": True},
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"roles": {"read_only": True},
"url": {"read_only": True},
}
class ProviderGroupIncludedSerializer(RLSSerializer, BaseWriteSerializer):
class ProviderGroupIncludedSerializer(ProviderGroupSerializer):
class Meta:
model = ProviderGroup
fields = ["id", "name"]
class ProviderGroupUpdateSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for updating the ProviderGroup model.
Only allows "name" field to be updated.
"""
class Meta:
model = ProviderGroup
fields = ["id", "name"]
class ProviderGroupMembershipUpdateSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for modifying provider group memberships
"""
provider_ids = serializers.ListField(
child=serializers.UUIDField(),
help_text="List of provider UUIDs to add to the group",
class ProviderGroupCreateSerializer(ProviderGroupSerializer):
providers = serializers.ResourceRelatedField(
queryset=Provider.objects.all(), many=True, required=False
)
roles = serializers.ResourceRelatedField(
queryset=Role.objects.all(), many=True, required=False
)
def validate(self, attrs):
tenant_id = self.context["tenant_id"]
provider_ids = attrs.get("provider_ids", [])
class Meta:
model = ProviderGroup
fields = [
"id",
"name",
"inserted_at",
"updated_at",
"providers",
"roles",
"url",
]
existing_provider_ids = set(
Provider.objects.filter(
id__in=provider_ids, tenant_id=tenant_id
).values_list("id", flat=True)
def create(self, validated_data):
providers = validated_data.pop("providers", [])
roles = validated_data.pop("roles", [])
tenant_id = self.context.get("tenant_id")
provider_group = ProviderGroup.objects.create(
tenant_id=tenant_id, **validated_data
)
provided_provider_ids = set(provider_ids)
missing_provider_ids = provided_provider_ids - existing_provider_ids
if missing_provider_ids:
raise serializers.ValidationError(
{
"provider_ids": f"The following provider IDs do not exist: {', '.join(str(id) for id in missing_provider_ids)}"
}
through_model_instances = [
ProviderGroupMembership(
provider_group=provider_group,
provider=provider,
tenant_id=tenant_id,
)
for provider in providers
]
ProviderGroupMembership.objects.bulk_create(through_model_instances)
return super().validate(attrs)
through_model_instances = [
RoleProviderGroupRelationship(
provider_group=provider_group,
role=role,
tenant_id=tenant_id,
)
for role in roles
]
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
return provider_group
class ProviderGroupUpdateSerializer(ProviderGroupSerializer):
def update(self, instance, validated_data):
tenant_id = self.context.get("tenant_id")
if "providers" in validated_data:
providers = validated_data.pop("providers")
instance.providers.clear()
through_model_instances = [
ProviderGroupMembership(
provider_group=instance,
provider=provider,
tenant_id=tenant_id,
)
for provider in providers
]
ProviderGroupMembership.objects.bulk_create(through_model_instances)
if "roles" in validated_data:
roles = validated_data.pop("roles")
instance.roles.clear()
through_model_instances = [
RoleProviderGroupRelationship(
provider_group=instance,
role=role,
tenant_id=tenant_id,
)
for role in roles
]
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
return super().update(instance, validated_data)
class ProviderResourceIdentifierSerializer(serializers.Serializer):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
class JSONAPIMeta:
resource_name = "provider-identifier"
def to_representation(self, instance):
"""
Ensure 'type' is used in the output instead of 'resource_type'.
"""
representation = super().to_representation(instance)
representation["type"] = representation.pop("resource_type", None)
return representation
def to_internal_value(self, data):
"""
Map 'type' back to 'resource_type' during input.
"""
data["resource_type"] = data.pop("type", None)
return super().to_internal_value(data)
class ProviderGroupMembershipSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for modifying provider_group memberships
"""
providers = serializers.ListField(
child=ProviderResourceIdentifierSerializer(),
help_text="List of resource identifier objects representing providers.",
)
def create(self, validated_data):
provider_ids = [item["id"] for item in validated_data["providers"]]
providers = Provider.objects.filter(id__in=provider_ids)
tenant_id = self.context.get("tenant_id")
new_relationships = [
ProviderGroupMembership(
provider_group=self.context.get("provider_group"),
provider=p,
tenant_id=tenant_id,
)
for p in providers
]
ProviderGroupMembership.objects.bulk_create(new_relationships)
return self.context.get("provider_group")
def update(self, instance, validated_data):
provider_ids = [item["id"] for item in validated_data["providers"]]
providers = Provider.objects.filter(id__in=provider_ids)
tenant_id = self.context.get("tenant_id")
instance.providers.clear()
new_relationships = [
ProviderGroupMembership(
provider_group=instance, provider=p, tenant_id=tenant_id
)
for p in providers
]
ProviderGroupMembership.objects.bulk_create(new_relationships)
return instance
class Meta:
model = ProviderGroupMembership
fields = ["id", "provider_ids"]
fields = ["id", "providers"]
# Providers
@@ -1034,6 +1235,8 @@ class InvitationSerializer(RLSSerializer):
Serializer for the Invitation model.
"""
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
class Meta:
model = Invitation
fields = [
@@ -1043,6 +1246,7 @@ class InvitationSerializer(RLSSerializer):
"email",
"state",
"token",
"roles",
"expires_at",
"inviter",
"url",
@@ -1050,6 +1254,8 @@ class InvitationSerializer(RLSSerializer):
class InvitationBaseWriteSerializer(BaseWriteSerializer):
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())
def validate_email(self, value):
user = User.objects.filter(email=value).first()
tenant_id = self.context["tenant_id"]
@@ -1086,31 +1292,54 @@ class InvitationCreateSerializer(InvitationBaseWriteSerializer, RLSSerializer):
class Meta:
model = Invitation
fields = ["email", "expires_at", "state", "token", "inviter"]
fields = ["email", "expires_at", "state", "token", "inviter", "roles"]
extra_kwargs = {
"token": {"read_only": True},
"state": {"read_only": True},
"inviter": {"read_only": True},
"expires_at": {"required": False},
"roles": {"required": False},
}
def create(self, validated_data):
inviter = self.context.get("request").user
tenant_id = self.context.get("tenant_id")
validated_data["inviter"] = inviter
return super().create(validated_data)
roles = validated_data.pop("roles", [])
invitation = super().create(validated_data)
for role in roles:
InvitationRoleRelationship.objects.create(
role=role, invitation=invitation, tenant_id=tenant_id
)
return invitation
class InvitationUpdateSerializer(InvitationBaseWriteSerializer):
class Meta:
model = Invitation
fields = ["id", "email", "expires_at", "state", "token"]
fields = ["id", "email", "expires_at", "state", "token", "roles"]
extra_kwargs = {
"token": {"read_only": True},
"state": {"read_only": True},
"expires_at": {"required": False},
"email": {"required": False},
"roles": {"required": False},
}
def update(self, instance, validated_data):
tenant_id = self.context.get("tenant_id")
invitation = super().update(instance, validated_data)
if "roles" in validated_data:
roles = validated_data.pop("roles")
instance.roles.clear()
for role in roles:
InvitationRoleRelationship.objects.create(
role=role, invitation=invitation, tenant_id=tenant_id
)
return invitation
class InvitationAcceptSerializer(RLSSerializer):
"""Serializer for accepting an invitation."""
@@ -1122,6 +1351,205 @@ class InvitationAcceptSerializer(RLSSerializer):
fields = ["invitation_token"]
# Roles
class RoleSerializer(RLSSerializer, BaseWriteSerializer):
permission_state = serializers.SerializerMethodField()
users = serializers.ResourceRelatedField(
queryset=User.objects.all(), many=True, required=False
)
provider_groups = serializers.ResourceRelatedField(
queryset=ProviderGroup.objects.all(), many=True, required=False
)
def get_permission_state(self, obj) -> str:
return obj.permission_state
def validate(self, attrs):
if Role.objects.filter(name=attrs.get("name")).exists():
raise serializers.ValidationError(
{"name": "A role with this name already exists."}
)
if attrs.get("manage_providers"):
attrs["unlimited_visibility"] = True
# Prevent updates to the admin role
if getattr(self.instance, "name", None) == "admin":
raise serializers.ValidationError(
{"name": "The admin role cannot be updated."}
)
return super().validate(attrs)
class Meta:
model = Role
fields = [
"id",
"name",
"manage_users",
"manage_account",
"manage_billing",
"manage_providers",
"manage_integrations",
"manage_scans",
"permission_state",
"unlimited_visibility",
"inserted_at",
"updated_at",
"provider_groups",
"users",
"invitations",
"url",
]
extra_kwargs = {
"id": {"read_only": True},
"inserted_at": {"read_only": True},
"updated_at": {"read_only": True},
"url": {"read_only": True},
}
class RoleCreateSerializer(RoleSerializer):
provider_groups = serializers.ResourceRelatedField(
many=True, queryset=ProviderGroup.objects.all(), required=False
)
users = serializers.ResourceRelatedField(
many=True, queryset=User.objects.all(), required=False
)
def create(self, validated_data):
provider_groups = validated_data.pop("provider_groups", [])
users = validated_data.pop("users", [])
tenant_id = self.context.get("tenant_id")
role = Role.objects.create(tenant_id=tenant_id, **validated_data)
through_model_instances = [
RoleProviderGroupRelationship(
role=role,
provider_group=provider_group,
tenant_id=tenant_id,
)
for provider_group in provider_groups
]
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
through_model_instances = [
UserRoleRelationship(
role=role,
user=user,
tenant_id=tenant_id,
)
for user in users
]
UserRoleRelationship.objects.bulk_create(through_model_instances)
return role
class RoleUpdateSerializer(RoleSerializer):
def update(self, instance, validated_data):
tenant_id = self.context.get("tenant_id")
if "provider_groups" in validated_data:
provider_groups = validated_data.pop("provider_groups")
instance.provider_groups.clear()
through_model_instances = [
RoleProviderGroupRelationship(
role=instance,
provider_group=provider_group,
tenant_id=tenant_id,
)
for provider_group in provider_groups
]
RoleProviderGroupRelationship.objects.bulk_create(through_model_instances)
if "users" in validated_data:
users = validated_data.pop("users")
instance.users.clear()
through_model_instances = [
UserRoleRelationship(
role=instance,
user=user,
tenant_id=tenant_id,
)
for user in users
]
UserRoleRelationship.objects.bulk_create(through_model_instances)
return super().update(instance, validated_data)
class ProviderGroupResourceIdentifierSerializer(serializers.Serializer):
resource_type = serializers.CharField(source="type")
id = serializers.UUIDField()
class JSONAPIMeta:
resource_name = "provider-group-identifier"
def to_representation(self, instance):
"""
Ensure 'type' is used in the output instead of 'resource_type'.
"""
representation = super().to_representation(instance)
representation["type"] = representation.pop("resource_type", None)
return representation
def to_internal_value(self, data):
"""
Map 'type' back to 'resource_type' during input.
"""
data["resource_type"] = data.pop("type", None)
return super().to_internal_value(data)
class RoleProviderGroupRelationshipSerializer(RLSSerializer, BaseWriteSerializer):
"""
Serializer for modifying role memberships
"""
provider_groups = serializers.ListField(
child=ProviderGroupResourceIdentifierSerializer(),
help_text="List of resource identifier objects representing provider groups.",
)
def create(self, validated_data):
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
tenant_id = self.context.get("tenant_id")
new_relationships = [
RoleProviderGroupRelationship(
role=self.context.get("role"), provider_group=pg, tenant_id=tenant_id
)
for pg in provider_groups
]
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
return self.context.get("role")
def update(self, instance, validated_data):
provider_group_ids = [item["id"] for item in validated_data["provider_groups"]]
provider_groups = ProviderGroup.objects.filter(id__in=provider_group_ids)
tenant_id = self.context.get("tenant_id")
instance.provider_groups.clear()
new_relationships = [
RoleProviderGroupRelationship(
role=instance, provider_group=pg, tenant_id=tenant_id
)
for pg in provider_groups
]
RoleProviderGroupRelationship.objects.bulk_create(new_relationships)
return instance
class Meta:
model = RoleProviderGroupRelationship
fields = ["id", "provider_groups"]
# Compliance overview
@@ -1334,6 +1762,24 @@ class OverviewSeveritySerializer(serializers.Serializer):
return {"version": "v1"}
class OverviewServiceSerializer(serializers.Serializer):
id = serializers.CharField(source="service")
total = serializers.IntegerField()
_pass = serializers.IntegerField()
fail = serializers.IntegerField()
muted = serializers.IntegerField()
class JSONAPIMeta:
resource_name = "services-overview"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields["pass"] = self.fields.pop("_pass")
def get_root_meta(self, _resource, _many):
return {"version": "v1"}
# Schedules

View File

@@ -3,16 +3,20 @@ from drf_spectacular.views import SpectacularRedocView
from rest_framework_nested import routers
from api.v1.views import (
ComplianceOverviewViewSet,
CustomTokenObtainView,
CustomTokenRefreshView,
FindingViewSet,
InvitationAcceptViewSet,
InvitationViewSet,
MembershipViewSet,
OverviewViewSet,
ProviderGroupViewSet,
ProviderGroupProvidersRelationshipView,
ProviderSecretViewSet,
InvitationViewSet,
InvitationAcceptViewSet,
RoleViewSet,
RoleProviderGroupRelationshipView,
UserRoleRelationshipView,
OverviewViewSet,
ComplianceOverviewViewSet,
ProviderViewSet,
ResourceViewSet,
ScanViewSet,
@@ -29,11 +33,12 @@ router = routers.DefaultRouter(trailing_slash=False)
router.register(r"users", UserViewSet, basename="user")
router.register(r"tenants", TenantViewSet, basename="tenant")
router.register(r"providers", ProviderViewSet, basename="provider")
router.register(r"provider_groups", ProviderGroupViewSet, basename="providergroup")
router.register(r"provider-groups", ProviderGroupViewSet, basename="providergroup")
router.register(r"scans", ScanViewSet, basename="scan")
router.register(r"tasks", TaskViewSet, basename="task")
router.register(r"resources", ResourceViewSet, basename="resource")
router.register(r"findings", FindingViewSet, basename="finding")
router.register(r"roles", RoleViewSet, basename="role")
router.register(
r"compliance-overviews", ComplianceOverviewViewSet, basename="complianceoverview"
)
@@ -80,6 +85,27 @@ urlpatterns = [
InvitationAcceptViewSet.as_view({"post": "accept"}),
name="invitation-accept",
),
path(
"roles/<uuid:pk>/relationships/provider_groups",
RoleProviderGroupRelationshipView.as_view(
{"post": "create", "patch": "partial_update", "delete": "destroy"}
),
name="role-provider-groups-relationship",
),
path(
"users/<uuid:pk>/relationships/roles",
UserRoleRelationshipView.as_view(
{"post": "create", "patch": "partial_update", "delete": "destroy"}
),
name="user-roles-relationship",
),
path(
"provider-groups/<uuid:pk>/relationships/providers",
ProviderGroupProvidersRelationshipView.as_view(
{"post": "create", "patch": "partial_update", "delete": "destroy"}
),
name="provider_group-providers-relationship",
),
path("", include(router.urls)),
path("", include(tenants_router.urls)),
path("", include(users_router.urls)),

View File

@@ -16,6 +16,7 @@ from drf_spectacular.utils import (
extend_schema_view,
)
from drf_spectacular.views import SpectacularAPIView
from drf_spectacular_jsonapi.schemas.openapi import JsonApiAutoSchema
from rest_framework import permissions, status
from rest_framework.decorators import action
from rest_framework.exceptions import (
@@ -25,7 +26,8 @@ from rest_framework.exceptions import (
ValidationError,
)
from rest_framework.generics import GenericAPIView, get_object_or_404
from rest_framework_json_api.views import Response
from rest_framework.permissions import SAFE_METHODS
from rest_framework_json_api.views import RelationshipView, Response
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from tasks.beat import schedule_provider_scan
from tasks.tasks import (
@@ -47,8 +49,10 @@ from api.filters import (
ProviderGroupFilter,
ProviderSecretFilter,
ResourceFilter,
RoleFilter,
ScanFilter,
ScanSummaryFilter,
ServiceOverviewFilter,
TaskFilter,
TenantFilter,
UserFilter,
@@ -63,6 +67,8 @@ from api.models import (
ProviderGroupMembership,
ProviderSecret,
Resource,
Role,
RoleProviderGroupRelationship,
Scan,
ScanSummary,
SeverityChoices,
@@ -70,8 +76,10 @@ from api.models import (
StatusChoices,
Task,
User,
UserRoleRelationship,
)
from api.pagination import ComplianceOverviewPagination
from api.rbac.permissions import Permissions, get_providers, get_role
from api.rls import Tenant
from api.utils import validate_invitation
from api.uuid_utils import datetime_to_uuid7
@@ -87,10 +95,12 @@ from api.v1.serializers import (
MembershipSerializer,
OverviewFindingSerializer,
OverviewProviderSerializer,
OverviewServiceSerializer,
OverviewSeveritySerializer,
ProviderCreateSerializer,
ProviderGroupMembershipUpdateSerializer,
ProviderGroupMembershipSerializer,
ProviderGroupSerializer,
ProviderGroupCreateSerializer,
ProviderGroupUpdateSerializer,
ProviderSecretCreateSerializer,
ProviderSecretSerializer,
@@ -98,6 +108,10 @@ from api.v1.serializers import (
ProviderSerializer,
ProviderUpdateSerializer,
ResourceSerializer,
RoleCreateSerializer,
RoleProviderGroupRelationshipSerializer,
RoleSerializer,
RoleUpdateSerializer,
ScanCreateSerializer,
ScanSerializer,
ScanUpdateSerializer,
@@ -107,6 +121,7 @@ from api.v1.serializers import (
TokenRefreshSerializer,
TokenSerializer,
UserCreateSerializer,
UserRoleRelationshipSerializer,
UserSerializer,
UserUpdateSerializer,
)
@@ -117,6 +132,11 @@ CACHE_DECORATOR = cache_control(
)
class RelationshipViewSchema(JsonApiAutoSchema):
def _resolve_path_parameters(self, _path_variables):
return []
@extend_schema(
tags=["Token"],
summary="Obtain a token",
@@ -172,7 +192,7 @@ class SchemaView(SpectacularAPIView):
def get(self, request, *args, **kwargs):
spectacular_settings.TITLE = "Prowler API"
spectacular_settings.VERSION = "1.0.1"
spectacular_settings.VERSION = "1.1.0"
spectacular_settings.DESCRIPTION = (
"Prowler API specification.\n\nThis file is auto-generated."
)
@@ -271,6 +291,19 @@ class UserViewSet(BaseUserViewset):
filterset_class = UserFilter
ordering = ["-date_joined"]
ordering_fields = ["name", "email", "company_name", "date_joined", "is_active"]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_USERS]
def set_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.action == "me":
# No permissions required for me request
self.required_permissions = []
else:
# Require permission for the rest of the requests
self.required_permissions = [Permissions.MANAGE_USERS]
def get_queryset(self):
# If called during schema generation, return an empty queryset
@@ -347,11 +380,125 @@ class UserViewSet(BaseUserViewset):
user=user, tenant=tenant, role=role
)
if invitation:
user_role = []
for role in invitation.roles.all():
user_role.append(
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=role, tenant=invitation.tenant
)
)
invitation.state = Invitation.State.ACCEPTED
invitation.save(using=MainRouter.admin_db)
else:
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
return Response(data=UserSerializer(user).data, status=status.HTTP_201_CREATED)
@extend_schema_view(
create=extend_schema(
tags=["User"],
summary="Create a new user-roles relationship",
description="Add a new user-roles relationship to the system by providing the required user-roles details.",
responses={
204: OpenApiResponse(description="Relationship created successfully"),
400: OpenApiResponse(
description="Bad request (e.g., relationship already exists)"
),
},
),
partial_update=extend_schema(
tags=["User"],
summary="Partially update a user-roles relationship",
description="Update the user-roles relationship information without affecting other fields.",
responses={
204: OpenApiResponse(
response=None, description="Relationship updated successfully"
)
},
),
destroy=extend_schema(
tags=["User"],
summary="Delete a user-roles relationship",
description="Remove the user-roles relationship from the system by their ID.",
responses={
204: OpenApiResponse(
response=None, description="Relationship deleted successfully"
)
},
),
)
class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
queryset = User.objects.all()
serializer_class = UserRoleRelationshipSerializer
resource_name = "roles"
http_method_names = ["post", "patch", "delete"]
schema = RelationshipViewSchema()
# RBAC required permissions
required_permissions = [Permissions.MANAGE_USERS]
def get_queryset(self):
return User.objects.all()
def create(self, request, *args, **kwargs):
user = self.get_object()
role_ids = [item["id"] for item in request.data]
existing_relationships = UserRoleRelationship.objects.filter(
user=user, role_id__in=role_ids
)
if existing_relationships.exists():
return Response(
{"detail": "One or more roles are already associated with the user."},
status=status.HTTP_400_BAD_REQUEST,
)
serializer = self.get_serializer(
data={"roles": request.data},
context={
"user": user,
"tenant_id": self.request.tenant_id,
"request": request,
},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def partial_update(self, request, *args, **kwargs):
user = self.get_object()
serializer = self.get_serializer(
instance=user,
data={"roles": request.data},
context={"tenant_id": self.request.tenant_id, "request": request},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def destroy(self, request, *args, **kwargs):
user = self.get_object()
user.roles.clear()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
list=extend_schema(
tags=["Tenant"],
@@ -389,6 +536,8 @@ class TenantViewSet(BaseTenantViewset):
search_fields = ["name"]
ordering = ["-inserted_at"]
ordering_fields = ["name", "inserted_at", "updated_at"]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Tenant.objects.all()
@@ -446,6 +595,8 @@ class MembershipViewSet(BaseTenantViewset):
"role",
"date_joined",
]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
user = self.request.user
@@ -479,6 +630,8 @@ class TenantMembersViewSet(BaseTenantViewset):
http_method_names = ["get", "delete"]
serializer_class = MembershipSerializer
queryset = Membership.objects.none()
# RBAC required permissions
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
tenant = self.get_tenant()
@@ -562,66 +715,128 @@ class ProviderGroupViewSet(BaseRLSViewSet):
queryset = ProviderGroup.objects.all()
serializer_class = ProviderGroupSerializer
filterset_class = ProviderGroupFilter
http_method_names = ["get", "post", "patch", "put", "delete"]
http_method_names = ["get", "post", "patch", "delete"]
ordering = ["inserted_at"]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_PROVIDERS]
def set_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.request.method in SAFE_METHODS:
# No permissions required for GET requests
self.required_permissions = []
else:
# Require permission for non-GET requests
self.required_permissions = [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return ProviderGroup.objects.prefetch_related("providers")
user_roles = get_role(self.request.user)
# Check if any of the user's roles have UNLIMITED_VISIBILITY
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all provider groups
return ProviderGroup.objects.prefetch_related("providers")
# Collect provider groups associated with the user's roles
return user_roles.provider_groups.all()
def get_serializer_class(self):
if self.action == "partial_update":
if self.action == "create":
return ProviderGroupCreateSerializer
elif self.action == "partial_update":
return ProviderGroupUpdateSerializer
elif self.action == "providers":
if hasattr(self, "response_serializer_class"):
return self.response_serializer_class
return ProviderGroupMembershipUpdateSerializer
return super().get_serializer_class()
@extend_schema(
tags=["Provider Group"],
summary="Add providers to a provider group",
description="Add one or more providers to an existing provider group.",
request=ProviderGroupMembershipUpdateSerializer,
responses={200: OpenApiResponse(response=ProviderGroupSerializer)},
)
@action(detail=True, methods=["put"], url_name="providers")
def providers(self, request, pk=None):
@extend_schema(tags=["Provider Group"])
@extend_schema_view(
create=extend_schema(
summary="Create a new provider_group-providers relationship",
description="Add a new provider_group-providers relationship to the system by providing the required provider_group-providers details.",
responses={
204: OpenApiResponse(description="Relationship created successfully"),
400: OpenApiResponse(
description="Bad request (e.g., relationship already exists)"
),
},
),
partial_update=extend_schema(
summary="Partially update a provider_group-providers relationship",
description="Update the provider_group-providers relationship information without affecting other fields.",
responses={
204: OpenApiResponse(
response=None, description="Relationship updated successfully"
)
},
),
destroy=extend_schema(
summary="Delete a provider_group-providers relationship",
description="Remove the provider_group-providers relationship from the system by their ID.",
responses={
204: OpenApiResponse(
response=None, description="Relationship deleted successfully"
)
},
),
)
class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
queryset = ProviderGroup.objects.all()
serializer_class = ProviderGroupMembershipSerializer
resource_name = "providers"
http_method_names = ["post", "patch", "delete"]
schema = RelationshipViewSchema()
# RBAC required permissions
required_permissions = [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return ProviderGroup.objects.all()
def create(self, request, *args, **kwargs):
provider_group = self.get_object()
# Validate input data
serializer = self.get_serializer_class()(
data=request.data,
context=self.get_serializer_context(),
provider_ids = [item["id"] for item in request.data]
existing_relationships = ProviderGroupMembership.objects.filter(
provider_group=provider_group, provider_id__in=provider_ids
)
if existing_relationships.exists():
return Response(
{
"detail": "One or more providers are already associated with the provider_group."
},
status=status.HTTP_400_BAD_REQUEST,
)
serializer = self.get_serializer(
data={"providers": request.data},
context={
"provider_group": provider_group,
"tenant_id": self.request.tenant_id,
"request": request,
},
)
serializer.is_valid(raise_exception=True)
serializer.save()
provider_ids = serializer.validated_data["provider_ids"]
return Response(status=status.HTTP_204_NO_CONTENT)
# Update memberships
ProviderGroupMembership.objects.filter(
provider_group=provider_group, tenant_id=request.tenant_id
).delete()
provider_group_memberships = [
ProviderGroupMembership(
tenant_id=self.request.tenant_id,
provider_group=provider_group,
provider_id=provider_id,
)
for provider_id in provider_ids
]
ProviderGroupMembership.objects.bulk_create(
provider_group_memberships, ignore_conflicts=True
def partial_update(self, request, *args, **kwargs):
provider_group = self.get_object()
serializer = self.get_serializer(
instance=provider_group,
data={"providers": request.data},
context={"tenant_id": self.request.tenant_id, "request": request},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
# Return the updated provider group with providers
provider_group.refresh_from_db()
self.response_serializer_class = ProviderGroupSerializer
response_serializer = ProviderGroupSerializer(
provider_group, context=self.get_serializer_context()
)
return Response(data=response_serializer.data, status=status.HTTP_200_OK)
def destroy(self, request, *args, **kwargs):
provider_group = self.get_object()
provider_group.providers.clear()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
@@ -671,9 +886,28 @@ class ProviderViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_PROVIDERS]
def set_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.request.method in SAFE_METHODS:
# No permissions required for GET requests
self.required_permissions = []
else:
# Require permission for non-GET requests
self.required_permissions = [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return Provider.objects.all()
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all providers
return Provider.objects.all()
# User lacks permission, filter providers based on provider groups associated with the role
return get_providers(user_roles)
def get_serializer_class(self):
if self.action == "create":
@@ -793,9 +1027,28 @@ class ScanViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_SCANS]
def set_required_permissions(self):
"""
Returns the required permissions based on the request method.
"""
if self.request.method in SAFE_METHODS:
# No permissions required for GET requests
self.required_permissions = [Permissions.MANAGE_PROVIDERS]
else:
# Require permission for non-GET requests
self.required_permissions = [Permissions.MANAGE_SCANS]
def get_queryset(self):
return Scan.objects.all()
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all scans
return Scan.objects.all()
# User lacks permission, filter providers based on provider groups associated with the role
return Scan.objects.filter(provider__in=get_providers(user_roles))
def get_serializer_class(self):
if self.action == "create":
@@ -885,10 +1138,13 @@ class TaskViewSet(BaseRLSViewSet):
search_fields = ["name"]
ordering = ["-inserted_at"]
ordering_fields = ["inserted_at", "completed_at", "name", "state"]
# RBAC required permissions
required_permissions = []
def get_queryset(self):
return Task.objects.annotate(
name=F("task_runner_task__task_name"), state=F("task_runner_task__status")
name=F("task_runner_task__task_name"),
state=F("task_runner_task__status"),
)
def destroy(self, request, *args, pk=None, **kwargs):
@@ -950,11 +1206,19 @@ class ResourceViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
required_permissions = []
def get_queryset(self):
queryset = Resource.objects.all()
search_value = self.request.query_params.get("filter[search]", None)
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all scans
queryset = Resource.objects.all()
else:
# User lacks permission, filter providers based on provider groups associated with the role
queryset = Resource.objects.filter(provider__in=get_providers(user_roles))
search_value = self.request.query_params.get("filter[search]", None)
if search_value:
# Django's ORM will build a LEFT JOIN and OUTER JOIN on the "through" table, resulting in duplicates
# The duplicates then require a `distinct` query
@@ -1025,11 +1289,8 @@ class FindingViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
def inserted_at_to_uuidv7(self, inserted_at):
if inserted_at is None:
return None
return datetime_to_uuid7(inserted_at)
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
required_permissions = []
def get_serializer_class(self):
if self.action == "findings_services_regions":
@@ -1038,9 +1299,17 @@ class FindingViewSet(BaseRLSViewSet):
return super().get_serializer_class()
def get_queryset(self):
queryset = Finding.objects.all()
search_value = self.request.query_params.get("filter[search]", None)
user_roles = get_role(self.request.user)
if user_roles.unlimited_visibility:
# User has unlimited visibility, return all scans
queryset = Finding.objects.all()
else:
# User lacks permission, filter providers based on provider groups associated with the role
queryset = Finding.objects.filter(
scan__provider__in=get_providers(user_roles)
)
search_value = self.request.query_params.get("filter[search]", None)
if search_value:
# Django's ORM will build a LEFT JOIN and OUTER JOIN on any "through" tables, resulting in duplicates
# The duplicates then require a `distinct` query
@@ -1068,6 +1337,11 @@ class FindingViewSet(BaseRLSViewSet):
return queryset
def inserted_at_to_uuidv7(self, inserted_at):
if inserted_at is None:
return None
return datetime_to_uuid7(inserted_at)
@action(detail=False, methods=["get"], url_name="findings_services_regions")
def findings_services_regions(self, request):
queryset = self.get_queryset()
@@ -1131,6 +1405,8 @@ class ProviderSecretViewSet(BaseRLSViewSet):
"inserted_at",
"updated_at",
]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_PROVIDERS]
def get_queryset(self):
return ProviderSecret.objects.all()
@@ -1188,6 +1464,8 @@ class InvitationViewSet(BaseRLSViewSet):
"state",
"inviter",
]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Invitation.objects.all()
@@ -1275,6 +1553,13 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
user=user,
tenant=invitation.tenant,
)
user_role = []
for role in invitation.roles.all():
user_role.append(
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=role, tenant=invitation.tenant
)
)
invitation.state = Invitation.State.ACCEPTED
invitation.save(using=MainRouter.admin_db)
@@ -1283,6 +1568,154 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
return Response(data=membership_serializer.data, status=status.HTTP_201_CREATED)
@extend_schema(tags=["Role"])
@extend_schema_view(
list=extend_schema(
tags=["Role"],
summary="List all roles",
description="Retrieve a list of all roles with options for filtering by various criteria.",
),
retrieve=extend_schema(
tags=["Role"],
summary="Retrieve data from a role",
description="Fetch detailed information about a specific role by their ID.",
),
create=extend_schema(
tags=["Role"],
summary="Create a new role",
description="Add a new role to the system by providing the required role details.",
),
partial_update=extend_schema(
tags=["Role"],
summary="Partially update a role",
description="Update certain fields of an existing role's information without affecting other fields.",
responses={200: RoleSerializer},
),
destroy=extend_schema(
tags=["Role"],
summary="Delete a role",
description="Remove a role from the system by their ID.",
),
)
class RoleViewSet(BaseRLSViewSet):
queryset = Role.objects.all()
serializer_class = RoleSerializer
filterset_class = RoleFilter
http_method_names = ["get", "post", "patch", "delete"]
ordering = ["inserted_at"]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Role.objects.all()
def get_serializer_class(self):
if self.action == "create":
return RoleCreateSerializer
elif self.action == "partial_update":
return RoleUpdateSerializer
return super().get_serializer_class()
def partial_update(self, request, *args, **kwargs):
user_role = get_role(request.user)
# If the user is the owner of the role, the manage_account field is not editable
if user_role and kwargs["pk"] == str(user_role.id):
request.data["manage_account"] = str(user_role.manage_account).lower()
return super().partial_update(request, *args, **kwargs)
@extend_schema_view(
create=extend_schema(
tags=["Role"],
summary="Create a new role-provider_groups relationship",
description="Add a new role-provider_groups relationship to the system by providing the required role-provider_groups details.",
responses={
204: OpenApiResponse(description="Relationship created successfully"),
400: OpenApiResponse(
description="Bad request (e.g., relationship already exists)"
),
},
),
partial_update=extend_schema(
tags=["Role"],
summary="Partially update a role-provider_groups relationship",
description="Update the role-provider_groups relationship information without affecting other fields.",
responses={
204: OpenApiResponse(
response=None, description="Relationship updated successfully"
)
},
),
destroy=extend_schema(
tags=["Role"],
summary="Delete a role-provider_groups relationship",
description="Remove the role-provider_groups relationship from the system by their ID.",
responses={
204: OpenApiResponse(
response=None, description="Relationship deleted successfully"
)
},
),
)
class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
queryset = Role.objects.all()
serializer_class = RoleProviderGroupRelationshipSerializer
resource_name = "provider_groups"
http_method_names = ["post", "patch", "delete"]
schema = RelationshipViewSchema()
# RBAC required permissions
required_permissions = [Permissions.MANAGE_ACCOUNT]
def get_queryset(self):
return Role.objects.all()
def create(self, request, *args, **kwargs):
role = self.get_object()
provider_group_ids = [item["id"] for item in request.data]
existing_relationships = RoleProviderGroupRelationship.objects.filter(
role=role, provider_group_id__in=provider_group_ids
)
if existing_relationships.exists():
return Response(
{
"detail": "One or more provider groups are already associated with the role."
},
status=status.HTTP_400_BAD_REQUEST,
)
serializer = self.get_serializer(
data={"provider_groups": request.data},
context={
"role": role,
"tenant_id": self.request.tenant_id,
"request": request,
},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def partial_update(self, request, *args, **kwargs):
role = self.get_object()
serializer = self.get_serializer(
instance=role,
data={"provider_groups": request.data},
context={"tenant_id": self.request.tenant_id, "request": request},
)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(status=status.HTTP_204_NO_CONTENT)
def destroy(self, request, *args, **kwargs):
role = self.get_object()
role.provider_groups.clear()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
list=extend_schema(
tags=["Compliance Overview"],
@@ -1317,12 +1750,32 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
search_fields = ["compliance_id"]
ordering = ["compliance_id"]
ordering_fields = ["inserted_at", "compliance_id", "framework", "region"]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
required_permissions = []
def get_queryset(self):
if self.action == "retrieve":
return ComplianceOverview.objects.all()
role = get_role(self.request.user)
unlimited_visibility = getattr(
role, Permissions.UNLIMITED_VISIBILITY.value, False
)
base_queryset = self.filter_queryset(ComplianceOverview.objects.all())
if self.action == "retrieve":
if unlimited_visibility:
# User has unlimited visibility, return all compliance compliances
return ComplianceOverview.objects.all()
providers = get_providers(role)
return ComplianceOverview.objects.filter(scan__provider__in=providers)
if unlimited_visibility:
base_queryset = self.filter_queryset(ComplianceOverview.objects.all())
else:
providers = Provider.objects.filter(
provider_groups__in=role.provider_groups.all()
).distinct()
base_queryset = self.filter_queryset(
ComplianceOverview.objects.filter(scan__provider__in=providers)
)
max_failed_ids = (
base_queryset.filter(compliance_id=OuterRef("compliance_id"))
@@ -1330,12 +1783,10 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
.values("id")[:1]
)
queryset = base_queryset.filter(id__in=Subquery(max_failed_ids)).order_by(
return base_queryset.filter(id__in=Subquery(max_failed_ids)).order_by(
"compliance_id"
)
return queryset
def get_serializer_class(self):
if self.action == "retrieve":
return ComplianceOverviewFullSerializer
@@ -1387,20 +1838,37 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
),
filters=True,
),
services=extend_schema(
summary="Get findings data by service",
description=(
"Retrieve an aggregated summary of findings grouped by service. The response includes the total count "
"of findings for each service, as long as there are at least one finding for that service. At least "
"one of the `inserted_at` filters must be provided."
),
filters=True,
),
)
@method_decorator(CACHE_DECORATOR, name="list")
class OverviewViewSet(BaseRLSViewSet):
queryset = ComplianceOverview.objects.all()
http_method_names = ["get"]
ordering = ["-id"]
# RBAC required permissions (implicit -> MANAGE_PROVIDERS enable unlimited visibility or check the visibility of the provider through the provider group)
required_permissions = []
def get_queryset(self):
role = get_role(self.request.user)
providers = get_providers(role)
def _get_filtered_queryset(model):
if role.unlimited_visibility:
return model.objects.all()
return model.objects.filter(scan__provider__in=providers)
if self.action == "providers":
return Finding.objects.all()
elif self.action == "findings":
return ScanSummary.objects.all()
elif self.action == "findings_severity":
return ScanSummary.objects.all()
return _get_filtered_queryset(Finding)
elif self.action in ("findings", "findings_severity", "services"):
return _get_filtered_queryset(ScanSummary)
else:
return super().get_queryset()
@@ -1411,6 +1879,8 @@ class OverviewViewSet(BaseRLSViewSet):
return OverviewFindingSerializer
elif self.action == "findings_severity":
return OverviewSeveritySerializer
elif self.action == "services":
return OverviewServiceSerializer
return super().get_serializer_class()
def get_filterset_class(self):
@@ -1418,6 +1888,8 @@ class OverviewViewSet(BaseRLSViewSet):
return None
elif self.action in ["findings", "findings_severity"]:
return ScanSummaryFilter
elif self.action == "services":
return ServiceOverviewFilter
return None
@extend_schema(exclude=True)
@@ -1563,6 +2035,38 @@ class OverviewViewSet(BaseRLSViewSet):
serializer = OverviewSeveritySerializer(severity_data)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(detail=False, methods=["get"], url_name="services")
def services(self, request):
queryset = self.get_queryset()
filtered_queryset = self.filter_queryset(queryset)
latest_scan_subquery = (
Scan.objects.filter(
state=StateChoices.COMPLETED, provider_id=OuterRef("scan__provider_id")
)
.order_by("-id")
.values("id")[:1]
)
annotated_queryset = filtered_queryset.annotate(
latest_scan_id=Subquery(latest_scan_subquery)
)
filtered_queryset = annotated_queryset.filter(scan_id=F("latest_scan_id"))
services_data = (
filtered_queryset.values("service")
.annotate(_pass=Sum("_pass"))
.annotate(fail=Sum("fail"))
.annotate(muted=Sum("muted"))
.annotate(total=Sum("total"))
.order_by("service")
)
serializer = OverviewServiceSerializer(services_data, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@extend_schema(tags=["Schedule"])
@extend_schema_view(
@@ -1578,6 +2082,8 @@ class ScheduleViewSet(BaseRLSViewSet):
# TODO: change to Schedule when implemented
queryset = Task.objects.none()
http_method_names = ["post"]
# RBAC required permissions
required_permissions = [Permissions.MANAGE_SCANS]
def get_queryset(self):
return super().get_queryset()

View File

@@ -35,10 +35,10 @@ class RLSTask(Task):
**options,
)
task_result_instance = TaskResult.objects.get(task_id=result.task_id)
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
tenant_id = kwargs.get("tenant_id")
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
APITask.objects.create(
id=task_result_instance.task_id,
tenant_id=tenant_id,

View File

@@ -10,8 +10,8 @@ DATABASES = {
"default": {
"ENGINE": "psqlextra.backend",
"NAME": "prowler_db_test",
"USER": env("POSTGRES_USER", default="prowler"),
"PASSWORD": env("POSTGRES_PASSWORD", default="S3cret"),
"USER": env("POSTGRES_USER", default="prowler_admin"),
"PASSWORD": env("POSTGRES_PASSWORD", default="postgres"),
"HOST": env("POSTGRES_HOST", default="localhost"),
"PORT": env("POSTGRES_PORT", default="5432"),
},

View File

@@ -1,35 +1,39 @@
import logging
from datetime import datetime, timedelta, timezone
from unittest.mock import patch
import pytest
from django.conf import settings
from datetime import datetime, timezone, timedelta
from django.db import connections as django_connections, connection as django_connection
from django.db import connection as django_connection
from django.db import connections as django_connections
from django.urls import reverse
from django_celery_results.models import TaskResult
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
from rest_framework import status
from rest_framework.test import APIClient
from api.db_utils import rls_transaction
from api.models import (
ComplianceOverview,
Finding,
)
from api.models import (
User,
Invitation,
Membership,
Provider,
ProviderGroup,
ProviderSecret,
Resource,
ResourceTag,
Role,
Scan,
ScanSummary,
StateChoices,
Task,
Membership,
ProviderSecret,
Invitation,
ComplianceOverview,
User,
UserRoleRelationship,
)
from api.rls import Tenant
from api.v1.serializers import TokenSerializer
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
NO_TENANT_HTTP_STATUS = status.HTTP_401_UNAUTHORIZED
@@ -83,8 +87,148 @@ def create_test_user(django_db_setup, django_db_blocker):
return user
@pytest.fixture(scope="function")
def create_test_user_rbac(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing",
email="rbac@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
Role.objects.create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.create(
user=user,
role=Role.objects.get(name="admin"),
tenant_id=tenant.id,
)
return user
@pytest.fixture(scope="function")
def create_test_user_rbac_no_roles(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing",
email="rbac_noroles@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
return user
@pytest.fixture(scope="function")
def create_test_user_rbac_limited(django_db_setup, django_db_blocker):
with django_db_blocker.unblock():
user = User.objects.create_user(
name="testing_limited",
email="rbac_limited@rbac.com",
password=TEST_PASSWORD,
)
tenant = Tenant.objects.create(
name="Tenant Test",
)
Membership.objects.create(
user=user,
tenant=tenant,
role=Membership.RoleChoices.OWNER,
)
Role.objects.create(
name="limited",
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=False,
manage_integrations=False,
manage_scans=False,
unlimited_visibility=False,
)
UserRoleRelationship.objects.create(
user=user,
role=Role.objects.get(name="limited"),
tenant_id=tenant.id,
)
return user
@pytest.fixture
def authenticated_client(create_test_user, tenants_fixture, client):
def authenticated_client_rbac(create_test_user_rbac, tenants_fixture, client):
client.user = create_test_user_rbac
serializer = TokenSerializer(
data={"type": "tokens", "email": "rbac@rbac.com", "password": TEST_PASSWORD}
)
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@pytest.fixture
def authenticated_client_rbac_noroles(
create_test_user_rbac_no_roles, tenants_fixture, client
):
client.user = create_test_user_rbac_no_roles
serializer = TokenSerializer(
data={
"type": "tokens",
"email": "rbac_noroles@rbac.com",
"password": TEST_PASSWORD,
}
)
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@pytest.fixture
def authenticated_client_no_permissions_rbac(
create_test_user_rbac_limited, tenants_fixture, client
):
client.user = create_test_user_rbac_limited
serializer = TokenSerializer(
data={
"type": "tokens",
"email": "rbac_limited@rbac.com",
"password": TEST_PASSWORD,
}
)
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@pytest.fixture
def authenticated_client(
create_test_user, tenants_fixture, set_user_admin_roles_fixture, client
):
client.user = create_test_user
serializer = TokenSerializer(
data={"type": "tokens", "email": TEST_USER, "password": TEST_PASSWORD}
@@ -104,6 +248,7 @@ def authenticated_api_client(create_test_user, tenants_fixture):
serializer.is_valid()
access_token = serializer.validated_data["access"]
client.defaults["HTTP_AUTHORIZATION"] = f"Bearer {access_token}"
return client
@@ -128,9 +273,33 @@ def tenants_fixture(create_test_user):
tenant3 = Tenant.objects.create(
name="Tenant Three",
)
return tenant1, tenant2, tenant3
@pytest.fixture
def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
user = create_test_user
for tenant in tenants_fixture[:2]:
with rls_transaction(str(tenant.id)):
role = Role.objects.create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.create(
user=user,
role=role,
tenant_id=tenant.id,
)
@pytest.fixture
def invitations_fixture(create_test_user, tenants_fixture):
user = create_test_user
@@ -153,6 +322,20 @@ def invitations_fixture(create_test_user, tenants_fixture):
return valid_invitation, expired_invitation
@pytest.fixture
def users_fixture(django_user_model):
user1 = User.objects.create_user(
name="user1", email="test_unit0@prowler.com", password="S3cret"
)
user2 = User.objects.create_user(
name="user2", email="test_unit1@prowler.com", password="S3cret"
)
user3 = User.objects.create_user(
name="user3", email="test_unit2@prowler.com", password="S3cret"
)
return user1, user2, user3
@pytest.fixture
def providers_fixture(tenants_fixture):
tenant, *_ = tenants_fixture
@@ -210,6 +393,57 @@ def provider_groups_fixture(tenants_fixture):
return pgroup1, pgroup2, pgroup3
@pytest.fixture
def roles_fixture(tenants_fixture):
tenant, *_ = tenants_fixture
role1 = Role.objects.create(
name="Role One",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=False,
manage_scans=True,
unlimited_visibility=False,
)
role2 = Role.objects.create(
name="Role Two",
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
role3 = Role.objects.create(
name="Role Three",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
role4 = Role.objects.create(
name="Role Four",
tenant_id=tenant.id,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=False,
manage_integrations=False,
manage_scans=False,
unlimited_visibility=False,
)
return role1, role2, role3, role4
@pytest.fixture
def provider_secret_fixture(providers_fixture):
return tuple(
@@ -537,10 +771,107 @@ def get_api_tokens(
data=json_body,
format="vnd.api+json",
)
return response.json()["data"]["attributes"]["access"], response.json()["data"][
"attributes"
]["refresh"]
return (
response.json()["data"]["attributes"]["access"],
response.json()["data"]["attributes"]["refresh"],
)
@pytest.fixture
def scan_summaries_fixture(tenants_fixture, providers_fixture):
tenant = tenants_fixture[0]
provider = providers_fixture[0]
scan = Scan.objects.create(
name="overview scan",
provider=provider,
trigger=Scan.TriggerChoices.MANUAL,
state=StateChoices.COMPLETED,
tenant=tenant,
)
ScanSummary.objects.create(
tenant=tenant,
check_id="check1",
service="service1",
severity="high",
region="region1",
_pass=1,
fail=0,
muted=0,
total=1,
new=1,
changed=0,
unchanged=0,
fail_new=0,
fail_changed=0,
pass_new=1,
pass_changed=0,
muted_new=0,
muted_changed=0,
scan=scan,
)
ScanSummary.objects.create(
tenant=tenant,
check_id="check1",
service="service1",
severity="high",
region="region2",
_pass=0,
fail=1,
muted=1,
total=2,
new=2,
changed=0,
unchanged=0,
fail_new=1,
fail_changed=0,
pass_new=0,
pass_changed=0,
muted_new=1,
muted_changed=0,
scan=scan,
)
ScanSummary.objects.create(
tenant=tenant,
check_id="check2",
service="service2",
severity="critical",
region="region1",
_pass=1,
fail=0,
muted=0,
total=1,
new=1,
changed=0,
unchanged=0,
fail_new=0,
fail_changed=0,
pass_new=1,
pass_changed=0,
muted_new=0,
muted_changed=0,
scan=scan,
)
def get_authorization_header(access_token: str) -> dict:
return {"Authorization": f"Bearer {access_token}"}
def pytest_collection_modifyitems(items):
"""Ensure test_rbac.py is executed first."""
items.sort(key=lambda item: 0 if "test_rbac.py" in item.nodeid else 1)
def pytest_configure(config):
# Apply the mock before the test session starts. This is necessary to avoid admin error when running the
# 0004_rbac_missing_admin_roles migration
patch("api.db_router.MainRouter.admin_db", new="default").start()
def pytest_unconfigure(config):
# Stop all patches after the test session ends. This is necessary to avoid admin error when running the
# 0004_rbac_missing_admin_roles migration
patch.stopall()

View File

@@ -2,7 +2,7 @@ from celery.utils.log import get_task_logger
from django.db import transaction
from api.db_router import MainRouter
from api.db_utils import batch_delete, tenant_transaction
from api.db_utils import batch_delete, rls_transaction
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant
logger = get_task_logger(__name__)
@@ -66,7 +66,7 @@ def delete_tenant(pk: str):
deletion_summary = {}
for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
with tenant_transaction(pk):
with rls_transaction(pk):
summary = delete_provider(provider.id)
deletion_summary.update(summary)

View File

@@ -11,7 +11,7 @@ from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
generate_scan_compliance,
)
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
from api.models import (
ComplianceOverview,
Finding,
@@ -69,7 +69,7 @@ def _store_resources(
- tuple[str, str]: A tuple containing the resource UID and region.
"""
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
resource_instance, created = Resource.objects.get_or_create(
tenant_id=tenant_id,
provider=provider_instance,
@@ -86,7 +86,7 @@ def _store_resources(
resource_instance.service = finding.service_name
resource_instance.type = finding.resource_type
resource_instance.save()
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
tags = [
ResourceTag.objects.get_or_create(
tenant_id=tenant_id, key=key, value=value
@@ -122,7 +122,7 @@ def perform_prowler_scan(
unique_resources = set()
start_time = time.time()
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
scan_instance = Scan.objects.get(pk=scan_id)
scan_instance.state = StateChoices.EXECUTING
@@ -130,7 +130,7 @@ def perform_prowler_scan(
scan_instance.save()
try:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
try:
prowler_provider = initialize_prowler_provider(provider_instance)
provider_instance.connected = True
@@ -156,7 +156,7 @@ def perform_prowler_scan(
for finding in findings:
for attempt in range(CELERY_DEADLOCK_ATTEMPTS):
try:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
# Process resource
resource_uid = finding.resource_uid
if resource_uid not in resource_cache:
@@ -188,7 +188,7 @@ def perform_prowler_scan(
resource_instance.type = finding.resource_type
updated_fields.append("type")
if updated_fields:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
resource_instance.save(update_fields=updated_fields)
except (OperationalError, IntegrityError) as db_err:
if attempt < CELERY_DEADLOCK_ATTEMPTS - 1:
@@ -203,7 +203,7 @@ def perform_prowler_scan(
# Update tags
tags = []
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
for key, value in finding.resource_tags.items():
tag_key = (key, value)
if tag_key not in tag_cache:
@@ -219,7 +219,7 @@ def perform_prowler_scan(
unique_resources.add((resource_instance.uid, resource_instance.region))
# Process finding
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
finding_uid = finding.uid
if finding_uid not in last_status_cache:
most_recent_finding = (
@@ -267,7 +267,7 @@ def perform_prowler_scan(
region_dict[finding.check_id] = finding.status.value
# Update scan progress
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_instance.progress = progress
scan_instance.save()
@@ -279,7 +279,7 @@ def perform_prowler_scan(
scan_instance.state = StateChoices.FAILED
finally:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=timezone.utc)
scan_instance.unique_resource_count = len(unique_resources)
@@ -330,7 +330,7 @@ def perform_prowler_scan(
total_requirements=compliance["total_requirements"],
)
)
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
ComplianceOverview.objects.bulk_create(compliance_overview_objects)
if exception is not None:
@@ -368,7 +368,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
- muted_new: Muted findings with a delta of 'new'.
- muted_changed: Muted findings with a delta of 'changed'.
"""
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
findings = Finding.objects.filter(scan_id=scan_id)
aggregation = findings.values(
@@ -464,7 +464,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
),
)
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_aggregations = {
ScanSummary(
tenant_id=tenant_id,

View File

@@ -7,7 +7,7 @@ from tasks.jobs.connection import check_provider_connection
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Provider, Scan
@@ -99,7 +99,7 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
"""
task_id = self.request.id
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
periodic_task_instance = PeriodicTask.objects.get(
name=f"scan-perform-scheduled-{provider_id}"

View File

@@ -1,5 +1,3 @@
from unittest.mock import patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider, delete_tenant
@@ -24,7 +22,6 @@ class TestDeleteProvider:
delete_provider(non_existent_pk)
@patch("api.db_router.MainRouter.admin_db", new="default")
@pytest.mark.django_db
class TestDeleteTenant:
def test_delete_tenant_success(self, tenants_fixture, providers_fixture):

View File

@@ -1,3 +1,4 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
@@ -26,7 +27,7 @@ class TestPerformScan:
providers_fixture,
):
with (
patch("api.db_utils.tenant_transaction"),
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -165,10 +166,10 @@ class TestPerformScan:
"tasks.jobs.scan.initialize_prowler_provider",
side_effect=Exception("Connection error"),
)
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_perform_prowler_scan_no_connection(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_initialize_prowler_provider,
mock_prowler_scan_class,
tenants_fixture,
@@ -205,14 +206,14 @@ class TestPerformScan:
@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_new_resource(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -253,14 +254,14 @@ class TestPerformScan:
@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_existing_resource(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -310,14 +311,14 @@ class TestPerformScan:
@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_with_tags(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"

View File

@@ -73,6 +73,8 @@ To use each one you need to pass the proper flag to the execution. Prowler for A
- **Subscription scope permissions**: Required to launch the checks against your resources, mandatory to launch the tool. It is required to add the following RBAC builtin roles per subscription to the entity that is going to be assumed by the tool:
- `Reader`
- `ProwlerRole` (custom role defined in [prowler-azure-custom-role](https://github.com/prowler-cloud/prowler/blob/master/permissions/prowler-azure-custom-role.json))
???+ note
Please, notice that the field `assignableScopes` in the JSON custom role file must be changed to be the subscription or management group where the role is going to be assigned. The valid formats for the field are `/subscriptions/<subscription-id>` or `/providers/Microsoft.Management/managementGroups/<management-group-id>`.
To assign the permissions, follow the instructions in the [Microsoft Entra ID permissions](../tutorials/azure/create-prowler-service-principal.md#assigning-the-proper-permissions) section and the [Azure subscriptions permissions](../tutorials/azure/subscriptions.md#assigning-proper-permissions) section, respectively.

View File

@@ -42,6 +42,7 @@ Mutelist:
Resources:
- "user-1" # Will mute user-1 in check iam_user_hardware_mfa_enabled
- "user-2" # Will mute user-2 in check iam_user_hardware_mfa_enabled
Description: "Findings related with the check iam_user_hardware_mfa_enabled will be muted for us-east-1 region and user-1, user-2 resources"
"ec2_*":
Regions:
- "*"
@@ -140,6 +141,9 @@ Mutelist:
| `resource` | The resource identifier. Use `*` to apply the mutelist to all resources. | `ANDed` |
| `tag` | The tag value. | `ORed` |
### Description
This field can be used to add information or some hints for the Mutelist rule.
## How to Use the Mutelist
@@ -171,6 +175,7 @@ If you want to mute failed findings only in specific regions, create a file with
- "ap-southeast-2"
Resources:
- "*"
Description: "Description related with the muted findings for the check"
### Default Mutelist
For the AWS Provider, Prowler is executed with a default AWS Mutelist with the AWS Resources that should be muted such as all resources created by AWS Control Tower when setting up a landing zone that can be found in [AWS Documentation](https://docs.aws.amazon.com/controltower/latest/userguide/shared-account-resources.html).

View File

@@ -95,6 +95,7 @@ Resources:
- 'servicecatalog:List*'
- 'ssm:GetDocument'
- 'ssm-incidents:List*'
- 'states:ListTagsForResource'
- 'support:Describe*'
- 'tag:GetTagKeys'
- 'wellarchitected:List*'

View File

@@ -45,6 +45,7 @@
"servicecatalog:List*",
"ssm:GetDocument",
"ssm-incidents:List*",
"states:ListTagsForResource",
"support:Describe*",
"tag:GetTagKeys",
"wellarchitected:List*"

View File

@@ -3,7 +3,7 @@
"roleName": "ProwlerRole",
"description": "Role used for checks that require read-only access to Azure resources and are not covered by the Reader role.",
"assignableScopes": [
"/"
"/{'subscriptions', 'providers/Microsoft.Management/managementGroups'}/{Your Subscription or Management Group ID}"
],
"permissions": [
{

24
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
[[package]]
name = "about-time"
@@ -775,17 +775,17 @@ files = [
[[package]]
name = "boto3"
version = "1.35.78"
version = "1.35.81"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "boto3-1.35.78-py3-none-any.whl", hash = "sha256:5ef7166fe5060637b92af8dc152cd7acecf96b3fc9c5456706a886cadb534391"},
{file = "boto3-1.35.78.tar.gz", hash = "sha256:fc8001519c8842e766ad3793bde3fbd0bb39e821a582fc12cf67876b8f3cf7f1"},
{file = "boto3-1.35.81-py3-none-any.whl", hash = "sha256:742941b2424c0223d2d94a08c3485462fa7c58d816b62ca80f08e555243acee1"},
{file = "boto3-1.35.81.tar.gz", hash = "sha256:d2e95fa06f095b8e0c545dd678c6269d253809b2997c30f5ce8a956c410b4e86"},
]
[package.dependencies]
botocore = ">=1.35.78,<1.36.0"
botocore = ">=1.35.81,<1.36.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
@@ -794,13 +794,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.35.79"
version = "1.35.81"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
{file = "botocore-1.35.79-py3-none-any.whl", hash = "sha256:e6b10bb9a357e3f5ca2e60f6dd15a85d311b9a476eb21b3c0c2a3b364a2897c8"},
{file = "botocore-1.35.79.tar.gz", hash = "sha256:245bfdda1b1508539ddd1819c67a8a2cc81780adf0715d3de418d64c4247f346"},
{file = "botocore-1.35.81-py3-none-any.whl", hash = "sha256:a7b13bbd959bf2d6f38f681676aab408be01974c46802ab997617b51399239f7"},
{file = "botocore-1.35.81.tar.gz", hash = "sha256:564c2478e50179e0b766e6a87e5e0cdd35e1bc37eb375c1cf15511f5dd13600d"},
]
[package.dependencies]
@@ -2583,13 +2583,13 @@ dev = ["click", "codecov", "mkdocs-gen-files", "mkdocs-git-authors-plugin", "mkd
[[package]]
name = "mkdocs-material"
version = "9.5.48"
version = "9.5.49"
description = "Documentation that simply works"
optional = false
python-versions = ">=3.8"
files = [
{file = "mkdocs_material-9.5.48-py3-none-any.whl", hash = "sha256:b695c998f4b939ce748adbc0d3bff73fa886a670ece948cf27818fa115dc16f8"},
{file = "mkdocs_material-9.5.48.tar.gz", hash = "sha256:a582531e8b34f4c7ed38c29d5c44763053832cf2a32f7409567e0c74749a47db"},
{file = "mkdocs_material-9.5.49-py3-none-any.whl", hash = "sha256:c3c2d8176b18198435d3a3e119011922f3e11424074645c24019c2dcf08a360e"},
{file = "mkdocs_material-9.5.49.tar.gz", hash = "sha256:3671bb282b4f53a1c72e08adbe04d2481a98f85fed392530051f80ff94a9621d"},
]
[package.dependencies]
@@ -5199,4 +5199,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.13"
content-hash = "e00da6013a01923ac8e79017e7fdb221e09a3dcf581ad8d74e39550be64cc2f3"
content-hash = "aff89cb4b5b66d79efcf4e1f16fdc76e4a1019e4196cd8961299d479b99cf7fa"

View File

@@ -10,6 +10,7 @@ Mutelist:
- "*"
Resources:
- "aws-controltower-NotificationForwarder"
Description: "Checks from AWS lambda functions muted by default"
"cloudformation_stack*":
Regions:
- "*"

View File

@@ -14,6 +14,7 @@ Mutelist:
Resources:
- "user-1" # Will ignore user-1 in check iam_user_hardware_mfa_enabled
- "user-2" # Will ignore user-2 in check iam_user_hardware_mfa_enabled
Description: "Check iam_user_hardware_mfa_enabled muted for region us-east-1 and resources user-1, user-2"
"ec2_*":
Regions:
- "*"

View File

@@ -15,6 +15,7 @@ Mutelist:
Resources:
- "sqlserver1" # Will ignore sqlserver1 in check sqlserver_tde_encryption_enabled located in westeurope
- "sqlserver2" # Will ignore sqlserver2 in check sqlserver_tde_encryption_enabled located in westeurope
Description: "Findings related with the check sqlserver_tde_encryption_enabled is muted for westeurope region and sqlserver1, sqlserver2 resources"
"defender_*":
Regions:
- "*"

View File

@@ -15,6 +15,7 @@ Mutelist:
Resources:
- "instance1" # Will ignore instance1 in check compute_instance_public_ip located in europe-southwest1
- "instance2" # Will ignore instance2 in check compute_instance_public_ip located in europe-southwest1
Description: "Findings related with the check compute_instance_public_ip will be muted for europe-southwest1 region and instance1, instance2 resources"
"iam_*":
Regions:
- "*"

View File

@@ -15,6 +15,7 @@ Mutelist:
Resources:
- "prowler-pod1" # Will ignore prowler-pod1 in check core_minimize_allowPrivilegeEscalation_containers located in namespace1
- "prowler-pod2" # Will ignore prowler-pod2 in check core_minimize_allowPrivilegeEscalation_containers located in namespace1
Description: "Findings related with the check core_minimize_allowPrivilegeEscalation_containers will be muted for namespace1 region and prowler-pod1, prowler-pod2 resources"
"kubelet_*":
Regions:
- "*"

View File

@@ -15,6 +15,7 @@ mutelist_schema = Schema(
Optional("Resources"): list,
Optional("Tags"): list,
},
Optional("Description"): str,
}
}
}

View File

@@ -106,6 +106,7 @@ class Mutelist(ABC):
- 'i-123456789'
Tags:
- 'Name=AdminInstance | Environment=Prod'
Description: 'Field to describe why the findings associated with these values are muted'
```
The check `ec2_instance_detailed_monitoring_enabled` will be muted for all accounts and regions and for the resource_id 'i-123456789' with at least one of the tags 'Name=AdminInstance' or 'Environment=Prod'.

View File

@@ -908,6 +908,7 @@
"ap-southeast-2",
"ap-southeast-3",
"ap-southeast-4",
"ap-southeast-5",
"ca-central-1",
"ca-west-1",
"eu-central-1",
@@ -1260,6 +1261,15 @@
"aws-us-gov": []
}
},
"bcm-pricing-calculator": {
"regions": {
"aws": [
"us-east-1"
],
"aws-cn": [],
"aws-us-gov": []
}
},
"bedrock": {
"regions": {
"aws": [
@@ -7345,6 +7355,31 @@
]
}
},
"networkflowmonitor": {
"regions": {
"aws": [
"ap-northeast-1",
"ap-northeast-2",
"ap-northeast-3",
"ap-south-1",
"ap-southeast-1",
"ap-southeast-2",
"ca-central-1",
"eu-central-1",
"eu-north-1",
"eu-west-1",
"eu-west-2",
"eu-west-3",
"sa-east-1",
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2"
],
"aws-cn": [],
"aws-us-gov": []
}
},
"networkmanager": {
"regions": {
"aws": [
@@ -7442,6 +7477,15 @@
"aws-us-gov": []
}
},
"notificationscontacts": {
"regions": {
"aws": [
"us-east-1"
],
"aws-cn": [],
"aws-us-gov": []
}
},
"oam": {
"regions": {
"aws": [
@@ -7486,6 +7530,23 @@
]
}
},
"observabilityadmin": {
"regions": {
"aws": [
"ap-northeast-1",
"ap-southeast-1",
"ap-southeast-2",
"eu-central-1",
"eu-north-1",
"eu-west-1",
"us-east-1",
"us-east-2",
"us-west-2"
],
"aws-cn": [],
"aws-us-gov": []
}
},
"omics": {
"regions": {
"aws": [
@@ -11100,10 +11161,12 @@
"ap-southeast-3",
"ap-southeast-4",
"ca-central-1",
"ca-west-1",
"eu-central-1",
"eu-central-2",
"eu-north-1",
"eu-south-1",
"eu-south-2",
"eu-west-1",
"eu-west-2",
"eu-west-3",

View File

@@ -8,14 +8,14 @@ class backup_recovery_point_encrypted(Check):
for recovery_point in backup_client.recovery_points:
report = Check_Report_AWS(self.metadata())
report.region = recovery_point.backup_vault_region
report.resource_id = recovery_point.backup_vault_name
report.resource_id = recovery_point.id
report.resource_arn = recovery_point.arn
report.resource_tags = recovery_point.tags
report.status = "FAIL"
report.status_extended = f"Backup Recovery Point {recovery_point.arn} for Backup Vault {recovery_point.backup_vault_name} is not encrypted at rest."
report.status_extended = f"Backup Recovery Point {recovery_point.id} for Backup Vault {recovery_point.backup_vault_name} is not encrypted at rest."
if recovery_point.encrypted:
report.status = "PASS"
report.status_extended = f"Backup Recovery Point {recovery_point.arn} for Backup Vault {recovery_point.backup_vault_name} is encrypted at rest."
report.status_extended = f"Backup Recovery Point {recovery_point.id} for Backup Vault {recovery_point.backup_vault_name} is encrypted at rest."
findings.append(report)

View File

@@ -195,6 +195,7 @@ class Backup(AWSService):
self.recovery_points.append(
RecoveryPoint(
arn=arn,
id=arn.split(":")[-1],
backup_vault_name=backup_vault.name,
encrypted=recovery_point.get(
"IsEncrypted", False
@@ -246,6 +247,7 @@ class BackupReportPlan(BaseModel):
class RecoveryPoint(BaseModel):
arn: str
id: str
backup_vault_name: str
encrypted: bool
backup_vault_region: str

View File

@@ -0,0 +1,52 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.cloudtrail.cloudtrail_client import (
cloudtrail_client,
)
from prowler.providers.aws.services.s3.s3_client import s3_client
def fixer(resource_id: str, region: str) -> bool:
"""
Modify the CloudTrail's associated S3 bucket's public access settings to ensure the bucket is not publicly accessible.
Specifically, this fixer configures the S3 bucket's public access block settings to block all public access.
Requires the s3:PutBucketPublicAccessBlock permissions.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "s3:PutBucketPublicAccessBlock",
"Resource": "*"
}
]
}
Args:
resource_id (str): The CloudTrail name.
region (str): AWS region where the CloudTrail and S3 bucket exist.
Returns:
bool: True if the operation is successful (policy and ACL updated), False otherwise.
"""
try:
regional_client = s3_client.regional_clients[region]
for trail in cloudtrail_client.trails.values():
if trail.name == resource_id:
trail_bucket = trail.s3_bucket
regional_client.put_public_access_block(
Bucket=trail_bucket,
PublicAccessBlockConfiguration={
"BlockPublicAcls": True,
"IgnorePublicAcls": True,
"BlockPublicPolicy": True,
"RestrictPublicBuckets": True,
},
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -48,7 +48,7 @@ class cloudtrail_multi_region_enabled_logging_management_events(Check):
report.resource_id = trail.name
report.resource_arn = trail.arn
report.resource_tags = trail.tags
report.region = trail.home_region
report.region = region
report.status = "PASS"
if trail.is_multiregion:
report.status_extended = f"Trail {trail.name} from home region {trail.home_region} is multi-region, is logging and have management events enabled."

View File

@@ -0,0 +1,40 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
def fixer(resource_id: str, region: str) -> bool:
"""
Modify the attributes of an EC2 AMI to remove public access.
Specifically, this fixer removes the 'all' value from the 'LaunchPermission' attribute
to prevent the AMI from being publicly accessible.
Requires the ec2:ModifyImageAttribute permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:ModifyImageAttribute",
"Resource": "*"
}
]
}
Args:
resource_id (str): The ID of the EC2 AMI to make private.
region (str): AWS region where the AMI exists.
Returns:
bool: True if the operation is successful (the AMI is no longer publicly accessible), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
regional_client.modify_image_attribute(
ImageId=resource_id,
LaunchPermission={"Remove": [{"Group": "all"}]},
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Cassandra ports (7000, 7001, 7199, 9042, 9160) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies Cassandra ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [7000, 7001, 7199, 9042, 9160]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Elasticsearch and Kibana ports (9200, 9300, 5601)
from any address (0.0.0.0/0) for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies those ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [9200, 9300, 5601]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing FTP ports (20, 21) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies FTP ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [20, 21]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Kafka ports (9092) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies Kafka ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [9092]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Kerberos ports (88, 464, 749, 750) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies Kerberos ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [88, 464, 749, 750]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing LDAP ports (389, 636) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies LDAP ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [389, 636]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Memcached ports (11211) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies Memcached ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [11211]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing MongoDB ports (27017, 27018) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies MongoDB ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [27017, 27018]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing MySQL ports (3306) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies MySQL ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [3306]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Oracle ports (1521, 2483, 2484) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies Oracle ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [1521, 2483, 2484]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing PostgreSQL ports (5432) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies PostgreSQL ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [5432]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing RDP ports (3389) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies RDP ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [3389]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Redis ports (6379) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies Redis ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [6379]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing SQLServer ports (1433, 1434) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies SQLServer ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [1433, 1434]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing SSH ports (22) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies SSH ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [22]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,51 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing Telnet ports (23) from any address (0.0.0.0/0)
for the EC2 instance's security groups.
This fixer will only be triggered if the check identifies Telnet ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The EC2 instance ID.
region (str): The AWS region where the EC2 instance exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = [23]
for instance in ec2_client.instances:
if instance.id == resource_id:
for sg in ec2_client.security_groups.values():
if sg.id in instance.security_groups:
for ingress_rule in sg.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=sg.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,52 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.ec2.ec2_client import ec2_client
from prowler.providers.aws.services.ec2.lib.security_groups import check_security_group
def fixer(resource_id: str, region: str) -> bool:
"""
Revokes any ingress rule allowing high risk ports (25, 110, 135, 143, 445, 3000, 4333, 5000, 5500, 8080, 8088)
from any address (0.0.0.0/0) for the security groups.
This fixer will only be triggered if the check identifies high risk ports open to the Internet.
Requires the ec2:RevokeSecurityGroupIngress permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:RevokeSecurityGroupIngress",
"Resource": "*"
}
]
}
Args:
resource_id (str): The Security Group ID.
region (str): The AWS region where the Security Group exists.
Returns:
bool: True if the operation is successful (ingress rule revoked), False otherwise.
"""
try:
regional_client = ec2_client.regional_clients[region]
check_ports = ec2_client.audit_config.get(
"ec2_high_risk_ports",
[25, 110, 135, 143, 445, 3000, 4333, 5000, 5500, 8080, 8088],
)
for security_group in ec2_client.security_groups.values():
if security_group.id == resource_id:
for ingress_rule in security_group.ingress_rules:
if check_security_group(
ingress_rule, "tcp", check_ports, any_address=True
):
regional_client.revoke_security_group_ingress(
GroupId=security_group.id,
IpPermissions=[ingress_rule],
)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -29,7 +29,9 @@ class route53_dangling_ip_subdomain_takeover(Check):
# Check if record is an IP Address
if validate_ip_address(record):
report = Check_Report_AWS(self.metadata())
report.resource_id = f"{record_set.hosted_zone_id}/{record}"
report.resource_id = (
f"{record_set.hosted_zone_id}/{record_set.name}/{record}"
)
report.resource_arn = route53_client.hosted_zones[
record_set.hosted_zone_id
].arn

View File

@@ -0,0 +1,38 @@
from prowler.lib.logger import logger
from prowler.providers.aws.services.s3.s3_client import s3_client
def fixer(resource_id: str, region: str) -> bool:
"""
Modify the S3 bucket's policy to remove public access.
Specifically, this fixer delete the policy of the public bucket.
Requires the s3:DeleteBucketPolicy permission.
Permissions:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "s3:DeleteBucketPolicy",
"Resource": "*"
}
]
}
Args:
resource_id (str): The S3 bucket name.
region (str): AWS region where the S3 bucket exists.
Returns:
bool: True if the operation is successful (policy updated), False otherwise.
"""
try:
regional_client = s3_client.regional_clients[region]
regional_client.delete_bucket_policy(Bucket=resource_id)
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False
else:
return True

View File

@@ -0,0 +1,6 @@
from prowler.providers.aws.services.stepfunctions.stepfunctions_service import (
StepFunctions,
)
from prowler.providers.common.provider import Provider
stepfunctions_client = StepFunctions(Provider.get_global_provider())

View File

@@ -0,0 +1,320 @@
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from botocore.exceptions import ClientError
from pydantic import BaseModel, Field
from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService
class StateMachineStatus(str, Enum):
"""Enumeration of possible State Machine statuses."""
ACTIVE = "ACTIVE"
DELETING = "DELETING"
class StateMachineType(str, Enum):
"""Enumeration of possible State Machine types."""
STANDARD = "STANDARD"
EXPRESS = "EXPRESS"
class LoggingLevel(str, Enum):
"""Enumeration of possible logging levels."""
ALL = "ALL"
ERROR = "ERROR"
FATAL = "FATAL"
OFF = "OFF"
class EncryptionType(str, Enum):
"""Enumeration of possible encryption types."""
AWS_OWNED_KEY = "AWS_OWNED_KEY"
CUSTOMER_MANAGED_KMS_KEY = "CUSTOMER_MANAGED_KMS_KEY"
class CloudWatchLogsLogGroup(BaseModel):
"""
Represents a CloudWatch Logs Log Group configuration for a State Machine.
Attributes:
log_group_arn (str): The ARN of the CloudWatch Logs Log Group.
"""
log_group_arn: str
class LoggingDestination(BaseModel):
"""
Represents a logging destination for a State Machine.
Attributes:
cloud_watch_logs_log_group (CloudWatchLogsLogGroup): The CloudWatch Logs Log Group configuration.
"""
cloud_watch_logs_log_group: CloudWatchLogsLogGroup
class LoggingConfiguration(BaseModel):
"""
Represents the logging configuration for a State Machine.
Attributes:
level (LoggingLevel): The logging level.
include_execution_data (bool): Whether to include execution data in the logs.
destinations (List[LoggingDestination]): List of logging destinations.
"""
level: LoggingLevel
include_execution_data: bool
destinations: List[LoggingDestination]
class TracingConfiguration(BaseModel):
"""
Represents the tracing configuration for a State Machine.
Attributes:
enabled (bool): Whether X-Ray tracing is enabled.
"""
enabled: bool
class EncryptionConfiguration(BaseModel):
"""
Represents the encryption configuration for a State Machine.
Attributes:
kms_key_id (Optional[str]): The KMS key ID used for encryption.
kms_data_key_reuse_period_seconds (Optional[int]): The time in seconds that a KMS data key can be reused.
type (EncryptionType): The type of encryption used.
"""
kms_key_id: Optional[str]
kms_data_key_reuse_period_seconds: Optional[int]
type: EncryptionType
class StateMachine(BaseModel):
"""
Represents an AWS Step Functions State Machine.
Attributes:
id (str): The unique identifier of the state machine.
arn (str): The ARN of the state machine.
name (Optional[str]): The name of the state machine.
status (StateMachineStatus): The current status of the state machine.
definition (str): The Amazon States Language definition of the state machine.
role_arn (str): The ARN of the IAM role used by the state machine.
type (StateMachineType): The type of the state machine (STANDARD or EXPRESS).
creation_date (datetime): The creation date and time of the state machine.
region (str): The region where the state machine is.
logging_configuration (Optional[LoggingConfiguration]): The logging configuration of the state machine.
tracing_configuration (Optional[TracingConfiguration]): The tracing configuration of the state machine.
label (Optional[str]): The label associated with the state machine.
revision_id (Optional[str]): The revision ID of the state machine.
description (Optional[str]): A description of the state machine.
encryption_configuration (Optional[EncryptionConfiguration]): The encryption configuration of the state machine.
tags (List[Dict]): A list of tags associated with the state machine.
"""
id: str
arn: str
name: Optional[str] = None
status: StateMachineStatus
definition: Optional[str] = None
role_arn: Optional[str] = None
type: StateMachineType
creation_date: datetime
region: str
logging_configuration: Optional[LoggingConfiguration] = None
tracing_configuration: Optional[TracingConfiguration] = None
label: Optional[str] = None
revision_id: Optional[str] = None
description: Optional[str] = None
encryption_configuration: Optional[EncryptionConfiguration] = None
tags: List[Dict] = Field(default_factory=list)
class StepFunctions(AWSService):
"""
AWS Step Functions service class to manage state machines.
This class provides methods to list state machines, describe their details,
and list their associated tags across different AWS regions.
"""
def __init__(self, provider):
"""
Initialize the StepFunctions service.
Args:
provider: The AWS provider instance containing regional clients and audit configurations.
"""
super().__init__(__class__.__name__, provider)
self.state_machines: Dict[str, StateMachine] = {}
self.__threading_call__(self._list_state_machines)
self.__threading_call__(
self._describe_state_machine, self.state_machines.values()
)
self.__threading_call__(
self._list_state_machine_tags, self.state_machines.values()
)
def _list_state_machines(self, regional_client) -> None:
"""
List AWS Step Functions state machines in the specified region and populate the state_machines dictionary.
This function retrieves all state machines using pagination, filters them based on audit_resources if provided,
and creates StateMachine instances to store their basic information.
Args:
regional_client: The regional AWS Step Functions client used to interact with the AWS API.
"""
logger.info("StepFunctions - Listing state machines...")
try:
list_state_machines_paginator = regional_client.get_paginator(
"list_state_machines"
)
for page in list_state_machines_paginator.paginate():
for state_machine_data in page.get("stateMachines", []):
try:
arn = state_machine_data.get("stateMachineArn")
state_machine_id = (
arn.split(":")[-1].split("/")[-1] if arn else None
)
if not self.audit_resources or is_resource_filtered(
arn, self.audit_resources
):
state_machine = StateMachine(
id=state_machine_id,
arn=arn,
name=state_machine_data.get("name"),
type=StateMachineType(
state_machine_data.get("type", "STANDARD")
),
creation_date=state_machine_data.get("creationDate"),
region=regional_client.region,
status=StateMachineStatus.ACTIVE,
)
self.state_machines[arn] = state_machine
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _describe_state_machine(self, state_machine: StateMachine) -> None:
"""
Describe an AWS Step Functions state machine and update its details.
Args:
state_machine (StateMachine): The StateMachine instance to describe and update.
"""
logger.info(
f"StepFunctions - Describing state machine with ID {state_machine.id} ..."
)
try:
regional_client = self.regional_clients[state_machine.region]
response = regional_client.describe_state_machine(
stateMachineArn=state_machine.arn
)
state_machine.status = StateMachineStatus(response.get("status"))
state_machine.definition = response.get("definition")
state_machine.role_arn = response.get("roleArn")
state_machine.label = response.get("label")
state_machine.revision_id = response.get("revisionId")
state_machine.description = response.get("description")
logging_config = response.get("loggingConfiguration")
if logging_config:
state_machine.logging_configuration = LoggingConfiguration(
level=LoggingLevel(logging_config.get("level")),
include_execution_data=logging_config.get("includeExecutionData"),
destinations=[
LoggingDestination(
cloud_watch_logs_log_group=CloudWatchLogsLogGroup(
log_group_arn=dest["cloudWatchLogsLogGroup"][
"logGroupArn"
]
)
)
for dest in logging_config.get("destinations", [])
],
)
tracing_config = response.get("tracingConfiguration")
if tracing_config:
state_machine.tracing_configuration = TracingConfiguration(
enabled=tracing_config.get("enabled")
)
encryption_config = response.get("encryptionConfiguration")
if encryption_config:
state_machine.encryption_configuration = EncryptionConfiguration(
kms_key_id=encryption_config.get("kmsKeyId"),
kms_data_key_reuse_period_seconds=encryption_config.get(
"kmsDataKeyReusePeriodSeconds"
),
type=EncryptionType(encryption_config.get("type")),
)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _list_state_machine_tags(self, state_machine: StateMachine) -> None:
"""
List tags for an AWS Step Functions state machine and update the StateMachine instance.
Args:
state_machine (StateMachine): The StateMachine instance to list and update tags for.
"""
logger.info(
f"StepFunctions - Listing tags for state machine with ID {state_machine.id} ..."
)
try:
regional_client = self.regional_clients[state_machine.region]
response = regional_client.list_tags_for_resource(
resourceArn=state_machine.arn
)
state_machine.tags = response.get("tags", [])
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

View File

@@ -0,0 +1,34 @@
{
"Provider": "aws",
"CheckID": "stepfunctions_statemachine_logging_enabled",
"CheckTitle": "Step Functions state machines should have logging enabled",
"CheckType": [
"Software and Configuration Checks/AWS Security Best Practices"
],
"ServiceName": "stepfunctions",
"SubServiceName": "",
"ResourceIdTemplate": "arn:aws:states:{region}:{account-id}:stateMachine/{stateMachine-id}",
"Severity": "medium",
"ResourceType": "AwsStepFunctionStateMachine",
"Description": "This control checks if AWS Step Functions state machines have logging enabled. The control fails if the state machine doesn't have the loggingConfiguration property defined.",
"Risk": "Without logging enabled, important operational data may be lost, making it difficult to troubleshoot issues, monitor performance, and ensure compliance with auditing requirements.",
"RelatedUrl": "https://docs.aws.amazon.com/step-functions/latest/dg/logging.html",
"Remediation": {
"Code": {
"CLI": "aws stepfunctions update-state-machine --state-machine-arn <state-machine-arn> --logging-configuration file://logging-config.json",
"NativeIaC": "",
"Other": "https://docs.aws.amazon.com/securityhub/latest/userguide/stepfunctions-controls.html#stepfunctions-1",
"Terraform": "https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/sfn_state_machine#logging_configuration"
},
"Recommendation": {
"Text": "Configure logging for your Step Functions state machines to ensure that operational data is captured and available for debugging, monitoring, and auditing purposes.",
"Url": "https://docs.aws.amazon.com/step-functions/latest/dg/logging.html"
}
},
"Categories": [
"logging"
],
"DependsOn": [],
"RelatedTo": [],
"Notes": ""
}

View File

@@ -0,0 +1,45 @@
from typing import List
from prowler.lib.check.models import Check, Check_Report_AWS
from prowler.providers.aws.services.stepfunctions.stepfunctions_client import (
stepfunctions_client,
)
from prowler.providers.aws.services.stepfunctions.stepfunctions_service import (
LoggingLevel,
)
class stepfunctions_statemachine_logging_enabled(Check):
"""
Check if AWS Step Functions state machines have logging enabled.
This class verifies whether each AWS Step Functions state machine has logging enabled by checking
for the presence of a loggingConfiguration property in the state machine's configuration.
"""
def execute(self) -> List[Check_Report_AWS]:
"""
Execute the Step Functions state machines logging enabled check.
Iterates over all Step Functions state machines and generates a report indicating whether
each state machine has logging enabled.
Returns:
List[Check_Report_AWS]: A list of report objects with the results of the check.
"""
findings = []
for state_machine in stepfunctions_client.state_machines.values():
report = Check_Report_AWS(self.metadata())
report.region = state_machine.region
report.resource_id = state_machine.id
report.resource_arn = state_machine.arn
report.resource_tags = state_machine.tags
report.status = "PASS"
report.status_extended = f"Step Functions state machine {state_machine.name} has logging enabled."
if state_machine.logging_configuration.level == LoggingLevel.OFF:
report.status = "FAIL"
report.status_extended = f"Step Functions state machine {state_machine.name} does not have logging enabled."
findings.append(report)
return findings

View File

@@ -30,9 +30,9 @@ class GCPBaseException(ProwlerException):
"message": "Error testing connection to GCP",
"remediation": "Check the connection and ensure it is properly set up.",
},
(3006, "GCPLoadCredentialsFromDictError"): {
"message": "Error loading credentials from dictionary",
"remediation": "Check the credentials and ensure they are properly set up. client_id, client_secret and refresh_token are required.",
(3006, "GCPLoadADCFromDictError"): {
"message": "Error loading Application Default Credentials from dictionary",
"remediation": "Check the dictionary and ensure a valid Application Default Credentials are present with client_id, client_secret and refresh_token keys.",
},
(3007, "GCPStaticCredentialsError"): {
"message": "Error loading static credentials",
@@ -46,6 +46,10 @@ class GCPBaseException(ProwlerException):
"message": "Cloud Asset API not used",
"remediation": "Enable the Cloud Asset API for the project.",
},
(3010, "GCPLoadServiceAccountKeyFromDictError"): {
"message": "Error loading Service Account Private Key credentials from dictionary",
"remediation": "Check the dictionary and ensure it contains a Service Account Private Key.",
},
}
def __init__(self, code, file=None, original_exception=None, message=None):
@@ -111,7 +115,7 @@ class GCPTestConnectionError(GCPBaseException):
)
class GCPLoadCredentialsFromDictError(GCPCredentialsError):
class GCPLoadADCFromDictError(GCPCredentialsError):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
3006, file=file, original_exception=original_exception, message=message
@@ -137,3 +141,10 @@ class GCPCloudAssetAPINotUsedError(GCPBaseException):
super().__init__(
3009, file=file, original_exception=original_exception, message=message
)
class GCPLoadServiceAccountKeyFromDictError(GCPCredentialsError):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
3010, file=file, original_exception=original_exception, message=message
)

View File

@@ -25,7 +25,8 @@ from prowler.providers.gcp.exceptions.exceptions import (
GCPGetProjectError,
GCPHTTPError,
GCPInvalidProviderIdError,
GCPLoadCredentialsFromDictError,
GCPLoadADCFromDictError,
GCPLoadServiceAccountKeyFromDictError,
GCPNoAccesibleProjectsError,
GCPSetUpSessionError,
GCPStaticCredentialsError,
@@ -57,7 +58,6 @@ class GcpProvider(Provider):
- get_projects -> Get the projects accessible by the provided credentials
- update_projects_with_organizations -> Update the projects with organizations
- is_project_matching -> Check if the input project matches the project to match
- validate_static_arguments -> Validate the static arguments
- validate_project_id -> Validate the provider ID
"""
@@ -87,6 +87,7 @@ class GcpProvider(Provider):
client_id: str = None,
client_secret: str = None,
refresh_token: str = None,
service_account_key: dict = None,
):
"""
GCP Provider constructor
@@ -106,11 +107,12 @@ class GcpProvider(Provider):
client_id: str
client_secret: str
refresh_token: str
service_account_key: dict
Raises:
GCPNoAccesibleProjectsError if no project IDs can be accessed via Google Credentials
GCPSetUpSessionError if an error occurs during the setup session
GCPLoadCredentialsFromDictError if an error occurs during the loading credentials from dict
GCPLoadADCFromDictError if an error occurs during the loading credentials from dict
GCPGetProjectError if an error occurs during the get project
Returns:
@@ -130,6 +132,10 @@ class GcpProvider(Provider):
... client_secret="client_secret",
... refresh_token="refresh_token"
... )
- Using the service account key:
>>> GcpProvider(
... service_account_key={"service_account_key": "service_account_key"}
... )
- Using a credentials file:
>>> GcpProvider(
... credentials_file="credentials_file"
@@ -167,7 +173,10 @@ class GcpProvider(Provider):
)
self._session, self._default_project_id = self.setup_session(
credentials_file, self._impersonated_service_account, gcp_credentials
credentials_file=credentials_file,
service_account=self._impersonated_service_account,
gcp_credentials=gcp_credentials,
service_account_key=service_account_key,
)
self._project_ids = []
@@ -312,25 +321,39 @@ class GcpProvider(Provider):
@staticmethod
def setup_session(
credentials_file: str, service_account: str, gcp_credentials: dict = None
credentials_file: str,
service_account: str,
gcp_credentials: dict = None,
service_account_key: dict = None,
) -> tuple:
"""
Setup the GCP session with the provided credentials file or service account to impersonate
Args:
credentials_file: str
service_account: str
credentials_file: str -> The credentials file path used to authenticate
service_account: dict -> The service account to impersonate
gcp_credentials: dict -> The GCP credentials following the format:
{
"client_id": str,
"client_secret": str,
"refresh_token": str,
"type": str
}
service_account_key: dict -> The service account key, used to authenticate
Returns:
Credentials object and default project ID
Raises:
GCPLoadCredentialsFromDictError if an error occurs during the loading credentials from dict
GCPLoadADCFromDictError if an error occurs during the loading credentials from dict
GCPLoadServiceAccountKeyFromDictError if an error occurs during the loading credentials from the service account key
GCPSetUpSessionError if an error occurs during the setup session
Usage:
>>> GcpProvider.setup_session(credentials_file, service_account)
>>> GcpProvider.setup_session(service_account, gcp_credentials)
>>> GcpProvider.setup_session(service_account, service_account_key)
>>> GcpProvider.setup_session(credentials_file, service_account, gcp_credentials)
"""
try:
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
@@ -347,7 +370,24 @@ class GcpProvider(Provider):
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
raise GCPLoadCredentialsFromDictError(
raise GCPLoadADCFromDictError(
file=__file__, original_exception=error
)
if service_account_key:
logger.info(
"GCP provider: Setting credentials from service account key..."
)
try:
credentials, default_project_id = load_credentials_from_dict(
service_account_key, scopes=scopes
)
return credentials, default_project_id
except Exception as error:
logger.critical(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
raise GCPLoadServiceAccountKeyFromDictError(
file=__file__, original_exception=error
)
@@ -388,10 +428,11 @@ class GcpProvider(Provider):
credentials_file: str = None,
service_account: str = None,
raise_on_exception: bool = True,
provider_id: Optional[str] = None,
client_id: str = None,
client_secret: str = None,
refresh_token: str = None,
provider_id: Optional[str] = None,
service_account_key: dict = None,
) -> Connection:
"""
Test the connection to GCP with the provided credentials file or service account to impersonate.
@@ -403,16 +444,18 @@ class GcpProvider(Provider):
credentials_file: str
service_account: str
raise_on_exception: bool
provider_id: Optional[str] -> The provider ID, for GCP it is the project ID
client_id: str
client_secret: str
refresh_token: str
provider_id: Optional[str] -> The provider ID, for GCP it is the project ID
service_account_key: dict
Returns:
Connection object with is_connected set to True if the connection is successful, or error set to the exception if the connection fails
Raises:
GCPLoadCredentialsFromDictError if an error occurs during the loading credentials from dict
GCPLoadADCFromDictError if an error occurs during the loading credentials from dict
GCPLoadServiceAccountKeyFromDictError if an error occurs during the loading credentials from dict
GCPSetUpSessionError if an error occurs during the setup session
GCPCloudResourceManagerAPINotUsedError if the Cloud Resource Manager API has not been used before or it is disabled
GCPInvalidProviderIdError if the provider ID does not match with the expected project_id
@@ -425,10 +468,6 @@ class GcpProvider(Provider):
... client_secret="client_secret",
... refresh_token="refresh_token"
... )
- Using a Service Account credentials file path:
>>> GcpProvider.test_connection(
... credentials_file="credentials_file"
... )
- Using ADC credentials with a Service Account to impersonate:
>>> GcpProvider.test_connection(
... client_id="client_id",
@@ -436,6 +475,14 @@ class GcpProvider(Provider):
... refresh_token="refresh_token",
... service_account="service_account"
... )
- Using service account key:
>>> GcpProvider.test_connection(
... service_account_key={"service_account_key": "service_account_key"}
... )
- Using a Service Account credentials file path:
>>> GcpProvider.test_connection(
... credentials_file="credentials_file"
... )
"""
try:
# Set the GCP credentials using the provided client_id, client_secret and refresh_token from ADC
@@ -444,9 +491,11 @@ class GcpProvider(Provider):
gcp_credentials = GcpProvider.validate_static_arguments(
client_id, client_secret, refresh_token
)
session, project_id = GcpProvider.setup_session(
credentials_file, service_account, gcp_credentials
credentials_file=credentials_file,
service_account=service_account,
gcp_credentials=gcp_credentials,
service_account_key=service_account_key,
)
if provider_id and project_id != provider_id:
# Logic to check if the provider ID matches the project ID
@@ -460,7 +509,7 @@ class GcpProvider(Provider):
return Connection(is_connected=True)
# Errors from setup_session
except GCPLoadCredentialsFromDictError as load_credentials_error:
except GCPLoadServiceAccountKeyFromDictError as load_credentials_error:
logger.critical(
f"{load_credentials_error.__class__.__name__}[{load_credentials_error.__traceback__.tb_lineno}]: {load_credentials_error}"
)
@@ -741,18 +790,14 @@ class GcpProvider(Provider):
) -> dict:
"""
Validate the static arguments client_id, client_secret and refresh_token of ADC credentials
Args:
client_id: str
client_secret: str
refresh_token: str
Returns:
dict
Raises:
GCPStaticCredentialsError if any of the static arguments is missing from the ADC credentials
Usage:
>>> GcpProvider.validate_static_arguments(client_id, client_secret, refresh_token)
"""

View File

@@ -10,7 +10,7 @@ class compute_instance_block_project_wide_ssh_keys_disabled(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "FAIL"
report.status_extended = f"The VM Instance {instance.name} is making use of common/shared project-wide SSH key(s)."
if instance.metadata.get("items"):

View File

@@ -10,7 +10,7 @@ class compute_instance_confidential_computing_enabled(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "PASS"
report.status_extended = (
f"VM Instance {instance.name} has Confidential Computing enabled."

View File

@@ -10,7 +10,7 @@ class compute_instance_default_service_account_in_use(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "PASS"
report.status_extended = f"The default service account is not configured to be used with VM Instance {instance.name}."
if (

View File

@@ -10,7 +10,7 @@ class compute_instance_default_service_account_in_use_with_full_api_access(Check
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "PASS"
report.status_extended = f"The VM Instance {instance.name} is not configured to use the default service account with full access to all cloud APIs."
for service_account in instance.service_accounts:

View File

@@ -10,7 +10,7 @@ class compute_instance_encryption_with_csek_enabled(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "FAIL"
report.status_extended = f"The VM Instance {instance.name} has the following unencrypted disks: '{', '.join([i[0] for i in instance.disks_encryption if not i[1]])}'."
if all([i[1] for i in instance.disks_encryption]):

View File

@@ -10,7 +10,7 @@ class compute_instance_ip_forwarding_is_enabled(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "PASS"
report.status_extended = (
f"The IP Forwarding of VM Instance {instance.name} is not enabled."

View File

@@ -10,7 +10,7 @@ class compute_instance_public_ip(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "PASS"
report.status_extended = (
f"VM Instance {instance.name} does not have a public IP."

View File

@@ -10,7 +10,7 @@ class compute_instance_serial_ports_in_use(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "PASS"
report.status_extended = f"VM Instance {instance.name} has Enable Connecting to Serial Ports off."
if instance.metadata.get("items"):

View File

@@ -10,7 +10,7 @@ class compute_instance_shielded_vm_enabled(Check):
report.project_id = instance.project_id
report.resource_id = instance.id
report.resource_name = instance.name
report.location = instance.zone
report.location = instance.region
report.status = "PASS"
report.status_extended = f"VM Instance {instance.name} has vTPM or Integrity Monitoring set to on."
if (

View File

@@ -18,7 +18,7 @@ class compute_network_default_in_use(Check):
report.project_id = project
report.resource_id = "default"
report.resource_name = "default"
report.location = "global"
report.location = compute_client.region
if project in projects_with_default_network:
report.status = "FAIL"

View File

@@ -9,7 +9,7 @@ class compute_project_os_login_enabled(Check):
report = Check_Report_GCP(self.metadata())
report.project_id = project.id
report.resource_id = project.id
report.location = "global"
report.location = compute_client.region
report.status = "PASS"
report.status_extended = f"Project {project.id} has OS Login enabled."
if not project.enable_oslogin:

View File

@@ -101,6 +101,7 @@ class Compute(GCPService):
name=instance["name"],
id=instance["id"],
zone=zone,
region=zone.rsplit("-", 1)[0],
public_ip=public_ip,
metadata=instance.get("metadata", {}),
shielded_enabled_vtpm=instance.get(
@@ -306,6 +307,7 @@ class Instance(BaseModel):
name: str
id: str
zone: str
region: str
public_ip: bool
project_id: str
metadata: dict

View File

@@ -10,6 +10,7 @@ class dataproc_encrypted_with_cmks_disabled(Check):
report.project_id = cluster.project_id
report.resource_id = cluster.id
report.resource_name = cluster.name
report.location = dataproc_client.region
report.status = "PASS"
report.status_extended = f"Dataproc cluster {cluster.name} is encrypted with customer managed encryption keys."
if cluster.encryption_config.get("gcePdKmsKeyName") is None:

View File

@@ -10,6 +10,7 @@ class dns_dnssec_disabled(Check):
report.project_id = managed_zone.project_id
report.resource_id = managed_zone.id
report.resource_name = managed_zone.name
report.location = dns_client.region
report.status = "PASS"
report.status_extended = (
f"Cloud DNS {managed_zone.name} has DNSSEC enabled."

View File

@@ -10,6 +10,7 @@ class dns_rsasha1_in_use_to_key_sign_in_dnssec(Check):
report.project_id = managed_zone.project_id
report.resource_id = managed_zone.id
report.resource_name = managed_zone.name
report.location = dns_client.region
report.status = "PASS"
report.status_extended = f"Cloud DNS {managed_zone.name} is not using RSASHA1 algorithm as key signing."
if any(

View File

@@ -10,6 +10,7 @@ class dns_rsasha1_in_use_to_zone_sign_in_dnssec(Check):
report.project_id = managed_zone.project_id
report.resource_id = managed_zone.id
report.resource_name = managed_zone.name
report.location = dns_client.region
report.status = "PASS"
report.status_extended = f"Cloud DNS {managed_zone.name} is not using RSASHA1 algorithm as zone signing."
if any(

View File

@@ -10,7 +10,7 @@ class gke_cluster_no_default_service_account(Check):
report.project_id = cluster.project_id
report.resource_id = cluster.id
report.resource_name = cluster.name
report.location = cluster.location
report.location = cluster.region
report.status = "PASS"
report.status_extended = f"GKE cluster {cluster.name} is not using the Compute Engine default service account."
if not cluster.node_pools and cluster.service_account == "default":

View File

@@ -60,6 +60,7 @@ class GKE(GCPService):
name=cluster["name"],
id=cluster["id"],
location=cluster["location"],
region=cluster["location"].rsplit("-", 1)[0],
service_account=cluster["nodeConfig"]["serviceAccount"],
node_pools=node_pools,
project_id=location.project_id,
@@ -85,6 +86,7 @@ class NodePool(BaseModel):
class Cluster(BaseModel):
name: str
id: str
region: str
location: str
service_account: str
node_pools: list[NodePool]

View File

@@ -48,8 +48,8 @@ azure-mgmt-storage = "21.2.1"
azure-mgmt-subscription = "3.1.1"
azure-mgmt-web = "7.3.1"
azure-storage-blob = "12.24.0"
boto3 = "1.35.78"
botocore = "1.35.79"
boto3 = "1.35.81"
botocore = "1.35.81"
colorama = "0.4.6"
cryptography = "43.0.1"
dash = "2.18.2"
@@ -100,7 +100,7 @@ optional = true
[tool.poetry.group.docs.dependencies]
mkdocs = "1.6.1"
mkdocs-git-revision-date-localized-plugin = "1.3.0"
mkdocs-material = "9.5.48"
mkdocs-material = "9.5.49"
mkdocs-material-extensions = "1.3.1"
[tool.poetry.scripts]

View File

@@ -94,12 +94,15 @@ class Test_backup_recovery_point_encrypted:
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
),
):
# Test Check
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
@@ -124,12 +127,15 @@ class Test_backup_recovery_point_encrypted:
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
),
):
# Test Check
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
@@ -142,9 +148,9 @@ class Test_backup_recovery_point_encrypted:
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].status_extended == (
"Backup Recovery Point arn:aws:backup:eu-west-1:123456789012:recovery-point:1 for Backup Vault Test Vault is not encrypted at rest."
"Backup Recovery Point 1 for Backup Vault Test Vault is not encrypted at rest."
)
assert result[0].resource_id == "Test Vault"
assert result[0].resource_id == "1"
assert (
result[0].resource_arn
== "arn:aws:backup:eu-west-1:123456789012:recovery-point:1"
@@ -165,12 +171,15 @@ class Test_backup_recovery_point_encrypted:
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
),
):
# Test Check
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
@@ -183,9 +192,9 @@ class Test_backup_recovery_point_encrypted:
assert len(result) == 1
assert result[0].status == "PASS"
assert result[0].status_extended == (
"Backup Recovery Point arn:aws:backup:eu-west-1:123456789012:recovery-point:1 for Backup Vault Test Vault is encrypted at rest."
"Backup Recovery Point 1 for Backup Vault Test Vault is encrypted at rest."
)
assert result[0].resource_id == "Test Vault"
assert result[0].resource_id == "1"
assert (
result[0].resource_arn
== "arn:aws:backup:eu-west-1:123456789012:recovery-point:1"

View File

@@ -0,0 +1,145 @@
from unittest import mock
import botocore
import botocore.client
from boto3 import client
from moto import mock_aws
from tests.providers.aws.utils import (
AWS_REGION_EU_WEST_1,
AWS_REGION_US_EAST_1,
set_mocked_aws_provider,
)
mock_make_api_call = botocore.client.BaseClient._make_api_call
def mock_make_api_call_error(self, operation_name, kwarg):
if operation_name == "PutPublicAccessBlock":
raise botocore.exceptions.ClientError(
{
"Error": {
"Code": "InvalidPermission.NotFound",
"Message": "The specified rule does not exist in this security group.",
}
},
operation_name,
)
return mock_make_api_call(self, operation_name, kwarg)
class Test_cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer:
@mock_aws
def test_trail_bucket_public_acl(self):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_US_EAST_1, AWS_REGION_EU_WEST_1]
)
s3_client = client("s3", region_name=AWS_REGION_US_EAST_1)
bucket_name_us = "bucket_test_us"
s3_client.create_bucket(Bucket=bucket_name_us)
s3_client.put_bucket_acl(
AccessControlPolicy={
"Grants": [
{
"Grantee": {
"DisplayName": "test",
"EmailAddress": "",
"ID": "test_ID",
"Type": "Group",
"URI": "http://acs.amazonaws.com/groups/global/AllUsers",
},
"Permission": "READ",
},
],
"Owner": {"DisplayName": "test", "ID": "test_id"},
},
Bucket=bucket_name_us,
)
trail_name_us = "trail_test_us"
cloudtrail_client = client("cloudtrail", region_name=AWS_REGION_US_EAST_1)
cloudtrail_client.create_trail(
Name=trail_name_us, S3BucketName=bucket_name_us, IsMultiRegionTrail=False
)
from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
Cloudtrail,
)
from prowler.providers.aws.services.s3.s3_service import S3
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.cloudtrail_client",
new=Cloudtrail(aws_provider),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.s3_client",
new=S3(aws_provider),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer import (
fixer,
)
assert fixer(trail_name_us, AWS_REGION_US_EAST_1)
@mock_aws
def test_trail_bucket_public_acl_error(self):
with mock.patch(
"botocore.client.BaseClient._make_api_call", new=mock_make_api_call_error
):
aws_provider = set_mocked_aws_provider(
[AWS_REGION_US_EAST_1, AWS_REGION_EU_WEST_1]
)
s3_client = client("s3", region_name=AWS_REGION_US_EAST_1)
bucket_name_us = "bucket_test_us"
s3_client.create_bucket(Bucket=bucket_name_us)
s3_client.put_bucket_acl(
AccessControlPolicy={
"Grants": [
{
"Grantee": {
"DisplayName": "test",
"EmailAddress": "",
"ID": "test_ID",
"Type": "Group",
"URI": "http://acs.amazonaws.com/groups/global/AllUsers",
},
"Permission": "READ",
},
],
"Owner": {"DisplayName": "test", "ID": "test_id"},
},
Bucket=bucket_name_us,
)
trail_name_us = "trail_test_us"
cloudtrail_client = client("cloudtrail", region_name=AWS_REGION_US_EAST_1)
cloudtrail_client.create_trail(
Name=trail_name_us,
S3BucketName=bucket_name_us,
IsMultiRegionTrail=False,
)
from prowler.providers.aws.services.cloudtrail.cloudtrail_service import (
Cloudtrail,
)
from prowler.providers.aws.services.s3.s3_service import S3
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.cloudtrail_client",
new=Cloudtrail(aws_provider),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer.s3_client",
new=S3(aws_provider),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_logs_s3_bucket_is_not_publicly_accessible.cloudtrail_logs_s3_bucket_is_not_publicly_accessible_fixer import (
fixer,
)
assert not fixer(trail_name_us, AWS_REGION_US_EAST_1)

Some files were not shown because too many files have changed in this diff Show More