fix: handle invitations in social and SAML auth

- Preserve invitation callbacks through social and SAML login
- Accept invited users without creating a default tenant
- Add regression coverage for invitation acceptance flows
This commit is contained in:
Adrián Jesús Peña Rodríguez
2026-07-01 12:41:12 +02:00
parent a212916a49
commit 20f849f58c
17 changed files with 502 additions and 126 deletions
+47 -21
View File
@@ -9,6 +9,7 @@ from api.models import (
User,
UserRoleRelationship,
)
from api.utils import accept_invitation_for_user
from django.db import transaction
@@ -20,6 +21,22 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
except User.DoesNotExist:
return None
@staticmethod
def _get_invitation_token(request):
for source_name in ("data", "POST"):
data = getattr(request, source_name, None) or {}
if not hasattr(data, "get"):
continue
invitation_token = data.get("invitation_token")
if invitation_token:
return invitation_token
wrapped_request = getattr(request, "_request", None)
if wrapped_request and wrapped_request is not request:
return ProwlerSocialAccountAdapter._get_invitation_token(wrapped_request)
return None
def pre_social_login(self, request, sociallogin):
# Link existing accounts with the same email address
email = sociallogin.account.extra_data.get("email")
@@ -83,29 +100,38 @@ class ProwlerSocialAccountAdapter(DefaultSocialAccountAdapter):
user.name = social_account_name
user.save(using=MainRouter.admin_db)
tenant = Tenant.objects.using(MainRouter.admin_db).create(
name=f"{user.email.split('@')[0]} default tenant"
)
with rls_transaction(str(tenant.id)):
Membership.objects.using(MainRouter.admin_db).create(
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
)
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
invitation_token = self._get_invitation_token(request)
if invitation_token:
invitation, _ = accept_invitation_for_user(
user=user,
role=role,
tenant_id=tenant.id,
invitation_token=invitation_token,
)
request.prowler_invitation_token = invitation_token
request.prowler_invitation_tenant_id = str(invitation.tenant_id)
else:
tenant = Tenant.objects.using(MainRouter.admin_db).create(
name=f"{user.email.split('@')[0]} default tenant"
)
with rls_transaction(str(tenant.id)):
Membership.objects.using(MainRouter.admin_db).create(
user=user, tenant=tenant, role=Membership.RoleChoices.OWNER
)
role = Role.objects.using(MainRouter.admin_db).create(
name="admin",
tenant_id=tenant.id,
manage_users=True,
manage_account=True,
manage_billing=True,
manage_providers=True,
manage_integrations=True,
manage_scans=True,
unlimited_visibility=True,
)
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user,
role=role,
tenant_id=tenant.id,
)
else:
request.session["saml_user_created"] = str(user.id)
+39 -1
View File
@@ -5,7 +5,7 @@ import pytest
from allauth.socialaccount.models import SocialLogin
from api.adapters import ProwlerSocialAccountAdapter
from api.db_router import MainRouter
from api.models import SAMLConfiguration
from api.models import Invitation, Membership, SAMLConfiguration, Tenant
from django.contrib.auth import get_user_model
User = get_user_model()
@@ -188,6 +188,44 @@ class TestProwlerSocialAccountAdapter:
_, called_user = call_args[0]
assert called_user.email == create_test_user.email
def test_save_user_social_with_invitation_joins_invited_tenant(
self, rf, create_test_user, tenants_fixture
):
adapter = ProwlerSocialAccountAdapter()
invited_tenant = tenants_fixture[2]
invited_email = "frank-invited@example.com"
invitation = Invitation.objects.create(
tenant=invited_tenant,
email=invited_email,
inviter=create_test_user,
)
request = rf.post("/", data={"invitation_token": invitation.token})
request.session = {}
sociallogin = MagicMock(spec=SocialLogin)
sociallogin.provider = MagicMock()
sociallogin.provider.id = "google"
sociallogin.account = MagicMock()
sociallogin.account.extra_data = {"name": "Frank"}
real_user = User.objects.create_user(
name="Frank", email=invited_email, password="Secret123!"
)
tenants_before = Tenant.objects.count()
with patch("api.adapters.super") as mock_super:
mock_super.return_value.save_user.return_value = real_user
adapter.save_user(request, sociallogin)
invitation.refresh_from_db()
assert invitation.state == Invitation.State.ACCEPTED
assert Tenant.objects.count() == tenants_before
assert Membership.objects.filter(
user=real_user,
tenant=invited_tenant,
role=Membership.RoleChoices.MEMBER,
).exists()
def test_save_user_saml_sets_session_flag(self, rf):
adapter = ProwlerSocialAccountAdapter()
request = rf.get("/")
+67 -1
View File
@@ -8356,6 +8356,39 @@ class TestInvitationViewSet:
).exists()
assert invitation.state == Invitation.State.ACCEPTED.value
def test_invitations_accept_invitation_existing_membership(
self,
authenticated_client,
create_test_user,
tenants_fixture,
):
*_, tenant = tenants_fixture
user = create_test_user
invitation = Invitation.objects.create(
tenant=tenant,
email=TEST_USER,
inviter=user,
expires_at=self.TOMORROW,
)
Membership.objects.create(user=user, tenant=tenant)
response = authenticated_client.post(
reverse("invitation-accept"),
data={"invitation_token": invitation.token},
format="json",
)
assert response.status_code == status.HTTP_201_CREATED
invitation.refresh_from_db()
assert invitation.state == Invitation.State.ACCEPTED.value
assert (
Membership.objects.filter(
user__email__iexact=user.email, tenant=tenant
).count()
== 1
)
def test_invitations_accept_invitation_invalid_token(self, authenticated_client):
data = {
"invitation_token": "invalid_token",
@@ -13458,6 +13491,37 @@ class TestSAMLInitiateAPIView:
)
assert "SAMLRequest" not in response.url
def test_valid_email_domain_preserves_safe_callback_url(
self, authenticated_client, saml_setup
):
url = reverse("api_saml_initiate")
callback_url = "/invitation/accept?invitation_token=test-token"
payload = {
"email_domain": saml_setup["email"],
"callback_url": callback_url,
}
response = authenticated_client.post(url, data=payload, format="json")
assert response.status_code == status.HTTP_302_FOUND
query_params = parse_qs(urlparse(response.url).query)
assert query_params["callback_url"] == [callback_url]
def test_valid_email_domain_rejects_external_callback_url(
self, authenticated_client, saml_setup
):
url = reverse("api_saml_initiate")
payload = {
"email_domain": saml_setup["email"],
"callback_url": "https://attacker.example/invitation",
}
response = authenticated_client.post(url, data=payload, format="json")
assert response.status_code == status.HTTP_302_FOUND
query_params = parse_qs(urlparse(response.url).query)
assert "callback_url" not in query_params
def test_invalid_email_domain(self, authenticated_client):
url = reverse("api_saml_initiate")
payload = {"email_domain": "user@unauthorized.com"}
@@ -13645,7 +13709,8 @@ class TestTenantFinishACSView:
)
)
request.user = user
request.session = {}
callback_url = "/invitation/accept?invitation_token=test-token"
request.session = {"saml_callback_url": callback_url}
with (
patch(
@@ -13687,6 +13752,7 @@ class TestTenantFinishACSView:
assert parsed_url.netloc == expected_callback_host
query_params = parse_qs(parsed_url.query)
assert "id" in query_params
assert query_params["callbackUrl"] == [callback_url]
token_id = query_params["id"][0]
token_obj = SAMLToken.objects.get(id=token_id)
+36 -1
View File
@@ -7,7 +7,16 @@ from allauth.socialaccount.providers.oauth2.client import OAuth2Client
from api.db_router import MainRouter
from api.db_utils import rls_transaction
from api.exceptions import InvitationTokenExpiredException
from api.models import Integration, Invitation, Processor, Provider, Resource
from api.models import (
Integration,
Invitation,
Membership,
Processor,
Provider,
Resource,
Role,
UserRoleRelationship,
)
from api.v1.serializers import FindingMetadataSerializer
from django.contrib.postgres.aggregates import ArrayAgg
from django.db.models import Subquery
@@ -538,6 +547,32 @@ def validate_invitation(
return invitation
def accept_invitation_for_user(
*, user, invitation_token: str, raise_not_found: bool = False
):
invitation = validate_invitation(
invitation_token, user.email, raise_not_found=raise_not_found
)
membership, _ = Membership.objects.using(MainRouter.admin_db).get_or_create(
user=user,
tenant=invitation.tenant,
defaults={"role": Membership.RoleChoices.MEMBER},
)
invitation_roles = Role.objects.using(MainRouter.admin_db).filter(
invitations=invitation
)
for role in invitation_roles:
UserRoleRelationship.objects.using(MainRouter.admin_db).get_or_create(
user=user,
role=role,
defaults={"tenant": invitation.tenant},
)
invitation.state = Invitation.State.ACCEPTED
invitation.save(using=MainRouter.admin_db)
return invitation, membership
# ToRemove after removing the fallback mechanism in /findings/metadata
def get_findings_metadata_no_aggregations(tenant_id: str, filtered_queryset):
filtered_ids = filtered_queryset.order_by().values("id")
+1
View File
@@ -3147,6 +3147,7 @@ class ProcessorUpdateSerializer(BaseWriteSerializer):
class SamlInitiateSerializer(BaseSerializerV1):
email_domain = serializers.CharField()
callback_url = serializers.CharField(required=False, allow_blank=True)
class JSONAPIMeta:
resource_name = "saml-initiate"
+70 -20
View File
@@ -9,7 +9,7 @@ from collections import defaultdict
from copy import deepcopy
from datetime import UTC, datetime, timedelta
from decimal import ROUND_HALF_UP, Decimal, InvalidOperation
from urllib.parse import urljoin
from urllib.parse import urlencode, urljoin
import sentry_sdk
from allauth.socialaccount.models import SocialAccount, SocialApp
@@ -128,6 +128,7 @@ from api.renderers import APIJSONRenderer, PlainTextRenderer
from api.rls import Tenant
from api.utils import (
CustomOAuth2Client,
accept_invitation_for_user,
get_findings_metadata_no_aggregations,
initialize_prowler_integration,
initialize_prowler_provider,
@@ -541,6 +542,46 @@ class SchemaView(SpectacularAPIView):
return super().get(request, *args, **kwargs)
SAML_CALLBACK_SESSION_KEY = "saml_callback_url"
def _safe_callback_path(value):
if not value or not isinstance(value, str):
return None
if not value.startswith("/") or value.startswith("//"):
return None
return value
def _get_request_invitation_token(request):
for source_name in ("data", "POST"):
data = getattr(request, source_name, None) or {}
if not hasattr(data, "get"):
continue
invitation_token = data.get("invitation_token")
if invitation_token:
return invitation_token
wrapped_request = getattr(request, "_request", None)
if wrapped_request and wrapped_request is not request:
return _get_request_invitation_token(wrapped_request)
return None
def _accept_social_invitation(request, user):
invitation_token = _get_request_invitation_token(request)
tenant_id = getattr(request, "prowler_invitation_tenant_id", None)
if invitation_token and not tenant_id:
invitation, _ = accept_invitation_for_user(
user=user,
invitation_token=invitation_token,
raise_not_found=True,
)
tenant_id = str(invitation.tenant_id)
return tenant_id
@extend_schema(exclude=True)
class GoogleSocialLoginView(SocialLoginView):
adapter_class = GoogleOAuth2Adapter
@@ -551,7 +592,11 @@ class GoogleSocialLoginView(SocialLoginView):
original_response = super().get_response()
if self.user and self.user.is_authenticated:
serializer = TokenSocialLoginSerializer(data={"email": self.user.email})
tenant_id = _accept_social_invitation(self.request, self.user)
serializer_data = {"email": self.user.email}
if tenant_id:
serializer_data["tenant_id"] = tenant_id
serializer = TokenSocialLoginSerializer(data=serializer_data)
try:
serializer.is_valid(raise_exception=True)
except TokenError as e:
@@ -576,7 +621,11 @@ class GithubSocialLoginView(SocialLoginView):
original_response = super().get_response()
if self.user and self.user.is_authenticated:
serializer = TokenSocialLoginSerializer(data={"email": self.user.email})
tenant_id = _accept_social_invitation(self.request, self.user)
serializer_data = {"email": self.user.email}
if tenant_id:
serializer_data["tenant_id"] = tenant_id
serializer = TokenSocialLoginSerializer(data=serializer_data)
try:
serializer.is_valid(raise_exception=True)
@@ -636,6 +685,9 @@ class CustomSAMLLoginView(LoginView):
This approach maintains security while providing better UX.
"""
callback_url = _safe_callback_path(request.GET.get("callback_url"))
if callback_url:
request.session[SAML_CALLBACK_SESSION_KEY] = callback_url
if request.method == "GET":
# Convert GET to POST while preserving parameters
request.method = "POST"
@@ -680,6 +732,11 @@ class SAMLInitiateAPIView(GenericAPIView):
"saml_login", kwargs={"organization_slug": config.email_domain}
)
login_url = urljoin(api_host, login_path)
callback_url = _safe_callback_path(
serializer.validated_data.get("callback_url")
)
if callback_url:
login_url = f"{login_url}?{urlencode({'callback_url': callback_url})}"
return redirect(login_url)
@@ -895,7 +952,13 @@ class TenantFinishACSView(FinishACSView):
token=token_data, user=user
)
callback_url = env.str("SAML_SSO_CALLBACK_URL")
redirect_url = f"{callback_url}?id={saml_token.id}"
redirect_params = {"id": str(saml_token.id)}
saml_callback_url = _safe_callback_path(
request.session.pop(SAML_CALLBACK_SESSION_KEY, None)
)
if saml_callback_url:
redirect_params["callbackUrl"] = saml_callback_url
redirect_url = f"{callback_url}?{urlencode(redirect_params)}"
request.session.pop("saml_user_created", None)
return redirect(redirect_url)
@@ -4368,25 +4431,12 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
invitation_token = serializer.validated_data["invitation_token"]
user_email = request.user.email
invitation = validate_invitation(
invitation_token, user_email, raise_not_found=True
)
# Proceed with accepting the invitation
user = User.objects.using(MainRouter.admin_db).get(email=user_email)
membership = Membership.objects.using(MainRouter.admin_db).create(
invitation, membership = accept_invitation_for_user(
user=user,
tenant=invitation.tenant,
invitation_token=invitation_token,
raise_not_found=True,
)
user_role = []
for role in invitation.roles.all():
user_role.append(
UserRoleRelationship.objects.using(MainRouter.admin_db).create(
user=user, role=role, tenant=invitation.tenant
)
)
invitation.state = Invitation.State.ACCEPTED
invitation.save(using=MainRouter.admin_db)
self.response_serializer_class = MembershipSerializer
membership_serializer = self.get_serializer(membership)
+7 -4
View File
@@ -166,8 +166,13 @@ export const deleteSamlConfig = async (id: string) => {
}
};
export const initiateSamlAuth = async (email: string) => {
export const initiateSamlAuth = async (email: string, callbackUrl = "/") => {
try {
const attributes = {
email_domain: email,
...(callbackUrl !== "/" && { callback_url: callbackUrl }),
};
const response = await fetch(`${apiBaseUrl}/auth/saml/initiate/`, {
method: "POST",
headers: {
@@ -176,9 +181,7 @@ export const initiateSamlAuth = async (email: string) => {
body: JSON.stringify({
data: {
type: "saml-initiate",
attributes: {
email_domain: email,
},
attributes,
},
}),
redirect: "manual",
@@ -13,6 +13,7 @@ const SignUp = async ({
typeof resolvedSearchParams?.invitation_token === "string"
? resolvedSearchParams.invitation_token
: null;
const isCloudEnv = process.env.NEXT_PUBLIC_IS_CLOUD_ENV === "true";
const GOOGLE_AUTH_URL = getAuthUrl("google");
const GITHUB_AUTH_URL = getAuthUrl("github");
@@ -21,6 +22,7 @@ const SignUp = async ({
<AuthForm
type="sign-up"
invitationToken={invitationToken}
isCloudEnv={isCloudEnv}
googleAuthUrl={GOOGLE_AUTH_URL}
githubAuthUrl={GITHUB_AUTH_URL}
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
+12 -2
View File
@@ -3,15 +3,24 @@
import { NextResponse } from "next/server";
import { signIn } from "@/auth.config";
import {
getInvitationTokenFromCallbackPath,
getSafeCallbackPath,
} from "@/lib/auth-callback-url";
import { apiBaseUrl, baseUrl } from "@/lib/helper";
export async function GET(req: Request) {
const { searchParams } = new URL(req.url);
const code = searchParams.get("code");
const callbackPath = getSafeCallbackPath(searchParams);
const invitationToken = getInvitationTokenFromCallbackPath(callbackPath);
const params = new URLSearchParams();
params.append("code", code || "");
if (invitationToken) {
params.append("invitation_token", invitationToken);
}
if (!code) {
return NextResponse.json(
@@ -37,18 +46,19 @@ export async function GET(req: Request) {
const { access, refresh } = data.data.attributes;
try {
const redirectPath = invitationToken ? "/" : callbackPath;
const result = await signIn("social-oauth", {
accessToken: access,
refreshToken: refresh,
redirect: false,
callbackUrl: `${baseUrl}/`,
callbackUrl: new URL(redirectPath, baseUrl).toString(),
});
if (result?.error) {
throw new Error(result.error);
}
return NextResponse.redirect(new URL("/", baseUrl));
return NextResponse.redirect(new URL(redirectPath, baseUrl));
} catch (error) {
console.error("SignIn error:", error);
return NextResponse.redirect(
+12 -2
View File
@@ -3,15 +3,24 @@
import { NextResponse } from "next/server";
import { signIn } from "@/auth.config";
import {
getInvitationTokenFromCallbackPath,
getSafeCallbackPath,
} from "@/lib/auth-callback-url";
import { apiBaseUrl, baseUrl } from "@/lib/helper";
export async function GET(req: Request) {
const { searchParams } = new URL(req.url);
const code = searchParams.get("code");
const callbackPath = getSafeCallbackPath(searchParams);
const invitationToken = getInvitationTokenFromCallbackPath(callbackPath);
const params = new URLSearchParams();
params.append("code", code || "");
if (invitationToken) {
params.append("invitation_token", invitationToken);
}
if (!code) {
return NextResponse.json(
@@ -37,18 +46,19 @@ export async function GET(req: Request) {
const { access, refresh } = data.data.attributes;
try {
const redirectPath = invitationToken ? "/" : callbackPath;
const result = await signIn("social-oauth", {
accessToken: access,
refreshToken: refresh,
redirect: false,
callbackUrl: `${baseUrl}/`,
callbackUrl: new URL(redirectPath, baseUrl).toString(),
});
if (result?.error) {
throw new Error(result.error);
}
return NextResponse.redirect(new URL("/", baseUrl));
return NextResponse.redirect(new URL(redirectPath, baseUrl));
} catch (error) {
console.error("SignIn error:", error);
return NextResponse.redirect(
+4 -2
View File
@@ -3,11 +3,13 @@
import { NextResponse } from "next/server";
import { signIn } from "@/auth.config";
import { getSafeCallbackPath } from "@/lib/auth-callback-url";
import { apiBaseUrl, baseUrl } from "@/lib/helper";
export async function GET(req: Request) {
const { searchParams } = new URL(req.url);
const id = searchParams.get("id");
const callbackPath = getSafeCallbackPath(searchParams, "callbackUrl");
if (!id) {
return NextResponse.json(
@@ -40,14 +42,14 @@ export async function GET(req: Request) {
accessToken: access,
refreshToken: refresh,
redirect: false,
callbackUrl: `${baseUrl}/`,
callbackUrl: new URL(callbackPath, baseUrl).toString(),
});
if (result?.error) {
throw new Error(result.error);
}
return NextResponse.redirect(new URL("/", baseUrl));
return NextResponse.redirect(new URL(callbackPath, baseUrl));
} catch (error) {
console.error("SAML authentication failed:", error);
return NextResponse.redirect(new URL("/sign-in", baseUrl));
+3
View File
@@ -4,6 +4,7 @@ import { SignUpForm } from "@/components/auth/oss/sign-up-form";
export const AuthForm = ({
type,
invitationToken,
isCloudEnv,
googleAuthUrl,
githubAuthUrl,
isGoogleOAuthEnabled,
@@ -11,6 +12,7 @@ export const AuthForm = ({
}: {
type: string;
invitationToken?: string | null;
isCloudEnv?: boolean;
googleAuthUrl?: string;
githubAuthUrl?: string;
isGoogleOAuthEnabled?: boolean;
@@ -30,6 +32,7 @@ export const AuthForm = ({
return (
<SignUpForm
invitationToken={invitationToken}
isCloudEnv={isCloudEnv}
googleAuthUrl={googleAuthUrl}
githubAuthUrl={githubAuthUrl}
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
+4 -2
View File
@@ -16,6 +16,7 @@ import { Button } from "@/components/shadcn";
import { useToast } from "@/components/ui";
import { CustomInput } from "@/components/ui/custom";
import { Form } from "@/components/ui/form";
import { getSafeCallbackPath } from "@/lib/auth-callback-url";
import { SignInFormData, signInSchema } from "@/types";
export const SignInForm = ({
@@ -32,7 +33,7 @@ export const SignInForm = ({
const router = useRouter();
const searchParams = useSearchParams();
const { toast } = useToast();
const callbackUrl = searchParams.get("callbackUrl") || "/";
const callbackUrl = getSafeCallbackPath(searchParams, "callbackUrl");
useEffect(() => {
const samlError = searchParams.get("sso_saml_failed");
@@ -102,7 +103,7 @@ export const SignInForm = ({
form.setValue("password", "");
}
const result = await initiateSamlAuth(email);
const result = await initiateSamlAuth(email, callbackUrl);
if (result.success && result.redirectUrl) {
window.location.href = result.redirectUrl;
@@ -181,6 +182,7 @@ export const SignInForm = ({
<SocialButtons
googleAuthUrl={googleAuthUrl}
githubAuthUrl={githubAuthUrl}
callbackUrl={callbackUrl}
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
isGithubOAuthEnabled={isGithubOAuthEnabled}
/>
+9 -3
View File
@@ -41,12 +41,14 @@ const FORM_ERROR_TYPE = {
export const SignUpForm = ({
invitationToken,
isCloudEnv,
googleAuthUrl,
githubAuthUrl,
isGoogleOAuthEnabled,
isGithubOAuthEnabled,
}: {
invitationToken?: string | null;
isCloudEnv?: boolean;
googleAuthUrl?: string;
githubAuthUrl?: string;
isGoogleOAuthEnabled?: boolean;
@@ -54,6 +56,9 @@ export const SignUpForm = ({
}) => {
const router = useRouter();
const { toast } = useToast();
const callbackUrl = invitationToken
? `/invitation/accept?invitation_token=${encodeURIComponent(invitationToken)}`
: "/";
const form = useForm<SignUpFormData>({
resolver: zodResolver(signUpSchema),
@@ -88,7 +93,7 @@ export const SignUpForm = ({
});
form.reset();
if (process.env.NEXT_PUBLIC_IS_CLOUD_ENV === "true") {
if (isCloudEnv) {
router.push("/email-verification");
} else {
router.push("/sign-in");
@@ -200,7 +205,7 @@ export const SignUpForm = ({
/>
)}
{process.env.NEXT_PUBLIC_IS_CLOUD_ENV === "true" && (
{isCloudEnv && (
<FormField
control={form.control}
name="termsAndConditions"
@@ -243,13 +248,14 @@ export const SignUpForm = ({
</form>
</Form>
{!invitationToken && (
{(!invitationToken || isCloudEnv) && (
<>
<AuthDivider />
<div className="flex flex-col gap-2">
<SocialButtons
googleAuthUrl={googleAuthUrl}
githubAuthUrl={githubAuthUrl}
callbackUrl={callbackUrl}
isGoogleOAuthEnabled={isGoogleOAuthEnabled}
isGithubOAuthEnabled={isGithubOAuthEnabled}
/>
+79 -67
View File
@@ -3,81 +3,93 @@ import { Icon } from "@iconify/react";
import { Button } from "@/components/shadcn";
import { CustomLink } from "@/components/ui/custom/custom-link";
import { appendCallbackState } from "@/lib/auth-callback-url";
export const SocialButtons = ({
googleAuthUrl,
githubAuthUrl,
callbackUrl = "/",
isGoogleOAuthEnabled,
isGithubOAuthEnabled,
}: {
googleAuthUrl?: string;
githubAuthUrl?: string;
callbackUrl?: string;
isGoogleOAuthEnabled?: boolean;
isGithubOAuthEnabled?: boolean;
}) => (
<>
<Tooltip
content={
<div className="flex-inline text-small">
Social Login with Google is not enabled.{" "}
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#google-oauth-configuration">
Read the docs
</CustomLink>
</div>
}
placement="top"
shadow="sm"
isDisabled={isGoogleOAuthEnabled}
className="w-96"
>
<span>
<Button
variant="outline"
className="w-full"
asChild={isGoogleOAuthEnabled}
disabled={!isGoogleOAuthEnabled}
>
<a href={googleAuthUrl} className="flex items-center gap-2">
<Icon
icon={
isGoogleOAuthEnabled
? "flat-color-icons:google"
: "simple-icons:google"
}
width={24}
/>
Continue with Google
</a>
</Button>
</span>
</Tooltip>
<Tooltip
content={
<div className="flex-inline text-small">
Social Login with Github is not enabled.{" "}
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#github-oauth-configuration">
Read the docs
</CustomLink>
</div>
}
placement="top"
shadow="sm"
isDisabled={isGithubOAuthEnabled}
className="w-96"
>
<span>
<Button
variant="outline"
className="w-full"
asChild={isGithubOAuthEnabled}
disabled={!isGithubOAuthEnabled}
>
<a href={githubAuthUrl} className="flex items-center gap-2">
<Icon icon="simple-icons:github" width={24} />
Continue with Github
</a>
</Button>
</span>
</Tooltip>
</>
);
}) => {
const googleUrl = googleAuthUrl
? appendCallbackState(googleAuthUrl, callbackUrl)
: undefined;
const githubUrl = githubAuthUrl
? appendCallbackState(githubAuthUrl, callbackUrl)
: undefined;
return (
<>
<Tooltip
content={
<div className="flex-inline text-small">
Social Login with Google is not enabled.{" "}
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#google-oauth-configuration">
Read the docs
</CustomLink>
</div>
}
placement="top"
shadow="sm"
isDisabled={isGoogleOAuthEnabled}
className="w-96"
>
<span>
<Button
variant="outline"
className="w-full"
asChild={isGoogleOAuthEnabled}
disabled={!isGoogleOAuthEnabled}
>
<a href={googleUrl} className="flex items-center gap-2">
<Icon
icon={
isGoogleOAuthEnabled
? "flat-color-icons:google"
: "simple-icons:google"
}
width={24}
/>
Continue with Google
</a>
</Button>
</span>
</Tooltip>
<Tooltip
content={
<div className="flex-inline text-small">
Social Login with Github is not enabled.{" "}
<CustomLink href="https://docs.prowler.com/projects/prowler-open-source/en/latest/tutorials/prowler-app-social-login/#github-oauth-configuration">
Read the docs
</CustomLink>
</div>
}
placement="top"
shadow="sm"
isDisabled={isGithubOAuthEnabled}
className="w-96"
>
<span>
<Button
variant="outline"
className="w-full"
asChild={isGithubOAuthEnabled}
disabled={!isGithubOAuthEnabled}
>
<a href={githubUrl} className="flex items-center gap-2">
<Icon icon="simple-icons:github" width={24} />
Continue with Github
</a>
</Button>
</span>
</Tooltip>
</>
);
};
+60
View File
@@ -0,0 +1,60 @@
import { describe, expect, it } from "vitest";
import {
appendCallbackState,
getInvitationTokenFromCallbackPath,
getSafeCallbackPath,
} from "@/lib/auth-callback-url";
describe("auth callback URL helpers", () => {
describe("when appending OAuth state", () => {
it("should add a relative callback path as provider state", () => {
const authUrl = "https://accounts.example.com/oauth?client_id=client";
const callbackPath = "/invitation/accept?invitation_token=test-token";
const result = appendCallbackState(authUrl, callbackPath);
expect(new URL(result).searchParams.get("state")).toBe(callbackPath);
});
it("should not add state for the default callback path", () => {
const authUrl = "https://accounts.example.com/oauth?client_id=client";
const result = appendCallbackState(authUrl, "/");
expect(new URL(result).searchParams.has("state")).toBe(false);
});
});
describe("when reading callback paths", () => {
it("should return relative callback paths", () => {
const params = new URLSearchParams({
state: "/invitation/accept?invitation_token=test-token",
});
const result = getSafeCallbackPath(params);
expect(result).toBe("/invitation/accept?invitation_token=test-token");
});
it("should reject external callback URLs", () => {
const params = new URLSearchParams({
state: "https://attacker.example/phishing",
});
const result = getSafeCallbackPath(params);
expect(result).toBe("/");
});
});
describe("when reading invitation tokens", () => {
it("should return invitation tokens from safe callback paths", () => {
const callbackPath = "/invitation/accept?invitation_token=test-token";
const result = getInvitationTokenFromCallbackPath(callbackPath);
expect(result).toBe("test-token");
});
});
});
+50
View File
@@ -0,0 +1,50 @@
const DEFAULT_CALLBACK_PATH = "/";
const INVITATION_TOKEN_PARAM = "invitation_token";
type CallbackSearchParams = {
get(name: string): string | null;
};
export const getSafeCallbackPathFromValue = (
value: string | null | undefined,
) => {
if (!value || !value.startsWith("/") || value.startsWith("//")) {
return DEFAULT_CALLBACK_PATH;
}
return value;
};
export const getSafeCallbackPath = (
searchParams: CallbackSearchParams,
key = "state",
) => getSafeCallbackPathFromValue(searchParams.get(key));
export const appendCallbackState = (authUrl: string, callbackPath: string) => {
const safeCallbackPath = getSafeCallbackPathFromValue(callbackPath);
if (safeCallbackPath === DEFAULT_CALLBACK_PATH) {
return authUrl;
}
try {
const url = new URL(authUrl);
url.searchParams.set("state", safeCallbackPath);
return url.toString();
} catch (_error) {
return authUrl;
}
};
export const getInvitationTokenFromCallbackPath = (callbackPath: string) => {
const safeCallbackPath = getSafeCallbackPathFromValue(callbackPath);
if (safeCallbackPath === DEFAULT_CALLBACK_PATH) {
return null;
}
try {
const url = new URL(safeCallbackPath, "http://localhost");
return url.searchParams.get(INVITATION_TOKEN_PARAM);
} catch (_error) {
return null;
}
};