Pipeline/backend/app/db/session.py

147 lines
5.0 KiB
Python
Raw Normal View History

"""Database engine, session factory, and startup migration helpers."""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import TYPE_CHECKING
from alembic.config import Config
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from app import models as _models
from app.core.config import settings
from app.core.logging import get_logger
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
# Import model modules so SQLModel metadata is fully registered at startup.
_MODEL_REGISTRY = _models
def _normalize_database_url(database_url: str) -> str:
if "://" not in database_url:
return database_url
scheme, rest = database_url.split("://", 1)
if scheme in ("postgresql", "postgres"):
return f"postgresql+psycopg://{rest}"
return database_url
async_engine: AsyncEngine = create_async_engine(
_normalize_database_url(settings.database_url),
pool_pre_ping=True,
)
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
logger = get_logger(__name__)
def _alembic_config() -> Config:
alembic_ini = Path(__file__).resolve().parents[2] / "alembic.ini"
alembic_cfg = Config(str(alembic_ini))
alembic_cfg.attributes["configure_logger"] = False
return alembic_cfg
def run_migrations() -> None:
"""Apply Alembic migrations to the latest revision."""
from alembic import command
logger.info("Running database migrations.")
command.upgrade(_alembic_config(), "head")
logger.info("Database migrations complete.")
2026-05-20 02:49:21 -05:00
_warn_on_schema_drift()
def _warn_on_schema_drift() -> None:
"""Log an error-level warning if the live schema is missing model columns.
This catches the case where Alembic's version table reports 'head' but
columns were never actually applied (e.g. a migration was inserted into the
chain after the DB had already advanced past it). The service continues to
start the warning is intentionally loud so it surfaces in logs immediately.
"""
from sqlalchemy import create_engine
from sqlalchemy import inspect as sa_inspect
sync_url = (
settings.database_url.replace("postgresql+asyncpg://", "postgresql+psycopg://")
2026-05-20 02:49:21 -05:00
.replace("postgresql://", "postgresql+psycopg://")
.replace("postgres://", "postgresql+psycopg://")
)
# _normalize_database_url already adds +psycopg; strip any double-prefix.
if "postgresql+psycopg+psycopg" in sync_url:
sync_url = sync_url.replace("postgresql+psycopg+psycopg", "postgresql+psycopg")
try:
engine = create_engine(sync_url, pool_pre_ping=True)
inspector = sa_inspect(engine)
except Exception as exc:
logger.error("schema_drift_check_failed: %s", str(exc))
if "engine" in locals():
engine.dispose()
2026-05-20 02:49:21 -05:00
return
missing: list[str] = []
try:
for table_name, table in SQLModel.metadata.tables.items():
if not inspector.has_table(table_name):
missing.append(f"TABLE {table_name}")
continue
db_cols = {col["name"] for col in inspector.get_columns(table_name)}
for col in table.columns:
if col.name not in db_cols:
missing.append(f"COLUMN {table_name}.{col.name}")
except Exception as exc:
logger.error("schema_drift_check_failed: %s", str(exc))
2026-05-20 02:49:21 -05:00
finally:
engine.dispose()
if missing:
logger.error(
"schema_drift_detected: %s (hint: %s)",
missing,
"DB schema does not match models. Run scripts/check_schema.py for details.",
2026-05-20 02:49:21 -05:00
)
async def init_db() -> None:
"""Initialize database schema, running migrations when configured."""
if settings.db_auto_migrate:
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
if any(versions_dir.glob("*.py")):
logger.info("Running migrations on startup")
await asyncio.to_thread(run_migrations)
return
logger.warning("No migration revisions found; falling back to create_all")
async with async_engine.connect() as conn, conn.begin():
await conn.run_sync(SQLModel.metadata.create_all)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""Yield a request-scoped async DB session with safe rollback on errors."""
async with async_session_maker() as session:
try:
yield session
finally:
in_txn = False
try:
in_txn = bool(session.in_transaction())
except SQLAlchemyError:
logger.exception("Failed to inspect session transaction state.")
if in_txn:
try:
await session.rollback()
except SQLAlchemyError:
logger.exception("Failed to rollback session after request error.")