mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
Merge commit from fork
* fix(saml): cross-tenant account takeover via SAML domain claiming * chore(changelog): add PR # * fix(api): bind SAML tokens to validated domain - Reject SAML assertions with mismatched email domains - Issue SAML tokens from the validated ACS tenant - Add regression coverage for cross-tenant SAML token issuance * fix(api): resolve SAML tenant inside RLS context - Load the SAML tenant relation before leaving the RLS transaction - Avoid lazy tenant lookups during the SAML ACS finish flow --------- Co-authored-by: Pepe Fagoaga <pepe@prowler.com>
This commit is contained in:
@@ -28,6 +28,14 @@ All notable changes to the **Prowler API** are documented in this file.
|
||||
|
||||
---
|
||||
|
||||
## [1.31.3] (Prowler v5.30.3)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
- SAML logins now link to an existing account only when the asserted email domain matches the ACS endpoint and the user is already a member of that domain's tenant, fixing a cross-tenant account takeover [(GHSA-h8m9-jgf8-vwvp)](https://github.com/prowler-cloud/prowler/security/advisories/GHSA-h8m9-jgf8-vwvp) [(#XXXXX)](https://github.com/prowler-cloud/prowler/pull/XXXXX)
|
||||
|
||||
---
|
||||
|
||||
## [1.31.2] (Prowler v5.30.2)
|
||||
|
||||
### 🔄 Changed
|
||||
|
||||
@@ -3,7 +3,14 @@ from django.db import transaction
|
||||
|
||||
from api.db_router import MainRouter
|
||||
from api.db_utils import rls_transaction
|
||||
from api.models import Membership, Role, Tenant, User, UserRoleRelationship
|
||||
from api.models import (
|
||||
Membership,
|
||||
Role,
|
||||
SAMLConfiguration,
|
||||
Tenant,
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
|
||||
|
||||
class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
@@ -18,7 +25,42 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
# Link existing accounts with the same email address
|
||||
email = sociallogin.account.extra_data.get("email")
|
||||
if sociallogin.provider.id == "saml":
|
||||
# For SAML, the asserted NameID email cannot be trusted on its own:
|
||||
# any tenant can claim any email domain in its SAML configuration. To
|
||||
# prevent cross-tenant account takeover (GHSA-h8m9-jgf8-vwvp), only link
|
||||
# the incoming SAML session to an existing account when (1) the email
|
||||
# domain matches the tenant whose ACS endpoint is being used and (2) the
|
||||
# existing user is already a member of that tenant.
|
||||
email = sociallogin.user.email
|
||||
if not email:
|
||||
return
|
||||
|
||||
domain = email.rsplit("@", 1)[-1].lower()
|
||||
resolver_match = getattr(request, "resolver_match", None)
|
||||
organization_slug = (
|
||||
(resolver_match.kwargs or {}).get("organization_slug", "")
|
||||
if resolver_match
|
||||
else ""
|
||||
).lower()
|
||||
# The ACS endpoint is scoped per email domain; reject mismatches so an
|
||||
# attacker cannot replay an assertion through another tenant's endpoint.
|
||||
if organization_slug != domain:
|
||||
return
|
||||
|
||||
try:
|
||||
saml_config = SAMLConfiguration.objects.using(MainRouter.admin_db).get(
|
||||
email_domain=domain
|
||||
)
|
||||
except SAMLConfiguration.DoesNotExist:
|
||||
return
|
||||
|
||||
existing_user = self.get_user_by_email(email)
|
||||
if existing_user and existing_user.is_member_of_tenant(
|
||||
str(saml_config.tenant_id)
|
||||
):
|
||||
sociallogin.connect(request, existing_user)
|
||||
return
|
||||
|
||||
if email:
|
||||
existing_user = self.get_user_by_email(email)
|
||||
if existing_user:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -5,9 +6,48 @@ from allauth.socialaccount.models import SocialLogin
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
from api.adapters import ProwlerSocialAccountAdapter
|
||||
from api.db_router import MainRouter
|
||||
from api.models import SAMLConfiguration
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
# Minimal, well-formed IdP metadata accepted by SAMLConfiguration._parse_metadata.
|
||||
VALID_METADATA = """<?xml version='1.0' encoding='UTF-8'?>
|
||||
<md:EntityDescriptor entityID='TEST' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
|
||||
<md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
|
||||
<md:KeyDescriptor use='signing'>
|
||||
<ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
|
||||
<ds:X509Data>
|
||||
<ds:X509Certificate>FAKECERTDATA</ds:X509Certificate>
|
||||
</ds:X509Data>
|
||||
</ds:KeyInfo>
|
||||
</md:KeyDescriptor>
|
||||
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://idp.test/sso'/>
|
||||
</md:IDPSSODescriptor>
|
||||
</md:EntityDescriptor>
|
||||
"""
|
||||
|
||||
|
||||
def _saml_request(rf, organization_slug):
|
||||
"""Build an ACS request whose resolver_match carries the organization slug,
|
||||
mirroring how Django populates it after routing the SAML ACS URL."""
|
||||
request = rf.post(f"/api/v1/accounts/saml/{organization_slug}/acs/finish/")
|
||||
request.resolver_match = SimpleNamespace(
|
||||
kwargs={"organization_slug": organization_slug}
|
||||
)
|
||||
return request
|
||||
|
||||
|
||||
def _saml_sociallogin(user):
|
||||
sociallogin = MagicMock(spec=SocialLogin)
|
||||
sociallogin.account = MagicMock()
|
||||
sociallogin.provider = MagicMock()
|
||||
sociallogin.provider.id = "saml"
|
||||
sociallogin.account.extra_data = {}
|
||||
sociallogin.user = user
|
||||
sociallogin.connect = MagicMock()
|
||||
return sociallogin
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestProwlerSocialAccountAdapter:
|
||||
@@ -20,26 +60,99 @@ class TestProwlerSocialAccountAdapter:
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
assert adapter.get_user_by_email("notfound@example.com") is None
|
||||
|
||||
def test_pre_social_login_links_existing_user(self, create_test_user, rf):
|
||||
def test_pre_social_login_links_member_of_saml_tenant(
|
||||
self, create_test_user, tenants_fixture, rf
|
||||
):
|
||||
"""A SAML login links to an existing account only when that user is
|
||||
already a member of the tenant that owns the asserted email domain."""
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
# create_test_user (dev@prowler.com) is a member of tenant1.
|
||||
domain = create_test_user.email.rsplit("@", 1)[-1]
|
||||
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
|
||||
email_domain=domain,
|
||||
metadata_xml=VALID_METADATA,
|
||||
tenant=tenants_fixture[0],
|
||||
)
|
||||
|
||||
sociallogin = MagicMock(spec=SocialLogin)
|
||||
sociallogin.account = MagicMock()
|
||||
sociallogin.provider = MagicMock()
|
||||
sociallogin.provider.id = "saml"
|
||||
sociallogin.account.extra_data = {}
|
||||
sociallogin.user = create_test_user
|
||||
sociallogin.connect = MagicMock()
|
||||
|
||||
adapter.pre_social_login(rf.get("/"), sociallogin)
|
||||
sociallogin = _saml_sociallogin(create_test_user)
|
||||
adapter.pre_social_login(_saml_request(rf, domain), sociallogin)
|
||||
|
||||
call_args = sociallogin.connect.call_args
|
||||
assert call_args is not None
|
||||
|
||||
called_request, called_user = call_args[0]
|
||||
assert called_request.path == "/"
|
||||
_, called_user = call_args[0]
|
||||
assert called_user.email == create_test_user.email
|
||||
|
||||
def test_pre_social_login_blocks_cross_tenant_takeover(
|
||||
self, create_test_user, tenants_fixture, rf
|
||||
):
|
||||
"""GHSA-h8m9-jgf8-vwvp: an attacker tenant that claims the victim's
|
||||
email domain must NOT be able to link to the victim's account, because
|
||||
the victim is not a member of the attacker's tenant."""
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
domain = create_test_user.email.rsplit("@", 1)[-1]
|
||||
# tenant3 is the attacker tenant; create_test_user is NOT a member of it.
|
||||
attacker_tenant = tenants_fixture[2]
|
||||
assert not create_test_user.is_member_of_tenant(str(attacker_tenant.id))
|
||||
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
|
||||
email_domain=domain,
|
||||
metadata_xml=VALID_METADATA,
|
||||
tenant=attacker_tenant,
|
||||
)
|
||||
|
||||
sociallogin = _saml_sociallogin(create_test_user)
|
||||
adapter.pre_social_login(_saml_request(rf, domain), sociallogin)
|
||||
|
||||
sociallogin.connect.assert_not_called()
|
||||
|
||||
def test_pre_social_login_blocks_domain_slug_mismatch(
|
||||
self, create_test_user, tenants_fixture, rf
|
||||
):
|
||||
"""The asserted email domain must match the ACS endpoint's slug, so an
|
||||
assertion cannot be replayed through a different tenant's endpoint."""
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
domain = create_test_user.email.rsplit("@", 1)[-1]
|
||||
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
|
||||
email_domain=domain,
|
||||
metadata_xml=VALID_METADATA,
|
||||
tenant=tenants_fixture[0],
|
||||
)
|
||||
|
||||
sociallogin = _saml_sociallogin(create_test_user)
|
||||
# Slug points at a different domain than the asserted email.
|
||||
adapter.pre_social_login(_saml_request(rf, "attacker.com"), sociallogin)
|
||||
|
||||
sociallogin.connect.assert_not_called()
|
||||
|
||||
def test_pre_social_login_blocks_when_no_saml_config(
|
||||
self, create_test_user, tenants_fixture, rf
|
||||
):
|
||||
"""No SAML configuration for the domain means nothing to link against."""
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
domain = create_test_user.email.rsplit("@", 1)[-1]
|
||||
|
||||
sociallogin = _saml_sociallogin(create_test_user)
|
||||
adapter.pre_social_login(_saml_request(rf, domain), sociallogin)
|
||||
|
||||
sociallogin.connect.assert_not_called()
|
||||
|
||||
def test_pre_social_login_blocks_without_resolver_match(
|
||||
self, create_test_user, tenants_fixture, rf
|
||||
):
|
||||
"""Fail closed: if the request has no resolver_match we cannot bind the
|
||||
assertion to a tenant, so no linking happens."""
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
domain = create_test_user.email.rsplit("@", 1)[-1]
|
||||
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
|
||||
email_domain=domain,
|
||||
metadata_xml=VALID_METADATA,
|
||||
tenant=tenants_fixture[0],
|
||||
)
|
||||
|
||||
sociallogin = _saml_sociallogin(create_test_user)
|
||||
adapter.pre_social_login(rf.post("/"), sociallogin)
|
||||
|
||||
sociallogin.connect.assert_not_called()
|
||||
|
||||
def test_pre_social_login_no_link_if_email_missing(self, rf):
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
|
||||
@@ -47,14 +160,35 @@ class TestProwlerSocialAccountAdapter:
|
||||
sociallogin.account = MagicMock()
|
||||
sociallogin.provider = MagicMock()
|
||||
sociallogin.user = MagicMock()
|
||||
sociallogin.user.email = ""
|
||||
sociallogin.provider.id = "saml"
|
||||
sociallogin.account.extra_data = {}
|
||||
sociallogin.connect = MagicMock()
|
||||
|
||||
adapter.pre_social_login(rf.get("/"), sociallogin)
|
||||
adapter.pre_social_login(_saml_request(rf, "prowler.com"), sociallogin)
|
||||
|
||||
sociallogin.connect.assert_not_called()
|
||||
|
||||
def test_pre_social_login_non_saml_links_by_email(self, create_test_user, rf):
|
||||
"""Non-SAML providers (e.g. Google/GitHub) still link to an existing
|
||||
local account by email; the tenant binding only applies to SAML."""
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
|
||||
sociallogin = MagicMock(spec=SocialLogin)
|
||||
sociallogin.account = MagicMock()
|
||||
sociallogin.provider = MagicMock()
|
||||
sociallogin.provider.id = "google"
|
||||
sociallogin.account.extra_data = {"email": create_test_user.email}
|
||||
sociallogin.user = create_test_user
|
||||
sociallogin.connect = MagicMock()
|
||||
|
||||
adapter.pre_social_login(rf.get("/"), sociallogin)
|
||||
|
||||
call_args = sociallogin.connect.call_args
|
||||
assert call_args is not None
|
||||
_, called_user = call_args[0]
|
||||
assert called_user.email == create_test_user.email
|
||||
|
||||
def test_save_user_saml_sets_session_flag(self, rf):
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
request = rf.get("/")
|
||||
|
||||
@@ -13586,7 +13586,9 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
|
||||
reverse(
|
||||
"saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]}
|
||||
)
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
@@ -13606,18 +13608,23 @@ class TestTenantFinishACSView:
|
||||
patch("api.models.User.objects.get") as mock_user_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml", client_id="testtenant", name="Test App", settings={}
|
||||
provider="saml",
|
||||
client_id=saml_setup["domain"],
|
||||
name="Test App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
|
||||
mock_saml_domain_get.return_value = SimpleNamespace(
|
||||
tenant_id=tenants_fixture[0].id
|
||||
)
|
||||
mock_saml_config_get.return_value = MagicMock()
|
||||
mock_saml_config_get.return_value = SimpleNamespace(
|
||||
email_domain=saml_setup["domain"], tenant=tenants_fixture[0]
|
||||
)
|
||||
mock_user_get.return_value = user
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug="testtenant")
|
||||
response = view(request, organization_slug=saml_setup["domain"])
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
@@ -13665,6 +13672,79 @@ class TestTenantFinishACSView:
|
||||
user.company_name = original_company
|
||||
user.save()
|
||||
|
||||
def test_dispatch_rejects_assertion_email_domain_that_differs_from_slug(
|
||||
self, tenants_fixture, saml_setup, monkeypatch
|
||||
):
|
||||
monkeypatch.setenv("AUTH_URL", "http://localhost")
|
||||
monkeypatch.setenv("SAML_SSO_CALLBACK_URL", "http://localhost/sso-complete")
|
||||
victim_tenant = tenants_fixture[0]
|
||||
attacker_tenant = tenants_fixture[1]
|
||||
attacker_domain = "attacker.com"
|
||||
|
||||
SAMLConfiguration.objects.using(MainRouter.admin_db).create(
|
||||
email_domain=attacker_domain,
|
||||
metadata_xml="""<?xml version='1.0' encoding='UTF-8'?>
|
||||
<md:EntityDescriptor entityID='ATTACKER' xmlns:md='urn:oasis:names:tc:SAML:2.0:metadata'>
|
||||
<md:IDPSSODescriptor WantAuthnRequestsSigned='false' protocolSupportEnumeration='urn:oasis:names:tc:SAML:2.0:protocol'>
|
||||
<md:KeyDescriptor use='signing'>
|
||||
<ds:KeyInfo xmlns:ds='http://www.w3.org/2000/09/xmldsig#'>
|
||||
<ds:X509Data>
|
||||
<ds:X509Certificate>TEST</ds:X509Certificate>
|
||||
</ds:X509Data>
|
||||
</ds:KeyInfo>
|
||||
</md:KeyDescriptor>
|
||||
<md:SingleSignOnService Binding='urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST' Location='https://ATTACKER/sso/saml'/>
|
||||
</md:IDPSSODescriptor>
|
||||
</md:EntityDescriptor>
|
||||
""",
|
||||
tenant=attacker_tenant,
|
||||
)
|
||||
user = User.objects.using(MainRouter.admin_db).create(
|
||||
email=f"intruder@{saml_setup['domain']}", name="Intruder"
|
||||
)
|
||||
social_account = SocialAccount(
|
||||
user=user,
|
||||
provider="ATTACKER",
|
||||
extra_data={
|
||||
"firstName": ["Mallory"],
|
||||
"lastName": ["Example"],
|
||||
},
|
||||
)
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": attacker_domain})
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"allauth.socialaccount.providers.saml.views.get_app_or_404"
|
||||
) as mock_get_app_or_404,
|
||||
patch(
|
||||
"allauth.socialaccount.models.SocialAccount.objects.get"
|
||||
) as mock_sa_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml",
|
||||
provider_id="ATTACKER",
|
||||
client_id=attacker_domain,
|
||||
name="Attacker App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug=attacker_domain)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "sso_saml_failed=true" in response.url
|
||||
assert not (
|
||||
Membership.objects.using(MainRouter.admin_db)
|
||||
.filter(user=user, tenant=victim_tenant)
|
||||
.exists()
|
||||
)
|
||||
assert not SAMLToken.objects.using(MainRouter.admin_db).filter(user=user).exists()
|
||||
|
||||
def test_rollback_saml_user_when_error_occurs(self, users_fixture, monkeypatch):
|
||||
"""Test that a user is properly deleted when created during SAML flow and an error occurs"""
|
||||
monkeypatch.setenv("AUTH_URL", "http://localhost")
|
||||
@@ -13734,7 +13814,9 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
|
||||
reverse(
|
||||
"saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]}
|
||||
)
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
@@ -13754,16 +13836,21 @@ class TestTenantFinishACSView:
|
||||
patch("api.models.User.objects.get") as mock_user_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml", client_id="testtenant", name="Test App", settings={}
|
||||
provider="saml",
|
||||
client_id=saml_setup["domain"],
|
||||
name="Test App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
|
||||
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
|
||||
mock_saml_config_get.return_value = MagicMock()
|
||||
mock_saml_config_get.return_value = SimpleNamespace(
|
||||
email_domain=saml_setup["domain"], tenant=tenant
|
||||
)
|
||||
mock_user_get.return_value = user
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug="testtenant")
|
||||
response = view(request, organization_slug=saml_setup["domain"])
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
@@ -13802,7 +13889,9 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
|
||||
reverse(
|
||||
"saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]}
|
||||
)
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
@@ -13822,16 +13911,21 @@ class TestTenantFinishACSView:
|
||||
patch("api.models.User.objects.get") as mock_user_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml", client_id="testtenant", name="Test App", settings={}
|
||||
provider="saml",
|
||||
client_id=saml_setup["domain"],
|
||||
name="Test App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
|
||||
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
|
||||
mock_saml_config_get.return_value = MagicMock()
|
||||
mock_saml_config_get.return_value = SimpleNamespace(
|
||||
email_domain=saml_setup["domain"], tenant=tenant
|
||||
)
|
||||
mock_user_get.return_value = user
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug="testtenant")
|
||||
response = view(request, organization_slug=saml_setup["domain"])
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
@@ -13881,7 +13975,9 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
|
||||
reverse(
|
||||
"saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]}
|
||||
)
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
@@ -13901,16 +13997,21 @@ class TestTenantFinishACSView:
|
||||
patch("api.models.User.objects.get") as mock_user_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml", client_id="testtenant", name="Test App", settings={}
|
||||
provider="saml",
|
||||
client_id=saml_setup["domain"],
|
||||
name="Test App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
|
||||
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
|
||||
mock_saml_config_get.return_value = MagicMock()
|
||||
mock_saml_config_get.return_value = SimpleNamespace(
|
||||
email_domain=saml_setup["domain"], tenant=tenant
|
||||
)
|
||||
mock_user_get.return_value = user
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug="testtenant")
|
||||
response = view(request, organization_slug=saml_setup["domain"])
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
@@ -13959,7 +14060,9 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
|
||||
reverse(
|
||||
"saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]}
|
||||
)
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
@@ -13979,16 +14082,21 @@ class TestTenantFinishACSView:
|
||||
patch("api.models.User.objects.get") as mock_user_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml", client_id="testtenant", name="Test App", settings={}
|
||||
provider="saml",
|
||||
client_id=saml_setup["domain"],
|
||||
name="Test App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
|
||||
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
|
||||
mock_saml_config_get.return_value = MagicMock()
|
||||
mock_saml_config_get.return_value = SimpleNamespace(
|
||||
email_domain=saml_setup["domain"], tenant=tenant
|
||||
)
|
||||
mock_user_get.return_value = user
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug="testtenant")
|
||||
response = view(request, organization_slug=saml_setup["domain"])
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
@@ -14043,7 +14151,9 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
|
||||
reverse(
|
||||
"saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]}
|
||||
)
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
@@ -14063,16 +14173,21 @@ class TestTenantFinishACSView:
|
||||
patch("api.models.User.objects.get") as mock_user_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml", client_id="testtenant", name="Test App", settings={}
|
||||
provider="saml",
|
||||
client_id=saml_setup["domain"],
|
||||
name="Test App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
|
||||
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
|
||||
mock_saml_config_get.return_value = MagicMock()
|
||||
mock_saml_config_get.return_value = SimpleNamespace(
|
||||
email_domain=saml_setup["domain"], tenant=tenant
|
||||
)
|
||||
mock_user_get.return_value = user
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug="testtenant")
|
||||
response = view(request, organization_slug=saml_setup["domain"])
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
@@ -14126,7 +14241,9 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
|
||||
request = RequestFactory().get(
|
||||
reverse("saml_finish_acs", kwargs={"organization_slug": "testtenant"})
|
||||
reverse(
|
||||
"saml_finish_acs", kwargs={"organization_slug": saml_setup["domain"]}
|
||||
)
|
||||
)
|
||||
request.user = non_admin_user
|
||||
request.session = {}
|
||||
@@ -14146,16 +14263,21 @@ class TestTenantFinishACSView:
|
||||
patch("api.models.User.objects.get") as mock_user_get,
|
||||
):
|
||||
mock_get_app_or_404.return_value = MagicMock(
|
||||
provider="saml", client_id="testtenant", name="Test App", settings={}
|
||||
provider="saml",
|
||||
client_id=saml_setup["domain"],
|
||||
name="Test App",
|
||||
settings={},
|
||||
)
|
||||
mock_sa_get.return_value = social_account
|
||||
mock_socialapp_get.return_value = MagicMock(provider_id="saml")
|
||||
mock_saml_domain_get.return_value = SimpleNamespace(tenant_id=tenant.id)
|
||||
mock_saml_config_get.return_value = MagicMock()
|
||||
mock_saml_config_get.return_value = SimpleNamespace(
|
||||
email_domain=saml_setup["domain"], tenant=tenant
|
||||
)
|
||||
mock_user_get.return_value = non_admin_user
|
||||
|
||||
view = TenantFinishACSView.as_view()
|
||||
response = view(request, organization_slug="testtenant")
|
||||
response = view(request, organization_slug=saml_setup["domain"])
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
|
||||
@@ -767,7 +767,10 @@ class TenantFinishACSView(FinishACSView):
|
||||
try:
|
||||
check = SAMLDomainIndex.objects.get(email_domain=organization_slug)
|
||||
with rls_transaction(str(check.tenant_id)):
|
||||
SAMLConfiguration.objects.get(tenant_id=str(check.tenant_id))
|
||||
saml_config = SAMLConfiguration.objects.select_related("tenant").get(
|
||||
tenant_id=str(check.tenant_id)
|
||||
)
|
||||
tenant = saml_config.tenant
|
||||
social_app = SocialApp.objects.get(
|
||||
provider="saml", client_id=organization_slug
|
||||
)
|
||||
@@ -787,6 +790,15 @@ class TenantFinishACSView(FinishACSView):
|
||||
callback_url = env.str("AUTH_URL")
|
||||
return redirect(f"{callback_url}?sso_saml_failed=true")
|
||||
|
||||
requested_domain = organization_slug.lower()
|
||||
configured_domain = saml_config.email_domain.lower()
|
||||
email_domain = user.email.rsplit("@", 1)[-1].lower()
|
||||
if configured_domain != requested_domain or email_domain != configured_domain:
|
||||
logger.error("SAML email domain does not match requested organization")
|
||||
self._rollback_saml_user(request)
|
||||
callback_url = env.str("AUTH_URL")
|
||||
return redirect(f"{callback_url}?sso_saml_failed=true")
|
||||
|
||||
extra = social_account.extra_data
|
||||
user.first_name = (
|
||||
extra.get("firstName", [""])[0] if extra.get("firstName") else ""
|
||||
@@ -800,13 +812,6 @@ class TenantFinishACSView(FinishACSView):
|
||||
user.name = "N/A"
|
||||
user.save()
|
||||
|
||||
email_domain = user.email.split("@")[-1]
|
||||
tenant = (
|
||||
SAMLConfiguration.objects.using(MainRouter.admin_db)
|
||||
.get(email_domain=email_domain)
|
||||
.tenant
|
||||
)
|
||||
|
||||
# Only remap roles when the IdP provides a userType attribute.
|
||||
# Without it, the user's current roles are left untouched.
|
||||
role_name = (
|
||||
|
||||
Reference in New Issue
Block a user