Compare commits

..

28 Commits

Author SHA1 Message Date
sumit_chaturvedi 98047e690d Merge branch 'master' into PRWLR-7512-create-custom-link-component 2025-07-08 15:16:46 +05:30
Pepe Fagoaga fe00b788cc fix: Remove type validation while updating provider credentials (#8197)
Co-authored-by: Víctor Fernández Poyatos <victor@prowler.com>
2025-07-08 15:27:02 +05:45
sumit_chaturvedi fa11e98a55 chore(ui): addressed PR comments 2025-07-08 14:33:56 +05:30
Rubén De la Torre Vico 4c50f4d811 feat(azure/vm): add new check vm_backup_enabled (#8182)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2025-07-08 17:01:22 +08:00
Rubén De la Torre Vico c0c736bffe chore: ignore some files from AI editors (#8209) 2025-07-08 10:43:38 +02:00
dependabot[bot] a3aa7d0a63 chore(deps): bump python from 3.12.10-slim-bookworm to 3.12.11-slim-bookworm (#8157)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-08 16:43:13 +08:00
Rubén De la Torre Vico 3ceb86c4d9 feat(azure/vm): add new check vm_scaleset_associated_load_balancer (#8181) 2025-07-08 16:40:43 +08:00
Rubén De la Torre Vico 3628e7b3e8 feat(azure/vm): add new check vm_ensure_using_approved_images (#8168) 2025-07-08 16:40:33 +08:00
Chandrapal Badshah f29c2ac9f0 docs(lighthouse): Add Lighthouse Docs (#8196)
Co-authored-by: Chandrapal Badshah <12944530+Chan9390@users.noreply.github.com>
Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
2025-07-08 11:56:23 +05:45
sumit_chaturvedi ff691f1d37 Merge branch 'master' into PRWLR-7512-create-custom-link-component
# Conflicts:
#	ui/CHANGELOG.md
2025-07-08 08:22:04 +05:30
Pablo Lara b4927c3ad1 chore: Update CHANGELOG UI (#8204) 2025-07-07 17:54:44 +02:00
Adrián Jesús Peña Rodríguez 19f3c1d310 chore(saml): restore SAML button (#8203) 2025-07-07 17:34:05 +02:00
Adrián Jesús Peña Rodríguez cd97e57521 fix(saml): restore SAML, deactivate urls, enable idp-initiate (#8175) 2025-07-07 16:42:11 +02:00
Hugo Pereira Brito b38207507a chore(docs): enhance M365 auth documentation (#8199)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2025-07-07 22:01:41 +08:00
Rubén De la Torre Vico ab96e0aac0 feat(azure/vm): add new check vm_linux_enforce_ssh_authentication (#8149) 2025-07-07 22:01:11 +08:00
Prowler Bot 4477cecc59 chore(regions_update): Changes in regions for AWS services (#8198)
Co-authored-by: prowler-bot <179230569+prowler-bot@users.noreply.github.com>
2025-07-07 18:04:49 +08:00
sumit_chaturvedi 819c9306ee docs: changelog update 2025-07-07 10:56:30 +05:30
sumit_chaturvedi bfc72170c5 feat(ui): create CustomLink component decoupled from CustomButton 2025-07-07 10:19:50 +05:30
Pablo Lara 641d671312 chore: upgrade to Next.js 14.2.30 and lock TypeScript to 5.5.4 for ES… (#8189) 2025-07-04 13:20:30 +02:00
Víctor Fernández Poyatos e7c2fa0699 fix(findings): avoid backfill on empty scans (#8183) 2025-07-04 12:24:49 +02:00
Pedro Martín 7eb08b0f14 fix(ec2): allow empty values for http_endpoint in templates (#8184) 2025-07-04 18:03:51 +08:00
Rubén De la Torre Vico 6f3112f754 feat(storage): add new check storage_smb_channel_encryption_with_secure_algorithm (#8123)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2025-07-04 15:26:33 +08:00
Kay Agahd f5ecae6da1 fix(iam): detect wildcarded ARNs in sts:AssumeRole policy resources (#8164)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2025-07-03 23:09:48 +08:00
Prowler Bot 1c75f6b804 chore(release): Bump version to v5.9.0 (#8178)
Co-authored-by: prowler-bot <179230569+prowler-bot@users.noreply.github.com>
2025-07-03 23:08:37 +08:00
Daniel Barranquero 91b64d8572 chore(docs): update m365 docs for app auth in cloud (#8147)
Co-authored-by: Sergio Garcia <hello@mistercloudsec.com>
2025-07-03 23:08:15 +08:00
Pablo Lara 233ae74560 fix: disable dynamic filters for now (#8177) 2025-07-03 14:17:02 +02:00
Alejandro Bailo fac97f9785 fix: remove duplicated calls during promise all resolving (#8176) 2025-07-03 14:02:57 +02:00
Pablo Lara e81c7a3893 fix: bug when updating credentials for m365 (#8173) 2025-07-03 11:31:40 +02:00
125 changed files with 8264 additions and 7772 deletions
+10
View File
@@ -44,6 +44,16 @@ junit-reports/
# Cursor files
.cursorignore
.cursor/
# RooCode files
.roo/
.rooignore
.roomodes
# Cline files
.cline/
.clineignore
# Terraform
.terraform*
+1 -1
View File
@@ -1,4 +1,4 @@
FROM python:3.12.10-slim-bookworm AS build
FROM python:3.12.11-slim-bookworm AS build
LABEL maintainer="https://github.com/prowler-cloud/prowler"
LABEL org.opencontainers.image.source="https://github.com/prowler-cloud/prowler"
+3 -7
View File
@@ -4,19 +4,15 @@ All notable changes to the **Prowler API** are documented in this file.
## [v1.10.0] (Prowler UNRELEASED)
### Added
- SSO with SAML support [(#8175)](https://github.com/prowler-cloud/prowler/pull/8175)
---
## [v1.9.1] (Prowler v5.8.1)
### Added
- Custom exception for provider connection errors during scans [(#8234)](https://github.com/prowler-cloud/prowler/pull/8234)
### Changed
- Summary and overview tasks now use a dedicated queue and no longer propagate errors to compliance tasks [(#8214)](https://github.com/prowler-cloud/prowler/pull/8214)
### Fixed
- Scan with no resources will not trigger legacy code for findings metadata [(#8183)](https://github.com/prowler-cloud/prowler/pull/8183)
- Invitation email comparison case-insensitive [(#8206)](https://github.com/prowler-cloud/prowler/pull/8206)
### Removed
- Validation of the provider's secret type during updates [(#8197)](https://github.com/prowler-cloud/prowler/pull/8197)
+1 -1
View File
@@ -32,7 +32,7 @@ start_prod_server() {
start_worker() {
echo "Starting the worker..."
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview -E --max-tasks-per-child 1
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill -E --max-tasks-per-child 1
}
start_worker_beat() {
+104 -1313
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -23,7 +23,7 @@ dependencies = [
"drf-spectacular==0.27.2",
"drf-spectacular-jsonapi==0.5.1",
"gunicorn==23.0.0",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@v5.8",
"prowler @ git+https://github.com/prowler-cloud/prowler.git@master",
"psycopg2-binary==2.9.9",
"pytest-celery[redis] (>=1.0.1,<2.0.0)",
"sentry-sdk[django] (>=2.20.0,<3.0.0)",
@@ -36,7 +36,7 @@ name = "prowler-api"
package-mode = false
# Needed for the SDK compatibility
requires-python = ">=3.11,<3.13"
version = "1.9.1"
version = "1.9.0"
[project.scripts]
celery = "src.backend.config.settings.celery"
+31 -90
View File
@@ -17,8 +17,8 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
def pre_social_login(self, request, sociallogin):
# Link existing accounts with the same email address
email = sociallogin.account.extra_data.get("email")
# if sociallogin.account.provider == "saml":
# email = sociallogin.user.email
if sociallogin.provider.id == "saml":
email = sociallogin.user.email
if email:
existing_user = self.get_user_by_email(email)
if existing_user:
@@ -31,98 +31,39 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
"""
with transaction.atomic(using=MainRouter.admin_db):
user = super().save_user(request, sociallogin, form)
provider = sociallogin.provider.id
extra = sociallogin.account.extra_data
# if provider == "saml":
# # Handle SAML-specific logic
# user.first_name = (
# extra.get("firstName", [""])[0] if extra.get("firstName") else ""
# )
# user.last_name = (
# extra.get("lastName", [""])[0] if extra.get("lastName") else ""
# )
# user.company_name = (
# extra.get("organization", [""])[0]
# if extra.get("organization")
# else ""
# )
# user.name = f"{user.first_name} {user.last_name}".strip()
# if user.name == "":
# user.name = "N/A"
# user.save(using=MainRouter.admin_db)
# email_domain = user.email.split("@")[-1]
# tenant = (
# SAMLConfiguration.objects.using(MainRouter.admin_db)
# .get(email_domain=email_domain)
# .tenant
# )
# with rls_transaction(str(tenant.id)):
# role_name = (
# extra.get("userType", ["saml_default_role"])[0].strip()
# if extra.get("userType")
# else "saml_default_role"
# )
# try:
# role = Role.objects.using(MainRouter.admin_db).get(
# name=role_name, tenant_id=tenant.id
# )
# except Role.DoesNotExist:
# role = Role.objects.using(MainRouter.admin_db).create(
# name=role_name,
# tenant_id=tenant.id,
# manage_users=False,
# manage_account=False,
# manage_billing=False,
# manage_providers=False,
# manage_integrations=False,
# manage_scans=False,
# unlimited_visibility=False,
# )
# Membership.objects.using(MainRouter.admin_db).create(
# user=user,
# tenant=tenant,
# role=Membership.RoleChoices.MEMBER,
# )
# UserRoleRelationship.objects.using(MainRouter.admin_db).create(
# user=user,
# role=role,
# tenant_id=tenant.id,
# )
# Handle other providers (e.g., GitHub, Google)
user.save(using=MainRouter.admin_db)
social_account_name = extra.get("name")
if social_account_name:
user.name = social_account_name
if provider != "saml":
# Handle other providers (e.g., GitHub, Google)
user.save(using=MainRouter.admin_db)
social_account_name = extra.get("name")
if social_account_name:
user.name = social_account_name
user.save(using=MainRouter.admin_db)
tenant = Tenant.objects.using(MainRouter.admin_db).create(
name=f"{user.email.split('@')[0]} default tenant"
)
with rls_transaction(str(tenant.id)):
Membership.objects.using(MainRouter.admin_db).create(
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
)
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
tenant = Tenant.objects.using(MainRouter.admin_db).create(
name=f"{user.email.split('@')[0]} default tenant"
)
with rls_transaction(str(tenant.id)):
Membership.objects.using(MainRouter.admin_db).create(
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
)
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
return user
-5
View File
@@ -57,11 +57,6 @@ class TaskInProgressException(TaskManagementError):
super().__init__()
# Provider connection errors
class ProviderConnectionError(Exception):
"""Base exception for provider connection errors."""
def custom_exception_handler(exc, context):
if isinstance(exc, django_validation_error):
if hasattr(exc, "error_dict"):
+150
View File
@@ -0,0 +1,150 @@
# Generated by Django 5.1.10 on 2025-07-02 15:47
import uuid
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
import api.db_utils
import api.rls
class Migration(migrations.Migration):
dependencies = [
("api", "0031_scan_disable_on_cascade_periodic_tasks"),
]
operations = [
migrations.AlterField(
model_name="integration",
name="integration_type",
field=api.db_utils.IntegrationTypeEnumField(
choices=[
("amazon_s3", "Amazon S3"),
("aws_security_hub", "AWS Security Hub"),
("jira", "JIRA"),
("slack", "Slack"),
]
),
),
migrations.CreateModel(
name="SAMLToken",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("expires_at", models.DateTimeField(editable=False)),
("token", models.JSONField(unique=True)),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"db_table": "saml_tokens",
},
),
migrations.CreateModel(
name="SAMLConfiguration",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
(
"email_domain",
models.CharField(
help_text="Email domain used to identify the tenant, e.g. prowlerdemo.com",
max_length=254,
unique=True,
),
),
(
"metadata_xml",
models.TextField(
help_text="Raw IdP metadata XML to configure SingleSignOnService, certificates, etc."
),
),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "saml_configurations",
},
),
migrations.AddConstraint(
model_name="samlconfiguration",
constraint=api.rls.RowLevelSecurityConstraint(
"tenant_id",
name="rls_on_samlconfiguration",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
migrations.AddConstraint(
model_name="samlconfiguration",
constraint=models.UniqueConstraint(
fields=("tenant",), name="unique_samlconfig_per_tenant"
),
),
migrations.CreateModel(
name="SAMLDomainIndex",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("email_domain", models.CharField(max_length=254, unique=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="api.tenant"
),
),
],
options={
"db_table": "saml_domain_index",
},
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=models.UniqueConstraint(
fields=("email_domain", "tenant"),
name="unique_resources_by_email_domain",
),
),
migrations.AddConstraint(
model_name="samldomainindex",
constraint=api.rls.BaseSecurityConstraint(
name="statements_on_samldomainindex",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
),
]
+207 -203
View File
@@ -1,15 +1,21 @@
import json
import logging
import re
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
from allauth.socialaccount.models import SocialApp
from config.custom_logging import BackendLogger
from config.settings.social_login import SOCIALACCOUNT_PROVIDERS
from cryptography.fernet import Fernet, InvalidToken
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.search import SearchVector, SearchVectorField
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q
@@ -21,6 +27,7 @@ from psqlextra.models import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
from api.db_router import MainRouter
from api.db_utils import (
CustomUserManager,
FindingDeltaEnumField,
@@ -936,11 +943,6 @@ class Invitation(RowLevelSecurityProtectedModel):
null=True,
)
def save(self, *args, **kwargs):
if self.email:
self.email = self.email.strip().lower()
super().save(*args, **kwargs)
class Meta(RowLevelSecurityProtectedModel.Meta):
db_table = "invitations"
@@ -1369,242 +1371,244 @@ class IntegrationProviderRelationship(RowLevelSecurityProtectedModel):
]
# class SAMLToken(models.Model):
# id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
# inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
# updated_at = models.DateTimeField(auto_now=True, editable=False)
# expires_at = models.DateTimeField(editable=False)
# token = models.JSONField(unique=True)
# user = models.ForeignKey(User, on_delete=models.CASCADE)
class SAMLToken(models.Model):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
expires_at = models.DateTimeField(editable=False)
token = models.JSONField(unique=True)
user = models.ForeignKey(User, on_delete=models.CASCADE)
# class Meta:
# db_table = "saml_tokens"
class Meta:
db_table = "saml_tokens"
# def save(self, *args, **kwargs):
# if not self.expires_at:
# self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15)
# super().save(*args, **kwargs)
def save(self, *args, **kwargs):
if not self.expires_at:
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15)
super().save(*args, **kwargs)
# def is_expired(self) -> bool:
# return datetime.now(timezone.utc) >= self.expires_at
def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at
# class SAMLDomainIndex(models.Model):
# """
# Public index of SAML domains. No RLS. Used for fast lookup in SAML login flow.
# """
class SAMLDomainIndex(models.Model):
"""
Public index of SAML domains. No RLS. Used for fast lookup in SAML login flow.
"""
# email_domain = models.CharField(max_length=254, unique=True)
# tenant = models.ForeignKey("Tenant", on_delete=models.CASCADE)
email_domain = models.CharField(max_length=254, unique=True)
tenant = models.ForeignKey("Tenant", on_delete=models.CASCADE)
# class Meta:
# db_table = "saml_domain_index"
class Meta:
db_table = "saml_domain_index"
# constraints = [
# models.UniqueConstraint(
# fields=("email_domain", "tenant"),
# name="unique_resources_by_email_domain",
# ),
# BaseSecurityConstraint(
# name="statements_on_%(class)s",
# statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
# ),
# ]
constraints = [
models.UniqueConstraint(
fields=("email_domain", "tenant"),
name="unique_resources_by_email_domain",
),
BaseSecurityConstraint(
name="statements_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
]
# class SAMLConfiguration(RowLevelSecurityProtectedModel):
# """
# Stores per-tenant SAML settings, including email domain and IdP metadata.
# Automatically syncs to a SocialApp instance on save.
class SAMLConfiguration(RowLevelSecurityProtectedModel):
"""
Stores per-tenant SAML settings, including email domain and IdP metadata.
Automatically syncs to a SocialApp instance on save.
# Note:
# This model exists to provide a tenant-aware abstraction over SAML configuration.
# It supports row-level security, custom validation, and metadata parsing, enabling
# Prowler to expose a clean API and admin interface for managing SAML integrations.
Note:
This model exists to provide a tenant-aware abstraction over SAML configuration.
It supports row-level security, custom validation, and metadata parsing, enabling
Prowler to expose a clean API and admin interface for managing SAML integrations.
# Although Django Allauth uses the SocialApp model to store provider configuration,
# it is not designed for multi-tenant use. SocialApp lacks support for tenant scoping,
# email domain mapping, and structured metadata handling.
Although Django Allauth uses the SocialApp model to store provider configuration,
it is not designed for multi-tenant use. SocialApp lacks support for tenant scoping,
email domain mapping, and structured metadata handling.
# By managing SAMLConfiguration separately, we ensure:
# - Strong isolation between tenants via RLS.
# - Ownership of raw IdP metadata and its validation.
# - An explicit link between SAML config and business-level identifiers (e.g. email domain).
# - Programmatic transformation into the SocialApp format used by Allauth.
By managing SAMLConfiguration separately, we ensure:
- Strong isolation between tenants via RLS.
- Ownership of raw IdP metadata and its validation.
- An explicit link between SAML config and business-level identifiers (e.g. email domain).
- Programmatic transformation into the SocialApp format used by Allauth.
# In short, this model acts as a secure and user-friendly layer over Allauth's lower-level primitives.
# """
In short, this model acts as a secure and user-friendly layer over Allauth's lower-level primitives.
"""
# id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
# email_domain = models.CharField(
# max_length=254,
# unique=True,
# help_text="Email domain used to identify the tenant, e.g. prowlerdemo.com",
# )
# metadata_xml = models.TextField(
# help_text="Raw IdP metadata XML to configure SingleSignOnService, certificates, etc."
# )
# created_at = models.DateTimeField(auto_now_add=True)
# updated_at = models.DateTimeField(auto_now=True)
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
email_domain = models.CharField(
max_length=254,
unique=True,
help_text="Email domain used to identify the tenant, e.g. prowlerdemo.com",
)
metadata_xml = models.TextField(
help_text="Raw IdP metadata XML to configure SingleSignOnService, certificates, etc."
)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
# class JSONAPIMeta:
# resource_name = "saml-configurations"
class JSONAPIMeta:
resource_name = "saml-configurations"
# class Meta:
# db_table = "saml_configurations"
class Meta:
db_table = "saml_configurations"
# constraints = [
# RowLevelSecurityConstraint(
# field="tenant_id",
# name="rls_on_%(class)s",
# statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
# ),
# # 1 config per tenant
# models.UniqueConstraint(
# fields=["tenant"],
# name="unique_samlconfig_per_tenant",
# ),
# ]
constraints = [
RowLevelSecurityConstraint(
field="tenant_id",
name="rls_on_%(class)s",
statements=["SELECT", "INSERT", "UPDATE", "DELETE"],
),
# 1 config per tenant
models.UniqueConstraint(
fields=["tenant"],
name="unique_samlconfig_per_tenant",
),
]
# def clean(self, old_email_domain=None):
# # Domain must not contain @
# if "@" in self.email_domain:
# raise ValidationError({"email_domain": "Domain must not contain @"})
def clean(self, old_email_domain=None):
# Domain must not contain @
if "@" in self.email_domain:
raise ValidationError({"email_domain": "Domain must not contain @"})
# # Enforce at most one config per tenant
# qs = SAMLConfiguration.objects.filter(tenant=self.tenant)
# # Exclude ourselves in case of update
# if self.pk:
# qs = qs.exclude(pk=self.pk)
# if qs.exists():
# raise ValidationError(
# {"tenant": "A SAML configuration already exists for this tenant."}
# )
# Enforce at most one config per tenant
qs = SAMLConfiguration.objects.filter(tenant=self.tenant)
# Exclude ourselves in case of update
if self.pk:
qs = qs.exclude(pk=self.pk)
if qs.exists():
raise ValidationError(
{"tenant": "A SAML configuration already exists for this tenant."}
)
# # The email domain must be unique in the entire system
# qs = SAMLConfiguration.objects.using(MainRouter.admin_db).filter(
# email_domain__iexact=self.email_domain
# )
# if qs.exists() and old_email_domain != self.email_domain:
# raise ValidationError(
# {"tenant": "There is a problem with your email domain."}
# )
# The email domain must be unique in the entire system
qs = SAMLConfiguration.objects.using(MainRouter.admin_db).filter(
email_domain__iexact=self.email_domain
)
if qs.exists() and old_email_domain != self.email_domain:
raise ValidationError(
{"tenant": "There is a problem with your email domain."}
)
# def save(self, *args, **kwargs):
# self.email_domain = self.email_domain.strip().lower()
# is_create = not SAMLConfiguration.objects.filter(pk=self.pk).exists()
def save(self, *args, **kwargs):
self.email_domain = self.email_domain.strip().lower()
is_create = not SAMLConfiguration.objects.filter(pk=self.pk).exists()
# if not is_create:
# old = SAMLConfiguration.objects.get(pk=self.pk)
# old_email_domain = old.email_domain
# old_metadata_xml = old.metadata_xml
# else:
# old_email_domain = None
# old_metadata_xml = None
if not is_create:
old = SAMLConfiguration.objects.get(pk=self.pk)
old_email_domain = old.email_domain
old_metadata_xml = old.metadata_xml
else:
old_email_domain = None
old_metadata_xml = None
# self.clean(old_email_domain)
# super().save(*args, **kwargs)
self.clean(old_email_domain)
super().save(*args, **kwargs)
# if is_create or (
# old_email_domain != self.email_domain
# or old_metadata_xml != self.metadata_xml
# ):
# self._sync_social_app(old_email_domain)
if is_create or (
old_email_domain != self.email_domain
or old_metadata_xml != self.metadata_xml
):
self._sync_social_app(old_email_domain)
# # Sync the public index
# if not is_create and old_email_domain and old_email_domain != self.email_domain:
# SAMLDomainIndex.objects.filter(email_domain=old_email_domain).delete()
# Sync the public index
if not is_create and old_email_domain and old_email_domain != self.email_domain:
SAMLDomainIndex.objects.filter(email_domain=old_email_domain).delete()
# # Create/update the new domain index
# SAMLDomainIndex.objects.update_or_create(
# email_domain=self.email_domain, defaults={"tenant": self.tenant}
# )
# Create/update the new domain index
SAMLDomainIndex.objects.update_or_create(
email_domain=self.email_domain, defaults={"tenant": self.tenant}
)
# def _parse_metadata(self):
# """
# Parse the raw IdP metadata XML and extract:
# - entity_id
# - sso_url
# - slo_url (may be None)
# - x509cert (required)
# """
# ns = {
# "md": "urn:oasis:names:tc:SAML:2.0:metadata",
# "ds": "http://www.w3.org/2000/09/xmldsig#",
# }
# try:
# root = ET.fromstring(self.metadata_xml)
# except ET.ParseError as e:
# raise ValidationError({"metadata_xml": f"Invalid XML: {e}"})
def _parse_metadata(self):
"""
Parse the raw IdP metadata XML and extract:
- entity_id
- sso_url
- slo_url (may be None)
- x509cert (required)
"""
ns = {
"md": "urn:oasis:names:tc:SAML:2.0:metadata",
"ds": "http://www.w3.org/2000/09/xmldsig#",
}
try:
root = ET.fromstring(self.metadata_xml)
except ET.ParseError as e:
raise ValidationError({"metadata_xml": f"Invalid XML: {e}"})
# # Entity ID
# entity_id = root.attrib.get("entityID")
# Entity ID
entity_id = root.attrib.get("entityID")
# # SSO endpoint (must exist)
# sso = root.find(".//md:IDPSSODescriptor/md:SingleSignOnService", ns)
# if sso is None or "Location" not in sso.attrib:
# raise ValidationError(
# {"metadata_xml": "Missing SingleSignOnService in metadata."}
# )
# sso_url = sso.attrib["Location"]
# SSO endpoint (must exist)
sso = root.find(".//md:IDPSSODescriptor/md:SingleSignOnService", ns)
if sso is None or "Location" not in sso.attrib:
raise ValidationError(
{"metadata_xml": "Missing SingleSignOnService in metadata."}
)
sso_url = sso.attrib["Location"]
# # SLO endpoint (optional)
# slo = root.find(".//md:IDPSSODescriptor/md:SingleLogoutService", ns)
# slo_url = slo.attrib.get("Location") if slo is not None else None
# SLO endpoint (optional)
slo = root.find(".//md:IDPSSODescriptor/md:SingleLogoutService", ns)
slo_url = slo.attrib.get("Location") if slo is not None else None
# # X.509 certificate (required)
# cert = root.find(
# './/md:KeyDescriptor[@use="signing"]/ds:KeyInfo/ds:X509Data/ds:X509Certificate',
# ns,
# )
# if cert is None or not cert.text or not cert.text.strip():
# raise ValidationError(
# {
# "metadata_xml": 'Metadata must include a <ds:X509Certificate> under <KeyDescriptor use="signing">.'
# }
# )
# x509cert = cert.text.strip()
# X.509 certificate (required)
cert = root.find(
'.//md:KeyDescriptor[@use="signing"]/ds:KeyInfo/ds:X509Data/ds:X509Certificate',
ns,
)
if cert is None or not cert.text or not cert.text.strip():
raise ValidationError(
{
"metadata_xml": 'Metadata must include a <ds:X509Certificate> under <KeyDescriptor use="signing">.'
}
)
x509cert = cert.text.strip()
# return {
# "entity_id": entity_id,
# "sso_url": sso_url,
# "slo_url": slo_url,
# "x509cert": x509cert,
# }
return {
"entity_id": entity_id,
"sso_url": sso_url,
"slo_url": slo_url,
"x509cert": x509cert,
}
# def _sync_social_app(self, previous_email_domain=None):
# """
# Create or update the corresponding SocialApp based on email_domain.
# If the domain changed, update the matching SocialApp.
# """
# idp_settings = self._parse_metadata()
# settings_dict = SOCIALACCOUNT_PROVIDERS["saml"].copy()
# settings_dict["idp"] = idp_settings
def _sync_social_app(self, previous_email_domain=None):
"""
Create or update the corresponding SocialApp based on email_domain.
If the domain changed, update the matching SocialApp.
"""
idp_settings = self._parse_metadata()
settings_dict = SOCIALACCOUNT_PROVIDERS["saml"].copy()
settings_dict["idp"] = idp_settings
# current_site = Site.objects.get(id=settings.SITE_ID)
current_site = Site.objects.get(id=settings.SITE_ID)
# social_app_qs = SocialApp.objects.filter(
# provider="saml", client_id=previous_email_domain or self.email_domain
# )
social_app_qs = SocialApp.objects.filter(
provider="saml", client_id=previous_email_domain or self.email_domain
)
# client_id = self.email_domain[:191]
# name = f"SAML-{self.email_domain}"[:40]
client_id = self.email_domain[:191]
name = f"SAML-{self.email_domain}"[:40]
# if social_app_qs.exists():
# social_app = social_app_qs.first()
# social_app.client_id = client_id
# social_app.name = name
# social_app.settings = settings_dict
# social_app.save()
# social_app.sites.set([current_site])
# else:
# social_app = SocialApp.objects.create(
# provider="saml",
# client_id=client_id,
# name=name,
# settings=settings_dict,
# )
# social_app.sites.set([current_site])
if social_app_qs.exists():
social_app = social_app_qs.first()
social_app.client_id = client_id
social_app.name = name
social_app.settings = settings_dict
social_app.provider_id = idp_settings["entity_id"]
social_app.save()
social_app.sites.set([current_site])
else:
social_app = SocialApp.objects.create(
provider="saml",
client_id=client_id,
name=name,
settings=settings_dict,
provider_id=idp_settings["entity_id"],
)
social_app.sites.set([current_site])
class ResourceScanSummary(RowLevelSecurityProtectedModel):
+340 -1
View File
@@ -1,7 +1,7 @@
openapi: 3.0.3
info:
title: Prowler API
version: 1.9.1
version: 1.9.0
description: |-
Prowler API specification.
@@ -5152,6 +5152,199 @@ paths:
responses:
'204':
description: Relationship deleted successfully
/api/v1/saml-config:
get:
operationId: saml_config_list
description: Returns all the SAML-based SSO configurations associated with the
current tenant.
summary: List all SSO configurations
parameters:
- in: query
name: fields[saml-configurations]
schema:
type: array
items:
type: string
enum:
- email_domain
- metadata_xml
- created_at
- updated_at
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- name: filter[search]
required: false
in: query
description: A search term.
schema:
type: string
- name: page[number]
required: false
in: query
description: A page number within the paginated result set.
schema:
type: integer
- name: page[size]
required: false
in: query
description: Number of results to return per page.
schema:
type: integer
- name: sort
required: false
in: query
description: '[list of fields to sort by](https://jsonapi.org/format/#fetching-sorting)'
schema:
type: array
items:
type: string
enum:
- id
- -id
- email_domain
- -email_domain
- metadata_xml
- -metadata_xml
- created_at
- -created_at
- updated_at
- -updated_at
explode: false
tags:
- SAML
security:
- jwtAuth: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/PaginatedSAMLConfigurationList'
description: ''
post:
operationId: saml_config_create
description: Creates a new SAML SSO configuration for the current tenant, including
email domain and metadata XML.
summary: Create the SSO configuration
tags:
- SAML
requestBody:
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/SAMLConfigurationRequest'
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/SAMLConfigurationRequest'
multipart/form-data:
schema:
$ref: '#/components/schemas/SAMLConfigurationRequest'
required: true
security:
- jwtAuth: []
responses:
'201':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/SAMLConfigurationResponse'
description: ''
/api/v1/saml-config/{id}:
get:
operationId: saml_config_retrieve
description: Returns the details of a specific SAML configuration belonging
to the current tenant.
summary: Retrieve SSO configuration details
parameters:
- in: query
name: fields[saml-configurations]
schema:
type: array
items:
type: string
enum:
- email_domain
- metadata_xml
- created_at
- updated_at
description: endpoint return only specific fields in the response on a per-type
basis by including a fields[TYPE] query parameter.
explode: false
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this saml configuration.
required: true
tags:
- SAML
security:
- jwtAuth: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/SAMLConfigurationResponse'
description: ''
patch:
operationId: saml_config_partial_update
description: Partially updates an existing SAML SSO configuration. Supports
changes to email domain and metadata XML.
summary: Update the SSO configuration
parameters:
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this saml configuration.
required: true
tags:
- SAML
requestBody:
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/PatchedSAMLConfigurationRequest'
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/PatchedSAMLConfigurationRequest'
multipart/form-data:
schema:
$ref: '#/components/schemas/PatchedSAMLConfigurationRequest'
required: true
security:
- jwtAuth: []
responses:
'200':
content:
application/vnd.api+json:
schema:
$ref: '#/components/schemas/SAMLConfigurationResponse'
description: ''
delete:
operationId: saml_config_destroy
description: Deletes an existing SAML SSO configuration associated with the
current tenant.
summary: Delete the SSO configuration
parameters:
- in: path
name: id
schema:
type: string
format: uuid
description: A UUID string identifying this saml configuration.
required: true
tags:
- SAML
security:
- jwtAuth: []
responses:
'204':
description: No response body
/api/v1/scans:
get:
operationId: scans_list
@@ -9158,6 +9351,15 @@ components:
$ref: '#/components/schemas/Role'
required:
- data
PaginatedSAMLConfigurationList:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/SAMLConfiguration'
required:
- data
PaginatedScanList:
type: object
properties:
@@ -10041,6 +10243,52 @@ components:
readOnly: true
required:
- data
PatchedSAMLConfigurationRequest:
type: object
properties:
data:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- saml-configurations
id:
type: string
format: uuid
attributes:
type: object
properties:
email_domain:
type: string
minLength: 1
description: Email domain used to identify the tenant, e.g. prowlerdemo.com
maxLength: 254
metadata_xml:
type: string
minLength: 1
description: Raw IdP metadata XML to configure SingleSignOnService,
certificates, etc.
created_at:
type: string
format: date-time
readOnly: true
updated_at:
type: string
format: date-time
readOnly: true
required:
- email_domain
- metadata_xml
required:
- data
PatchedScanUpdateRequest:
type: object
properties:
@@ -12153,6 +12401,97 @@ components:
$ref: '#/components/schemas/Role'
required:
- data
SAMLConfiguration:
type: object
required:
- type
- id
additionalProperties: false
properties:
type:
allOf:
- $ref: '#/components/schemas/SAMLConfigurationTypeEnum'
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
id:
type: string
format: uuid
attributes:
type: object
properties:
email_domain:
type: string
description: Email domain used to identify the tenant, e.g. prowlerdemo.com
maxLength: 254
metadata_xml:
type: string
description: Raw IdP metadata XML to configure SingleSignOnService,
certificates, etc.
created_at:
type: string
format: date-time
readOnly: true
updated_at:
type: string
format: date-time
readOnly: true
required:
- email_domain
- metadata_xml
SAMLConfigurationRequest:
type: object
properties:
data:
type: object
required:
- type
additionalProperties: false
properties:
type:
type: string
description: The [type](https://jsonapi.org/format/#document-resource-object-identification)
member is used to describe resource objects that share common attributes
and relationships.
enum:
- saml-configurations
attributes:
type: object
properties:
email_domain:
type: string
minLength: 1
description: Email domain used to identify the tenant, e.g. prowlerdemo.com
maxLength: 254
metadata_xml:
type: string
minLength: 1
description: Raw IdP metadata XML to configure SingleSignOnService,
certificates, etc.
created_at:
type: string
format: date-time
readOnly: true
updated_at:
type: string
format: date-time
readOnly: true
required:
- email_domain
- metadata_xml
required:
- data
SAMLConfigurationResponse:
type: object
properties:
data:
$ref: '#/components/schemas/SAMLConfiguration'
required:
- data
SAMLConfigurationTypeEnum:
type: string
enum:
- saml-configurations
Scan:
type: object
required:
+18 -50
View File
@@ -20,69 +20,37 @@ class TestProwlerSocialAccountAdapter:
adapter = ProwlerSocialAccountAdapter()
assert adapter.get_user_by_email("notfound@example.com") is None
# def test_pre_social_login_links_existing_user(self, create_test_user, rf):
# adapter = ProwlerSocialAccountAdapter()
def test_pre_social_login_links_existing_user(self, create_test_user, rf):
adapter = ProwlerSocialAccountAdapter()
# sociallogin = MagicMock(spec=SocialLogin)
# sociallogin.account = MagicMock()
# sociallogin.account.provider = "saml"
# sociallogin.account.extra_data = {}
# sociallogin.user = create_test_user
# sociallogin.connect = MagicMock()
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.account = MagicMock()
sociallogin.provider = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account.extra_data = {}
sociallogin.user = create_test_user
sociallogin.connect = MagicMock()
# adapter.pre_social_login(rf.get("/"), sociallogin)
adapter.pre_social_login(rf.get("/"), sociallogin)
# call_args = sociallogin.connect.call_args
# assert call_args is not None
call_args = sociallogin.connect.call_args
assert call_args is not None
# called_request, called_user = call_args[0]
# assert called_request.path == "/"
# assert called_user.email == create_test_user.email
called_request, called_user = call_args[0]
assert called_request.path == "/"
assert called_user.email == create_test_user.email
def test_pre_social_login_no_link_if_email_missing(self, rf):
adapter = ProwlerSocialAccountAdapter()
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.account = MagicMock()
sociallogin.account.provider = "github"
sociallogin.provider = MagicMock()
sociallogin.user = MagicMock()
sociallogin.provider.id = "saml"
sociallogin.account.extra_data = {}
sociallogin.connect = MagicMock()
adapter.pre_social_login(rf.get("/"), sociallogin)
sociallogin.connect.assert_not_called()
# def test_save_user_saml_flow(
# self,
# rf,
# saml_setup,
# saml_sociallogin,
# ):
# adapter = ProwlerSocialAccountAdapter()
# request = rf.get("/")
# saml_sociallogin.user.email = saml_setup["email"]
# saml_sociallogin.account.extra_data = {
# "firstName": [],
# "lastName": [],
# "organization": [],
# "userType": [],
# }
# tenant = Tenant.objects.using(MainRouter.admin_db).get(
# id=saml_setup["tenant_id"]
# )
# saml_config = SAMLConfiguration.objects.using(MainRouter.admin_db).get(
# tenant=tenant
# )
# assert saml_config.email_domain == saml_setup["domain"]
# user = adapter.save_user(request, saml_sociallogin)
# assert user.name == "N/A"
# assert user.company_name == ""
# assert user.email == saml_setup["email"]
# assert (
# Membership.objects.using(MainRouter.admin_db)
# .filter(user=user, tenant=tenant)
# .exists()
# )
+124 -121
View File
@@ -1,6 +1,9 @@
import pytest
from allauth.socialaccount.models import SocialApp
from django.core.exceptions import ValidationError
from api.models import Resource, ResourceTag
from api.db_router import MainRouter
from api.models import Resource, ResourceTag, SAMLConfiguration, Tenant
@pytest.mark.django_db
@@ -122,147 +125,147 @@ class TestResourceModel:
# assert Finding.objects.filter(uid=long_uid).exists()
# @pytest.mark.django_db
# class TestSAMLConfigurationModel:
# VALID_METADATA = """<?xml version='1.0' encoding='UTF-8'?>
# <md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
# <md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
# <md:KeyDescriptor use='signing'>
# <ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
# <ds:X509Data>
# <ds:X509Certificate>FAKECERTDATA</ds:X509Certificate>
# </ds:X509Data>
# </ds:KeyInfo>
# </md:KeyDescriptor>
# <md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://idp.test/sso'/>
# </md:IDPSSODescriptor>
# </md:EntityDescriptor>
# """
@pytest.mark.django_db
class TestSAMLConfigurationModel:
VALID_METADATA = """<?xml version='1.0' encoding='UTF-8'?>
<md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
<md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
<md:KeyDescriptor use='signing'>
<ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
<ds:X509Data>
<ds:X509Certificate>FAKECERTDATA</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://idp.test/sso'/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>
"""
# def test_creates_valid_configuration(self):
# tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant A")
# config = SAMLConfiguration.objects.using(MainRouter.admin_db).create(
# email_domain="ssoexample.com",
# metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
# tenant=tenant,
# )
def test_creates_valid_configuration(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant A")
config = SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="ssoexample.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant,
)
# assert config.email_domain == "ssoexample.com"
# assert SocialApp.objects.filter(client_id="ssoexample.com").exists()
assert config.email_domain == "ssoexample.com"
assert SocialApp.objects.filter(client_id="ssoexample.com").exists()
# def test_email_domain_with_at_symbol_fails(self):
# tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant B")
# config = SAMLConfiguration(
# email_domain="invalid@domain.com",
# metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
# tenant=tenant,
# )
def test_email_domain_with_at_symbol_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant B")
config = SAMLConfiguration(
email_domain="invalid@domain.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant,
)
# with pytest.raises(ValidationError) as exc_info:
# config.clean()
with pytest.raises(ValidationError) as exc_info:
config.clean()
# errors = exc_info.value.message_dict
# assert "email_domain" in errors
# assert "Domain must not contain @" in errors["email_domain"][0]
errors = exc_info.value.message_dict
assert "email_domain" in errors
assert "Domain must not contain @" in errors["email_domain"][0]
# def test_duplicate_email_domain_fails(self):
# tenant1 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C1")
# tenant2 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C2")
def test_duplicate_email_domain_fails(self):
tenant1 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C1")
tenant2 = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant C2")
# SAMLConfiguration.objects.using(MainRouter.admin_db).create(
# email_domain="duplicate.com",
# metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
# tenant=tenant1,
# )
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="duplicate.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant1,
)
# config = SAMLConfiguration(
# email_domain="duplicate.com",
# metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
# tenant=tenant2,
# )
config = SAMLConfiguration(
email_domain="duplicate.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant2,
)
# with pytest.raises(ValidationError) as exc_info:
# config.clean()
with pytest.raises(ValidationError) as exc_info:
config.clean()
# errors = exc_info.value.message_dict
# assert "tenant" in errors
# assert "There is a problem with your email domain." in errors["tenant"][0]
errors = exc_info.value.message_dict
assert "tenant" in errors
assert "There is a problem with your email domain." in errors["tenant"][0]
# def test_duplicate_tenant_config_fails(self):
# tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant D")
def test_duplicate_tenant_config_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant D")
# SAMLConfiguration.objects.using(MainRouter.admin_db).create(
# email_domain="unique1.com",
# metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
# tenant=tenant,
# )
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
email_domain="unique1.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant,
)
# config = SAMLConfiguration(
# email_domain="unique2.com",
# metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
# tenant=tenant,
# )
config = SAMLConfiguration(
email_domain="unique2.com",
metadata_xml=TestSAMLConfigurationModel.VALID_METADATA,
tenant=tenant,
)
# with pytest.raises(ValidationError) as exc_info:
# config.clean()
with pytest.raises(ValidationError) as exc_info:
config.clean()
# errors = exc_info.value.message_dict
# assert "tenant" in errors
# assert (
# "A SAML configuration already exists for this tenant."
# in errors["tenant"][0]
# )
errors = exc_info.value.message_dict
assert "tenant" in errors
assert (
"A SAML configuration already exists for this tenant."
in errors["tenant"][0]
)
# def test_invalid_metadata_xml_fails(self):
# tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant E")
# config = SAMLConfiguration(
# email_domain="brokenxml.com",
# metadata_xml="<bad<xml>",
# tenant=tenant,
# )
def test_invalid_metadata_xml_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant E")
config = SAMLConfiguration(
email_domain="brokenxml.com",
metadata_xml="<bad<xml>",
tenant=tenant,
)
# with pytest.raises(ValidationError) as exc_info:
# config._parse_metadata()
with pytest.raises(ValidationError) as exc_info:
config._parse_metadata()
# errors = exc_info.value.message_dict
# assert "metadata_xml" in errors
# assert "Invalid XML" in errors["metadata_xml"][0]
# assert "not well-formed" in errors["metadata_xml"][0]
errors = exc_info.value.message_dict
assert "metadata_xml" in errors
assert "Invalid XML" in errors["metadata_xml"][0]
assert "not well-formed" in errors["metadata_xml"][0]
# def test_metadata_missing_sso_fails(self):
# tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant F")
# xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
# <md:IDPSSODescriptor></md:IDPSSODescriptor>
# </md:EntityDescriptor>"""
# config = SAMLConfiguration(
# email_domain="nosso.com",
# metadata_xml=xml,
# tenant=tenant,
# )
def test_metadata_missing_sso_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant F")
xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor></md:IDPSSODescriptor>
</md:EntityDescriptor>"""
config = SAMLConfiguration(
email_domain="nosso.com",
metadata_xml=xml,
tenant=tenant,
)
# with pytest.raises(ValidationError) as exc_info:
# config._parse_metadata()
with pytest.raises(ValidationError) as exc_info:
config._parse_metadata()
# errors = exc_info.value.message_dict
# assert "metadata_xml" in errors
# assert "Missing SingleSignOnService" in errors["metadata_xml"][0]
errors = exc_info.value.message_dict
assert "metadata_xml" in errors
assert "Missing SingleSignOnService" in errors["metadata_xml"][0]
# def test_metadata_missing_certificate_fails(self):
# tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant G")
# xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
# <md:IDPSSODescriptor>
# <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://example.com/sso"/>
# </md:IDPSSODescriptor>
# </md:EntityDescriptor>"""
# config = SAMLConfiguration(
# email_domain="nocert.com",
# metadata_xml=xml,
# tenant=tenant,
# )
def test_metadata_missing_certificate_fails(self):
tenant = Tenant.objects.using(MainRouter.admin_db).create(name="Tenant G")
xml = """<md:EntityDescriptor entityID="x" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://example.com/sso"/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>"""
config = SAMLConfiguration(
email_domain="nocert.com",
metadata_xml=xml,
tenant=tenant,
)
# with pytest.raises(ValidationError) as exc_info:
# config._parse_metadata()
with pytest.raises(ValidationError) as exc_info:
config._parse_metadata()
# errors = exc_info.value.message_dict
# assert "metadata_xml" in errors
# assert "X509Certificate" in errors["metadata_xml"][0]
errors = exc_info.value.message_dict
assert "metadata_xml" in errors
assert "X509Certificate" in errors["metadata_xml"][0]
+4 -26
View File
@@ -254,7 +254,7 @@ class TestValidateInvitation:
assert result == invitation
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="user@example.com"
token="VALID_TOKEN", email="user@example.com"
)
def test_invitation_not_found_raises_validation_error(self):
@@ -269,7 +269,7 @@ class TestValidateInvitation:
"invitation_token": "Invalid invitation code."
}
mock_db.get.assert_called_once_with(
token="INVALID_TOKEN", email__iexact="user@example.com"
token="INVALID_TOKEN", email="user@example.com"
)
def test_invitation_not_found_raises_not_found(self):
@@ -284,7 +284,7 @@ class TestValidateInvitation:
assert exc_info.value.detail == "Invitation is not valid."
mock_db.get.assert_called_once_with(
token="INVALID_TOKEN", email__iexact="user@example.com"
token="INVALID_TOKEN", email="user@example.com"
)
def test_invitation_expired(self, invitation):
@@ -332,27 +332,5 @@ class TestValidateInvitation:
"invitation_token": "Invalid invitation code."
}
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="different@example.com"
)
def test_valid_invitation_uppercase_email(self):
"""Test that validate_invitation works with case-insensitive email lookup."""
uppercase_email = "USER@example.com"
invitation = MagicMock(spec=Invitation)
invitation.token = "VALID_TOKEN"
invitation.email = uppercase_email
invitation.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
invitation.state = Invitation.State.PENDING
invitation.tenant = MagicMock()
with patch("api.utils.Invitation.objects.using") as mock_using:
mock_db = mock_using.return_value
mock_db.get.return_value = invitation
result = validate_invitation("VALID_TOKEN", "user@example.com")
assert result == invitation
mock_db.get.assert_called_once_with(
token="VALID_TOKEN", email__iexact="user@example.com"
token="VALID_TOKEN", email="different@example.com"
)
+299 -282
View File
@@ -5,19 +5,26 @@ import os
import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock, patch
from urllib.parse import parse_qs, urlparse
from uuid import uuid4
import jwt
import pytest
from allauth.socialaccount.models import SocialAccount, SocialApp
from botocore.exceptions import ClientError, NoCredentialsError
from conftest import API_JSON_CONTENT_TYPE, TEST_PASSWORD, TEST_USER
from django.conf import settings
from django.http import JsonResponse
from django.test import RequestFactory
from django.urls import reverse
from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from api.compliance import get_compliance_frameworks
from api.db_router import MainRouter
from api.models import (
Integration,
Invitation,
@@ -28,6 +35,8 @@ from api.models import (
ProviderSecret,
Role,
RoleProviderGroupRelationship,
SAMLConfiguration,
SAMLToken,
Scan,
StateChoices,
Task,
@@ -35,7 +44,7 @@ from api.models import (
UserRoleRelationship,
)
from api.rls import Tenant
from api.v1.views import ComplianceOverviewViewSet
from api.v1.views import ComplianceOverviewViewSet, TenantFinishACSView
TODAY = str(datetime.today().date())
@@ -5756,334 +5765,342 @@ class TestIntegrationViewSet:
assert response.status_code == status.HTTP_400_BAD_REQUEST
# @pytest.mark.django_db
# class TestSAMLTokenValidation:
# def test_valid_token_returns_tokens(self, authenticated_client, create_test_user):
# user = create_test_user
# valid_token_data = {
# "access": "mock_access_token",
# "refresh": "mock_refresh_token",
# }
# saml_token = SAMLToken.objects.create(
# token=valid_token_data,
# user=user,
# expires_at=datetime.now(timezone.utc) + timedelta(seconds=10),
# )
@pytest.mark.django_db
class TestSAMLTokenValidation:
def test_valid_token_returns_tokens(self, authenticated_client, create_test_user):
user = create_test_user
valid_token_data = {
"access": "mock_access_token",
"refresh": "mock_refresh_token",
}
saml_token = SAMLToken.objects.create(
token=valid_token_data,
user=user,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=10),
)
# url = reverse("token-saml")
# response = authenticated_client.post(f"{url}?id={saml_token.id}")
url = reverse("token-saml")
response = authenticated_client.post(f"{url}?id={saml_token.id}")
# assert response.status_code == status.HTTP_200_OK
# assert response.json() == {"data": valid_token_data}
# assert not SAMLToken.objects.filter(id=saml_token.id).exists()
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"data": valid_token_data}
assert not SAMLToken.objects.filter(id=saml_token.id).exists()
# def test_invalid_token_id_returns_404(self, authenticated_client):
# url = reverse("token-saml")
# response = authenticated_client.post(f"{url}?id={str(uuid4())}")
def test_invalid_token_id_returns_404(self, authenticated_client):
url = reverse("token-saml")
response = authenticated_client.post(f"{url}?id={str(uuid4())}")
# assert response.status_code == status.HTTP_404_NOT_FOUND
# assert response.json()["errors"]["detail"] == "Invalid token ID."
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json()["errors"]["detail"] == "Invalid token ID."
# def test_expired_token_returns_400(self, authenticated_client, create_test_user):
# user = create_test_user
# expired_token_data = {
# "access": "expired_access_token",
# "refresh": "expired_refresh_token",
# }
# saml_token = SAMLToken.objects.create(
# token=expired_token_data,
# user=user,
# expires_at=datetime.now(timezone.utc) - timedelta(seconds=1),
# )
def test_expired_token_returns_400(self, authenticated_client, create_test_user):
user = create_test_user
expired_token_data = {
"access": "expired_access_token",
"refresh": "expired_refresh_token",
}
saml_token = SAMLToken.objects.create(
token=expired_token_data,
user=user,
expires_at=datetime.now(timezone.utc) - timedelta(seconds=1),
)
# url = reverse("token-saml")
# response = authenticated_client.post(f"{url}?id={saml_token.id}")
url = reverse("token-saml")
response = authenticated_client.post(f"{url}?id={saml_token.id}")
# assert response.status_code == status.HTTP_400_BAD_REQUEST
# assert response.json()["errors"]["detail"] == "Token expired."
# assert SAMLToken.objects.filter(id=saml_token.id).exists()
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["errors"]["detail"] == "Token expired."
assert SAMLToken.objects.filter(id=saml_token.id).exists()
# def test_token_can_be_used_only_once(self, authenticated_client, create_test_user):
# user = create_test_user
# token_data = {
# "access": "single_use_token",
# "refresh": "single_use_refresh",
# }
# saml_token = SAMLToken.objects.create(
# token=token_data,
# user=user,
# expires_at=datetime.now(timezone.utc) + timedelta(seconds=10),
# )
def test_token_can_be_used_only_once(self, authenticated_client, create_test_user):
user = create_test_user
token_data = {
"access": "single_use_token",
"refresh": "single_use_refresh",
}
saml_token = SAMLToken.objects.create(
token=token_data,
user=user,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=10),
)
# url = reverse("token-saml")
url = reverse("token-saml")
# # First use: should succeed
# response1 = authenticated_client.post(f"{url}?id={saml_token.id}")
# assert response1.status_code == status.HTTP_200_OK
# First use: should succeed
response1 = authenticated_client.post(f"{url}?id={saml_token.id}")
assert response1.status_code == status.HTTP_200_OK
# # Second use: should fail (already deleted)
# response2 = authenticated_client.post(f"{url}?id={saml_token.id}")
# assert response2.status_code == status.HTTP_404_NOT_FOUND
# Second use: should fail (already deleted)
response2 = authenticated_client.post(f"{url}?id={saml_token.id}")
assert response2.status_code == status.HTTP_404_NOT_FOUND
# @pytest.mark.django_db
# class TestSAMLInitiateAPIView:
# def test_valid_email_domain_and_certificates(
# self, authenticated_client, saml_setup, monkeypatch
# ):
# monkeypatch.setenv("SAML_PUBLIC_CERT", "fake_cert")
# monkeypatch.setenv("SAML_PRIVATE_KEY", "fake_key")
@pytest.mark.django_db
class TestSAMLInitiateAPIView:
def test_valid_email_domain_and_certificates(
self, authenticated_client, saml_setup, monkeypatch
):
monkeypatch.setenv("SAML_PUBLIC_CERT", "fake_cert")
monkeypatch.setenv("SAML_PRIVATE_KEY", "fake_key")
# url = reverse("api_saml_initiate")
# payload = {"email_domain": saml_setup["email"]}
url = reverse("api_saml_initiate")
payload = {"email_domain": saml_setup["email"]}
# response = authenticated_client.post(url, data=payload, format="json")
response = authenticated_client.post(url, data=payload, format="json")
# assert response.status_code == status.HTTP_302_FOUND
# assert (
# reverse("saml_login", kwargs={"organization_slug": saml_setup["domain"]})
# in response.url
# )
# assert "SAMLRequest" not in response.url
assert response.status_code == status.HTTP_302_FOUND
assert (
reverse("saml_login", kwargs={"organization_slug": saml_setup["domain"]})
in response.url
)
assert "SAMLRequest" not in response.url
# def test_invalid_email_domain(self, authenticated_client):
# url = reverse("api_saml_initiate")
# payload = {"email_domain": "user@unauthorized.com"}
def test_invalid_email_domain(self, authenticated_client):
url = reverse("api_saml_initiate")
payload = {"email_domain": "user@unauthorized.com"}
# response = authenticated_client.post(url, data=payload, format="json")
response = authenticated_client.post(url, data=payload, format="json")
# assert response.status_code == status.HTTP_403_FORBIDDEN
# assert response.json()["errors"]["detail"] == "Unauthorized domain."
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json()["errors"]["detail"] == "Unauthorized domain."
# @pytest.mark.django_db
# class TestSAMLConfigurationViewSet:
# def test_list_saml_configurations(self, authenticated_client, saml_setup):
# config = SAMLConfiguration.objects.get(
# email_domain=saml_setup["email"].split("@")[-1]
# )
# response = authenticated_client.get(reverse("saml-config-list"))
# assert response.status_code == status.HTTP_200_OK
# assert (
# response.json()["data"][0]["attributes"]["email_domain"]
# == config.email_domain
# )
@pytest.mark.django_db
class TestSAMLConfigurationViewSet:
def test_list_saml_configurations(self, authenticated_client, saml_setup):
config = SAMLConfiguration.objects.get(
email_domain=saml_setup["email"].split("@")[-1]
)
response = authenticated_client.get(reverse("saml-config-list"))
assert response.status_code == status.HTTP_200_OK
assert (
response.json()["data"][0]["attributes"]["email_domain"]
== config.email_domain
)
# def test_retrieve_saml_configuration(self, authenticated_client, saml_setup):
# config = SAMLConfiguration.objects.get(
# email_domain=saml_setup["email"].split("@")[-1]
# )
# response = authenticated_client.get(
# reverse("saml-config-detail", kwargs={"pk": config.id})
# )
# assert response.status_code == status.HTTP_200_OK
# assert (
# response.json()["data"]["attributes"]["metadata_xml"] == config.metadata_xml
# )
def test_retrieve_saml_configuration(self, authenticated_client, saml_setup):
config = SAMLConfiguration.objects.get(
email_domain=saml_setup["email"].split("@")[-1]
)
response = authenticated_client.get(
reverse("saml-config-detail", kwargs={"pk": config.id})
)
assert response.status_code == status.HTTP_200_OK
assert (
response.json()["data"]["attributes"]["metadata_xml"] == config.metadata_xml
)
# def test_create_saml_configuration(self, authenticated_client, tenants_fixture):
# payload = {
# "email_domain": "newdomain.com",
# "metadata_xml": """<?xml version='1.0' encoding='UTF-8'?>
# <md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
# <md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
# <md:KeyDescriptor use='signing'>
# <ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
# <ds:X509Data>
# <ds:X509Certificate>TEST</ds:X509Certificate>
# </ds:X509Data>
# </ds:KeyInfo>
# </md:KeyDescriptor>
# <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
# <md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://TEST/sso/saml'/>
# <md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' Location='https://TEST/sso/saml'/>
# </md:IDPSSODescriptor>
# </md:EntityDescriptor>
# """,
# }
# response = authenticated_client.post(
# reverse("saml-config-list"), data=payload, format="json"
# )
# assert response.status_code == status.HTTP_201_CREATED
# assert SAMLConfiguration.objects.filter(email_domain="newdomain.com").exists()
def test_create_saml_configuration(self, authenticated_client, tenants_fixture):
payload = {
"email_domain": "newdomain.com",
"metadata_xml": """<?xml version='1.0' encoding='UTF-8'?>
<md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
<md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
<md:KeyDescriptor use='signing'>
<ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
<ds:X509Data>
<ds:X509Certificate>TEST</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://TEST/sso/saml'/>
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' Location='https://TEST/sso/saml'/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>
""",
}
response = authenticated_client.post(
reverse("saml-config-list"), data=payload, format="json"
)
assert response.status_code == status.HTTP_201_CREATED
assert SAMLConfiguration.objects.filter(email_domain="newdomain.com").exists()
# def test_update_saml_configuration(self, authenticated_client, saml_setup):
# config = SAMLConfiguration.objects.get(
# email_domain=saml_setup["email"].split("@")[-1]
# )
# payload = {
# "data": {
# "type": "saml-configurations",
# "id": str(config.id),
# "attributes": {
# "metadata_xml": """<?xml version='1.0' encoding='UTF-8'?>
# <md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
# <md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
# <md:KeyDescriptor use='signing'>
# <ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
# <ds:X509Data>
# <ds:X509Certificate>TEST2</ds:X509Certificate>
# </ds:X509Data>
# </ds:KeyInfo>
# </md:KeyDescriptor>
# <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
# <md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://TEST/sso/saml'/>
# <md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' Location='https://TEST/sso/saml'/>
# </md:IDPSSODescriptor>
# </md:EntityDescriptor>
# """
# },
# }
# }
# response = authenticated_client.patch(
# reverse("saml-config-detail", kwargs={"pk": config.id}),
# data=payload,
# content_type="application/vnd.api+json",
# )
# assert response.status_code == status.HTTP_200_OK
# config.refresh_from_db()
# assert (
# config.metadata_xml.strip()
# == payload["data"]["attributes"]["metadata_xml"].strip()
# )
def test_update_saml_configuration(self, authenticated_client, saml_setup):
config = SAMLConfiguration.objects.get(
email_domain=saml_setup["email"].split("@")[-1]
)
payload = {
"data": {
"type": "saml-configurations",
"id": str(config.id),
"attributes": {
"metadata_xml": """<?xml version='1.0' encoding='UTF-8'?>
<md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
<md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
<md:KeyDescriptor use='signing'>
<ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
<ds:X509Data>
<ds:X509Certificate>TEST2</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://TEST/sso/saml'/>
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' Location='https://TEST/sso/saml'/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>
"""
},
}
}
response = authenticated_client.patch(
reverse("saml-config-detail", kwargs={"pk": config.id}),
data=payload,
content_type="application/vnd.api+json",
)
assert response.status_code == status.HTTP_200_OK
config.refresh_from_db()
assert (
config.metadata_xml.strip()
== payload["data"]["attributes"]["metadata_xml"].strip()
)
# def test_delete_saml_configuration(self, authenticated_client, saml_setup):
# config = SAMLConfiguration.objects.get(
# email_domain=saml_setup["email"].split("@")[-1]
# )
# response = authenticated_client.delete(
# reverse("saml-config-detail", kwargs={"pk": config.id})
# )
# assert response.status_code == status.HTTP_204_NO_CONTENT
# assert not SAMLConfiguration.objects.filter(id=config.id).exists()
def test_delete_saml_configuration(self, authenticated_client, saml_setup):
config = SAMLConfiguration.objects.get(
email_domain=saml_setup["email"].split("@")[-1]
)
response = authenticated_client.delete(
reverse("saml-config-detail", kwargs={"pk": config.id})
)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert not SAMLConfiguration.objects.filter(id=config.id).exists()
# @pytest.mark.django_db
# class TestTenantFinishACSView:
# def test_dispatch_skips_if_user_not_authenticated(self):
# request = RequestFactory().get(
# reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
# )
# request.user = type("Anonymous", (), {"is_authenticated": False})()
@pytest.mark.django_db
class TestTenantFinishACSView:
def test_dispatch_skips_if_user_not_authenticated(self):
request = RequestFactory().get(
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
)
request.user = type("Anonymous", (), {"is_authenticated": False})()
# with patch(
# "allauth.socialaccount.providers.saml.views.get_app_or_404"
# ) as mock_get_app:
# mock_get_app.return_value = SocialApp(
# provider="saml",
# client_id="testtenant",
# name="Test App",
# settings={},
# )
with patch(
"allauth.socialaccount.providers.saml.views.get_app_or_404"
) as mock_get_app:
mock_get_app.return_value = SocialApp(
provider="saml",
client_id="testtenant",
name="Test App",
settings={},
)
# view = TenantFinishACSView.as_view()
# response = view(request, organization_slug="testtenant")
view = TenantFinishACSView.as_view()
response = view(request, organization_slug="testtenant")
# assert response.status_code in [200, 302]
assert response.status_code in [200, 302]
# def test_dispatch_skips_if_social_app_not_found(self, users_fixture):
# request = RequestFactory().get(
# reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
# )
# request.user = users_fixture[0]
def test_dispatch_skips_if_social_app_not_found(self, users_fixture):
request = RequestFactory().get(
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
)
request.user = users_fixture[0]
# with patch(
# "allauth.socialaccount.providers.saml.views.get_app_or_404"
# ) as mock_get_app:
# mock_get_app.return_value = SocialApp(
# provider="saml",
# client_id="testtenant",
# name="Test App",
# settings={},
# )
with patch(
"allauth.socialaccount.providers.saml.views.get_app_or_404"
) as mock_get_app:
mock_get_app.return_value = SocialApp(
provider="saml",
client_id="testtenant",
name="Test App",
settings={},
)
# view = TenantFinishACSView.as_view()
# response = view(request, organization_slug="testtenant")
view = TenantFinishACSView.as_view()
response = view(request, organization_slug="testtenant")
# assert isinstance(response, JsonResponse) or response.status_code in [200, 302]
assert isinstance(response, JsonResponse) or response.status_code in [200, 302]
# def test_dispatch_sets_user_profile_and_assigns_role_and_creates_token(
# self, create_test_user, tenants_fixture, saml_setup, settings, monkeypatch
# ):
# monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete")
# user = create_test_user
# original_email = user.email
# original_name = user.name
# original_company = user.company_name
# user.email = f"doe@{saml_setup['email']}"
def test_dispatch_sets_user_profile_and_assigns_role_and_creates_token(
self, create_test_user, tenants_fixture, saml_setup, settings, monkeypatch
):
monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete")
user = create_test_user
original_name = user.name
original_company = user.company_name
user.company_name = "testing_company"
user.is_authenticate = True
# social_account = SocialAccount(
# user=user,
# provider="saml",
# extra_data={
# "firstName": ["John"],
# "lastName": ["Doe"],
# "organization": ["TestOrg"],
# "userType": ["saml_default_role"],
# },
# )
social_account = SocialAccount(
user=user,
provider="saml",
extra_data={
"firstName": ["John"],
"lastName": ["Doe"],
"organization": ["testing_company"],
"userType": ["saml_default_role"],
},
)
# request = RequestFactory().get(
# reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
# )
# request.user = user
request = RequestFactory().get(
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
)
request.user = user
# with (
# patch(
# "allauth.socialaccount.providers.saml.views.get_app_or_404"
# ) as mock_get_app_or_404,
# patch("allauth.socialaccount.models.SocialApp.objects.get"),
# patch(
# "allauth.socialaccount.models.SocialAccount.objects.get"
# ) as mock_sa_get,
# ):
# mock_get_app_or_404.return_value = MagicMock(
# provider="saml", client_id="testtenant", name="Test App", settings={}
# )
# mock_sa_get.return_value = social_account
with (
patch(
"allauth.socialaccount.providers.saml.views.get_app_or_404"
) as mock_get_app_or_404,
patch(
"allauth.socialaccount.models.SocialApp.objects.get"
) as mock_socialapp_get,
patch(
"allauth.socialaccount.models.SocialAccount.objects.get"
) as mock_sa_get,
patch("api.models.SAMLDomainIndex.objects.get") as mock_saml_domain_get,
patch("api.models.SAMLConfiguration.objects.get") as mock_saml_config_get,
patch("api.models.User.objects.get") as mock_user_get,
):
mock_get_app_or_404.return_value = MagicMock(
provider="saml", client_id="testtenant", name="Test App", settings={}
)
mock_sa_get.return_value = social_account
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
mock_saml_domain_get.return_value = SimpleNamespace(
tenant_id=tenants_fixture[0].id
)
mock_saml_config_get.return_value = MagicMock()
mock_user_get.return_value = user
# view = TenantFinishACSView.as_view()
# response = view(request, organization_slug="testtenant")
view = TenantFinishACSView.as_view()
response = view(request, organization_slug="testtenant")
# assert response.status_code == 302
assert response.status_code == 302
# expected_callback_host = "localhost"
# parsed_url = urlparse(response.url)
# assert parsed_url.netloc == expected_callback_host
# query_params = parse_qs(parsed_url.query)
# assert "id" in query_params
expected_callback_host = "localhost"
parsed_url = urlparse(response.url)
assert parsed_url.netloc == expected_callback_host
query_params = parse_qs(parsed_url.query)
assert "id" in query_params
# token_id = query_params["id"][0]
# token_obj = SAMLToken.objects.get(id=token_id)
# assert token_obj.user == user
# assert not token_obj.is_expired()
token_id = query_params["id"][0]
token_obj = SAMLToken.objects.get(id=token_id)
assert token_obj.user == user
assert not token_obj.is_expired()
# user.refresh_from_db()
# assert user.name == "John Doe"
# assert user.company_name == "TestOrg"
user.refresh_from_db()
assert user.name == "John Doe"
assert user.company_name == "testing_company"
# role = Role.objects.using(MainRouter.admin_db).get(name="saml_default_role")
# assert role.tenant == tenants_fixture[0]
role = Role.objects.using(MainRouter.admin_db).get(name="saml_default_role")
assert role.tenant == tenants_fixture[0]
# assert (
# UserRoleRelationship.objects.using(MainRouter.admin_db)
# .filter(user=user, tenant_id=tenants_fixture[0].id)
# .exists()
# )
assert (
UserRoleRelationship.objects.using(MainRouter.admin_db)
.filter(user=user, tenant_id=tenants_fixture[0].id)
.exists()
)
# # Membership should have been created with default role
# membership = Membership.objects.using(MainRouter.admin_db).get(
# user=user, tenant=tenants_fixture[0]
# )
# assert membership.role == Membership.RoleChoices.MEMBER
# assert membership.user == user
# assert membership.tenant == tenants_fixture[0]
membership = Membership.objects.using(MainRouter.admin_db).get(
user=user, tenant=tenants_fixture[0]
)
assert membership.role == Membership.RoleChoices.MEMBER
assert membership.user == user
assert membership.tenant == tenants_fixture[0]
# # Restore original user state
# user.email = original_email
# user.name = original_name
# user.company_name = original_company
# user.save()
user.name = original_name
user.company_name = original_company
user.save()
@pytest.mark.django_db
+1 -1
View File
@@ -187,7 +187,7 @@ def validate_invitation(
# Admin DB connector is used to bypass RLS protection since the invitation belongs to a tenant the user
# is not a member of yet
invitation = Invitation.objects.using(MainRouter.admin_db).get(
token=invitation_token, email__iexact=email
token=invitation_token, email=email
)
except Invitation.DoesNotExist:
if raise_not_found:
+19 -12
View File
@@ -29,6 +29,7 @@ from api.models import (
ResourceTag,
Role,
RoleProviderGroupRelationship,
SAMLConfiguration,
Scan,
StateChoices,
StatusChoices,
@@ -129,6 +130,12 @@ class TokenSerializer(BaseTokenSerializer):
class TokenSocialLoginSerializer(BaseTokenSerializer):
email = serializers.EmailField(write_only=True)
tenant_id = serializers.UUIDField(
write_only=True,
required=False,
help_text="If not provided, the tenant ID of the first membership that was added"
" to the user will be used.",
)
# Output tokens
refresh = serializers.CharField(read_only=True)
@@ -2068,23 +2075,23 @@ class IntegrationUpdateSerializer(BaseWriteIntegrationSerializer):
# SSO
# class SamlInitiateSerializer(serializers.Serializer):
# email_domain = serializers.CharField()
class SamlInitiateSerializer(serializers.Serializer):
email_domain = serializers.CharField()
# class JSONAPIMeta:
# resource_name = "saml-initiate"
class JSONAPIMeta:
resource_name = "saml-initiate"
# class SamlMetadataSerializer(serializers.Serializer):
# class JSONAPIMeta:
# resource_name = "saml-meta"
class SamlMetadataSerializer(serializers.Serializer):
class JSONAPIMeta:
resource_name = "saml-meta"
# class SAMLConfigurationSerializer(RLSSerializer):
# class Meta:
# model = SAMLConfiguration
# fields = ["id", "email_domain", "metadata_xml", "created_at", "updated_at"]
# read_only_fields = ["id", "created_at", "updated_at"]
class SAMLConfigurationSerializer(RLSSerializer):
class Meta:
model = SAMLConfiguration
fields = ["id", "email_domain", "metadata_xml", "created_at", "updated_at"]
read_only_fields = ["id", "created_at", "updated_at"]
class LighthouseConfigSerializer(RLSSerializer):
+37 -19
View File
@@ -1,9 +1,11 @@
from allauth.socialaccount.providers.saml.views import ACSView, MetadataView, SLSView
from django.urls import include, path
from drf_spectacular.views import SpectacularRedocView
from rest_framework_nested import routers
from api.v1.views import (
ComplianceOverviewViewSet,
CustomSAMLLoginView,
CustomTokenObtainView,
CustomTokenRefreshView,
CustomTokenSwitchTenantView,
@@ -23,10 +25,14 @@ from api.v1.views import (
ResourceViewSet,
RoleProviderGroupRelationshipView,
RoleViewSet,
SAMLConfigurationViewSet,
SAMLInitiateAPIView,
SAMLTokenValidateView,
ScanViewSet,
ScheduleViewSet,
SchemaView,
TaskViewSet,
TenantFinishACSView,
TenantMembersViewSet,
TenantViewSet,
UserRoleRelationshipView,
@@ -50,7 +56,7 @@ router.register(
router.register(r"overviews", OverviewViewSet, basename="overview")
router.register(r"schedules", ScheduleViewSet, basename="schedule")
router.register(r"integrations", IntegrationViewSet, basename="integration")
# router.register(r"saml-config", SAMLConfigurationViewSet, basename="saml-config")
router.register(r"saml-config", SAMLConfigurationViewSet, basename="saml-config")
router.register(
r"lighthouse-configurations",
LighthouseConfigViewSet,
@@ -119,24 +125,36 @@ urlpatterns = [
),
name="provider_group-providers-relationship",
),
# API endpoint to start SAML SSO flow (WIP)
# path(
# "auth/saml/initiate/", SAMLInitiateAPIView.as_view(), name="api_saml_initiate"
# ),
# # Custom SAML endpoints (must come before allauth.urls) (WIP)
# path(
# "accounts/saml/<organization_slug>/login/",
# CustomSAMLLoginView.as_view(),
# name="saml_login",
# ),
# path(
# "accounts/saml/<organization_slug>/acs/finish/",
# TenantFinishACSView.as_view(),
# name="saml_finish_acs",
# ),
# Allauth SAML endpoints for tenants (WIP)
# path("accounts/", include("allauth.urls")),
# path("tokens/saml", SAMLTokenValidateView.as_view(), name="token-saml"),
# API endpoint to start SAML SSO flow
path(
"auth/saml/initiate/", SAMLInitiateAPIView.as_view(), name="api_saml_initiate"
),
path(
"accounts/saml/<organization_slug>/login/",
CustomSAMLLoginView.as_view(),
name="saml_login",
),
path(
"accounts/saml/<organization_slug>/acs/",
ACSView.as_view(),
name="saml_acs",
),
path(
"accounts/saml/<organization_slug>/acs/finish/",
TenantFinishACSView.as_view(),
name="saml_finish_acs",
),
path(
"accounts/saml/<organization_slug>/sls/",
SLSView.as_view(),
name="saml_sls",
),
path(
"accounts/saml/<organization_slug>/metadata/",
MetadataView.as_view(),
name="saml_metadata",
),
path("tokens/saml", SAMLTokenValidateView.as_view(), name="token-saml"),
path("tokens/google", GoogleSocialLoginView.as_view(), name="token-google"),
path("tokens/github", GithubSocialLoginView.as_view(), name="token-github"),
path("", include(router.urls)),
+226 -204
View File
@@ -1,10 +1,13 @@
import glob
import os
from datetime import datetime, timedelta, timezone
from urllib.parse import urljoin
import sentry_sdk
from allauth.socialaccount.models import SocialAccount, SocialApp
from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
from allauth.socialaccount.providers.saml.views import FinishACSView, LoginView
from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError
from celery.result import AsyncResult
from config.env import env
@@ -20,6 +23,7 @@ from django.db import transaction
from django.db.models import Count, Exists, F, OuterRef, Prefetch, Q, Sum
from django.db.models.functions import Coalesce
from django.http import HttpResponse
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.dateparse import parse_date
from django.utils.decorators import method_decorator
@@ -64,6 +68,7 @@ from api.compliance import (
get_compliance_frameworks,
)
from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.exceptions import TaskFailedException
from api.filters import (
ComplianceOverviewFilter,
@@ -101,6 +106,9 @@ from api.models import (
ResourceScanSummary,
Role,
RoleProviderGroupRelationship,
SAMLConfiguration,
SAMLDomainIndex,
SAMLToken,
Scan,
ScanSummary,
SeverityChoices,
@@ -157,6 +165,8 @@ from api.v1.serializers import (
RoleProviderGroupRelationshipSerializer,
RoleSerializer,
RoleUpdateSerializer,
SAMLConfigurationSerializer,
SamlInitiateSerializer,
ScanComplianceReportSerializer,
ScanCreateSerializer,
ScanReportSerializer,
@@ -271,7 +281,7 @@ class SchemaView(SpectacularAPIView):
def get(self, request, *args, **kwargs):
spectacular_settings.TITLE = "Prowler API"
spectacular_settings.VERSION = "1.9.1"
spectacular_settings.VERSION = "1.9.0"
spectacular_settings.DESCRIPTION = (
"Prowler API specification.\n\nThis file is auto-generated."
)
@@ -393,240 +403,252 @@ class GithubSocialLoginView(SocialLoginView):
return original_response
# @extend_schema(exclude=True)
# class SAMLTokenValidateView(GenericAPIView):
# resource_name = "tokens"
# http_method_names = ["post"]
@extend_schema(exclude=True)
class SAMLTokenValidateView(GenericAPIView):
resource_name = "tokens"
http_method_names = ["post"]
# def post(self, request):
# token_id = request.query_params.get("id", "invalid")
# try:
# saml_token = SAMLToken.objects.using(MainRouter.admin_db).get(id=token_id)
# except SAMLToken.DoesNotExist:
# return Response({"detail": "Invalid token ID."}, status=404)
def post(self, request):
token_id = request.query_params.get("id", "invalid")
try:
saml_token = SAMLToken.objects.using(MainRouter.admin_db).get(id=token_id)
except SAMLToken.DoesNotExist:
return Response({"detail": "Invalid token ID."}, status=404)
# if saml_token.is_expired():
# return Response({"detail": "Token expired."}, status=400)
if saml_token.is_expired():
return Response({"detail": "Token expired."}, status=400)
# token_data = saml_token.token
# # Currently we don't store the tokens in the database, so we delete the token after use
# saml_token.delete()
token_data = saml_token.token
# Currently we don't store the tokens in the database, so we delete the token after use
saml_token.delete()
# return Response(token_data, status=200)
return Response(token_data, status=200)
# @extend_schema(exclude=True)
# class CustomSAMLLoginView(LoginView):
# def dispatch(self, request, *args, **kwargs):
# """
# Convert GET requests to POST to bypass allauth's confirmation screen.
@extend_schema(exclude=True)
class CustomSAMLLoginView(LoginView):
def dispatch(self, request, *args, **kwargs):
"""
Convert GET requests to POST to bypass allauth's confirmation screen.
# Why this is necessary:
# - django-allauth requires POST for social logins to prevent open redirect attacks
# - SAML login links typically use GET requests (e.g., <a href="...">)
# - This conversion allows seamless login without user-facing confirmation
Why this is necessary:
- django-allauth requires POST for social logins to prevent open redirect attacks
- SAML login links typically use GET requests (e.g., <a href="...">)
- This conversion allows seamless login without user-facing confirmation
# Security considerations:
# 1. Preserves CSRF protection: Original POST handling remains intact
# 2. Avoids global SOCIALACCOUNT_LOGIN_ON_GET=True which would:
# - Enable GET logins for ALL providers (security risk)
# - Potentially expose open redirect vulnerabilities
# 3. SAML payloads remain signed/encrypted regardless of HTTP method
# 4. No sensitive parameters are exposed in URLs (copied to POST body)
Security considerations:
1. Preserves CSRF protection: Original POST handling remains intact
2. Avoids global SOCIALACCOUNT_LOGIN_ON_GET=True which would:
- Enable GET logins for ALL providers (security risk)
- Potentially expose open redirect vulnerabilities
3. SAML payloads remain signed/encrypted regardless of HTTP method
4. No sensitive parameters are exposed in URLs (copied to POST body)
# This approach maintains security while providing better UX.
# """
# if request.method == "GET":
# # Convert GET to POST while preserving parameters
# request.method = "POST"
# # Safe because SAML validates signatures
# request.POST = request.GET.copy()
# return super().dispatch(request, *args, **kwargs)
This approach maintains security while providing better UX.
"""
if request.method == "GET":
# Convert GET to POST while preserving parameters
request.method = "POST"
return super().dispatch(request, *args, **kwargs)
# @extend_schema(exclude=True)
# class SAMLInitiateAPIView(GenericAPIView):
# serializer_class = SamlInitiateSerializer
# permission_classes = []
@extend_schema(exclude=True)
class SAMLInitiateAPIView(GenericAPIView):
serializer_class = SamlInitiateSerializer
permission_classes = []
# def post(self, request, *args, **kwargs):
# # Validate the input payload and extract the domain
# serializer = self.get_serializer(data=request.data)
# serializer.is_valid(raise_exception=True)
# email = serializer.validated_data["email_domain"]
# domain = email.split("@", 1)[-1].lower()
def post(self, request, *args, **kwargs):
# Validate the input payload and extract the domain
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
email = serializer.validated_data["email_domain"]
domain = email.split("@", 1)[-1].lower()
# # Retrieve the SAML configuration for the given email domain
# try:
# check = SAMLDomainIndex.objects.get(email_domain=domain)
# with rls_transaction(str(check.tenant_id)):
# config = SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
# except (SAMLDomainIndex.DoesNotExist, SAMLConfiguration.DoesNotExist):
# return Response(
# {"detail": "Unauthorized domain."}, status=status.HTTP_403_FORBIDDEN
# )
# Retrieve the SAML configuration for the given email domain
try:
check = SAMLDomainIndex.objects.get(email_domain=domain)
with rls_transaction(str(check.tenant_id)):
config = SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
except (SAMLDomainIndex.DoesNotExist, SAMLConfiguration.DoesNotExist):
return Response(
{"detail": "Unauthorized domain."}, status=status.HTTP_403_FORBIDDEN
)
# # Check certificates are not empty (TODO: Validate certificates)
# # saml_public_cert = os.getenv("SAML_PUBLIC_CERT", "").strip()
# # saml_private_key = os.getenv("SAML_PRIVATE_KEY", "").strip()
# Check certificates are not empty (TODO: Validate certificates)
# saml_public_cert = os.getenv("SAML_PUBLIC_CERT", "").strip()
# saml_private_key = os.getenv("SAML_PRIVATE_KEY", "").strip()
# # if not saml_public_cert or not saml_private_key:
# # return Response(
# # {"detail": "SAML configuration is invalid: missing certificates."},
# # status=status.HTTP_403_FORBIDDEN,
# # )
# if not saml_public_cert or not saml_private_key:
# return Response(
# {"detail": "SAML configuration is invalid: missing certificates."},
# status=status.HTTP_403_FORBIDDEN,
# )
# # Build the SAML login URL using the configured API host
# api_host = os.getenv("API_BASE_URL")
# login_path = reverse(
# "saml_login", kwargs={"organization_slug": config.email_domain}
# )
# login_url = urljoin(api_host, login_path)
# Build the SAML login URL using the configured API host
api_host = os.getenv("API_BASE_URL")
login_path = reverse(
"saml_login", kwargs={"organization_slug": config.email_domain}
)
login_url = urljoin(api_host, login_path)
# return redirect(login_url)
return redirect(login_url)
# @extend_schema_view(
# list=extend_schema(
# tags=["SAML"],
# summary="List all SSO configurations",
# description="Returns all the SAML-based SSO configurations associated with the current tenant.",
# ),
# retrieve=extend_schema(
# tags=["SAML"],
# summary="Retrieve SSO configuration details",
# description="Returns the details of a specific SAML configuration belonging to the current tenant.",
# ),
# create=extend_schema(
# tags=["SAML"],
# summary="Create the SSO configuration",
# description="Creates a new SAML SSO configuration for the current tenant, including email domain and metadata XML.",
# ),
# partial_update=extend_schema(
# tags=["SAML"],
# summary="Update the SSO configuration",
# description="Partially updates an existing SAML SSO configuration. Supports changes to email domain and metadata XML.",
# ),
# destroy=extend_schema(
# tags=["SAML"],
# summary="Delete the SSO configuration",
# description="Deletes an existing SAML SSO configuration associated with the current tenant.",
# ),
# )
# @method_decorator(CACHE_DECORATOR, name="retrieve")
# @method_decorator(CACHE_DECORATOR, name="list")
# class SAMLConfigurationViewSet(BaseRLSViewSet):
# """
# ViewSet for managing SAML SSO configurations per tenant.
@extend_schema_view(
list=extend_schema(
tags=["SAML"],
summary="List all SSO configurations",
description="Returns all the SAML-based SSO configurations associated with the current tenant.",
),
retrieve=extend_schema(
tags=["SAML"],
summary="Retrieve SSO configuration details",
description="Returns the details of a specific SAML configuration belonging to the current tenant.",
),
create=extend_schema(
tags=["SAML"],
summary="Create the SSO configuration",
description="Creates a new SAML SSO configuration for the current tenant, including email domain and metadata XML.",
),
partial_update=extend_schema(
tags=["SAML"],
summary="Update the SSO configuration",
description="Partially updates an existing SAML SSO configuration. Supports changes to email domain and metadata XML.",
),
destroy=extend_schema(
tags=["SAML"],
summary="Delete the SSO configuration",
description="Deletes an existing SAML SSO configuration associated with the current tenant.",
),
)
@method_decorator(CACHE_DECORATOR, name="retrieve")
@method_decorator(CACHE_DECORATOR, name="list")
class SAMLConfigurationViewSet(BaseRLSViewSet):
"""
ViewSet for managing SAML SSO configurations per tenant.
# This endpoint allows authorized users to perform CRUD operations on SAMLConfiguration,
# which define how a tenant integrates with an external SAML Identity Provider (IdP).
This endpoint allows authorized users to perform CRUD operations on SAMLConfiguration,
which define how a tenant integrates with an external SAML Identity Provider (IdP).
# Typical use cases include:
# - Listing all existing configurations for auditing or UI display.
# - Retrieving a single configuration to show setup details.
# - Creating or updating a configuration to onboard or modify SAML integration.
# - Deleting a configuration when deactivating SAML for a tenant.
# """
Typical use cases include:
- Listing all existing configurations for auditing or UI display.
- Retrieving a single configuration to show setup details.
- Creating or updating a configuration to onboard or modify SAML integration.
- Deleting a configuration when deactivating SAML for a tenant.
"""
# serializer_class = SAMLConfigurationSerializer
# required_permissions = [Permissions.MANAGE_INTEGRATIONS]
# queryset = SAMLConfiguration.objects.all()
serializer_class = SAMLConfigurationSerializer
required_permissions = [Permissions.MANAGE_INTEGRATIONS]
queryset = SAMLConfiguration.objects.all()
# def get_queryset(self):
# # If called during schema generation, return an empty queryset
# if getattr(self, "swagger_fake_view", False):
# return SAMLConfiguration.objects.none()
# return SAMLConfiguration.objects.filter(tenant=self.request.tenant_id)
def get_queryset(self):
# If called during schema generation, return an empty queryset
if getattr(self, "swagger_fake_view", False):
return SAMLConfiguration.objects.none()
return SAMLConfiguration.objects.filter(tenant=self.request.tenant_id)
# class TenantFinishACSView(FinishACSView):
# def dispatch(self, request, organization_slug):
# response = super().dispatch(request, organization_slug)
# user = getattr(request, "user", None)
# if not user or not user.is_authenticated:
# return response
class TenantFinishACSView(FinishACSView):
def dispatch(self, request, organization_slug):
response = super().dispatch(request, organization_slug)
user = getattr(request, "user", None)
if not user or not user.is_authenticated:
return response
# try:
# social_app = SocialApp.objects.get(
# provider="saml", client_id=organization_slug
# )
# social_account = SocialAccount.objects.get(
# user=user, provider=social_app.provider
# )
# except (SocialApp.DoesNotExist, SocialAccount.DoesNotExist):
# return response
# Defensive check to avoid edge case failures due to inconsistent or incomplete data in the database
# This handles scenarios like partially deleted or missing related objects
try:
check = SAMLDomainIndex.objects.get(email_domain=organization_slug)
with rls_transaction(str(check.tenant_id)):
SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
social_app = SocialApp.objects.get(
provider="saml", client_id=organization_slug
)
user_id = User.objects.get(email=str(user)).id
social_account = SocialAccount.objects.get(
user=str(user_id), provider=social_app.provider_id
)
except (
SAMLDomainIndex.DoesNotExist,
SAMLConfiguration.DoesNotExist,
SocialApp.DoesNotExist,
SocialAccount.DoesNotExist,
User.DoesNotExist,
):
return response
# extra = social_account.extra_data
# user.first_name = (
# extra.get("firstName", [""])[0] if extra.get("firstName") else ""
# )
# user.last_name = extra.get("lastName", [""])[0] if extra.get("lastName") else ""
# user.company_name = (
# extra.get("organization", [""])[0] if extra.get("organization") else ""
# )
# user.name = f"{user.first_name} {user.last_name}".strip()
# if user.name == "":
# user.name = "N/A"
# user.save()
extra = social_account.extra_data
user.first_name = (
extra.get("firstName", [""])[0] if extra.get("firstName") else ""
)
user.last_name = extra.get("lastName", [""])[0] if extra.get("lastName") else ""
user.company_name = (
extra.get("organization", [""])[0] if extra.get("organization") else ""
)
user.name = f"{user.first_name} {user.last_name}".strip()
if user.name == "":
user.name = "N/A"
user.save()
# email_domain = user.email.split("@")[-1]
# tenant = (
# SAMLConfiguration.objects.using(MainRouter.admin_db)
# .get(email_domain=email_domain)
# .tenant
# )
# role_name = (
# extra.get("userType", ["saml_default_role"])[0].strip()
# if extra.get("userType")
# else "saml_default_role"
# )
# try:
# role = Role.objects.using(MainRouter.admin_db).get(
# name=role_name, tenant=tenant
# )
# except Role.DoesNotExist:
# role = Role.objects.using(MainRouter.admin_db).create(
# name=role_name,
# tenant=tenant,
# manage_users=False,
# manage_account=False,
# manage_billing=False,
# manage_providers=False,
# manage_integrations=False,
# manage_scans=False,
# unlimited_visibility=False,
# )
# UserRoleRelationship.objects.using(MainRouter.admin_db).filter(
# user=user,
# tenant_id=tenant.id,
# ).delete()
# UserRoleRelationship.objects.using(MainRouter.admin_db).create(
# user=user,
# role=role,
# tenant_id=tenant.id,
# )
# membership, _ = Membership.objects.using(MainRouter.admin_db).get_or_create(
# user=user,
# tenant=tenant,
# defaults={
# "user": user,
# "tenant": tenant,
# "role": Membership.RoleChoices.MEMBER,
# },
# )
email_domain = user.email.split("@")[-1]
tenant = (
SAMLConfiguration.objects.using(MainRouter.admin_db)
.get(email_domain=email_domain)
.tenant
)
role_name = (
extra.get("userType", ["saml_default_role"])[0].strip()
if extra.get("userType")
else "saml_default_role"
)
try:
role = Role.objects.using(MainRouter.admin_db).get(
name=role_name, tenant=tenant
)
except Role.DoesNotExist:
role = Role.objects.using(MainRouter.admin_db).create(
name=role_name,
tenant=tenant,
manage_users=False,
manage_account=False,
manage_billing=False,
manage_providers=False,
manage_integrations=False,
manage_scans=False,
unlimited_visibility=False,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).filter(
user=user,
tenant_id=tenant.id,
).delete()
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
membership, _ = Membership.objects.using(MainRouter.admin_db).get_or_create(
user=user,
tenant=tenant,
defaults={
"user": user,
"tenant": tenant,
"role": Membership.RoleChoices.MEMBER,
},
)
# serializer = TokenSocialLoginSerializer(data={"email": user.email})
# serializer.is_valid(raise_exception=True)
serializer = TokenSocialLoginSerializer(
data={"email": user.email, "tenant_id": str(tenant.id)}
)
serializer.is_valid(raise_exception=True)
# token_data = serializer.validated_data
# saml_token = SAMLToken.objects.using(MainRouter.admin_db).create(
# token=token_data, user=user
# )
# callback_url = env.str("SAML_SSO_CALLBACK_URL")
# redirect_url = f"{callback_url}?id={saml_token.id}"
token_data = serializer.validated_data
saml_token = SAMLToken.objects.using(MainRouter.admin_db).create(
token=token_data, user=user
)
callback_url = env.str("SAML_SSO_CALLBACK_URL")
redirect_url = f"{callback_url}?id={saml_token.id}"
# return redirect(redirect_url)
return redirect(redirect_url)
@extend_schema_view(
+5
View File
@@ -11,6 +11,7 @@ SECRET_KEY = env("SECRET_KEY", default="secret")
DEBUG = env.bool("DJANGO_DEBUG", default=False)
ALLOWED_HOSTS = ["localhost", "127.0.0.1"]
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https")
USE_X_FORWARDED_HOST = True
# Application definition
@@ -248,3 +249,7 @@ X_FRAME_OPTIONS = "DENY"
SECURE_REFERRER_POLICY = "strict-origin-when-cross-origin"
DJANGO_DELETION_BATCH_SIZE = env.int("DJANGO_DELETION_BATCH_SIZE", 5000)
# SAML requirement
CSRF_COOKIE_SECURE = True
SESSION_COOKIE_SECURE = True
+3 -6
View File
@@ -4,7 +4,6 @@ from config.env import env
IGNORED_EXCEPTIONS = [
# Provider is not connected due to credentials errors
"is not connected",
"ProviderConnectionError",
# Authentication Errors from AWS
"InvalidToken",
"AccessDeniedException",
@@ -17,7 +16,7 @@ IGNORED_EXCEPTIONS = [
"InternalServerErrorException",
"AccessDenied",
"No Shodan API Key", # Shodan Check
"RequestLimitExceeded", # For now, we don't want to log the RequestLimitExceeded errors
"RequestLimitExceeded", # For now we don't want to log the RequestLimitExceeded errors
"ThrottlingException",
"Rate exceeded",
"SubscriptionRequiredException",
@@ -43,9 +42,7 @@ IGNORED_EXCEPTIONS = [
"AWSAccessKeyIDInvalidError",
"AWSSessionTokenExpiredError",
"EndpointConnectionError", # AWS Service is not available in a region
# The following comes from urllib3: eu-west-1 -- HTTPClientError[126]: An HTTP Client raised an
# unhandled exception: AWSHTTPSConnectionPool(host='hostname.s3.eu-west-1.amazonaws.com', port=443): Pool is closed.
"Pool is closed",
"Pool is closed", # The following comes from urllib3: eu-west-1 -- HTTPClientError[126]: An HTTP Client raised an unhandled exception: AWSHTTPSConnectionPool(host='hostname.s3.eu-west-1.amazonaws.com', port=443): Pool is closed.
# Authentication Errors from GCP
"ClientAuthenticationError",
"AuthorizationFailed",
@@ -74,7 +71,7 @@ IGNORED_EXCEPTIONS = [
def before_send(event, hint):
"""
before_send handles the Sentry events in order to send them or not
before_send handles the Sentry events in order to sent them or not
"""
# Ignore logs with the ignored_exceptions
# https://docs.python.org/3/library/logging.html#logrecord-objects
@@ -25,9 +25,18 @@ SOCIALACCOUNT_EMAIL_AUTHENTICATION = True
SOCIALACCOUNT_EMAIL_AUTHENTICATION_AUTO_CONNECT = True
SOCIALACCOUNT_ADAPTER = "api.adapters.ProwlerSocialAccountAdapter"
# SAML keys (TODO: Validate certificates)
# SAML_PUBLIC_CERT = env("SAML_PUBLIC_CERT", default="")
# SAML_PRIVATE_KEY = env("SAML_PRIVATE_KEY", default="")
# def inline(pem: str) -> str:
# return "".join(
# line.strip()
# for line in pem.splitlines()
# if "CERTIFICATE" not in line and "KEY" not in line
# )
# # SAML keys (TODO: Validate certificates)
# SAML_PUBLIC_CERT = inline(env("SAML_PUBLIC_CERT", default=""))
# SAML_PRIVATE_KEY = inline(env("SAML_PRIVATE_KEY", default=""))
SOCIALACCOUNT_PROVIDERS = {
"google": {
@@ -60,17 +69,14 @@ SOCIALACCOUNT_PROVIDERS = {
"entity_id": "urn:prowler.com:sp",
},
"advanced": {
# TODO: Validate certificates
# "x509cert": SAML_PUBLIC_CERT,
# "private_key": SAML_PRIVATE_KEY,
# "authn_request_signed": True,
# "want_assertion_signed": True,
# "want_message_signed": True,
# "want_assertion_signed": True,
"reject_idp_initiated_sso": False,
"name_id_format": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
"authn_request_signed": False,
"logout_request_signed": False,
"logout_response_signed": False,
"want_assertion_encrypted": False,
"want_name_id_encrypted": False,
},
},
}
+52 -49
View File
@@ -1,8 +1,9 @@
import logging
from datetime import datetime, timedelta, timezone
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from allauth.socialaccount.models import SocialLogin
from django.conf import settings
from django.db import connection as django_connection
from django.db import connections as django_connections
@@ -28,6 +29,8 @@ from api.models import (
Resource,
ResourceTag,
Role,
SAMLConfiguration,
SAMLDomainIndex,
Scan,
ScanSummary,
StateChoices,
@@ -1118,62 +1121,62 @@ def latest_scan_finding(authenticated_client, providers_fixture, resources_fixtu
return finding
# @pytest.fixture
# def saml_setup(tenants_fixture):
# tenant_id = tenants_fixture[0].id
# domain = "example.com"
@pytest.fixture
def saml_setup(tenants_fixture):
tenant_id = tenants_fixture[0].id
domain = "prowler.com"
# SAMLDomainIndex.objects.create(email_domain=domain, tenant_id=tenant_id)
SAMLDomainIndex.objects.create(email_domain=domain, tenant_id=tenant_id)
# metadata_xml = """<?xml version='1.0' encoding='UTF-8'?>
# <md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
# <md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
# <md:KeyDescriptor use='signing'>
# <ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
# <ds:X509Data>
# <ds:X509Certificate>TEST</ds:X509Certificate>
# </ds:X509Data>
# </ds:KeyInfo>
# </md:KeyDescriptor>
# <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
# <md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://TEST/sso/saml'/>
# <md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' Location='https://TEST/sso/saml'/>
# </md:IDPSSODescriptor>
# </md:EntityDescriptor>
# """
# SAMLConfiguration.objects.create(
# tenant_id=str(tenant_id),
# email_domain=domain,
# metadata_xml=metadata_xml,
# )
metadata_xml = """<?xml version='1.0' encoding='UTF-8'?>
<md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
<md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
<md:KeyDescriptor use='signing'>
<ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
<ds:X509Data>
<ds:X509Certificate>TEST</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://TEST/sso/saml'/>
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect' Location='https://TEST/sso/saml'/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>
"""
SAMLConfiguration.objects.create(
tenant_id=str(tenant_id),
email_domain=domain,
metadata_xml=metadata_xml,
)
# return {
# "email": f"user@{domain}",
# "domain": domain,
# "tenant_id": tenant_id,
# }
return {
"email": f"user@{domain}",
"domain": domain,
"tenant_id": tenant_id,
}
# @pytest.fixture
# def saml_sociallogin(users_fixture):
# user = users_fixture[0]
# user.email = "samlsso@acme.com"
# extra_data = {
# "firstName": ["Test"],
# "lastName": ["User"],
# "organization": ["Prowler"],
# "userType": ["member"],
# }
@pytest.fixture
def saml_sociallogin(users_fixture):
user = users_fixture[0]
user.email = "samlsso@acme.com"
extra_data = {
"firstName": ["Test"],
"lastName": ["User"],
"organization": ["Prowler"],
"userType": ["member"],
}
# account = MagicMock()
# account.provider = "saml"
# account.extra_data = extra_data
account = MagicMock()
account.provider = "saml"
account.extra_data = extra_data
# sociallogin = MagicMock(spec=SocialLogin)
# sociallogin.account = account
# sociallogin.user = user
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.account = account
sociallogin.user = user
# return sociallogin
return sociallogin
def get_authorization_header(access_token: str) -> dict:
+1 -2
View File
@@ -14,7 +14,6 @@ from api.compliance import (
generate_scan_compliance,
)
from api.db_utils import create_objects_in_batches, rls_transaction
from api.exceptions import ProviderConnectionError
from api.models import (
ComplianceRequirementOverview,
Finding,
@@ -140,7 +139,7 @@ def perform_prowler_scan(
provider_instance.connected = True
except Exception as e:
provider_instance.connected = False
exc = ProviderConnectionError(
exc = ValueError(
f"Provider {provider_instance.provider} is not connected: {e}"
)
finally:
+19 -25
View File
@@ -37,26 +37,6 @@ from prowler.lib.outputs.finding import Finding as FindingOutput
logger = get_task_logger(__name__)
def _perform_scan_complete_tasks(tenant_id: str, scan_id: str, provider_id: str):
"""
Helper function to perform tasks after a scan is completed.
Args:
tenant_id (str): The tenant ID under which the scan was performed.
scan_id (str): The ID of the scan that was performed.
provider_id (str): The primary key of the Provider instance that was scanned.
"""
create_compliance_requirements_task.apply_async(
kwargs={"tenant_id": tenant_id, "scan_id": scan_id}
)
chain(
perform_scan_summary_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs_task.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
).apply_async()
@shared_task(base=RLSTask, name="provider-connection-check")
@set_tenant
def check_provider_connection_task(provider_id: str):
@@ -123,7 +103,13 @@ def perform_scan_task(
checks_to_execute=checks_to_execute,
)
_perform_scan_complete_tasks(tenant_id, scan_id, provider_id)
chain(
perform_scan_summary_task.si(tenant_id, scan_id),
create_compliance_requirements_task.si(tenant_id=tenant_id, scan_id=scan_id),
generate_outputs.si(
scan_id=scan_id, provider_id=provider_id, tenant_id=tenant_id
),
).apply_async()
return result
@@ -228,12 +214,20 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
scheduler_task_id=periodic_task_instance.id,
)
_perform_scan_complete_tasks(tenant_id, str(scan_instance.id), provider_id)
chain(
perform_scan_summary_task.si(tenant_id, scan_instance.id),
create_compliance_requirements_task.si(
tenant_id=tenant_id, scan_id=str(scan_instance.id)
),
generate_outputs.si(
scan_id=str(scan_instance.id), provider_id=provider_id, tenant_id=tenant_id
),
).apply_async()
return result
@shared_task(name="scan-summary", queue="overview")
@shared_task(name="scan-summary")
def perform_scan_summary_task(tenant_id: str, scan_id: str):
return aggregate_findings(tenant_id=tenant_id, scan_id=scan_id)
@@ -249,7 +243,7 @@ def delete_tenant_task(tenant_id: str):
queue="scan-reports",
)
@set_tenant(keep_tenant=True)
def generate_outputs_task(scan_id: str, provider_id: str, tenant_id: str):
def generate_outputs(scan_id: str, provider_id: str, tenant_id: str):
"""
Process findings in batches and generate output files in multiple formats.
@@ -387,7 +381,7 @@ def backfill_scan_resource_summaries_task(tenant_id: str, scan_id: str):
return backfill_resource_scan_summaries(tenant_id=tenant_id, scan_id=scan_id)
@shared_task(base=RLSTask, name="scan-compliance-overviews", queue="overview")
@shared_task(base=RLSTask, name="scan-compliance-overviews")
def create_compliance_requirements_task(tenant_id: str, scan_id: str):
"""
Creates detailed compliance requirement records for a scan.
+1 -2
View File
@@ -12,7 +12,6 @@ from tasks.jobs.scan import (
)
from tasks.utils import CustomEncoder
from api.exceptions import ProviderConnectionError
from api.models import (
ComplianceRequirementOverview,
Finding,
@@ -204,7 +203,7 @@ class TestPerformScan:
provider_id = str(provider.id)
checks_to_execute = ["check1", "check2"]
with pytest.raises(ProviderConnectionError):
with pytest.raises(ValueError):
perform_prowler_scan(tenant_id, scan_id, provider_id, checks_to_execute)
scan.refresh_from_db()
+8 -31
View File
@@ -3,10 +3,9 @@ from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from tasks.tasks import _perform_scan_complete_tasks, generate_outputs_task
from tasks.tasks import generate_outputs
# TODO Move this to outputs/reports jobs
@pytest.mark.django_db
class TestGenerateOutputs:
def setup_method(self):
@@ -18,7 +17,7 @@ class TestGenerateOutputs:
with patch("tasks.tasks.ScanSummary.objects.filter") as mock_filter:
mock_filter.return_value.exists.return_value = False
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -100,7 +99,7 @@ class TestGenerateOutputs:
mock_compress.return_value = "/tmp/zipped.zip"
mock_upload.return_value = "s3://bucket/zipped.zip"
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -151,7 +150,7 @@ class TestGenerateOutputs:
True,
]
result = generate_outputs_task(
result = generate_outputs(
scan_id="scan",
provider_id="provider",
tenant_id=self.tenant_id,
@@ -209,7 +208,7 @@ class TestGenerateOutputs:
{"aws": [(lambda x: True, MagicMock())]},
),
):
generate_outputs_task(
generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -277,7 +276,7 @@ class TestGenerateOutputs:
}
},
):
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -347,7 +346,7 @@ class TestGenerateOutputs:
):
mock_summary.return_value.exists.return_value = True
result = generate_outputs_task(
result = generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
@@ -408,31 +407,9 @@ class TestGenerateOutputs:
),
):
with caplog.at_level("ERROR"):
generate_outputs_task(
generate_outputs(
scan_id=self.scan_id,
provider_id=self.provider_id,
tenant_id=self.tenant_id,
)
assert "Error deleting output files" in caplog.text
class TestScanCompleteTasks:
@patch("tasks.tasks.create_compliance_requirements_task.apply_async")
@patch("tasks.tasks.perform_scan_summary_task.si")
@patch("tasks.tasks.generate_outputs_task.si")
def test_scan_complete_tasks(
self, mock_outputs_task, mock_scan_summary_task, mock_compliance_tasks
):
_perform_scan_complete_tasks("tenant-id", "scan-id", "provider-id")
mock_compliance_tasks.assert_called_once_with(
kwargs={"tenant_id": "tenant-id", "scan_id": "scan-id"},
)
mock_scan_summary_task.assert_called_once_with(
scan_id="scan-id",
tenant_id="tenant-id",
)
mock_outputs_task.assert_called_once_with(
scan_id="scan-id",
provider_id="provider-id",
tenant_id="tenant-id",
)
+134
View File
@@ -0,0 +1,134 @@
# Extending Prowler Lighthouse
This guide helps developers customize and extend Prowler Lighthouse by adding or modifying AI agents.
## Understanding AI Agents
AI agents combine Large Language Models (LLMs) with specialized tools that provide environmental context. These tools can include API calls, system command execution, or any function-wrapped capability.
### Types of AI Agents
AI agents fall into two main categories:
- **Autonomous Agents**: Freely chooses from available tools to complete tasks, adapting their approach based on context. They decide which tools to use and when.
- **Workflow Agents**: Follows structured paths with predefined logic. They execute specific tool sequences and can include conditional logic.
Prowler Lighthouse is an autonomous agent - selecting the right tool(s) based on the users query.
???+ note
To learn more about AI agents, read [Anthropic's blog post on building effective agents](https://www.anthropic.com/engineering/building-effective-agents).
### LLM Dependency
The autonomous nature of agents depends on the underlying LLM. Autonomous agents using identical system prompts and tools but powered by different LLM providers might approach user queries differently. Agent with one LLM might solve a problem efficiently, while with another it might take a different route or fail entirely.
After evaluating multiple LLM providers (OpenAI, Gemini, Claude, LLama) based on tool calling features and response accuracy, we recommend using the `gpt-4o` model.
## Prowler Lighthouse Architecture
Prowler Lighthouse uses a multi-agent architecture orchestrated by the [Langgraph-Supervisor](https://www.npmjs.com/package/@langchain/langgraph-supervisor) library.
### Architecture Components
<img src="../../tutorials/img/lighthouse-architecture.png" alt="Prowler Lighthouse architecture">
Prowler Lighthouse integrates with the NextJS application:
- The [Langgraph-Supervisor](https://www.npmjs.com/package/@langchain/langgraph-supervisor) library integrates directly with NextJS
- The system uses the authenticated user session to interact with the Prowler API server
- Agents only access data the current user is authorized to view
- Session management operates automatically, ensuring Role-Based Access Control (RBAC) is maintained
## Available Prowler AI Agents
The following specialized AI agents are available in Prowler:
### Agent Overview
- **provider_agent**: Fetches information about cloud providers connected to Prowler
- **user_info_agent**: Retrieves information about Prowler users
- **scans_agent**: Fetches information about Prowler scans
- **compliance_agent**: Retrieves compliance overviews across scans
- **findings_agent**: Fetches information about individual findings across scans
- **overview_agent**: Retrieves overview information (providers, findings by status and severity, etc.)
## How to Add New Capabilities
### Updating the Supervisor Prompt
The supervisor agent controls system behavior, tone, and capabilities. You can find the supervisor prompt at: [https://github.com/prowler-cloud/prowler/blob/master/ui/lib/lighthouse/prompts.ts](https://github.com/prowler-cloud/prowler/blob/master/ui/lib/lighthouse/prompts.ts)
#### Supervisor Prompt Modifications
Modifying the supervisor prompt allows you to:
- Change personality or response style
- Add new high-level capabilities
- Modify task delegation to specialized agents
- Set up guardrails (query types to answer or decline)
???+ note
The supervisor agent should not have its own tools. This design keeps the system modular and maintainable.
### How to Create New Specialized Agents
The supervisor agent and all specialized agents are defined in the `route.ts` file. The supervisor agent uses [langgraph-supervisor](https://www.npmjs.com/package/@langchain/langgraph-supervisor), while other agents use the prebuilt [create-react-agent](https://langchain-ai.github.io/langgraphjs/how-tos/create-react-agent/).
To add new capabilities or all Lighthouse to interact with other APIs, create additional specialized agents:
1. First determine what the new agent would do. Create a detailed prompt defining the agent's purpose and capabilities. You can see an example from [here](https://github.com/prowler-cloud/prowler/blob/master/ui/lib/lighthouse/prompts.ts#L359-L385).
???+ note
Ensure that the new agent's capabilities don't collide with existing agents. For example, if there's already a *findings_agent* that talks to findings APIs don't create a new agent to do the same.
2. Create necessary tools for the agents to access specific data or perform actions. A tool is a specialized function that extends the capabilities of LLM by allowing it to access external data or APIs. A tool is triggered by LLM based on the description of the tool and the user's query.
For example, the description of `getScanTool` is "Fetches detailed information about a specific scan by its ID." If the description doesn't convey what the tool is capable of doing, LLM will not invoke the function. If the description of `getScanTool` was set to something random or not set at all, LLM will not answer queries like "Give me the critical issues from the scan ID xxxxxxxxxxxxxxx"
???+ note
Ensure that one tool is added to one agent only. Adding tools is optional. There can be agents with no tools at all.
3. Use the `createReactAgent` function to define a new agent. For example, the rolesAgent name is "roles_agent" and has access to call tools "*getRolesTool*" and "*getRoleTool*"
```js
const rolesAgent = createReactAgent({
llm: llm,
tools: [getRolesTool, getRoleTool],
name: "roles_agent",
prompt: rolesAgentPrompt,
});
```
4. Create a detailed prompt defining the agent's purpose and capabilities.
5. Add the new agent to the available agents list:
```js
const agents = [
userInfoAgent,
providerAgent,
overviewAgent,
scansAgent,
complianceAgent,
findingsAgent,
rolesAgent, // New agent added here
];
// Create supervisor workflow
const workflow = createSupervisor({
agents: agents,
llm: supervisorllm,
prompt: supervisorPrompt,
outputMode: "last_message",
});
```
6. Update the supervisor's system prompt to summarize the new agent's capabilities.
### Best Practices for Agent Development
When developing new agents or capabilities:
- **Clear Responsibility Boundaries**: Each agent should have a defined purpose with minimal overlap. No two agents should access the same tools or different tools accessing the same Prowler APIs.
- **Minimal Data Access**: Agents should only request the data they need, keeping requests specific to minimize context window usage, cost, and response time.
- **Thorough Prompting:** Ensure agent prompts include clear instructions about:
- The agent's purpose and limitations
- How to use its tools
- How to format responses for the supervisor
- Error handling procedures (Optional)
- **Security Considerations:** Agents should never modify data or access sensitive information like secrets or credentials.
- **Testing:** Thoroughly test new agents with various queries before deploying to production.
+45 -18
View File
@@ -156,7 +156,7 @@ Follow the instructions in the [Create Prowler Service Principal](../tutorials/m
If you don't add the external API permissions described in the mentioned section above you will only be able to run the checks that work through MS Graph. This means that you won't run all the provider.
If you want to scan all the checks from M365 you will need to use the recommended authentication method or add the external API permissions.
If you want to scan all the checks from M365 you will need to add the required permissions to the service principal application. Refer to the [Needed permissions](/docs/tutorials/microsoft365/getting-started-m365.md#needed-permissions) section for more information.
### Service Principal and User Credentials authentication
@@ -172,9 +172,10 @@ export M365_USER="your_email@example.com"
export M365_PASSWORD="examplepassword"
```
These two new environment variables are **required** to execute the PowerShell modules needed to retrieve information from M365 services. Prowler uses Service Principal authentication to access Microsoft Graph and user credentials to authenticate to Microsoft PowerShell modules.
These two new environment variables are **required** in this authentication method to execute the PowerShell modules needed to retrieve information from M365 services. Prowler uses Service Principal authentication to access Microsoft Graph and user credentials to authenticate to Microsoft PowerShell modules.
- `M365_USER` should be your Microsoft account email using the **assigned domain in the tenant**. This means it must look like `example@YourCompany.onmicrosoft.com` or `example@YourCompany.com`, but it must be the exact domain assigned to that user in the tenant.
???+ warning
If the user is newly created, you need to sign in with that account first, as Microsoft will prompt you to change the password. If you dont complete this step, user authentication will fail because Microsoft marks the initial password as expired.
@@ -207,30 +208,56 @@ Since this is a delegated permission authentication method, necessary permission
### Needed permissions
Prowler for M365 requires two types of permission scopes to be set (if you want to run the full provider including PowerShell checks). Both must be configured using Microsoft Entra ID:
Prowler for M365 requires different permission scopes depending on the authentication method you choose. The permissions must be configured using Microsoft Entra ID:
- **Service Principal Application Permissions**: These are set at the **application** level and are used to retrieve data from the identity being assessed:
- `AuditLog.Read.All`: Required for Entra service.
- `Directory.Read.All`: Required for all services.
- `Policy.Read.All`: Required for all services.
- `SharePointTenantSettings.Read.All`: Required for SharePoint service.
- `User.Read` (IMPORTANT: this must be set as **delegated**): Required for the sign-in.
- `Exchange.ManageAsApp` from external API `Office 365 Exchange Online`: Required for Exchange PowerShell module app authentication. You also need to assign the `Exchange Administrator` role to the app.
- `application_access` from external API `Skype and Teams Tenant Admin API`: Required for Teams PowerShell module app authentication.
#### For Service Principal Authentication (`--sp-env-auth`) - Recommended
???+ note
You can replace `Directory.Read.All` with `Domain.Read.All` is a more restrictive permission but you won't be able to run the Entra checks related with DirectoryRoles and GetUsers.
When using service principal authentication, you need to add the following **Application Permissions** configured to:
> If you do this you will need to add also the `Organization.Read.All` permission to the service principal application in order to authenticate.
**Microsoft Graph API Permissions:**
- `AuditLog.Read.All`: Required for Entra service.
- `Directory.Read.All`: Required for all services.
- `Policy.Read.All`: Required for all services.
- `SharePointTenantSettings.Read.All`: Required for SharePoint service.
- `User.Read` (IMPORTANT: this must be set as **delegated**): Required for the sign-in.
**External API Permissions:**
- `Exchange.ManageAsApp` from external API `Office 365 Exchange Online`: Required for Exchange PowerShell module app authentication. You also need to assign the `Exchange Administrator` role to the app.
- `application_access` from external API `Skype and Teams Tenant Admin API`: Required for Teams PowerShell module app authentication.
???+ note
You can replace `Directory.Read.All` with `Domain.Read.All` is a more restrictive permission but you won't be able to run the Entra checks related with DirectoryRoles and GetUsers.
> If you do this you will need to add also the `Organization.Read.All` permission to the service principal application in order to authenticate.
- **Powershell Modules Permissions** (if using user credentials): These are set at the `M365_USER` level, so the user used to run Prowler must have one of the following roles:
- `Global Reader` (recommended): this allows you to read all roles needed.
- `Exchange Administrator` and `Teams Administrator`: user needs both roles but with this [roles](https://learn.microsoft.com/en-us/exchange/permissions-exo/permissions-exo#microsoft-365-permissions-in-exchange-online) you can access to the same information as a Global Reader (since only read access is needed, Global Reader is recommended).
???+ warning
With service principal only authentication, you can only run checks that work through MS Graph API. Some checks that require PowerShell modules will not be executed.
In order to know how to assign those permissions and roles follow the instructions in the Microsoft Entra ID [permissions](../tutorials/microsoft365/getting-started-m365.md#grant-required-api-permissions) and [roles](../tutorials/microsoft365/getting-started-m365.md#assign-required-roles-to-your-user) section.
#### For Service Principal + User Credentials Authentication (`--env-auth`)
When using service principal with user credentials authentication, you need **both** sets of permissions:
**1. Service Principal Application Permissions**:
- You **will need** all the Microsoft Graph API permissions listed above.
- You **won't need** the External API permissions listed above.
**2. User-Level Permissions**: These are set at the `M365_USER` level, so the user used to run Prowler must have one of the following roles:
- `Global Reader` (recommended): this allows you to read all roles needed.
- `Exchange Administrator` and `Teams Administrator`: user needs both roles but with this [roles](https://learn.microsoft.com/en-us/exchange/permissions-exo/permissions-exo#microsoft-365-permissions-in-exchange-online) you can access to the same information as a Global Reader (since only read access is needed, Global Reader is recommended).
???+ note
This is the **recommended authentication method** because it allows you to run the full M365 provider including PowerShell checks, providing complete coverage of all available security checks.
#### For Browser Authentication (`--browser-auth`)
When using browser authentication, permissions are delegated to the user, so the user must have the appropriate permissions rather than the application.
???+ warning
With browser authentication, you will only be able to run checks that work through MS Graph API. PowerShell module checks will not be executed.
---
**To assign these permissions and roles**, follow the instructions in the Microsoft Entra ID [permissions](../tutorials/microsoft365/getting-started-m365.md#grant-required-api-permissions) and [roles](../tutorials/microsoft365/getting-started-m365.md#assign-required-roles-to-your-user) section.
### Supported PowerShell versions
Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 433 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 197 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 268 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 404 KiB

@@ -12,7 +12,7 @@ This allows Prowler to authenticate against Microsoft 365 using the following me
To launch the tool first you need to specify which method is used through the following flags:
```console
# To use service principal (app) authentication and Microsoft user credentials (to use PowerShell)
# To use service principal (app) authentication and Microsoft user credentials
prowler m365 --env-auth
# To use service principal authentication
@@ -25,4 +25,4 @@ prowler m365 --az-cli-auth
prowler m365 --browser-auth --tenant-id "XXXXXXXX"
```
To use Prowler you need to set up also the permissions required to access your resources in your Microsoft 365 account, to more details refer to [Requirements](../../getting-started/requirements.md#microsoft-365)
To use Prowler you need to set up also the permissions required to access your resources in your Microsoft 365 account, to more details refer to [Requirements](../../getting-started/requirements.md#needed-permissions-2)
@@ -30,7 +30,7 @@ Go to the Entra ID portal, then you can search for `Domain` or go to Identity >
![Custom Domain Names](./img/custom-domain-names.png)
Once you are there just select the domain you want to use.
Once you are there just select the domain you want to use as unique identifier for your M365 account in Prowler Cloud/App.
---
@@ -139,10 +139,7 @@ The permissions you need to grant depends on whether you are using user credenti
Make sure you add the correct set of permissions for the authentication method you are using.
#### If using application(service principal) authentication
???+ warning "Warning"
Currently Prowler Cloud only supports user authentication.
#### If using application(service principal) authentication (Recommended)
To grant the permissions for the PowerShell modules via application authentication, you need to add the necessary APIs to your app registration.
@@ -191,12 +188,15 @@ To grant the permissions for the PowerShell modules via application authenticati
![Final Permission Assignment](./img/final-permissions.png)
---
#### If using user authentication
This method is not recommended because it requires a user with MFA enabled and Microsoft will not allow MFA capable users to authenticate programmatically after 1st September 2025. See [Microsoft documentation](https://learn.microsoft.com/en-us/entra/identity/authentication/concept-mandatory-multifactor-authentication?tabs=dotnet) for more information.
???+ warning
Remember that if the user is newly created, you need to sign in with that account first, as Microsoft will prompt you to change the password. If you dont complete this step, user authentication will fail because Microsoft marks the initial password as expired.
---
#### If using user authentication (Currently Prowler Cloud only supports this method)
1. Search and select:
@@ -253,6 +253,8 @@ To grant the permissions for the PowerShell modules via application authenticati
- `Client ID`
- `Tenant ID`
- `AZURE_CLIENT_SECRET` from earlier
If you are using user authentication, also add:
- `M365_USER` the user using the correct assigned domain, more info [here](../../getting-started/requirements.md#service-principal-and-user-credentials-authentication-recommended)
- `M365_PASSWORD` the password of the user
Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 433 KiB

+204
View File
@@ -0,0 +1,204 @@
# Prowler Lighthouse
Prowler Lighthouse is an AI Cloud Security Analyst chatbot that helps you understand, prioritize, and remediate security findings in your cloud environments. It's designed to provide security expertise for teams without dedicated resources, acting as your 24/7 virtual cloud security analyst.
<img src="../img/lighthouse-intro.png" alt="Prowler Lighthouse">
## How It Works
Prowler Lighthouse uses OpenAI's language models and integrates with your Prowler security findings data.
Here's what's happening behind the scenes:
- The system uses a multi-agent architecture built with [LanggraphJS](https://github.com/langchain-ai/langgraphjs) for LLM logic and [Vercel AI SDK UI](https://sdk.vercel.ai/docs/ai-sdk-ui/overview) for frontend chatbot.
- It uses a ["supervisor" architecture](https://langchain-ai.lang.chat/langgraphjs/tutorials/multi_agent/agent_supervisor/) that interacts with different agents for specialized tasks. For example, `findings_agent` can analyze detected security findings, while `overview_agent` provides a summary of connected cloud accounts.
- The system connects to OpenAI models to understand, fetch the right data, and respond to the user's query.
???+ note
Lighthouse is tested against `gpt-4o` and `gpt-4o-mini` OpenAI models.
- The supervisor agent is the main contact point. It is what users interact with directly from the chat interface. It coordinates with other agents to answer users' questions comprehensively.
<img src="../img/lighthouse-architecture.png" alt="Lighthouse Architecture">
???+ note
All agents can only read relevant security data. They cannot modify your data or access sensitive information like configured secrets or tenant details.
## Set up
Getting started with Prowler Lighthouse is easy:
1. Go to the configuration page in your Prowler dashboard.
2. Enter your OpenAI API key.
3. Select your preferred model. The recommended one for best results is `gpt-4o`.
4. (Optional) Add business context to improve response quality and prioritization.
<img src="../img/lighthouse-config.png" alt="Lighthouse Configuration">
### Adding Business Context
The optional business context field lets you provide additional information to help Lighthouse understand your environment and priorities, including:
- Your organization's cloud security goals
- Information about account owners or responsible teams
- Compliance requirements for your organization
- Current security initiatives or focus areas
Better context leads to more relevant responses and prioritization that aligns with your needs.
## Capabilities
Prowler Lighthouse is designed to be your AI security team member, with capabilities including:
### Natural Language Querying
Ask questions in plain English about your security findings. Examples:
- "What are my highest risk findings?"
- "Show me all S3 buckets with public access."
- "What security issues were found in my production accounts?"
<img src="../img/lighthouse-feature1.png" alt="Natural language querying">
### Detailed Remediation Guidance
Get tailored step-by-step instructions for fixing security issues:
- Clear explanations of the problem and its impact
- Commands or console steps to implement fixes
- Alternative approaches with different solutions
<img src="../img/lighthouse-feature2.png" alt="Detailed Remediation">
### Enhanced Context and Analysis
Lighthouse can provide additional context to help you understand the findings:
- Explain security concepts related to findings in simple terms
- Provide risk assessments based on your environment and context
- Connect related findings to show broader security patterns
<img src="../img/lighthouse-config.png" alt="Business Context">
<img src="../img/lighthouse-feature3.png" alt="Contextual Responses">
## Important Notes
Prowler Lighthouse is powerful, but there are limitations:
- **Continuous improvement**: Please report any issues, as the feature may make mistakes or encounter errors, despite extensive testing.
- **Access limitations**: Lighthouse can only access data the logged-in user can view. If you can't see certain information, Lighthouse can't see it either.
- **NextJS session dependence**: If your Prowler application session expires or logs out, Lighthouse will error out. Refresh and log back in to continue.
- **Response quality**: The response quality depends on the selected OpenAI model. For best results, use gpt-4o.
### Getting Help
If you encounter issues with Prowler Lighthouse or have suggestions for improvements, please [reach out through our Slack channel](https://goto.prowler.com/slack).
### What Data Is Shared to OpenAI?
The following API endpoints are accessible to Prowler Lighthouse. Data from the following API endpoints could be shared with OpenAI depending on the scope of user's query:
#### Accessible API Endpoints
**User Management:**
- List all users - `/api/v1/users`
- Retrieve the current user's information - `/api/v1/users/me`
**Provider Management:**
- List all providers - `/api/v1/providers`
- Retrieve data from a provider - `/api/v1/providers/{id}`
**Scan Management:**
- List all scans - `/api/v1/scans`
- Retrieve data from a specific scan - `/api/v1/scans/{id}`
**Resource Management:**
- List all resources - `/api/v1/resources`
- Retrieve data for a resource - `/api/v1/resources/{id}`
**Findings Management:**
- List all findings - `/api/v1/findings`
- Retrieve data from a specific finding - `/api/v1/findings/{id}`
- Retrieve metadata values from findings - `/api/v1/findings/metadata`
**Overview Data:**
- Get aggregated findings data - `/api/v1/overviews/findings`
- Get findings data by severity - `/api/v1/overviews/findings_severity`
- Get aggregated provider data - `/api/v1/overviews/providers`
- Get findings data by service - `/api/v1/overviews/services`
**Compliance Management:**
- List compliance overviews for a scan - `/api/v1/compliance-overviews`
- Retrieve data from a specific compliance overview - `/api/v1/compliance-overviews/{id}`
#### Excluded API Endpoints
Not all Prowler API endpoints are integrated with Lighthouse. They are intentionally excluded for the following reasons:
- OpenAI/other LLM providers shouldn't have access to sensitive data (like fetching provider secrets and other sensitive config)
- Users queries don't need responses from those API endpoints (ex: tasks, tenant details, downloading zip file, etc.)
**Excluded Endpoints:**
**User Management:**
- List specific users information - `/api/v1/users/{id}`
- List user memberships - `/api/v1/users/{user_pk}/memberships`
- Retrieve membership data from the user - `/api/v1/users/{user_pk}/memberships/{id}`
**Tenant Management:**
- List all tenants - `/api/v1/tenants`
- Retrieve data from a tenant - `/api/v1/tenants/{id}`
- List tenant memberships - `/api/v1/tenants/{tenant_pk}/memberships`
- List all invitations - `/api/v1/tenants/invitations`
- Retrieve data from tenant invitation - `/api/v1/tenants/invitations/{id}`
**Security and Configuration:**
- List all secrets - `/api/v1/providers/secrets`
- Retrieve data from a secret - `/api/v1/providers/secrets/{id}`
- List all provider groups - `/api/v1/provider-groups`
- Retrieve data from a provider group - `/api/v1/provider-groups/{id}`
**Reports and Tasks:**
- Download zip report - `/api/v1/scans/{v1}/report`
- List all tasks - `/api/v1/tasks`
- Retrieve data from a specific task - `/api/v1/tasks/{id}`
**Lighthouse Configuration:**
- List OpenAI configuration - `/api/v1/lighthouse-config`
- Retrieve OpenAI key and configuration - `/api/v1/lighthouse-config/{id}`
???+ note
Agents only have access to hit GET endpoints. They don't have access to other HTTP methods.
## FAQs
**1. Why only OpenAI models?**
During feature development, we evaluated other LLM models.
- **Claude AI** - Claude models have [tier-based ratelimits](https://docs.anthropic.com/en/api/rate-limits#requirements-to-advance-tier). For Lighthouse to answer slightly complex questions, there are a handful of API calls to the LLM provider within few seconds. With Claude's tiering system, users must purchase $400 credits or convert their subscription to monthly invoicing after talking to their sales team. This pricing may not suit all Prowler users.
- **Gemini Models** - Gemini lacks a solid tool calling feature like OpenAI. It calls functions recursively until exceeding limits. Gemini-2.5-Pro-Experimental is better than previous models regarding tool calling and responding, but it's still experimental.
- **Deepseek V3** - Doesn't support system prompt messages.
**2. Why a multi-agent supervisor model?**
Context windows are limited. While demo data fits inside the context window, querying real-world data often exceeds it. A multi-agent architecture is used so different agents fetch different sizes of data and respond with the minimum required data to the supervisor. This spreads the context window usage across agents.
**3. Is my security data shared with OpenAI?**
Minimal data is shared to generate useful responses. Agents can access security findings and remediation details when needed. Provider secrets are protected by design and cannot be read. The Lighthouse key is only accessible to our NextJS server and is never sent to LLMs. Resource metadata (names, tags, account/project IDs, etc) may be shared with OpenAI based on your query requirements.
**4. Can the Lighthouse change my cloud environment?**
No. The agent doesn't have the tools to make the changes, even if the configured cloud provider API keys contain permissions to modify resources.
+10 -51
View File
@@ -1,6 +1,6 @@
# Configuring SAML Single Sign-On (SSO) in Prowler
This guide explains how to enable and test SAML SSO integration in Prowler. It includes environment setup, certificate configuration, API endpoints, and how to configure Okta as your Identity Provider (IdP).
This guide explains how to enable and test SAML SSO integration in Prowler. It includes environment setup, API endpoints, and how to configure Okta as your Identity Provider (IdP).
---
@@ -20,26 +20,6 @@ Update this variable to specify which domains Django should accept incoming requ
DJANGO_ALLOWED_HOSTS=localhost,127.0.0.1,prowler-api,mycompany.prowler
```
# SAML Certificates
To enable SAML support, you must provide a public certificate and private key to allow Prowler to sign SAML requests and validate responses.
### Why is this necessary?
SAML relies on digital signatures to verify trust between the Identity Provider (IdP) and the Service Provider (SP). Prowler acts as the SP and must use a certificate to sign outbound authentication requests.
### Add to your .env file:
```env
SAML_PUBLIC_CERT="-----BEGIN CERTIFICATE-----
...your certificate here...
-----END CERTIFICATE-----"
SAML_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----
...your private key here...
-----END PRIVATE KEY-----"
```
# SAML Configuration API
You can manage SAML settings via the API. Prowler provides full CRUD support for tenant-specific SAML configuration.
@@ -60,7 +40,7 @@ You can manage SAML settings via the API. Prowler provides full CRUD support for
### Description
This endpoint receives an email and checks if there is an active SAML configuration for the associated domain (i.e., the part after the @). If a configuration exists and the required certificates are present, it responds with an HTTP 302 redirect to the appropriate saml_login endpoint for the organization.
This endpoint receives an email and checks if there is an active SAML configuration for the associated domain (i.e., the part after the @). If a configuration exists it responds with an HTTP 302 redirect to the appropriate saml_login endpoint for the organization.
- POST /api/v1/accounts/saml/initiate/
@@ -78,7 +58,7 @@ This endpoint receives an email and checks if there is an active SAML configurat
• 302 FOUND: Redirects to the SAML login URL associated with the organization.
• 403 FORBIDDEN: The domain is not authorized or SAML certificates are missing from the configuration.
• 403 FORBIDDEN: The domain is not authorized.
### Validation logic
@@ -86,8 +66,6 @@ This endpoint receives an email and checks if there is an active SAML configurat
• Retrieves the related SAMLConfiguration object via tenant_id.
• Verifies that SAML_PUBLIC_CERT and SAML_PRIVATE_KEY environment variables are set.
# SAML Integration: Testing Guide
@@ -95,26 +73,7 @@ This document outlines the process for testing the SAML integration functionalit
---
## 1. Generate Self-Signed Certificate and Private Key
First, generate a self-signed certificate and corresponding private key using OpenSSL:
```bash
openssl req -x509 -nodes -days 3650 -newkey rsa:2048 \
-keyout saml_private_key.pem \
-out saml_public_cert.pem \
-subj "/C=US/ST=Test/L=Test/O=Test/OU=Test/CN=localhost"
```
## 2. Add Certificate Values to .env
Paste the generated values into your .env file:
```
SAML_PUBLIC_CERT=<paste certificate content here>
SAML_PRIVATE_KEY=<paste private key content here>
```
## 3. Start Ngrok and Update ALLOWED_HOSTS
## 1. Start Ngrok and Update ALLOWED_HOSTS
Start ngrok on port 8080:
```
@@ -127,7 +86,7 @@ Then, copy the generated ngrok URL and include it in the ALLOWED_HOSTS setting.
ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["*"])
```
## 4. Configure the Identity Provider (IdP)
## 2. Configure the Identity Provider (IdP)
Start your environment and configure your IdP. You will need to download the IdP's metadata XML file.
@@ -137,7 +96,7 @@ Your Assertion Consumer Service (ACS) URL must follow this format:
https://<PROXY_URL>/api/v1/accounts/saml/<CONFIGURED_DOMAIN>/acs/
```
## 5. IdP Attribute Mapping
## 3. IdP Attribute Mapping
The following fields are expected from the IdP:
@@ -151,7 +110,7 @@ The following fields are expected from the IdP:
These values are dynamic. If the values change in the IdP, they will be updated on the next login.
## 6. SAML Configuration API (POST)
## 4. SAML Configuration API (POST)
SAML configuration is managed via a CRUD API. Use the following POST request to create a new configuration:
@@ -171,7 +130,7 @@ curl --location 'http://localhost:8080/api/v1/saml-config' \
}'
```
## 7. SAML SSO Callback Configuration
## 5. SAML SSO Callback Configuration
### Environment Variable Configuration
@@ -201,7 +160,7 @@ AUTH_URL="<WEB_UI_URL>"
- Both environment variables are required for proper SAML SSO functionality
- Verify that the `NEXT_PUBLIC_API_BASE_URL` environment variable is properly configured to reference the correct API server base URL corresponding to your target deployment environment. This ensures proper routing of SAML callback requests to the appropriate backend services.
## 8. Start SAML Login Flow
## 6. Start SAML Login Flow
Once everything is configured, start the SAML login process by visiting the following URL:
@@ -211,6 +170,6 @@ https://<PROXY_IP>/api/v1/accounts/saml/<CONFIGURED_DOMAIN>/login/?email=<USER_E
At the end you will get a valid access and refresh token
## 9. Notes on the initiate Endpoint
## 7. Notes on the initiate Endpoint
The initiate endpoint is not strictly required. It was created to allow extra checks or behavior modifications (like enumeration mitigation). It also simplifies UI integration with SAML, but again, it's optional.
+9 -1
View File
@@ -170,7 +170,15 @@ By default, the `kubeconfig` file is located at `~/.kube/config`.
---
### **Step 4.5: M365 Credentials**
For M365, Prowler App uses a service principal application with user and password to authenticate, for more information about the requirements needed for this provider check this [section](../getting-started/requirements.md#microsoft-365). Also, the detailed steps of how to add this provider to Prowler Cloud and start using it are [here](./microsoft365/getting-started-m365.md).
For M365, you must enter your Domain ID and choose the authentication method you want to use:
- Service Principal Authentication (Recommended)
- User Authentication (only works if the user does not have MFA enabled)
???+ note
User authentication with M365_USER and M365_PASSWORD is optional and will only work if the account does not enforce MFA.
For full setup instructions and requirements, check the [Microsoft 365 provider requirements](./microsoft365/getting-started-m365.md).
<img src="../../img/m365-credentials.png" alt="Prowler Cloud M365 Credentials" width="700"/>
+2
View File
@@ -54,6 +54,7 @@ nav:
- Role-Based Access Control: tutorials/prowler-app-rbac.md
- Social Login: tutorials/prowler-app-social-login.md
- SSO with SAML: tutorials/prowler-app-sso.md
- Lighthouse: tutorials/prowler-app-lighthouse.md
- CLI:
- Miscellaneous: tutorials/misc.md
- Reporting: tutorials/reporting.md
@@ -117,6 +118,7 @@ nav:
- Outputs: developer-guide/outputs.md
- Integrations: developer-guide/integrations.md
- Compliance: developer-guide/security-compliance-framework.md
- Lighthouse: developer-guide/lighthouse.md
- Provider Specific Details:
- AWS: developer-guide/aws-details.md
- Azure: developer-guide/azure-details.md
Generated
+38 -2
View File
@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]]
name = "about-time"
@@ -666,6 +666,42 @@ azure-common = ">=1.1,<2.0"
azure-mgmt-core = ">=1.3.0,<2.0.0"
msrest = ">=0.6.21"
[[package]]
name = "azure-mgmt-recoveryservices"
version = "3.1.0"
description = "Microsoft Azure Recovery Services Client Library for Python"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "azure_mgmt_recoveryservices-3.1.0-py3-none-any.whl", hash = "sha256:21c58afdf4ae66806783e95f8cd17e3bec31be7178c48784db21f0b05de7fa66"},
{file = "azure_mgmt_recoveryservices-3.1.0.tar.gz", hash = "sha256:7f2db98401708cf145322f50bc491caf7967bec4af3bf7b0984b9f07d3092687"},
]
[package.dependencies]
azure-common = ">=1.1"
azure-mgmt-core = ">=1.5.0"
isodate = ">=0.6.1"
typing-extensions = ">=4.6.0"
[[package]]
name = "azure-mgmt-recoveryservicesbackup"
version = "9.2.0"
description = "Microsoft Azure Recovery Services Backup Management Client Library for Python"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "azure_mgmt_recoveryservicesbackup-9.2.0-py3-none-any.whl", hash = "sha256:c0002858d0166b6a10189a1fd580a49c83dc31b111e98010a5b2ea0f767dfff1"},
{file = "azure_mgmt_recoveryservicesbackup-9.2.0.tar.gz", hash = "sha256:c402b3e22a6c3879df56bc37e0063142c3352c5102599ff102d19824f1b32b29"},
]
[package.dependencies]
azure-common = ">=1.1"
azure-mgmt-core = ">=1.3.2"
isodate = ">=0.6.1"
typing-extensions = ">=4.6.0"
[[package]]
name = "azure-mgmt-resource"
version = "23.3.0"
@@ -6624,4 +6660,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">3.9.1,<3.13"
content-hash = "c442552635c8e904d1c7a50f4787c8e90ec90787960ee1867f2235a7aa2205f0"
content-hash = "a0c3e917dcedf073426ae47c942c1db1e04e14ea1ab1a81d7fb91f2873daf1cb"
+7 -4
View File
@@ -5,6 +5,11 @@ All notable changes to the **Prowler SDK** are documented in this file.
## [v5.9.0] (Prowler UNRELEASED)
### Added
- `storage_smb_channel_encryption_with_secure_algorithm` check for Azure provider [(#8123)](https://github.com/prowler-cloud/prowler/pull/8123)
- `vm_backup_enabled` check for Azure provider [(#8182)](https://github.com/prowler-cloud/prowler/pull/8182)
- `vm_linux_enforce_ssh_authentication` check for Azure provider [(#8149)](https://github.com/prowler-cloud/prowler/pull/8149)
- `vm_ensure_using_approved_images` check for Azure provider [(#8168)](https://github.com/prowler-cloud/prowler/pull/8168)
- `vm_scaleset_associated_load_balancer` check for Azure provider [(#8181)](https://github.com/prowler-cloud/prowler/pull/8181)
### Changed
@@ -15,10 +20,8 @@ All notable changes to the **Prowler SDK** are documented in this file.
## [v5.8.1] (Prowler 5.8.1)
### Fixed
- Detect wildcarded ARNs in sts:AssumeRole policy resources [(#8164)](https://github.com/prowler-cloud/prowler/pull/8164)
- List all streams and `firehose_stream_encrypted_at_rest` logic [(#8213)](https://github.com/prowler-cloud/prowler/pull/8213)
- Allow empty values for http_endpoint in templates [(#8184)](https://github.com/prowler-cloud/prowler/pull/8184)
- Convert all Azure Storage models to Pydantic models to avoid serialization issues [(#8222)](https://github.com/prowler-cloud/prowler/pull/8222)
- fix(iam): detect wildcarded ARNs in sts:AssumeRole policy resources [(#8164)](https://github.com/prowler-cloud/prowler/pull/8164)
- fix(ec2): allow empty values for http_endpoint in templates [(#8184)](https://github.com/prowler-cloud/prowler/pull/8184)
---
+1 -1
View File
@@ -12,7 +12,7 @@ from prowler.lib.logger import logger
timestamp = datetime.today()
timestamp_utc = datetime.now(timezone.utc).replace(tzinfo=timezone.utc)
prowler_version = "5.8.1"
prowler_version = "5.9.0"
html_logo_url = "https://github.com/prowler-cloud/prowler/"
square_logo_img = "https://prowler.com/wp-content/uploads/logo-html.png"
aws_logo = "https://user-images.githubusercontent.com/38561120/235953920-3e3fba08-0795-41dc-b480-9bea57db9f2e.png"
@@ -1985,6 +1985,7 @@
"eu-central-2",
"eu-north-1",
"eu-south-1",
"eu-south-2",
"eu-west-1",
"eu-west-2",
"eu-west-3",
@@ -5479,25 +5480,37 @@
"ap-northeast-2",
"ap-northeast-3",
"ap-south-1",
"ap-south-2",
"ap-southeast-1",
"ap-southeast-2",
"ap-southeast-3",
"ap-southeast-4",
"ap-southeast-5",
"ap-southeast-7",
"ca-central-1",
"ca-west-1",
"eu-central-1",
"eu-central-2",
"eu-north-1",
"eu-south-1",
"eu-south-2",
"eu-west-1",
"eu-west-2",
"eu-west-3",
"il-central-1",
"me-central-1",
"me-south-1",
"mx-central-1",
"sa-east-1",
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2"
],
"aws-cn": [],
"aws-cn": [
"cn-north-1",
"cn-northwest-1"
],
"aws-us-gov": [
"us-gov-east-1",
"us-gov-west-1"
@@ -5513,18 +5526,27 @@
"ap-northeast-2",
"ap-northeast-3",
"ap-south-1",
"ap-south-2",
"ap-southeast-1",
"ap-southeast-2",
"ap-southeast-3",
"ap-southeast-4",
"ap-southeast-5",
"ap-southeast-7",
"ca-central-1",
"ca-west-1",
"eu-central-1",
"eu-central-2",
"eu-north-1",
"eu-south-1",
"eu-south-2",
"eu-west-1",
"eu-west-2",
"eu-west-3",
"il-central-1",
"me-central-1",
"me-south-1",
"mx-central-1",
"sa-east-1",
"us-east-1",
"us-east-2",
@@ -7946,6 +7968,7 @@
"sa-east-1",
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2"
],
"aws-cn": [],
@@ -9829,6 +9852,13 @@
]
}
},
"sagemakerautopilot": {
"regions": {
"aws": [],
"aws-cn": [],
"aws-us-gov": []
}
},
"savingsplans": {
"regions": {
"aws": [
@@ -25,47 +25,18 @@ class Firehose(AWSService):
def _list_delivery_streams(self, regional_client):
logger.info("Firehose - Listing delivery streams...")
try:
# Manual pagination using ExclusiveStartDeliveryStreamName
# This ensures we get all streams alphabetically without duplicates
exclusive_start_delivery_stream_name = None
processed_streams = set()
while True:
kwargs = {}
if exclusive_start_delivery_stream_name:
kwargs["ExclusiveStartDeliveryStreamName"] = (
exclusive_start_delivery_stream_name
for stream_name in regional_client.list_delivery_streams()[
"DeliveryStreamNames"
]:
stream_arn = f"arn:{self.audited_partition}:firehose:{regional_client.region}:{self.audited_account}:deliverystream/{stream_name}"
if not self.audit_resources or (
is_resource_filtered(stream_arn, self.audit_resources)
):
self.delivery_streams[stream_arn] = DeliveryStream(
arn=stream_arn,
name=stream_name,
region=regional_client.region,
)
response = regional_client.list_delivery_streams(**kwargs)
stream_names = response.get("DeliveryStreamNames", [])
for stream_name in stream_names:
if stream_name in processed_streams:
continue
processed_streams.add(stream_name)
stream_arn = f"arn:{self.audited_partition}:firehose:{regional_client.region}:{self.audited_account}:deliverystream/{stream_name}"
if not self.audit_resources or (
is_resource_filtered(stream_arn, self.audit_resources)
):
self.delivery_streams[stream_arn] = DeliveryStream(
arn=stream_arn,
name=stream_name,
region=regional_client.region,
)
if not response.get("HasMoreDeliveryStreams", False):
break
# Set the starting point for the next page (last stream name from current batch)
# ExclusiveStartDeliveryStreamName will start after this stream alphabetically
if stream_names:
exclusive_start_delivery_stream_name = stream_names[-1]
else:
break
except ClientError as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -90,45 +61,13 @@ class Firehose(AWSService):
describe_stream = self.regional_clients[
stream.region
].describe_delivery_stream(DeliveryStreamName=stream.name)
encryption_config = describe_stream.get(
"DeliveryStreamDescription", {}
).get("DeliveryStreamEncryptionConfiguration", {})
stream.kms_encryption = EncryptionStatus(
encryption_config.get("Status", "DISABLED")
)
stream.kms_key_arn = encryption_config.get("KeyARN", "")
stream.delivery_stream_type = describe_stream.get(
"DeliveryStreamDescription", {}
).get("DeliveryStreamType", "")
source_config = describe_stream.get("DeliveryStreamDescription", {}).get(
"Source", {}
)
stream.source = Source(
direct_put=DirectPutSourceDescription(
troughput_hint_in_mb_per_sec=source_config.get(
"DirectPutSourceDescription", {}
).get("TroughputHintInMBPerSec", 0)
),
kinesis_stream=KinesisStreamSourceDescription(
kinesis_stream_arn=source_config.get(
"KinesisStreamSourceDescription", {}
).get("KinesisStreamARN", "")
),
msk=MSKSourceDescription(
msk_cluster_arn=source_config.get("MSKSourceDescription", {}).get(
"MSKClusterARN", ""
)
),
database=DatabaseSourceDescription(
endpoint=source_config.get("DatabaseSourceDescription", {}).get(
"Endpoint", ""
)
),
)
except ClientError as error:
logger.error(
f"{stream.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
@@ -146,39 +85,6 @@ class EncryptionStatus(Enum):
DISABLING_FAILED = "DISABLING_FAILED"
class DirectPutSourceDescription(BaseModel):
"""Model for the DirectPut source of a Firehose stream"""
troughput_hint_in_mb_per_sec: int = Field(default_factory=int)
class KinesisStreamSourceDescription(BaseModel):
"""Model for the KinesisStream source of a Firehose stream"""
kinesis_stream_arn: str = Field(default_factory=str)
class MSKSourceDescription(BaseModel):
"""Model for the MSK source of a Firehose stream"""
msk_cluster_arn: str = Field(default_factory=str)
class DatabaseSourceDescription(BaseModel):
"""Model for the Database source of a Firehose stream"""
endpoint: str = Field(default_factory=str)
class Source(BaseModel):
"""Model for the source of a Firehose stream"""
direct_put: Optional[DirectPutSourceDescription]
kinesis_stream: Optional[KinesisStreamSourceDescription]
msk: Optional[MSKSourceDescription]
database: Optional[DatabaseSourceDescription]
class DeliveryStream(BaseModel):
"""Model for a Firehose Delivery Stream"""
@@ -188,5 +94,3 @@ class DeliveryStream(BaseModel):
kms_key_arn: Optional[str] = Field(default_factory=str)
kms_encryption: Optional[str] = Field(default_factory=str)
tags: Optional[List[Dict[str, str]]] = Field(default_factory=list)
delivery_stream_type: Optional[str] = Field(default_factory=str)
source: Source = Field(default_factory=Source)
@@ -3,8 +3,6 @@ from typing import List
from prowler.lib.check.models import Check, Check_Report_AWS
from prowler.providers.aws.services.firehose.firehose_client import firehose_client
from prowler.providers.aws.services.firehose.firehose_service import EncryptionStatus
from prowler.providers.aws.services.kinesis.kinesis_client import kinesis_client
from prowler.providers.aws.services.kinesis.kinesis_service import EncryptionType
class firehose_stream_encrypted_at_rest(Check):
@@ -24,22 +22,14 @@ class firehose_stream_encrypted_at_rest(Check):
findings = []
for stream in firehose_client.delivery_streams.values():
report = Check_Report_AWS(metadata=self.metadata(), resource=stream)
report.status = "FAIL"
report.status_extended = f"Firehose Stream {stream.name} does not have at rest encryption enabled or the source stream is not encrypted."
report.status = "PASS"
report.status_extended = (
f"Firehose Stream {stream.name} does have at rest encryption enabled."
)
# Encrypted Kinesis Stream source
if stream.delivery_stream_type == "KinesisStreamAsSource":
source_stream = kinesis_client.streams.get(
stream.source.kinesis_stream.kinesis_stream_arn
)
if source_stream.encrypted_at_rest != EncryptionType.NONE:
report.status = "PASS"
report.status_extended = f"Firehose Stream {stream.name} does not have at rest encryption enabled but the source stream {source_stream.name} has at rest encryption enabled."
# Check if the stream has encryption enabled directly
elif stream.kms_encryption == EncryptionStatus.ENABLED:
report.status = "PASS"
report.status_extended = f"Firehose Stream {stream.name} does have at rest encryption enabled."
if stream.kms_encryption != EncryptionStatus.ENABLED:
report.status = "FAIL"
report.status_extended = f"Firehose Stream {stream.name} does not have at rest encryption enabled."
findings.append(report)
@@ -0,0 +1,4 @@
from prowler.providers.azure.services.recovery.recovery_service import Recovery
from prowler.providers.common.provider import Provider
recovery_client = Recovery(Provider.get_global_provider())
@@ -0,0 +1,101 @@
from typing import Optional
from azure.mgmt.recoveryservices import RecoveryServicesClient
from azure.mgmt.recoveryservicesbackup import RecoveryServicesBackupClient
from azure.mgmt.recoveryservicesbackup.activestamp.models import DataSourceType
from pydantic import BaseModel, Field
from prowler.lib.logger import logger
from prowler.providers.azure.azure_provider import AzureProvider
from prowler.providers.azure.lib.service.service import AzureService
class BackupItem(BaseModel):
"""Minimal BackupItem: only essential identifying and descriptive fields."""
id: str
name: str
workload_type: Optional[DataSourceType]
class BackupVault(BaseModel):
"""Minimal BackupVault: only essential identifying fields and its backup items."""
id: str
name: str
location: str
backup_protected_items: dict[str, BackupItem] = Field(default_factory=dict)
class Recovery(AzureService):
def __init__(self, provider: AzureProvider):
super().__init__(RecoveryServicesClient, provider)
self.vaults: dict[str, dict[str, BackupVault]] = self._get_vaults()
RecoveryBackup(provider, self.vaults)
def _get_vaults(self) -> dict[str, dict[str, BackupVault]]:
"""
Retrieve all Recovery Services vaults for each subscription.
Returns:
Nested dictionary of vaults by subscription.
"""
logger.info("Recovery - Getting Recovery Services vaults...")
vaults_dict: dict[str, dict[str, BackupVault]] = {}
try:
vaults_dict: dict[str, dict[str, BackupVault]] = {}
for subscription_name, client in self.clients.items():
vaults = client.vaults.list_by_subscription_id()
vaults_dict[subscription_name] = {}
for vault in vaults:
vault_obj = BackupVault(
id=vault.id,
name=vault.name,
location=vault.location,
)
vaults_dict[subscription_name][vault_obj.id] = vault_obj
except Exception as error:
logger.error(
f"Subscription name: {subscription_name} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return vaults_dict
class RecoveryBackup(AzureService):
def __init__(
self, provider: AzureProvider, vaults: dict[str, dict[str, BackupVault]]
):
super().__init__(RecoveryServicesBackupClient, provider)
for subscription_name, vaults in vaults.items():
for vault in vaults.values():
vault.backup_protected_items = self._get_backup_protected_items(
subscription_name=subscription_name, vault=vault
)
def _get_backup_protected_items(
self, subscription_name: str, vault: BackupVault
) -> dict[str, BackupItem]:
"""
Retrieve all backup protected items for a given vault.
"""
logger.info("Recovery - Getting backup protected items...")
backup_protected_items_dict: dict[str, BackupItem] = {}
try:
backup_protected_items = self.clients[
subscription_name
].backup_protected_items.list(
vault_name=vault.name,
resource_group_name=vault.id.split("/")[4],
)
for item in backup_protected_items:
item_properties = getattr(item, "properties", None)
backup_protected_items_dict[item.id] = BackupItem(
id=item.id,
name=item.name,
workload_type=(
item_properties.workload_type if item_properties else None
),
)
except Exception as e:
logger.error(f"Recovery - Error getting backup protected items: {e}")
return backup_protected_items_dict
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import List, Optional
from azure.mgmt.storage import StorageManagementClient
from pydantic import BaseModel
@@ -32,7 +33,7 @@ class Storage(AzureService):
resouce_group_name = None
key_expiration_period_in_days = None
if storage_account.key_policy:
key_expiration_period_in_days = int(
key_expiration_period_in_days = (
storage_account.key_policy.key_expiration_period_in_days
)
replication_settings = ReplicationSettings(storage_account.sku.name)
@@ -157,6 +158,21 @@ class Storage(AzureService):
"share_delete_retention_policy",
None,
)
smb_channel_encryption_raw = getattr(
getattr(
getattr(
file_service_properties,
"protocol_settings",
None,
),
"smb",
None,
),
"channel_encryption",
None,
)
account.file_service_properties = FileServiceProperties(
id=file_service_properties.id,
name=file_service_properties.name,
@@ -173,6 +189,13 @@ class Storage(AzureService):
0,
),
),
smb_protocol_settings=SMBProtocolSettings(
channel_encryption=(
smb_channel_encryption_raw.rstrip(";").split(";")
if smb_channel_encryption_raw
else []
)
),
)
except Exception as error:
logger.error(
@@ -180,26 +203,30 @@ class Storage(AzureService):
)
class DeleteRetentionPolicy(BaseModel):
@dataclass
class DeleteRetentionPolicy:
enabled: bool
days: int
class BlobProperties(BaseModel):
@dataclass
class BlobProperties:
id: str
name: str
type: str
default_service_version: str
container_delete_retention_policy: DeleteRetentionPolicy
default_service_version: Optional[str] = None
versioning_enabled: Optional[bool] = None
versioning_enabled: bool = False
class NetworkRuleSet(BaseModel):
@dataclass
class NetworkRuleSet:
bypass: str
default_action: str
class PrivateEndpointConnection(BaseModel):
@dataclass
class PrivateEndpointConnection:
id: str
name: str
type: str
@@ -216,26 +243,32 @@ class ReplicationSettings(Enum):
STANDARD_RAGZRS = "Standard_RAGZRS"
class SMBProtocolSettings(BaseModel):
channel_encryption: list[str]
class FileServiceProperties(BaseModel):
id: str
name: str
type: str
share_delete_retention_policy: DeleteRetentionPolicy
smb_protocol_settings: SMBProtocolSettings
class Account(BaseModel):
@dataclass
class Account:
id: str
name: str
location: str
resouce_group_name: str
enable_https_traffic_only: bool
infrastructure_encryption: Optional[bool] = None
infrastructure_encryption: bool
allow_blob_public_access: bool
network_rule_set: NetworkRuleSet
encryption_type: str
minimum_tls_version: str
private_endpoint_connections: list[PrivateEndpointConnection]
key_expiration_period_in_days: Optional[int] = None
private_endpoint_connections: List[PrivateEndpointConnection]
key_expiration_period_in_days: str
location: str
replication_settings: ReplicationSettings = ReplicationSettings.STANDARD_LRS
allow_cross_tenant_replication: bool = True
allow_shared_key_access: bool = True
@@ -0,0 +1,30 @@
{
"Provider": "azure",
"CheckID": "storage_smb_channel_encryption_with_secure_algorithm",
"CheckTitle": "Ensure SMB channel encryption uses a secure algorithm for SMB file shares",
"CheckType": [],
"ServiceName": "storage",
"SubServiceName": "",
"ResourceIdTemplate": "/subscriptions/{subscription_id}/resourceGroups/{resource_group}/providers/Microsoft.Storage/storageAccounts/{storageAccountName}/fileServices/default",
"Severity": "medium",
"ResourceType": "AzureStorageAccount",
"Description": "Implement SMB channel encryption with a secure algorithm for SMB file shares to ensure data confidentiality and integrity in transit.",
"Risk": "Not using the recommended SMB channel encryption may expose data transmitted over SMB channels to unauthorized interception and tampering.",
"RelatedUrl": "https://learn.microsoft.com/en-us/azure/well-architected/service-guides/azure-files#recommendations-for-smb-file-shares",
"Remediation": {
"Code": {
"CLI": "az storage account file-service-properties update --resource-group <resource-group> --account-name <storage-account> --channel-encryption AES-256-GCM",
"NativeIaC": "",
"Other": "",
"Terraform": ""
},
"Recommendation": {
"Text": "Use the portal, CLI or PowerShell to set the SMB channel encryption to a secure algorithm.",
"Url": "https://learn.microsoft.com/en-us/azure/storage/files/files-smb-protocol?tabs=azure-portal#smb-security-settings"
}
},
"Categories": [],
"DependsOn": [],
"RelatedTo": [],
"Notes": "This check passes if SMB channel encryption is set to a secure algorithm."
}
@@ -0,0 +1,51 @@
from prowler.lib.check.models import Check, Check_Report_Azure
from prowler.providers.azure.services.storage.storage_client import storage_client
SECURE_ENCRYPTION_ALGORITHMS = ["AES-256-GCM"]
class storage_smb_channel_encryption_with_secure_algorithm(Check):
"""
Ensure SMB channel encryption for file shares is set to the recommended algorithm (AES-256-GCM or higher).
This check evaluates whether SMB file shares are configured to use only the recommended SMB channel encryption algorithms.
- PASS: Storage account has the recommended SMB channel encryption (AES-256-GCM or higher) enabled for file shares.
- FAIL: Storage account does not have the recommended SMB channel encryption enabled for file shares or uses an unsupported algorithm.
"""
def execute(self) -> list[Check_Report_Azure]:
findings = []
for subscription, storage_accounts in storage_client.storage_accounts.items():
for account in storage_accounts:
if account.file_service_properties:
pretty_current_algorithms = (
", ".join(
account.file_service_properties.smb_protocol_settings.channel_encryption
)
if account.file_service_properties.smb_protocol_settings.channel_encryption
else "none"
)
report = Check_Report_Azure(
metadata=self.metadata(),
resource=account.file_service_properties,
)
report.subscription = subscription
report.resource_name = account.name
if (
not account.file_service_properties.smb_protocol_settings.channel_encryption
):
report.status = "FAIL"
report.status_extended = f"Storage account {account.name} from subscription {subscription} does not have SMB channel encryption enabled for file shares."
elif any(
algorithm in SECURE_ENCRYPTION_ALGORITHMS
for algorithm in account.file_service_properties.smb_protocol_settings.channel_encryption
):
report.status = "PASS"
report.status_extended = f"Storage account {account.name} from subscription {subscription} has a secure algorithm for SMB channel encryption ({', '.join(SECURE_ENCRYPTION_ALGORITHMS)}) enabled for file shares since it supports {pretty_current_algorithms}."
else:
report.status = "FAIL"
report.status_extended = f"Storage account {account.name} from subscription {subscription} does not have SMB channel encryption with a secure algorithm for file shares since it supports {pretty_current_algorithms}."
findings.append(report)
return findings
@@ -0,0 +1,30 @@
{
"Provider": "azure",
"CheckID": "vm_backup_enabled",
"CheckTitle": "Ensure Backups are enabled for Azure Virtual Machines",
"CheckType": [],
"ServiceName": "vm",
"SubServiceName": "",
"ResourceIdTemplate": "",
"Severity": "high",
"ResourceType": "Microsoft.Compute/virtualMachines",
"Description": "Ensure that Microsoft Azure Backup service is in use for your Azure virtual machines (VMs) to protect against accidental deletion or corruption.",
"Risk": "Without Azure Backup enabled, VMs are at risk of data loss due to accidental deletion, corruption, or other failures, and recovery options are limited.",
"RelatedUrl": "https://docs.microsoft.com/en-us/azure/backup/backup-overview",
"Remediation": {
"Code": {
"CLI": "az backup protection enable-for-vm --resource-group <resource-group> --vm <vm-name> --vault-name <vault-name> --policy-name DefaultPolicy",
"NativeIaC": "",
"Other": "https://learn.microsoft.com/en-us/azure/backup/quick-backup-vm-portal",
"Terraform": ""
},
"Recommendation": {
"Text": "Enable Azure Backup for each VM by associating it with a Recovery Services vault and a backup policy.",
"Url": "https://docs.microsoft.com/en-us/azure/backup/quick-backup-vm-portal"
}
},
"Categories": [],
"DependsOn": [],
"RelatedTo": [],
"Notes": ""
}
@@ -0,0 +1,50 @@
from azure.mgmt.recoveryservicesbackup.activestamp.models import DataSourceType
from prowler.lib.check.models import Check, Check_Report_Azure
from prowler.providers.azure.services.recovery.recovery_client import recovery_client
from prowler.providers.azure.services.vm.vm_client import vm_client
class vm_backup_enabled(Check):
"""
Ensure that Microsoft Azure Backup service is in use for your Azure virtual machines (VMs).
This check evaluates whether each Azure VM in the subscription is protected by Azure Backup.
- PASS: The VM is protected by Azure Backup (present in a Recovery Services vault).
- FAIL: The VM is not protected by Azure Backup (not present in any Recovery Services vault).
"""
def execute(self) -> list[Check_Report_Azure]:
"""Execute Azure VM backup enabled check.
Returns:
A list of reports containing the result of the check.
"""
findings = []
for subscription_name, vms in vm_client.virtual_machines.items():
vaults = recovery_client.vaults.get(subscription_name, {})
for vm in vms.values():
found = False
found_vault_name = None
for vault in vaults.values():
for backup_item in vault.backup_protected_items.values():
if (
backup_item.workload_type == DataSourceType.VM
and backup_item.name.split(";")[-1] == vm.resource_name
):
found = True
found_vault_name = vault.name
break
if found:
break
report = Check_Report_Azure(metadata=self.metadata(), resource=vm)
report.subscription = subscription_name
if found:
report.status = "PASS"
report.status_extended = f"VM {vm.resource_name} in subscription {subscription_name} is protected by Azure Backup (vault: {found_vault_name})."
else:
report.status = "FAIL"
report.status_extended = f"VM {vm.resource_name} in subscription {subscription_name} is not protected by Azure Backup."
findings.append(report)
return findings
@@ -0,0 +1,30 @@
{
"Provider": "azure",
"CheckID": "vm_ensure_using_approved_images",
"CheckTitle": "Ensure that Azure VMs are using an approved machine image.",
"CheckType": [],
"ServiceName": "vm",
"SubServiceName": "image",
"ResourceIdTemplate": "/subscriptions/<subscription-id>/resourceGroups/<resource-group-name>/providers/Microsoft.Compute/images/<virtual-machine-image-id>",
"Severity": "medium",
"ResourceType": "Microsoft.Compute/images",
"Description": "Ensure that all your Azure virtual machine instances are launched from approved machine images only.",
"Risk": "An approved machine image is a custom virtual machine (VM) image that contains a pre-configured OS and a well-defined stack of server software approved by Azure, fully configured to run your application. Using approved (golden) machine images to launch new VM instances within your Azure cloud environment brings major benefits such as fast and stable application deployment and scaling, secure application stack upgrades, and versioning.",
"RelatedUrl": "https://learn.microsoft.com/en-us/azure/virtual-machines/windows/create-vm-generalized-managed",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "https://www.trendmicro.com/cloudoneconformity/knowledge-base/azure/VirtualMachines/approved-machine-image.html",
"Terraform": ""
},
"Recommendation": {
"Text": "Re-create the required VM instances using the approved (golden) machine image.",
"Url": "https://docs.microsoft.com/en-us/azure/virtual-machines/windows/create-vm-generalized-managed"
}
},
"Categories": [],
"DependsOn": [],
"RelatedTo": [],
"Notes": "This check only validates if the VM was launched from a custom image. It does not validate the image content or security baseline."
}
@@ -0,0 +1,33 @@
from prowler.lib.check.models import Check, Check_Report_Azure
from prowler.providers.azure.services.vm.vm_client import vm_client
class vm_ensure_using_approved_images(Check):
"""
Ensure that Azure VMs are using an approved (custom) machine image.
This check evaluates whether Azure Virtual Machines are launched from an approved (custom) machine image by checking the image reference ID format.
- PASS: The Azure VM is using an approved custom machine image.
- FAIL: The Azure VM is not using an approved custom machine image.
"""
def execute(self):
findings = []
for subscription_name, vms in vm_client.virtual_machines.items():
for vm in vms.values():
report = Check_Report_Azure(metadata=self.metadata(), resource=vm)
report.subscription = subscription_name
image_id = getattr(vm, "image_reference", None)
if (
image_id
and image_id.startswith("/subscriptions/")
and "/providers/Microsoft.Compute/images/" in image_id
):
report.status = "PASS"
report.status_extended = f"VM {vm.resource_name} in subscription {subscription_name} is using an approved machine image: {image_id.split('/')[-1]}."
else:
report.status = "FAIL"
report.status_extended = f"VM {vm.resource_name} in subscription {subscription_name} is not using an approved machine image."
findings.append(report)
return findings
@@ -0,0 +1,30 @@
{
"Provider": "azure",
"CheckID": "vm_linux_enforce_ssh_authentication",
"CheckTitle": "Ensure SSH key authentication is enforced on Linux-based Virtual Machines",
"CheckType": [],
"ServiceName": "vm",
"SubServiceName": "",
"ResourceIdTemplate": "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Compute/virtualMachines/{vmName}",
"Severity": "high",
"ResourceType": "Microsoft.Compute/virtualMachines",
"Description": "Ensure that Azure Linux-based virtual machines are configured to use SSH keys by disabling password authentication.",
"Risk": "Allowing password-based SSH authentication increases the risk of brute-force attacks and unauthorized access. Enforcing SSH key authentication ensures only users with the private key can access the VM.",
"RelatedUrl": "https://docs.microsoft.com/en-us/azure/virtual-machines/linux/create-ssh-keys-detailed",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "https://www.trendmicro.com/cloudoneconformity/knowledge-base/azure/VirtualMachines/ssh-authentication-type.html",
"Terraform": ""
},
"Recommendation": {
"Text": "Recreate Linux VMs with SSH key authentication enabled and password authentication disabled.",
"Url": "https://docs.microsoft.com/en-us/azure/virtual-machines/linux/create-ssh-keys-detailed"
}
},
"Categories": [],
"DependsOn": [],
"RelatedTo": [],
"Notes": ""
}
@@ -0,0 +1,29 @@
from prowler.lib.check.models import Check, Check_Report_Azure
from prowler.providers.azure.services.vm.vm_client import vm_client
class vm_linux_enforce_ssh_authentication(Check):
"""
Ensure that Azure Linux-based virtual machines are configured to use SSH keys (password authentication is disabled).
This check evaluates whether disablePasswordAuthentication is set to True for Linux VMs to ensure only SSH key authentication is allowed.
- PASS: VM has password authentication disabled (SSH key authentication enforced).
- FAIL: VM has password authentication enabled (password-based SSH allowed).
"""
def execute(self) -> list[Check_Report_Azure]:
findings = []
for subscription_name, vms in vm_client.virtual_machines.items():
for vm in vms.values():
if vm.linux_configuration:
report = Check_Report_Azure(metadata=self.metadata(), resource=vm)
report.subscription = subscription_name
if vm.linux_configuration.disable_password_authentication:
report.status = "PASS"
report.status_extended = f"VM {vm.resource_name} in subscription {subscription_name} has password authentication disabled (SSH key authentication enforced)."
else:
report.status = "FAIL"
report.status_extended = f"VM {vm.resource_name} in subscription {subscription_name} has password authentication enabled (password-based SSH allowed)."
findings.append(report)
return findings
@@ -0,0 +1,30 @@
{
"Provider": "azure",
"CheckID": "vm_scaleset_associated_with_load_balancer",
"CheckTitle": "VM Scale Set Is Associated With Load Balancer",
"CheckType": [],
"ServiceName": "vm",
"SubServiceName": "scaleset",
"ResourceIdTemplate": "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Compute/virtualMachineScaleSets/{vmScaleSetName}",
"Severity": "medium",
"ResourceType": "Microsoft.Compute/virtualMachineScaleSets",
"Description": "Ensure that your Azure virtual machine scale sets are using load balancers for traffic distribution.",
"Risk": "Without load balancer integration, Azure virtual machine scale sets may experience reduced availability and potential service disruptions during traffic spikes or instance failures, leading to degraded user experience and potential business impact.",
"RelatedUrl": "https://learn.microsoft.com/en-us/azure/virtual-network/network-overview",
"Remediation": {
"Code": {
"CLI": "",
"NativeIaC": "",
"Other": "https://www.trendmicro.com/cloudoneconformity/knowledge-base/azure/VirtualMachines/associated-load-balancers.html",
"Terraform": ""
},
"Recommendation": {
"Text": "Attach a load balancer to your Azure virtual machine scale set to ensure high availability and optimal traffic distribution.",
"Url": "https://docs.microsoft.com/en-us/azure/load-balancer/load-balancer-overview"
}
},
"Categories": [],
"DependsOn": [],
"RelatedTo": [],
"Notes": ""
}
@@ -0,0 +1,36 @@
from prowler.lib.check.models import Check, Check_Report_Azure
from prowler.providers.azure.services.vm.vm_client import vm_client
class vm_scaleset_associated_with_load_balancer(Check):
"""
Ensure that Azure virtual machine scale sets are associated with a load balancer backend pool.
This check evaluates whether each VM scale set is associated with at least one load balancer backend pool.
- PASS: The scale set is associated with a load balancer backend pool.
- FAIL: The scale set is not associated with any load balancer backend pool.
"""
def execute(self):
findings = []
for subscription, scale_sets in vm_client.vm_scale_sets.items():
for scale_set in scale_sets.values():
report = Check_Report_Azure(
metadata=self.metadata(), resource=scale_set
)
report.subscription = subscription
report.resource_id = scale_set.resource_id
report.resource_name = scale_set.resource_name
report.location = scale_set.location
if scale_set.load_balancer_backend_pools:
report.status = "PASS"
backend_pool_names = [
pool.split("/")[-1]
for pool in scale_set.load_balancer_backend_pools
]
report.status_extended = f"Scale set '{scale_set.resource_name}' in subscription '{subscription}' is associated with load balancer backend pool(s): {', '.join(backend_pool_names)}."
else:
report.status = "FAIL"
report.status_extended = f"Scale set '{scale_set.resource_name}' in subscription '{subscription}' is not associated with any load balancer backend pool."
findings.append(report)
return findings
@@ -15,6 +15,7 @@ class VirtualMachines(AzureService):
super().__init__(ComputeManagementClient, provider)
self.virtual_machines = self._get_virtual_machines()
self.disks = self._get_disks()
self.vm_scale_sets = self._get_vm_scale_sets()
def _get_virtual_machines(self):
logger.info("VirtualMachines - Getting virtual machines...")
@@ -62,6 +63,18 @@ class VirtualMachines(AzureService):
if extension
]
# Collect LinuxConfiguration.disablePasswordAuthentication if available
linux_configuration = None
os_profile = getattr(vm, "os_profile", None)
if os_profile:
linux_conf = getattr(os_profile, "linux_configuration", None)
if linux_conf:
linux_configuration = LinuxConfiguration(
disable_password_authentication=getattr(
linux_conf, "disable_password_authentication", False
)
)
virtual_machines[subscription_name].update(
{
vm.id: VirtualMachine(
@@ -92,6 +105,12 @@ class VirtualMachines(AzureService):
location=vm.location,
security_profile=getattr(vm, "security_profile", None),
extensions=extensions,
image_reference=getattr(
getattr(storage_profile, "image_reference", None),
"id",
None,
),
linux_configuration=linux_configuration,
)
}
)
@@ -137,6 +156,69 @@ class VirtualMachines(AzureService):
return disks
def _get_vm_scale_sets(self) -> dict[str, dict]:
"""
Get all needed information about VM scale sets.
Returns:
A nested dictionary with the following structure:
{
"subscription_name": {
"vm_scale_set_id": VirtualMachineScaleSet()
}
}
"""
logger.info(
"VirtualMachines - Getting VM scale sets and their load balancer associations..."
)
vm_scale_sets = {}
for subscription_name, client in self.clients.items():
try:
scale_sets = client.virtual_machine_scale_sets.list_all()
vm_scale_sets[subscription_name] = {}
for scale_set in scale_sets:
backend_pools = []
nic_configs = []
virtual_machine_profile = getattr(
scale_set, "virtual_machine_profile", None
)
if virtual_machine_profile:
network_profile = getattr(
virtual_machine_profile, "network_profile", None
)
if network_profile:
nic_configs = (
getattr(
network_profile,
"network_interface_configurations",
[],
)
or []
)
for nic in nic_configs:
ip_confs = getattr(nic, "ip_configurations", [])
for ipconf in ip_confs:
pools = getattr(
ipconf, "load_balancer_backend_address_pools", []
)
if pools:
for pool in pools:
if getattr(pool, "id", None):
backend_pools.append(pool.id)
vm_scale_sets[subscription_name][scale_set.id] = (
VirtualMachineScaleSet(
resource_id=scale_set.id,
resource_name=scale_set.name,
location=scale_set.location,
load_balancer_backend_pools=backend_pools,
)
)
except Exception as error:
logger.error(
f"Subscription name: {subscription_name} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return vm_scale_sets
@dataclass
class UefiSettings:
@@ -180,6 +262,10 @@ class VirtualMachineExtension(BaseModel):
id: str
class LinuxConfiguration(BaseModel):
disable_password_authentication: bool
class VirtualMachine(BaseModel):
resource_id: str
resource_name: str
@@ -187,6 +273,8 @@ class VirtualMachine(BaseModel):
security_profile: Optional[SecurityProfile]
extensions: list[VirtualMachineExtension]
storage_profile: Optional[StorageProfile] = None
image_reference: Optional[str] = None
linux_configuration: Optional[LinuxConfiguration] = None
class Disk(BaseModel):
@@ -195,3 +283,10 @@ class Disk(BaseModel):
vms_attached: list[str]
encryption_type: str
location: str
class VirtualMachineScaleSet(BaseModel):
resource_id: str
resource_name: str
location: str
load_balancer_backend_pools: list[str]
+3 -1
View File
@@ -26,6 +26,8 @@ dependencies = [
"azure-mgmt-monitor==6.0.2",
"azure-mgmt-network==28.1.0",
"azure-mgmt-rdbms==10.1.0",
"azure-mgmt-recoveryservices==3.1.0",
"azure-mgmt-recoveryservicesbackup==9.2.0",
"azure-mgmt-resource==23.3.0",
"azure-mgmt-search==9.1.0",
"azure-mgmt-security==7.0.0",
@@ -68,7 +70,7 @@ maintainers = [{name = "Prowler Engineering", email = "engineering@prowler.com"}
name = "prowler"
readme = "README.md"
requires-python = ">3.9.1,<3.13"
version = "5.8.1"
version = "5.9.0"
[project.scripts]
prowler = "prowler.__main__:prowler"
@@ -2,13 +2,8 @@ from boto3 import client
from moto import mock_aws
from prowler.providers.aws.services.firehose.firehose_service import (
DatabaseSourceDescription,
DirectPutSourceDescription,
EncryptionStatus,
Firehose,
KinesisStreamSourceDescription,
MSKSourceDescription,
Source,
)
from tests.providers.aws.utils import (
AWS_ACCOUNT_NUMBER,
@@ -157,102 +152,3 @@ class Test_Firehose_Service:
firehose.delivery_streams[arn].kms_key_arn
== f"arn:aws:kms:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:key/test-kms-key-id"
)
@mock_aws
def test_describe_delivery_stream_source_direct_put(self):
# Generate S3 client
s3_client = client("s3", region_name=AWS_REGION_EU_WEST_1)
s3_client.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": AWS_REGION_EU_WEST_1},
)
# Generate Firehose client
firehose_client = client("firehose", region_name=AWS_REGION_EU_WEST_1)
delivery_stream = firehose_client.create_delivery_stream(
DeliveryStreamName="test-delivery-stream",
DeliveryStreamType="DirectPut",
S3DestinationConfiguration={
"RoleARN": "arn:aws:iam::012345678901:role/firehose-role",
"BucketARN": "arn:aws:s3:::test-bucket",
"Prefix": "",
"BufferingHints": {"IntervalInSeconds": 300, "SizeInMBs": 5},
"CompressionFormat": "UNCOMPRESSED",
},
Tags=[{"Key": "key", "Value": "value"}],
)
arn = delivery_stream["DeliveryStreamARN"]
# Firehose Client for this test class
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
firehose = Firehose(aws_provider)
assert len(firehose.delivery_streams) == 1
assert firehose.delivery_streams[arn].delivery_stream_type == "DirectPut"
# Test Source structure
assert isinstance(firehose.delivery_streams[arn].source, Source)
assert isinstance(
firehose.delivery_streams[arn].source.direct_put, DirectPutSourceDescription
)
assert isinstance(
firehose.delivery_streams[arn].source.kinesis_stream,
KinesisStreamSourceDescription,
)
assert isinstance(
firehose.delivery_streams[arn].source.msk, MSKSourceDescription
)
assert isinstance(
firehose.delivery_streams[arn].source.database, DatabaseSourceDescription
)
@mock_aws
def test_describe_delivery_stream_source_kinesis_stream(self):
# Generate Kinesis client
kinesis_client = client("kinesis", region_name=AWS_REGION_EU_WEST_1)
kinesis_client.create_stream(
StreamName="test-kinesis-stream",
ShardCount=1,
)
kinesis_stream_arn = f"arn:aws:kinesis:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:stream/test-kinesis-stream"
# Generate Firehose client
firehose_client = client("firehose", region_name=AWS_REGION_EU_WEST_1)
delivery_stream = firehose_client.create_delivery_stream(
DeliveryStreamName="test-delivery-stream",
DeliveryStreamType="KinesisStreamAsSource",
KinesisStreamSourceConfiguration={
"KinesisStreamARN": kinesis_stream_arn,
"RoleARN": "arn:aws:iam::012345678901:role/firehose-role",
},
S3DestinationConfiguration={
"RoleARN": "arn:aws:iam::012345678901:role/firehose-role",
"BucketARN": "arn:aws:s3:::test-bucket",
"Prefix": "",
"BufferingHints": {"IntervalInSeconds": 300, "SizeInMBs": 5},
"CompressionFormat": "UNCOMPRESSED",
},
Tags=[{"Key": "key", "Value": "value"}],
)
arn = delivery_stream["DeliveryStreamARN"]
# Firehose Client for this test class
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
firehose = Firehose(aws_provider)
assert len(firehose.delivery_streams) == 1
assert (
firehose.delivery_streams[arn].delivery_stream_type
== "KinesisStreamAsSource"
)
# Test Source structure
assert isinstance(firehose.delivery_streams[arn].source, Source)
assert isinstance(
firehose.delivery_streams[arn].source.kinesis_stream,
KinesisStreamSourceDescription,
)
assert (
firehose.delivery_streams[arn].source.kinesis_stream.kinesis_stream_arn
== kinesis_stream_arn
)
@@ -198,7 +198,7 @@ class Test_firehose_stream_encrypted_at_rest:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"Firehose Stream {stream_name} does not have at rest encryption enabled or the source stream is not encrypted."
== f"Firehose Stream {stream_name} does not have at rest encryption enabled."
)
@mock_aws
@@ -253,74 +253,5 @@ class Test_firehose_stream_encrypted_at_rest:
assert result[0].status == "FAIL"
assert (
result[0].status_extended
== f"Firehose Stream {stream_name} does not have at rest encryption enabled or the source stream is not encrypted."
== f"Firehose Stream {stream_name} does not have at rest encryption enabled."
)
@mock_aws
def test_stream_kinesis_source_encrypted(self):
# Generate Kinesis client
kinesis_client = client("kinesis", region_name=AWS_REGION_EU_WEST_1)
kinesis_client.create_stream(
StreamName="test-kinesis-stream",
ShardCount=1,
)
kinesis_stream_arn = f"arn:aws:kinesis:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:stream/test-kinesis-stream"
# Enable encryption on the Kinesis stream
kinesis_client.start_stream_encryption(
StreamName="test-kinesis-stream",
EncryptionType="KMS",
KeyId=f"arn:aws:kms:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:key/test-kms-key-id",
)
# Generate Firehose client
firehose_client = client("firehose", region_name=AWS_REGION_EU_WEST_1)
delivery_stream = firehose_client.create_delivery_stream(
DeliveryStreamName="test-delivery-stream",
DeliveryStreamType="KinesisStreamAsSource",
KinesisStreamSourceConfiguration={
"KinesisStreamARN": kinesis_stream_arn,
"RoleARN": "arn:aws:iam::012345678901:role/firehose-role",
},
S3DestinationConfiguration={
"RoleARN": "arn:aws:iam::012345678901:role/firehose-role",
"BucketARN": "arn:aws:s3:::test-bucket",
"Prefix": "",
"BufferingHints": {"IntervalInSeconds": 300, "SizeInMBs": 5},
"CompressionFormat": "UNCOMPRESSED",
},
Tags=[{"Key": "key", "Value": "value"}],
)
arn = delivery_stream["DeliveryStreamARN"]
stream_name = arn.split("/")[-1]
from prowler.providers.aws.services.firehose.firehose_service import Firehose
from prowler.providers.aws.services.kinesis.kinesis_service import Kinesis
aws_provider = set_mocked_aws_provider([AWS_REGION_EU_WEST_1])
with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
):
with mock.patch(
"prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest.firehose_client",
new=Firehose(aws_provider),
):
with mock.patch(
"prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest.kinesis_client",
new=Kinesis(aws_provider),
):
# Test Check
from prowler.providers.aws.services.firehose.firehose_stream_encrypted_at_rest.firehose_stream_encrypted_at_rest import (
firehose_stream_encrypted_at_rest,
)
check = firehose_stream_encrypted_at_rest()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert (
result[0].status_extended
== f"Firehose Stream {stream_name} does not have at rest encryption enabled but the source stream test-kinesis-stream has at rest encryption enabled."
)
@@ -110,7 +110,6 @@ class TestAzureProvider:
return_value={},
),
):
with pytest.raises(AzureNoAuthenticationMethodError) as exception:
_ = AzureProvider(
az_cli_auth,
@@ -151,7 +150,6 @@ class TestAzureProvider:
return_value={},
),
):
with pytest.raises(AzureBrowserAuthNoTenantIDError) as exception:
_ = AzureProvider(
az_cli_auth,
@@ -193,7 +191,6 @@ class TestAzureProvider:
return_value={},
),
):
with pytest.raises(AzureTenantIDNoBrowserAuthError) as exception:
_ = AzureProvider(
az_cli_auth,
@@ -224,7 +221,6 @@ class TestAzureProvider:
"prowler.providers.azure.azure_provider.SubscriptionClient"
) as mock_resource_client,
):
# Mock the return value of DefaultAzureCredential
mock_credentials = MagicMock()
mock_credentials.get_token.return_value = AccessToken(
@@ -266,7 +262,6 @@ class TestAzureProvider:
"prowler.providers.azure.azure_provider.AzureProvider.validate_static_credentials"
) as mock_validate_static_credentials,
):
# Mock the return value of DefaultAzureCredential
mock_credentials = MagicMock()
mock_credentials.get_token.return_value = AccessToken(
@@ -317,7 +312,6 @@ class TestAzureProvider:
"prowler.providers.azure.azure_provider.AzureProvider.validate_static_credentials"
) as mock_validate_static_credentials,
):
# Mock the return value of DefaultAzureCredential
mock_default_credential.return_value = {
"client_id": str(uuid4()),
@@ -368,7 +362,6 @@ class TestAzureProvider:
"prowler.providers.azure.azure_provider.AzureProvider.validate_static_credentials"
) as mock_validate_static_credentials,
):
# Mock the return value of DefaultAzureCredential
mock_default_credential.return_value = {
"client_id": str(uuid4()),
@@ -442,7 +435,6 @@ class TestAzureProvider:
"prowler.providers.azure.azure_provider.AzureProvider.setup_session"
) as mock_setup_session,
):
mock_setup_session.side_effect = AzureHTTPResponseError(
file="test_file", original_exception="Simulated HttpResponseError"
)
@@ -463,7 +455,6 @@ class TestAzureProvider:
with patch(
"prowler.providers.azure.azure_provider.AzureProvider.setup_session"
) as mock_setup_session:
mock_setup_session.side_effect = Exception("Simulated Exception")
with pytest.raises(Exception) as exception:
@@ -91,21 +91,6 @@ class Test_monitor_diagnostic_settings_exists:
)
from prowler.providers.azure.services.storage.storage_service import (
Account,
BlobProperties,
DeleteRetentionPolicy,
NetworkRuleSet,
)
# Create a valid BlobProperties instance
valid_blob_properties = BlobProperties(
id="id",
name="name",
type="type",
default_service_version="default_service_version",
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=False, days=0
),
versioning_enabled=True,
)
monitor_client.diagnostics_settings = {
@@ -153,34 +138,42 @@ class Test_monitor_diagnostic_settings_exists:
name="storageaccountname1",
resouce_group_name="rg",
enable_https_traffic_only=True,
infrastructure_encryption=True,
infrastructure_encryption="Enabled",
allow_blob_public_access=True,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
network_rule_set="AllowAll",
encryption_type="Microsoft.CustomerManagedKeyVault",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days="365",
key_expiration_period_in_days=365,
location="euwest",
blob_properties=valid_blob_properties,
blob_properties=mock.MagicMock(
id="id",
name="name",
type="type",
default_service_version="default_service_version",
container_delete_retention_policy="container_delete_retention_policy",
),
),
Account(
id="/subscriptions/1224a5-123a-123a-123a-1234567890ab/resourceGroups/rg/providers/Microsoft.Storage/storageAccounts/storageaccountname2",
name="storageaccountname2",
resouce_group_name="rg",
enable_https_traffic_only=False,
infrastructure_encryption=True,
infrastructure_encryption="Enabled",
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
network_rule_set="AllowAll",
encryption_type="Microsoft.Storage",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days="365",
key_expiration_period_in_days=365,
location="euwest",
blob_properties=valid_blob_properties,
blob_properties=mock.MagicMock(
id="id",
name="name",
type="type",
default_service_version="default_service_version",
container_delete_retention_policy="container_delete_retention_policy",
),
),
]
}
@@ -78,9 +78,6 @@ class Test_monitor_storage_account_with_activity_logs_cmk_encrypted:
)
from prowler.providers.azure.services.storage.storage_service import (
Account,
BlobProperties,
DeleteRetentionPolicy,
NetworkRuleSet,
)
monitor_client.diagnostics_settings = {
@@ -128,25 +125,20 @@ class Test_monitor_storage_account_with_activity_logs_cmk_encrypted:
name="storageaccountname1",
resouce_group_name="rg",
enable_https_traffic_only=True,
infrastructure_encryption=True, # bool
infrastructure_encryption="Enabled",
allow_blob_public_access=True,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
network_rule_set="AllowAll",
encryption_type="Microsoft.CustomerManagedKeyVault",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days="365", # str
key_expiration_period_in_days=365,
location="euwest",
blob_properties=BlobProperties(
blob_properties=mock.MagicMock(
id="id",
name="name",
type="type",
default_service_version="default_service_version",
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=True, days=7
),
versioning_enabled=True,
container_delete_retention_policy="container_delete_retention_policy",
),
),
Account(
@@ -154,25 +146,20 @@ class Test_monitor_storage_account_with_activity_logs_cmk_encrypted:
name="storageaccountname2",
resouce_group_name="rg",
enable_https_traffic_only=False,
infrastructure_encryption=True, # bool
infrastructure_encryption="Enabled",
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
network_rule_set="AllowAll",
encryption_type="Microsoft.Storage",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days="365", # str
key_expiration_period_in_days=365,
location="euwest",
blob_properties=BlobProperties(
blob_properties=mock.MagicMock(
id="id",
name="name",
type="type",
default_service_version="default_service_version",
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=True, days=7
),
versioning_enabled=False,
container_delete_retention_policy="container_delete_retention_policy",
),
),
]
@@ -78,9 +78,6 @@ class Test_monitor_storage_account_with_activity_logs_is_private:
)
from prowler.providers.azure.services.storage.storage_service import (
Account,
BlobProperties,
DeleteRetentionPolicy,
NetworkRuleSet,
)
monitor_client.diagnostics_settings = {
@@ -128,25 +125,20 @@ class Test_monitor_storage_account_with_activity_logs_is_private:
name="storageaccountname1",
resouce_group_name="rg",
enable_https_traffic_only=True,
infrastructure_encryption=True,
infrastructure_encryption="Enabled",
allow_blob_public_access=True,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
network_rule_set="AllowAll",
encryption_type="Microsoft.Storage",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days=365,
location="euwest",
blob_properties=BlobProperties(
blob_properties=mock.MagicMock(
id="id",
name="name",
type="type",
default_service_version="default_service_version",
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=True, days=7
),
versioning_enabled=True,
container_delete_retention_policy="container_delete_retention_policy",
),
),
Account(
@@ -154,25 +146,20 @@ class Test_monitor_storage_account_with_activity_logs_is_private:
name="storageaccountname2",
resouce_group_name="rg",
enable_https_traffic_only=False,
infrastructure_encryption=True,
infrastructure_encryption="Enabled",
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
network_rule_set="AllowAll",
encryption_type="Microsoft.Storage",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days=365,
location="euwest",
blob_properties=BlobProperties(
blob_properties=mock.MagicMock(
id="id",
name="name",
type="type",
default_service_version="default_service_version",
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=True, days=7
),
versioning_enabled=False,
container_delete_retention_policy="container_delete_retention_policy",
),
),
]
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_account_key_access_disabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=True,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=None,
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
allow_shared_key_access=True,
)
]
@@ -96,18 +91,16 @@ class Test_storage_account_key_access_disabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=None,
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
allow_shared_key_access=False,
)
]
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_blob_public_access_level_is_disabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=True,
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
)
]
}
@@ -95,18 +90,16 @@ class Test_storage_blob_public_access_level_is_disabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=False,
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
)
]
}
@@ -45,28 +45,23 @@ class Test_storage_blob_versioning_is_enabled:
new=storage_client,
),
):
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
blob_properties=storage_account_blob_properties,
)
]
@@ -97,13 +92,12 @@ class Test_storage_blob_versioning_is_enabled:
Account,
BlobProperties,
DeleteRetentionPolicy,
NetworkRuleSet,
)
storage_account_blob_properties = BlobProperties(
id="id",
name="name",
type="type",
id=None,
name=None,
type=None,
default_service_version=None,
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=False, days=0
@@ -115,18 +109,16 @@ class Test_storage_blob_versioning_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
blob_properties=storage_account_blob_properties,
)
]
@@ -166,13 +158,12 @@ class Test_storage_blob_versioning_is_enabled:
Account,
BlobProperties,
DeleteRetentionPolicy,
NetworkRuleSet,
)
storage_account_blob_properties = BlobProperties(
id="id",
name="name",
type="type",
id=None,
name=None,
type=None,
default_service_version=None,
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=False, days=0
@@ -184,18 +175,16 @@ class Test_storage_blob_versioning_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
blob_properties=storage_account_blob_properties,
)
]
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_cross_tenant_replication_disabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=None,
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
allow_cross_tenant_replication=True,
)
]
@@ -96,18 +91,16 @@ class Test_storage_cross_tenant_replication_disabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=None,
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
allow_cross_tenant_replication=False,
)
]
@@ -43,18 +43,18 @@ class Test_storage_default_network_access_rule_is_denied:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
allow_blob_public_access=None,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
default_action="Allow", bypass="AzureServices"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -95,18 +95,18 @@ class Test_storage_default_network_access_rule_is_denied:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
allow_blob_public_access=None,
network_rule_set=NetworkRuleSet(
default_action="Deny", bypass="AzureServices"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_default_to_entra_authorization_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
default_to_entra_authorization=True,
)
]
@@ -96,18 +91,16 @@ class Test_storage_default_to_entra_authorization_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
default_to_entra_authorization=False,
)
]
@@ -43,18 +43,18 @@ class Test_storage_ensure_azure_services_are_trusted_to_access_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
allow_blob_public_access=None,
network_rule_set=NetworkRuleSet(
bypass="None", default_action="Deny"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -95,18 +95,18 @@ class Test_storage_ensure_azure_services_are_trusted_to_access_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
allow_blob_public_access=None,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_ensure_encryption_with_customer_managed_keys:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
)
]
}
@@ -95,18 +90,16 @@ class Test_storage_ensure_encryption_with_customer_managed_keys:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="Microsoft.Keyvault",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
)
]
}
@@ -5,7 +5,7 @@ from prowler.providers.azure.services.storage.storage_service import (
Account,
DeleteRetentionPolicy,
FileServiceProperties,
NetworkRuleSet,
SMBProtocolSettings,
)
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
@@ -45,18 +45,16 @@ class Test_storage_ensure_file_shares_soft_delete_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
file_service_properties=None,
)
]
@@ -90,24 +88,23 @@ class Test_storage_ensure_file_shares_soft_delete_is_enabled:
name="default",
type="Microsoft.Storage/storageAccounts/fileServices",
share_delete_retention_policy=retention_policy,
smb_protocol_settings=SMBProtocolSettings(channel_encryption=[]),
)
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
file_service_properties=file_service_properties,
)
]
@@ -150,24 +147,23 @@ class Test_storage_ensure_file_shares_soft_delete_is_enabled:
name="default",
type="Microsoft.Storage/storageAccounts/fileServices",
share_delete_retention_policy=retention_policy,
smb_protocol_settings=SMBProtocolSettings(channel_encryption=[]),
)
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
file_service_properties=file_service_properties,
)
]
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_ensure_minimum_tls_version_12:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_1",
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -95,18 +90,16 @@ class Test_storage_ensure_minimum_tls_version_12:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -3,7 +3,6 @@ from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
PrivateEndpointConnection,
)
from tests.providers.azure.azure_fixtures import (
@@ -46,18 +45,16 @@ class Test_storage_ensure_private_endpoints_in_storage_accounts:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -100,24 +97,22 @@ class Test_storage_ensure_private_endpoints_in_storage_accounts:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[
PrivateEndpointConnection(
id="f1ef2e48-978a-4b0e-b34f-e6c34a9e0724",
name="Test Private Endpoint Connection",
type="Test Type",
)
],
private_endpoint_connections=PrivateEndpointConnection(
id=str(
uuid4(),
),
name="Test Private Endpoint Connection",
type="Test Type",
),
)
]
}
@@ -5,7 +5,6 @@ from prowler.providers.azure.services.storage.storage_service import (
Account,
BlobProperties,
DeleteRetentionPolicy,
NetworkRuleSet,
)
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
@@ -46,18 +45,16 @@ class Test_storage_ensure_soft_delete_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
blob_properties=storage_account_blob_properties,
)
]
@@ -88,32 +85,29 @@ class Test_storage_ensure_soft_delete_is_enabled:
storage_account_name = "Test Storage Account"
storage_client = mock.MagicMock
storage_account_blob_properties = BlobProperties(
id="id",
name="name",
type="type",
id=None,
name=None,
type=None,
default_service_version=None,
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=False, days=7
),
versioning_enabled=False,
)
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
blob_properties=storage_account_blob_properties,
)
]
@@ -153,32 +147,29 @@ class Test_storage_ensure_soft_delete_is_enabled:
storage_account_name = "Test Storage Account"
storage_client = mock.MagicMock
storage_account_blob_properties = BlobProperties(
id="id",
name="name",
type="type",
id=None,
name=None,
type=None,
default_service_version=None,
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=True, days=7
),
versioning_enabled=True,
)
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
blob_properties=storage_account_blob_properties,
)
]
@@ -3,7 +3,6 @@ from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
ReplicationSettings,
)
from tests.providers.azure.azure_fixtures import (
@@ -44,18 +43,16 @@ class Test_storage_geo_redundant_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
replication_settings=ReplicationSettings.STANDARD_GRS,
)
]
@@ -97,18 +94,16 @@ class Test_storage_geo_redundant_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
network_rule_set=None,
encryption_type=None,
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
replication_settings=ReplicationSettings.STANDARD_LRS,
)
]
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_infrastructure_encryption_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_1",
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -95,18 +90,16 @@ class Test_storage_infrastructure_encryption_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=True,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version="TLS1_1",
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -44,18 +41,16 @@ class Test_storage_key_rotation_90_dayss:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
key_expiration_period_in_days="91",
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
minimum_tls_version="TLS1_1",
key_expiration_period_in_days=expiration_days,
location="westeurope",
private_endpoint_connections=None,
)
]
}
@@ -97,18 +92,16 @@ class Test_storage_key_rotation_90_dayss:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
key_expiration_period_in_days=90,
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days=expiration_days,
location="westeurope",
private_endpoint_connections=None,
)
]
}
@@ -149,18 +142,16 @@ class Test_storage_key_rotation_90_dayss:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
key_expiration_period_in_days=None,
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
private_endpoint_connections=[],
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
)
]
}
@@ -1,10 +1,7 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
NetworkRuleSet,
)
from prowler.providers.azure.services.storage.storage_service import Account
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
@@ -43,18 +40,16 @@ class Test_storage_secure_transfer_required_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version="TLS1_1",
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -95,18 +90,16 @@ class Test_storage_secure_transfer_required_is_enabled:
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=True,
infrastructure_encryption=True,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version="TLS1_1",
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=[],
private_endpoint_connections=None,
)
]
}
@@ -5,8 +5,8 @@ from prowler.providers.azure.services.storage.storage_service import (
BlobProperties,
DeleteRetentionPolicy,
FileServiceProperties,
NetworkRuleSet,
ReplicationSettings,
SMBProtocolSettings,
Storage,
)
from tests.providers.azure.azure_fixtures import (
@@ -21,7 +21,7 @@ def mock_storage_get_storage_accounts(_):
name="name",
type="type",
default_service_version=None,
container_delete_retention_policy=DeleteRetentionPolicy(enabled=True, days=7),
container_delete_retention_policy=None,
)
retention_policy = DeleteRetentionPolicy(enabled=True, days=7)
file_service_properties = FileServiceProperties(
@@ -29,23 +29,22 @@ def mock_storage_get_storage_accounts(_):
name="name",
type="type",
share_delete_retention_policy=retention_policy,
smb_protocol_settings=SMBProtocolSettings(channel_encryption=[]),
)
return {
AZURE_SUBSCRIPTION_ID: [
Account(
id="id",
name="name",
resouce_group_name="rg",
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=False,
network_rule_set=NetworkRuleSet(
bypass="AzureServices", default_action="Allow"
),
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version="TLS1_2",
minimum_tls_version=None,
key_expiration_period_in_days=None,
private_endpoint_connections=[],
private_endpoint_connections=None,
location="westeurope",
blob_properties=blob_properties,
default_to_entra_authorization=True,
@@ -80,7 +79,7 @@ class Test_Storage_Service:
assert storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].name == "name"
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].resouce_group_name
== "rg"
is None
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].enable_https_traffic_only
@@ -92,21 +91,10 @@ class Test_Storage_Service:
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].allow_blob_public_access
is False
is None
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].network_rule_set
is not None
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].network_rule_set.bypass
== "AzureServices"
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
0
].network_rule_set.default_action
== "Allow"
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].network_rule_set is None
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].encryption_type == "None"
@@ -116,7 +104,7 @@ class Test_Storage_Service:
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][0].minimum_tls_version
== "TLS1_2"
is None
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
@@ -128,7 +116,7 @@ class Test_Storage_Service:
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
0
].private_endpoint_connections
== []
is None
)
assert storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
0
@@ -137,9 +125,7 @@ class Test_Storage_Service:
name="name",
type="type",
default_service_version=None,
container_delete_retention_policy=DeleteRetentionPolicy(
enabled=True, days=7
),
container_delete_retention_policy=None,
)
assert storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
0
@@ -189,19 +175,7 @@ class Test_Storage_Service:
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
0
].blob_properties.container_delete_retention_policy
is not None
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
0
].blob_properties.container_delete_retention_policy.enabled
is True
)
assert (
storage.storage_accounts[AZURE_SUBSCRIPTION_ID][
0
].blob_properties.container_delete_retention_policy.days
== 7
is None
)
def test_get_file_service_properties(self):
@@ -0,0 +1,237 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.storage.storage_service import (
Account,
DeleteRetentionPolicy,
FileServiceProperties,
SMBProtocolSettings,
)
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
)
class Test_storage_smb_channel_encryption_with_secure_algorithm:
def test_no_storage_accounts(self):
storage_client = mock.MagicMock()
storage_client.storage_accounts = {}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm.storage_client",
new=storage_client,
),
):
from prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm import (
storage_smb_channel_encryption_with_secure_algorithm,
)
check = storage_smb_channel_encryption_with_secure_algorithm()
result = check.execute()
assert len(result) == 0
def test_no_file_service_properties(self):
storage_account_id = str(uuid4())
storage_account_name = "Test Storage Account"
storage_client = mock.MagicMock()
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
file_service_properties=None,
)
]
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm.storage_client",
new=storage_client,
),
):
from prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm import (
storage_smb_channel_encryption_with_secure_algorithm,
)
check = storage_smb_channel_encryption_with_secure_algorithm()
result = check.execute()
assert len(result) == 0
def test_no_smb_protocol_settings(self):
storage_account_id = str(uuid4())
storage_account_name = "Test Storage Account"
file_service_properties = FileServiceProperties(
id="id1",
name="fs1",
type="type1",
share_delete_retention_policy=DeleteRetentionPolicy(enabled=True, days=7),
smb_protocol_settings=SMBProtocolSettings(channel_encryption=[]),
)
storage_client = mock.MagicMock()
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
file_service_properties=file_service_properties,
)
]
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm.storage_client",
new=storage_client,
),
):
from prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm import (
storage_smb_channel_encryption_with_secure_algorithm,
)
check = storage_smb_channel_encryption_with_secure_algorithm()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].status_extended == (
f"Storage account {storage_account_name} from subscription {AZURE_SUBSCRIPTION_ID} does not have SMB channel encryption enabled for file shares."
)
def test_not_recommended_encryption(self):
storage_account_id = str(uuid4())
storage_account_name = "Test Storage Account"
file_service_properties = FileServiceProperties(
id="id1",
name="fs1",
type="type1",
share_delete_retention_policy=DeleteRetentionPolicy(enabled=True, days=7),
smb_protocol_settings=SMBProtocolSettings(
channel_encryption=["AES-128-GCM"]
),
)
storage_client = mock.MagicMock()
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
file_service_properties=file_service_properties,
)
]
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm.storage_client",
new=storage_client,
),
):
from prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm import (
storage_smb_channel_encryption_with_secure_algorithm,
)
check = storage_smb_channel_encryption_with_secure_algorithm()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].status_extended == (
f"Storage account {storage_account_name} from subscription {AZURE_SUBSCRIPTION_ID} does not have SMB channel encryption with a secure algorithm for file shares since it supports AES-128-GCM."
)
def test_recommended_encryption(self):
storage_account_id = str(uuid4())
storage_account_name = "Test Storage Account"
file_service_properties = FileServiceProperties(
id="id1",
name="fs1",
type="type1",
share_delete_retention_policy=DeleteRetentionPolicy(enabled=True, days=7),
smb_protocol_settings=SMBProtocolSettings(
channel_encryption=["AES-256-GCM"]
),
)
storage_client = mock.MagicMock()
storage_client.storage_accounts = {
AZURE_SUBSCRIPTION_ID: [
Account(
id=storage_account_id,
name=storage_account_name,
resouce_group_name=None,
enable_https_traffic_only=False,
infrastructure_encryption=False,
allow_blob_public_access=None,
network_rule_set=None,
encryption_type="None",
minimum_tls_version=None,
key_expiration_period_in_days=None,
location="westeurope",
private_endpoint_connections=None,
file_service_properties=file_service_properties,
)
]
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm.storage_client",
new=storage_client,
),
):
from prowler.providers.azure.services.storage.storage_smb_channel_encryption_with_secure_algorithm.storage_smb_channel_encryption_with_secure_algorithm import (
storage_smb_channel_encryption_with_secure_algorithm,
)
check = storage_smb_channel_encryption_with_secure_algorithm()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert result[0].status_extended == (
f"Storage account {storage_account_name} from subscription {AZURE_SUBSCRIPTION_ID} has a secure algorithm for SMB channel encryption (AES-256-GCM) enabled for file shares since it supports AES-256-GCM."
)
@@ -0,0 +1,301 @@
from unittest import mock
from uuid import uuid4
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
)
class Test_vm_backup_enabled:
def test_vm_backup_enabled_no_subscriptions(self):
vm_client = mock.MagicMock
recovery_client = mock.MagicMock
vm_client.virtual_machines = {}
recovery_client.vaults = {}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.vm_client",
new=vm_client,
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.recovery_client",
new=recovery_client,
),
):
from prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled import (
vm_backup_enabled,
)
check = vm_backup_enabled()
result = check.execute()
assert len(result) == 0
def test_no_vms(self):
mock_vm_client = mock.MagicMock()
mock_vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {}}
mock_recovery_client = mock.MagicMock()
mock_recovery_client.vaults = {AZURE_SUBSCRIPTION_ID: {}}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.vm_client",
new=mock_vm_client,
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.recovery_client",
new=mock_recovery_client,
),
):
from prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled import (
vm_backup_enabled,
)
check = vm_backup_enabled()
result = check.execute()
assert len(result) == 0
def test_vm_protected_by_backup(self):
vm_id = str(uuid4())
vm_name = "VMTest"
vault_id = str(uuid4())
vault_name = "vault1"
mock_vm_client = mock.MagicMock()
mock_recovery_client = mock.MagicMock()
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.vm_client",
new=mock_vm_client,
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.recovery_client",
new=mock_recovery_client,
),
):
from azure.mgmt.recoveryservicesbackup.activestamp.models import (
DataSourceType,
)
from prowler.providers.azure.services.recovery.recovery_service import (
BackupItem,
BackupVault,
)
from prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled import (
vm_backup_enabled,
)
from prowler.providers.azure.services.vm.vm_service import (
ManagedDiskParameters,
OSDisk,
StorageProfile,
VirtualMachine,
)
vm = VirtualMachine(
resource_id=vm_id,
resource_name=vm_name,
location="eastus",
security_profile=None,
extensions=[],
storage_profile=StorageProfile(
os_disk=OSDisk(
name="os_disk_name",
operating_system_type="Linux",
managed_disk=ManagedDiskParameters(id="managed_disk_id"),
),
data_disks=[],
),
)
backup_item = BackupItem(
id=str(uuid4()),
name=f"someprefix;{vm_name}",
workload_type=DataSourceType.VM,
)
vault = BackupVault(
id=vault_id,
name=vault_name,
location="eastus",
backup_protected_items={backup_item.id: backup_item},
)
mock_vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {vm_id: vm}}
mock_recovery_client.vaults = {AZURE_SUBSCRIPTION_ID: {vault_id: vault}}
check = vm_backup_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
assert result[0].resource_name == vm_name
assert result[0].resource_id == vm_id
assert (
result[0].status_extended
== f"VM {vm_name} in subscription {AZURE_SUBSCRIPTION_ID} is protected by Azure Backup (vault: {vault_name})."
)
def test_vm_not_protected_by_backup(self):
vm_id = str(uuid4())
vm_name = "VMTest"
vault_id = str(uuid4())
vault_name = "vault1"
mock_vm_client = mock.MagicMock()
mock_recovery_client = mock.MagicMock()
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.vm_client",
new=mock_vm_client,
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.recovery_client",
new=mock_recovery_client,
),
):
from azure.mgmt.recoveryservicesbackup.activestamp.models import (
DataSourceType,
)
from prowler.providers.azure.services.recovery.recovery_service import (
BackupItem,
BackupVault,
)
from prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled import (
vm_backup_enabled,
)
from prowler.providers.azure.services.vm.vm_service import (
ManagedDiskParameters,
OSDisk,
StorageProfile,
VirtualMachine,
)
vm = VirtualMachine(
resource_id=vm_id,
resource_name=vm_name,
location="eastus",
security_profile=None,
extensions=[],
storage_profile=StorageProfile(
os_disk=OSDisk(
name="os_disk_name",
operating_system_type="Linux",
managed_disk=ManagedDiskParameters(id="managed_disk_id"),
),
data_disks=[],
),
)
backup_item = BackupItem(
id=str(uuid4()),
name="someprefix;OtherVM",
workload_type=DataSourceType.VM,
)
vault = BackupVault(
id=vault_id,
name=vault_name,
location="eastus",
backup_protected_items={backup_item.id: backup_item},
)
mock_vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {vm_id: vm}}
mock_recovery_client.vaults = {AZURE_SUBSCRIPTION_ID: {vault_id: vault}}
check = vm_backup_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
assert result[0].resource_name == vm_name
assert result[0].resource_id == vm_id
assert (
result[0].status_extended
== f"VM {vm_name} in subscription {AZURE_SUBSCRIPTION_ID} is not protected by Azure Backup."
)
def test_vm_protected_by_backup_non_vm_workload(self):
vm_id = str(uuid4())
vm_name = "VMTest"
vault_id = str(uuid4())
vault_name = "vault1"
mock_vm_client = mock.MagicMock()
mock_recovery_client = mock.MagicMock()
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.vm_client",
new=mock_vm_client,
),
mock.patch(
"prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled.recovery_client",
new=mock_recovery_client,
),
):
from azure.mgmt.recoveryservicesbackup.activestamp.models import (
DataSourceType,
)
from prowler.providers.azure.services.recovery.recovery_service import (
BackupItem,
BackupVault,
)
from prowler.providers.azure.services.vm.vm_backup_enabled.vm_backup_enabled import (
vm_backup_enabled,
)
from prowler.providers.azure.services.vm.vm_service import (
ManagedDiskParameters,
OSDisk,
StorageProfile,
VirtualMachine,
)
vm = VirtualMachine(
resource_id=vm_id,
resource_name=vm_name,
location="eastus",
security_profile=None,
extensions=[],
storage_profile=StorageProfile(
os_disk=OSDisk(
name="os_disk_name",
operating_system_type="Linux",
managed_disk=ManagedDiskParameters(id="managed_disk_id"),
),
data_disks=[],
),
)
backup_item = BackupItem(
id=str(uuid4()),
name=f"someprefix;{vm_name}",
workload_type=DataSourceType.FILE_FOLDER,
)
vault = BackupVault(
id=vault_id,
name=vault_name,
location="eastus",
backup_protected_items={backup_item.id: backup_item},
)
mock_vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {vm_id: vm}}
mock_recovery_client.vaults = {AZURE_SUBSCRIPTION_ID: {vault_id: vault}}
check = vm_backup_enabled()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
assert result[0].resource_name == vm_name
assert result[0].resource_id == vm_id
assert (
result[0].status_extended
== f"VM {vm_name} in subscription {AZURE_SUBSCRIPTION_ID} is not protected by Azure Backup."
)
@@ -0,0 +1,165 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.vm.vm_service import VirtualMachine
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
)
class Test_vm_ensure_using_approved_images:
def test_no_subscriptions(self):
vm_client = mock.MagicMock()
vm_client.virtual_machines = {}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images import (
vm_ensure_using_approved_images,
)
check = vm_ensure_using_approved_images()
result = check.execute()
assert len(result) == 0
def test_empty_vms_in_subscription(self):
vm_client = mock.MagicMock()
vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {}}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images import (
vm_ensure_using_approved_images,
)
check = vm_ensure_using_approved_images()
result = check.execute()
assert len(result) == 0
def test_vm_with_approved_image(self):
vm_id = str(uuid4())
approved_image_id = f"/subscriptions/{AZURE_SUBSCRIPTION_ID}/resourceGroups/rg/providers/Microsoft.Compute/images/custom-image"
vm = VirtualMachine(
resource_id=vm_id,
resource_name="VMTestApproved",
location="westeurope",
security_profile=None,
extensions=[],
storage_profile=None,
image_reference=approved_image_id,
)
vm_client = mock.MagicMock()
vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {vm_id: vm}}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images import (
vm_ensure_using_approved_images,
)
check = vm_ensure_using_approved_images()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert result[0].resource_name == "VMTestApproved"
assert result[0].resource_id == vm_id
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
expected_status_extended = f"VM VMTestApproved in subscription {AZURE_SUBSCRIPTION_ID} is using an approved machine image: custom-image."
assert result[0].status_extended == expected_status_extended
def test_vm_with_not_approved_image(self):
vm_id = str(uuid4())
not_approved_image_id = "/subscriptions/other/resourceGroups/rg/providers/Microsoft.Compute/otherResource/other-image"
vm = VirtualMachine(
resource_id=vm_id,
resource_name="VMTestNotApproved",
location="westeurope",
security_profile=None,
extensions=[],
storage_profile=None,
image_reference=not_approved_image_id,
)
vm_client = mock.MagicMock()
vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {vm_id: vm}}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images import (
vm_ensure_using_approved_images,
)
check = vm_ensure_using_approved_images()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].resource_name == "VMTestNotApproved"
assert result[0].resource_id == vm_id
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
expected_status_extended = f"VM VMTestNotApproved in subscription {AZURE_SUBSCRIPTION_ID} is not using an approved machine image."
assert result[0].status_extended == expected_status_extended
def test_vm_with_missing_image_reference(self):
vm_id = str(uuid4())
vm = VirtualMachine(
resource_id=vm_id,
resource_name="VMTestNoImageRef",
location="westeurope",
security_profile=None,
extensions=[],
storage_profile=None,
image_reference=None,
)
vm_client = mock.MagicMock()
vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {vm_id: vm}}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_ensure_using_approved_images.vm_ensure_using_approved_images import (
vm_ensure_using_approved_images,
)
check = vm_ensure_using_approved_images()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].resource_name == "VMTestNoImageRef"
assert result[0].resource_id == vm_id
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
expected_status_extended = f"VM VMTestNoImageRef in subscription {AZURE_SUBSCRIPTION_ID} is not using an approved machine image."
assert result[0].status_extended == expected_status_extended
@@ -86,6 +86,7 @@ class Test_vm_ensure_using_managed_disks:
),
data_disks=[],
),
linux_configuration=None,
),
}
}
@@ -142,6 +143,7 @@ class Test_vm_ensure_using_managed_disks:
),
data_disks=[],
),
linux_configuration=None,
)
}
}
@@ -200,6 +202,7 @@ class Test_vm_ensure_using_managed_disks:
DataDisk(lun=0, name="data_disk_1", managed_disk=None)
],
),
linux_configuration=None,
)
}
}
@@ -0,0 +1,173 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.vm.vm_service import (
LinuxConfiguration,
VirtualMachine,
)
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
)
class Test_vm_linux_enforce_ssh_authentication:
def test_no_subscriptions(self):
vm_client = mock.MagicMock
vm_client.virtual_machines = {}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication import (
vm_linux_enforce_ssh_authentication,
)
check = vm_linux_enforce_ssh_authentication()
result = check.execute()
assert len(result) == 0
def test_empty_subscription(self):
vm_client = mock.MagicMock
vm_client.virtual_machines = {AZURE_SUBSCRIPTION_ID: {}}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication import (
vm_linux_enforce_ssh_authentication,
)
check = vm_linux_enforce_ssh_authentication()
result = check.execute()
assert len(result) == 0
def test_linux_vm_password_auth_disabled(self):
vm_id = str(uuid4())
vm_client = mock.MagicMock
vm_client.virtual_machines = {
AZURE_SUBSCRIPTION_ID: {
vm_id: VirtualMachine(
resource_id=vm_id,
resource_name="LinuxVM",
location="westeurope",
security_profile=None,
extensions=[],
storage_profile=None,
linux_configuration=LinuxConfiguration(
disable_password_authentication=True
),
)
}
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication import (
vm_linux_enforce_ssh_authentication,
)
check = vm_linux_enforce_ssh_authentication()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
assert result[0].resource_name == "LinuxVM"
assert result[0].resource_id == vm_id
assert "password authentication disabled" in result[0].status_extended
def test_linux_vm_password_auth_enabled(self):
vm_id = str(uuid4())
vm_client = mock.MagicMock
vm_client.virtual_machines = {
AZURE_SUBSCRIPTION_ID: {
vm_id: VirtualMachine(
resource_id=vm_id,
resource_name="LinuxVM",
location="westeurope",
security_profile=None,
extensions=[],
storage_profile=None,
linux_configuration=LinuxConfiguration(
disable_password_authentication=False
),
)
}
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication import (
vm_linux_enforce_ssh_authentication,
)
check = vm_linux_enforce_ssh_authentication()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].subscription == AZURE_SUBSCRIPTION_ID
assert result[0].resource_name == "LinuxVM"
assert result[0].resource_id == vm_id
assert "password authentication enabled" in result[0].status_extended
def test_non_linux_vm(self):
vm_id = str(uuid4())
vm_client = mock.MagicMock
vm_client.virtual_machines = {
AZURE_SUBSCRIPTION_ID: {
vm_id: VirtualMachine(
resource_id=vm_id,
resource_name="WindowsVM",
location="westeurope",
security_profile=None,
extensions=[],
storage_profile=None,
linux_configuration=None, # Not a Linux VM
)
}
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication.vm_client",
new=vm_client,
),
):
from prowler.providers.azure.services.vm.vm_linux_enforce_ssh_authentication.vm_linux_enforce_ssh_authentication import (
vm_linux_enforce_ssh_authentication,
)
check = vm_linux_enforce_ssh_authentication()
result = check.execute()
assert len(result) == 0
@@ -0,0 +1,216 @@
from unittest import mock
from uuid import uuid4
from prowler.providers.azure.services.vm.vm_service import VirtualMachineScaleSet
from tests.providers.azure.azure_fixtures import (
AZURE_SUBSCRIPTION_ID,
set_mocked_azure_provider,
)
class Test_vm_scaleset_associated_with_load_balancer:
def test_no_subscriptions(self):
vm_scale_sets = {}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_client.vm_client.vm_scale_sets",
new=vm_scale_sets,
),
):
from prowler.providers.azure.services.vm.vm_scaleset_associated_with_load_balancer.vm_scaleset_associated_with_load_balancer import (
vm_scaleset_associated_with_load_balancer,
)
check = vm_scaleset_associated_with_load_balancer()
result = check.execute()
assert len(result) == 0
def test_empty_scale_sets(self):
vm_scale_sets = {AZURE_SUBSCRIPTION_ID: {}}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_client.vm_client.vm_scale_sets",
new=vm_scale_sets,
),
):
from prowler.providers.azure.services.vm.vm_scaleset_associated_with_load_balancer.vm_scaleset_associated_with_load_balancer import (
vm_scaleset_associated_with_load_balancer,
)
check = vm_scaleset_associated_with_load_balancer()
result = check.execute()
assert len(result) == 0
def test_compliant_scale_set(self):
vmss_id = str(uuid4())
backend_pool_id = f"/subscriptions/{AZURE_SUBSCRIPTION_ID}/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/bepool"
vm_scale_sets = {
AZURE_SUBSCRIPTION_ID: {
vmss_id: VirtualMachineScaleSet(
resource_id=vmss_id,
resource_name="compliant-vmss",
location="eastus",
load_balancer_backend_pools=[backend_pool_id],
)
}
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_client.vm_client.vm_scale_sets",
new=vm_scale_sets,
),
):
from prowler.providers.azure.services.vm.vm_scaleset_associated_with_load_balancer.vm_scaleset_associated_with_load_balancer import (
vm_scaleset_associated_with_load_balancer,
)
check = vm_scaleset_associated_with_load_balancer()
result = check.execute()
assert len(result) == 1
assert result[0].status == "PASS"
assert result[0].resource_id == vmss_id
assert result[0].resource_name == "compliant-vmss"
assert result[0].location == "eastus"
expected_status_extended = (
f"Scale set 'compliant-vmss' in subscription '{AZURE_SUBSCRIPTION_ID}' "
f"is associated with load balancer backend pool(s): bepool."
)
assert result[0].status_extended == expected_status_extended
def test_noncompliant_scale_set(self):
vmss_id = str(uuid4())
vm_scale_sets = {
AZURE_SUBSCRIPTION_ID: {
vmss_id: VirtualMachineScaleSet(
resource_id=vmss_id,
resource_name="noncompliant-vmss",
location="westeurope",
load_balancer_backend_pools=[],
)
}
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_client.vm_client.vm_scale_sets",
new=vm_scale_sets,
),
):
from prowler.providers.azure.services.vm.vm_scaleset_associated_with_load_balancer.vm_scaleset_associated_with_load_balancer import (
vm_scaleset_associated_with_load_balancer,
)
check = vm_scaleset_associated_with_load_balancer()
result = check.execute()
assert len(result) == 1
assert result[0].status == "FAIL"
assert result[0].resource_id == vmss_id
assert result[0].resource_name == "noncompliant-vmss"
assert result[0].location == "westeurope"
expected_status_extended = (
f"Scale set 'noncompliant-vmss' in subscription '{AZURE_SUBSCRIPTION_ID}' "
f"is not associated with any load balancer backend pool."
)
assert result[0].status_extended == expected_status_extended
def test_multiple_scale_sets(self):
compliant_id = str(uuid4())
noncompliant_id = str(uuid4())
backend_pool_id = f"/subscriptions/{AZURE_SUBSCRIPTION_ID}/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/bepool"
vm_scale_sets = {
AZURE_SUBSCRIPTION_ID: {
compliant_id: VirtualMachineScaleSet(
resource_id=compliant_id,
resource_name="compliant-vmss",
location="eastus",
load_balancer_backend_pools=[backend_pool_id],
),
noncompliant_id: VirtualMachineScaleSet(
resource_id=noncompliant_id,
resource_name="noncompliant-vmss",
location="westeurope",
load_balancer_backend_pools=[],
),
}
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_client.vm_client.vm_scale_sets",
new=vm_scale_sets,
),
):
from prowler.providers.azure.services.vm.vm_scaleset_associated_with_load_balancer.vm_scaleset_associated_with_load_balancer import (
vm_scaleset_associated_with_load_balancer,
)
check = vm_scaleset_associated_with_load_balancer()
result = check.execute()
assert len(result) == 2
for r in result:
if r.resource_name == "compliant-vmss":
expected_status_extended = (
f"Scale set 'compliant-vmss' in subscription '{AZURE_SUBSCRIPTION_ID}' "
f"is associated with load balancer backend pool(s): bepool."
)
assert r.status == "PASS"
assert r.status_extended == expected_status_extended
elif r.resource_name == "noncompliant-vmss":
expected_status_extended = (
f"Scale set 'noncompliant-vmss' in subscription '{AZURE_SUBSCRIPTION_ID}' "
f"is not associated with any load balancer backend pool."
)
assert r.status == "FAIL"
assert r.status_extended == expected_status_extended
def test_missing_attributes(self):
# Simulate a scale set with missing optional attributes
vmss_id = str(uuid4())
vm_scale_sets = {
AZURE_SUBSCRIPTION_ID: {
vmss_id: VirtualMachineScaleSet(
resource_id=vmss_id,
resource_name="",
location="",
load_balancer_backend_pools=[],
)
}
}
with (
mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_azure_provider(),
),
mock.patch(
"prowler.providers.azure.services.vm.vm_client.vm_client.vm_scale_sets",
new=vm_scale_sets,
),
):
from prowler.providers.azure.services.vm.vm_scaleset_associated_with_load_balancer.vm_scaleset_associated_with_load_balancer import (
vm_scaleset_associated_with_load_balancer,
)
check = vm_scaleset_associated_with_load_balancer()
result = check.execute()
assert len(result) == 1
expected_status_extended = f"Scale set '' in subscription '{AZURE_SUBSCRIPTION_ID}' is not associated with any load balancer backend pool."
assert result[0].status == "FAIL"
assert result[0].status_extended == expected_status_extended
@@ -2,6 +2,7 @@ from unittest.mock import patch
from prowler.providers.azure.services.vm.vm_service import (
Disk,
LinuxConfiguration,
ManagedDiskParameters,
OperatingSystemType,
OSDisk,
@@ -40,6 +41,7 @@ def mock_vm_get_virtual_machines(_):
),
data_disks=[],
),
linux_configuration=None,
)
}
}
@@ -55,6 +57,7 @@ def mock_vm_get_virtual_machines_with_none(_):
security_profile=None,
extensions=[],
storage_profile=None,
linux_configuration=None,
),
"vm_id-2": VirtualMachine(
resource_id="/subscriptions/resource_id2",
@@ -66,6 +69,7 @@ def mock_vm_get_virtual_machines_with_none(_):
os_disk=None,
data_disks=[],
),
linux_configuration=None,
),
}
}
@@ -85,6 +89,24 @@ def mock_vm_get_disks(_):
}
def mock_vm_get_virtual_machines_with_linux(_):
return {
AZURE_SUBSCRIPTION_ID: {
"vm_id-linux": VirtualMachine(
resource_id="/subscriptions/resource_id_linux",
resource_name="LinuxVM",
location="location",
security_profile=None,
extensions=[],
storage_profile=None,
linux_configuration=LinuxConfiguration(
disable_password_authentication=True
),
)
}
}
@patch(
"prowler.providers.azure.services.vm.vm_service.VirtualMachines._get_virtual_machines",
new=mock_vm_get_virtual_machines,
@@ -186,3 +208,14 @@ class Test_VirtualMachines_NoneCases:
assert vm_2.storage_profile.os_disk is None
assert vm_2.storage_profile.data_disks == []
assert vm_2.resource_name == "VMWithPartialNone"
@patch(
"prowler.providers.azure.services.vm.vm_service.VirtualMachines._get_virtual_machines",
new=mock_vm_get_virtual_machines_with_linux,
)
def test_virtual_machine_with_linux_configuration():
virtual_machines = VirtualMachines(set_mocked_azure_provider())
vm = virtual_machines.virtual_machines[AZURE_SUBSCRIPTION_ID]["vm_id-linux"]
assert vm.linux_configuration is not None
assert vm.linux_configuration.disable_password_authentication is True
+7 -4
View File
@@ -5,7 +5,14 @@ All notable changes to the **Prowler UI** are documented in this file.
## [v1.9.0] (Prowler v5.9.0) UNRELEASED
### 🚀 Added
- SAML login integration [(#8203)](https://github.com/prowler-cloud/prowler/pull/8203)
- Introduced new `CustomLink` component for handling all navigation and link-related behavior [(#8195)] (https://github.com/prowler-cloud/prowler/pull/8195)
### 🔄 Changed
- Upgrade to Next.js 14.2.30 and lock TypeScript to 5.5.4 for ESLint compatibility [(#8189)](https://github.com/prowler-cloud/prowler/pull/8189)
### 🐞 Fixed
### Removed
@@ -13,9 +20,6 @@ All notable changes to the **Prowler UI** are documented in this file.
## [v1.8.1] (Prowler 5.8.1)
### 🔄 Changed
- Latest new failed findings now use `GET /findings/latest` [(#8219)](https://github.com/prowler-cloud/prowler/pull/8219)
### Removed
- Validation of the provider's secret type during updates [(#8197)](https://github.com/prowler-cloud/prowler/pull/8197)
@@ -40,7 +44,6 @@ All notable changes to the **Prowler UI** are documented in this file.
- Improve `Scan ID` filter by adding more context and enhancing the UI/UX [(#8046)](https://github.com/prowler-cloud/prowler/pull/8046)
- Lighthouse chat interface [(#7878)](https://github.com/prowler-cloud/prowler/pull/7878)
- Google Tag Manager integration [(#8058)](https://github.com/prowler-cloud/prowler/pull/8058)
<!-- - SAML login integration [(#8094)](https://github.com/prowler-cloud/prowler/pull/8094) -->
### 🔄 Changed

Some files were not shown because too many files have changed in this diff Show More