Pipeline/backend/app/api/provider_credentials.py

284 lines
10 KiB
Python
Raw Normal View History

"""CRUD endpoints for AI provider credentials.
Every org member can read credentials (to see what providers are configured).
Creating, updating, and deleting requires org-admin access.
The full api_key is never returned in any response.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import select
from app.api.deps import require_org_admin, require_org_member
from app.core.time import utcnow
from app.db import crud
from app.db.session import get_session
from app.models.provider_credentials import ProviderCredential, SUPPORTED_PROVIDERS
from app.schemas.provider_credentials import (
ProviderCredentialCreate,
ProviderCredentialRead,
2026-05-20 23:03:19 -05:00
ProviderCredentialTestRequest,
ProviderCredentialUpdate,
ProviderUsageLiveRead,
RequestWindowRead,
TokenWindowRead,
)
from app.services.provider_usage import fetch_provider_usage
from app.services.organizations import OrganizationContext
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/provider-credentials", tags=["provider-credentials"])
SESSION_DEP = Depends(get_session)
ORG_MEMBER_DEP = Depends(require_org_member)
ORG_ADMIN_DEP = Depends(require_org_admin)
def _to_read(cred: ProviderCredential) -> ProviderCredentialRead:
return ProviderCredentialRead(
id=cred.id,
organization_id=cred.organization_id,
provider=cred.provider,
account_key=cred.account_key,
display_name=cred.display_name,
api_key_last_four=cred.api_key_last_four,
has_api_key=bool(cred.api_key),
base_url=cred.base_url,
active=cred.active,
created_at=cred.created_at,
updated_at=cred.updated_at,
)
@router.get("", response_model=list[ProviderCredentialRead])
async def list_provider_credentials(
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> list[ProviderCredentialRead]:
"""List all provider credentials for the caller's organisation."""
rows = (
await session.exec(
select(ProviderCredential)
.where(ProviderCredential.organization_id == ctx.organization.id)
.order_by(ProviderCredential.provider, ProviderCredential.account_key)
)
).all()
return [_to_read(r) for r in rows]
@router.post("", response_model=ProviderCredentialRead, status_code=status.HTTP_201_CREATED)
async def create_provider_credential(
payload: ProviderCredentialCreate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> ProviderCredentialRead:
"""Create a new provider credential. Admin-only."""
if payload.provider not in SUPPORTED_PROVIDERS:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Unsupported provider '{payload.provider}'. Supported: {sorted(SUPPORTED_PROVIDERS)}",
)
# Check for duplicate (provider + account_key per org)
existing = (
await session.exec(
select(ProviderCredential).where(
ProviderCredential.organization_id == ctx.organization.id,
ProviderCredential.provider == payload.provider,
ProviderCredential.account_key == payload.account_key,
)
)
).first()
if existing is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"A '{payload.provider}' credential with account key '{payload.account_key}' already exists.",
)
last_four = payload.api_key[-4:] if payload.api_key and len(payload.api_key) >= 4 else None
cred = ProviderCredential(
organization_id=ctx.organization.id,
provider=payload.provider,
account_key=payload.account_key,
display_name=payload.display_name or payload.account_key,
api_key=payload.api_key or None,
api_key_last_four=last_four,
base_url=payload.base_url or None,
active=payload.active,
)
session.add(cred)
await session.commit()
await session.refresh(cred)
return _to_read(cred)
2026-05-20 23:03:19 -05:00
@router.post("/test", response_model=ProviderUsageLiveRead)
async def test_provider_credential(
payload: ProviderCredentialTestRequest,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> ProviderUsageLiveRead:
"""Validate provider credentials without saving them. Admin-only."""
if payload.provider not in SUPPORTED_PROVIDERS:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Unsupported provider '{payload.provider}'. Supported: {sorted(SUPPORTED_PROVIDERS)}",
)
account_key = payload.account_key.strip() or "test"
live = await fetch_provider_usage(
credential_id=f"test:{ctx.organization.id}:{payload.provider}:{account_key}",
provider=payload.provider,
account_key=account_key,
api_key=payload.api_key,
base_url=payload.base_url,
force_refresh=True,
)
def _tok(w) -> TokenWindowRead:
return TokenWindowRead(
limit=w.limit,
remaining=w.remaining,
used=w.used,
pct_used=w.pct_used,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
def _req(w) -> RequestWindowRead:
return RequestWindowRead(
limit=w.limit,
remaining=w.remaining,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
return ProviderUsageLiveRead(
provider=live.provider,
account_key=live.account_key,
checked_at=live.checked_at.isoformat(),
reachable=live.reachable,
error=live.error,
tokens=_tok(live.tokens),
input_tokens=_tok(live.input_tokens),
requests=_req(live.requests),
models=live.models,
debug_rate_limit_headers=sorted(live.raw_headers.keys()) if live.raw_headers else None,
)
@router.get("/{credential_id}", response_model=ProviderCredentialRead)
async def get_provider_credential(
credential_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> ProviderCredentialRead:
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return _to_read(cred)
@router.patch("/{credential_id}", response_model=ProviderCredentialRead)
async def update_provider_credential(
credential_id: UUID,
payload: ProviderCredentialUpdate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> ProviderCredentialRead:
"""Update a credential. Pass api_key="" to clear it. Admin-only."""
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if payload.display_name is not None:
cred.display_name = payload.display_name
if payload.base_url is not None:
cred.base_url = payload.base_url or None
if payload.active is not None:
cred.active = payload.active
if payload.api_key is not None:
if payload.api_key == "":
cred.api_key = None
cred.api_key_last_four = None
else:
cred.api_key = payload.api_key
cred.api_key_last_four = payload.api_key[-4:] if len(payload.api_key) >= 4 else None
cred.updated_at = utcnow()
await crud.save(session, cred)
return _to_read(cred)
@router.get("/{credential_id}/usage", response_model=ProviderUsageLiveRead)
async def get_provider_usage_live(
credential_id: UUID,
refresh: bool = Query(default=False, description="Bypass cache and fetch fresh data"),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> ProviderUsageLiveRead:
"""Fetch live token usage and rate limits directly from the provider API.
Calls the provider's model-list endpoint (zero token cost) and reads the
rate-limit response headers. Results are cached for 60 seconds.
Pass ?refresh=true to force a fresh fetch.
"""
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
live = await fetch_provider_usage(
credential_id=str(credential_id),
provider=cred.provider,
account_key=cred.account_key,
api_key=cred.api_key,
base_url=cred.base_url,
force_refresh=refresh,
)
def _tok(w) -> TokenWindowRead:
return TokenWindowRead(
limit=w.limit, remaining=w.remaining, used=w.used,
pct_used=w.pct_used,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
def _req(w) -> RequestWindowRead:
return RequestWindowRead(
limit=w.limit, remaining=w.remaining,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
return ProviderUsageLiveRead(
provider=live.provider,
account_key=live.account_key,
checked_at=live.checked_at.isoformat(),
reachable=live.reachable,
error=live.error,
tokens=_tok(live.tokens),
input_tokens=_tok(live.input_tokens),
requests=_req(live.requests),
models=live.models,
2026-05-20 23:03:19 -05:00
debug_rate_limit_headers=sorted(live.raw_headers.keys()) if live.raw_headers else None,
)
@router.delete("/{credential_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_provider_credential(
credential_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> None:
"""Delete a credential. Admin-only."""
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await session.delete(cred)
await session.commit()