Pipeline/backend/scripts/seed_provider_credentials.py

393 lines
14 KiB
Python
Raw Normal View History

"""Idempotent seeder for local AI provider credentials.
Reads session tokens from the local Claude Code and Codex CLI credential
files and upserts them as ProviderCredential rows so Pipeline recognises
the providers as configured.
Sources
Anthropic / Claude
~/.claude/.credentials.json (claudeAiOauth.accessToken)
Override path: CLAUDE_CREDENTIALS_PATH env var
OpenAI / GPT
~/.codex/auth.json (tokens.access_token)
Override path: CODEX_CREDENTIALS_PATH env var
API keys (optional fallback / supplement)
ANTHROPIC_API_KEY
OPENAI_API_KEY
OPENAI_BASE_URL
Safe to run on every boot rows are only created or updated when a token
or key value differs from what is already stored.
Usage
# standalone
python scripts/seed_provider_credentials.py
# called automatically from main.py lifespan when AUTH_MODE=local
"""
from __future__ import annotations
import asyncio
import base64
import json
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(BACKEND_ROOT))
# ── credential-file readers ───────────────────────────────────────────────────
def _read_claude_session_key() -> str | None:
"""Read the Claude Code OAuth access token from ~/.claude/.credentials.json."""
path = os.environ.get("CLAUDE_CREDENTIALS_PATH", "").strip()
if not path:
path = os.path.join(os.path.expanduser("~"), ".claude", ".credentials.json")
try:
with open(path) as fh:
data = json.load(fh)
except (FileNotFoundError, PermissionError, ValueError, OSError):
return None
oauth = data.get("claudeAiOauth")
if not isinstance(oauth, dict):
return None
token = oauth.get("accessToken")
expires_at = oauth.get("expiresAt")
if not isinstance(token, str) or not token:
return None
if isinstance(expires_at, (int, float)) and expires_at > 0:
if expires_at <= time.time() * 1000:
return None # expired
return token
def _read_openai_session_key() -> str | None:
"""Read the Codex CLI JWT access token from ~/.codex/auth.json."""
path = os.environ.get("CODEX_CREDENTIALS_PATH", "").strip()
if not path:
path = os.path.join(os.path.expanduser("~"), ".codex", "auth.json")
try:
with open(path) as fh:
data = json.load(fh)
except (FileNotFoundError, PermissionError, ValueError, OSError):
return None
tokens = data.get("tokens")
if not isinstance(tokens, dict):
return None
token = tokens.get("access_token")
if not isinstance(token, str) or not token:
return None
# Check JWT expiry without verifying the signature.
parts = token.split(".")
if len(parts) >= 2:
try:
pad = parts[1] + "=" * (4 - len(parts[1]) % 4)
payload = json.loads(base64.urlsafe_b64decode(pad))
exp = payload.get("exp")
if isinstance(exp, (int, float)) and exp <= time.time():
return None # expired
except Exception:
pass # Can't decode — proceed optimistically
return token
# ── spec dataclass ────────────────────────────────────────────────────────────
@dataclass
class _ProviderSpec:
provider: str
account_key: str
display_name: str
api_key: str | None = None
session_key: str | None = None
base_url: str | None = None
@property
def has_credentials(self) -> bool:
return bool(self.api_key or self.session_key or self.base_url)
def _collect_specs() -> list[_ProviderSpec]:
specs: list[_ProviderSpec] = []
# Anthropic: session token from local file, API key from env.
anthropic = _ProviderSpec(
provider="anthropic",
account_key="default",
display_name="Anthropic (local)",
session_key=_read_claude_session_key(),
api_key=os.environ.get("ANTHROPIC_API_KEY", "").strip() or None,
base_url=os.environ.get("ANTHROPIC_BASE_URL", "").strip() or None,
)
if anthropic.has_credentials:
specs.append(anthropic)
# OpenAI: session token from local Codex file, API key from env.
openai = _ProviderSpec(
provider="openai",
account_key="default",
display_name="OpenAI (local)",
session_key=_read_openai_session_key(),
api_key=os.environ.get("OPENAI_API_KEY", "").strip() or None,
base_url=os.environ.get("OPENAI_BASE_URL", "").strip() or None,
)
if openai.has_credentials:
specs.append(openai)
return specs
# ── upsert logic ──────────────────────────────────────────────────────────────
async def seed(*, verbose: bool = True) -> int:
"""Upsert provider credentials from local credential files and env vars.
Returns the number of rows created or updated.
"""
from sqlmodel import select
from app.core.auth import LOCAL_AUTH_USER_ID, LOCAL_AUTH_EMAIL, LOCAL_AUTH_NAME
from app.db import crud
from app.db.session import async_session_maker, init_db
from app.models.provider_credentials import ProviderCredential
from app.models.users import User
from app.services.organizations import ensure_member_for_user
specs = _collect_specs()
if not specs:
if verbose:
print(
"seed_provider_credentials: no credentials found — "
"ensure ~/.claude/.credentials.json or ~/.codex/auth.json exist, "
"or set ANTHROPIC_API_KEY / OPENAI_API_KEY"
)
return 0
await init_db()
async with async_session_maker() as session:
# Ensure the local user + org exist before we can write credentials.
user, _created = await crud.get_or_create(
session,
User,
clerk_user_id=LOCAL_AUTH_USER_ID,
defaults={"email": LOCAL_AUTH_EMAIL, "name": LOCAL_AUTH_NAME},
)
if not user.email:
user.email = LOCAL_AUTH_EMAIL
if not user.name:
user.name = LOCAL_AUTH_NAME
session.add(user)
await session.commit()
await session.refresh(user)
member = await ensure_member_for_user(session, user)
organization_id = member.organization_id
changed = 0
for spec in specs:
existing = (
await session.exec(
select(ProviderCredential).where(
ProviderCredential.organization_id == organization_id,
ProviderCredential.provider == spec.provider,
ProviderCredential.account_key == spec.account_key,
)
)
).first()
if existing is None:
cred = ProviderCredential(
organization_id=organization_id,
provider=spec.provider,
account_key=spec.account_key,
display_name=spec.display_name,
api_key=spec.api_key,
api_key_last_four=spec.api_key[-4:] if spec.api_key else None,
session_key=spec.session_key,
session_key_last_four=spec.session_key[-4:] if spec.session_key else None,
base_url=spec.base_url,
active=True,
)
session.add(cred)
await session.commit()
changed += 1
if verbose:
_log_created(spec)
else:
dirty = False
if spec.api_key is not None and existing.api_key != spec.api_key:
existing.api_key = spec.api_key
existing.api_key_last_four = spec.api_key[-4:]
dirty = True
if spec.session_key is not None and existing.session_key != spec.session_key:
existing.session_key = spec.session_key
existing.session_key_last_four = spec.session_key[-4:]
dirty = True
if spec.base_url != existing.base_url:
existing.base_url = spec.base_url
dirty = True
if not existing.active:
existing.active = True
dirty = True
if dirty:
session.add(existing)
await session.commit()
changed += 1
if verbose:
_log_updated(spec)
else:
if verbose:
print(
f"seed_provider_credentials: "
f"{spec.provider}/{spec.account_key} already current — skipped"
)
return changed
def _log_created(spec: _ProviderSpec) -> None:
hints: list[str] = []
if spec.session_key:
hints.append(f"session …{spec.session_key[-4:]}")
if spec.api_key:
hints.append(f"api_key …{spec.api_key[-4:]}")
print(
f"seed_provider_credentials: created "
f"{spec.provider}/{spec.account_key} ({', '.join(hints) or 'base_url only'})"
)
def _log_updated(spec: _ProviderSpec) -> None:
hints: list[str] = []
if spec.session_key:
hints.append(f"session …{spec.session_key[-4:]}")
if spec.api_key:
hints.append(f"api_key …{spec.api_key[-4:]}")
print(
f"seed_provider_credentials: updated "
f"{spec.provider}/{spec.account_key} ({', '.join(hints) or 'base_url only'})"
)
WATCH_INTERVAL_SECONDS = 60 # re-check credential files every 60 s
def _credential_file_paths() -> list[str]:
"""Return the resolved paths for all watched credential files."""
claude_path = os.environ.get("CLAUDE_CREDENTIALS_PATH", "").strip()
if not claude_path:
claude_path = os.path.join(os.path.expanduser("~"), ".claude", ".credentials.json")
codex_path = os.environ.get("CODEX_CREDENTIALS_PATH", "").strip()
if not codex_path:
codex_path = os.path.join(os.path.expanduser("~"), ".codex", "auth.json")
return [claude_path, codex_path]
def _snapshot_mtimes() -> dict[str, float]:
"""Return {path: mtime} for each watched file, 0.0 if missing."""
result: dict[str, float] = {}
for path in _credential_file_paths():
try:
result[path] = os.stat(path).st_mtime
except OSError:
result[path] = 0.0
return result
async def watch(*, interval: int = WATCH_INTERVAL_SECONDS) -> None:
"""Background task: re-seed whenever a credential file changes.
Polls file mtimes every `interval` seconds (default 60). Only re-seeds
when at least one file's mtime has changed — no DB or file-read overhead
on quiet cycles.
Designed to be run as a long-lived asyncio task inside the FastAPI
lifespan. Exits cleanly on CancelledError.
"""
import logging
logger = logging.getLogger(__name__)
last_mtimes = _snapshot_mtimes()
logger.info(
"provider_credentials.watcher started interval=%ds paths=%s",
interval,
list(last_mtimes.keys()),
)
try:
while True:
await asyncio.sleep(interval)
current = _snapshot_mtimes()
changed_paths = [p for p, mtime in current.items() if mtime != last_mtimes.get(p)]
if not changed_paths:
continue
for path in changed_paths:
prev = last_mtimes.get(path, 0.0)
curr = current[path]
if curr == 0.0:
logger.info("provider_credentials.watcher file_removed path=%s", path)
elif prev == 0.0:
logger.info("provider_credentials.watcher file_appeared path=%s", path)
else:
logger.info("provider_credentials.watcher file_changed path=%s", path)
last_mtimes = current
try:
n = await seed(verbose=False)
if n:
logger.info("provider_credentials.watcher reseeded count=%d", n)
else:
logger.debug("provider_credentials.watcher reseed no_changes")
except Exception as exc:
logger.warning("provider_credentials.watcher reseed_failed error=%s", exc)
except asyncio.CancelledError:
logger.info("provider_credentials.watcher stopped")
raise
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Seed or watch local AI provider credentials.")
parser.add_argument(
"--watch",
action="store_true",
help=f"Run continuously, re-seeding every {WATCH_INTERVAL_SECONDS}s when files change.",
)
parser.add_argument(
"--interval",
type=int,
default=WATCH_INTERVAL_SECONDS,
metavar="SECONDS",
help=f"Poll interval in seconds (default: {WATCH_INTERVAL_SECONDS}).",
)
args = parser.parse_args()
if args.watch:
print(f"seed_provider_credentials: watching every {args.interval}s — Ctrl-C to stop")
asyncio.run(watch(interval=args.interval))
else:
result = asyncio.run(seed(verbose=True))
print(f"seed_provider_credentials: done — {result} row(s) affected")