Source code for regstack.backends.protocols

"""Storage-layer protocols.

Each backend (`mongo`, `sql`) provides one concrete implementation per
protocol. Routers and services depend only on these protocols, so a
backend swap is a wiring change, not a code change.

Mongo's existing semantics are the contract: anything new (SQLAlchemy,
in-memory, etc.) is judged by behavioural parity, not by whether the SQL
or Mongo idiom feels "right" to the implementer.
"""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import StrEnum
from typing import Protocol, runtime_checkable

from regstack.models.mfa_code import MfaCode, MfaKind
from regstack.models.oauth_identity import OAuthIdentity
from regstack.models.oauth_state import OAuthState
from regstack.models.pending_registration import PendingRegistration
from regstack.models.user import BaseUser


[docs] class MfaVerifyOutcome(StrEnum): """Result of submitting an SMS MFA code to the repo. The five possible values: - ``OK`` — the code matched and the row was consumed. - ``WRONG`` — the code didn't match. ``attempts_remaining`` on the paired :class:`MfaVerifyResult` says how many tries are left before the row is deleted (forcing a re-issue). - ``EXPIRED`` — a row exists but its TTL has passed. - ``LOCKED`` — too many wrong guesses; the row was deleted and the user must request a new code. - ``MISSING`` — no outstanding code for this user / kind. """ OK = "ok" WRONG = "wrong" EXPIRED = "expired" LOCKED = "locked" MISSING = "missing"
[docs] @dataclass(slots=True, frozen=True) class MfaVerifyResult: """Outcome of :meth:`MfaCodeRepoProtocol.verify`.""" outcome: MfaVerifyOutcome """Which terminal state the verify call landed in. See :class:`MfaVerifyOutcome`.""" attempts_remaining: int = 0 """For :attr:`MfaVerifyOutcome.WRONG`, how many more guesses the user has before the code is deleted and they must request a new one. ``0`` for any other outcome."""
[docs] class UserAlreadyExistsError(Exception): """Raised when an insert / email-change collides with an existing user. Backend-agnostic — every repo raises this same type on its integrity-error path so callers can branch on the type without importing a backend module. Surfaced by the registration and change-email routers as HTTP 409. """
[docs] class PendingAlreadyExistsError(Exception): """A pending registration with this email already exists. In practice the registration router uses an upsert so this exception is rarely raised — kept as the backend-agnostic name for the error so future callers don't need to import a backend module. """
[docs] @runtime_checkable class UserRepoProtocol(Protocol):
[docs] async def create(self, user: BaseUser) -> BaseUser: ...
[docs] async def get_by_email(self, email: str) -> BaseUser | None: ...
[docs] async def get_by_id(self, user_id: str) -> BaseUser | None: ...
[docs] async def set_last_login(self, user_id: str, when: datetime) -> None: ...
[docs] async def set_tokens_invalidated_after(self, user_id: str, when: datetime) -> None: ...
[docs] async def update_password(self, user_id: str, hashed_password: str) -> None: ...
[docs] async def set_active(self, user_id: str, *, is_active: bool) -> None: ...
[docs] async def set_superuser(self, user_id: str, *, is_superuser: bool) -> None: ...
[docs] async def set_full_name(self, user_id: str, full_name: str | None) -> None: ...
[docs] async def set_phone(self, user_id: str, phone_number: str | None) -> None: ...
[docs] async def set_mfa_enabled(self, user_id: str, *, is_mfa_enabled: bool) -> None: ...
[docs] async def update_email(self, user_id: str, new_email: str) -> None: """Atomically swap email + bump tokens_invalidated_after. Implementations MUST raise :class:`UserAlreadyExistsError` if the new email is already taken by a different user. Bulk-revoke is the caller-visible side-effect that sessions bound to the old email die immediately. """ ...
[docs] async def delete(self, user_id: str) -> bool: ...
[docs] async def count( self, *, is_active: bool | None = None, is_verified: bool | None = None, is_superuser: bool | None = None, ) -> int: """Count users matching ALL of the provided filters (None = ignored).""" ...
[docs] async def list_paged(
self, *, skip: int = 0, limit: int = 50, sort_by_created_at_desc: bool = True, ) -> list[BaseUser]: ...
[docs] @runtime_checkable class PendingRepoProtocol(Protocol):
[docs] async def upsert(self, pending: PendingRegistration) -> PendingRegistration: """Insert or replace the pending registration for this email. Resends overwrite outstanding rows so the most recent token is the only valid one — old verification links must stop working. """ ...
[docs] async def find_by_token_hash(self, token_hash: str) -> PendingRegistration | None: ...
[docs] async def find_by_email(self, email: str) -> PendingRegistration | None: ...
[docs] async def delete_by_email(self, email: str) -> None: ...
[docs] async def purge_expired(self, now: datetime | None = None) -> int: """Sweep expired rows. MongoDB has a TTL index; SQL backends rely on a periodic call to this method. """ ...
[docs] async def count_unexpired(self, now: datetime | None = None) -> int: """Count pending-registration rows whose ``expires_at`` is in the future. "Unexpired" rather than a raw row-count because SQL backends accumulate dead rows until ``purge_expired`` runs — a raw count would double-report a verification email that's been unanswered for a month and a fresh one sent today. Args: now: Reference instant. Defaults to ``datetime.now(UTC)``. Returns: Number of pending rows with ``expires_at > now``. """ ...
[docs] @runtime_checkable class BlacklistRepoProtocol(Protocol):
[docs] async def revoke(self, jti: str, exp: datetime) -> None: ...
[docs] async def is_revoked(self, jti: str) -> bool: ...
[docs] async def purge_expired(self, now: datetime | None = None) -> int: ...
[docs] @runtime_checkable class LoginAttemptRepoProtocol(Protocol):
[docs] async def record_failure(
self, email: str, *, when: datetime | None = None, ip: str | None = None, ) -> None: ...
[docs] async def count_recent(
self, email: str, *, window: timedelta, now: datetime, ) -> int: ...
[docs] async def clear(self, email: str) -> None: ...
[docs] async def purge_expired(self, now: datetime, window: timedelta) -> int: ...
[docs] @runtime_checkable class MfaCodeRepoProtocol(Protocol):
[docs] async def put(self, code: MfaCode) -> None: ...
[docs] async def verify(
self, *, user_id: str, kind: MfaKind, raw_code: str, ) -> MfaVerifyResult: ...
[docs] async def delete(self, *, user_id: str, kind: MfaKind | None = None) -> None: ...
[docs] async def find(self, *, user_id: str, kind: MfaKind) -> MfaCode | None: ...
[docs] async def purge_expired(self, now: datetime | None = None) -> int: ...
[docs] class OAuthIdentityAlreadyLinkedError(Exception): """An identity is already linked to a regstack user. Raised by :meth:`OAuthIdentityRepoProtocol.create` when the ``UNIQUE(provider, subject_id)`` or ``UNIQUE(user_id, provider)`` constraint fires. Routers translate this to HTTP 409. """
[docs] @runtime_checkable class OAuthIdentityRepoProtocol(Protocol): """External-OAuth identities linked to regstack users. One row per ``(provider, subject_id)``. Two unique constraints — see :class:`~regstack.models.oauth_identity.OAuthIdentity` for the rationale. """
[docs] async def create(self, identity: OAuthIdentity) -> OAuthIdentity: """Insert a new identity. Raises :class:`OAuthIdentityAlreadyLinkedError` on either unique-constraint violation. """ ...
[docs] async def find_by_subject(self, *, provider: str, subject_id: str) -> OAuthIdentity | None: ...
[docs] async def list_for_user(self, user_id: str) -> list[OAuthIdentity]: """Every identity linked to ``user_id``, sorted by ``linked_at`` ascending.""" ...
[docs] async def delete(self, *, user_id: str, provider: str) -> bool: """Delete one identity. Returns True if a row was removed.""" ...
[docs] async def delete_by_user_id(self, user_id: str) -> int: """Delete every identity for a user. Called from the delete-account path so identities don't outlive their owner. """ ...
[docs] async def touch_last_used(self, *, provider: str, subject_id: str, when: datetime) -> None: """Update ``last_used_at``. Called on each successful sign-in through this identity. Best-effort — failure is logged, not raised. """ ...
[docs] @runtime_checkable class OAuthStateRepoProtocol(Protocol): """Server-side state rows for in-flight OAuth flows. The OAuth ``state`` parameter the browser carries is just the row's ``id``. The PKCE ``code_verifier`` and the post-callback ``result_token`` are server-side fields on the row. """
[docs] async def create(self, state: OAuthState) -> None: """Insert. Caller picks the row id (usually :func:`secrets.token_urlsafe`). """ ...
[docs] async def find(self, state_id: str) -> OAuthState | None: ...
[docs] async def set_result_token( self, state_id: str, token: str, *, new_expires_at: datetime | None = None, ) -> None: """Stash the session JWT after a successful callback so the SPA can pick it up via :meth:`consume`. If ``new_expires_at`` is given, the row's ``expires_at`` is bumped to that timestamp at the same time — this is how callers shorten the redemption window from ``oauth.state_ttl_seconds`` (covering the round-trip with the provider) to ``oauth.completion_ttl_seconds`` (covering only the SPA's exchange call after the callback lands). """ ...
[docs] async def consume(self, state_id: str) -> OAuthState | None: """Atomic read + delete. The exchange endpoint reads the ``result_token``; the row is gone after this call returns, making the exchange single-use. Returns ``None`` if the row is missing. """ ...
[docs] async def purge_expired(self, now: datetime | None = None) -> int: """Sweep expired rows. Mongo has a TTL index; SQL relies on a periodic call to this. """ ...