147 lines
5.1 KiB
Python
147 lines
5.1 KiB
Python
"""Database engine, session factory, and startup migration helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
from alembic.config import Config
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
|
from sqlmodel import SQLModel
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app import models as _models
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncGenerator
|
|
|
|
# Import model modules so SQLModel metadata is fully registered at startup.
|
|
_MODEL_REGISTRY = _models
|
|
|
|
|
|
def _normalize_database_url(database_url: str) -> str:
|
|
if "://" not in database_url:
|
|
return database_url
|
|
scheme, rest = database_url.split("://", 1)
|
|
if scheme in ("postgresql", "postgres"):
|
|
return f"postgresql+psycopg://{rest}"
|
|
return database_url
|
|
|
|
|
|
async_engine: AsyncEngine = create_async_engine(
|
|
_normalize_database_url(settings.database_url),
|
|
pool_pre_ping=True,
|
|
)
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _alembic_config() -> Config:
|
|
alembic_ini = Path(__file__).resolve().parents[2] / "alembic.ini"
|
|
|
|
alembic_cfg = Config(str(alembic_ini))
|
|
alembic_cfg.attributes["configure_logger"] = False
|
|
return alembic_cfg
|
|
|
|
|
|
def run_migrations() -> None:
|
|
"""Apply Alembic migrations to the latest revision."""
|
|
from alembic import command
|
|
|
|
logger.info("Running database migrations.")
|
|
command.upgrade(_alembic_config(), "head")
|
|
logger.info("Database migrations complete.")
|
|
_warn_on_schema_drift()
|
|
|
|
|
|
def _warn_on_schema_drift() -> None:
|
|
"""Log an error-level warning if the live schema is missing model columns.
|
|
|
|
This catches the case where Alembic's version table reports 'head' but
|
|
columns were never actually applied (e.g. a migration was inserted into the
|
|
chain after the DB had already advanced past it). The service continues to
|
|
start — the warning is intentionally loud so it surfaces in logs immediately.
|
|
"""
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy import inspect as sa_inspect
|
|
|
|
sync_url = (
|
|
settings.database_url
|
|
.replace("postgresql+asyncpg://", "postgresql+psycopg://")
|
|
.replace("postgresql://", "postgresql+psycopg://")
|
|
.replace("postgres://", "postgresql+psycopg://")
|
|
)
|
|
# _normalize_database_url already adds +psycopg; strip any double-prefix.
|
|
if "postgresql+psycopg+psycopg" in sync_url:
|
|
sync_url = sync_url.replace("postgresql+psycopg+psycopg", "postgresql+psycopg")
|
|
|
|
try:
|
|
engine = create_engine(sync_url, pool_pre_ping=True)
|
|
inspector = sa_inspect(engine)
|
|
except Exception as exc:
|
|
logger.error("schema_drift_check_failed", error=str(exc))
|
|
engine.dispose() if "engine" in dir() else None # type: ignore[name-defined]
|
|
return
|
|
|
|
missing: list[str] = []
|
|
try:
|
|
for table_name, table in SQLModel.metadata.tables.items():
|
|
if not inspector.has_table(table_name):
|
|
missing.append(f"TABLE {table_name}")
|
|
continue
|
|
db_cols = {col["name"] for col in inspector.get_columns(table_name)}
|
|
for col in table.columns:
|
|
if col.name not in db_cols:
|
|
missing.append(f"COLUMN {table_name}.{col.name}")
|
|
except Exception as exc:
|
|
logger.error("schema_drift_check_failed", error=str(exc))
|
|
finally:
|
|
engine.dispose()
|
|
|
|
if missing:
|
|
logger.error(
|
|
"schema_drift_detected",
|
|
missing=missing,
|
|
hint="DB schema does not match models. Run scripts/check_schema.py for details.",
|
|
)
|
|
|
|
|
|
async def init_db() -> None:
|
|
"""Initialize database schema, running migrations when configured."""
|
|
if settings.db_auto_migrate:
|
|
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
|
|
if any(versions_dir.glob("*.py")):
|
|
logger.info("Running migrations on startup")
|
|
await asyncio.to_thread(run_migrations)
|
|
return
|
|
logger.warning("No migration revisions found; falling back to create_all")
|
|
|
|
async with async_engine.connect() as conn, conn.begin():
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
|
|
|
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Yield a request-scoped async DB session with safe rollback on errors."""
|
|
async with async_session_maker() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
in_txn = False
|
|
try:
|
|
in_txn = bool(session.in_transaction())
|
|
except SQLAlchemyError:
|
|
logger.exception("Failed to inspect session transaction state.")
|
|
if in_txn:
|
|
try:
|
|
await session.rollback()
|
|
except SQLAlchemyError:
|
|
logger.exception("Failed to rollback session after request error.")
|