Compare commits

...

26 Commits

Author SHA1 Message Date
Prowler Bot
5d41c6a0a5 feat(celery): Add configurable broker visibility timeout setting (#6246)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
2024-12-19 00:05:38 +05:45
Prowler Bot
29dad4e8aa fix(.env): remove comment (#6242)
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2024-12-18 11:15:59 -05:00
Prowler Bot
a1e53ef0fc chore(rls): rename tenant_transaction to rls_transaction (#6203)
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2024-12-16 12:40:02 +01:00
Prowler Bot
dfed6ac248 fix(RLS): enforce config security (#6190)
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2024-12-16 11:39:05 +01:00
Sergio Garcia
c930416260 chore(version): update Prowler version (#6196) 2024-12-16 08:31:16 +01:00
Prowler Bot
83ffd78e63 chore(deps): bump cross-spawn from 7.0.3 to 7.0.6 in /ui (#6176)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-13 15:46:06 +01:00
Prowler Bot
1045ffe489 fix(aws): set unique resource IDs (#6192)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2024-12-13 09:07:57 -04:00
Prowler Bot
5af81b9b6d chore(deps): bump nanoid from 3.3.7 to 3.3.8 in /ui (#6175)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-13 09:13:20 +01:00
Prowler Bot
f95394bec0 chore: delete unneeded requirements file (#6058)
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2024-12-13 07:58:21 +01:00
Prowler Bot
0a865f8950 fix(tenant): fix delete tenants behavior (#6014)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
2024-12-13 07:56:46 +01:00
Prowler Bot
68d7f140ff fix(deploy): temporal fix for the alpine-python segmentation fault (#6115)
Co-authored-by: Adrián Jesús Peña Rodríguez <adrianjpr@gmail.com>
2024-12-13 07:56:19 +01:00
Prowler Bot
6ed237b49c feat(users): user detail can be edited now properly (#6137)
Co-authored-by: Pablo Lara <larabjj@gmail.com>
2024-12-13 07:55:35 +01:00
Prowler Bot
51c2158563 fix(rds): add invalid SG to status_extended (#6170)
Co-authored-by: Pedro Martín <pedromarting3@gmail.com>
2024-12-12 12:47:11 -04:00
Prowler Bot
dbb348fb09 fix(aurora): Add default ports to the check of using non default ports (#6151)
Co-authored-by: Mads Brouer Lundholm <mads@madslundholm.dk>
2024-12-11 13:49:02 -04:00
Prowler Bot
405dc9c507 fix(autoscaling): autoscaling_group_launch_configuration_requires_imdsv2 fails if Launch Template is used (#6147)
Co-authored-by: Daniel Barranquero <74871504+danibarranqueroo@users.noreply.github.com>
2024-12-11 12:06:39 -04:00
Prowler Bot
40004ebb99 fix(app): add support for TLS 1.3 to Web Apps check (#6144)
Co-authored-by: Rubén De la Torre Vico <rubendltv22@gmail.com>
2024-12-11 10:28:41 -04:00
Prowler Bot
0556f30670 fix(iam): set unique resource id for each user access key (#6134)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2024-12-11 09:25:29 -04:00
Prowler Bot
1723ac6a6a fix(compliance_tables): add correct values for findings (#6127)
Co-authored-by: Pedro Martín <pedromarting3@gmail.com>
2024-12-10 16:47:28 -04:00
Prowler Bot
7b308bf5f4 fix(aws): get firewall manager managed rule groups (#6124)
Co-authored-by: Hugo Pereira Brito <101209179+HugoPBrito@users.noreply.github.com>
2024-12-10 16:46:48 -04:00
Prowler Bot
d4e9940beb fix(aws): check AWS Owned keys in firehose_stream_encrypted_at_rest (#6121)
Co-authored-by: Hugo Pereira Brito <101209179+HugoPBrito@users.noreply.github.com>
2024-12-10 14:30:41 -04:00
Prowler Bot
8558034eae fix(aws): set IAM identity as resource in threat detection (#6118)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2024-12-10 13:37:38 -04:00
Prowler Bot
a6b4c27262 fix(gcp): make sure default project is active (#6113)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2024-12-10 11:53:32 -04:00
Prowler Bot
159aa8b464 fix(aws): set same severity for EC2 IMDSv2 checks (#6104)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2024-12-10 09:30:17 -04:00
Prowler Bot
293c822c3d fix(backup): modify list recovery points call (#6096)
Co-authored-by: Daniel Barranquero <74871504+danibarranqueroo@users.noreply.github.com>
2024-12-09 17:26:11 -04:00
Prowler Bot
649ec19012 chore(actions): standardize names (#6092)
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2024-12-09 16:33:43 -04:00
Prowler Bot
e04e5d3b18 fix(invitations): remove wrong url (#6012)
Co-authored-by: Pablo Lara <larabjj@gmail.com>
2024-12-05 10:56:46 +01:00
85 changed files with 2757 additions and 1604 deletions

10
.env
View File

@@ -40,9 +40,12 @@ DJANGO_LOGGING_FORMATTER=human_readable
# Select one of [DEBUG|INFO|WARNING|ERROR|CRITICAL]
# Applies to both Django and Celery Workers
DJANGO_LOGGING_LEVEL=INFO
DJANGO_WORKERS=4 # Defaults to the maximum available based on CPU cores if not set.
DJANGO_ACCESS_TOKEN_LIFETIME=30 # Token lifetime is in minutes
DJANGO_REFRESH_TOKEN_LIFETIME=1440 # Token lifetime is in minutes
# Defaults to the maximum available based on CPU cores if not set.
DJANGO_WORKERS=4
# Token lifetime is in minutes
DJANGO_ACCESS_TOKEN_LIFETIME=30
# Token lifetime is in minutes
DJANGO_REFRESH_TOKEN_LIFETIME=1440
DJANGO_CACHE_MAX_AGE=3600
DJANGO_STALE_WHILE_REVALIDATE=60
DJANGO_MANAGE_DB_PARTITIONS=True
@@ -87,3 +90,4 @@ jQIDAQAB
-----END PUBLIC KEY-----"
# openssl rand -base64 32
DJANGO_SECRETS_ENCRYPTION_KEY="oE/ltOhp/n1TdbHjVmzcjDPLcLA41CVI/4Rk+UB5ESc="
DJANGO_BROKER_VISIBILITY_TIMEOUT=86400

View File

@@ -1,3 +1,3 @@
name: "Custom CodeQL Config for API"
name: "API - CodeQL Config"
paths:
- 'api/'
- "api/"

View File

@@ -1,4 +0,0 @@
name: "Custom CodeQL Config"
paths-ignore:
- 'api/'
- 'ui/'

4
.github/codeql/sdk-codeql-config.yml vendored Normal file
View File

@@ -0,0 +1,4 @@
name: "SDK - CodeQL Config"
paths-ignore:
- "api/"
- "ui/"

View File

@@ -1,3 +1,3 @@
name: "Custom CodeQL Config for UI"
name: "UI - CodeQL Config"
paths:
- "ui/"

View File

@@ -9,11 +9,11 @@
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "API - CodeQL"
name: API - CodeQL
on:
push:
branches:
branches:
- "master"
- "v3"
- "v4.*"
@@ -21,7 +21,7 @@ on:
paths:
- "api/**"
pull_request:
branches:
branches:
- "master"
- "v3"
- "v4.*"

View File

@@ -1,4 +1,4 @@
name: "API - Pull Request"
name: API - Pull Request
on:
push:

View File

@@ -1,4 +1,4 @@
name: Automatic Backport
name: Prowler - Automatic Backport
on:
pull_request_target:

View File

@@ -1,4 +1,4 @@
name: Pull Request Documentation Link
name: Prowler - Pull Request Documentation Link
on:
pull_request:

View File

@@ -1,4 +1,4 @@
name: Find secrets
name: Prowler - Find secrets
on: pull_request
@@ -16,4 +16,4 @@ jobs:
path: ./
base: ${{ github.event.repository.default_branch }}
head: HEAD
extra_args: --only-verified
extra_args: --only-verified

View File

@@ -1,4 +1,4 @@
name: "Pull Request Labeler"
name: Prowler - PR Labeler
on:
pull_request_target:

View File

@@ -1,4 +1,4 @@
name: Build and Push containers
name: SDK - Build and Push containers
on:
push:
@@ -85,8 +85,8 @@ jobs:
echo "STABLE_TAG=v3-stable" >> "${GITHUB_ENV}"
;;
4)
4)
echo "LATEST_TAG=v4-latest" >> "${GITHUB_ENV}"
echo "STABLE_TAG=v4-stable" >> "${GITHUB_ENV}"
;;

View File

@@ -9,22 +9,24 @@
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
name: SDK - CodeQL
on:
push:
branches:
branches:
- "master"
- "v3"
- "v4.*"
- "v5.*"
paths-ignore:
- 'ui/**'
- 'api/**'
pull_request:
branches:
branches:
- "master"
- "v3"
- "v4.*"
- "v5.*"
paths-ignore:
- 'ui/**'
- 'api/**'
@@ -55,7 +57,7 @@ jobs:
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
config-file: ./.github/codeql/codeql-config.yml
config-file: ./.github/codeql/sdk-codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3

View File

@@ -1,4 +1,4 @@
name: "Pull Request"
name: SDK - Pull Request
on:
push:

View File

@@ -1,4 +1,4 @@
name: PyPI release
name: SDK - PyPI release
on:
release:

View File

@@ -1,6 +1,6 @@
# This is a basic workflow to help you get started with Actions
name: Refresh regions of AWS services
name: SDK - Refresh AWS services' regions
on:
schedule:

View File

@@ -9,7 +9,7 @@
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "UI - CodeQL"
name: UI - CodeQL
on:
push:

View File

@@ -1,4 +1,4 @@
name: "UI - Pull Request"
name: UI - Pull Request
on:
pull_request:
@@ -31,4 +31,4 @@ jobs:
run: npm run healthcheck
- name: Build the application
working-directory: ./ui
run: npm run build
run: npm run build

View File

@@ -1,4 +1,4 @@
FROM python:3.12-alpine
FROM python:3.12.8-alpine3.20
LABEL maintainer="https://github.com/prowler-cloud/prowler"

View File

@@ -22,6 +22,7 @@ DJANGO_SECRETS_ENCRYPTION_KEY=""
# Decide whether to allow Django manage database table partitions
DJANGO_MANAGE_DB_PARTITIONS=[True|False]
DJANGO_CELERY_DEADLOCK_ATTEMPTS=5
DJANGO_BROKER_VISIBILITY_TIMEOUT=86400
# PostgreSQL settings
# If running django and celery on host, use 'localhost', else use 'postgres-db'

1956
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -27,7 +27,7 @@ drf-nested-routers = "^0.94.1"
drf-spectacular = "0.27.2"
drf-spectacular-jsonapi = "0.5.1"
gunicorn = "23.0.0"
prowler = {git = "https://github.com/prowler-cloud/prowler.git", branch = "master"}
prowler = {git = "https://github.com/prowler-cloud/prowler.git", tag = "5.0.0"}
psycopg2-binary = "2.9.9"
pytest-celery = {extras = ["redis"], version = "^1.0.1"}
# Needed for prowler compatibility

View File

@@ -1,14 +1,12 @@
import uuid
from django.db import transaction, connection
from django.db import transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
from rest_framework.filters import SearchFilter
from rest_framework_json_api import filters
from rest_framework_json_api.serializers import ValidationError
from rest_framework_json_api.views import ModelViewSet
from rest_framework_simplejwt.authentication import JWTAuthentication
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
@@ -47,13 +45,7 @@ class BaseRLSViewSet(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
@@ -69,10 +61,25 @@ class BaseTenantViewset(BaseViewSet):
return super().dispatch(request, *args, **kwargs)
def initial(self, request, *args, **kwargs):
user_id = str(request.user.id)
if (
request.resolver_match.url_name != "tenant-detail"
and request.method != "DELETE"
):
user_id = str(request.user.id)
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
return super().initial(request, *args, **kwargs)
# TODO: DRY this when we have time
if request.auth is None:
raise NotAuthenticated
tenant_id = request.auth.get("tenant_id")
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
@@ -92,12 +99,6 @@ class BaseUserViewset(BaseViewSet):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

View File

@@ -1,4 +1,5 @@
import secrets
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
@@ -8,6 +9,7 @@ from django.core.paginator import Paginator
from django.db import connection, models, transaction
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError
DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
@@ -23,6 +25,8 @@ TASK_RUNNER_DB_TABLE = "django_celery_results_taskresult"
POSTGRES_TENANT_VAR = "api.tenant_id"
POSTGRES_USER_VAR = "api.user_id"
SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"
@contextmanager
def psycopg_connection(database_alias: str):
@@ -44,10 +48,23 @@ def psycopg_connection(database_alias: str):
@contextmanager
def tenant_transaction(tenant_id: str):
def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
"""
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
if the value is a valid UUID.
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
"""
with transaction.atomic():
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
try:
# just in case the value is an UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor

View File

@@ -1,6 +1,10 @@
import uuid
from functools import wraps
from django.db import connection, transaction
from rest_framework_json_api.serializers import ValidationError
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
def set_tenant(func):
@@ -31,7 +35,7 @@ def set_tenant(func):
pass
# When calling the task
some_task.delay(arg1, tenant_id="1234-abcd-5678")
some_task.delay(arg1, tenant_id="8db7ca86-03cc-4d42-99f6-5e480baf6ab5")
# The tenant context will be set before the task logic executes.
"""
@@ -43,9 +47,12 @@ def set_tenant(func):
tenant_id = kwargs.pop("tenant_id")
except KeyError:
raise KeyError("This task requires the tenant_id")
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
return func(*args, **kwargs)

View File

@@ -2,7 +2,7 @@ from contextlib import nullcontext
from rest_framework_json_api.renderers import JSONRenderer
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
class APIJSONRenderer(JSONRenderer):
@@ -13,9 +13,9 @@ class APIJSONRenderer(JSONRenderer):
tenant_id = getattr(request, "tenant_id", None) if request else None
include_param_present = "include" in request.query_params if request else False
# Use tenant_transaction if needed for included resources, otherwise do nothing
# Use rls_transaction if needed for included resources, otherwise do nothing
context_manager = (
tenant_transaction(tenant_id)
rls_transaction(tenant_id)
if tenant_id and include_param_present
else nullcontext()
)

View File

@@ -1,7 +1,9 @@
from unittest.mock import patch, call
import uuid
from unittest.mock import call, patch
import pytest
from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.decorators import set_tenant
@@ -15,12 +17,12 @@ class TestSetTenantDecorator:
def random_func(arg):
return arg
tenant_id = "1234-abcd-5678"
tenant_id = str(uuid.uuid4())
result = random_func("test_arg", tenant_id=tenant_id)
assert (
call(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
call(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
in mock_cursor.execute.mock_calls
)
assert result == "test_arg"

View File

@@ -418,13 +418,24 @@ class TestTenantViewSet:
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_tenants_delete(self, authenticated_client, tenants_fixture):
@patch("api.db_router.MainRouter.admin_db", new="default")
@patch("api.v1.views.delete_tenant_task.apply_async")
def test_tenants_delete(
self, delete_tenant_mock, authenticated_client, tenants_fixture
):
def _delete_tenant(kwargs):
Tenant.objects.filter(pk=kwargs.get("tenant_id")).delete()
delete_tenant_mock.side_effect = _delete_tenant
tenant1, *_ = tenants_fixture
response = authenticated_client.delete(
reverse("tenant-detail", kwargs={"pk": tenant1.id})
)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert Tenant.objects.count() == len(tenants_fixture) - 1
assert Membership.objects.filter(tenant_id=tenant1.id).count() == 0
# User is not deleted because it has another membership
assert User.objects.count() == 1
def test_tenants_delete_invalid(self, authenticated_client):
response = authenticated_client.delete(

View File

@@ -31,6 +31,7 @@ from tasks.beat import schedule_provider_scan
from tasks.tasks import (
check_provider_connection_task,
delete_provider_task,
delete_tenant_task,
perform_scan_summary_task,
perform_scan_task,
)
@@ -171,7 +172,7 @@ class SchemaView(SpectacularAPIView):
def get(self, request, *args, **kwargs):
spectacular_settings.TITLE = "Prowler API"
spectacular_settings.VERSION = "1.0.0"
spectacular_settings.VERSION = "1.0.1"
spectacular_settings.DESCRIPTION = (
"Prowler API specification.\n\nThis file is auto-generated."
)
@@ -401,6 +402,25 @@ class TenantViewSet(BaseTenantViewset):
)
return Response(data=serializer.data, status=status.HTTP_201_CREATED)
def destroy(self, request, *args, **kwargs):
# This will perform validation and raise a 404 if the tenant does not exist
tenant_id = kwargs.get("pk")
get_object_or_404(Tenant, id=tenant_id)
with transaction.atomic():
# Delete memberships
Membership.objects.using(MainRouter.admin_db).filter(
tenant_id=tenant_id
).delete()
# Delete users without memberships
User.objects.using(MainRouter.admin_db).filter(
membership__isnull=True
).delete()
# Delete tenant in batches
delete_tenant_task.apply_async(kwargs={"tenant_id": tenant_id})
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
list=extend_schema(

View File

@@ -1,10 +1,21 @@
from celery import Celery, Task
from config.env import env
BROKER_VISIBILITY_TIMEOUT = env.int("DJANGO_BROKER_VISIBILITY_TIMEOUT", default=86400)
celery_app = Celery("tasks")
celery_app.config_from_object("django.conf:settings", namespace="CELERY")
celery_app.conf.update(result_extended=True, result_expires=None)
celery_app.conf.broker_transport_options = {
"visibility_timeout": BROKER_VISIBILITY_TIMEOUT
}
celery_app.conf.result_backend_transport_options = {
"visibility_timeout": BROKER_VISIBILITY_TIMEOUT
}
celery_app.conf.visibility_timeout = BROKER_VISIBILITY_TIMEOUT
celery_app.autodiscover_tasks(["api"])
@@ -35,10 +46,10 @@ class RLSTask(Task):
**options,
)
task_result_instance = TaskResult.objects.get(task_id=result.task_id)
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
tenant_id = kwargs.get("tenant_id")
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
APITask.objects.create(
id=task_result_instance.task_id,
tenant_id=tenant_id,

View File

@@ -1,35 +1,34 @@
import logging
from datetime import datetime, timedelta, timezone
import pytest
from django.conf import settings
from datetime import datetime, timezone, timedelta
from django.db import connections as django_connections, connection as django_connection
from django.db import connection as django_connection
from django.db import connections as django_connections
from django.urls import reverse
from django_celery_results.models import TaskResult
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
from rest_framework import status
from rest_framework.test import APIClient
from api.models import (
ComplianceOverview,
Finding,
)
from api.models import (
User,
Invitation,
Membership,
Provider,
ProviderGroup,
ProviderSecret,
Resource,
ResourceTag,
Scan,
StateChoices,
Task,
Membership,
ProviderSecret,
Invitation,
ComplianceOverview,
User,
)
from api.rls import Tenant
from api.v1.serializers import TokenSerializer
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
API_JSON_CONTENT_TYPE = "application/vnd.api+json"
NO_TENANT_HTTP_STATUS = status.HTTP_401_UNAUTHORIZED
@@ -537,9 +536,10 @@ def get_api_tokens(
data=json_body,
format="vnd.api+json",
)
return response.json()["data"]["attributes"]["access"], response.json()["data"][
"attributes"
]["refresh"]
return (
response.json()["data"]["attributes"]["access"],
response.json()["data"]["attributes"]["refresh"],
)
def get_authorization_header(access_token: str) -> dict:

View File

@@ -1,8 +1,9 @@
from celery.utils.log import get_task_logger
from django.db import transaction
from api.db_utils import batch_delete
from api.models import Finding, Provider, Resource, Scan, ScanSummary
from api.db_router import MainRouter
from api.db_utils import batch_delete, rls_transaction
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant
logger = get_task_logger(__name__)
@@ -49,3 +50,26 @@ def delete_provider(pk: str):
deletion_summary.update(provider_summary)
return deletion_summary
def delete_tenant(pk: str):
"""
Gracefully deletes an instance of a tenant along with its related data.
Args:
pk (str): The primary key of the Tenant instance to delete.
Returns:
dict: A dictionary with the count of deleted objects per model,
including related models.
"""
deletion_summary = {}
for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
with rls_transaction(pk):
summary = delete_provider(provider.id)
deletion_summary.update(summary)
Tenant.objects.using(MainRouter.admin_db).filter(id=pk).delete()
return deletion_summary

View File

@@ -11,7 +11,7 @@ from api.compliance import (
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
generate_scan_compliance,
)
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
from api.models import (
ComplianceOverview,
Finding,
@@ -69,7 +69,7 @@ def _store_resources(
- tuple[str, str]: A tuple containing the resource UID and region.
"""
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
resource_instance, created = Resource.objects.get_or_create(
tenant_id=tenant_id,
provider=provider_instance,
@@ -86,7 +86,7 @@ def _store_resources(
resource_instance.service = finding.service_name
resource_instance.type = finding.resource_type
resource_instance.save()
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
tags = [
ResourceTag.objects.get_or_create(
tenant_id=tenant_id, key=key, value=value
@@ -122,7 +122,7 @@ def perform_prowler_scan(
unique_resources = set()
start_time = time.time()
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
scan_instance = Scan.objects.get(pk=scan_id)
scan_instance.state = StateChoices.EXECUTING
@@ -130,7 +130,7 @@ def perform_prowler_scan(
scan_instance.save()
try:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
try:
prowler_provider = initialize_prowler_provider(provider_instance)
provider_instance.connected = True
@@ -156,7 +156,7 @@ def perform_prowler_scan(
for finding in findings:
for attempt in range(CELERY_DEADLOCK_ATTEMPTS):
try:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
# Process resource
resource_uid = finding.resource_uid
if resource_uid not in resource_cache:
@@ -188,7 +188,7 @@ def perform_prowler_scan(
resource_instance.type = finding.resource_type
updated_fields.append("type")
if updated_fields:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
resource_instance.save(update_fields=updated_fields)
except (OperationalError, IntegrityError) as db_err:
if attempt < CELERY_DEADLOCK_ATTEMPTS - 1:
@@ -203,7 +203,7 @@ def perform_prowler_scan(
# Update tags
tags = []
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
for key, value in finding.resource_tags.items():
tag_key = (key, value)
if tag_key not in tag_cache:
@@ -219,7 +219,7 @@ def perform_prowler_scan(
unique_resources.add((resource_instance.uid, resource_instance.region))
# Process finding
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
finding_uid = finding.uid
if finding_uid not in last_status_cache:
most_recent_finding = (
@@ -267,7 +267,7 @@ def perform_prowler_scan(
region_dict[finding.check_id] = finding.status.value
# Update scan progress
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_instance.progress = progress
scan_instance.save()
@@ -279,7 +279,7 @@ def perform_prowler_scan(
scan_instance.state = StateChoices.FAILED
finally:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=timezone.utc)
scan_instance.unique_resource_count = len(unique_resources)
@@ -330,7 +330,7 @@ def perform_prowler_scan(
total_requirements=compliance["total_requirements"],
)
)
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
ComplianceOverview.objects.bulk_create(compliance_overview_objects)
if exception is not None:
@@ -368,7 +368,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
- muted_new: Muted findings with a delta of 'new'.
- muted_changed: Muted findings with a delta of 'changed'.
"""
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
findings = Finding.objects.filter(scan_id=scan_id)
aggregation = findings.values(
@@ -464,7 +464,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
),
)
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_aggregations = {
ScanSummary(
tenant_id=tenant_id,

View File

@@ -4,10 +4,10 @@ from celery import shared_task
from config.celery import RLSTask
from django_celery_beat.models import PeriodicTask
from tasks.jobs.connection import check_provider_connection
from tasks.jobs.deletion import delete_provider
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Provider, Scan
@@ -99,7 +99,7 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
"""
task_id = self.request.id
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
periodic_task_instance = PeriodicTask.objects.get(
name=f"scan-perform-scheduled-{provider_id}"
@@ -134,3 +134,8 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
@shared_task(name="scan-summary")
def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(name="tenant-deletion")
def delete_tenant_task(tenant_id: str):
return delete_tenant(pk=tenant_id)

View File

@@ -1,13 +1,15 @@
from unittest.mock import patch
import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider
from tasks.jobs.deletion import delete_provider, delete_tenant
from api.models import Provider
from api.models import Provider, Tenant
@pytest.mark.django_db
class TestDeleteInstance:
def test_delete_instance_success(self, providers_fixture):
class TestDeleteProvider:
def test_delete_provider_success(self, providers_fixture):
instance = providers_fixture[0]
result = delete_provider(instance.id)
@@ -15,8 +17,47 @@ class TestDeleteInstance:
with pytest.raises(ObjectDoesNotExist):
Provider.objects.get(pk=instance.id)
def test_delete_instance_does_not_exist(self):
def test_delete_provider_does_not_exist(self):
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"
with pytest.raises(ObjectDoesNotExist):
delete_provider(non_existent_pk)
@patch("api.db_router.MainRouter.admin_db", new="default")
@pytest.mark.django_db
class TestDeleteTenant:
def test_delete_tenant_success(self, tenants_fixture, providers_fixture):
"""
Test successful deletion of a tenant and its related data.
"""
tenant = tenants_fixture[0]
providers = Provider.objects.filter(tenant_id=tenant.id)
# Ensure the tenant and related providers exist before deletion
assert Tenant.objects.filter(id=tenant.id).exists()
assert providers.exists()
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
assert deletion_summary is not None
assert not Tenant.objects.filter(id=tenant.id).exists()
assert not Provider.objects.filter(tenant_id=tenant.id).exists()
def test_delete_tenant_with_no_providers(self, tenants_fixture):
"""
Test deletion of a tenant with no related providers.
"""
tenant = tenants_fixture[1] # Assume this tenant has no providers
providers = Provider.objects.filter(tenant_id=tenant.id)
# Ensure the tenant exists but has no related providers
assert Tenant.objects.filter(id=tenant.id).exists()
assert not providers.exists()
# Call the function and validate the result
deletion_summary = delete_tenant(tenant.id)
assert deletion_summary == {} # No providers, so empty summary
assert not Tenant.objects.filter(id=tenant.id).exists()

View File

@@ -1,3 +1,4 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
@@ -26,7 +27,7 @@ class TestPerformScan:
providers_fixture,
):
with (
patch("api.db_utils.tenant_transaction"),
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
@@ -165,10 +166,10 @@ class TestPerformScan:
"tasks.jobs.scan.initialize_prowler_provider",
side_effect=Exception("Connection error"),
)
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_perform_prowler_scan_no_connection(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_initialize_prowler_provider,
mock_prowler_scan_class,
tenants_fixture,
@@ -205,14 +206,14 @@ class TestPerformScan:
@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_new_resource(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -253,14 +254,14 @@ class TestPerformScan:
@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_existing_resource(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"
@@ -310,14 +311,14 @@ class TestPerformScan:
@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_with_tags(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"

View File

@@ -37,7 +37,7 @@ services:
- 3000:3000
postgres:
image: postgres:16.3-alpine
image: postgres:16.3-alpine3.20
hostname: "postgres-db"
volumes:
- ./_data/postgres:/var/lib/postgresql/data

View File

@@ -25,7 +25,7 @@ services:
- ${UI_PORT:-3000}:${UI_PORT:-3000}
postgres:
image: postgres:16.3-alpine
image: postgres:16.3-alpine3.20
hostname: "postgres-db"
volumes:
- ./_data/postgres:/var/lib/postgresql/data

View File

@@ -12,7 +12,7 @@ from prowler.lib.logger import logger
timestamp = datetime.today()
timestamp_utc = datetime.now(timezone.utc).replace(tzinfo=timezone.utc)
prowler_version = "5.0.0"
prowler_version = "5.0.2"
html_logo_url = "https://github.com/prowler-cloud/prowler/"
square_logo_img = "https://prowler.com/wp-content/uploads/logo-html.png"
aws_logo = "https://user-images.githubusercontent.com/38561120/235953920-3e3fba08-0795-41dc-b480-9bea57db9f2e.png"

View File

@@ -94,11 +94,12 @@ def get_cis_table(
print(
f"\nCompliance Status of {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Framework:"
)
total_findings_count = len(fail_count) + len(pass_count) + len(muted_count)
overview_table = [
[
f"{Fore.RED}{round(len(fail_count) / len(findings) * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / len(findings) * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / len(findings) * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
f"{Fore.RED}{round(len(fail_count) / total_findings_count * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / total_findings_count * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / total_findings_count * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
]
]
print(tabulate(overview_table, tablefmt="rounded_grid"))

View File

@@ -95,11 +95,12 @@ def get_ens_table(
print(
f"\nEstado de Cumplimiento de {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL}:"
)
total_findings_count = len(fail_count) + len(pass_count) + len(muted_count)
overview_table = [
[
f"{Fore.RED}{round(len(fail_count) / len(findings) * 100, 2)}% ({len(fail_count)}) NO CUMPLE{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / len(findings) * 100, 2)}% ({len(pass_count)}) CUMPLE{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / len(findings) * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
f"{Fore.RED}{round(len(fail_count) / total_findings_count * 100, 2)}% ({len(fail_count)}) NO CUMPLE{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / total_findings_count * 100, 2)}% ({len(pass_count)}) CUMPLE{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / total_findings_count * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
]
]
print(tabulate(overview_table, tablefmt="rounded_grid"))

View File

@@ -39,11 +39,12 @@ def get_generic_compliance_table(
print(
f"\nCompliance Status of {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Framework:"
)
total_findings_count = len(fail_count) + len(pass_count) + len(muted_count)
overview_table = [
[
f"{Fore.RED}{round(len(fail_count) / len(findings) * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / len(findings) * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / len(findings) * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
f"{Fore.RED}{round(len(fail_count) / total_findings_count * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / total_findings_count * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / total_findings_count * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
]
]
print(tabulate(overview_table, tablefmt="rounded_grid"))

View File

@@ -61,11 +61,12 @@ def get_kisa_ismsp_table(
print(
f"\nCompliance Status of {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Framework:"
)
total_findings_count = len(fail_count) + len(pass_count) + len(muted_count)
overview_table = [
[
f"{Fore.RED}{round(len(fail_count) / len(findings) * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / len(findings) * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / len(findings) * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
f"{Fore.RED}{round(len(fail_count) / total_findings_count * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / total_findings_count * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / total_findings_count * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
]
]
print(tabulate(overview_table, tablefmt="rounded_grid"))

View File

@@ -69,11 +69,12 @@ def get_mitre_attack_table(
print(
f"\nCompliance Status of {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Framework:"
)
total_findings_count = len(fail_count) + len(pass_count) + len(muted_count)
overview_table = [
[
f"{Fore.RED}{round(len(fail_count) / len(findings) * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / len(findings) * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / len(findings) * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
f"{Fore.RED}{round(len(fail_count) / total_findings_count * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / total_findings_count * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / total_findings_count * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
]
]
print(tabulate(overview_table, tablefmt="rounded_grid"))

View File

@@ -8,19 +8,20 @@ class autoscaling_group_launch_configuration_no_public_ip(Check):
def execute(self):
findings = []
for group in autoscaling_client.groups:
report = Check_Report_AWS(self.metadata())
report.region = group.region
report.resource_id = group.name
report.resource_arn = group.arn
report.resource_tags = group.tags
report.status = "PASS"
report.status_extended = f"Autoscaling group {group.name} does not have an associated launch configuration assigning a public IP address."
for lc in autoscaling_client.launch_configurations.values():
if lc.name == group.launch_configuration_name and lc.public_ip:
report.status = "FAIL"
report.status_extended = f"Autoscaling group {group.name} has an associated launch configuration assigning a public IP address."
if lc.name == group.launch_configuration_name:
report = Check_Report_AWS(self.metadata())
report.region = group.region
report.resource_id = group.name
report.resource_arn = group.arn
report.resource_tags = group.tags
report.status = "PASS"
report.status_extended = f"Autoscaling group {group.name} does not have an associated launch configuration assigning a public IP address."
findings.append(report)
if lc.public_ip:
report.status = "FAIL"
report.status_extended = f"Autoscaling group {group.name} has an associated launch configuration assigning a public IP address."
findings.append(report)
return findings

View File

@@ -8,20 +8,17 @@ class autoscaling_group_launch_configuration_requires_imdsv2(Check):
def execute(self):
findings = []
for group in autoscaling_client.groups:
report = Check_Report_AWS(self.metadata())
report.region = group.region
report.resource_id = group.name
report.resource_arn = group.arn
report.resource_tags = group.tags
report.status = "FAIL"
report.status_extended = (
f"Autoscaling group {group.name} has IMDSv2 disabled or not required."
)
for (
launch_configuration
) in autoscaling_client.launch_configurations.values():
if launch_configuration.name == group.launch_configuration_name:
report = Check_Report_AWS(self.metadata())
report.region = group.region
report.resource_id = group.name
report.resource_arn = group.arn
report.resource_tags = group.tags
report.status = "FAIL"
report.status_extended = f"Autoscaling group {group.name} has IMDSv2 disabled or not required."
if (
launch_configuration.http_endpoint == "enabled"
and launch_configuration.http_tokens == "required"
@@ -32,6 +29,6 @@ class autoscaling_group_launch_configuration_requires_imdsv2(Check):
report.status = "PASS"
report.status_extended = f"Autoscaling group {group.name} has metadata service disabled."
findings.append(report)
findings.append(report)
return findings

View File

@@ -8,14 +8,14 @@ class backup_recovery_point_encrypted(Check):
for recovery_point in backup_client.recovery_points:
report = Check_Report_AWS(self.metadata())
report.region = recovery_point.backup_vault_region
report.resource_id = recovery_point.backup_vault_name
report.resource_id = recovery_point.id
report.resource_arn = recovery_point.arn
report.resource_tags = recovery_point.tags
report.status = "FAIL"
report.status_extended = f"Backup Recovery Point {recovery_point.arn} for Backup Vault {recovery_point.backup_vault_name} is not encrypted at rest."
report.status_extended = f"Backup Recovery Point {recovery_point.id} for Backup Vault {recovery_point.backup_vault_name} is not encrypted at rest."
if recovery_point.encrypted:
report.status = "PASS"
report.status_extended = f"Backup Recovery Point {recovery_point.arn} for Backup Vault {recovery_point.backup_vault_name} is encrypted at rest."
report.status_extended = f"Backup Recovery Point {recovery_point.id} for Backup Vault {recovery_point.backup_vault_name} is encrypted at rest."
findings.append(report)

View File

@@ -183,21 +183,27 @@ class Backup(AWSService):
def _list_recovery_points(self, regional_client):
logger.info("Backup - Listing Recovery Points...")
try:
for backup_vault in self.backup_vaults:
paginator = regional_client.get_paginator(
"list_recovery_points_by_backup_vault"
)
for page in paginator.paginate(BackupVaultName=backup_vault.name):
for recovery_point in page.get("RecoveryPoints", []):
self.recovery_points.append(
RecoveryPoint(
arn=recovery_point.get("RecoveryPointArn"),
backup_vault_name=backup_vault.name,
encrypted=recovery_point.get("IsEncrypted", False),
backup_vault_region=backup_vault.region,
tags=[],
)
)
if self.backup_vaults:
for backup_vault in self.backup_vaults:
paginator = regional_client.get_paginator(
"list_recovery_points_by_backup_vault"
)
for page in paginator.paginate(BackupVaultName=backup_vault.name):
for recovery_point in page.get("RecoveryPoints", []):
arn = recovery_point.get("RecoveryPointArn")
if arn:
self.recovery_points.append(
RecoveryPoint(
arn=arn,
id=arn.split(":")[-1],
backup_vault_name=backup_vault.name,
encrypted=recovery_point.get(
"IsEncrypted", False
),
backup_vault_region=backup_vault.region,
tags=[],
)
)
except ClientError as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -241,6 +247,7 @@ class BackupReportPlan(BaseModel):
class RecoveryPoint(BaseModel):
arn: str
id: str
backup_vault_name: str
encrypted: bool
backup_vault_region: str

View File

@@ -48,7 +48,7 @@ class cloudtrail_multi_region_enabled_logging_management_events(Check):
report.resource_id = trail.name
report.resource_arn = trail.arn
report.resource_tags = trail.tags
report.region = trail.home_region
report.region = region
report.status = "PASS"
if trail.is_multiregion:
report.status_extended = f"Trail {trail.name} from home region {trail.home_region} is multi-region, is logging and have management events enabled."

View File

@@ -67,10 +67,8 @@ class cloudtrail_threat_detection_enumeration(Check):
found_potential_enumeration = True
report = Check_Report_AWS(self.metadata())
report.region = cloudtrail_client.region
report.resource_id = cloudtrail_client.audited_account
report.resource_arn = cloudtrail_client._get_trail_arn_template(
cloudtrail_client.region
)
report.resource_id = aws_identity_arn.split("/")[-1]
report.resource_arn = aws_identity_arn
report.status = "FAIL"
report.status_extended = f"Potential enumeration attack detected from AWS {aws_identity_type} {aws_identity_arn.split('/')[-1]} with an threshold of {identity_threshold}."
findings.append(report)

View File

@@ -67,10 +67,8 @@ class cloudtrail_threat_detection_llm_jacking(Check):
found_potential_llm_jacking = True
report = Check_Report_AWS(self.metadata())
report.region = cloudtrail_client.region
report.resource_id = cloudtrail_client.audited_account
report.resource_arn = cloudtrail_client._get_trail_arn_template(
cloudtrail_client.region
)
report.resource_id = aws_identity_arn.split("/")[-1]
report.resource_arn = aws_identity_arn
report.status = "FAIL"
report.status_extended = f"Potential LLM Jacking attack detected from AWS {aws_identity_type} {aws_identity_arn.split('/')[-1]} with an threshold of {identity_threshold}."
findings.append(report)

View File

@@ -69,10 +69,8 @@ class cloudtrail_threat_detection_privilege_escalation(Check):
found_potential_privilege_escalation = True
report = Check_Report_AWS(self.metadata())
report.region = cloudtrail_client.region
report.resource_id = cloudtrail_client.audited_account
report.resource_arn = cloudtrail_client._get_trail_arn_template(
cloudtrail_client.region
)
report.resource_id = aws_identity_arn.split("/")[-1]
report.resource_arn = aws_identity_arn
report.status = "FAIL"
report.status_extended = f"Potential privilege escalation attack detected from AWS {aws_identity_type} {aws_identity_arn.split('/')[-1]} with an threshold of {identity_threshold}."
findings.append(report)

View File

@@ -8,7 +8,7 @@
"ServiceName": "ec2",
"SubServiceName": "",
"ResourceIdTemplate": "arn:partition:service:region:account-id",
"Severity": "medium",
"Severity": "high",
"ResourceType": "AwsEc2Instance",
"Description": "Ensure Instance Metadata Service Version 2 (IMDSv2) is enforced for EC2 instances at the account level to protect against SSRF vulnerabilities.",
"Risk": "EC2 instances that use IMDSv1 are vulnerable to SSRF attacks.",

View File

@@ -8,7 +8,7 @@
"ServiceName": "ec2",
"SubServiceName": "",
"ResourceIdTemplate": "arn:partition:service:region:account-id:resource-id",
"Severity": "medium",
"Severity": "high",
"ResourceType": "AwsEc2Instance",
"Description": "Check if EC2 Instance Metadata Service Version 2 (IMDSv2) is Enabled and Required.",
"Risk": "Using IMDSv2 will protect from misconfiguration and SSRF vulnerabilities. IMDSv1 will not.",

View File

@@ -31,10 +31,7 @@ class firehose_stream_encrypted_at_rest(Check):
f"Firehose Stream {stream.name} does have at rest encryption enabled."
)
if (
stream.kms_encryption != EncryptionStatus.ENABLED
or not stream.kms_key_arn
):
if stream.kms_encryption != EncryptionStatus.ENABLED:
report.status = "FAIL"
report.status_extended = f"Firehose Stream {stream.name} does not have at rest encryption enabled."

View File

@@ -49,7 +49,7 @@ class iam_rotate_access_key_90_days(Check):
old_access_keys = True
report = Check_Report_AWS(self.metadata())
report.region = iam_client.region
report.resource_id = user["user"]
report.resource_id = f"{user['user']}-access-key-1"
report.resource_arn = user["arn"]
report.resource_tags = user_tags
report.status = "FAIL"
@@ -66,7 +66,7 @@ class iam_rotate_access_key_90_days(Check):
old_access_keys = True
report = Check_Report_AWS(self.metadata())
report.region = iam_client.region
report.resource_id = user["user"]
report.resource_id = f"{user['user']}-access-key-2"
report.resource_arn = user["arn"]
report.resource_tags = user_tags
report.status = "FAIL"

View File

@@ -6,8 +6,8 @@ class rds_cluster_non_default_port(Check):
def execute(self):
findings = []
default_ports = {
3306: ["mysql", "mariadb"],
5432: ["postgres"],
3306: ["mysql", "mariadb", "aurora-mysql"],
5432: ["postgres", "aurora-postgresql"],
1521: ["oracle"],
1433: ["sqlserver"],
50000: ["db2"],

View File

@@ -37,18 +37,21 @@ class rds_instance_no_public_access(Check):
):
report.status_extended = f"RDS Instance {db_instance.id} is set as publicly accessible and security group {security_group.name} ({security_group.id}) has {db_instance.engine} port {db_instance_port} open to the Internet at endpoint {db_instance.endpoint.get('Address')} but is not in a public subnet."
public_sg = True
if db_instance.subnet_ids:
for subnet_id in db_instance.subnet_ids:
if (
subnet_id in vpc_client.vpc_subnets
and vpc_client.vpc_subnets[
subnet_id
].public
):
report.status = "FAIL"
report.status_extended = f"RDS Instance {db_instance.id} is set as publicly accessible and security group {security_group.name} ({security_group.id}) has {db_instance.engine} port {db_instance_port} open to the Internet at endpoint {db_instance.endpoint.get('Address')} in a public subnet {subnet_id}."
break
if public_sg:
break
if public_sg:
break
if db_instance.subnet_ids:
for subnet_id in db_instance.subnet_ids:
if (
subnet_id in vpc_client.vpc_subnets
and vpc_client.vpc_subnets[subnet_id].public
):
report.status = "FAIL"
report.status_extended = f"RDS Instance {db_instance.id} is set as publicly accessible and security group {security_group.name} ({security_group.id}) has {db_instance.engine} port {db_instance_port} open to the Internet at endpoint {db_instance.endpoint.get('Address')} in a public subnet {subnet_id}."
break
findings.append(report)

View File

@@ -6,8 +6,8 @@ class rds_instance_non_default_port(Check):
def execute(self):
findings = []
default_ports = {
3306: ["mysql", "mariadb"],
5432: ["postgres"],
3306: ["mysql", "mariadb", "aurora-mysql"],
5432: ["postgres", "aurora-postgresql"],
1521: ["oracle"],
1433: ["sqlserver"],
50000: ["db2"],

View File

@@ -29,7 +29,9 @@ class route53_dangling_ip_subdomain_takeover(Check):
# Check if record is an IP Address
if validate_ip_address(record):
report = Check_Report_AWS(self.metadata())
report.resource_id = f"{record_set.hosted_zone_id}/{record}"
report.resource_id = (
f"{record_set.hosted_zone_id}/{record_set.name}/{record}"
)
report.resource_arn = route53_client.hosted_zones[
record_set.hosted_zone_id
].arn

View File

@@ -150,6 +150,22 @@ class WAFv2(AWSService):
else:
acl.rules.append(new_rule)
firewall_manager_managed_rg = get_web_acl.get("WebACL", {}).get(
"PreProcessFirewallManagerRuleGroups", []
) + get_web_acl.get("WebACL", {}).get(
"PostProcessFirewallManagerRuleGroups", []
)
for rule in firewall_manager_managed_rg:
acl.rule_groups.append(
Rule(
name=rule.get("Name", ""),
cloudwatch_metrics_enabled=rule.get(
"VisibilityConfig", {}
).get("CloudWatchMetricsEnabled", False),
)
)
except Exception as error:
logger.error(
f"{acl.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -193,13 +209,6 @@ class Rule(BaseModel):
cloudwatch_metrics_enabled: bool = False
class FirewallManagerRuleGroup(BaseModel):
"""Model representing a rule group for the Web ACL."""
name: str
cloudwatch_metrics_enabled: bool = False
class WebAclv2(BaseModel):
"""Model representing a Web ACL for WAFv2."""

View File

@@ -19,12 +19,11 @@ class app_minimum_tls_version_12(Check):
report.location = app.location
report.status_extended = f"Minimum TLS version is not set to 1.2 for app '{app_name}' in subscription '{subscription_name}'."
if (
app.configurations
and getattr(app.configurations, "min_tls_version", "") == "1.2"
):
if app.configurations and getattr(
app.configurations, "min_tls_version", ""
) in ["1.2", "1.3"]:
report.status = "PASS"
report.status_extended = f"Minimum TLS version is set to 1.2 for app '{app_name}' in subscription '{subscription_name}'."
report.status_extended = f"Minimum TLS version is set to {app.configurations.min_tls_version} for app '{app_name}' in subscription '{subscription_name}'."
findings.append(report)

View File

@@ -181,8 +181,6 @@ class GcpProvider(Provider):
message="No Project IDs can be accessed via Google Credentials.",
)
if project_ids:
if self._default_project_id not in project_ids:
self._default_project_id = project_ids[0]
for input_project in project_ids:
for (
accessible_project_id,
@@ -203,6 +201,10 @@ class GcpProvider(Provider):
self._projects[project_id] = project
self._project_ids.append(project_id)
# Change default project if not in active projects
if self._project_ids and self._default_project_id not in self._project_ids:
self._default_project_id = self._project_ids[0]
# Remove excluded projects if any input
if excluded_project_ids:
for excluded_project in excluded_project_ids:

View File

@@ -23,7 +23,7 @@ packages = [
{include = "dashboard"}
]
readme = "README.md"
version = "5.0.0"
version = "5.0.2"
[tool.poetry.dependencies]
alive-progress = "3.2.0"

View File

@@ -171,10 +171,6 @@ class Test_autoscaling_group_launch_configuration_no_public_ip:
AvailabilityZones=["us-east-1a", "us-east-1b"],
)
autoscaling_group_arn = autoscaling_client.describe_auto_scaling_groups(
AutoScalingGroupNames=[autoscaling_group_name]
)["AutoScalingGroups"][0]["AutoScalingGroupARN"]
from prowler.providers.aws.services.autoscaling.autoscaling_service import (
AutoScaling,
)
@@ -196,12 +192,4 @@ class Test_autoscaling_group_launch_configuration_no_public_ip:
check = autoscaling_group_launch_configuration_no_public_ip()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Autoscaling group {autoscaling_group_name} does not have an associated launch configuration assigning a public IP address."
)
assert result[0].resource_id == autoscaling_group_name
assert result[0].resource_tags == []
assert result[0].resource_arn == autoscaling_group_arn
assert len(result) == 0

View File

@@ -119,10 +119,6 @@ class Test_autoscaling_group_launch_configuration_requires_imdsv2:
AvailabilityZones=["us-east-1a", "us-east-1b"],
)
autoscaling_group_arn = autoscaling_client.describe_auto_scaling_groups(
AutoScalingGroupNames=[autoscaling_group_name]
)["AutoScalingGroups"][0]["AutoScalingGroupARN"]
from prowler.providers.aws.services.autoscaling.autoscaling_service import (
AutoScaling,
)
@@ -144,15 +140,7 @@ class Test_autoscaling_group_launch_configuration_requires_imdsv2:
check = autoscaling_group_launch_configuration_requires_imdsv2()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"Autoscaling group {autoscaling_group_name} has IMDSv2 disabled or not required."
)
assert result[0].resource_id == autoscaling_group_name
assert result[0].resource_tags == []
assert result[0].resource_arn == autoscaling_group_arn
assert len(result) == 0
@mock_aws
def test_groups_with_imdsv2_disabled(self):

View File

@@ -94,12 +94,15 @@ class Test_backup_recovery_point_encrypted:
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
),
):
# Test Check
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
@@ -124,12 +127,15 @@ class Test_backup_recovery_point_encrypted:
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
),
):
# Test Check
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
@@ -142,9 +148,9 @@ class Test_backup_recovery_point_encrypted:
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].status_extended == (
"Backup Recovery Point arn:aws:backup:eu-west-1:123456789012:recovery-point:1 for Backup Vault Test Vault is not encrypted at rest."
"Backup Recovery Point 1 for Backup Vault Test Vault is not encrypted at rest."
)
assert result[0].resource_id == "Test Vault"
assert result[0].resource_id == "1"
assert (
result[0].resource_arn
== "arn:aws:backup:eu-west-1:123456789012:recovery-point:1"
@@ -165,12 +171,15 @@ class Test_backup_recovery_point_encrypted:
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted.backup_client",
new=Backup(aws_provider),
),
):
# Test Check
from prowler.providers.aws.services.backup.backup_recovery_point_encrypted.backup_recovery_point_encrypted import (
@@ -183,9 +192,9 @@ class Test_backup_recovery_point_encrypted:
assert len(result) == 1
assert result[0].status == "PASS"
assert result[0].status_extended == (
"Backup Recovery Point arn:aws:backup:eu-west-1:123456789012:recovery-point:1 for Backup Vault Test Vault is encrypted at rest."
"Backup Recovery Point 1 for Backup Vault Test Vault is encrypted at rest."
)
assert result[0].resource_id == "Test Vault"
assert result[0].resource_id == "1"
assert (
result[0].resource_arn
== "arn:aws:backup:eu-west-1:123456789012:recovery-point:1"

View File

@@ -19,10 +19,10 @@ def mock_get_trail_arn_template(region=None, *_) -> str:
def mock__get_lookup_events__(trail=None, event_name=None, minutes=None, *_) -> list:
return [
{
"CloudTrailEvent": '{"eventName": "DescribeAccessEntry", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Mateo", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Mateo", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
"CloudTrailEvent": '{"eventName": "DescribeAccessEntry", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Attacker", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Attacker", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
},
{
"CloudTrailEvent": '{"eventName": "DescribeAccountAttributes", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Mateo", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Mateo", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
"CloudTrailEvent": '{"eventName": "DescribeAccountAttributes", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Attacker", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Attacker", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
},
]
@@ -50,12 +50,15 @@ class Test_cloudtrail_threat_detection_enumeration:
cloudtrail_client.audited_account = AWS_ACCOUNT_NUMBER
cloudtrail_client.region = AWS_REGION_US_EAST_1
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration import (
@@ -99,12 +102,15 @@ class Test_cloudtrail_threat_detection_enumeration:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration import (
@@ -148,12 +154,15 @@ class Test_cloudtrail_threat_detection_enumeration:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration import (
@@ -167,13 +176,13 @@ class Test_cloudtrail_threat_detection_enumeration:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "Potential enumeration attack detected from AWS IAMUser Mateo with an threshold of 1.0."
== "Potential enumeration attack detected from AWS IAMUser Attacker with an threshold of 1.0."
)
assert result[0].resource_id == AWS_ACCOUNT_NUMBER
assert result[0].resource_id == "Attacker"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:trail"
== f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:user/Attacker"
)
@mock_aws
@@ -198,12 +207,15 @@ class Test_cloudtrail_threat_detection_enumeration:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration import (
@@ -247,12 +259,15 @@ class Test_cloudtrail_threat_detection_enumeration:
cloudtrail_client._lookup_events = mock__get_lookup_events_aws_service__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_enumeration.cloudtrail_threat_detection_enumeration import (

View File

@@ -19,10 +19,10 @@ def mock_get_trail_arn_template(region=None, *_) -> str:
def mock__get_lookup_events__(trail=None, event_name=None, minutes=None, *_) -> list:
return [
{
"CloudTrailEvent": '{"eventName": "InvokeModel", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Mateo", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Mateo", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
"CloudTrailEvent": '{"eventName": "InvokeModel", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Attacker", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Attacker", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
},
{
"CloudTrailEvent": '{"eventName": "InvokeModelWithResponseStream", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Mateo", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Mateo", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
"CloudTrailEvent": '{"eventName": "InvokeModelWithResponseStream", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Attacker", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Attacker", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
},
]
@@ -50,12 +50,15 @@ class Test_cloudtrail_threat_detection_llm_jacking:
cloudtrail_client.audited_account = AWS_ACCOUNT_NUMBER
cloudtrail_client.region = AWS_REGION_US_EAST_1
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking import (
@@ -96,12 +99,15 @@ class Test_cloudtrail_threat_detection_llm_jacking:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking import (
@@ -145,12 +151,15 @@ class Test_cloudtrail_threat_detection_llm_jacking:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking import (
@@ -164,13 +173,13 @@ class Test_cloudtrail_threat_detection_llm_jacking:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "Potential LLM Jacking attack detected from AWS IAMUser Mateo with an threshold of 1.0."
== "Potential LLM Jacking attack detected from AWS IAMUser Attacker with an threshold of 1.0."
)
assert result[0].resource_id == AWS_ACCOUNT_NUMBER
assert result[0].resource_id == "Attacker"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:trail"
== f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:user/Attacker"
)
@mock_aws
@@ -195,12 +204,15 @@ class Test_cloudtrail_threat_detection_llm_jacking:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking import (
@@ -244,12 +256,15 @@ class Test_cloudtrail_threat_detection_llm_jacking:
cloudtrail_client._lookup_events = mock__get_lookup_events_aws_service__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_llm_jacking.cloudtrail_threat_detection_llm_jacking import (

View File

@@ -19,10 +19,10 @@ def mock_get_trail_arn_template(region=None, *_) -> str:
def mock__get_lookup_events__(trail=None, event_name=None, minutes=None, *_) -> list:
return [
{
"CloudTrailEvent": '{"eventName": "CreateLoginProfile", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Mateo", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Mateo", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
"CloudTrailEvent": '{"eventName": "CreateLoginProfile", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Attacker", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Attacker", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
},
{
"CloudTrailEvent": '{"eventName": "UpdateLoginProfile", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Mateo", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Mateo", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
"CloudTrailEvent": '{"eventName": "UpdateLoginProfile", "userIdentity": {"type": "IAMUser", "principalId": "EXAMPLE6E4XEGITWATV6R", "arn": "arn:aws:iam::123456789012:user/Attacker", "accountId": "123456789012", "accessKeyId": "AKIAIOSFODNN7EXAMPLE", "userName": "Attacker", "sessionContext": {"sessionIssuer": {}, "webIdFederationData": {}, "attributes": {"creationDate": "2023-07-19T21:11:57Z", "mfaAuthenticated": "false"}}}}'
},
]
@@ -50,12 +50,15 @@ class Test_cloudtrail_threat_detection_privilege_escalation:
cloudtrail_client.audited_account = AWS_ACCOUNT_NUMBER
cloudtrail_client.region = AWS_REGION_US_EAST_1
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation import (
@@ -97,12 +100,15 @@ class Test_cloudtrail_threat_detection_privilege_escalation:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation import (
@@ -147,12 +153,15 @@ class Test_cloudtrail_threat_detection_privilege_escalation:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation import (
@@ -166,13 +175,13 @@ class Test_cloudtrail_threat_detection_privilege_escalation:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "Potential privilege escalation attack detected from AWS IAMUser Mateo with an threshold of 1.0."
== "Potential privilege escalation attack detected from AWS IAMUser Attacker with an threshold of 1.0."
)
assert result[0].resource_id == AWS_ACCOUNT_NUMBER
assert result[0].resource_id == "Attacker"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:cloudtrail:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:trail"
== f"arn:aws:iam::{AWS_ACCOUNT_NUMBER}:user/Attacker"
)
@mock_aws
@@ -197,12 +206,15 @@ class Test_cloudtrail_threat_detection_privilege_escalation:
cloudtrail_client._lookup_events = mock__get_lookup_events__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation import (
@@ -247,12 +259,15 @@ class Test_cloudtrail_threat_detection_privilege_escalation:
cloudtrail_client._lookup_events = mock__get_lookup_events_aws_service__
cloudtrail_client._get_trail_arn_template = mock_get_trail_arn_template
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
), mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider(),
),
mock.patch(
"prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation.cloudtrail_client",
new=cloudtrail_client,
),
):
# Test Check
from prowler.providers.aws.services.cloudtrail.cloudtrail_threat_detection_privilege_escalation.cloudtrail_threat_detection_privilege_escalation import (

View File

@@ -17,12 +17,15 @@ class Test_firehose_stream_encrypted_at_rest:
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest.firehose_client",
new=Firehose(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest.firehose_client",
new=Firehose(aws_provider),
),
):
# Test Check
from prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest import (
@@ -94,6 +97,65 @@ class Test_firehose_stream_encrypted_at_rest:
== f"Firehose Stream {stream_name} does have at rest encryption enabled."
)
@mock_aws
def test_stream_kms_encryption_enabled_aws_managed_key(self):
# Generate S3 client
s3_client = client("s3", region_name=AWS_REGION_EU_WEST_1)
s3_client.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": AWS_REGION_EU_WEST_1},
)
# Generate Firehose client
firehose_client = client("firehose", region_name=AWS_REGION_EU_WEST_1)
delivery_stream = firehose_client.create_delivery_stream(
DeliveryStreamName="test-delivery-stream",
DeliveryStreamType="DirectPut",
S3DestinationConfiguration={
"RoleARN": "arn:aws:iam::012345678901:role/firehose-role",
"BucketARN": "arn:aws:s3:::test-bucket",
"Prefix": "",
"BufferingHints": {"IntervalInSeconds": 300, "SizeInMBs": 5},
"CompressionFormat": "UNCOMPRESSED",
},
Tags=[{"Key": "key", "Value": "value"}],
)
arn = delivery_stream["DeliveryStreamARN"]
stream_name = arn.split("/")[-1]
firehose_client.start_delivery_stream_encryption(
DeliveryStreamName=stream_name,
DeliveryStreamEncryptionConfigurationInput={
"KeyType": "AWS_OWNED_CMK",
},
)
from prowler.providers.aws.services.firehose.firehose_service import Firehose
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest.firehose_client",
new=Firehose(aws_provider),
):
# Test Check
from prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest import (
firehose_stream_encrypted_at_rest,
)
check = firehose_stream_encrypted_at_rest()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Firehose Stream {stream_name} does have at rest encryption enabled."
)
@mock_aws
def test_stream_kms_encryption_not_enabled(self):
# Generate Firehose client

View File

@@ -21,13 +21,16 @@ class Test_iam_rotate_access_key_90_days_test:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client:
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client,
):
from prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days import (
iam_rotate_access_key_90_days,
)
@@ -62,13 +65,16 @@ class Test_iam_rotate_access_key_90_days_test:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client:
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client,
):
from prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days import (
iam_rotate_access_key_90_days,
)
@@ -86,7 +92,7 @@ class Test_iam_rotate_access_key_90_days_test:
result[0].status_extended
== f"User {user} has not rotated access key 1 in over 90 days (100 days)."
)
assert result[0].resource_id == user
assert result[0].resource_id == f"{user}-access-key-1"
assert result[0].resource_arn == arn
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_tags == [{"Key": "test-tag", "Value": "test"}]
@@ -106,13 +112,16 @@ class Test_iam_rotate_access_key_90_days_test:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client:
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client,
):
from prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days import (
iam_rotate_access_key_90_days,
)
@@ -130,7 +139,7 @@ class Test_iam_rotate_access_key_90_days_test:
result[0].status_extended
== f"User {user} has not rotated access key 2 in over 90 days (100 days)."
)
assert result[0].resource_id == user
assert result[0].resource_id == f"{user}-access-key-2"
assert result[0].resource_arn == arn
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_tags == [{"Key": "test-tag", "Value": "test"}]
@@ -150,13 +159,16 @@ class Test_iam_rotate_access_key_90_days_test:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client:
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days.iam_client",
new=IAM(aws_provider),
) as service_client,
):
from prowler.providers.aws.services.iam.iam_rotate_access_key_90_days.iam_rotate_access_key_90_days import (
iam_rotate_access_key_90_days,
)
@@ -179,7 +191,7 @@ class Test_iam_rotate_access_key_90_days_test:
result[0].status_extended
== f"User {user} has not rotated access key 1 in over 90 days (100 days)."
)
assert result[0].resource_id == user
assert result[0].resource_id == f"{user}-access-key-1"
assert result[0].resource_arn == arn
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_tags == [{"Key": "test-tag", "Value": "test"}]
@@ -188,7 +200,7 @@ class Test_iam_rotate_access_key_90_days_test:
result[1].status_extended
== f"User {user} has not rotated access key 2 in over 90 days (100 days)."
)
assert result[1].resource_id == user
assert result[1].resource_id == f"{user}-access-key-2"
assert result[1].resource_arn == arn
assert result[1].region == AWS_REGION_US_EAST_1
assert result[1].resource_tags == [{"Key": "test-tag", "Value": "test"}]

View File

@@ -35,7 +35,7 @@ class Test_rds_cluster_non_default_port:
assert len(result) == 0
@mock_aws
def test_rds_cluster_using_default_port(self):
def test_rds_cluster_aurora_postgres_using_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-1",
@@ -82,10 +82,10 @@ class Test_rds_cluster_non_default_port:
assert result[0].resource_tags == [{"Key": "test", "Value": "test"}]
@mock_aws
def test_rds_cluster_using_non_default_port(self):
def test_rds_cluster_aurora_postgres_using_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-1",
DBClusterIdentifier="db-cluster-2",
Engine="aurora-postgresql",
StorageEncrypted=True,
DeletionProtection=True,
@@ -118,13 +118,205 @@ class Test_rds_cluster_non_default_port:
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Cluster db-cluster-1 is not using the default port 5433 for aurora-postgresql."
== "RDS Cluster db-cluster-2 is not using the default port 5433 for aurora-postgresql."
)
assert result[0].resource_id == "db-cluster-1"
assert result[0].resource_id == "db-cluster-2"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-1"
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-2"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}
]
@mock_aws
def test_rds_cluster_postgres_using_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-3",
Engine="postgres",
StorageEncrypted=True,
DeletionProtection=True,
MasterUsername="cluster",
MasterUserPassword="password",
Port=5432,
Tags=[{"Key": "test", "Value": "test"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port import (
rds_cluster_non_default_port,
)
check = rds_cluster_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Cluster db-cluster-3 is using the default port 5432 for postgres."
)
assert result[0].resource_id == "db-cluster-3"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-3"
)
assert result[0].resource_tags == [{"Key": "test", "Value": "test"}]
@mock_aws
def test_rds_cluster_postgres_using_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-4",
Engine="postgres",
StorageEncrypted=True,
DeletionProtection=True,
MasterUsername="cluster",
MasterUserPassword="password",
Port=5433,
Tags=[{"Key": "env", "Value": "production"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port import (
rds_cluster_non_default_port,
)
check = rds_cluster_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Cluster db-cluster-4 is not using the default port 5433 for postgres."
)
assert result[0].resource_id == "db-cluster-4"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-4"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}
]
@mock_aws
def test_rds_cluster_aurora_mysql_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-5",
Engine="aurora-mysql",
StorageEncrypted=True,
DeletionProtection=True,
MasterUsername="cluster",
MasterUserPassword="password",
Port=3306,
Tags=[{"Key": "env", "Value": "staging"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port import (
rds_cluster_non_default_port,
)
check = rds_cluster_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Cluster db-cluster-5 is using the default port 3306 for aurora-mysql."
)
assert result[0].resource_id == "db-cluster-5"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-5"
)
assert result[0].resource_tags == [{"Key": "env", "Value": "staging"}]
@mock_aws
def test_rds_cluster_aurora_mysql_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-6",
Engine="aurora-mysql",
StorageEncrypted=True,
DeletionProtection=True,
MasterUsername="cluster",
MasterUserPassword="password",
Port=3307,
Tags=[{"Key": "env", "Value": "production"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_cluster_non_default_port.rds_cluster_non_default_port import (
rds_cluster_non_default_port,
)
check = rds_cluster_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Cluster db-cluster-6 is not using the default port 3307 for aurora-mysql."
)
assert result[0].resource_id == "db-cluster-6"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-6"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}
@@ -134,7 +326,7 @@ class Test_rds_cluster_non_default_port:
def test_rds_cluster_mysql_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-1",
DBClusterIdentifier="db-cluster-7",
Engine="mysql",
StorageEncrypted=True,
DeletionProtection=True,
@@ -167,13 +359,13 @@ class Test_rds_cluster_non_default_port:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Cluster db-cluster-1 is using the default port 3306 for mysql."
== "RDS Cluster db-cluster-7 is using the default port 3306 for mysql."
)
assert result[0].resource_id == "db-cluster-1"
assert result[0].resource_id == "db-cluster-7"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-1"
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-7"
)
assert result[0].resource_tags == [{"Key": "env", "Value": "staging"}]
@@ -181,7 +373,7 @@ class Test_rds_cluster_non_default_port:
def test_rds_cluster_mysql_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_cluster(
DBClusterIdentifier="db-cluster-1",
DBClusterIdentifier="db-cluster-8",
Engine="mysql",
StorageEncrypted=True,
DeletionProtection=True,
@@ -214,13 +406,13 @@ class Test_rds_cluster_non_default_port:
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Cluster db-cluster-1 is not using the default port 3307 for mysql."
== "RDS Cluster db-cluster-8 is not using the default port 3307 for mysql."
)
assert result[0].resource_id == "db-cluster-1"
assert result[0].resource_id == "db-cluster-8"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-1"
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:cluster:db-cluster-8"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}

View File

@@ -35,11 +35,115 @@ class Test_rds_instance_non_default_port:
assert len(result) == 0
@mock_aws
def test_rds_instance_using_default_port(self):
def test_rds_instance_aurora_postgres_using_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-1",
AllocatedStorage=10,
Engine="aurora-postgresql",
DBName="staging-postgres",
DBInstanceClass="db.m1.small",
StorageEncrypted=True,
DeletionProtection=True,
PubliclyAccessible=True,
AutoMinorVersionUpgrade=True,
BackupRetentionPeriod=10,
Port=5432,
Tags=[{"Key": "test", "Value": "test"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port import (
rds_instance_non_default_port,
)
check = rds_instance_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Instance db-master-1 is using the default port 5432 for aurora-postgresql."
)
assert result[0].resource_id == "db-master-1"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-1"
)
assert result[0].resource_tags == [{"Key": "test", "Value": "test"}]
@mock_aws
def test_rds_instance_aurora_postgres_using_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-2",
AllocatedStorage=10,
Engine="aurora-postgresql",
DBName="production-postgres",
DBInstanceClass="db.m1.small",
StorageEncrypted=True,
DeletionProtection=True,
PubliclyAccessible=True,
AutoMinorVersionUpgrade=True,
BackupRetentionPeriod=10,
Port=5433,
Tags=[{"Key": "env", "Value": "production"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port import (
rds_instance_non_default_port,
)
check = rds_instance_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Instance db-master-2 is not using the default port 5433 for aurora-postgresql."
)
assert result[0].resource_id == "db-master-2"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-2"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}
]
@mock_aws
def test_rds_instance_postgres_using_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-3",
AllocatedStorage=10,
Engine="postgres",
DBName="staging-postgres",
DBInstanceClass="db.m1.small",
@@ -75,21 +179,21 @@ class Test_rds_instance_non_default_port:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Instance db-master-1 is using the default port 5432 for postgres."
== "RDS Instance db-master-3 is using the default port 5432 for postgres."
)
assert result[0].resource_id == "db-master-1"
assert result[0].resource_id == "db-master-3"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-1"
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-3"
)
assert result[0].resource_tags == [{"Key": "test", "Value": "test"}]
@mock_aws
def test_rds_instance_using_non_default_port(self):
def test_rds_instance_postgres_using_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-2",
DBInstanceIdentifier="db-master-4",
AllocatedStorage=10,
Engine="postgres",
DBName="production-postgres",
@@ -126,13 +230,221 @@ class Test_rds_instance_non_default_port:
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Instance db-master-2 is not using the default port 5433 for postgres."
== "RDS Instance db-master-4 is not using the default port 5433 for postgres."
)
assert result[0].resource_id == "db-master-2"
assert result[0].resource_id == "db-master-4"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-2"
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-4"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}
]
@mock_aws
def test_rds_instance_mysql_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-5",
AllocatedStorage=10,
Engine="mysql",
DBName="staging-mariadb",
DBInstanceClass="db.m1.small",
StorageEncrypted=True,
DeletionProtection=True,
PubliclyAccessible=True,
AutoMinorVersionUpgrade=True,
BackupRetentionPeriod=10,
Port=3306,
Tags=[{"Key": "env", "Value": "staging"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port import (
rds_instance_non_default_port,
)
check = rds_instance_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Instance db-master-5 is using the default port 3306 for mysql."
)
assert result[0].resource_id == "db-master-5"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-5"
)
assert result[0].resource_tags == [{"Key": "env", "Value": "staging"}]
@mock_aws
def test_rds_instance_mysql_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-6",
AllocatedStorage=10,
Engine="mysql",
DBName="production-mariadb",
DBInstanceClass="db.m1.small",
StorageEncrypted=True,
DeletionProtection=True,
PubliclyAccessible=True,
AutoMinorVersionUpgrade=True,
BackupRetentionPeriod=10,
Port=3307,
Tags=[{"Key": "env", "Value": "production"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port import (
rds_instance_non_default_port,
)
check = rds_instance_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Instance db-master-6 is not using the default port 3307 for mysql."
)
assert result[0].resource_id == "db-master-6"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-6"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}
]
@mock_aws
def test_rds_instance_aurora_mysql_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-7",
AllocatedStorage=10,
Engine="aurora-mysql",
DBName="staging-mariadb",
DBInstanceClass="db.m1.small",
StorageEncrypted=True,
DeletionProtection=True,
PubliclyAccessible=True,
AutoMinorVersionUpgrade=True,
BackupRetentionPeriod=10,
Port=3306,
Tags=[{"Key": "env", "Value": "staging"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port import (
rds_instance_non_default_port,
)
check = rds_instance_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Instance db-master-7 is using the default port 3306 for aurora-mysql."
)
assert result[0].resource_id == "db-master-7"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-7"
)
assert result[0].resource_tags == [{"Key": "env", "Value": "staging"}]
@mock_aws
def test_rds_instance_aurora_mysql_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-8",
AllocatedStorage=10,
Engine="aurora-mysql",
DBName="production-mariadb",
DBInstanceClass="db.m1.small",
StorageEncrypted=True,
DeletionProtection=True,
PubliclyAccessible=True,
AutoMinorVersionUpgrade=True,
BackupRetentionPeriod=10,
Port=3307,
Tags=[{"Key": "env", "Value": "production"}],
)
from prowler.providers.aws.services.rds.rds_service import RDS
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port.rds_client",
new=RDS(aws_provider),
):
from prowler.providers.aws.services.rds.rds_instance_non_default_port.rds_instance_non_default_port import (
rds_instance_non_default_port,
)
check = rds_instance_non_default_port()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Instance db-master-8 is not using the default port 3307 for aurora-mysql."
)
assert result[0].resource_id == "db-master-8"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-8"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}
@@ -142,7 +454,7 @@ class Test_rds_instance_non_default_port:
def test_rds_instance_mariadb_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-3",
DBInstanceIdentifier="db-master-9",
AllocatedStorage=10,
Engine="mariadb",
DBName="staging-mariadb",
@@ -179,13 +491,13 @@ class Test_rds_instance_non_default_port:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== "RDS Instance db-master-3 is using the default port 3306 for mariadb."
== "RDS Instance db-master-9 is using the default port 3306 for mariadb."
)
assert result[0].resource_id == "db-master-3"
assert result[0].resource_id == "db-master-9"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-3"
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-9"
)
assert result[0].resource_tags == [{"Key": "env", "Value": "staging"}]
@@ -193,7 +505,7 @@ class Test_rds_instance_non_default_port:
def test_rds_instance_mariadb_non_default_port(self):
conn = client("rds", region_name=AWS_REGION_US_EAST_1)
conn.create_db_instance(
DBInstanceIdentifier="db-master-4",
DBInstanceIdentifier="db-master-10",
AllocatedStorage=10,
Engine="mariadb",
DBName="production-mariadb",
@@ -230,13 +542,13 @@ class Test_rds_instance_non_default_port:
assert result[0].status == "PASS"
assert (
result[0].status_extended
== "RDS Instance db-master-4 is not using the default port 3307 for mariadb."
== "RDS Instance db-master-10 is not using the default port 3307 for mariadb."
)
assert result[0].resource_id == "db-master-4"
assert result[0].resource_id == "db-master-10"
assert result[0].region == AWS_REGION_US_EAST_1
assert (
result[0].resource_arn
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-4"
== f"arn:aws:rds:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:db:db-master-10"
)
assert result[0].resource_tags == [
{"Key": "env", "Value": "production"}

View File

@@ -131,7 +131,10 @@ class Test_route53_dangling_ip_subdomain_takeover:
)
assert (
result[0].resource_id
== zone_id.replace("/hostedzone/", "") + "/192.168.1.1"
== zone_id.replace("/hostedzone/", "")
+ "/"
+ record_set_name
+ "/192.168.1.1"
)
assert (
result[0].resource_arn
@@ -196,7 +199,10 @@ class Test_route53_dangling_ip_subdomain_takeover:
)
assert (
result[0].resource_id
== zone_id.replace("/hostedzone/", "") + "/17.5.7.3"
== zone_id.replace("/hostedzone/", "")
+ "/"
+ record_set_name
+ "/17.5.7.3"
)
assert (
result[0].resource_arn
@@ -261,7 +267,10 @@ class Test_route53_dangling_ip_subdomain_takeover:
)
assert (
result[0].resource_id
== zone_id.replace("/hostedzone/", "") + "/54.152.12.70"
== zone_id.replace("/hostedzone/", "")
+ "/"
+ record_set_name
+ "/54.152.12.70"
)
assert (
result[0].resource_arn
@@ -330,7 +339,10 @@ class Test_route53_dangling_ip_subdomain_takeover:
)
assert (
result[0].resource_id
== zone_id.replace("/hostedzone/", "") + "/17.5.7.3"
== zone_id.replace("/hostedzone/", "")
+ "/"
+ record_set_name
+ "/17.5.7.3"
)
assert (
result[0].resource_arn
@@ -405,7 +417,10 @@ class Test_route53_dangling_ip_subdomain_takeover:
)
assert (
result[0].resource_id
== zone_id.replace("/hostedzone/", "") + "/17.5.7.3"
== zone_id.replace("/hostedzone/", "")
+ "/"
+ record_set_name
+ "/17.5.7.3"
)
assert (
result[0].resource_arn

View File

@@ -1,10 +1,61 @@
from unittest import mock
from unittest.mock import patch
import botocore
from boto3 import client
from moto import mock_aws
from tests.providers.aws.utils import AWS_REGION_US_EAST_1, set_mocked_aws_provider
# Original botocore _make_api_call function
orig = botocore.client.BaseClient._make_api_call
FM_RG_NAME = "test-firewall-managed-rule-group"
FM_RG_ARN = "arn:aws:wafv2:us-east-1:123456789012:regional/webacl/test-firewall-managed-rule-group"
# Mocked botocore _make_api_call function
def mock_make_api_call(self, operation_name, kwarg):
if operation_name == "ListWebACLs":
return {
"WebACLs": [
{
"Name": FM_RG_NAME,
"Id": FM_RG_NAME,
"ARN": FM_RG_ARN,
}
]
}
elif operation_name == "GetWebACL":
return {
"WebACL": {
"PostProcessFirewallManagerRuleGroups": [
{
"Name": FM_RG_NAME,
"VisibilityConfig": {
"SampledRequestsEnabled": True,
"CloudWatchMetricsEnabled": True,
"MetricName": "web-acl-test-metric",
},
}
]
}
}
elif operation_name == "ListResourcesForWebACL":
return {
"ResourceArns": [
FM_RG_ARN,
]
}
elif operation_name == "ListTagsForResource":
return {
"TagInfoForResource": {
"ResourceARN": FM_RG_ARN,
"TagList": [{"Key": "Name", "Value": FM_RG_NAME}],
}
}
return orig(self, operation_name, kwarg)
class Test_wafv2_webacl_with_rules:
@mock_aws
@@ -13,12 +64,15 @@ class Test_wafv2_webacl_with_rules:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
),
):
from prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules import (
wafv2_webacl_with_rules,
@@ -69,12 +123,15 @@ class Test_wafv2_webacl_with_rules:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
),
):
from prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules import (
wafv2_webacl_with_rules,
@@ -137,12 +194,15 @@ class Test_wafv2_webacl_with_rules:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
),
):
from prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules import (
wafv2_webacl_with_rules,
@@ -161,6 +221,43 @@ class Test_wafv2_webacl_with_rules:
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_tags == [{"Key": "Name", "Value": waf_name}]
@patch(
"botocore.client.BaseClient._make_api_call",
new=mock_make_api_call,
)
@mock_aws
def test_wafv2_web_acl_with_firewall_manager_managed_rule_group(self):
from prowler.providers.aws.services.wafv2.wafv2_service import WAFv2
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
),
):
from prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules import (
wafv2_webacl_with_rules,
)
check = wafv2_webacl_with_rules()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"AWS WAFv2 Web ACL {FM_RG_NAME} does have rules or rule groups attached."
)
assert result[0].resource_id == FM_RG_NAME
assert result[0].resource_arn == FM_RG_ARN
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_tags == [{"Key": "Name", "Value": FM_RG_NAME}]
@mock_aws
def test_wafv2_web_acl_without_rule_or_rule_group(self):
wafv2_client = client("wafv2", region_name=AWS_REGION_US_EAST_1)
@@ -184,12 +281,15 @@ class Test_wafv2_webacl_with_rules:
aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
),
mock.patch(
"prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules.wafv2_client",
new=WAFv2(aws_provider),
),
):
from prowler.providers.aws.services.wafv2.wafv2_webacl_with_rules.wafv2_webacl_with_rules import (
wafv2_webacl_with_rules,

View File

@@ -171,3 +171,45 @@ class Test_app_minimum_tls_version_12:
assert result[0].resource_name == "app_id-1"
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
assert result[0].location == "West Europe"
def test_app_min_tls_version_13(self):
resource_id = f"/subscriptions/{uuid4()}"
app_client = mock.MagicMock
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
), mock.patch(
"prowler.providers.azure.services.app.app_minimum_tls_version_12.app_minimum_tls_version_12.app_client",
new=app_client,
):
from prowler.providers.azure.services.app.app_minimum_tls_version_12.app_minimum_tls_version_12 import (
app_minimum_tls_version_12,
)
from prowler.providers.azure.services.app.app_service import WebApp
app_client.apps = {
AZURE_SUBSCRIPTION_ID: {
"app_id-1": WebApp(
resource_id=resource_id,
auth_enabled=False,
configurations=mock.MagicMock(min_tls_version="1.3"),
client_cert_mode="Ignore",
https_only=False,
identity=None,
location="West Europe",
)
}
}
check = app_minimum_tls_version_12()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Minimum TLS version is set to 1.3 for app 'app_id-1' in subscription '{AZURE_SUBSCRIPTION_ID}'."
)
assert result[0].resource_id == resource_id
assert result[0].resource_name == "app_id-1"
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
assert result[0].location == "West Europe"

View File

@@ -51,18 +51,23 @@ class TestGCPProvider:
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, "test-project"),
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, "test-project"),
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
project_id,
@@ -119,18 +124,23 @@ class TestGCPProvider:
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, None),
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, None),
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
@@ -193,21 +203,27 @@ class TestGCPProvider:
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"os.path.abspath",
return_value="test_credentials_file",
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
@@ -257,21 +273,27 @@ class TestGCPProvider:
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"os.path.abspath",
return_value="test_credentials_file",
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
@@ -334,21 +356,27 @@ class TestGCPProvider:
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"os.path.abspath",
return_value="test_credentials_file",
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
@@ -401,21 +429,27 @@ class TestGCPProvider:
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"os.path.abspath",
return_value="test_credentials_file",
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
with pytest.raises(Exception) as e:
GcpProvider(
@@ -433,6 +467,81 @@ class TestGCPProvider:
)
assert e.type == GCPNoAccesibleProjectsError
def test_setup_session_with_inactive_default_project(self):
mocked_credentials = MagicMock()
mocked_credentials.refresh.return_value = None
mocked_credentials._service_account_email = "test-service-account-email"
arguments = Namespace()
arguments.project_id = ["default_project", "active_project"]
arguments.excluded_project_id = []
arguments.organization_id = None
arguments.list_project_id = False
arguments.credentials_file = "test_credentials_file"
arguments.impersonate_service_account = ""
arguments.config_file = default_config_file_path
arguments.fixer_config = default_fixer_config_file_path
projects = {
"default_project": GCPProject(
number="55555555",
id="default_project",
name="default_project",
labels={"test": "value"},
lifecycle_state="DELETE_REQUESTED",
),
"active_project": GCPProject(
number="12345678",
id="active_project",
name="active_project",
labels={"test": "value"},
lifecycle_state="ACTIVE",
),
}
mocked_service = MagicMock()
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, "default_project"),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
arguments.project_id,
arguments.excluded_project_id,
arguments.credentials_file,
arguments.impersonate_service_account,
arguments.list_project_id,
arguments.config_file,
arguments.fixer_config,
client_id=None,
client_secret=None,
refresh_token=None,
)
assert gcp_provider.default_project_id == "active_project"
def test_print_credentials_default_options(self, capsys):
mocked_credentials = MagicMock()
@@ -464,21 +573,27 @@ class TestGCPProvider:
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"os.path.abspath",
return_value="test_credentials_file",
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
@@ -535,21 +650,27 @@ class TestGCPProvider:
mocked_service.projects.list.return_value = MagicMock(
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"os.path.abspath",
return_value="test_credentials_file",
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
@@ -614,21 +735,27 @@ class TestGCPProvider:
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
), patch(
"os.path.abspath",
return_value="test_credentials_file",
), patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.get_projects",
return_value=projects,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.update_projects_with_organizations",
return_value=None,
),
patch(
"os.path.abspath",
return_value="test_credentials_file",
),
patch(
"prowler.providers.gcp.gcp_provider.default",
return_value=(mocked_credentials, MagicMock()),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
gcp_provider = GcpProvider(
arguments.organization_id,
@@ -698,12 +825,15 @@ class TestGCPProvider:
execute=MagicMock(return_value={"projectId": project_id})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, project_id),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, project_id),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
):
output = GcpProvider.test_connection(
client_id="test-client-id",
@@ -730,16 +860,19 @@ class TestGCPProvider:
execute=MagicMock(return_value={"projects": projects})
)
with patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, "test-valid-project"),
), patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
), patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.validate_project_id"
) as mock_validate_project_id:
with (
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.setup_session",
return_value=(None, "test-valid-project"),
),
patch(
"prowler.providers.gcp.gcp_provider.discovery.build",
return_value=mocked_service,
),
patch(
"prowler.providers.gcp.gcp_provider.GcpProvider.validate_project_id"
) as mock_validate_project_id,
):
mock_validate_project_id.side_effect = GCPInvalidProviderIdError(
"Invalid project ID"
)

