Source code for regstack.auth.jwt

from __future__ import annotations

import secrets
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Protocol

import jwt as pyjwt

from regstack.config.secrets import derive_secret

if TYPE_CHECKING:
    from regstack.auth.clock import Clock
    from regstack.config.schema import RegStackConfig

_DEFAULT_PURPOSE = "session"


[docs] class TokenError(Exception): """Raised when a token cannot be decoded or is no longer valid."""
[docs] @dataclass(slots=True) class TokenPayload: """Decoded JWT claims for a regstack-issued token. Returned by :meth:`JwtCodec.decode` and produced as a side-effect of :meth:`JwtCodec.encode`. Carries everything callers need to enforce revocation and audit trails without re-decoding the raw token. """ sub: str """The token subject — for session tokens, the user id.""" jti: str """Per-token unique id used by the blacklist repo for explicit revocation on logout.""" iat: datetime """When the token was issued. Tz-aware. Stored as a float (RFC 7519 NumericDate) so the bulk-revoke comparison ``iat <= cutoff`` is precise even when a login completes microseconds after a password change.""" exp: datetime """Expiry timestamp. Tz-aware.""" purpose: str """Token kind — ``session``, ``password_reset``, ``email_change``, ``phone_setup``, or ``login_mfa``. Each kind is signed with a different derived key."""
[docs] def to_claims(self, audience: str | None) -> dict[str, Any]: """Serialize as a JWT claims dict, including ``aud`` if set. Args: audience: Optional audience claim. When non-``None``, added as ``aud``; ``JwtCodec.decode`` then validates it. Returns: A dict suitable for ``pyjwt.encode``. """ # iat is emitted as a float so a within-the-same-second login after a # bulk-revoke cutoff isn't falsely rejected. RFC 7519 NumericDate # explicitly allows non-integer values. claims: dict[str, Any] = { "sub": self.sub, "jti": self.jti, "iat": self.iat.timestamp(), "exp": int(self.exp.timestamp()), "purpose": self.purpose, } if audience is not None: claims["aud"] = audience return claims
class RevocationChecker(Protocol): """Protocol for anything that can answer "is this `jti` revoked?". The :class:`~regstack.backends.protocols.BlacklistRepoProtocol` satisfies this contract; auth dependencies depend on it abstractly so tests can substitute a stub. """ async def is_revoked(self, jti: str) -> bool: ...
[docs] class JwtCodec: """Encode and decode regstack's signed JWTs. Each *purpose* (``session``, ``password_reset``, ``email_change``, ``phone_setup``, ``login_mfa``) signs with a separate key derived from ``config.jwt_secret`` via HMAC-SHA256. Compromise of one derived key does not compromise the others, and an attacker who captures a session token cannot replay it as a password-reset token. Expiry (``exp``) and issued-at (``iat``) are evaluated against the injected :class:`~regstack.auth.clock.Clock`, not the system wall clock — that's the seam ``FrozenClock``-driven tests rely on. """ def __init__(self, config: RegStackConfig, clock: Clock) -> None: """Bind the codec to a config and a clock. Args: config: The active :class:`~regstack.config.schema.RegStackConfig`. ``config.jwt_secret`` must be non-empty. clock: The clock used for issuing ``iat``/``exp`` and for validating expiry on decode. Raises: ValueError: If ``config.jwt_secret`` is empty. Use ``regstack init`` to generate one, or set ``REGSTACK_JWT_SECRET``. """ if not config.jwt_secret.get_secret_value(): raise ValueError( "RegStackConfig.jwt_secret is empty. Run `regstack init` to generate one, " "or set REGSTACK_JWT_SECRET." ) self._config = config self._clock = clock def _key(self, purpose: str) -> bytes: return derive_secret(self._config.jwt_secret.get_secret_value(), purpose)
[docs] def encode( self, subject: str, *, purpose: str = _DEFAULT_PURPOSE, ttl_seconds: int | None = None, ) -> tuple[str, TokenPayload]: """Sign and return a JWT plus its decoded payload. Args: subject: The ``sub`` claim — usually a user id. purpose: Logical token kind (``session`` by default). The derived signing key depends on this string, so a token minted with one purpose cannot be decoded with another. ttl_seconds: Override for the token lifetime. When ``None``, ``config.jwt_ttl_seconds`` is used. Returns: A ``(token, payload)`` tuple. ``token`` is the encoded string the caller hands to the client; ``payload`` is the in-memory :class:`TokenPayload` (useful when the caller also needs to record the ``jti`` for later revocation). """ now = self._clock.now() ttl = ttl_seconds if ttl_seconds is not None else self._config.jwt_ttl_seconds exp = now + timedelta(seconds=ttl) payload = TokenPayload( sub=subject, jti=secrets.token_urlsafe(16), iat=now, exp=exp, purpose=purpose, ) token = pyjwt.encode( payload.to_claims(self._config.jwt_audience), self._key(purpose), algorithm=self._config.jwt_algorithm, ) return token, payload
[docs] def decode(self, token: str, *, purpose: str = _DEFAULT_PURPOSE) -> TokenPayload: """Verify a token's signature and decode its claims. Verification is strict: signature, required-claims set, and ``aud`` (when configured) must all pass. ``exp`` is checked against the injected :class:`Clock`, **not** ``time.time()``, so frozen-clock tests stay deterministic. The ``purpose`` claim must match the expected purpose — a session token cannot satisfy ``decode(..., purpose="password_reset")``. Args: token: The encoded JWT string from the client. purpose: The expected token kind. Must match the value used at :meth:`encode` time. Returns: The decoded :class:`TokenPayload`. Raises: TokenError: On any failure — bad signature, expired token, purpose mismatch, missing required claim, etc. """ try: # We disable pyjwt's exp/iat checks because they use the system # wall clock; we re-check both against the injected Clock so that # FrozenClock-driven tests (and any future time-mocking host) get # consistent behaviour. claims = pyjwt.decode( token, self._key(purpose), algorithms=[self._config.jwt_algorithm], audience=self._config.jwt_audience, options={ "require": ["sub", "exp", "iat", "jti", "purpose"], "verify_exp": False, "verify_iat": False, "verify_nbf": False, }, ) except pyjwt.PyJWTError as exc: raise TokenError(str(exc)) from exc if claims.get("purpose") != purpose: raise TokenError("token purpose mismatch") now = self._clock.now() tz = now.tzinfo iat = datetime.fromtimestamp(float(claims["iat"]), tz=tz) exp = datetime.fromtimestamp(int(claims["exp"]), tz=tz) if now >= exp: raise TokenError("Signature has expired") return TokenPayload( sub=str(claims["sub"]), jti=str(claims["jti"]), iat=iat, exp=exp, purpose=str(claims["purpose"]), )
[docs] def is_payload_bulk_revoked(payload: TokenPayload, cutoff: datetime | None) -> bool: """Decide whether a token has been bulk-revoked. Bulk revocation lets regstack invalidate every outstanding session when a user changes their password (or admin disables them) without enumerating every ``jti``. The user document carries a ``tokens_invalidated_after`` timestamp; any token with ``iat <= cutoff`` is rejected. The comparison is ``<=`` (not ``<``) so a token issued at exactly the cutoff instant is treated as before-the-change for security. Float-precision ``iat`` (RFC 7519 NumericDate) means a login completing microseconds after a password change has ``iat > cutoff`` and survives. Args: payload: The decoded :class:`TokenPayload`. cutoff: The user's ``tokens_invalidated_after`` value, or ``None`` if they've never been bulk-revoked. Returns: ``True`` if the token must be rejected; ``False`` if it's still valid (subject to the per-token blacklist check). """ if cutoff is None: return False return payload.iat <= cutoff