393 lines
14 KiB
Python
393 lines
14 KiB
Python
|
|
"""Idempotent seeder for local AI provider credentials.
|
||
|
|
|
||
|
|
Reads session tokens from the local Claude Code and Codex CLI credential
|
||
|
|
files and upserts them as ProviderCredential rows so Pipeline recognises
|
||
|
|
the providers as configured.
|
||
|
|
|
||
|
|
Sources
|
||
|
|
───────
|
||
|
|
Anthropic / Claude
|
||
|
|
~/.claude/.credentials.json (claudeAiOauth.accessToken)
|
||
|
|
Override path: CLAUDE_CREDENTIALS_PATH env var
|
||
|
|
|
||
|
|
OpenAI / GPT
|
||
|
|
~/.codex/auth.json (tokens.access_token)
|
||
|
|
Override path: CODEX_CREDENTIALS_PATH env var
|
||
|
|
|
||
|
|
API keys (optional fallback / supplement)
|
||
|
|
ANTHROPIC_API_KEY
|
||
|
|
OPENAI_API_KEY
|
||
|
|
OPENAI_BASE_URL
|
||
|
|
|
||
|
|
Safe to run on every boot — rows are only created or updated when a token
|
||
|
|
or key value differs from what is already stored.
|
||
|
|
|
||
|
|
Usage
|
||
|
|
─────
|
||
|
|
# standalone
|
||
|
|
python scripts/seed_provider_credentials.py
|
||
|
|
|
||
|
|
# called automatically from main.py lifespan when AUTH_MODE=local
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import base64
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||
|
|
sys.path.insert(0, str(BACKEND_ROOT))
|
||
|
|
|
||
|
|
|
||
|
|
# ── credential-file readers ───────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def _read_claude_session_key() -> str | None:
|
||
|
|
"""Read the Claude Code OAuth access token from ~/.claude/.credentials.json."""
|
||
|
|
path = os.environ.get("CLAUDE_CREDENTIALS_PATH", "").strip()
|
||
|
|
if not path:
|
||
|
|
path = os.path.join(os.path.expanduser("~"), ".claude", ".credentials.json")
|
||
|
|
try:
|
||
|
|
with open(path) as fh:
|
||
|
|
data = json.load(fh)
|
||
|
|
except (FileNotFoundError, PermissionError, ValueError, OSError):
|
||
|
|
return None
|
||
|
|
|
||
|
|
oauth = data.get("claudeAiOauth")
|
||
|
|
if not isinstance(oauth, dict):
|
||
|
|
return None
|
||
|
|
|
||
|
|
token = oauth.get("accessToken")
|
||
|
|
expires_at = oauth.get("expiresAt")
|
||
|
|
|
||
|
|
if not isinstance(token, str) or not token:
|
||
|
|
return None
|
||
|
|
if isinstance(expires_at, (int, float)) and expires_at > 0:
|
||
|
|
if expires_at <= time.time() * 1000:
|
||
|
|
return None # expired
|
||
|
|
|
||
|
|
return token
|
||
|
|
|
||
|
|
|
||
|
|
def _read_openai_session_key() -> str | None:
|
||
|
|
"""Read the Codex CLI JWT access token from ~/.codex/auth.json."""
|
||
|
|
path = os.environ.get("CODEX_CREDENTIALS_PATH", "").strip()
|
||
|
|
if not path:
|
||
|
|
path = os.path.join(os.path.expanduser("~"), ".codex", "auth.json")
|
||
|
|
try:
|
||
|
|
with open(path) as fh:
|
||
|
|
data = json.load(fh)
|
||
|
|
except (FileNotFoundError, PermissionError, ValueError, OSError):
|
||
|
|
return None
|
||
|
|
|
||
|
|
tokens = data.get("tokens")
|
||
|
|
if not isinstance(tokens, dict):
|
||
|
|
return None
|
||
|
|
|
||
|
|
token = tokens.get("access_token")
|
||
|
|
if not isinstance(token, str) or not token:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Check JWT expiry without verifying the signature.
|
||
|
|
parts = token.split(".")
|
||
|
|
if len(parts) >= 2:
|
||
|
|
try:
|
||
|
|
pad = parts[1] + "=" * (4 - len(parts[1]) % 4)
|
||
|
|
payload = json.loads(base64.urlsafe_b64decode(pad))
|
||
|
|
exp = payload.get("exp")
|
||
|
|
if isinstance(exp, (int, float)) and exp <= time.time():
|
||
|
|
return None # expired
|
||
|
|
except Exception:
|
||
|
|
pass # Can't decode — proceed optimistically
|
||
|
|
|
||
|
|
return token
|
||
|
|
|
||
|
|
|
||
|
|
# ── spec dataclass ────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class _ProviderSpec:
|
||
|
|
provider: str
|
||
|
|
account_key: str
|
||
|
|
display_name: str
|
||
|
|
api_key: str | None = None
|
||
|
|
session_key: str | None = None
|
||
|
|
base_url: str | None = None
|
||
|
|
|
||
|
|
@property
|
||
|
|
def has_credentials(self) -> bool:
|
||
|
|
return bool(self.api_key or self.session_key or self.base_url)
|
||
|
|
|
||
|
|
|
||
|
|
def _collect_specs() -> list[_ProviderSpec]:
|
||
|
|
specs: list[_ProviderSpec] = []
|
||
|
|
|
||
|
|
# Anthropic: session token from local file, API key from env.
|
||
|
|
anthropic = _ProviderSpec(
|
||
|
|
provider="anthropic",
|
||
|
|
account_key="default",
|
||
|
|
display_name="Anthropic (local)",
|
||
|
|
session_key=_read_claude_session_key(),
|
||
|
|
api_key=os.environ.get("ANTHROPIC_API_KEY", "").strip() or None,
|
||
|
|
base_url=os.environ.get("ANTHROPIC_BASE_URL", "").strip() or None,
|
||
|
|
)
|
||
|
|
if anthropic.has_credentials:
|
||
|
|
specs.append(anthropic)
|
||
|
|
|
||
|
|
# OpenAI: session token from local Codex file, API key from env.
|
||
|
|
openai = _ProviderSpec(
|
||
|
|
provider="openai",
|
||
|
|
account_key="default",
|
||
|
|
display_name="OpenAI (local)",
|
||
|
|
session_key=_read_openai_session_key(),
|
||
|
|
api_key=os.environ.get("OPENAI_API_KEY", "").strip() or None,
|
||
|
|
base_url=os.environ.get("OPENAI_BASE_URL", "").strip() or None,
|
||
|
|
)
|
||
|
|
if openai.has_credentials:
|
||
|
|
specs.append(openai)
|
||
|
|
|
||
|
|
return specs
|
||
|
|
|
||
|
|
|
||
|
|
# ── upsert logic ──────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
async def seed(*, verbose: bool = True) -> int:
|
||
|
|
"""Upsert provider credentials from local credential files and env vars.
|
||
|
|
|
||
|
|
Returns the number of rows created or updated.
|
||
|
|
"""
|
||
|
|
from sqlmodel import select
|
||
|
|
|
||
|
|
from app.core.auth import LOCAL_AUTH_USER_ID, LOCAL_AUTH_EMAIL, LOCAL_AUTH_NAME
|
||
|
|
from app.db import crud
|
||
|
|
from app.db.session import async_session_maker, init_db
|
||
|
|
from app.models.provider_credentials import ProviderCredential
|
||
|
|
from app.models.users import User
|
||
|
|
from app.services.organizations import ensure_member_for_user
|
||
|
|
|
||
|
|
specs = _collect_specs()
|
||
|
|
if not specs:
|
||
|
|
if verbose:
|
||
|
|
print(
|
||
|
|
"seed_provider_credentials: no credentials found — "
|
||
|
|
"ensure ~/.claude/.credentials.json or ~/.codex/auth.json exist, "
|
||
|
|
"or set ANTHROPIC_API_KEY / OPENAI_API_KEY"
|
||
|
|
)
|
||
|
|
return 0
|
||
|
|
|
||
|
|
await init_db()
|
||
|
|
|
||
|
|
async with async_session_maker() as session:
|
||
|
|
# Ensure the local user + org exist before we can write credentials.
|
||
|
|
user, _created = await crud.get_or_create(
|
||
|
|
session,
|
||
|
|
User,
|
||
|
|
clerk_user_id=LOCAL_AUTH_USER_ID,
|
||
|
|
defaults={"email": LOCAL_AUTH_EMAIL, "name": LOCAL_AUTH_NAME},
|
||
|
|
)
|
||
|
|
if not user.email:
|
||
|
|
user.email = LOCAL_AUTH_EMAIL
|
||
|
|
if not user.name:
|
||
|
|
user.name = LOCAL_AUTH_NAME
|
||
|
|
session.add(user)
|
||
|
|
await session.commit()
|
||
|
|
await session.refresh(user)
|
||
|
|
|
||
|
|
member = await ensure_member_for_user(session, user)
|
||
|
|
organization_id = member.organization_id
|
||
|
|
|
||
|
|
changed = 0
|
||
|
|
for spec in specs:
|
||
|
|
existing = (
|
||
|
|
await session.exec(
|
||
|
|
select(ProviderCredential).where(
|
||
|
|
ProviderCredential.organization_id == organization_id,
|
||
|
|
ProviderCredential.provider == spec.provider,
|
||
|
|
ProviderCredential.account_key == spec.account_key,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
).first()
|
||
|
|
|
||
|
|
if existing is None:
|
||
|
|
cred = ProviderCredential(
|
||
|
|
organization_id=organization_id,
|
||
|
|
provider=spec.provider,
|
||
|
|
account_key=spec.account_key,
|
||
|
|
display_name=spec.display_name,
|
||
|
|
api_key=spec.api_key,
|
||
|
|
api_key_last_four=spec.api_key[-4:] if spec.api_key else None,
|
||
|
|
session_key=spec.session_key,
|
||
|
|
session_key_last_four=spec.session_key[-4:] if spec.session_key else None,
|
||
|
|
base_url=spec.base_url,
|
||
|
|
active=True,
|
||
|
|
)
|
||
|
|
session.add(cred)
|
||
|
|
await session.commit()
|
||
|
|
changed += 1
|
||
|
|
if verbose:
|
||
|
|
_log_created(spec)
|
||
|
|
else:
|
||
|
|
dirty = False
|
||
|
|
if spec.api_key is not None and existing.api_key != spec.api_key:
|
||
|
|
existing.api_key = spec.api_key
|
||
|
|
existing.api_key_last_four = spec.api_key[-4:]
|
||
|
|
dirty = True
|
||
|
|
if spec.session_key is not None and existing.session_key != spec.session_key:
|
||
|
|
existing.session_key = spec.session_key
|
||
|
|
existing.session_key_last_four = spec.session_key[-4:]
|
||
|
|
dirty = True
|
||
|
|
if spec.base_url != existing.base_url:
|
||
|
|
existing.base_url = spec.base_url
|
||
|
|
dirty = True
|
||
|
|
if not existing.active:
|
||
|
|
existing.active = True
|
||
|
|
dirty = True
|
||
|
|
if dirty:
|
||
|
|
session.add(existing)
|
||
|
|
await session.commit()
|
||
|
|
changed += 1
|
||
|
|
if verbose:
|
||
|
|
_log_updated(spec)
|
||
|
|
else:
|
||
|
|
if verbose:
|
||
|
|
print(
|
||
|
|
f"seed_provider_credentials: "
|
||
|
|
f"{spec.provider}/{spec.account_key} already current — skipped"
|
||
|
|
)
|
||
|
|
|
||
|
|
return changed
|
||
|
|
|
||
|
|
|
||
|
|
def _log_created(spec: _ProviderSpec) -> None:
|
||
|
|
hints: list[str] = []
|
||
|
|
if spec.session_key:
|
||
|
|
hints.append(f"session …{spec.session_key[-4:]}")
|
||
|
|
if spec.api_key:
|
||
|
|
hints.append(f"api_key …{spec.api_key[-4:]}")
|
||
|
|
print(
|
||
|
|
f"seed_provider_credentials: created "
|
||
|
|
f"{spec.provider}/{spec.account_key} ({', '.join(hints) or 'base_url only'})"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _log_updated(spec: _ProviderSpec) -> None:
|
||
|
|
hints: list[str] = []
|
||
|
|
if spec.session_key:
|
||
|
|
hints.append(f"session …{spec.session_key[-4:]}")
|
||
|
|
if spec.api_key:
|
||
|
|
hints.append(f"api_key …{spec.api_key[-4:]}")
|
||
|
|
print(
|
||
|
|
f"seed_provider_credentials: updated "
|
||
|
|
f"{spec.provider}/{spec.account_key} ({', '.join(hints) or 'base_url only'})"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
WATCH_INTERVAL_SECONDS = 60 # re-check credential files every 60 s
|
||
|
|
|
||
|
|
|
||
|
|
def _credential_file_paths() -> list[str]:
|
||
|
|
"""Return the resolved paths for all watched credential files."""
|
||
|
|
claude_path = os.environ.get("CLAUDE_CREDENTIALS_PATH", "").strip()
|
||
|
|
if not claude_path:
|
||
|
|
claude_path = os.path.join(os.path.expanduser("~"), ".claude", ".credentials.json")
|
||
|
|
codex_path = os.environ.get("CODEX_CREDENTIALS_PATH", "").strip()
|
||
|
|
if not codex_path:
|
||
|
|
codex_path = os.path.join(os.path.expanduser("~"), ".codex", "auth.json")
|
||
|
|
return [claude_path, codex_path]
|
||
|
|
|
||
|
|
|
||
|
|
def _snapshot_mtimes() -> dict[str, float]:
|
||
|
|
"""Return {path: mtime} for each watched file, 0.0 if missing."""
|
||
|
|
result: dict[str, float] = {}
|
||
|
|
for path in _credential_file_paths():
|
||
|
|
try:
|
||
|
|
result[path] = os.stat(path).st_mtime
|
||
|
|
except OSError:
|
||
|
|
result[path] = 0.0
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
async def watch(*, interval: int = WATCH_INTERVAL_SECONDS) -> None:
|
||
|
|
"""Background task: re-seed whenever a credential file changes.
|
||
|
|
|
||
|
|
Polls file mtimes every `interval` seconds (default 60). Only re-seeds
|
||
|
|
when at least one file's mtime has changed — no DB or file-read overhead
|
||
|
|
on quiet cycles.
|
||
|
|
|
||
|
|
Designed to be run as a long-lived asyncio task inside the FastAPI
|
||
|
|
lifespan. Exits cleanly on CancelledError.
|
||
|
|
"""
|
||
|
|
import logging
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
last_mtimes = _snapshot_mtimes()
|
||
|
|
logger.info(
|
||
|
|
"provider_credentials.watcher started interval=%ds paths=%s",
|
||
|
|
interval,
|
||
|
|
list(last_mtimes.keys()),
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
await asyncio.sleep(interval)
|
||
|
|
current = _snapshot_mtimes()
|
||
|
|
changed_paths = [p for p, mtime in current.items() if mtime != last_mtimes.get(p)]
|
||
|
|
if not changed_paths:
|
||
|
|
continue
|
||
|
|
|
||
|
|
for path in changed_paths:
|
||
|
|
prev = last_mtimes.get(path, 0.0)
|
||
|
|
curr = current[path]
|
||
|
|
if curr == 0.0:
|
||
|
|
logger.info("provider_credentials.watcher file_removed path=%s", path)
|
||
|
|
elif prev == 0.0:
|
||
|
|
logger.info("provider_credentials.watcher file_appeared path=%s", path)
|
||
|
|
else:
|
||
|
|
logger.info("provider_credentials.watcher file_changed path=%s", path)
|
||
|
|
|
||
|
|
last_mtimes = current
|
||
|
|
|
||
|
|
try:
|
||
|
|
n = await seed(verbose=False)
|
||
|
|
if n:
|
||
|
|
logger.info("provider_credentials.watcher reseeded count=%d", n)
|
||
|
|
else:
|
||
|
|
logger.debug("provider_credentials.watcher reseed no_changes")
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning("provider_credentials.watcher reseed_failed error=%s", exc)
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
logger.info("provider_credentials.watcher stopped")
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description="Seed or watch local AI provider credentials.")
|
||
|
|
parser.add_argument(
|
||
|
|
"--watch",
|
||
|
|
action="store_true",
|
||
|
|
help=f"Run continuously, re-seeding every {WATCH_INTERVAL_SECONDS}s when files change.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--interval",
|
||
|
|
type=int,
|
||
|
|
default=WATCH_INTERVAL_SECONDS,
|
||
|
|
metavar="SECONDS",
|
||
|
|
help=f"Poll interval in seconds (default: {WATCH_INTERVAL_SECONDS}).",
|
||
|
|
)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
if args.watch:
|
||
|
|
print(f"seed_provider_credentials: watching every {args.interval}s — Ctrl-C to stop")
|
||
|
|
asyncio.run(watch(interval=args.interval))
|
||
|
|
else:
|
||
|
|
result = asyncio.run(seed(verbose=True))
|
||
|
|
print(f"seed_provider_credentials: done — {result} row(s) affected")
|