"""CRUD endpoints for AI provider credentials. Every org member can read credentials (to see what providers are configured). Creating, updating, and deleting requires org-admin access. The full api_key is never returned in any response. """ from __future__ import annotations from typing import TYPE_CHECKING from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlmodel import select from app.api.deps import require_org_admin, require_org_member from app.core.time import utcnow from app.db import crud from app.db.session import get_session from app.models.provider_credentials import ProviderCredential, SUPPORTED_PROVIDERS from app.schemas.provider_credentials import ( ProviderCredentialCreate, ProviderCredentialRead, ProviderCredentialTestRequest, ProviderCredentialUpdate, ProviderUsageLiveRead, RequestWindowRead, TokenWindowRead, ) from app.services.provider_usage import fetch_provider_usage from app.services.organizations import OrganizationContext if TYPE_CHECKING: from sqlmodel.ext.asyncio.session import AsyncSession router = APIRouter(prefix="/provider-credentials", tags=["provider-credentials"]) SESSION_DEP = Depends(get_session) ORG_MEMBER_DEP = Depends(require_org_member) ORG_ADMIN_DEP = Depends(require_org_admin) def _to_read(cred: ProviderCredential) -> ProviderCredentialRead: return ProviderCredentialRead( id=cred.id, organization_id=cred.organization_id, provider=cred.provider, account_key=cred.account_key, display_name=cred.display_name, api_key_last_four=cred.api_key_last_four, has_api_key=bool(cred.api_key), base_url=cred.base_url, active=cred.active, created_at=cred.created_at, updated_at=cred.updated_at, ) @router.get("", response_model=list[ProviderCredentialRead]) async def list_provider_credentials( session: AsyncSession = SESSION_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> list[ProviderCredentialRead]: """List all provider credentials for the caller's organisation.""" rows = ( await session.exec( select(ProviderCredential) .where(ProviderCredential.organization_id == ctx.organization.id) .order_by(ProviderCredential.provider, ProviderCredential.account_key) ) ).all() return [_to_read(r) for r in rows] @router.post("", response_model=ProviderCredentialRead, status_code=status.HTTP_201_CREATED) async def create_provider_credential( payload: ProviderCredentialCreate, session: AsyncSession = SESSION_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> ProviderCredentialRead: """Create a new provider credential. Admin-only.""" if payload.provider not in SUPPORTED_PROVIDERS: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Unsupported provider '{payload.provider}'. Supported: {sorted(SUPPORTED_PROVIDERS)}", ) # Check for duplicate (provider + account_key per org) existing = ( await session.exec( select(ProviderCredential).where( ProviderCredential.organization_id == ctx.organization.id, ProviderCredential.provider == payload.provider, ProviderCredential.account_key == payload.account_key, ) ) ).first() if existing is not None: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"A '{payload.provider}' credential with account key '{payload.account_key}' already exists.", ) last_four = payload.api_key[-4:] if payload.api_key and len(payload.api_key) >= 4 else None cred = ProviderCredential( organization_id=ctx.organization.id, provider=payload.provider, account_key=payload.account_key, display_name=payload.display_name or payload.account_key, api_key=payload.api_key or None, api_key_last_four=last_four, base_url=payload.base_url or None, active=payload.active, ) session.add(cred) await session.commit() await session.refresh(cred) return _to_read(cred) @router.post("/test", response_model=ProviderUsageLiveRead) async def test_provider_credential( payload: ProviderCredentialTestRequest, ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> ProviderUsageLiveRead: """Validate provider credentials without saving them. Admin-only.""" if payload.provider not in SUPPORTED_PROVIDERS: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Unsupported provider '{payload.provider}'. Supported: {sorted(SUPPORTED_PROVIDERS)}", ) account_key = payload.account_key.strip() or "test" live = await fetch_provider_usage( credential_id=f"test:{ctx.organization.id}:{payload.provider}:{account_key}", provider=payload.provider, account_key=account_key, api_key=payload.api_key, base_url=payload.base_url, force_refresh=True, ) def _tok(w) -> TokenWindowRead: return TokenWindowRead( limit=w.limit, remaining=w.remaining, used=w.used, pct_used=w.pct_used, reset_at=w.reset_at.isoformat() if w.reset_at else None, reset_in_ms=w.reset_in_ms, ) def _req(w) -> RequestWindowRead: return RequestWindowRead( limit=w.limit, remaining=w.remaining, reset_at=w.reset_at.isoformat() if w.reset_at else None, reset_in_ms=w.reset_in_ms, ) return ProviderUsageLiveRead( provider=live.provider, account_key=live.account_key, checked_at=live.checked_at.isoformat(), reachable=live.reachable, source=live.source, confidence=live.confidence, error=live.error, tokens=_tok(live.tokens), input_tokens=_tok(live.input_tokens), requests=_req(live.requests), models=live.models, sample_model=live.sample_model, sample_input_tokens=live.sample_input_tokens, sample_output_tokens=live.sample_output_tokens, sample_latency_ms=live.sample_latency_ms, debug_rate_limit_headers=sorted(live.raw_headers.keys()) if live.raw_headers else None, ) @router.get("/{credential_id}", response_model=ProviderCredentialRead) async def get_provider_credential( credential_id: UUID, session: AsyncSession = SESSION_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> ProviderCredentialRead: cred = await crud.get_by_id(session, ProviderCredential, credential_id) if cred is None or cred.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return _to_read(cred) @router.patch("/{credential_id}", response_model=ProviderCredentialRead) async def update_provider_credential( credential_id: UUID, payload: ProviderCredentialUpdate, session: AsyncSession = SESSION_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> ProviderCredentialRead: """Update a credential. Pass api_key="" to clear it. Admin-only.""" cred = await crud.get_by_id(session, ProviderCredential, credential_id) if cred is None or cred.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) if payload.display_name is not None: cred.display_name = payload.display_name if payload.base_url is not None: cred.base_url = payload.base_url or None if payload.active is not None: cred.active = payload.active if payload.api_key is not None: if payload.api_key == "": cred.api_key = None cred.api_key_last_four = None else: cred.api_key = payload.api_key cred.api_key_last_four = payload.api_key[-4:] if len(payload.api_key) >= 4 else None cred.updated_at = utcnow() await crud.save(session, cred) return _to_read(cred) @router.get("/{credential_id}/usage", response_model=ProviderUsageLiveRead) async def get_provider_usage_live( credential_id: UUID, refresh: bool = Query(default=False, description="Bypass cache and fetch fresh data"), session: AsyncSession = SESSION_DEP, ctx: OrganizationContext = ORG_MEMBER_DEP, ) -> ProviderUsageLiveRead: """Fetch live token usage and rate limits directly from the provider API. Calls the provider's model-list endpoint (zero token cost) and reads the rate-limit response headers. Results are cached for 60 seconds. Pass ?refresh=true to force a fresh fetch. """ cred = await crud.get_by_id(session, ProviderCredential, credential_id) if cred is None or cred.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) live = await fetch_provider_usage( credential_id=str(credential_id), provider=cred.provider, account_key=cred.account_key, api_key=cred.api_key, base_url=cred.base_url, force_refresh=refresh, ) def _tok(w) -> TokenWindowRead: return TokenWindowRead( limit=w.limit, remaining=w.remaining, used=w.used, pct_used=w.pct_used, reset_at=w.reset_at.isoformat() if w.reset_at else None, reset_in_ms=w.reset_in_ms, ) def _req(w) -> RequestWindowRead: return RequestWindowRead( limit=w.limit, remaining=w.remaining, reset_at=w.reset_at.isoformat() if w.reset_at else None, reset_in_ms=w.reset_in_ms, ) return ProviderUsageLiveRead( provider=live.provider, account_key=live.account_key, checked_at=live.checked_at.isoformat(), reachable=live.reachable, source=live.source, confidence=live.confidence, error=live.error, tokens=_tok(live.tokens), input_tokens=_tok(live.input_tokens), requests=_req(live.requests), models=live.models, sample_model=live.sample_model, sample_input_tokens=live.sample_input_tokens, sample_output_tokens=live.sample_output_tokens, sample_latency_ms=live.sample_latency_ms, debug_rate_limit_headers=sorted(live.raw_headers.keys()) if live.raw_headers else None, ) @router.delete("/{credential_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_provider_credential( credential_id: UUID, session: AsyncSession = SESSION_DEP, ctx: OrganizationContext = ORG_ADMIN_DEP, ) -> None: """Delete a credential. Admin-only.""" cred = await crud.get_by_id(session, ProviderCredential, credential_id) if cred is None or cred.organization_id != ctx.organization.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) await session.delete(cred) await session.commit()