Source code for regstack.routers.oauth

"""OAuth router — sign-in / sign-up / link / unlink against an external provider.

Mounted by :func:`regstack.routers.build_router` when
``config.enable_oauth`` is on AND at least one provider is registered
on ``rs.oauth``.

Five endpoints (per provider; v1 ships with ``"google"``):

- ``GET  /oauth/{provider}/start`` — public; redirects to the provider.
- ``GET  /oauth/{provider}/callback`` — public; handles the redirect
  back, completes the flow, redirects to ``/account/oauth-complete``.
- ``POST /oauth/exchange`` — single-use; the SPA trades the state-id
  for the freshly-minted session JWT.
- ``POST /oauth/{provider}/link/start`` — authenticated; returns the
  authorization URL the SPA should navigate to.
- ``DELETE /oauth/{provider}/link`` — authenticated; unlinks one
  identity, refusing if it's the only auth method on the account.

The router enforces:

- Same-origin ``redirect_to`` (no open redirect).
- Server-side PKCE (``code_verifier`` never leaves the server).
- Anti-enumeration on linking conflicts (a clean 409 with no leaks).
- ``users.hashed_password is None`` paths for OAuth-only users.
"""

from __future__ import annotations

import base64
import hashlib
import logging
import secrets as _secrets
from datetime import timedelta
from typing import TYPE_CHECKING, Any
from urllib.parse import urlsplit

from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, ConfigDict, Field

from regstack.backends.protocols import OAuthIdentityAlreadyLinkedError
from regstack.models.oauth_identity import OAuthIdentity
from regstack.models.oauth_state import OAuthState
from regstack.models.user import BaseUser
from regstack.oauth.errors import OAuthError, OAuthIdTokenError, OAuthTokenExchangeError
from regstack.routers._schemas import MessageResponse

if TYPE_CHECKING:
    from regstack.app import RegStack
    from regstack.oauth.base import OAuthProvider, OAuthUserInfo

log = logging.getLogger("regstack.oauth")


class ExchangeRequest(BaseModel):
    model_config = ConfigDict(extra="forbid")
    id: str = Field(min_length=8, max_length=128)


class ExchangeResponse(BaseModel):
    """SPA payload after a successful OAuth callback.

    Two shapes, distinguished by ``mfa_required``:

    - Normal: ``mfa_required=False``, ``access_token`` is the session
      JWT, ``mfa_pending_token`` is ``None``.
    - MFA required (only when ``config.oauth.enforce_mfa_on_oauth_signin``
      is on and the user has SMS MFA set up): ``mfa_required=True``,
      ``access_token`` is empty, ``mfa_pending_token`` is what the SPA
      forwards to ``POST /login/mfa-confirm`` along with the SMS code.
    """

    access_token: str = ""
    redirect_to: str
    was_new_account: bool
    token_type: str = "bearer"
    expires_in: int
    mfa_required: bool = False
    mfa_pending_token: str | None = None


class LinkStartResponse(BaseModel):
    authorization_url: str


class LinkedIdentitySummary(BaseModel):
    """One identity in the ``/oauth/providers`` response."""

    provider: str
    email: str | None
    linked_at: str
    last_used_at: str | None


class ProvidersResponse(BaseModel):
    """Available providers + which ones the current user has linked.

    Drives the SSR ``/account/me`` "Connected accounts" panel and the
    ``/account/login`` "Sign in with X" buttons.
    """

    available: list[str]
    linked: list[LinkedIdentitySummary]


