Pipeline/backend/app/services/provider_usage.py

602 lines
23 KiB
Python

"""Live provider usage — fetch real token limits and reset times directly from provider APIs.
Each provider exposes rate-limit headers on every API response. Pipeline calls
a lightweight, zero-token-cost endpoint (model list) and reads those headers.
No guessing, no JSONL scanning, no estimates. If the provider API is
unreachable or the key is invalid, an error is returned and all limit fields
are None.
Supported providers
-------------------
anthropic → GET https://api.anthropic.com/v1/models
Headers: anthropic-ratelimit-tokens-limit/remaining/reset
anthropic-ratelimit-requests-limit/remaining/reset
anthropic-ratelimit-input-tokens-limit/remaining/reset
Fallback probe (only when headers missing):
POST /v1/messages with max_tokens=1 to surface usage+time data.
openai → GET https://api.openai.com/v1/models
(codex) Headers: x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens,
x-ratelimit-reset-tokens, x-ratelimit-limit-requests,
x-ratelimit-remaining-requests, x-ratelimit-reset-requests
Fallback probe (only when headers missing):
POST /v1/chat/completions with max_tokens=1 to surface usage+time.
ollama → GET {base_url}/api/tags (health-check only; no rate limits)
Returns: model list, server reachable flag
Fallback probe:
POST {base_url}/api/generate with num_predict=1 for usage+time.
Caching
-------
Results are cached per credential_id for CACHE_TTL_SECONDS (default 60s) to
avoid hammering provider APIs on every page load.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any
import httpx
from app.core.logging import get_logger
from app.core.time import utcnow
logger = get_logger(__name__)
CACHE_TTL_SECONDS = 60
REQUEST_TIMEOUT = 8.0 # seconds
# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------
@dataclass
class TokenWindow:
limit: int | None = None
remaining: int | None = None
reset_at: datetime | None = None # UTC naive datetime
@property
def reset_in_ms(self) -> int | None:
if self.reset_at is None:
return None
delta = (self.reset_at - utcnow()).total_seconds()
return max(0, int(delta * 1000))
@property
def used(self) -> int | None:
if self.limit is not None and self.remaining is not None:
return max(0, self.limit - self.remaining)
return None
@property
def pct_used(self) -> float | None:
if self.limit and self.limit > 0 and self.remaining is not None:
return round((1 - self.remaining / self.limit) * 100, 1)
return None
@dataclass
class RequestWindow:
limit: int | None = None
remaining: int | None = None
reset_at: datetime | None = None
@property
def reset_in_ms(self) -> int | None:
if self.reset_at is None:
return None
delta = (self.reset_at - utcnow()).total_seconds()
return max(0, int(delta * 1000))
@dataclass
class ProviderUsageLive:
provider: str
account_key: str
checked_at: datetime
reachable: bool
error: str | None = None
tokens: TokenWindow = field(default_factory=TokenWindow)
input_tokens: TokenWindow = field(default_factory=TokenWindow) # Anthropic splits input/output
requests: RequestWindow = field(default_factory=RequestWindow)
models: list[str] = field(default_factory=list) # model IDs available on this key
raw_headers: dict[str, str] = field(default_factory=dict)
sample_model: str | None = None
sample_input_tokens: int | None = None
sample_output_tokens: int | None = None
sample_latency_ms: int | None = None
def to_dict(self) -> dict[str, Any]:
def _window(w: TokenWindow | RequestWindow) -> dict[str, Any]:
d: dict[str, Any] = {}
if hasattr(w, "limit"):
d["limit"] = w.limit
if hasattr(w, "remaining"):
d["remaining"] = w.remaining
if hasattr(w, "reset_in_ms"):
d["reset_in_ms"] = w.reset_in_ms
if hasattr(w, "reset_at"):
d["reset_at"] = w.reset_at.isoformat() if w.reset_at else None
if isinstance(w, TokenWindow):
d["used"] = w.used
d["pct_used"] = w.pct_used
return d
return {
"provider": self.provider,
"account_key": self.account_key,
"checked_at": self.checked_at.isoformat(),
"reachable": self.reachable,
"error": self.error,
"tokens": _window(self.tokens),
"input_tokens": _window(self.input_tokens),
"requests": _window(self.requests),
"models": self.models[:20], # cap for response size
"sample_model": self.sample_model,
"sample_input_tokens": self.sample_input_tokens,
"sample_output_tokens": self.sample_output_tokens,
"sample_latency_ms": self.sample_latency_ms,
}
# ---------------------------------------------------------------------------
# Header parsers
# ---------------------------------------------------------------------------
def _parse_int_header(headers: dict[str, str], *names: str) -> int | None:
for name in names:
val = headers.get(name.lower())
if val is not None:
try:
return int(val)
except (ValueError, TypeError):
pass
return None
def _parse_iso_reset(value: str) -> datetime | None:
"""Parse an ISO 8601 reset timestamp → UTC naive datetime."""
if not value:
return None
try:
normalized = value.strip().replace("Z", "+00:00")
dt = datetime.fromisoformat(normalized)
if dt.tzinfo is not None:
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
return dt
except ValueError:
return None
# OpenAI encodes reset as a duration string like "1m30s", "2h", "30s"
_OAI_DURATION_RE = re.compile(
r"(?:(\d+)h)?(?:(\d+)m)?(?:(\d+(?:\.\d+)?)s)?$"
)
def _apply_anthropic_ratelimit_headers(result: ProviderUsageLive, headers: dict[str, str]) -> None:
"""Populate Anthropic limit windows from response headers."""
result.tokens = TokenWindow(
limit=_parse_int_header(headers, "anthropic-ratelimit-tokens-limit"),
remaining=_parse_int_header(headers, "anthropic-ratelimit-tokens-remaining"),
reset_at=_parse_iso_reset(headers.get("anthropic-ratelimit-tokens-reset", "")),
)
result.input_tokens = TokenWindow(
limit=_parse_int_header(headers, "anthropic-ratelimit-input-tokens-limit"),
remaining=_parse_int_header(headers, "anthropic-ratelimit-input-tokens-remaining"),
reset_at=_parse_iso_reset(headers.get("anthropic-ratelimit-input-tokens-reset", "")),
)
result.requests = RequestWindow(
limit=_parse_int_header(headers, "anthropic-ratelimit-requests-limit"),
remaining=_parse_int_header(headers, "anthropic-ratelimit-requests-remaining"),
reset_at=_parse_iso_reset(headers.get("anthropic-ratelimit-requests-reset", "")),
)
def _pick_anthropic_probe_model(models: list[str]) -> str | None:
if not models:
return None
priorities = ("haiku", "sonnet", "opus")
lowered = [(m, m.lower()) for m in models]
for priority in priorities:
for original, lowered_name in lowered:
if priority in lowered_name:
return original
return models[0]
def _pick_openai_probe_model(models: list[str]) -> str | None:
if not models:
return None
priorities = ("gpt-4.1-mini", "gpt-4o-mini", "gpt-4.1", "gpt-4o", "o4-mini")
lowered = [(m, m.lower()) for m in models]
for priority in priorities:
for original, lowered_name in lowered:
if priority in lowered_name:
return original
return models[0]
def _parse_openai_reset(value: str) -> datetime | None:
"""Parse an OpenAI reset header: ISO datetime OR duration like '1m30s'."""
if not value:
return None
# ISO format first
if "T" in value or value.endswith("Z"):
return _parse_iso_reset(value)
# Duration
m = _OAI_DURATION_RE.match(value.strip())
if m and any(m.groups()):
h = float(m.group(1) or 0)
mn = float(m.group(2) or 0)
s = float(m.group(3) or 0)
total_seconds = h * 3600 + mn * 60 + s
return utcnow() + timedelta(seconds=total_seconds)
return None
# ---------------------------------------------------------------------------
# Provider-specific fetch functions
# ---------------------------------------------------------------------------
async def _fetch_anthropic(api_key: str, base_url: str | None) -> ProviderUsageLive:
base = (base_url or "https://api.anthropic.com").rstrip("/")
now = utcnow()
result = ProviderUsageLive(provider="anthropic", account_key="", checked_at=now, reachable=False)
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
try:
resp = await client.get(
f"{base}/v1/models",
headers={
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
},
)
except (httpx.ConnectError, httpx.TimeoutException) as exc:
result.error = f"Connection failed: {exc}"
return result
except Exception as exc:
result.error = str(exc)
return result
if resp.status_code == 401:
result.error = "Invalid API key (401)."
return result
if resp.status_code not in (200, 429):
result.error = f"Provider returned HTTP {resp.status_code}."
# Still try to parse headers on 429 (rate limited but data is there)
if resp.status_code != 429:
return result
h = {k.lower(): v for k, v in resp.headers.items()}
result.reachable = True
result.raw_headers = {k: v for k, v in h.items() if "ratelimit" in k}
_apply_anthropic_ratelimit_headers(result, h)
# Extract model IDs
try:
data = resp.json()
items = data.get("data") or data if isinstance(data.get("data"), list) else []
result.models = [m.get("id", "") for m in items if isinstance(m, dict) and m.get("id")]
except Exception:
pass
# Some tiers/paths may omit ratelimit headers on /v1/models.
# Fallback to a minimal /v1/messages probe so we can still surface usage/time.
if (
result.tokens.limit is None
and result.input_tokens.limit is None
and result.requests.limit is None
):
probe_model = _pick_anthropic_probe_model(result.models)
if probe_model:
result.sample_model = probe_model
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
try:
probe_resp = await client.post(
f"{base}/v1/messages",
headers={
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
},
json={
"model": probe_model,
"max_tokens": 1,
"messages": [{"role": "user", "content": "Usage probe"}],
},
)
except Exception:
probe_resp = None
if probe_resp is not None:
probe_headers = {k.lower(): v for k, v in probe_resp.headers.items()}
probe_rl_headers = {k: v for k, v in probe_headers.items() if "ratelimit" in k}
if probe_rl_headers:
result.raw_headers = probe_rl_headers
_apply_anthropic_ratelimit_headers(result, probe_headers)
if probe_resp.status_code == 200:
try:
payload = probe_resp.json()
usage = payload.get("usage") if isinstance(payload, dict) else None
if isinstance(usage, dict):
in_tok = usage.get("input_tokens")
out_tok = usage.get("output_tokens")
if isinstance(in_tok, int):
result.sample_input_tokens = in_tok
if isinstance(out_tok, int):
result.sample_output_tokens = out_tok
except Exception:
pass
elapsed_ms = probe_resp.elapsed.total_seconds() * 1000.0
result.sample_latency_ms = int(max(0.0, round(elapsed_ms)))
return result
async def _fetch_openai(api_key: str, base_url: str | None) -> ProviderUsageLive:
base = (base_url or "https://api.openai.com").rstrip("/")
now = utcnow()
result = ProviderUsageLive(provider="openai", account_key="", checked_at=now, reachable=False)
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
try:
resp = await client.get(
f"{base}/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)
except (httpx.ConnectError, httpx.TimeoutException) as exc:
result.error = f"Connection failed: {exc}"
return result
except Exception as exc:
result.error = str(exc)
return result
if resp.status_code == 401:
result.error = "Invalid API key (401)."
return result
if resp.status_code not in (200, 429):
result.error = f"Provider returned HTTP {resp.status_code}."
if resp.status_code != 429:
return result
h = {k.lower(): v for k, v in resp.headers.items()}
result.reachable = True
result.raw_headers = {k: v for k, v in h.items() if "ratelimit" in k}
result.tokens = TokenWindow(
limit = _parse_int_header(h, "x-ratelimit-limit-tokens"),
remaining = _parse_int_header(h, "x-ratelimit-remaining-tokens"),
reset_at = _parse_openai_reset(h.get("x-ratelimit-reset-tokens", "")),
)
result.requests = RequestWindow(
limit = _parse_int_header(h, "x-ratelimit-limit-requests"),
remaining = _parse_int_header(h, "x-ratelimit-remaining-requests"),
reset_at = _parse_openai_reset(h.get("x-ratelimit-reset-requests", "")),
)
try:
data = resp.json()
items = data.get("data") or []
result.models = [m.get("id", "") for m in items if isinstance(m, dict) and m.get("id")]
except Exception:
pass
if result.tokens.limit is None and result.requests.limit is None:
probe_model = _pick_openai_probe_model(result.models)
if probe_model:
result.sample_model = probe_model
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
try:
probe_resp = await client.post(
f"{base}/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"content-type": "application/json",
},
json={
"model": probe_model,
"messages": [{"role": "user", "content": "Usage probe"}],
"max_tokens": 1,
},
)
except Exception:
probe_resp = None
if probe_resp is not None:
probe_headers = {k.lower(): v for k, v in probe_resp.headers.items()}
probe_rl_headers = {k: v for k, v in probe_headers.items() if "ratelimit" in k}
if probe_rl_headers:
result.raw_headers = probe_rl_headers
result.tokens = TokenWindow(
limit=_parse_int_header(probe_headers, "x-ratelimit-limit-tokens"),
remaining=_parse_int_header(probe_headers, "x-ratelimit-remaining-tokens"),
reset_at=_parse_openai_reset(
probe_headers.get("x-ratelimit-reset-tokens", "")
),
)
result.requests = RequestWindow(
limit=_parse_int_header(probe_headers, "x-ratelimit-limit-requests"),
remaining=_parse_int_header(
probe_headers, "x-ratelimit-remaining-requests"
),
reset_at=_parse_openai_reset(
probe_headers.get("x-ratelimit-reset-requests", "")
),
)
if probe_resp.status_code == 200:
try:
payload = probe_resp.json()
usage = payload.get("usage") if isinstance(payload, dict) else None
if isinstance(usage, dict):
in_tok = usage.get("prompt_tokens")
out_tok = usage.get("completion_tokens")
if isinstance(in_tok, int):
result.sample_input_tokens = in_tok
if isinstance(out_tok, int):
result.sample_output_tokens = out_tok
except Exception:
pass
elapsed_ms = probe_resp.elapsed.total_seconds() * 1000.0
result.sample_latency_ms = int(max(0.0, round(elapsed_ms)))
return result
async def _fetch_ollama(base_url: str | None, api_key: str | None) -> ProviderUsageLive:
base = (base_url or "http://localhost:11434").rstrip("/")
now = utcnow()
result = ProviderUsageLive(provider="ollama", account_key="", checked_at=now, reachable=False)
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
try:
resp = await client.get(f"{base}/api/tags", headers=headers)
except (httpx.ConnectError, httpx.TimeoutException) as exc:
result.error = f"Ollama unreachable: {exc}"
return result
except Exception as exc:
result.error = str(exc)
return result
if resp.status_code not in (200,):
result.error = f"Ollama returned HTTP {resp.status_code}."
return result
result.reachable = True
# Ollama has no rate limits — just expose available models
try:
data = resp.json()
models_raw = data.get("models") or []
result.models = [m.get("name", "") for m in models_raw if isinstance(m, dict) and m.get("name")]
except Exception:
pass
if result.models:
result.sample_model = result.models[0]
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
try:
probe_resp = await client.post(
f"{base}/api/generate",
headers={**headers, "content-type": "application/json"},
json={
"model": result.sample_model,
"prompt": "Usage probe",
"stream": False,
"options": {"num_predict": 1},
},
)
except Exception:
probe_resp = None
if probe_resp is not None and probe_resp.status_code == 200:
try:
payload = probe_resp.json()
in_tok = payload.get("prompt_eval_count")
out_tok = payload.get("eval_count")
total_duration_ns = payload.get("total_duration")
if isinstance(in_tok, int):
result.sample_input_tokens = in_tok
if isinstance(out_tok, int):
result.sample_output_tokens = out_tok
if isinstance(total_duration_ns, int):
result.sample_latency_ms = max(0, int(round(total_duration_ns / 1_000_000)))
except Exception:
pass
return result
# ---------------------------------------------------------------------------
# In-memory TTL cache
# ---------------------------------------------------------------------------
_cache: dict[str, tuple[datetime, ProviderUsageLive]] = {}
def _get_cached(credential_id: str) -> ProviderUsageLive | None:
entry = _cache.get(credential_id)
if entry is None:
return None
cached_at, result = entry
if (utcnow() - cached_at).total_seconds() > CACHE_TTL_SECONDS:
del _cache[credential_id]
return None
return result
def _set_cached(credential_id: str, result: ProviderUsageLive) -> None:
_cache[credential_id] = (utcnow(), result)
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def fetch_provider_usage(
credential_id: str,
provider: str,
account_key: str,
api_key: str | None,
base_url: str | None,
*,
force_refresh: bool = False,
) -> ProviderUsageLive:
"""Fetch live usage from the provider API.
Results are cached for CACHE_TTL_SECONDS. Pass force_refresh=True to
bypass the cache (e.g., when the user clicks Refresh).
"""
if not force_refresh:
cached = _get_cached(credential_id)
if cached is not None:
return cached
if provider == "anthropic":
if not api_key:
result = ProviderUsageLive(
provider=provider, account_key=account_key,
checked_at=utcnow(), reachable=False,
error="No API key configured.",
)
else:
result = await _fetch_anthropic(api_key, base_url)
elif provider in ("openai", "codex"):
if not api_key:
result = ProviderUsageLive(
provider=provider, account_key=account_key,
checked_at=utcnow(), reachable=False,
error="No API key configured.",
)
else:
result = await _fetch_openai(api_key, base_url)
elif provider == "ollama":
result = await _fetch_ollama(base_url, api_key)
else:
result = ProviderUsageLive(
provider=provider, account_key=account_key,
checked_at=utcnow(), reachable=False,
error=f"Live usage not supported for provider '{provider}'.",
)
result.account_key = account_key
_set_cached(credential_id, result)
logger.info(
"provider_usage.checked provider=%s account=%s reachable=%s error=%s",
provider, account_key, result.reachable, result.error,
)
return result