View File

@@ -51,14 +51,28 @@ export const updateUser = async (formData: FormData) => {
const session = await auth();
const keyServer = process.env.API_BASE_URL;
const userId = formData.get("userId");
const userName = formData.get("name");
const userPassword = formData.get("password");
const userEmail = formData.get("email");
const userCompanyName = formData.get("company_name");
const userId = formData.get("userId") as string; // Ensure userId is a string
const userName = formData.get("name") as string | null;
const userPassword = formData.get("password") as string | null;
const userEmail = formData.get("email") as string | null;
const userCompanyName = formData.get("company_name") as string | null;
const url = new URL(`${keyServer}/users/${userId}`);
// Prepare attributes to send based on changes
const attributes: Record<string, any> = {};
// Add only changed fields
if (userName !== null) attributes.name = userName;
if (userEmail !== null) attributes.email = userEmail;
if (userCompanyName !== null) attributes.company_name = userCompanyName;
if (userPassword !== null) attributes.password = userPassword;
// If no fields have changed, don't send the request
if (Object.keys(attributes).length === 0) {
return { error: "No changes detected" };
}
try {
const response = await fetch(url.toString(), {
method: "PATCH",
@@ -71,15 +85,11 @@ export const updateUser = async (formData: FormData) => {
data: {
type: "users",
id: userId,
attributes: {
name: userName,
password: userPassword,
email: userEmail,
company_name: userCompanyName,
},
attributes: attributes,
},
}),
});
const data = await response.json();
revalidatePath("/users");
return parseStringify(data);

View File

@@ -26,7 +26,7 @@ interface InvitationDetailsProps {
}
export const InvitationDetails = ({ attributes }: InvitationDetailsProps) => {
const baseURL = process.env.SITE_URL || "http://localhost:3000";
const baseURL = process.env.SITE_URL;
const invitationLink = `${baseURL}/sign-up?invitation_token=${attributes.token}`;
return (
<div className="flex flex-col gap-x-4 gap-y-8">

View File

@@ -76,7 +76,7 @@ export function DataTableRowActions<InvitationProps>({
>
<DropdownSection title="Actions">
<DropdownItem
href={`http://localhost:3000/invitations/check-details?id=${invitationId}`}
href={`/invitations/check-details?id=${invitationId}`}
key="check-details"
description="View invitation details"
textValue="Check Details"

View File

@@ -25,11 +25,7 @@ export const EditForm = ({
userCompanyName?: string;
setIsOpen: Dispatch<SetStateAction<boolean>>;
}) => {
const formSchema = editUserFormSchema(
userName ?? "",
userEmail ?? "",
userCompanyName ?? "",
);
const formSchema = editUserFormSchema();
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
@@ -48,16 +44,26 @@ export const EditForm = ({
const onSubmitClient = async (values: z.infer<typeof formSchema>) => {
const formData = new FormData();
Object.entries(values).forEach(
([key, value]) => value !== undefined && formData.append(key, value),
);
// Check if the value is not undefined before appending to FormData
if (values.name !== undefined) {
formData.append("name", values.name);
}
if (values.email !== undefined) {
formData.append("email", values.email);
}
if (values.company_name !== undefined) {
formData.append("company_name", values.company_name);
}
// Always include userId
formData.append("userId", userId);
const data = await updateUser(formData);
if (data?.errors && data.errors.length > 0) {
const error = data.errors[0];
const errorMessage = `${error.detail}`;
// show error
// Show error
toast({
variant: "destructive",
title: "Oops! Something went wrong",

12
ui/package-lock.json generated
View File

@@ -6965,9 +6965,9 @@
}
},
"node_modules/cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
"integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==",
"version": "7.0.6",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
"integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==",
"dependencies": {
"path-key": "^3.1.0",
"shebang-command": "^2.0.0",
@@ -10239,9 +10239,9 @@
}
},
"node_modules/nanoid": {
"version": "3.3.7",
"resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz",
"integrity": "sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==",
"version": "3.3.8",
"resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz",
"integrity": "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==",
"funding": [
{
"type": "github",

View File

@@ -1,3 +0,0 @@
Django==5.0.9
djangorestframework==3.15.2
django-cors-headers==4.3.1

View File

@@ -160,36 +160,21 @@ export const editInviteFormSchema = z.object({
expires_at: z.string().optional(),
});
export const editUserFormSchema = (
currentName: string,
currentEmail: string,
currentCompanyName: string,
) =>
export const editUserFormSchema = () =>
z.object({
name: z
.string()
.min(3, { message: "The name must have at least 3 characters." })
.max(150, { message: "The name cannot exceed 150 characters." })
.refine((val) => val !== currentName, {
message: "The new name must be different from the current one.",
})
.optional(),
email: z
.string()
.email({ message: "Please enter a valid email address." })
.refine((val) => val !== currentEmail, {
message: "The new email must be different from the current one.",
})
.optional(),
password: z
.string()
.min(1, { message: "The password cannot be empty." })
.optional(),
company_name: z
.string()
.refine((val) => val !== currentCompanyName, {
message: "The new company name must be different from the current one.",
})
.optional(),
company_name: z.string().optional(),
userId: z.string(),
});