[docs] def build_oauth_router(rs: RegStack) -> APIRouter: """Build the OAuth router. Captures ``rs`` in closures so two :class:`~regstack.app.RegStack` instances in one process don't share state. """ router = APIRouter(prefix="/oauth", tags=["regstack-oauth"]) @router.get( "/exchange", include_in_schema=False, ) async def _no_get_exchange() -> None: # The SPA exchange is POST-only. Reject GET loudly so a # misconfigured client can't hit it accidentally. raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED) @router.get( "/providers", response_model=ProvidersResponse, summary="List configured providers and which the current user has linked", ) async def providers_list( user: BaseUser = Depends(rs.deps.current_user()), ) -> ProvidersResponse: assert user.id is not None identities = await rs.oauth_identities.list_for_user(user.id) return ProvidersResponse( available=rs.oauth.names(), linked=[ LinkedIdentitySummary( provider=i.provider, email=i.email, linked_at=i.linked_at.isoformat(), last_used_at=i.last_used_at.isoformat() if i.last_used_at else None, ) for i in identities ], ) @router.post( "/exchange", response_model=ExchangeResponse, summary="Trade an OAuth state-id for a session JWT", ) async def exchange(payload: ExchangeRequest) -> ExchangeResponse: state = await rs.oauth_states.consume(payload.id) if state is None or state.result_token is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="OAuth state not found or already consumed.", ) if state.expires_at < rs.clock.now(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="OAuth state has expired.", ) # When the callback hits the MFA branch (controlled by # `config.oauth.enforce_mfa_on_oauth_signin`), the result_token # we stored is a short-lived MFA pending JWT, not a session # token. Peek at the purpose claim — we minted this token # ourselves, so trusting the unverified payload is fine; the # state-row read is what authenticates it. if _is_mfa_pending_token(state.result_token): return ExchangeResponse( redirect_to=state.redirect_to, was_new_account=False, expires_in=rs.config.mfa_pending_token_ttl_seconds, mfa_required=True, mfa_pending_token=state.result_token, ) return ExchangeResponse( access_token=state.result_token, redirect_to=state.redirect_to, was_new_account=False, expires_in=rs.config.jwt_ttl_seconds, ) @router.get( "/{provider_name}/start", summary="Start an OAuth sign-in flow", response_class=RedirectResponse, status_code=status.HTTP_302_FOUND, ) async def oauth_start( provider_name: str, redirect_to: str = Query(default="/account/me"), ) -> RedirectResponse: provider = _resolve_provider(rs, provider_name) validated_redirect = _validate_redirect(rs, redirect_to) url = await _begin_flow( rs, provider, mode="signin", redirect_to=validated_redirect, linking_user_id=None, ) await rs.hooks.fire("oauth_signin_started", provider=provider_name, mode="signin") return RedirectResponse(url, status_code=status.HTTP_302_FOUND) @router.post( "/{provider_name}/link/start", response_model=LinkStartResponse, summary="Start an OAuth flow that links the provider to the current user", ) async def oauth_link_start( provider_name: str, user: BaseUser = Depends(rs.deps.current_user()), redirect_to: str = Query(default="/account/me"), ) -> LinkStartResponse: assert user.id is not None provider = _resolve_provider(rs, provider_name) validated_redirect = _validate_redirect(rs, redirect_to) url = await _begin_flow( rs, provider, mode="link", redirect_to=validated_redirect, linking_user_id=user.id, ) await rs.hooks.fire( "oauth_signin_started", provider=provider_name, mode="link", user=user, ) return LinkStartResponse(authorization_url=url) @router.delete( "/{provider_name}/link", response_model=MessageResponse, summary="Unlink an OAuth provider from the current account", ) async def oauth_unlink( provider_name: str, user: BaseUser = Depends(rs.deps.current_user()), ) -> MessageResponse: assert user.id is not None identities = await rs.oauth_identities.list_for_user(user.id) match = next((i for i in identities if i.provider == provider_name), None) if match is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="That provider is not linked to your account.", ) # Refuse to remove the last sign-in method. other_identities = len(identities) - 1 has_password = user.hashed_password is not None if not has_password and other_identities == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=( "Cannot unlink your only sign-in method. " "Set a password (use forgot-password) or link another " "provider first." ), ) await rs.oauth_identities.delete(user_id=user.id, provider=provider_name) await rs.hooks.fire("oauth_account_unlinked", user=user, provider=provider_name) return MessageResponse(message=f"Unlinked {provider_name}.") @router.get( "/{provider_name}/callback", summary="Provider redirects the browser here after authorization", response_class=RedirectResponse, status_code=status.HTTP_302_FOUND, ) async def oauth_callback( provider_name: str, request: Request, code: str | None = Query(default=None), state: str | None = Query(default=None), error: str | None = Query(default=None), ) -> RedirectResponse: ui_login = _ui_login_url(rs) if error: # The `error` query string is attacker-controlled (set by the # OAuth provider after the redirect). Strip CR/LF/CTRL chars # and cap length before logging so a compromised or malicious # provider can't inject log lines or ANSI escapes into the # server's log stream. Flagged as I-3 in the 2026-05-15 / # 2026-05-16 security reviews. safe_error = _sanitize_for_log(error) log.info("oauth callback error from provider %s: %s", provider_name, safe_error) return _redirect_with_error(ui_login, "oauth_failed") if not code or not state: return _redirect_with_error(ui_login, "missing_code_or_state") try: provider = _resolve_provider(rs, provider_name) except HTTPException: return _redirect_with_error(ui_login, "unknown_provider") state_row = await rs.oauth_states.find(state) if state_row is None or state_row.provider != provider_name: return _redirect_with_error(ui_login, "bad_state") if state_row.expires_at < rs.clock.now(): return _redirect_with_error(ui_login, "state_expired") try: tokens = await provider.exchange_code( code=code, redirect_uri=_callback_url(rs, provider_name), code_verifier=state_row.code_verifier, ) user_info = await provider.verify_id_token( tokens.id_token, expected_nonce=state_row.nonce ) except OAuthTokenExchangeError as exc: log.warning("oauth token exchange failed: %s", exc) return _redirect_with_error(ui_login, "token_exchange_failed") except OAuthIdTokenError as exc: log.warning("oauth id token verification failed: %s", exc) return _redirect_with_error(ui_login, "id_token_failed") except OAuthError as exc: log.warning("oauth error: %s", exc) return _redirect_with_error(ui_login, "oauth_failed") try: user, was_new = await _resolve_user( rs, provider_name=provider_name, info=user_info, state_row=state_row ) except _LinkConflictError as exc: return _redirect_with_error(ui_login, exc.code) # Touch last_used_at on the identity (best-effort). try: await rs.oauth_identities.touch_last_used( provider=provider_name, subject_id=user_info.subject_id, when=rs.clock.now(), ) except Exception: # pragma: no cover — best-effort log.exception("touch_last_used failed for %s/%s", provider_name, user_info.subject_id) assert user.id is not None # MFA gate: if the operator enabled `enforce_mfa_on_oauth_signin` # and this user has SMS MFA set up, stash a short-lived MFA # pending token in the state row instead of the session JWT. # The SPA's /oauth/exchange sees `mfa_required=True` and # forwards the pending token to POST /login/mfa-confirm with # the SMS code. Link flows are exempt — the user was already # authenticated when they started the link, and re-MFAing for # an attachment step is pointless friction. mfa_required = ( rs.config.oauth.enforce_mfa_on_oauth_signin and state_row.mode == "signin" and user.is_mfa_enabled and user.phone_number is not None ) if mfa_required: pending_token = await _start_oauth_mfa_step(rs, user) await rs.oauth_states.set_result_token( state_row.id, pending_token, new_expires_at=rs.clock.now() + timedelta(seconds=rs.config.mfa_pending_token_ttl_seconds), ) else: # Mint the session JWT, stash it on the state row for the # SPA's exchange call, and redirect. Shorten the row's expiry # to `oauth.completion_ttl_seconds` (default 30s) so the # window during which a stolen state_id could yield a session # token is bounded — the full `state_ttl_seconds` only had to # cover the user's round-trip with the provider. token, _payload = rs.jwt.encode(user.id) await rs.users.set_last_login(user.id, _payload.iat) await rs.oauth_states.set_result_token( state_row.id, token, new_expires_at=rs.clock.now() + timedelta(seconds=rs.config.oauth.completion_ttl_seconds), ) await rs.hooks.fire( "oauth_signin_completed", user=user, provider=provider_name, mode=state_row.mode, was_new=was_new, mfa_required=mfa_required, ) if state_row.mode == "link" and not was_new: await rs.hooks.fire("oauth_account_linked", user=user, provider=provider_name) complete_url = f"{rs.config.ui_prefix.rstrip('/')}/oauth-complete?id={state_row.id}" return RedirectResponse(complete_url, status_code=status.HTTP_302_FOUND) return router
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _LOGIN_MFA_PURPOSE = "login_mfa" def _is_mfa_pending_token(token: str) -> bool: """Peek at a JWT's ``purpose`` claim without verifying the signature. We minted this token ourselves and just read it out of our own state row — the trust anchor is the state-row read, not the JWT signature. `/login/mfa-confirm` re-verifies the signature properly before acting on it. Returns ``False`` on any decode failure so a malformed value falls through to the regular session-token path. """ import base64 import json try: parts = token.split(".") if len(parts) != 3: return False body = parts[1] body += "=" * (-len(body) % 4) payload = json.loads(base64.urlsafe_b64decode(body)) return bool(payload.get("purpose") == _LOGIN_MFA_PURPOSE) except Exception: return False async def _start_oauth_mfa_step(rs: RegStack, user: BaseUser) -> str: """Send the SMS second factor and return a short-lived pending JWT. Mirrors :func:`regstack.routers.login._start_mfa_step` but returns just the pending token (no FastAPI response). The pending JWT has purpose ``login_mfa`` so the SPA can hand it straight to ``POST /login/mfa-confirm`` — same downstream entry point as the password-login MFA flow. """ import secrets as _secrets import jwt as pyjwt from regstack.auth.mfa import generate_mfa_code from regstack.config.secrets import derive_secret from regstack.models.mfa_code import MfaCode from regstack.sms.base import SmsMessage assert user.id is not None assert user.phone_number is not None raw_code, code_hash = generate_mfa_code(rs.config) ttl = rs.config.sms_code_ttl_seconds await rs.mfa_codes.put( MfaCode( user_id=user.id, kind="login_mfa", code_hash=code_hash, expires_at=rs.clock.now() + timedelta(seconds=ttl), max_attempts=rs.config.sms_code_max_attempts, ) ) body = rs.mail.sms_body( kind="login_mfa", code=raw_code, ttl_minutes=max(ttl // 60, 1), ) await rs.sms.send( SmsMessage( to=user.phone_number, body=body, from_number=rs.config.sms.from_number, ) ) # No `code=` in the hook payload — same rationale as login.py. await rs.hooks.fire("mfa_login_started", user=user) pending_ttl = rs.config.mfa_pending_token_ttl_seconds now = rs.clock.now() claims: dict[str, Any] = { "sub": user.id, "jti": _secrets.token_urlsafe(16), "iat": now.timestamp(), "exp": int((now + timedelta(seconds=pending_ttl)).timestamp()), "purpose": _LOGIN_MFA_PURPOSE, } if rs.config.jwt_audience is not None: claims["aud"] = rs.config.jwt_audience key = derive_secret(rs.config.jwt_secret.get_secret_value(), _LOGIN_MFA_PURPOSE) return pyjwt.encode(claims, key, algorithm=rs.config.jwt_algorithm) class _LinkConflictError(Exception): """Internal: a callback couldn't reconcile the identity to a user. Carries a short error code that the redirect surfaces so the SPA can show a tailored message without us echoing internals. """ def __init__(self, code: str) -> None: self.code = code super().__init__(code) def _resolve_provider(rs: RegStack, name: str) -> OAuthProvider: if name not in rs.oauth: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"OAuth provider {name!r} is not configured.", ) return rs.oauth.get(name) def _validate_redirect(rs: RegStack, redirect_to: str) -> str: """Reject anything that isn't a same-origin path or full URL. `urlsplit` is too forgiving for our purposes: a value like ``/\\evil.com`` parses with an empty netloc but browsers normalize the backslash to a slash, producing the protocol-relative ``//evil.com``. Similarly, ``////evil.com`` collapses in the browser. Both are open-redirect vectors. We pre-screen the raw string for those shapes before trusting `urlsplit`. """ if not redirect_to: return "/account/me" redirect_to = redirect_to.strip() # Browsers normalize backslashes to forward slashes; treat any # backslash anywhere as hostile. if "\\" in redirect_to: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="redirect_to must be same-origin.", ) parts = urlsplit(redirect_to) if not parts.scheme and not parts.netloc: # Plain path like "/account/me" — fine, but reject anything # that looks like a protocol-relative URL (``//host``) or # doesn't start with a single leading slash. if not redirect_to.startswith("/") or redirect_to.startswith("//"): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="redirect_to must be same-origin.", ) return redirect_to base_parts = urlsplit(str(rs.config.base_url)) if (parts.scheme, parts.netloc) != (base_parts.scheme, base_parts.netloc): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="redirect_to must be same-origin.", ) return redirect_to def _ui_login_url(rs: RegStack) -> str: return f"{rs.config.ui_prefix.rstrip('/')}/login" def _callback_url(rs: RegStack, provider_name: str) -> str: """Build the absolute callback URL the provider redirects back to.""" cfg = rs.config if cfg.oauth.google_redirect_uri is not None and provider_name == "google": return str(cfg.oauth.google_redirect_uri) base = str(cfg.base_url).rstrip("/") return f"{base}{cfg.api_prefix.rstrip('/')}/oauth/{provider_name}/callback" def _redirect_with_error(ui_login_url: str, code: str) -> RedirectResponse: sep = "&" if "?" in ui_login_url else "?" return RedirectResponse( f"{ui_login_url}{sep}error={code}", status_code=status.HTTP_302_FOUND, ) def _sanitize_for_log(value: str) -> str: """Strip CR/LF and other control characters and cap length for safe logging of provider-controlled query strings. Defends against log injection where a malicious or compromised provider crafts an ``error=...`` value containing newlines or ANSI escape sequences intended to forge or obscure log entries. """ cleaned = "".join(ch for ch in value if ch.isprintable() and ch not in "\r\n") return cleaned[:200] async def _begin_flow( rs: RegStack, provider: OAuthProvider, *, mode: str, redirect_to: str, linking_user_id: str | None, ) -> str: """Generate PKCE / nonce / state, persist a state row, return the authorization URL the browser should be sent to. """ code_verifier = _secrets.token_urlsafe(64) code_challenge = ( base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) .rstrip(b"=") .decode() ) nonce = _secrets.token_urlsafe(16) state_id = _secrets.token_urlsafe(32) state = OAuthState( id=state_id, provider=provider.name, code_verifier=code_verifier, nonce=nonce, redirect_to=redirect_to, mode=mode, linking_user_id=linking_user_id, created_at=rs.clock.now(), expires_at=rs.clock.now() + timedelta(seconds=rs.config.oauth.state_ttl_seconds), ) await rs.oauth_states.create(state) return provider.authorization_url( redirect_uri=_callback_url(rs, provider.name), state=state_id, code_challenge=code_challenge, nonce=nonce, ) async def _resolve_user( rs: RegStack, *, provider_name: str, info: OAuthUserInfo, state_row: OAuthState, ) -> tuple[BaseUser, bool]: """Find or create the user this OAuth login should resolve to. Returns ``(user, was_new_account)``. Raises ``_LinkConflictError`` with a short error code on ambiguous / refused linking. """ # 1. Already-linked identity? Sign that user in. identity = await rs.oauth_identities.find_by_subject( provider=provider_name, subject_id=info.subject_id ) if identity is not None: if state_row.mode == "link": # Linking an identity that's already on a different (or even # the same) account is a 409 — surface it; don't silently # take over. if ( state_row.linking_user_id is not None and identity.user_id != state_row.linking_user_id ): raise _LinkConflictError("identity_in_use") raise _LinkConflictError("already_linked") user = await rs.users.get_by_id(identity.user_id) if user is None or not user.is_active: raise _LinkConflictError("user_inactive") return user, False # 2. Authenticated link flow — attach the identity to the linking user. if state_row.mode == "link": assert state_row.linking_user_id is not None target = await rs.users.get_by_id(state_row.linking_user_id) if target is None or not target.is_active or target.id is None: raise _LinkConflictError("user_inactive") try: await rs.oauth_identities.create( OAuthIdentity( user_id=target.id, provider=provider_name, subject_id=info.subject_id, email=info.email, linked_at=rs.clock.now(), ) ) except OAuthIdentityAlreadyLinkedError as exc: raise _LinkConflictError("identity_in_use") from exc return target, False # 3. Sign-in flow with no existing identity. if info.email: existing = await rs.users.get_by_email(info.email) else: existing = None if existing is not None: if rs.config.oauth.auto_link_verified_emails and info.email_verified: assert existing.id is not None try: await rs.oauth_identities.create( OAuthIdentity( user_id=existing.id, provider=provider_name, subject_id=info.subject_id, email=info.email, linked_at=rs.clock.now(), ) ) except OAuthIdentityAlreadyLinkedError as exc: raise _LinkConflictError("identity_in_use") from exc return existing, False raise _LinkConflictError("email_in_use") # 4. Brand-new account. # `/register` honours `allow_registration=False`; OAuth must too or # operators who disabled self-service signup still get accounts # created via "Sign in with <provider>". if not rs.config.allow_registration: raise _LinkConflictError("registration_disabled") new_user = BaseUser( email=info.email or _placeholder_email(provider_name, info.subject_id), hashed_password=None, full_name=info.full_name, is_active=True, is_verified=bool(info.email_verified), ) new_user = await rs.users.create(new_user) assert new_user.id is not None await rs.oauth_identities.create( OAuthIdentity( user_id=new_user.id, provider=provider_name, subject_id=info.subject_id, email=info.email, linked_at=rs.clock.now(), ) ) await rs.hooks.fire("user_registered", user=new_user) return new_user, True def _placeholder_email(provider_name: str, subject_id: str) -> str: """Last-resort email for providers that didn't return one. Hosts that hit this should turn on the provider scopes that actually return email; we don't want a regstack user without an email for password-reset purposes. The placeholder uses the `.invalid` TLD so it definitely won't deliver. """ return f"{provider_name}-{subject_id}@oauth.invalid" __all__ = ["ExchangeRequest", "ExchangeResponse", "LinkStartResponse", "build_oauth_router"]