"""Idempotent seeder for local AI provider credentials. Reads session tokens from the local Claude Code and Codex CLI credential files and upserts them as ProviderCredential rows so Pipeline recognises the providers as configured. Sources ─────── Anthropic / Claude ~/.claude/.credentials.json (claudeAiOauth.accessToken) Override path: CLAUDE_CREDENTIALS_PATH env var OpenAI / GPT ~/.codex/auth.json (tokens.access_token) Override path: CODEX_CREDENTIALS_PATH env var API keys (optional fallback / supplement) ANTHROPIC_API_KEY OPENAI_API_KEY OPENAI_BASE_URL Safe to run on every boot — rows are only created or updated when a token or key value differs from what is already stored. Usage ───── # standalone python scripts/seed_provider_credentials.py # called automatically from main.py lifespan when AUTH_MODE=local """ from __future__ import annotations import asyncio import base64 import json import os import sys import time from dataclasses import dataclass, field from pathlib import Path BACKEND_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(BACKEND_ROOT)) # ── credential-file readers ─────────────────────────────────────────────────── def _read_claude_session_key() -> str | None: """Read the Claude Code OAuth access token from ~/.claude/.credentials.json.""" path = os.environ.get("CLAUDE_CREDENTIALS_PATH", "").strip() if not path: path = os.path.join(os.path.expanduser("~"), ".claude", ".credentials.json") try: with open(path) as fh: data = json.load(fh) except (FileNotFoundError, PermissionError, ValueError, OSError): return None oauth = data.get("claudeAiOauth") if not isinstance(oauth, dict): return None token = oauth.get("accessToken") expires_at = oauth.get("expiresAt") if not isinstance(token, str) or not token: return None if isinstance(expires_at, (int, float)) and expires_at > 0: if expires_at <= time.time() * 1000: return None # expired return token def _read_openai_session_key() -> str | None: """Read the Codex CLI JWT access token from ~/.codex/auth.json.""" path = os.environ.get("CODEX_CREDENTIALS_PATH", "").strip() if not path: path = os.path.join(os.path.expanduser("~"), ".codex", "auth.json") try: with open(path) as fh: data = json.load(fh) except (FileNotFoundError, PermissionError, ValueError, OSError): return None tokens = data.get("tokens") if not isinstance(tokens, dict): return None token = tokens.get("access_token") if not isinstance(token, str) or not token: return None # Check JWT expiry without verifying the signature. parts = token.split(".") if len(parts) >= 2: try: pad = parts[1] + "=" * (4 - len(parts[1]) % 4) payload = json.loads(base64.urlsafe_b64decode(pad)) exp = payload.get("exp") if isinstance(exp, (int, float)) and exp <= time.time(): return None # expired except Exception: pass # Can't decode — proceed optimistically return token # ── spec dataclass ──────────────────────────────────────────────────────────── @dataclass class _ProviderSpec: provider: str account_key: str display_name: str api_key: str | None = None session_key: str | None = None base_url: str | None = None @property def has_credentials(self) -> bool: return bool(self.api_key or self.session_key or self.base_url) def _collect_specs() -> list[_ProviderSpec]: specs: list[_ProviderSpec] = [] # Anthropic: session token from local file, API key from env. anthropic = _ProviderSpec( provider="anthropic", account_key="default", display_name="Anthropic (local)", session_key=_read_claude_session_key(), api_key=os.environ.get("ANTHROPIC_API_KEY", "").strip() or None, base_url=os.environ.get("ANTHROPIC_BASE_URL", "").strip() or None, ) if anthropic.has_credentials: specs.append(anthropic) # OpenAI: session token from local Codex file, API key from env. openai = _ProviderSpec( provider="openai", account_key="default", display_name="OpenAI (local)", session_key=_read_openai_session_key(), api_key=os.environ.get("OPENAI_API_KEY", "").strip() or None, base_url=os.environ.get("OPENAI_BASE_URL", "").strip() or None, ) if openai.has_credentials: specs.append(openai) return specs # ── upsert logic ────────────────────────────────────────────────────────────── async def seed(*, verbose: bool = True) -> int: """Upsert provider credentials from local credential files and env vars. Returns the number of rows created or updated. """ from sqlmodel import select from app.core.auth import LOCAL_AUTH_USER_ID, LOCAL_AUTH_EMAIL, LOCAL_AUTH_NAME from app.db import crud from app.db.session import async_session_maker, init_db from app.models.provider_credentials import ProviderCredential from app.models.users import User from app.services.organizations import ensure_member_for_user specs = _collect_specs() if not specs: if verbose: print( "seed_provider_credentials: no credentials found — " "ensure ~/.claude/.credentials.json or ~/.codex/auth.json exist, " "or set ANTHROPIC_API_KEY / OPENAI_API_KEY" ) return 0 await init_db() async with async_session_maker() as session: # Ensure the local user + org exist before we can write credentials. user, _created = await crud.get_or_create( session, User, clerk_user_id=LOCAL_AUTH_USER_ID, defaults={"email": LOCAL_AUTH_EMAIL, "name": LOCAL_AUTH_NAME}, ) if not user.email: user.email = LOCAL_AUTH_EMAIL if not user.name: user.name = LOCAL_AUTH_NAME session.add(user) await session.commit() await session.refresh(user) member = await ensure_member_for_user(session, user) organization_id = member.organization_id changed = 0 for spec in specs: existing = ( await session.exec( select(ProviderCredential).where( ProviderCredential.organization_id == organization_id, ProviderCredential.provider == spec.provider, ProviderCredential.account_key == spec.account_key, ) ) ).first() if existing is None: cred = ProviderCredential( organization_id=organization_id, provider=spec.provider, account_key=spec.account_key, display_name=spec.display_name, api_key=spec.api_key, api_key_last_four=spec.api_key[-4:] if spec.api_key else None, session_key=spec.session_key, session_key_last_four=spec.session_key[-4:] if spec.session_key else None, base_url=spec.base_url, active=True, ) session.add(cred) await session.commit() changed += 1 if verbose: _log_created(spec) else: dirty = False if spec.api_key is not None and existing.api_key != spec.api_key: existing.api_key = spec.api_key existing.api_key_last_four = spec.api_key[-4:] dirty = True if spec.session_key is not None and existing.session_key != spec.session_key: existing.session_key = spec.session_key existing.session_key_last_four = spec.session_key[-4:] dirty = True if spec.base_url != existing.base_url: existing.base_url = spec.base_url dirty = True if not existing.active: existing.active = True dirty = True if dirty: session.add(existing) await session.commit() changed += 1 if verbose: _log_updated(spec) else: if verbose: print( f"seed_provider_credentials: " f"{spec.provider}/{spec.account_key} already current — skipped" ) return changed def _log_created(spec: _ProviderSpec) -> None: hints: list[str] = [] if spec.session_key: hints.append(f"session …{spec.session_key[-4:]}") if spec.api_key: hints.append(f"api_key …{spec.api_key[-4:]}") print( f"seed_provider_credentials: created " f"{spec.provider}/{spec.account_key} ({', '.join(hints) or 'base_url only'})" ) def _log_updated(spec: _ProviderSpec) -> None: hints: list[str] = [] if spec.session_key: hints.append(f"session …{spec.session_key[-4:]}") if spec.api_key: hints.append(f"api_key …{spec.api_key[-4:]}") print( f"seed_provider_credentials: updated " f"{spec.provider}/{spec.account_key} ({', '.join(hints) or 'base_url only'})" ) WATCH_INTERVAL_SECONDS = 60 # re-check credential files every 60 s def _credential_file_paths() -> list[str]: """Return the resolved paths for all watched credential files.""" claude_path = os.environ.get("CLAUDE_CREDENTIALS_PATH", "").strip() if not claude_path: claude_path = os.path.join(os.path.expanduser("~"), ".claude", ".credentials.json") codex_path = os.environ.get("CODEX_CREDENTIALS_PATH", "").strip() if not codex_path: codex_path = os.path.join(os.path.expanduser("~"), ".codex", "auth.json") return [claude_path, codex_path] def _snapshot_mtimes() -> dict[str, float]: """Return {path: mtime} for each watched file, 0.0 if missing.""" result: dict[str, float] = {} for path in _credential_file_paths(): try: result[path] = os.stat(path).st_mtime except OSError: result[path] = 0.0 return result async def watch(*, interval: int = WATCH_INTERVAL_SECONDS) -> None: """Background task: re-seed whenever a credential file changes. Polls file mtimes every `interval` seconds (default 60). Only re-seeds when at least one file's mtime has changed — no DB or file-read overhead on quiet cycles. Designed to be run as a long-lived asyncio task inside the FastAPI lifespan. Exits cleanly on CancelledError. """ import logging logger = logging.getLogger(__name__) last_mtimes = _snapshot_mtimes() logger.info( "provider_credentials.watcher started interval=%ds paths=%s", interval, list(last_mtimes.keys()), ) try: while True: await asyncio.sleep(interval) current = _snapshot_mtimes() changed_paths = [p for p, mtime in current.items() if mtime != last_mtimes.get(p)] if not changed_paths: continue for path in changed_paths: prev = last_mtimes.get(path, 0.0) curr = current[path] if curr == 0.0: logger.info("provider_credentials.watcher file_removed path=%s", path) elif prev == 0.0: logger.info("provider_credentials.watcher file_appeared path=%s", path) else: logger.info("provider_credentials.watcher file_changed path=%s", path) last_mtimes = current try: n = await seed(verbose=False) if n: logger.info("provider_credentials.watcher reseeded count=%d", n) else: logger.debug("provider_credentials.watcher reseed no_changes") except Exception as exc: logger.warning("provider_credentials.watcher reseed_failed error=%s", exc) except asyncio.CancelledError: logger.info("provider_credentials.watcher stopped") raise if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Seed or watch local AI provider credentials.") parser.add_argument( "--watch", action="store_true", help=f"Run continuously, re-seeding every {WATCH_INTERVAL_SECONDS}s when files change.", ) parser.add_argument( "--interval", type=int, default=WATCH_INTERVAL_SECONDS, metavar="SECONDS", help=f"Poll interval in seconds (default: {WATCH_INTERVAL_SECONDS}).", ) args = parser.parse_args() if args.watch: print(f"seed_provider_credentials: watching every {args.interval}s — Ctrl-C to stop") asyncio.run(watch(interval=args.interval)) else: result = asyncio.run(seed(verbose=True)) print(f"seed_provider_credentials: done — {result} row(s) affected")