1079 lines
41 KiB
Python
1079 lines
41 KiB
Python
"""Live provider usage — fetch token limits and reset times from provider APIs.
|
||
|
||
Some providers expose rate-limit headers on regular API responses. Pipeline first
|
||
calls a lightweight endpoint (model list) and then, when needed, performs a
|
||
minimal generation probe to surface usage and latency details.
|
||
|
||
No guessing, no JSONL scanning, no estimates. If the provider API is
|
||
unreachable or the key is invalid, an error is returned and all limit fields
|
||
are None.
|
||
|
||
Supported providers
|
||
-------------------
|
||
anthropic → GET https://api.anthropic.com/v1/models
|
||
Headers: anthropic-ratelimit-tokens-limit/remaining/reset
|
||
anthropic-ratelimit-requests-limit/remaining/reset
|
||
anthropic-ratelimit-input-tokens-limit/remaining/reset
|
||
Fallback probe (only when headers missing):
|
||
POST /v1/messages with max_tokens=1 to surface usage+time data.
|
||
|
||
openai → GET https://api.openai.com/v1/models
|
||
(codex) Headers: x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens,
|
||
x-ratelimit-reset-tokens, x-ratelimit-limit-requests,
|
||
x-ratelimit-remaining-requests, x-ratelimit-reset-requests
|
||
Fallback probe (only when headers missing):
|
||
POST /v1/responses with max_output_tokens=1 (preferred),
|
||
then /v1/chat/completions with max_tokens=1 (compatibility).
|
||
|
||
ollama → GET {base_url}/api/tags (health-check only; no rate limits)
|
||
Returns: model list, server reachable flag
|
||
Fallback probe:
|
||
POST {base_url}/api/generate with num_predict=1 for usage+time.
|
||
|
||
Caching
|
||
-------
|
||
Results are cached per credential_id for CACHE_TTL_SECONDS (default 60s) to
|
||
avoid hammering provider APIs on every page load.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json as _json_module
|
||
import os
|
||
import re
|
||
import time as _time_module
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.core.logging import get_logger
|
||
from app.core.time import utcnow
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
CACHE_TTL_SECONDS = 60
|
||
CACHE_TTL_FAILURE_SECONDS = 5 # short TTL for results with no subscription windows
|
||
REQUEST_TIMEOUT = 8.0 # seconds
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Result types
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class TokenWindow:
|
||
limit: int | None = None
|
||
remaining: int | None = None
|
||
reset_at: datetime | None = None # UTC naive datetime
|
||
|
||
@property
|
||
def reset_in_ms(self) -> int | None:
|
||
if self.reset_at is None:
|
||
return None
|
||
delta = (self.reset_at - utcnow()).total_seconds()
|
||
return max(0, int(delta * 1000))
|
||
|
||
@property
|
||
def used(self) -> int | None:
|
||
if self.limit is not None and self.remaining is not None:
|
||
return max(0, self.limit - self.remaining)
|
||
return None
|
||
|
||
@property
|
||
def pct_used(self) -> float | None:
|
||
if self.limit and self.limit > 0 and self.remaining is not None:
|
||
return round((1 - self.remaining / self.limit) * 100, 1)
|
||
return None
|
||
|
||
|
||
@dataclass
|
||
class RequestWindow:
|
||
limit: int | None = None
|
||
remaining: int | None = None
|
||
reset_at: datetime | None = None
|
||
|
||
@property
|
||
def reset_in_ms(self) -> int | None:
|
||
if self.reset_at is None:
|
||
return None
|
||
delta = (self.reset_at - utcnow()).total_seconds()
|
||
return max(0, int(delta * 1000))
|
||
|
||
|
||
@dataclass
|
||
class SubscriptionWindow:
|
||
"""One subscription-plan usage window (e.g. 5h session, 7-day all-models)."""
|
||
|
||
key: str # "five_hour" | "seven_day" | "seven_day_sonnet" | "seven_day_opus"
|
||
label: str # human label: "Current session" | "All models" | "Sonnet" | "Opus"
|
||
pct_used: float # 0–100
|
||
reset_at: datetime | None = None # UTC naive datetime
|
||
|
||
@property
|
||
def reset_in_ms(self) -> int | None:
|
||
if self.reset_at is None:
|
||
return None
|
||
delta = (self.reset_at - utcnow()).total_seconds()
|
||
return max(0, int(delta * 1000))
|
||
|
||
|
||
@dataclass
|
||
class ProviderUsageLive:
|
||
provider: str
|
||
account_key: str
|
||
checked_at: datetime
|
||
reachable: bool
|
||
# Phase 1 semantics: this service reports provider API diagnostics
|
||
# (rate-limit windows / probe metadata), not subscription usage windows.
|
||
source: str = "provider_api_rate_limit"
|
||
confidence: str = "high"
|
||
error: str | None = None
|
||
tokens: TokenWindow = field(default_factory=TokenWindow)
|
||
input_tokens: TokenWindow = field(default_factory=TokenWindow) # Anthropic input-only window
|
||
output_tokens: TokenWindow = field(default_factory=TokenWindow) # Anthropic output-only window
|
||
requests: RequestWindow = field(default_factory=RequestWindow)
|
||
# Provider subscription windows — populated when session_key is provided
|
||
subscription_windows: list[SubscriptionWindow] = field(default_factory=list)
|
||
subscription_plan: str | None = None # e.g. "pro", "plus ($15.00)"
|
||
models: list[str] = field(default_factory=list) # model IDs available on this key
|
||
raw_headers: dict[str, str] = field(default_factory=dict)
|
||
sample_model: str | None = None
|
||
sample_input_tokens: int | None = None
|
||
sample_output_tokens: int | None = None
|
||
sample_latency_ms: int | None = None
|
||
|
||
def to_dict(self) -> dict[str, Any]:
|
||
def _window(w: TokenWindow | RequestWindow) -> dict[str, Any]:
|
||
d: dict[str, Any] = {}
|
||
if hasattr(w, "limit"):
|
||
d["limit"] = w.limit
|
||
if hasattr(w, "remaining"):
|
||
d["remaining"] = w.remaining
|
||
if hasattr(w, "reset_in_ms"):
|
||
d["reset_in_ms"] = w.reset_in_ms
|
||
if hasattr(w, "reset_at"):
|
||
d["reset_at"] = w.reset_at.isoformat() if w.reset_at else None
|
||
if isinstance(w, TokenWindow):
|
||
d["used"] = w.used
|
||
d["pct_used"] = w.pct_used
|
||
return d
|
||
|
||
return {
|
||
"provider": self.provider,
|
||
"account_key": self.account_key,
|
||
"checked_at": self.checked_at.isoformat(),
|
||
"reachable": self.reachable,
|
||
"source": self.source,
|
||
"confidence": self.confidence,
|
||
"error": self.error,
|
||
"tokens": _window(self.tokens),
|
||
"input_tokens": _window(self.input_tokens),
|
||
"output_tokens": _window(self.output_tokens),
|
||
"requests": _window(self.requests),
|
||
"models": self.models[:20], # cap for response size
|
||
"sample_model": self.sample_model,
|
||
"sample_input_tokens": self.sample_input_tokens,
|
||
"sample_output_tokens": self.sample_output_tokens,
|
||
"sample_latency_ms": self.sample_latency_ms,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Header parsers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _parse_int_header(headers: dict[str, str], *names: str) -> int | None:
|
||
for name in names:
|
||
val = headers.get(name.lower())
|
||
if val is not None:
|
||
try:
|
||
return int(val)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
return None
|
||
|
||
|
||
def _parse_iso_reset(value: str) -> datetime | None:
|
||
"""Parse an ISO 8601 reset timestamp → UTC naive datetime."""
|
||
if not value:
|
||
return None
|
||
try:
|
||
normalized = value.strip().replace("Z", "+00:00")
|
||
dt = datetime.fromisoformat(normalized)
|
||
if dt.tzinfo is not None:
|
||
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||
return dt
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
# OpenAI encodes reset as a duration string like "1m30s", "2h", "30s"
|
||
_OAI_DURATION_RE = re.compile(
|
||
r"(?:(\d+)h)?(?:(\d+)m)?(?:(\d+(?:\.\d+)?)s)?$"
|
||
)
|
||
|
||
|
||
def _apply_anthropic_ratelimit_headers(result: ProviderUsageLive, headers: dict[str, str]) -> None:
|
||
"""Populate Anthropic limit windows from response headers."""
|
||
result.tokens = TokenWindow(
|
||
limit=_parse_int_header(headers, "anthropic-ratelimit-tokens-limit"),
|
||
remaining=_parse_int_header(headers, "anthropic-ratelimit-tokens-remaining"),
|
||
reset_at=_parse_iso_reset(headers.get("anthropic-ratelimit-tokens-reset", "")),
|
||
)
|
||
result.input_tokens = TokenWindow(
|
||
limit=_parse_int_header(headers, "anthropic-ratelimit-input-tokens-limit"),
|
||
remaining=_parse_int_header(headers, "anthropic-ratelimit-input-tokens-remaining"),
|
||
reset_at=_parse_iso_reset(headers.get("anthropic-ratelimit-input-tokens-reset", "")),
|
||
)
|
||
result.output_tokens = TokenWindow(
|
||
limit=_parse_int_header(headers, "anthropic-ratelimit-output-tokens-limit"),
|
||
remaining=_parse_int_header(headers, "anthropic-ratelimit-output-tokens-remaining"),
|
||
reset_at=_parse_iso_reset(headers.get("anthropic-ratelimit-output-tokens-reset", "")),
|
||
)
|
||
result.requests = RequestWindow(
|
||
limit=_parse_int_header(headers, "anthropic-ratelimit-requests-limit"),
|
||
remaining=_parse_int_header(headers, "anthropic-ratelimit-requests-remaining"),
|
||
reset_at=_parse_iso_reset(headers.get("anthropic-ratelimit-requests-reset", "")),
|
||
)
|
||
|
||
|
||
def _pick_anthropic_probe_model(models: list[str]) -> str | None:
|
||
if not models:
|
||
return None
|
||
priorities = ("haiku", "sonnet", "opus")
|
||
lowered = [(m, m.lower()) for m in models]
|
||
for priority in priorities:
|
||
for original, lowered_name in lowered:
|
||
if priority in lowered_name:
|
||
return original
|
||
return models[0]
|
||
|
||
|
||
def _pick_openai_probe_model(models: list[str]) -> str | None:
|
||
if not models:
|
||
return None
|
||
priorities = (
|
||
"gpt-5.5",
|
||
"gpt-5.4",
|
||
"gpt-5.3-codex",
|
||
"gpt-5.2-codex",
|
||
"gpt-5.1-codex",
|
||
"gpt-5-codex",
|
||
"codex",
|
||
"gpt-4.1-mini",
|
||
"gpt-4o-mini",
|
||
"gpt-4.1",
|
||
"gpt-4o",
|
||
"o4-mini",
|
||
)
|
||
lowered = [(m, m.lower()) for m in models]
|
||
for priority in priorities:
|
||
for original, lowered_name in lowered:
|
||
if priority in lowered_name:
|
||
return original
|
||
return models[0]
|
||
|
||
|
||
def _normalize_base(base_url: str | None, default_base: str, *, strip_suffixes: tuple[str, ...]) -> str:
|
||
base = (base_url or default_base).strip().rstrip("/")
|
||
lowered = base.lower()
|
||
for suffix in strip_suffixes:
|
||
if lowered.endswith(suffix):
|
||
base = base[: -len(suffix)]
|
||
break
|
||
return base.rstrip("/")
|
||
|
||
|
||
def _parse_openai_reset(value: str) -> datetime | None:
|
||
"""Parse an OpenAI reset header: ISO datetime OR duration like '1m30s'."""
|
||
if not value:
|
||
return None
|
||
# ISO format first
|
||
if "T" in value or value.endswith("Z"):
|
||
return _parse_iso_reset(value)
|
||
# Duration
|
||
m = _OAI_DURATION_RE.match(value.strip())
|
||
if m and any(m.groups()):
|
||
h = float(m.group(1) or 0)
|
||
mn = float(m.group(2) or 0)
|
||
s = float(m.group(3) or 0)
|
||
total_seconds = h * 3600 + mn * 60 + s
|
||
return utcnow() + timedelta(seconds=total_seconds)
|
||
return None
|
||
|
||
|
||
def _apply_openai_ratelimit_headers(result: ProviderUsageLive, headers: dict[str, str]) -> None:
|
||
result.tokens = TokenWindow(
|
||
limit=_parse_int_header(headers, "x-ratelimit-limit-tokens"),
|
||
remaining=_parse_int_header(headers, "x-ratelimit-remaining-tokens"),
|
||
reset_at=_parse_openai_reset(headers.get("x-ratelimit-reset-tokens", "")),
|
||
)
|
||
result.requests = RequestWindow(
|
||
limit=_parse_int_header(headers, "x-ratelimit-limit-requests"),
|
||
remaining=_parse_int_header(headers, "x-ratelimit-remaining-requests"),
|
||
reset_at=_parse_openai_reset(headers.get("x-ratelimit-reset-requests", "")),
|
||
)
|
||
|
||
|
||
def _extract_openai_usage(payload: Any) -> tuple[int | None, int | None]:
|
||
if not isinstance(payload, dict):
|
||
return (None, None)
|
||
usage = payload.get("usage")
|
||
if not isinstance(usage, dict):
|
||
return (None, None)
|
||
|
||
# Responses API style
|
||
in_tok = usage.get("input_tokens")
|
||
out_tok = usage.get("output_tokens")
|
||
if isinstance(in_tok, int) or isinstance(out_tok, int):
|
||
return (
|
||
in_tok if isinstance(in_tok, int) else None,
|
||
out_tok if isinstance(out_tok, int) else None,
|
||
)
|
||
|
||
# Chat Completions style
|
||
in_tok = usage.get("prompt_tokens")
|
||
out_tok = usage.get("completion_tokens")
|
||
return (
|
||
in_tok if isinstance(in_tok, int) else None,
|
||
out_tok if isinstance(out_tok, int) else None,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Provider-specific fetch functions
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _fetch_anthropic(api_key: str, base_url: str | None) -> ProviderUsageLive:
|
||
base = _normalize_base(
|
||
base_url,
|
||
"https://api.anthropic.com",
|
||
strip_suffixes=("/v1",),
|
||
)
|
||
now = utcnow()
|
||
result = ProviderUsageLive(provider="anthropic", account_key="", checked_at=now, reachable=False)
|
||
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
resp = await client.get(
|
||
f"{base}/v1/models",
|
||
headers={
|
||
"x-api-key": api_key,
|
||
"anthropic-version": "2023-06-01",
|
||
},
|
||
)
|
||
except (httpx.ConnectError, httpx.TimeoutException) as exc:
|
||
result.error = f"Connection failed: {exc}"
|
||
return result
|
||
except Exception as exc:
|
||
result.error = str(exc)
|
||
return result
|
||
|
||
if resp.status_code == 401:
|
||
result.error = "Invalid API key (401)."
|
||
return result
|
||
if resp.status_code not in (200, 429):
|
||
result.error = f"Provider returned HTTP {resp.status_code}."
|
||
# Still try to parse headers on 429 (rate limited but data is there)
|
||
if resp.status_code != 429:
|
||
return result
|
||
|
||
h = {k.lower(): v for k, v in resp.headers.items()}
|
||
result.reachable = True
|
||
result.raw_headers = {k: v for k, v in h.items() if "ratelimit" in k}
|
||
|
||
_apply_anthropic_ratelimit_headers(result, h)
|
||
|
||
# Extract model IDs
|
||
try:
|
||
data = resp.json()
|
||
items = data.get("data") or data if isinstance(data.get("data"), list) else []
|
||
result.models = [m.get("id", "") for m in items if isinstance(m, dict) and m.get("id")]
|
||
except Exception:
|
||
pass
|
||
|
||
# Some tiers/paths may omit ratelimit headers on /v1/models.
|
||
# Fallback to a minimal /v1/messages probe so we can still surface usage/time.
|
||
if (
|
||
result.tokens.limit is None
|
||
and result.input_tokens.limit is None
|
||
and result.requests.limit is None
|
||
):
|
||
probe_model = _pick_anthropic_probe_model(result.models)
|
||
if probe_model:
|
||
result.sample_model = probe_model
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
probe_resp = await client.post(
|
||
f"{base}/v1/messages",
|
||
headers={
|
||
"x-api-key": api_key,
|
||
"anthropic-version": "2023-06-01",
|
||
"content-type": "application/json",
|
||
},
|
||
json={
|
||
"model": probe_model,
|
||
"max_tokens": 1,
|
||
"messages": [{"role": "user", "content": "Usage probe"}],
|
||
},
|
||
)
|
||
except Exception:
|
||
probe_resp = None
|
||
|
||
if probe_resp is not None:
|
||
probe_headers = {k.lower(): v for k, v in probe_resp.headers.items()}
|
||
probe_rl_headers = {k: v for k, v in probe_headers.items() if "ratelimit" in k}
|
||
if probe_rl_headers:
|
||
result.raw_headers = probe_rl_headers
|
||
_apply_anthropic_ratelimit_headers(result, probe_headers)
|
||
if probe_resp.status_code == 200:
|
||
try:
|
||
payload = probe_resp.json()
|
||
usage = payload.get("usage") if isinstance(payload, dict) else None
|
||
if isinstance(usage, dict):
|
||
in_tok = usage.get("input_tokens")
|
||
out_tok = usage.get("output_tokens")
|
||
if isinstance(in_tok, int):
|
||
result.sample_input_tokens = in_tok
|
||
if isinstance(out_tok, int):
|
||
result.sample_output_tokens = out_tok
|
||
except Exception:
|
||
pass
|
||
elapsed_ms = probe_resp.elapsed.total_seconds() * 1000.0
|
||
result.sample_latency_ms = int(max(0.0, round(elapsed_ms)))
|
||
|
||
return result
|
||
|
||
|
||
async def _fetch_openai(api_key: str, base_url: str | None) -> ProviderUsageLive:
|
||
base = _normalize_base(
|
||
base_url,
|
||
"https://api.openai.com",
|
||
strip_suffixes=("/v1",),
|
||
)
|
||
now = utcnow()
|
||
result = ProviderUsageLive(provider="openai", account_key="", checked_at=now, reachable=False)
|
||
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
resp = await client.get(
|
||
f"{base}/v1/models",
|
||
headers={"Authorization": f"Bearer {api_key}"},
|
||
)
|
||
except (httpx.ConnectError, httpx.TimeoutException) as exc:
|
||
result.error = f"Connection failed: {exc}"
|
||
return result
|
||
except Exception as exc:
|
||
result.error = str(exc)
|
||
return result
|
||
|
||
if resp.status_code == 401:
|
||
result.error = "Invalid API key (401)."
|
||
return result
|
||
if resp.status_code not in (200, 429):
|
||
result.error = f"Provider returned HTTP {resp.status_code}."
|
||
if resp.status_code != 429:
|
||
return result
|
||
|
||
h = {k.lower(): v for k, v in resp.headers.items()}
|
||
result.reachable = True
|
||
result.raw_headers = {k: v for k, v in h.items() if "ratelimit" in k}
|
||
|
||
_apply_openai_ratelimit_headers(result, h)
|
||
|
||
try:
|
||
data = resp.json()
|
||
items = data.get("data") or []
|
||
result.models = [m.get("id", "") for m in items if isinstance(m, dict) and m.get("id")]
|
||
except Exception:
|
||
pass
|
||
|
||
if result.tokens.limit is None and result.requests.limit is None:
|
||
probe_model = _pick_openai_probe_model(result.models)
|
||
if probe_model:
|
||
result.sample_model = probe_model
|
||
# min 16 tokens: gpt-5.x models reject max_output_tokens < 16
|
||
probe_endpoints: list[tuple[str, dict[str, Any]]] = [
|
||
(
|
||
f"{base}/v1/responses",
|
||
{
|
||
"model": probe_model,
|
||
"input": "Usage probe",
|
||
"max_output_tokens": 16,
|
||
},
|
||
),
|
||
(
|
||
f"{base}/v1/chat/completions",
|
||
{
|
||
"model": probe_model,
|
||
"messages": [{"role": "user", "content": "Usage probe"}],
|
||
"max_tokens": 16,
|
||
},
|
||
),
|
||
]
|
||
for endpoint, body in probe_endpoints:
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
probe_resp = await client.post(
|
||
endpoint,
|
||
headers={
|
||
"Authorization": f"Bearer {api_key}",
|
||
"content-type": "application/json",
|
||
},
|
||
json=body,
|
||
)
|
||
except Exception:
|
||
continue
|
||
# Quota exhaustion is a 429 with a distinct error code — surface it
|
||
# clearly rather than treating it as a transient rate limit.
|
||
if probe_resp.status_code == 429:
|
||
try:
|
||
err = probe_resp.json().get("error", {})
|
||
if err.get("code") == "insufficient_quota" or "exceeded your current quota" in str(err.get("message", "")):
|
||
result.error = "Quota exhausted — add credits at platform.openai.com/billing."
|
||
elapsed_ms = probe_resp.elapsed.total_seconds() * 1000.0
|
||
result.sample_latency_ms = int(max(0.0, round(elapsed_ms)))
|
||
break
|
||
except Exception:
|
||
pass
|
||
probe_headers = {k.lower(): v for k, v in probe_resp.headers.items()}
|
||
probe_rl_headers = {k: v for k, v in probe_headers.items() if "ratelimit" in k}
|
||
if probe_rl_headers:
|
||
result.raw_headers = probe_rl_headers
|
||
_apply_openai_ratelimit_headers(result, probe_headers)
|
||
if probe_resp.status_code == 200:
|
||
try:
|
||
payload = probe_resp.json()
|
||
in_tok, out_tok = _extract_openai_usage(payload)
|
||
if in_tok is not None:
|
||
result.sample_input_tokens = in_tok
|
||
if out_tok is not None:
|
||
result.sample_output_tokens = out_tok
|
||
except Exception:
|
||
pass
|
||
elapsed_ms = probe_resp.elapsed.total_seconds() * 1000.0
|
||
result.sample_latency_ms = int(max(0.0, round(elapsed_ms)))
|
||
if (
|
||
result.tokens.limit is not None
|
||
or result.requests.limit is not None
|
||
or result.sample_input_tokens is not None
|
||
or result.sample_output_tokens is not None
|
||
):
|
||
break
|
||
|
||
return result
|
||
|
||
|
||
async def _fetch_ollama(base_url: str | None, api_key: str | None) -> ProviderUsageLive:
|
||
base = _normalize_base(
|
||
base_url,
|
||
"http://localhost:11434",
|
||
strip_suffixes=("/api",),
|
||
)
|
||
now = utcnow()
|
||
result = ProviderUsageLive(provider="ollama", account_key="", checked_at=now, reachable=False)
|
||
|
||
headers: dict[str, str] = {}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
resp = await client.get(f"{base}/api/tags", headers=headers)
|
||
except (httpx.ConnectError, httpx.TimeoutException) as exc:
|
||
result.error = f"Ollama unreachable: {exc}"
|
||
return result
|
||
except Exception as exc:
|
||
result.error = str(exc)
|
||
return result
|
||
|
||
if resp.status_code not in (200,):
|
||
result.error = f"Ollama returned HTTP {resp.status_code}."
|
||
return result
|
||
|
||
result.reachable = True
|
||
# Ollama has no rate limits — just expose available models
|
||
try:
|
||
data = resp.json()
|
||
models_raw = data.get("models") or []
|
||
result.models = [m.get("name", "") for m in models_raw if isinstance(m, dict) and m.get("name")]
|
||
except Exception:
|
||
pass
|
||
|
||
if result.models:
|
||
result.sample_model = result.models[0]
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
probe_resp = await client.post(
|
||
f"{base}/api/generate",
|
||
headers={**headers, "content-type": "application/json"},
|
||
json={
|
||
"model": result.sample_model,
|
||
"prompt": "Usage probe",
|
||
"stream": False,
|
||
"options": {"num_predict": 1},
|
||
},
|
||
)
|
||
except Exception:
|
||
probe_resp = None
|
||
if probe_resp is not None and probe_resp.status_code == 200:
|
||
try:
|
||
payload = probe_resp.json()
|
||
in_tok = payload.get("prompt_eval_count")
|
||
out_tok = payload.get("eval_count")
|
||
total_duration_ns = payload.get("total_duration")
|
||
if isinstance(in_tok, int):
|
||
result.sample_input_tokens = in_tok
|
||
if isinstance(out_tok, int):
|
||
result.sample_output_tokens = out_tok
|
||
if isinstance(total_duration_ns, int):
|
||
result.sample_latency_ms = max(0, int(round(total_duration_ns / 1_000_000)))
|
||
except Exception:
|
||
pass
|
||
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Subscription usage fetchers — require session/OAuth tokens, not API keys
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_ANTHROPIC_SUBSCRIPTION_URL = "https://api.anthropic.com/api/oauth/usage"
|
||
_CODEX_SUBSCRIPTION_URL = "https://chatgpt.com/backend-api/wham/usage"
|
||
|
||
# Path to the Claude Code OAuth credentials file, mounted read-only from the host.
|
||
_CLAUDE_CREDENTIALS_PATH = os.environ.get("CLAUDE_CREDENTIALS_PATH", "")
|
||
|
||
# Path to the Codex CLI auth file, mounted read-only from the host.
|
||
_CODEX_CREDENTIALS_PATH = os.environ.get("CODEX_CREDENTIALS_PATH", "")
|
||
|
||
_claude_oauth_cache: tuple[float, str] | None = None # (expires_at_ms, access_token)
|
||
|
||
|
||
def _read_claude_local_oauth_token() -> str | None:
|
||
"""Read the Claude Code OAuth access token from the host credentials file.
|
||
|
||
Returns a valid access token, or None if the file is absent, unreadable,
|
||
or the token has expired. The caller should fall back to the manually
|
||
configured session_key when this returns None.
|
||
"""
|
||
global _claude_oauth_cache
|
||
|
||
path = _CLAUDE_CREDENTIALS_PATH
|
||
if not path:
|
||
# Fall back to the default XDG location when no explicit path is set.
|
||
home = os.path.expanduser("~")
|
||
path = os.path.join(home, ".claude", ".credentials.json")
|
||
|
||
try:
|
||
with open(path) as fh:
|
||
data = _json_module.load(fh)
|
||
except (FileNotFoundError, PermissionError, ValueError, OSError):
|
||
return None
|
||
|
||
oauth = data.get("claudeAiOauth")
|
||
if not isinstance(oauth, dict):
|
||
return None
|
||
|
||
access_token = oauth.get("accessToken")
|
||
expires_at = oauth.get("expiresAt")
|
||
|
||
if not isinstance(access_token, str) or not access_token:
|
||
return None
|
||
if not isinstance(expires_at, (int, float)) or expires_at <= 0:
|
||
return None
|
||
|
||
now_ms = _time_module.time() * 1000
|
||
if expires_at <= now_ms:
|
||
logger.debug("provider_usage.claude_oauth.token_expired expires_at=%s", expires_at)
|
||
return None
|
||
|
||
return access_token
|
||
|
||
|
||
def _read_codex_local_token() -> str | None:
|
||
"""Read the Codex CLI JWT access token from the host auth file (~/.codex/auth.json).
|
||
|
||
The Codex CLI stores a long-lived JWT (typically ~9 days) issued by OpenAI's
|
||
OAuth flow. This token works as a Bearer credential for the ChatGPT backend
|
||
API, including the wham/usage subscription endpoint.
|
||
|
||
Returns the access_token string, or None if the file is absent, unreadable,
|
||
or the token has expired. The caller should fall back to the manually
|
||
configured session_key when this returns None.
|
||
"""
|
||
path = _CODEX_CREDENTIALS_PATH
|
||
if not path:
|
||
home = os.path.expanduser("~")
|
||
path = os.path.join(home, ".codex", "auth.json")
|
||
|
||
try:
|
||
with open(path) as fh:
|
||
data = _json_module.load(fh)
|
||
except (FileNotFoundError, PermissionError, ValueError, OSError):
|
||
return None
|
||
|
||
tokens = data.get("tokens")
|
||
if not isinstance(tokens, dict):
|
||
return None
|
||
|
||
access_token = tokens.get("access_token")
|
||
if not isinstance(access_token, str) or not access_token:
|
||
return None
|
||
|
||
# Decode the JWT payload (no signature verification needed here — we just
|
||
# want to avoid using a token we know is expired before making a network call).
|
||
parts = access_token.split(".")
|
||
if len(parts) >= 2:
|
||
import base64 as _base64
|
||
try:
|
||
pad = parts[1] + "=" * (4 - len(parts[1]) % 4)
|
||
payload = _json_module.loads(_base64.urlsafe_b64decode(pad))
|
||
exp = payload.get("exp")
|
||
if isinstance(exp, (int, float)) and exp <= _time_module.time():
|
||
logger.debug("provider_usage.codex_oauth.token_expired exp=%s", exp)
|
||
return None
|
||
except Exception:
|
||
pass # Can't decode JWT — proceed optimistically
|
||
|
||
return access_token
|
||
|
||
|
||
_ANTHROPIC_WINDOW_LABELS: dict[str, str] = {
|
||
"five_hour": "Current session",
|
||
"seven_day": "All models",
|
||
"seven_day_sonnet": "Sonnet",
|
||
"seven_day_opus": "Opus",
|
||
}
|
||
|
||
|
||
async def _fetch_anthropic_subscription(session_key: str) -> list[SubscriptionWindow]:
|
||
"""Fetch Claude subscription usage windows via the Anthropic OAuth usage endpoint.
|
||
|
||
Requires a Claude.ai session key (the ``sessionKey`` cookie value, which
|
||
starts with ``sk-ant-``). Returns an empty list if the key is invalid or
|
||
the endpoint is unreachable.
|
||
"""
|
||
headers = {
|
||
"Authorization": f"Bearer {session_key}",
|
||
"User-Agent": "pipeline",
|
||
"Accept": "application/json",
|
||
"anthropic-version": "2023-06-01",
|
||
"anthropic-beta": "oauth-2025-04-20",
|
||
}
|
||
# Retry once on 429 — the subscription endpoint can be rate-limited when
|
||
# called right after another api.anthropic.com request (e.g. /v1/models).
|
||
for attempt in range(2):
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
resp = await client.get(_ANTHROPIC_SUBSCRIPTION_URL, headers=headers)
|
||
except Exception as exc:
|
||
logger.warning("provider_usage.subscription.anthropic.fetch_failed error=%s", exc)
|
||
return []
|
||
if resp.status_code == 429 and attempt == 0:
|
||
await asyncio.sleep(1.5)
|
||
continue
|
||
break
|
||
|
||
if not resp.status_code == 200:
|
||
logger.debug(
|
||
"provider_usage.subscription.anthropic.http_error status=%s body=%s",
|
||
resp.status_code,
|
||
resp.text[:200],
|
||
)
|
||
return []
|
||
|
||
try:
|
||
data = resp.json()
|
||
except Exception:
|
||
return []
|
||
|
||
windows: list[SubscriptionWindow] = []
|
||
for key, label in _ANTHROPIC_WINDOW_LABELS.items():
|
||
block = data.get(key)
|
||
if not isinstance(block, dict):
|
||
continue
|
||
utilization = block.get("utilization")
|
||
if utilization is None:
|
||
continue
|
||
resets_at_str = block.get("resets_at")
|
||
reset_dt: datetime | None = None
|
||
if isinstance(resets_at_str, str):
|
||
try:
|
||
normalized = resets_at_str.replace("Z", "+00:00")
|
||
parsed = datetime.fromisoformat(normalized)
|
||
if parsed.tzinfo is not None:
|
||
reset_dt = parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
||
else:
|
||
reset_dt = parsed
|
||
except ValueError:
|
||
pass
|
||
windows.append(SubscriptionWindow(
|
||
key=key,
|
||
label=label,
|
||
pct_used=min(100.0, round(float(utilization), 1)), # already 0–100
|
||
reset_at=reset_dt,
|
||
))
|
||
|
||
logger.info(
|
||
"provider_usage.subscription.anthropic.fetched windows=%d",
|
||
len(windows),
|
||
)
|
||
return windows
|
||
|
||
|
||
async def _fetch_codex_subscription(session_key: str) -> tuple[list[SubscriptionWindow], str | None]:
|
||
"""Fetch Codex/ChatGPT subscription usage windows via the wham/usage endpoint.
|
||
|
||
Requires a ChatGPT bearer token (from browser session, not the standard
|
||
API key). Returns (windows, plan_label) or ([], None) on failure.
|
||
"""
|
||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||
try:
|
||
resp = await client.get(
|
||
_CODEX_SUBSCRIPTION_URL,
|
||
headers={
|
||
"Authorization": f"Bearer {session_key}",
|
||
"Accept": "application/json",
|
||
},
|
||
)
|
||
except Exception as exc:
|
||
logger.warning("provider_usage.subscription.codex.fetch_failed error=%s", exc)
|
||
return [], None
|
||
|
||
if resp.status_code != 200:
|
||
logger.debug(
|
||
"provider_usage.subscription.codex.http_error status=%s",
|
||
resp.status_code,
|
||
)
|
||
return [], None
|
||
|
||
try:
|
||
data = resp.json()
|
||
except Exception:
|
||
return [], None
|
||
|
||
windows: list[SubscriptionWindow] = []
|
||
rate_limit = data.get("rate_limit") or {}
|
||
|
||
primary = rate_limit.get("primary_window")
|
||
if isinstance(primary, dict) and primary.get("used_percent") is not None:
|
||
window_seconds = primary.get("limit_window_seconds") or 10800
|
||
window_hours = round(window_seconds / 3600)
|
||
reset_ts = primary.get("reset_at")
|
||
reset_dt: datetime | None = None
|
||
if isinstance(reset_ts, (int, float)):
|
||
reset_dt = datetime.fromtimestamp(reset_ts, tz=timezone.utc).replace(tzinfo=None)
|
||
windows.append(SubscriptionWindow(
|
||
key="primary",
|
||
label=f"Current session ({window_hours}h)",
|
||
pct_used=round(float(primary["used_percent"]), 1),
|
||
reset_at=reset_dt,
|
||
))
|
||
|
||
secondary = rate_limit.get("secondary_window")
|
||
if isinstance(secondary, dict) and secondary.get("used_percent") is not None:
|
||
sec_seconds = secondary.get("limit_window_seconds") or 86400
|
||
sec_hours = round(sec_seconds / 3600)
|
||
sec_label = "All models (weekly)" if sec_hours >= 168 else f"All models ({sec_hours}h)"
|
||
reset_ts = secondary.get("reset_at")
|
||
reset_dt = None
|
||
if isinstance(reset_ts, (int, float)):
|
||
reset_dt = datetime.fromtimestamp(reset_ts, tz=timezone.utc).replace(tzinfo=None)
|
||
windows.append(SubscriptionWindow(
|
||
key="secondary",
|
||
label=sec_label,
|
||
pct_used=round(float(secondary["used_percent"]), 1),
|
||
reset_at=reset_dt,
|
||
))
|
||
|
||
# Plan label
|
||
plan_type = data.get("plan_type")
|
||
credits = data.get("credits") or {}
|
||
balance = credits.get("balance")
|
||
plan_label: str | None = None
|
||
if isinstance(balance, (int, float)) and balance > 0:
|
||
plan_label = f"{plan_type} (${float(balance):.2f})" if plan_type else f"${float(balance):.2f}"
|
||
elif plan_type:
|
||
plan_label = str(plan_type)
|
||
|
||
logger.info(
|
||
"provider_usage.subscription.codex.fetched windows=%d plan=%s",
|
||
len(windows),
|
||
plan_label,
|
||
)
|
||
return windows, plan_label
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# In-memory TTL cache + in-flight deduplication
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_cache: dict[str, tuple[datetime, ProviderUsageLive, int]] = {}
|
||
# Tracks in-progress fetches so concurrent requests share one result instead
|
||
# of each racing to hit the provider API (which triggers 429 cascades).
|
||
_inflight: dict[str, asyncio.Future[ProviderUsageLive]] = {}
|
||
|
||
|
||
def _get_cached(credential_id: str) -> ProviderUsageLive | None:
|
||
entry = _cache.get(credential_id)
|
||
if entry is None:
|
||
return None
|
||
cached_at, result, ttl = entry
|
||
if (utcnow() - cached_at).total_seconds() > ttl:
|
||
del _cache[credential_id]
|
||
return None
|
||
return result
|
||
|
||
|
||
def _set_cached(credential_id: str, result: ProviderUsageLive, ttl: int = CACHE_TTL_SECONDS) -> None:
|
||
_cache[credential_id] = (utcnow(), result, ttl)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Public entry point
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def _do_fetch_provider_usage(
|
||
credential_id: str,
|
||
provider: str,
|
||
account_key: str,
|
||
api_key: str | None,
|
||
base_url: str | None,
|
||
session_key: str | None,
|
||
) -> ProviderUsageLive:
|
||
"""Inner fetch — called once per credential even under concurrent load."""
|
||
subscription_attempted = False
|
||
|
||
if provider == "anthropic":
|
||
# Auto-detected local OAuth token (sk-ant-oat01-) takes precedence — it is
|
||
# the correct credential type for the oauth/usage endpoint and is always
|
||
# fresher than a manually stored key. Fall back to session_key only when
|
||
# no local token is found (e.g. Pipeline running on a different machine).
|
||
local_oauth = _read_claude_local_oauth_token()
|
||
effective_session_key = local_oauth or session_key
|
||
|
||
if not api_key and not effective_session_key:
|
||
result = ProviderUsageLive(
|
||
provider=provider, account_key=account_key,
|
||
checked_at=utcnow(), reachable=False,
|
||
error="No API key configured.",
|
||
)
|
||
elif api_key:
|
||
result = await _fetch_anthropic(api_key, base_url)
|
||
else:
|
||
result = ProviderUsageLive(
|
||
provider=provider, account_key=account_key,
|
||
checked_at=utcnow(), reachable=True,
|
||
)
|
||
# Overlay subscription windows from OAuth token (auto or explicit)
|
||
if effective_session_key and result.reachable is not False:
|
||
subscription_attempted = True
|
||
sub_windows = await _fetch_anthropic_subscription(effective_session_key)
|
||
if sub_windows:
|
||
result.subscription_windows = sub_windows
|
||
result.reachable = True
|
||
else:
|
||
result.error = result.error or "No subscription data returned."
|
||
|
||
elif provider in ("openai", "codex"):
|
||
if not api_key and not session_key:
|
||
result = ProviderUsageLive(
|
||
provider=provider, account_key=account_key,
|
||
checked_at=utcnow(), reachable=False,
|
||
error="No API key configured.",
|
||
)
|
||
elif api_key:
|
||
result = await _fetch_openai(api_key, base_url)
|
||
else:
|
||
result = ProviderUsageLive(
|
||
provider=provider, account_key=account_key,
|
||
checked_at=utcnow(), reachable=True,
|
||
)
|
||
# Overlay subscription windows.
|
||
# Explicit session_key wins (allows a second account to override auto-detection).
|
||
# Fall back to auto-detected Codex CLI token when no key is stored.
|
||
local_codex = _read_codex_local_token()
|
||
effective_codex_key = session_key or local_codex
|
||
if effective_codex_key and result.reachable is not False:
|
||
subscription_attempted = True
|
||
sub_windows, plan_label = await _fetch_codex_subscription(effective_codex_key)
|
||
if sub_windows:
|
||
result.subscription_windows = sub_windows
|
||
result.subscription_plan = plan_label
|
||
result.reachable = True
|
||
else:
|
||
result.error = result.error or "No subscription data returned."
|
||
|
||
elif provider == "ollama":
|
||
result = await _fetch_ollama(base_url, api_key)
|
||
|
||
else:
|
||
result = ProviderUsageLive(
|
||
provider=provider, account_key=account_key,
|
||
checked_at=utcnow(), reachable=False,
|
||
error=f"Live usage not supported for provider '{provider}'.",
|
||
)
|
||
|
||
result.account_key = account_key
|
||
# Use a short TTL when subscription windows were expected but came back empty
|
||
# (e.g. a transient 429 at startup). This avoids persisting a 60s stale result
|
||
# while still preventing the thundering-herd that occurs with no caching at all.
|
||
ttl = CACHE_TTL_FAILURE_SECONDS if (subscription_attempted and not result.subscription_windows) else CACHE_TTL_SECONDS
|
||
_set_cached(credential_id, result, ttl=ttl)
|
||
logger.info(
|
||
"provider_usage.checked provider=%s account=%s reachable=%s windows=%d error=%s",
|
||
provider, account_key, result.reachable, len(result.subscription_windows), result.error,
|
||
)
|
||
return result
|
||
|
||
|
||
async def fetch_provider_usage(
|
||
credential_id: str,
|
||
provider: str,
|
||
account_key: str,
|
||
api_key: str | None,
|
||
base_url: str | None,
|
||
*,
|
||
session_key: str | None = None,
|
||
force_refresh: bool = False,
|
||
) -> ProviderUsageLive:
|
||
"""Fetch live usage from the provider API.
|
||
|
||
When ``session_key`` is provided, also fetches subscription-plan usage
|
||
windows (e.g. "Current session 77% used, resets in 23 min") in addition
|
||
to the standard API rate-limit diagnostics.
|
||
|
||
Results are cached for CACHE_TTL_SECONDS (short CACHE_TTL_FAILURE_SECONDS
|
||
when subscription windows are unavailable). Concurrent requests for the same
|
||
credential share one in-flight fetch to avoid rate-limit cascades.
|
||
Pass force_refresh=True to bypass the cache.
|
||
"""
|
||
if not force_refresh:
|
||
cached = _get_cached(credential_id)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
# In-flight deduplication: if another coroutine is already fetching this
|
||
# credential, await its result rather than racing to hit the provider API.
|
||
loop = asyncio.get_event_loop()
|
||
if credential_id in _inflight:
|
||
return await asyncio.shield(_inflight[credential_id])
|
||
|
||
fut: asyncio.Future[ProviderUsageLive] = loop.create_future()
|
||
_inflight[credential_id] = fut
|
||
try:
|
||
result = await _do_fetch_provider_usage(
|
||
credential_id, provider, account_key, api_key, base_url, session_key,
|
||
)
|
||
fut.set_result(result)
|
||
return result
|
||
except Exception as exc:
|
||
fut.set_exception(exc)
|
||
raise
|
||
finally:
|
||
_inflight.pop(credential_id, None)
|