122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
|
|
"""check_schema.py — verify the live DB schema matches SQLModel model definitions.
|
||
|
|
|
||
|
|
Connects to the database, inspects every table and column defined in the
|
||
|
|
SQLModel metadata, and reports anything that is absent in the real schema.
|
||
|
|
Exits 1 if drift is found, 0 if clean.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python scripts/check_schema.py
|
||
|
|
DATABASE_URL=postgresql+psycopg://... python scripts/check_schema.py
|
||
|
|
|
||
|
|
This catches the class of problem where Alembic's version table says the DB is
|
||
|
|
up-to-date but columns were never actually created (e.g. a migration was
|
||
|
|
inserted into the chain after the DB had already advanced past it).
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
# Make the app package importable when run directly.
|
||
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_database_url() -> str | None:
|
||
|
|
url = os.getenv("DATABASE_URL")
|
||
|
|
if url:
|
||
|
|
return url
|
||
|
|
env_file = Path(__file__).resolve().parents[1] / ".env"
|
||
|
|
if env_file.exists():
|
||
|
|
for line in env_file.read_text().splitlines():
|
||
|
|
line = line.strip()
|
||
|
|
if line.startswith("DATABASE_URL=") and not line.startswith("#"):
|
||
|
|
return line.split("=", 1)[1].strip()
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _to_sync_url(url: str) -> str:
|
||
|
|
"""Convert async driver URL to sync equivalent for inspection."""
|
||
|
|
replacements = [
|
||
|
|
("postgresql+asyncpg://", "postgresql+psycopg://"),
|
||
|
|
("postgresql+aiosqlite://", "sqlite:///"),
|
||
|
|
("postgresql://", "postgresql+psycopg://"),
|
||
|
|
("postgres://", "postgresql+psycopg://"),
|
||
|
|
]
|
||
|
|
for old, new in replacements:
|
||
|
|
if url.startswith(old):
|
||
|
|
return new + url[len(old):]
|
||
|
|
return url
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> int:
|
||
|
|
from sqlalchemy import create_engine
|
||
|
|
from sqlalchemy import inspect as sa_inspect
|
||
|
|
from sqlmodel import SQLModel
|
||
|
|
|
||
|
|
# Import all models so SQLModel.metadata is fully populated.
|
||
|
|
import app.models as _ # noqa: F401
|
||
|
|
|
||
|
|
database_url = _resolve_database_url()
|
||
|
|
if not database_url:
|
||
|
|
print("ERROR: DATABASE_URL not set and not found in .env")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
sync_url = _to_sync_url(database_url)
|
||
|
|
|
||
|
|
try:
|
||
|
|
engine = create_engine(sync_url, pool_pre_ping=True)
|
||
|
|
inspector = sa_inspect(engine)
|
||
|
|
except Exception as exc:
|
||
|
|
print(f"ERROR: could not connect to database: {exc}")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
missing_tables: list[str] = []
|
||
|
|
missing_columns: list[str] = []
|
||
|
|
|
||
|
|
try:
|
||
|
|
for table_name, table in SQLModel.metadata.tables.items():
|
||
|
|
if not inspector.has_table(table_name):
|
||
|
|
missing_tables.append(table_name)
|
||
|
|
continue
|
||
|
|
|
||
|
|
db_columns = {col["name"] for col in inspector.get_columns(table_name)}
|
||
|
|
model_columns = {col.name for col in table.columns}
|
||
|
|
for col in sorted(model_columns - db_columns):
|
||
|
|
missing_columns.append(f"{table_name}.{col}")
|
||
|
|
except Exception as exc:
|
||
|
|
print(f"ERROR: failed to inspect schema: {exc}")
|
||
|
|
return 1
|
||
|
|
finally:
|
||
|
|
engine.dispose()
|
||
|
|
|
||
|
|
if not missing_tables and not missing_columns:
|
||
|
|
model_table_count = len(SQLModel.metadata.tables)
|
||
|
|
print(f"OK: all {model_table_count} model tables and their columns exist in the database")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
print("SCHEMA DRIFT DETECTED\n")
|
||
|
|
if missing_tables:
|
||
|
|
print("Missing tables (exist in models, not in DB):")
|
||
|
|
for t in sorted(missing_tables):
|
||
|
|
print(f" - {t}")
|
||
|
|
print()
|
||
|
|
if missing_columns:
|
||
|
|
print("Missing columns (exist in models, not in DB):")
|
||
|
|
for c in missing_columns:
|
||
|
|
print(f" - {c}")
|
||
|
|
print()
|
||
|
|
print("Possible causes:")
|
||
|
|
print(" 1. A migration was inserted into the chain after the DB had already")
|
||
|
|
print(" advanced past it — run the migration manually or add the columns.")
|
||
|
|
print(" 2. A new model field has no migration yet — run:")
|
||
|
|
print(" alembic revision --autogenerate -m 'add <field>'")
|
||
|
|
print(" 3. The DB was populated from an older schema — run:")
|
||
|
|
print(" alembic upgrade head")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
raise SystemExit(main())
|