Pipeline/backend/app/api/provider_credentials.py

332 lines
12 KiB
Python

"""CRUD endpoints for AI provider credentials.
Every org member can read credentials (to see what providers are configured).
Creating, updating, and deleting requires org-admin access.
The full api_key is never returned in any response.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel import select
from app.api.deps import require_org_admin, require_org_member
from app.core.time import utcnow
from app.db import crud
from app.db.session import get_session
from app.models.provider_credentials import ProviderCredential, SUPPORTED_PROVIDERS
from app.schemas.provider_credentials import (
ProviderCredentialCreate,
ProviderCredentialRead,
ProviderCredentialTestRequest,
ProviderCredentialUpdate,
ProviderUsageLiveRead,
RequestWindowRead,
SubscriptionWindowRead,
TokenWindowRead,
)
from app.services.provider_usage import fetch_provider_usage
from app.services.organizations import OrganizationContext
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/provider-credentials", tags=["provider-credentials"])
SESSION_DEP = Depends(get_session)
ORG_MEMBER_DEP = Depends(require_org_member)
ORG_ADMIN_DEP = Depends(require_org_admin)
def _to_read(cred: ProviderCredential) -> ProviderCredentialRead:
return ProviderCredentialRead(
id=cred.id,
organization_id=cred.organization_id,
provider=cred.provider,
account_key=cred.account_key,
display_name=cred.display_name,
api_key_last_four=cred.api_key_last_four,
has_api_key=bool(cred.api_key),
session_key_last_four=cred.session_key_last_four,
has_session_key=bool(cred.session_key),
base_url=cred.base_url,
active=cred.active,
created_at=cred.created_at,
updated_at=cred.updated_at,
)
@router.get("", response_model=list[ProviderCredentialRead])
async def list_provider_credentials(
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> list[ProviderCredentialRead]:
"""List all provider credentials for the caller's organisation."""
rows = (
await session.exec(
select(ProviderCredential)
.where(ProviderCredential.organization_id == ctx.organization.id)
.order_by(ProviderCredential.provider, ProviderCredential.account_key)
)
).all()
return [_to_read(r) for r in rows]
@router.post("", response_model=ProviderCredentialRead, status_code=status.HTTP_201_CREATED)
async def create_provider_credential(
payload: ProviderCredentialCreate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> ProviderCredentialRead:
"""Create a new provider credential. Admin-only."""
if payload.provider not in SUPPORTED_PROVIDERS:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Unsupported provider '{payload.provider}'. Supported: {sorted(SUPPORTED_PROVIDERS)}",
)
# Check for duplicate (provider + account_key per org)
existing = (
await session.exec(
select(ProviderCredential).where(
ProviderCredential.organization_id == ctx.organization.id,
ProviderCredential.provider == payload.provider,
ProviderCredential.account_key == payload.account_key,
)
)
).first()
if existing is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"A '{payload.provider}' credential with account key '{payload.account_key}' already exists.",
)
last_four = payload.api_key[-4:] if payload.api_key and len(payload.api_key) >= 4 else None
sk_last_four = payload.session_key[-4:] if payload.session_key and len(payload.session_key) >= 4 else None
cred = ProviderCredential(
organization_id=ctx.organization.id,
provider=payload.provider,
account_key=payload.account_key,
display_name=payload.display_name or payload.account_key,
api_key=payload.api_key or None,
api_key_last_four=last_four,
session_key=payload.session_key or None,
session_key_last_four=sk_last_four,
base_url=payload.base_url or None,
active=payload.active,
)
session.add(cred)
await session.commit()
await session.refresh(cred)
return _to_read(cred)
@router.post("/test", response_model=ProviderUsageLiveRead)
async def test_provider_credential(
payload: ProviderCredentialTestRequest,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> ProviderUsageLiveRead:
"""Validate provider credentials without saving them. Admin-only."""
if payload.provider not in SUPPORTED_PROVIDERS:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Unsupported provider '{payload.provider}'. Supported: {sorted(SUPPORTED_PROVIDERS)}",
)
account_key = payload.account_key.strip() or "test"
live = await fetch_provider_usage(
credential_id=f"test:{ctx.organization.id}:{payload.provider}:{account_key}",
provider=payload.provider,
account_key=account_key,
api_key=payload.api_key,
base_url=payload.base_url,
force_refresh=True,
)
def _tok(w) -> TokenWindowRead:
return TokenWindowRead(
limit=w.limit,
remaining=w.remaining,
used=w.used,
pct_used=w.pct_used,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
def _req(w) -> RequestWindowRead:
return RequestWindowRead(
limit=w.limit,
remaining=w.remaining,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
return ProviderUsageLiveRead(
provider=live.provider,
account_key=live.account_key,
checked_at=live.checked_at.isoformat(),
reachable=live.reachable,
source=live.source,
confidence=live.confidence,
error=live.error,
tokens=_tok(live.tokens),
input_tokens=_tok(live.input_tokens),
output_tokens=_tok(live.output_tokens),
requests=_req(live.requests),
subscription_windows=[
SubscriptionWindowRead(
key=w.key,
label=w.label,
pct_used=w.pct_used,
reset_in_ms=w.reset_in_ms,
)
for w in live.subscription_windows
],
subscription_plan=live.subscription_plan,
models=live.models,
sample_model=live.sample_model,
sample_input_tokens=live.sample_input_tokens,
sample_output_tokens=live.sample_output_tokens,
sample_latency_ms=live.sample_latency_ms,
debug_rate_limit_headers=sorted(live.raw_headers.keys()) if live.raw_headers else None,
)
@router.get("/{credential_id}", response_model=ProviderCredentialRead)
async def get_provider_credential(
credential_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> ProviderCredentialRead:
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return _to_read(cred)
@router.patch("/{credential_id}", response_model=ProviderCredentialRead)
async def update_provider_credential(
credential_id: UUID,
payload: ProviderCredentialUpdate,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> ProviderCredentialRead:
"""Update a credential. Pass api_key="" to clear it. Admin-only."""
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if payload.display_name is not None:
cred.display_name = payload.display_name
if payload.base_url is not None:
cred.base_url = payload.base_url or None
if payload.active is not None:
cred.active = payload.active
if payload.api_key is not None:
if payload.api_key == "":
cred.api_key = None
cred.api_key_last_four = None
else:
cred.api_key = payload.api_key
cred.api_key_last_four = payload.api_key[-4:] if len(payload.api_key) >= 4 else None
if payload.session_key is not None:
if payload.session_key == "":
cred.session_key = None
cred.session_key_last_four = None
else:
cred.session_key = payload.session_key
cred.session_key_last_four = payload.session_key[-4:] if len(payload.session_key) >= 4 else None
cred.updated_at = utcnow()
await crud.save(session, cred)
return _to_read(cred)
@router.get("/{credential_id}/usage", response_model=ProviderUsageLiveRead)
async def get_provider_usage_live(
credential_id: UUID,
refresh: bool = Query(default=False, description="Bypass cache and fetch fresh data"),
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_MEMBER_DEP,
) -> ProviderUsageLiveRead:
"""Fetch live token usage and rate limits directly from the provider API.
Calls the provider's model-list endpoint (zero token cost) and reads the
rate-limit response headers. Results are cached for 60 seconds.
Pass ?refresh=true to force a fresh fetch.
"""
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
live = await fetch_provider_usage(
credential_id=str(credential_id),
provider=cred.provider,
account_key=cred.account_key,
api_key=cred.api_key,
base_url=cred.base_url,
session_key=cred.session_key,
force_refresh=refresh,
)
def _tok(w) -> TokenWindowRead:
return TokenWindowRead(
limit=w.limit, remaining=w.remaining, used=w.used,
pct_used=w.pct_used,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
def _req(w) -> RequestWindowRead:
return RequestWindowRead(
limit=w.limit, remaining=w.remaining,
reset_at=w.reset_at.isoformat() if w.reset_at else None,
reset_in_ms=w.reset_in_ms,
)
return ProviderUsageLiveRead(
provider=live.provider,
account_key=live.account_key,
checked_at=live.checked_at.isoformat(),
reachable=live.reachable,
source=live.source,
confidence=live.confidence,
error=live.error,
tokens=_tok(live.tokens),
input_tokens=_tok(live.input_tokens),
output_tokens=_tok(live.output_tokens),
requests=_req(live.requests),
subscription_windows=[
SubscriptionWindowRead(
key=w.key,
label=w.label,
pct_used=w.pct_used,
reset_in_ms=w.reset_in_ms,
)
for w in live.subscription_windows
],
subscription_plan=live.subscription_plan,
models=live.models,
sample_model=live.sample_model,
sample_input_tokens=live.sample_input_tokens,
sample_output_tokens=live.sample_output_tokens,
sample_latency_ms=live.sample_latency_ms,
debug_rate_limit_headers=sorted(live.raw_headers.keys()) if live.raw_headers else None,
)
@router.delete("/{credential_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_provider_credential(
credential_id: UUID,
session: AsyncSession = SESSION_DEP,
ctx: OrganizationContext = ORG_ADMIN_DEP,
) -> None:
"""Delete a credential. Admin-only."""
cred = await crud.get_by_id(session, ProviderCredential, credential_id)
if cred is None or cred.organization_id != ctx.organization.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
await session.delete(cred)
await session.commit()