mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
fix: handle invitations in social and SAML auth
- Preserve invitation callbacks through social and SAML login - Accept invited users without creating a default tenant - Add regression coverage for invitation acceptance flows
This commit is contained in:
@@ -9,6 +9,7 @@ from api.models import (
|
||||
User,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from api.utils import accept_invitation_for_user
|
||||
from django.db import transaction
|
||||
|
||||
|
||||
@@ -20,6 +21,22 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
except User.DoesNotExist:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_invitation_token(request):
|
||||
for source_name in ("data", "POST"):
|
||||
data = getattr(request, source_name, None) or {}
|
||||
if not hasattr(data, "get"):
|
||||
continue
|
||||
invitation_token = data.get("invitation_token")
|
||||
if invitation_token:
|
||||
return invitation_token
|
||||
|
||||
wrapped_request = getattr(request, "_request", None)
|
||||
if wrapped_request and wrapped_request is not request:
|
||||
return ProwlerSocialAccountAdapter._get_invitation_token(wrapped_request)
|
||||
|
||||
return None
|
||||
|
||||
def pre_social_login(self, request, sociallogin):
|
||||
# Link existing accounts with the same email address
|
||||
email = sociallogin.account.extra_data.get("email")
|
||||
@@ -83,29 +100,38 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
user.name = social_account_name
|
||||
user.save(using=MainRouter.admin_db)
|
||||
|
||||
tenant = Tenant.objects.using(MainRouter.admin_db).create(
|
||||
name=f"{user.email.split('@')[0]} default tenant"
|
||||
)
|
||||
with rls_transaction(str(tenant.id)):
|
||||
Membership.objects.using(MainRouter.admin_db).create(
|
||||
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
|
||||
)
|
||||
role = Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
invitation_token = self._get_invitation_token(request)
|
||||
if invitation_token:
|
||||
invitation, _ = accept_invitation_for_user(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
invitation_token=invitation_token,
|
||||
)
|
||||
request.prowler_invitation_token = invitation_token
|
||||
request.prowler_invitation_tenant_id = str(invitation.tenant_id)
|
||||
else:
|
||||
tenant = Tenant.objects.using(MainRouter.admin_db).create(
|
||||
name=f"{user.email.split('@')[0]} default tenant"
|
||||
)
|
||||
with rls_transaction(str(tenant.id)):
|
||||
Membership.objects.using(MainRouter.admin_db).create(
|
||||
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
|
||||
)
|
||||
role = Role.objects.using(MainRouter.admin_db).create(
|
||||
name="admin",
|
||||
tenant_id=tenant.id,
|
||||
manage_users=True,
|
||||
manage_account=True,
|
||||
manage_billing=True,
|
||||
manage_providers=True,
|
||||
manage_integrations=True,
|
||||
manage_scans=True,
|
||||
unlimited_visibility=True,
|
||||
)
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user,
|
||||
role=role,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
else:
|
||||
request.session["saml_user_created"] = str(user.id)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from allauth.socialaccount.models import SocialLogin
|
||||
from api.adapters import ProwlerSocialAccountAdapter
|
||||
from api.db_router import MainRouter
|
||||
from api.models import SAMLConfiguration
|
||||
from api.models import Invitation, Membership, SAMLConfiguration, Tenant
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
@@ -188,6 +188,44 @@ class TestProwlerSocialAccountAdapter:
|
||||
_, called_user = call_args[0]
|
||||
assert called_user.email == create_test_user.email
|
||||
|
||||
def test_save_user_social_with_invitation_joins_invited_tenant(
|
||||
self, rf, create_test_user, tenants_fixture
|
||||
):
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
invited_tenant = tenants_fixture[2]
|
||||
invited_email = "frank-invited@example.com"
|
||||
invitation = Invitation.objects.create(
|
||||
tenant=invited_tenant,
|
||||
email=invited_email,
|
||||
inviter=create_test_user,
|
||||
)
|
||||
request = rf.post("/", data={"invitation_token": invitation.token})
|
||||
request.session = {}
|
||||
|
||||
sociallogin = MagicMock(spec=SocialLogin)
|
||||
sociallogin.provider = MagicMock()
|
||||
sociallogin.provider.id = "google"
|
||||
sociallogin.account = MagicMock()
|
||||
sociallogin.account.extra_data = {"name": "Frank"}
|
||||
|
||||
real_user = User.objects.create_user(
|
||||
name="Frank", email=invited_email, password="Secret123!"
|
||||
)
|
||||
tenants_before = Tenant.objects.count()
|
||||
|
||||
with patch("api.adapters.super") as mock_super:
|
||||
mock_super.return_value.save_user.return_value = real_user
|
||||
adapter.save_user(request, sociallogin)
|
||||
|
||||
invitation.refresh_from_db()
|
||||
assert invitation.state == Invitation.State.ACCEPTED
|
||||
assert Tenant.objects.count() == tenants_before
|
||||
assert Membership.objects.filter(
|
||||
user=real_user,
|
||||
tenant=invited_tenant,
|
||||
role=Membership.RoleChoices.MEMBER,
|
||||
).exists()
|
||||
|
||||
def test_save_user_saml_sets_session_flag(self, rf):
|
||||
adapter = ProwlerSocialAccountAdapter()
|
||||
request = rf.get("/")
|
||||
|
||||
@@ -8356,6 +8356,39 @@ class TestInvitationViewSet:
|
||||
).exists()
|
||||
assert invitation.state == Invitation.State.ACCEPTED.value
|
||||
|
||||
def test_invitations_accept_invitation_existing_membership(
|
||||
self,
|
||||
authenticated_client,
|
||||
create_test_user,
|
||||
tenants_fixture,
|
||||
):
|
||||
*_, tenant = tenants_fixture
|
||||
user = create_test_user
|
||||
|
||||
invitation = Invitation.objects.create(
|
||||
tenant=tenant,
|
||||
email=TEST_USER,
|
||||
inviter=user,
|
||||
expires_at=self.TOMORROW,
|
||||
)
|
||||
Membership.objects.create(user=user, tenant=tenant)
|
||||
|
||||
response = authenticated_client.post(
|
||||
reverse("invitation-accept"),
|
||||
data={"invitation_token": invitation.token},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
invitation.refresh_from_db()
|
||||
assert invitation.state == Invitation.State.ACCEPTED.value
|
||||
assert (
|
||||
Membership.objects.filter(
|
||||
user__email__iexact=user.email, tenant=tenant
|
||||
).count()
|
||||
== 1
|
||||
)
|
||||
|
||||
def test_invitations_accept_invitation_invalid_token(self, authenticated_client):
|
||||
data = {
|
||||
"invitation_token": "invalid_token",
|
||||
@@ -13458,6 +13491,37 @@ class TestSAMLInitiateAPIView:
|
||||
)
|
||||
assert "SAMLRequest" not in response.url
|
||||
|
||||
def test_valid_email_domain_preserves_safe_callback_url(
|
||||
self, authenticated_client, saml_setup
|
||||
):
|
||||
url = reverse("api_saml_initiate")
|
||||
callback_url = "/invitation/accept?invitation_token=test-token"
|
||||
payload = {
|
||||
"email_domain": saml_setup["email"],
|
||||
"callback_url": callback_url,
|
||||
}
|
||||
|
||||
response = authenticated_client.post(url, data=payload, format="json")
|
||||
|
||||
assert response.status_code == status.HTTP_302_FOUND
|
||||
query_params = parse_qs(urlparse(response.url).query)
|
||||
assert query_params["callback_url"] == [callback_url]
|
||||
|
||||
def test_valid_email_domain_rejects_external_callback_url(
|
||||
self, authenticated_client, saml_setup
|
||||
):
|
||||
url = reverse("api_saml_initiate")
|
||||
payload = {
|
||||
"email_domain": saml_setup["email"],
|
||||
"callback_url": "https://attacker.example/invitation",
|
||||
}
|
||||
|
||||
response = authenticated_client.post(url, data=payload, format="json")
|
||||
|
||||
assert response.status_code == status.HTTP_302_FOUND
|
||||
query_params = parse_qs(urlparse(response.url).query)
|
||||
assert "callback_url" not in query_params
|
||||
|
||||
def test_invalid_email_domain(self, authenticated_client):
|
||||
url = reverse("api_saml_initiate")
|
||||
payload = {"email_domain": "user@unauthorized.com"}
|
||||
@@ -13645,7 +13709,8 @@ class TestTenantFinishACSView:
|
||||
)
|
||||
)
|
||||
request.user = user
|
||||
request.session = {}
|
||||
callback_url = "/invitation/accept?invitation_token=test-token"
|
||||
request.session = {"saml_callback_url": callback_url}
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -13687,6 +13752,7 @@ class TestTenantFinishACSView:
|
||||
assert parsed_url.netloc == expected_callback_host
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
assert "id" in query_params
|
||||
assert query_params["callbackUrl"] == [callback_url]
|
||||
|
||||
token_id = query_params["id"][0]
|
||||
token_obj = SAMLToken.objects.get(id=token_id)
|
||||
|
||||
@@ -7,7 +7,16 @@ from allauth.socialaccount.providers.oauth2.client import OAuth2Client
|
||||
from api.db_router import MainRouter
|
||||
from api.db_utils import rls_transaction
|
||||
from api.exceptions import InvitationTokenExpiredException
|
||||
from api.models import Integration, Invitation, Processor, Provider, Resource
|
||||
from api.models import (
|
||||
Integration,
|
||||
Invitation,
|
||||
Membership,
|
||||
Processor,
|
||||
Provider,
|
||||
Resource,
|
||||
Role,
|
||||
UserRoleRelationship,
|
||||
)
|
||||
from api.v1.serializers import FindingMetadataSerializer
|
||||
from django.contrib.postgres.aggregates import ArrayAgg
|
||||
from django.db.models import Subquery
|
||||
@@ -538,6 +547,32 @@ def validate_invitation(
|
||||
return invitation
|
||||
|
||||
|
||||
def accept_invitation_for_user(
|
||||
*, user, invitation_token: str, raise_not_found: bool = False
|
||||
):
|
||||
invitation = validate_invitation(
|
||||
invitation_token, user.email, raise_not_found=raise_not_found
|
||||
)
|
||||
membership, _ = Membership.objects.using(MainRouter.admin_db).get_or_create(
|
||||
user=user,
|
||||
tenant=invitation.tenant,
|
||||
defaults={"role": Membership.RoleChoices.MEMBER},
|
||||
)
|
||||
invitation_roles = Role.objects.using(MainRouter.admin_db).filter(
|
||||
invitations=invitation
|
||||
)
|
||||
for role in invitation_roles:
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).get_or_create(
|
||||
user=user,
|
||||
role=role,
|
||||
defaults={"tenant": invitation.tenant},
|
||||
)
|
||||
|
||||
invitation.state = Invitation.State.ACCEPTED
|
||||
invitation.save(using=MainRouter.admin_db)
|
||||
return invitation, membership
|
||||
|
||||
|
||||
# ToRemove after removing the fallback mechanism in /findings/metadata
|
||||
def get_findings_metadata_no_aggregations(tenant_id: str, filtered_queryset):
|
||||
filtered_ids = filtered_queryset.order_by().values("id")
|
||||
|
||||
@@ -3147,6 +3147,7 @@ class ProcessorUpdateSerializer(BaseWriteSerializer):
|
||||
|
||||
class SamlInitiateSerializer(BaseSerializerV1):
|
||||
email_domain = serializers.CharField()
|
||||
callback_url = serializers.CharField(required=False, allow_blank=True)
|
||||
|
||||
class JSONAPIMeta:
|
||||
resource_name = "saml-initiate"
|
||||
|
||||
@@ -9,7 +9,7 @@ from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlencode, urljoin
|
||||
|
||||
import sentry_sdk
|
||||
from allauth.socialaccount.models import SocialAccount, SocialApp
|
||||
@@ -128,6 +128,7 @@ from api.renderers import APIJSONRenderer, PlainTextRenderer
|
||||
from api.rls import Tenant
|
||||
from api.utils import (
|
||||
CustomOAuth2Client,
|
||||
accept_invitation_for_user,
|
||||
get_findings_metadata_no_aggregations,
|
||||
initialize_prowler_integration,
|
||||
initialize_prowler_provider,
|
||||
@@ -541,6 +542,46 @@ class SchemaView(SpectacularAPIView):
|
||||
return super().get(request, *args, **kwargs)
|
||||
|
||||
|
||||
SAML_CALLBACK_SESSION_KEY = "saml_callback_url"
|
||||
|
||||
|
||||
def _safe_callback_path(value):
|
||||
if not value or not isinstance(value, str):
|
||||
return None
|
||||
if not value.startswith("/") or value.startswith("//"):
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def _get_request_invitation_token(request):
|
||||
for source_name in ("data", "POST"):
|
||||
data = getattr(request, source_name, None) or {}
|
||||
if not hasattr(data, "get"):
|
||||
continue
|
||||
invitation_token = data.get("invitation_token")
|
||||
if invitation_token:
|
||||
return invitation_token
|
||||
|
||||
wrapped_request = getattr(request, "_request", None)
|
||||
if wrapped_request and wrapped_request is not request:
|
||||
return _get_request_invitation_token(wrapped_request)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _accept_social_invitation(request, user):
|
||||
invitation_token = _get_request_invitation_token(request)
|
||||
tenant_id = getattr(request, "prowler_invitation_tenant_id", None)
|
||||
if invitation_token and not tenant_id:
|
||||
invitation, _ = accept_invitation_for_user(
|
||||
user=user,
|
||||
invitation_token=invitation_token,
|
||||
raise_not_found=True,
|
||||
)
|
||||
tenant_id = str(invitation.tenant_id)
|
||||
return tenant_id
|
||||
|
||||
|
||||
@extend_schema(exclude=True)
|
||||
class GoogleSocialLoginView(SocialLoginView):
|
||||
adapter_class = GoogleOAuth2Adapter
|
||||
@@ -551,7 +592,11 @@ class GoogleSocialLoginView(SocialLoginView):
|
||||
original_response = super().get_response()
|
||||
|
||||
if self.user and self.user.is_authenticated:
|
||||
serializer = TokenSocialLoginSerializer(data={"email": self.user.email})
|
||||
tenant_id = _accept_social_invitation(self.request, self.user)
|
||||
serializer_data = {"email": self.user.email}
|
||||
if tenant_id:
|
||||
serializer_data["tenant_id"] = tenant_id
|
||||
serializer = TokenSocialLoginSerializer(data=serializer_data)
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
except TokenError as e:
|
||||
@@ -576,7 +621,11 @@ class GithubSocialLoginView(SocialLoginView):
|
||||
original_response = super().get_response()
|
||||
|
||||
if self.user and self.user.is_authenticated:
|
||||
serializer = TokenSocialLoginSerializer(data={"email": self.user.email})
|
||||
tenant_id = _accept_social_invitation(self.request, self.user)
|
||||
serializer_data = {"email": self.user.email}
|
||||
if tenant_id:
|
||||
serializer_data["tenant_id"] = tenant_id
|
||||
serializer = TokenSocialLoginSerializer(data=serializer_data)
|
||||
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
@@ -636,6 +685,9 @@ class CustomSAMLLoginView(LoginView):
|
||||
|
||||
This approach maintains security while providing better UX.
|
||||
"""
|
||||
callback_url = _safe_callback_path(request.GET.get("callback_url"))
|
||||
if callback_url:
|
||||
request.session[SAML_CALLBACK_SESSION_KEY] = callback_url
|
||||
if request.method == "GET":
|
||||
# Convert GET to POST while preserving parameters
|
||||
request.method = "POST"
|
||||
@@ -680,6 +732,11 @@ class SAMLInitiateAPIView(GenericAPIView):
|
||||
"saml_login", kwargs={"organization_slug": config.email_domain}
|
||||
)
|
||||
login_url = urljoin(api_host, login_path)
|
||||
callback_url = _safe_callback_path(
|
||||
serializer.validated_data.get("callback_url")
|
||||
)
|
||||
if callback_url:
|
||||
login_url = f"{login_url}?{urlencode({'callback_url': callback_url})}"
|
||||
|
||||
return redirect(login_url)
|
||||
|
||||
@@ -895,7 +952,13 @@ class TenantFinishACSView(FinishACSView):
|
||||
token=token_data, user=user
|
||||
)
|
||||
callback_url = env.str("SAML_SSO_CALLBACK_URL")
|
||||
redirect_url = f"{callback_url}?id={saml_token.id}"
|
||||
redirect_params = {"id": str(saml_token.id)}
|
||||
saml_callback_url = _safe_callback_path(
|
||||
request.session.pop(SAML_CALLBACK_SESSION_KEY, None)
|
||||
)
|
||||
if saml_callback_url:
|
||||
redirect_params["callbackUrl"] = saml_callback_url
|
||||
redirect_url = f"{callback_url}?{urlencode(redirect_params)}"
|
||||
request.session.pop("saml_user_created", None)
|
||||
|
||||
return redirect(redirect_url)
|
||||
@@ -4368,25 +4431,12 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
|
||||
invitation_token = serializer.validated_data["invitation_token"]
|
||||
user_email = request.user.email
|
||||
|
||||
invitation = validate_invitation(
|
||||
invitation_token, user_email, raise_not_found=True
|
||||
)
|
||||
|
||||
# Proceed with accepting the invitation
|
||||
user = User.objects.using(MainRouter.admin_db).get(email=user_email)
|
||||
membership = Membership.objects.using(MainRouter.admin_db).create(
|
||||
invitation, membership = accept_invitation_for_user(
|
||||
user=user,
|
||||
tenant=invitation.tenant,
|
||||
invitation_token=invitation_token,
|
||||
raise_not_found=True,
|
||||
)
|
||||
user_role = []
|
||||
for role in invitation.roles.all():
|
||||
user_role.append(
|
||||
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
|
||||
user=user, role=role, tenant=invitation.tenant
|
||||
)
|
||||
)
|
||||
invitation.state = Invitation.State.ACCEPTED
|
||||
invitation.save(using=MainRouter.admin_db)
|
||||
|
||||
self.response_serializer_class = MembershipSerializer
|
||||
membership_serializer = self.get_serializer(membership)
|
||||
|
||||
@@ -166,8 +166,13 @@ export const deleteSamlConfig = async (id: string) => {
|
||||
}
|
||||
};
|
||||
|
||||
export const initiateSamlAuth = async (email: string) => {
|
||||
export const initiateSamlAuth = async (email: string, callbackUrl = "/") => {
|
||||
try {
|
||||
const attributes = {
|
||||
email_domain: email,
|
||||
...(callbackUrl !== "/" && { callback_url: callbackUrl }),
|
||||
};
|
||||
|
||||
const response = await fetch(`${apiBaseUrl}/auth/saml/initiate/`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -176,9 +181,7 @@ export const initiateSamlAuth = async (email: string) => {
|
||||
body: JSON.stringify({
|
||||
data: {
|
||||
type: "saml-initiate",
|
||||
attributes: {
|
||||
email_domain: email,
|
||||
},
|
||||
attributes,
|
||||
},
|
||||
}),
|
||||
redirect: "manual",
|
||||
|
||||
@@ -13,6 +13,7 @@ const SignUp = async ({
|
||||
typeof resolvedSearchParams?.invitation_token === "string"
|
||||
? resolvedSearchParams.invitation_token
|
||||
: null;
|
||||
const isCloudEnv = process.env.NEXT_PUBLIC_IS_CLOUD_ENV === "true";
|
||||
|
||||
const GOOGLE_AUTH_URL = getAuthUrl("google");
|
||||
const GITHUB_AUTH_URL = getAuthUrl("github");
|
||||
@@ -21,6 +22,7 @@ const SignUp = async ({
|
||||
<AuthForm
|
||||
type="sign-up"
|
||||
invitationToken={invitationToken}
|
||||
isCloudEnv={isCloudEnv}
|
||||
googleAuthUrl={GOOGLE_AUTH_URL}
|
||||
githubAuthUrl={GITHUB_AUTH_URL}
|
||||
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
|
||||
|
||||
@@ -3,15 +3,24 @@
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
import { signIn } from "@/auth.config";
|
||||
import {
|
||||
getInvitationTokenFromCallbackPath,
|
||||
getSafeCallbackPath,
|
||||
} from "@/lib/auth-callback-url";
|
||||
import { apiBaseUrl, baseUrl } from "@/lib/helper";
|
||||
|
||||
export async function GET(req: Request) {
|
||||
const { searchParams } = new URL(req.url);
|
||||
|
||||
const code = searchParams.get("code");
|
||||
const callbackPath = getSafeCallbackPath(searchParams);
|
||||
const invitationToken = getInvitationTokenFromCallbackPath(callbackPath);
|
||||
|
||||
const params = new URLSearchParams();
|
||||
params.append("code", code || "");
|
||||
if (invitationToken) {
|
||||
params.append("invitation_token", invitationToken);
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
return NextResponse.json(
|
||||
@@ -37,18 +46,19 @@ export async function GET(req: Request) {
|
||||
const { access, refresh } = data.data.attributes;
|
||||
|
||||
try {
|
||||
const redirectPath = invitationToken ? "/" : callbackPath;
|
||||
const result = await signIn("social-oauth", {
|
||||
accessToken: access,
|
||||
refreshToken: refresh,
|
||||
redirect: false,
|
||||
callbackUrl: `${baseUrl}/`,
|
||||
callbackUrl: new URL(redirectPath, baseUrl).toString(),
|
||||
});
|
||||
|
||||
if (result?.error) {
|
||||
throw new Error(result.error);
|
||||
}
|
||||
|
||||
return NextResponse.redirect(new URL("/", baseUrl));
|
||||
return NextResponse.redirect(new URL(redirectPath, baseUrl));
|
||||
} catch (error) {
|
||||
console.error("SignIn error:", error);
|
||||
return NextResponse.redirect(
|
||||
|
||||
@@ -3,15 +3,24 @@
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
import { signIn } from "@/auth.config";
|
||||
import {
|
||||
getInvitationTokenFromCallbackPath,
|
||||
getSafeCallbackPath,
|
||||
} from "@/lib/auth-callback-url";
|
||||
import { apiBaseUrl, baseUrl } from "@/lib/helper";
|
||||
|
||||
export async function GET(req: Request) {
|
||||
const { searchParams } = new URL(req.url);
|
||||
|
||||
const code = searchParams.get("code");
|
||||
const callbackPath = getSafeCallbackPath(searchParams);
|
||||
const invitationToken = getInvitationTokenFromCallbackPath(callbackPath);
|
||||
|
||||
const params = new URLSearchParams();
|
||||
params.append("code", code || "");
|
||||
if (invitationToken) {
|
||||
params.append("invitation_token", invitationToken);
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
return NextResponse.json(
|
||||
@@ -37,18 +46,19 @@ export async function GET(req: Request) {
|
||||
const { access, refresh } = data.data.attributes;
|
||||
|
||||
try {
|
||||
const redirectPath = invitationToken ? "/" : callbackPath;
|
||||
const result = await signIn("social-oauth", {
|
||||
accessToken: access,
|
||||
refreshToken: refresh,
|
||||
redirect: false,
|
||||
callbackUrl: `${baseUrl}/`,
|
||||
callbackUrl: new URL(redirectPath, baseUrl).toString(),
|
||||
});
|
||||
|
||||
if (result?.error) {
|
||||
throw new Error(result.error);
|
||||
}
|
||||
|
||||
return NextResponse.redirect(new URL("/", baseUrl));
|
||||
return NextResponse.redirect(new URL(redirectPath, baseUrl));
|
||||
} catch (error) {
|
||||
console.error("SignIn error:", error);
|
||||
return NextResponse.redirect(
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
import { signIn } from "@/auth.config";
|
||||
import { getSafeCallbackPath } from "@/lib/auth-callback-url";
|
||||
import { apiBaseUrl, baseUrl } from "@/lib/helper";
|
||||
|
||||
export async function GET(req: Request) {
|
||||
const { searchParams } = new URL(req.url);
|
||||
const id = searchParams.get("id");
|
||||
const callbackPath = getSafeCallbackPath(searchParams, "callbackUrl");
|
||||
|
||||
if (!id) {
|
||||
return NextResponse.json(
|
||||
@@ -40,14 +42,14 @@ export async function GET(req: Request) {
|
||||
accessToken: access,
|
||||
refreshToken: refresh,
|
||||
redirect: false,
|
||||
callbackUrl: `${baseUrl}/`,
|
||||
callbackUrl: new URL(callbackPath, baseUrl).toString(),
|
||||
});
|
||||
|
||||
if (result?.error) {
|
||||
throw new Error(result.error);
|
||||
}
|
||||
|
||||
return NextResponse.redirect(new URL("/", baseUrl));
|
||||
return NextResponse.redirect(new URL(callbackPath, baseUrl));
|
||||
} catch (error) {
|
||||
console.error("SAML authentication failed:", error);
|
||||
return NextResponse.redirect(new URL("/sign-in", baseUrl));
|
||||
|
||||
@@ -4,6 +4,7 @@ import { SignUpForm } from "@/components/auth/oss/sign-up-form";
|
||||
export const AuthForm = ({
|
||||
type,
|
||||
invitationToken,
|
||||
isCloudEnv,
|
||||
googleAuthUrl,
|
||||
githubAuthUrl,
|
||||
isGoogleOAuthEnabled,
|
||||
@@ -11,6 +12,7 @@ export const AuthForm = ({
|
||||
}: {
|
||||
type: string;
|
||||
invitationToken?: string | null;
|
||||
isCloudEnv?: boolean;
|
||||
googleAuthUrl?: string;
|
||||
githubAuthUrl?: string;
|
||||
isGoogleOAuthEnabled?: boolean;
|
||||
@@ -30,6 +32,7 @@ export const AuthForm = ({
|
||||
return (
|
||||
<SignUpForm
|
||||
invitationToken={invitationToken}
|
||||
isCloudEnv={isCloudEnv}
|
||||
googleAuthUrl={googleAuthUrl}
|
||||
githubAuthUrl={githubAuthUrl}
|
||||
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
|
||||
|
||||
@@ -16,6 +16,7 @@ import { Button } from "@/components/shadcn";
|
||||
import { useToast } from "@/components/ui";
|
||||
import { CustomInput } from "@/components/ui/custom";
|
||||
import { Form } from "@/components/ui/form";
|
||||
import { getSafeCallbackPath } from "@/lib/auth-callback-url";
|
||||
import { SignInFormData, signInSchema } from "@/types";
|
||||
|
||||
export const SignInForm = ({
|
||||
@@ -32,7 +33,7 @@ export const SignInForm = ({
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const { toast } = useToast();
|
||||
const callbackUrl = searchParams.get("callbackUrl") || "/";
|
||||
const callbackUrl = getSafeCallbackPath(searchParams, "callbackUrl");
|
||||
|
||||
useEffect(() => {
|
||||
const samlError = searchParams.get("sso_saml_failed");
|
||||
@@ -102,7 +103,7 @@ export const SignInForm = ({
|
||||
form.setValue("password", "");
|
||||
}
|
||||
|
||||
const result = await initiateSamlAuth(email);
|
||||
const result = await initiateSamlAuth(email, callbackUrl);
|
||||
|
||||
if (result.success && result.redirectUrl) {
|
||||
window.location.href = result.redirectUrl;
|
||||
@@ -181,6 +182,7 @@ export const SignInForm = ({
|
||||
<SocialButtons
|
||||
googleAuthUrl={googleAuthUrl}
|
||||
githubAuthUrl={githubAuthUrl}
|
||||
callbackUrl={callbackUrl}
|
||||
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
|
||||
isGithubOAuthEnabled={isGithubOAuthEnabled}
|
||||
/>
|
||||
|
||||
@@ -41,12 +41,14 @@ const FORM_ERROR_TYPE = {
|
||||
|
||||
export const SignUpForm = ({
|
||||
invitationToken,
|
||||
isCloudEnv,
|
||||
googleAuthUrl,
|
||||
githubAuthUrl,
|
||||
isGoogleOAuthEnabled,
|
||||
isGithubOAuthEnabled,
|
||||
}: {
|
||||
invitationToken?: string | null;
|
||||
isCloudEnv?: boolean;
|
||||
googleAuthUrl?: string;
|
||||
githubAuthUrl?: string;
|
||||
isGoogleOAuthEnabled?: boolean;
|
||||
@@ -54,6 +56,9 @@ export const SignUpForm = ({
|
||||
}) => {
|
||||
const router = useRouter();
|
||||
const { toast } = useToast();
|
||||
const callbackUrl = invitationToken
|
||||
? `/invitation/accept?invitation_token=${encodeURIComponent(invitationToken)}`
|
||||
: "/";
|
||||
|
||||
const form = useForm<SignUpFormData>({
|
||||
resolver: zodResolver(signUpSchema),
|
||||
@@ -88,7 +93,7 @@ export const SignUpForm = ({
|
||||
});
|
||||
form.reset();
|
||||
|
||||
if (process.env.NEXT_PUBLIC_IS_CLOUD_ENV === "true") {
|
||||
if (isCloudEnv) {
|
||||
router.push("/email-verification");
|
||||
} else {
|
||||
router.push("/sign-in");
|
||||
@@ -200,7 +205,7 @@ export const SignUpForm = ({
|
||||
/>
|
||||
)}
|
||||
|
||||
{process.env.NEXT_PUBLIC_IS_CLOUD_ENV === "true" && (
|
||||
{isCloudEnv && (
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="termsAndConditions"
|
||||
@@ -243,13 +248,14 @@ export const SignUpForm = ({
|
||||
</form>
|
||||
</Form>
|
||||
|
||||
{!invitationToken && (
|
||||
{(!invitationToken || isCloudEnv) && (
|
||||
<>
|
||||
<AuthDivider />
|
||||
<div className="flex flex-col gap-2">
|
||||
<SocialButtons
|
||||
googleAuthUrl={googleAuthUrl}
|
||||
githubAuthUrl={githubAuthUrl}
|
||||
callbackUrl={callbackUrl}
|
||||
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
|
||||
isGithubOAuthEnabled={isGithubOAuthEnabled}
|
||||
/>
|
||||
|
||||
@@ -3,81 +3,93 @@ import { Icon } from "@iconify/react";
|
||||
|
||||
import { Button } from "@/components/shadcn";
|
||||
import { CustomLink } from "@/components/ui/custom/custom-link";
|
||||
import { appendCallbackState } from "@/lib/auth-callback-url";
|
||||
|
||||
export const SocialButtons = ({
|
||||
googleAuthUrl,
|
||||
githubAuthUrl,
|
||||
callbackUrl = "/",
|
||||
isGoogleOAuthEnabled,
|
||||
isGithubOAuthEnabled,
|
||||
}: {
|
||||
googleAuthUrl?: string;
|
||||
githubAuthUrl?: string;
|
||||
callbackUrl?: string;
|
||||
isGoogleOAuthEnabled?: boolean;
|
||||
isGithubOAuthEnabled?: boolean;
|
||||
}) => (
|
||||
<>
|
||||
<Tooltip
|
||||
content={
|
||||
<div className="flex-inline text-small">
|
||||
Social Login with Google is not enabled.{" "}
|
||||
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#google-oauth-configuration">
|
||||
Read the docs
|
||||
</CustomLink>
|
||||
</div>
|
||||
}
|
||||
placement="top"
|
||||
shadow="sm"
|
||||
isDisabled={isGoogleOAuthEnabled}
|
||||
className="w-96"
|
||||
>
|
||||
<span>
|
||||
<Button
|
||||
variant="outline"
|
||||
className="w-full"
|
||||
asChild={isGoogleOAuthEnabled}
|
||||
disabled={!isGoogleOAuthEnabled}
|
||||
>
|
||||
<a href={googleAuthUrl} className="flex items-center gap-2">
|
||||
<Icon
|
||||
icon={
|
||||
isGoogleOAuthEnabled
|
||||
? "flat-color-icons:google"
|
||||
: "simple-icons:google"
|
||||
}
|
||||
width={24}
|
||||
/>
|
||||
Continue with Google
|
||||
</a>
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
content={
|
||||
<div className="flex-inline text-small">
|
||||
Social Login with Github is not enabled.{" "}
|
||||
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#github-oauth-configuration">
|
||||
Read the docs
|
||||
</CustomLink>
|
||||
</div>
|
||||
}
|
||||
placement="top"
|
||||
shadow="sm"
|
||||
isDisabled={isGithubOAuthEnabled}
|
||||
className="w-96"
|
||||
>
|
||||
<span>
|
||||
<Button
|
||||
variant="outline"
|
||||
className="w-full"
|
||||
asChild={isGithubOAuthEnabled}
|
||||
disabled={!isGithubOAuthEnabled}
|
||||
>
|
||||
<a href={githubAuthUrl} className="flex items-center gap-2">
|
||||
<Icon icon="simple-icons:github" width={24} />
|
||||
Continue with Github
|
||||
</a>
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
</>
|
||||
);
|
||||
}) => {
|
||||
const googleUrl = googleAuthUrl
|
||||
? appendCallbackState(googleAuthUrl, callbackUrl)
|
||||
: undefined;
|
||||
const githubUrl = githubAuthUrl
|
||||
? appendCallbackState(githubAuthUrl, callbackUrl)
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Tooltip
|
||||
content={
|
||||
<div className="flex-inline text-small">
|
||||
Social Login with Google is not enabled.{" "}
|
||||
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#google-oauth-configuration">
|
||||
Read the docs
|
||||
</CustomLink>
|
||||
</div>
|
||||
}
|
||||
placement="top"
|
||||
shadow="sm"
|
||||
isDisabled={isGoogleOAuthEnabled}
|
||||
className="w-96"
|
||||
>
|
||||
<span>
|
||||
<Button
|
||||
variant="outline"
|
||||
className="w-full"
|
||||
asChild={isGoogleOAuthEnabled}
|
||||
disabled={!isGoogleOAuthEnabled}
|
||||
>
|
||||
<a href={googleUrl} className="flex items-center gap-2">
|
||||
<Icon
|
||||
icon={
|
||||
isGoogleOAuthEnabled
|
||||
? "flat-color-icons:google"
|
||||
: "simple-icons:google"
|
||||
}
|
||||
width={24}
|
||||
/>
|
||||
Continue with Google
|
||||
</a>
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
content={
|
||||
<div className="flex-inline text-small">
|
||||
Social Login with Github is not enabled.{" "}
|
||||
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#github-oauth-configuration">
|
||||
Read the docs
|
||||
</CustomLink>
|
||||
</div>
|
||||
}
|
||||
placement="top"
|
||||
shadow="sm"
|
||||
isDisabled={isGithubOAuthEnabled}
|
||||
className="w-96"
|
||||
>
|
||||
<span>
|
||||
<Button
|
||||
variant="outline"
|
||||
className="w-full"
|
||||
asChild={isGithubOAuthEnabled}
|
||||
disabled={!isGithubOAuthEnabled}
|
||||
>
|
||||
<a href={githubUrl} className="flex items-center gap-2">
|
||||
<Icon icon="simple-icons:github" width={24} />
|
||||
Continue with Github
|
||||
</a>
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import {
|
||||
appendCallbackState,
|
||||
getInvitationTokenFromCallbackPath,
|
||||
getSafeCallbackPath,
|
||||
} from "@/lib/auth-callback-url";
|
||||
|
||||
describe("auth callback URL helpers", () => {
|
||||
describe("when appending OAuth state", () => {
|
||||
it("should add a relative callback path as provider state", () => {
|
||||
const authUrl = "https://accounts.example.com/oauth?client_id=client";
|
||||
const callbackPath = "/invitation/accept?invitation_token=test-token";
|
||||
|
||||
const result = appendCallbackState(authUrl, callbackPath);
|
||||
|
||||
expect(new URL(result).searchParams.get("state")).toBe(callbackPath);
|
||||
});
|
||||
|
||||
it("should not add state for the default callback path", () => {
|
||||
const authUrl = "https://accounts.example.com/oauth?client_id=client";
|
||||
|
||||
const result = appendCallbackState(authUrl, "/");
|
||||
|
||||
expect(new URL(result).searchParams.has("state")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("when reading callback paths", () => {
|
||||
it("should return relative callback paths", () => {
|
||||
const params = new URLSearchParams({
|
||||
state: "/invitation/accept?invitation_token=test-token",
|
||||
});
|
||||
|
||||
const result = getSafeCallbackPath(params);
|
||||
|
||||
expect(result).toBe("/invitation/accept?invitation_token=test-token");
|
||||
});
|
||||
|
||||
it("should reject external callback URLs", () => {
|
||||
const params = new URLSearchParams({
|
||||
state: "https://attacker.example/phishing",
|
||||
});
|
||||
|
||||
const result = getSafeCallbackPath(params);
|
||||
|
||||
expect(result).toBe("/");
|
||||
});
|
||||
});
|
||||
|
||||
describe("when reading invitation tokens", () => {
|
||||
it("should return invitation tokens from safe callback paths", () => {
|
||||
const callbackPath = "/invitation/accept?invitation_token=test-token";
|
||||
|
||||
const result = getInvitationTokenFromCallbackPath(callbackPath);
|
||||
|
||||
expect(result).toBe("test-token");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,50 @@
|
||||
const DEFAULT_CALLBACK_PATH = "/";
|
||||
const INVITATION_TOKEN_PARAM = "invitation_token";
|
||||
|
||||
type CallbackSearchParams = {
|
||||
get(name: string): string | null;
|
||||
};
|
||||
|
||||
export const getSafeCallbackPathFromValue = (
|
||||
value: string | null | undefined,
|
||||
) => {
|
||||
if (!value || !value.startsWith("/") || value.startsWith("//")) {
|
||||
return DEFAULT_CALLBACK_PATH;
|
||||
}
|
||||
|
||||
return value;
|
||||
};
|
||||
|
||||
export const getSafeCallbackPath = (
|
||||
searchParams: CallbackSearchParams,
|
||||
key = "state",
|
||||
) => getSafeCallbackPathFromValue(searchParams.get(key));
|
||||
|
||||
export const appendCallbackState = (authUrl: string, callbackPath: string) => {
|
||||
const safeCallbackPath = getSafeCallbackPathFromValue(callbackPath);
|
||||
if (safeCallbackPath === DEFAULT_CALLBACK_PATH) {
|
||||
return authUrl;
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(authUrl);
|
||||
url.searchParams.set("state", safeCallbackPath);
|
||||
return url.toString();
|
||||
} catch (_error) {
|
||||
return authUrl;
|
||||
}
|
||||
};
|
||||
|
||||
export const getInvitationTokenFromCallbackPath = (callbackPath: string) => {
|
||||
const safeCallbackPath = getSafeCallbackPathFromValue(callbackPath);
|
||||
if (safeCallbackPath === DEFAULT_CALLBACK_PATH) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(safeCallbackPath, "http://localhost");
|
||||
return url.searchParams.get(INVITATION_TOKEN_PARAM);
|
||||
} catch (_error) {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user