From bf3b5c2ba713e533014927141b64948c82c8f32e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Pe=C3=B1a?= Date: Fri, 19 Jun 2026 13:38:51 +0200 Subject: [PATCH] Merge commit from fork * fix(saml): cross-tenant account takeover via SAML domain claiming * chore(changelog): add PR # * fix(api): bind SAML tokens to validated domain - Reject SAML assertions with mismatched email domains - Issue SAML tokens from the validated ACS tenant - Add regression coverage for cross-tenant SAML token issuance * fix(api): resolve SAML tenant inside RLS context - Load the SAML tenant relation before leaving the RLS transaction - Avoid lazy tenant lookups during the SAML ACS finish flow --------- Co-authored-by: Pepe Fagoaga --- api/CHANGELOG.md | 8 + api/src/backend/api/adapters.py | 44 ++++- api/src/backend/api/tests/test_adapters.py | 162 +++++++++++++++++-- api/src/backend/api/tests/test_views.py | 178 +++++++++++++++++---- api/src/backend/api/v1/views.py | 21 ++- 5 files changed, 362 insertions(+), 51 deletions(-) diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index a684591651..fe93dcf1b2 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -28,6 +28,14 @@ All notable changes to the **Prowler API** are documented in this file. --- +## [1.31.3] (Prowler v5.30.3) + +### 🔐 Security + +- SAML logins now link to an existing account only when the asserted email domain matches the ACS endpoint and the user is already a member of that domain's tenant, fixing a cross-tenant account takeover [(GHSA-h8m9-jgf8-vwvp)](https://github.com/prowler-cloud/prowler/security/advisories/GHSA-h8m9-jgf8-vwvp) [(#XXXXX)](https://github.com/prowler-cloud/prowler/pull/XXXXX) + +--- + ## [1.31.2] (Prowler v5.30.2) ### 🔄 Changed diff --git a/api/src/backend/api/adapters.py b/api/src/backend/api/adapters.py index e09dc972b4..1d0d2ace00 100644 --- a/api/src/backend/api/adapters.py +++ b/api/src/backend/api/adapters.py @@ -3,7 +3,14 @@ from django.db import transaction from api.db_router import MainRouter from api.db_utils import rls_transaction -from api.models import Membership, Role, Tenant, User, UserRoleRelationship +from api.models import ( + Membership, + Role, + SAMLConfiguration, + Tenant, + User, + UserRoleRelationship, +) class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter): @@ -18,7 +25,42 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter): # Link existing accounts with the same email address email = sociallogin.account.extra_data.get("email") if sociallogin.provider.id == "saml": + # For SAML, the asserted NameID email cannot be trusted on its own: + # any tenant can claim any email domain in its SAML configuration. To + # prevent cross-tenant account takeover (GHSA-h8m9-jgf8-vwvp), only link + # the incoming SAML session to an existing account when (1) the email + # domain matches the tenant whose ACS endpoint is being used and (2) the + # existing user is already a member of that tenant. email = sociallogin.user.email + if not email: + return + + domain = email.rsplit("@", 1)[-1].lower() + resolver_match = getattr(request, "resolver_match", None) + organization_slug = ( + (resolver_match.kwargs or {}).get("organization_slug", "") + if resolver_match + else "" + ).lower() + # The ACS endpoint is scoped per email domain; reject mismatches so an + # attacker cannot replay an assertion through another tenant's endpoint. + if organization_slug != domain: + return + + try: + saml_config = SAMLConfiguration.objects.using(MainRouter.admin_db).get( + email_domain=domain + ) + except SAMLConfiguration.DoesNotExist: + return + + existing_user = self.get_user_by_email(email) + if existing_user and existing_user.is_member_of_tenant( + str(saml_config.tenant_id) + ): + sociallogin.connect(request, existing_user) + return + if email: existing_user = self.get_user_by_email(email) if existing_user: diff --git a/api/src/backend/api/tests/test_adapters.py b/api/src/backend/api/tests/test_adapters.py index 22b44b3506..182e4ddc25 100644 --- a/api/src/backend/api/tests/test_adapters.py +++ b/api/src/backend/api/tests/test_adapters.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -5,9 +6,48 @@ from allauth.socialaccount.models import SocialLogin from django.contrib.auth import get_user_model from api.adapters import ProwlerSocialAccountAdapter +from api.db_router import MainRouter +from api.models import SAMLConfiguration User = get_user_model() +# Minimal, well-formed IdP metadata accepted by SAMLConfiguration._parse_metadata. +VALID_METADATA = """ + + + + + + FAKECERTDATA + + + + + + +""" + + +def _saml_request(rf, organization_slug): + """Build an ACS request whose resolver_match carries the organization slug, + mirroring how Django populates it after routing the SAML ACS URL.""" + request = rf.post(f"/api/v1/accounts/saml/{organization_slug}/acs/finish/") + request.resolver_match = SimpleNamespace( + kwargs={"organization_slug": organization_slug} + ) + return request + + +def _saml_sociallogin(user): + sociallogin = MagicMock(spec=SocialLogin) + sociallogin.account = MagicMock() + sociallogin.provider = MagicMock() + sociallogin.provider.id = "saml" + sociallogin.account.extra_data = {} + sociallogin.user = user + sociallogin.connect = MagicMock() + return sociallogin + @pytest.mark.django_db class TestProwlerSocialAccountAdapter: @@ -20,26 +60,99 @@ class TestProwlerSocialAccountAdapter: adapter = ProwlerSocialAccountAdapter() assert adapter.get_user_by_email("notfound@example.com") is None - def test_pre_social_login_links_existing_user(self, create_test_user, rf): + def test_pre_social_login_links_member_of_saml_tenant( + self, create_test_user, tenants_fixture, rf + ): + """A SAML login links to an existing account only when that user is + already a member of the tenant that owns the asserted email domain.""" adapter = ProwlerSocialAccountAdapter() + # create_test_user (dev@prowler.com) is a member of tenant1. + domain = create_test_user.email.rsplit("@", 1)[-1] + SAMLConfiguration.objects.using(MainRouter.admin_db).create( + email_domain=domain, + metadata_xml=VALID_METADATA, + tenant=tenants_fixture[0], + ) - sociallogin = MagicMock(spec=SocialLogin) - sociallogin.account = MagicMock() - sociallogin.provider = MagicMock() - sociallogin.provider.id = "saml" - sociallogin.account.extra_data = {} - sociallogin.user = create_test_user - sociallogin.connect = MagicMock() - - adapter.pre_social_login(rf.get("/"), sociallogin) + sociallogin = _saml_sociallogin(create_test_user) + adapter.pre_social_login(_saml_request(rf, domain), sociallogin) call_args = sociallogin.connect.call_args assert call_args is not None - - called_request, called_user = call_args[0] - assert called_request.path == "/" + _, called_user = call_args[0] assert called_user.email == create_test_user.email + def test_pre_social_login_blocks_cross_tenant_takeover( + self, create_test_user, tenants_fixture, rf + ): + """GHSA-h8m9-jgf8-vwvp: an attacker tenant that claims the victim's + email domain must NOT be able to link to the victim's account, because + the victim is not a member of the attacker's tenant.""" + adapter = ProwlerSocialAccountAdapter() + domain = create_test_user.email.rsplit("@", 1)[-1] + # tenant3 is the attacker tenant; create_test_user is NOT a member of it. + attacker_tenant = tenants_fixture[2] + assert not create_test_user.is_member_of_tenant(str(attacker_tenant.id)) + SAMLConfiguration.objects.using(MainRouter.admin_db).create( + email_domain=domain, + metadata_xml=VALID_METADATA, + tenant=attacker_tenant, + ) + + sociallogin = _saml_sociallogin(create_test_user) + adapter.pre_social_login(_saml_request(rf, domain), sociallogin) + + sociallogin.connect.assert_not_called() + + def test_pre_social_login_blocks_domain_slug_mismatch( + self, create_test_user, tenants_fixture, rf + ): + """The asserted email domain must match the ACS endpoint's slug, so an + assertion cannot be replayed through a different tenant's endpoint.""" + adapter = ProwlerSocialAccountAdapter() + domain = create_test_user.email.rsplit("@", 1)[-1] + SAMLConfiguration.objects.using(MainRouter.admin_db).create( + email_domain=domain, + metadata_xml=VALID_METADATA, + tenant=tenants_fixture[0], + ) + + sociallogin = _saml_sociallogin(create_test_user) + # Slug points at a different domain than the asserted email. + adapter.pre_social_login(_saml_request(rf, "attacker.com"), sociallogin) + + sociallogin.connect.assert_not_called() + + def test_pre_social_login_blocks_when_no_saml_config( + self, create_test_user, tenants_fixture, rf + ): + """No SAML configuration for the domain means nothing to link against.""" + adapter = ProwlerSocialAccountAdapter() + domain = create_test_user.email.rsplit("@", 1)[-1] + + sociallogin = _saml_sociallogin(create_test_user) + adapter.pre_social_login(_saml_request(rf, domain), sociallogin) + + sociallogin.connect.assert_not_called() + + def test_pre_social_login_blocks_without_resolver_match( + self, create_test_user, tenants_fixture, rf + ): + """Fail closed: if the request has no resolver_match we cannot bind the + assertion to a tenant, so no linking happens.""" + adapter = ProwlerSocialAccountAdapter() + domain = create_test_user.email.rsplit("@", 1)[-1] + SAMLConfiguration.objects.using(MainRouter.admin_db).create( + email_domain=domain, + metadata_xml=VALID_METADATA, + tenant=tenants_fixture[0], + ) + + sociallogin = _saml_sociallogin(create_test_user) + adapter.pre_social_login(rf.post("/"), sociallogin) + + sociallogin.connect.assert_not_called() + def test_pre_social_login_no_link_if_email_missing(self, rf): adapter = ProwlerSocialAccountAdapter() @@ -47,14 +160,35 @@ class TestProwlerSocialAccountAdapter: sociallogin.account = MagicMock() sociallogin.provider = MagicMock() sociallogin.user = MagicMock() + sociallogin.user.email = "" sociallogin.provider.id = "saml" sociallogin.account.extra_data = {} sociallogin.connect = MagicMock() - adapter.pre_social_login(rf.get("/"), sociallogin) + adapter.pre_social_login(_saml_request(rf, "prowler.com"), sociallogin) sociallogin.connect.assert_not_called() + def test_pre_social_login_non_saml_links_by_email(self, create_test_user, rf): + """Non-SAML providers (e.g. Google/GitHub) still link to an existing + local account by email; the tenant binding only applies to SAML.""" + adapter = ProwlerSocialAccountAdapter() + + sociallogin = MagicMock(spec=SocialLogin) + sociallogin.account = MagicMock() + sociallogin.provider = MagicMock() + sociallogin.provider.id = "google" + sociallogin.account.extra_data = {"email": create_test_user.email} + sociallogin.user = create_test_user + sociallogin.connect = MagicMock() + + adapter.pre_social_login(rf.get("/"), sociallogin) + + call_args = sociallogin.connect.call_args + assert call_args is not None + _, called_user = call_args[0] + assert called_user.email == create_test_user.email + def test_save_user_saml_sets_session_flag(self, rf): adapter = ProwlerSocialAccountAdapter() request = rf.get("/") diff --git a/api/src/backend/api/tests/test_views.py b/api/src/backend/api/tests/test_views.py index 35098c1907..49234777da 100644 --- a/api/src/backend/api/tests/test_views.py +++ b/api/src/backend/api/tests/test_views.py @@ -13586,7 +13586,9 @@ class TestTenantFinishACSView: ) request = RequestFactory().get( - reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"}) + reverse( + "saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]} + ) ) request.user = user request.session = {} @@ -13606,18 +13608,23 @@ class TestTenantFinishACSView: patch("api.models.User.objects.get") as mock_user_get, ): mock_get_app_or_404.return_value = MagicMock( - provider="saml", client_id="testtenant", name="Test App", settings={} + provider="saml", + client_id=saml_setup["domain"], + name="Test App", + settings={}, ) mock_sa_get.return_value = social_account mock_socialapp_get.return_value = MagicMock(provider_id="saml") mock_saml_domain_get.return_value = SimpleNamespace( tenant_id=tenants_fixture[0].id ) - mock_saml_config_get.return_value = MagicMock() + mock_saml_config_get.return_value = SimpleNamespace( + email_domain=saml_setup["domain"], tenant=tenants_fixture[0] + ) mock_user_get.return_value = user view = TenantFinishACSView.as_view() - response = view(request, organization_slug="testtenant") + response = view(request, organization_slug=saml_setup["domain"]) assert response.status_code == 302 @@ -13665,6 +13672,79 @@ class TestTenantFinishACSView: user.company_name = original_company user.save() + def test_dispatch_rejects_assertion_email_domain_that_differs_from_slug( + self, tenants_fixture, saml_setup, monkeypatch + ): + monkeypatch.setenv("AUTH_URL", "http://localhost") + monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete") + victim_tenant = tenants_fixture[0] + attacker_tenant = tenants_fixture[1] + attacker_domain = "attacker.com" + + SAMLConfiguration.objects.using(MainRouter.admin_db).create( + email_domain=attacker_domain, + metadata_xml=""" + + + + + + TEST + + + + + + + """, + tenant=attacker_tenant, + ) + user = User.objects.using(MainRouter.admin_db).create( + email=f"intruder@{saml_setup['domain']}", name="Intruder" + ) + social_account = SocialAccount( + user=user, + provider="ATTACKER", + extra_data={ + "firstName": ["Mallory"], + "lastName": ["Example"], + }, + ) + request = RequestFactory().get( + reverse("saml_finish_acs", kwargs={"organization_slug": attacker_domain}) + ) + request.user = user + request.session = {} + + with ( + patch( + "allauth.socialaccount.providers.saml.views.get_app_or_404" + ) as mock_get_app_or_404, + patch( + "allauth.socialaccount.models.SocialAccount.objects.get" + ) as mock_sa_get, + ): + mock_get_app_or_404.return_value = MagicMock( + provider="saml", + provider_id="ATTACKER", + client_id=attacker_domain, + name="Attacker App", + settings={}, + ) + mock_sa_get.return_value = social_account + + view = TenantFinishACSView.as_view() + response = view(request, organization_slug=attacker_domain) + + assert response.status_code == 302 + assert "sso_saml_failed=true" in response.url + assert not ( + Membership.objects.using(MainRouter.admin_db) + .filter(user=user, tenant=victim_tenant) + .exists() + ) + assert not SAMLToken.objects.using(MainRouter.admin_db).filter(user=user).exists() + def test_rollback_saml_user_when_error_occurs(self, users_fixture, monkeypatch): """Test that a user is properly deleted when created during SAML flow and an error occurs""" monkeypatch.setenv("AUTH_URL", "http://localhost") @@ -13734,7 +13814,9 @@ class TestTenantFinishACSView: ) request = RequestFactory().get( - reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"}) + reverse( + "saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]} + ) ) request.user = user request.session = {} @@ -13754,16 +13836,21 @@ class TestTenantFinishACSView: patch("api.models.User.objects.get") as mock_user_get, ): mock_get_app_or_404.return_value = MagicMock( - provider="saml", client_id="testtenant", name="Test App", settings={} + provider="saml", + client_id=saml_setup["domain"], + name="Test App", + settings={}, ) mock_sa_get.return_value = social_account mock_socialapp_get.return_value = MagicMock(provider_id="saml") mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id) - mock_saml_config_get.return_value = MagicMock() + mock_saml_config_get.return_value = SimpleNamespace( + email_domain=saml_setup["domain"], tenant=tenant + ) mock_user_get.return_value = user view = TenantFinishACSView.as_view() - response = view(request, organization_slug="testtenant") + response = view(request, organization_slug=saml_setup["domain"]) assert response.status_code == 302 @@ -13802,7 +13889,9 @@ class TestTenantFinishACSView: ) request = RequestFactory().get( - reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"}) + reverse( + "saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]} + ) ) request.user = user request.session = {} @@ -13822,16 +13911,21 @@ class TestTenantFinishACSView: patch("api.models.User.objects.get") as mock_user_get, ): mock_get_app_or_404.return_value = MagicMock( - provider="saml", client_id="testtenant", name="Test App", settings={} + provider="saml", + client_id=saml_setup["domain"], + name="Test App", + settings={}, ) mock_sa_get.return_value = social_account mock_socialapp_get.return_value = MagicMock(provider_id="saml") mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id) - mock_saml_config_get.return_value = MagicMock() + mock_saml_config_get.return_value = SimpleNamespace( + email_domain=saml_setup["domain"], tenant=tenant + ) mock_user_get.return_value = user view = TenantFinishACSView.as_view() - response = view(request, organization_slug="testtenant") + response = view(request, organization_slug=saml_setup["domain"]) assert response.status_code == 302 @@ -13881,7 +13975,9 @@ class TestTenantFinishACSView: ) request = RequestFactory().get( - reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"}) + reverse( + "saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]} + ) ) request.user = user request.session = {} @@ -13901,16 +13997,21 @@ class TestTenantFinishACSView: patch("api.models.User.objects.get") as mock_user_get, ): mock_get_app_or_404.return_value = MagicMock( - provider="saml", client_id="testtenant", name="Test App", settings={} + provider="saml", + client_id=saml_setup["domain"], + name="Test App", + settings={}, ) mock_sa_get.return_value = social_account mock_socialapp_get.return_value = MagicMock(provider_id="saml") mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id) - mock_saml_config_get.return_value = MagicMock() + mock_saml_config_get.return_value = SimpleNamespace( + email_domain=saml_setup["domain"], tenant=tenant + ) mock_user_get.return_value = user view = TenantFinishACSView.as_view() - response = view(request, organization_slug="testtenant") + response = view(request, organization_slug=saml_setup["domain"]) assert response.status_code == 302 @@ -13959,7 +14060,9 @@ class TestTenantFinishACSView: ) request = RequestFactory().get( - reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"}) + reverse( + "saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]} + ) ) request.user = user request.session = {} @@ -13979,16 +14082,21 @@ class TestTenantFinishACSView: patch("api.models.User.objects.get") as mock_user_get, ): mock_get_app_or_404.return_value = MagicMock( - provider="saml", client_id="testtenant", name="Test App", settings={} + provider="saml", + client_id=saml_setup["domain"], + name="Test App", + settings={}, ) mock_sa_get.return_value = social_account mock_socialapp_get.return_value = MagicMock(provider_id="saml") mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id) - mock_saml_config_get.return_value = MagicMock() + mock_saml_config_get.return_value = SimpleNamespace( + email_domain=saml_setup["domain"], tenant=tenant + ) mock_user_get.return_value = user view = TenantFinishACSView.as_view() - response = view(request, organization_slug="testtenant") + response = view(request, organization_slug=saml_setup["domain"]) assert response.status_code == 302 @@ -14043,7 +14151,9 @@ class TestTenantFinishACSView: ) request = RequestFactory().get( - reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"}) + reverse( + "saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]} + ) ) request.user = user request.session = {} @@ -14063,16 +14173,21 @@ class TestTenantFinishACSView: patch("api.models.User.objects.get") as mock_user_get, ): mock_get_app_or_404.return_value = MagicMock( - provider="saml", client_id="testtenant", name="Test App", settings={} + provider="saml", + client_id=saml_setup["domain"], + name="Test App", + settings={}, ) mock_sa_get.return_value = social_account mock_socialapp_get.return_value = MagicMock(provider_id="saml") mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id) - mock_saml_config_get.return_value = MagicMock() + mock_saml_config_get.return_value = SimpleNamespace( + email_domain=saml_setup["domain"], tenant=tenant + ) mock_user_get.return_value = user view = TenantFinishACSView.as_view() - response = view(request, organization_slug="testtenant") + response = view(request, organization_slug=saml_setup["domain"]) assert response.status_code == 302 @@ -14126,7 +14241,9 @@ class TestTenantFinishACSView: ) request = RequestFactory().get( - reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"}) + reverse( + "saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]} + ) ) request.user = non_admin_user request.session = {} @@ -14146,16 +14263,21 @@ class TestTenantFinishACSView: patch("api.models.User.objects.get") as mock_user_get, ): mock_get_app_or_404.return_value = MagicMock( - provider="saml", client_id="testtenant", name="Test App", settings={} + provider="saml", + client_id=saml_setup["domain"], + name="Test App", + settings={}, ) mock_sa_get.return_value = social_account mock_socialapp_get.return_value = MagicMock(provider_id="saml") mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id) - mock_saml_config_get.return_value = MagicMock() + mock_saml_config_get.return_value = SimpleNamespace( + email_domain=saml_setup["domain"], tenant=tenant + ) mock_user_get.return_value = non_admin_user view = TenantFinishACSView.as_view() - response = view(request, organization_slug="testtenant") + response = view(request, organization_slug=saml_setup["domain"]) assert response.status_code == 302 diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 4cf39cdce1..6b2616b289 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -767,7 +767,10 @@ class TenantFinishACSView(FinishACSView): try: check = SAMLDomainIndex.objects.get(email_domain=organization_slug) with rls_transaction(str(check.tenant_id)): - SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id)) + saml_config = SAMLConfiguration.objects.select_related("tenant").get( + tenant_id=str(check.tenant_id) + ) + tenant = saml_config.tenant social_app = SocialApp.objects.get( provider="saml", client_id=organization_slug ) @@ -787,6 +790,15 @@ class TenantFinishACSView(FinishACSView): callback_url = env.str("AUTH_URL") return redirect(f"{callback_url}?sso_saml_failed=true") + requested_domain = organization_slug.lower() + configured_domain = saml_config.email_domain.lower() + email_domain = user.email.rsplit("@", 1)[-1].lower() + if configured_domain != requested_domain or email_domain != configured_domain: + logger.error("SAML email domain does not match requested organization") + self._rollback_saml_user(request) + callback_url = env.str("AUTH_URL") + return redirect(f"{callback_url}?sso_saml_failed=true") + extra = social_account.extra_data user.first_name = ( extra.get("firstName", [""])[0] if extra.get("firstName") else "" @@ -800,13 +812,6 @@ class TenantFinishACSView(FinishACSView): user.name = "N/A" user.save() - email_domain = user.email.split("@")[-1] - tenant = ( - SAMLConfiguration.objects.using(MainRouter.admin_db) - .get(email_domain=email_domain) - .tenant - ) - # Only remap roles when the IdP provides a userType attribute. # Without it, the user's current roles are left untouched. role_name = (