"""Database engine, session factory, and startup migration helpers.""" from __future__ import annotations import asyncio from pathlib import Path from typing import TYPE_CHECKING from alembic.config import Config from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession from app import models as _models from app.core.config import settings from app.core.logging import get_logger if TYPE_CHECKING: from collections.abc import AsyncGenerator # Import model modules so SQLModel metadata is fully registered at startup. _MODEL_REGISTRY = _models def _normalize_database_url(database_url: str) -> str: if "://" not in database_url: return database_url scheme, rest = database_url.split("://", 1) if scheme in ("postgresql", "postgres"): return f"postgresql+psycopg://{rest}" return database_url async_engine: AsyncEngine = create_async_engine( _normalize_database_url(settings.database_url), pool_pre_ping=True, ) async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, ) logger = get_logger(__name__) def _alembic_config() -> Config: alembic_ini = Path(__file__).resolve().parents[2] / "alembic.ini" alembic_cfg = Config(str(alembic_ini)) alembic_cfg.attributes["configure_logger"] = False return alembic_cfg def run_migrations() -> None: """Apply Alembic migrations to the latest revision.""" from alembic import command logger.info("Running database migrations.") command.upgrade(_alembic_config(), "head") logger.info("Database migrations complete.") _warn_on_schema_drift() def _warn_on_schema_drift() -> None: """Log an error-level warning if the live schema is missing model columns. This catches the case where Alembic's version table reports 'head' but columns were never actually applied (e.g. a migration was inserted into the chain after the DB had already advanced past it). The service continues to start — the warning is intentionally loud so it surfaces in logs immediately. """ from sqlalchemy import create_engine from sqlalchemy import inspect as sa_inspect sync_url = ( settings.database_url .replace("postgresql+asyncpg://", "postgresql+psycopg://") .replace("postgresql://", "postgresql+psycopg://") .replace("postgres://", "postgresql+psycopg://") ) # _normalize_database_url already adds +psycopg; strip any double-prefix. if "postgresql+psycopg+psycopg" in sync_url: sync_url = sync_url.replace("postgresql+psycopg+psycopg", "postgresql+psycopg") try: engine = create_engine(sync_url, pool_pre_ping=True) inspector = sa_inspect(engine) except Exception as exc: logger.error("schema_drift_check_failed", error=str(exc)) engine.dispose() if "engine" in dir() else None # type: ignore[name-defined] return missing: list[str] = [] try: for table_name, table in SQLModel.metadata.tables.items(): if not inspector.has_table(table_name): missing.append(f"TABLE {table_name}") continue db_cols = {col["name"] for col in inspector.get_columns(table_name)} for col in table.columns: if col.name not in db_cols: missing.append(f"COLUMN {table_name}.{col.name}") except Exception as exc: logger.error("schema_drift_check_failed", error=str(exc)) finally: engine.dispose() if missing: logger.error( "schema_drift_detected", missing=missing, hint="DB schema does not match models. Run scripts/check_schema.py for details.", ) async def init_db() -> None: """Initialize database schema, running migrations when configured.""" if settings.db_auto_migrate: versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions" if any(versions_dir.glob("*.py")): logger.info("Running migrations on startup") await asyncio.to_thread(run_migrations) return logger.warning("No migration revisions found; falling back to create_all") async with async_engine.connect() as conn, conn.begin(): await conn.run_sync(SQLModel.metadata.create_all) async def get_session() -> AsyncGenerator[AsyncSession, None]: """Yield a request-scoped async DB session with safe rollback on errors.""" async with async_session_maker() as session: try: yield session finally: in_txn = False try: in_txn = bool(session.in_transaction()) except SQLAlchemyError: logger.exception("Failed to inspect session transaction state.") if in_txn: try: await session.rollback() except SQLAlchemyError: logger.exception("Failed to rollback session after request error.")