diff --git a/.env b/.env index 447fd3c1b4..3fe4707dcf 100644 --- a/.env +++ b/.env @@ -144,6 +144,7 @@ SOCIAL_GITHUB_OAUTH_CLIENT_SECRET="" # Single Sign-On (SSO) SAML_PUBLIC_CERT="" SAML_PRIVATE_KEY="" +SAML_SSO_CALLBACK_URL="${AUTH_URL}/api/auth/callback/saml" # Lighthouse tracing LANGSMITH_TRACING=false diff --git a/api/src/backend/api/migrations/0033_samltoken.py b/api/src/backend/api/migrations/0033_samltoken.py new file mode 100644 index 0000000000..f7abefd2e6 --- /dev/null +++ b/api/src/backend/api/migrations/0033_samltoken.py @@ -0,0 +1,44 @@ +# Generated by Django 5.1.10 on 2025-06-25 10:06 + +import uuid + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("api", "0032_scan_disable_on_cascade_periodic_tasks"), + ] + + operations = [ + migrations.CreateModel( + name="SAMLToken", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("inserted_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("expires_at", models.DateTimeField(editable=False)), + ("token", models.JSONField(unique=True)), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "db_table": "saml_tokens", + }, + ), + ] diff --git a/api/src/backend/api/models.py b/api/src/backend/api/models.py index 3f8e1264c9..a91948c43d 100644 --- a/api/src/backend/api/models.py +++ b/api/src/backend/api/models.py @@ -2,6 +2,7 @@ import json import logging import re import xml.etree.ElementTree as ET +from datetime import datetime, timedelta, timezone from uuid import UUID, uuid4 from allauth.socialaccount.models import SocialApp @@ -1370,6 +1371,26 @@ class IntegrationProviderRelationship(RowLevelSecurityProtectedModel): ] +class SAMLToken(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) + inserted_at = models.DateTimeField(auto_now_add=True, editable=False) + updated_at = models.DateTimeField(auto_now=True, editable=False) + expires_at = models.DateTimeField(editable=False) + token = models.JSONField(unique=True) + user = models.ForeignKey(User, on_delete=models.CASCADE) + + class Meta: + db_table = "saml_tokens" + + def save(self, *args, **kwargs): + if not self.expires_at: + self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=15) + super().save(*args, **kwargs) + + def is_expired(self) -> bool: + return datetime.now(timezone.utc) >= self.expires_at + + class SAMLDomainIndex(models.Model): """ Public index of SAML domains. No RLS. Used for fast lookup in SAML login flow. diff --git a/api/src/backend/api/tests/test_views.py b/api/src/backend/api/tests/test_views.py index 8771e6a1ee..c1d6b37ba0 100644 --- a/api/src/backend/api/tests/test_views.py +++ b/api/src/backend/api/tests/test_views.py @@ -6,6 +6,8 @@ import tempfile from datetime import datetime, timedelta, timezone from pathlib import Path from unittest.mock import ANY, MagicMock, Mock, patch +from urllib.parse import parse_qs, urlparse +from uuid import uuid4 import jwt import pytest @@ -33,6 +35,7 @@ from api.models import ( Role, RoleProviderGroupRelationship, SAMLConfiguration, + SAMLToken, Scan, StateChoices, Task, @@ -5561,6 +5564,76 @@ class TestIntegrationViewSet: assert response.status_code == status.HTTP_400_BAD_REQUEST +@pytest.mark.django_db +class TestSAMLTokenValidation: + def test_valid_token_returns_tokens(self, authenticated_client, create_test_user): + user = create_test_user + valid_token_data = { + "access": "mock_access_token", + "refresh": "mock_refresh_token", + } + saml_token = SAMLToken.objects.create( + token=valid_token_data, + user=user, + expires_at=datetime.now(timezone.utc) + timedelta(seconds=10), + ) + + url = reverse("token-saml") + response = authenticated_client.post(f"{url}?id={saml_token.id}") + + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"data": valid_token_data} + assert not SAMLToken.objects.filter(id=saml_token.id).exists() + + def test_invalid_token_id_returns_404(self, authenticated_client): + url = reverse("token-saml") + response = authenticated_client.post(f"{url}?id={str(uuid4())}") + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.json()["errors"]["detail"] == "Invalid token ID." + + def test_expired_token_returns_400(self, authenticated_client, create_test_user): + user = create_test_user + expired_token_data = { + "access": "expired_access_token", + "refresh": "expired_refresh_token", + } + saml_token = SAMLToken.objects.create( + token=expired_token_data, + user=user, + expires_at=datetime.now(timezone.utc) - timedelta(seconds=1), + ) + + url = reverse("token-saml") + response = authenticated_client.post(f"{url}?id={saml_token.id}") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["errors"]["detail"] == "Token expired." + assert SAMLToken.objects.filter(id=saml_token.id).exists() + + def test_token_can_be_used_only_once(self, authenticated_client, create_test_user): + user = create_test_user + token_data = { + "access": "single_use_token", + "refresh": "single_use_refresh", + } + saml_token = SAMLToken.objects.create( + token=token_data, + user=user, + expires_at=datetime.now(timezone.utc) + timedelta(seconds=10), + ) + + url = reverse("token-saml") + + # First use: should succeed + response1 = authenticated_client.post(f"{url}?id={saml_token.id}") + assert response1.status_code == status.HTTP_200_OK + + # Second use: should fail (already deleted) + response2 = authenticated_client.post(f"{url}?id={saml_token.id}") + assert response2.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.django_db class TestSAMLInitiateAPIView: def test_valid_email_domain_and_certificates( @@ -5575,11 +5648,11 @@ class TestSAMLInitiateAPIView: response = authenticated_client.post(url, data=payload, format="json") assert response.status_code == status.HTTP_302_FOUND - assert f"email={saml_setup['email']}" in response.url assert ( reverse("saml_login", kwargs={"organization_slug": saml_setup["domain"]}) in response.url ) + assert "SAMLRequest" not in response.url def test_invalid_email_domain(self, authenticated_client): url = reverse("api_saml_initiate") @@ -5752,9 +5825,10 @@ class TestTenantFinishACSView: assert isinstance(response, JsonResponse) or response.status_code in [200, 302] - def test_dispatch_sets_user_profile_and_assigns_role( - self, create_test_user, tenants_fixture, saml_setup + def test_dispatch_sets_user_profile_and_assigns_role_and_creates_token( + self, create_test_user, tenants_fixture, saml_setup, settings, monkeypatch ): + monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete") user = create_test_user original_email = user.email original_name = user.name @@ -5784,26 +5858,29 @@ class TestTenantFinishACSView: patch("allauth.socialaccount.models.SocialApp.objects.get"), patch( "allauth.socialaccount.models.SocialAccount.objects.get" - ) as mock_socialaccount_get, - patch("api.v1.serializers.TokenSocialLoginSerializer") as mock_serializer, + ) as mock_sa_get, ): mock_get_app_or_404.return_value = MagicMock( provider="saml", client_id="testtenant", name="Test App", settings={} ) - - mock_socialaccount_get.return_value = social_account - - mock_instance = mock_serializer.return_value - mock_instance.is_valid.return_value = True - mock_instance.validated_data = { - "token": "mocktoken", - "refresh_token": "mockrefresh", - } + mock_sa_get.return_value = social_account view = TenantFinishACSView.as_view() response = view(request, organization_slug="testtenant") - assert response.status_code == 200 + assert response.status_code == 302 + + expected_callback_host = "localhost" + parsed_url = urlparse(response.url) + assert parsed_url.netloc == expected_callback_host + query_params = parse_qs(parsed_url.query) + assert "id" in query_params + + token_id = query_params["id"][0] + token_obj = SAMLToken.objects.get(id=token_id) + assert token_obj.user == user + assert not token_obj.is_expired() + user.refresh_from_db() assert user.name == "John Doe" assert user.company_name == "TestOrg" @@ -5816,6 +5893,7 @@ class TestTenantFinishACSView: .filter(user=user, tenant_id=tenants_fixture[0].id) .exists() ) + user.email = original_email user.name = original_name user.company_name = original_company diff --git a/api/src/backend/api/v1/urls.py b/api/src/backend/api/v1/urls.py index c96a6f56d0..bd5b8b60d1 100644 --- a/api/src/backend/api/v1/urls.py +++ b/api/src/backend/api/v1/urls.py @@ -4,6 +4,7 @@ from rest_framework_nested import routers from api.v1.views import ( ComplianceOverviewViewSet, + CustomSAMLLoginView, CustomTokenObtainView, CustomTokenRefreshView, CustomTokenSwitchTenantView, @@ -25,6 +26,7 @@ from api.v1.views import ( RoleViewSet, SAMLConfigurationViewSet, SAMLInitiateAPIView, + SAMLTokenValidateView, ScanViewSet, ScheduleViewSet, SchemaView, @@ -129,10 +131,16 @@ urlpatterns = [ # Allauth SAML endpoints for tenants path("accounts/", include("allauth.urls")), path( - "api/v1/accounts/saml//acs/finish/", + "accounts/saml//login/", + CustomSAMLLoginView.as_view(), + name="saml_login", + ), + path( + "accounts/saml//acs/finish/", TenantFinishACSView.as_view(), name="saml_finish_acs", ), + path("tokens/saml", SAMLTokenValidateView.as_view(), name="token-saml"), path("tokens/google", GoogleSocialLoginView.as_view(), name="token-google"), path("tokens/github", GithubSocialLoginView.as_view(), name="token-github"), path("", include(router.urls)), diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 7d6cf1bfa0..799b1b5f2b 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -6,7 +6,7 @@ import sentry_sdk from allauth.socialaccount.models import SocialAccount, SocialApp from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter -from allauth.socialaccount.providers.saml.views import FinishACSView +from allauth.socialaccount.providers.saml.views import FinishACSView, LoginView from botocore.exceptions import ClientError, NoCredentialsError, ParamValidationError from celery.result import AsyncResult from config.env import env @@ -21,7 +21,7 @@ from django.contrib.postgres.search import SearchQuery from django.db import transaction from django.db.models import Count, Exists, F, OuterRef, Prefetch, Q, Sum from django.db.models.functions import Coalesce -from django.http import HttpResponse, JsonResponse +from django.http import HttpResponse from django.shortcuts import redirect from django.urls import reverse from django.utils.dateparse import parse_date @@ -107,6 +107,7 @@ from api.models import ( RoleProviderGroupRelationship, SAMLConfiguration, SAMLDomainIndex, + SAMLToken, Scan, ScanSummary, SeverityChoices, @@ -401,17 +402,70 @@ class GithubSocialLoginView(SocialLoginView): return original_response +@extend_schema(exclude=True) +class SAMLTokenValidateView(GenericAPIView): + resource_name = "tokens" + http_method_names = ["post"] + + def post(self, request): + token_id = request.query_params.get("id", "invalid") + try: + saml_token = SAMLToken.objects.using(MainRouter.admin_db).get(id=token_id) + except SAMLToken.DoesNotExist: + return Response({"detail": "Invalid token ID."}, status=404) + + if saml_token.is_expired(): + return Response({"detail": "Token expired."}, status=400) + + token_data = saml_token.token + # Currently we don't store the tokens in the database, so we delete the token after use + saml_token.delete() + + return Response(token_data, status=200) + + +@extend_schema(exclude=True) +class CustomSAMLLoginView(LoginView): + def dispatch(self, request, *args, **kwargs): + """ + Convert GET requests to POST to bypass allauth's confirmation screen. + + Why this is necessary: + - django-allauth requires POST for social logins to prevent open redirect attacks + - SAML login links typically use GET requests (e.g., ) + - This conversion allows seamless login without user-facing confirmation + + Security considerations: + 1. Preserves CSRF protection: Original POST handling remains intact + 2. Avoids global SOCIALACCOUNT_LOGIN_ON_GET=True which would: + - Enable GET logins for ALL providers (security risk) + - Potentially expose open redirect vulnerabilities + 3. SAML payloads remain signed/encrypted regardless of HTTP method + 4. No sensitive parameters are exposed in URLs (copied to POST body) + + This approach maintains security while providing better UX. + """ + if request.method == "GET": + # Convert GET to POST while preserving parameters + request.method = "POST" + # Safe because SAML validates signatures + request.POST = request.GET.copy() + return super().dispatch(request, *args, **kwargs) + + @extend_schema(exclude=True) class SAMLInitiateAPIView(GenericAPIView): serializer_class = SamlInitiateSerializer permission_classes = [] def post(self, request, *args, **kwargs): + # Validate the input payload and extract the domain serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) email = serializer.validated_data["email_domain"] domain = email.split("@", 1)[-1].lower() + # Retrieve the SAML configuration for the given email domain try: check = SAMLDomainIndex.objects.get(email_domain=domain) with rls_transaction(str(check.tenant_id)): @@ -431,10 +485,12 @@ class SAMLInitiateAPIView(GenericAPIView): status=status.HTTP_403_FORBIDDEN, ) - saml_login_url = reverse( + relative = reverse( "saml_login", kwargs={"organization_slug": config.email_domain} ) - return redirect(f"{saml_login_url}?email={email}") + login_url = request.build_absolute_uri(relative) + + return redirect(login_url) @extend_schema_view( @@ -560,12 +616,15 @@ class TenantFinishACSView(FinishACSView): serializer = TokenSocialLoginSerializer(data={"email": user.email}) serializer.is_valid(raise_exception=True) - return JsonResponse( - { - "type": "saml-social-tokens", - "attributes": serializer.validated_data, - } + + token_data = serializer.validated_data + saml_token = SAMLToken.objects.using(MainRouter.admin_db).create( + token=token_data, user=user ) + callback_url = env.str("SAML_SSO_CALLBACK_URL") + redirect_url = f"{callback_url}?id={saml_token.id}" + + return redirect(redirect_url) @extend_schema_view( diff --git a/api/src/backend/config/django/base.py b/api/src/backend/config/django/base.py index 47c6320ce1..c7aea17c72 100644 --- a/api/src/backend/config/django/base.py +++ b/api/src/backend/config/django/base.py @@ -248,3 +248,6 @@ X_FRAME_OPTIONS = "DENY" SECURE_REFERRER_POLICY = "strict-origin-when-cross-origin" DJANGO_DELETION_BATCH_SIZE = env.int("DJANGO_DELETION_BATCH_SIZE", 5000) + +# This is needed to forward the host header when defined +USE_X_FORWARDED_HOST = True diff --git a/docs/tutorials/prowler-app-sso.md b/docs/tutorials/prowler-app-sso.md index a82971c8a9..e94290fd9a 100644 --- a/docs/tutorials/prowler-app-sso.md +++ b/docs/tutorials/prowler-app-sso.md @@ -24,7 +24,7 @@ DJANGO_ALLOWED_HOSTS=localhost,127.0.0.1,prowler-api,mycompany.prowler To enable SAML support, you must provide a public certificate and private key to allow Prowler to sign SAML requests and validate responses. -### Why is this necessary? +### Why is this necessary? SAML relies on digital signatures to verify trust between the Identity Provider (IdP) and the Service Provider (SP). Prowler acts as the SP and must use a certificate to sign outbound authentication requests. @@ -121,7 +121,7 @@ Start ngrok on port 8080: ngrok http 8080 ``` -Then, copy the generated ngrok URL and include it in the ALLOWED_HOSTS setting. If you’re using the development environment, it usually defaults to *, but in some cases this may not work properly, like in my tests (investigate): +Then, copy the generated ngrok URL and include it in the ALLOWED_HOSTS setting. If you're using the development environment, it usually defaults to *, but in some cases this may not work properly, like in my tests (investigate): ``` ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["*"]) @@ -129,7 +129,7 @@ ALLOWED_HOSTS = env.list("DJANGO_ALLOWED_HOSTS", default=["*"]) ## 4. Configure the Identity Provider (IdP) -Start your environment and configure your IdP. You will need to download the IdP’s metadata XML file. +Start your environment and configure your IdP. You will need to download the IdP's metadata XML file. Your Assertion Consumer Service (ACS) URL must follow this format: @@ -147,7 +147,7 @@ The following fields are expected from the IdP: - userType (this is the name of the role the user should be assigned) -- companyName (this is filled automatically if the IdP includes an “organization” field) +- companyName (this is filled automatically if the IdP includes an "organization" field) These values are dynamic. If the values change in the IdP, they will be updated on the next login. @@ -171,7 +171,37 @@ curl --location 'http://localhost:8080/api/v1/saml-config' \ }' ``` -## 7. Start SAML Login Flow +## 7. SAML SSO Callback Configuration + +### Environment Variable Configuration + +The SAML authentication flow requires proper callback URL configuration to handle post-authentication redirects. Configure the following environment variables: + +#### `SAML_SSO_CALLBACK_URL` + +Specifies the callback endpoint that will be invoked upon successful SAML authentication completion. This URL directs users back to the web application interface. + +```env +SAML_SSO_CALLBACK_URL="${AUTH_URL}/api/auth/callback/saml" +``` + +#### `AUTH_URL` + +Defines the base URL of the web user interface application that serves as the authentication callback destination. + +```env +AUTH_URL="" +``` + +### Configuration Notes + +- The `SAML_SSO_CALLBACK_URL` dynamically references the `AUTH_URL` variable to construct the complete callback endpoint +- Ensure the `AUTH_URL` points to the correct web UI deployment (development, staging, or production) +- The callback endpoint `/api/auth/callback/saml` must be accessible and properly configured to handle SAML authentication responses +- Both environment variables are required for proper SAML SSO functionality +- Verify that the `API_BASE_URL` environment variable is properly configured to reference the correct API server base URL corresponding to your target deployment environment. This ensures proper routing of SAML callback requests to the appropriate backend services. + +## 8. Start SAML Login Flow Once everything is configured, start the SAML login process by visiting the following URL: @@ -181,6 +211,6 @@ https:///api/v1/accounts/saml//login/?email=