627 lines
24 KiB
Python
627 lines
24 KiB
Python
"""Runtime usage service — compute model spend, burn rate, and time remaining.
|
|
|
|
Data source: gateway RPC methods ``usage.cost`` and ``usage.status``.
|
|
Pricing: built-in defaults plus an optional JSON override file pointed to by
|
|
the ``RUNTIME_USAGE_PRICING_FILE`` environment variable.
|
|
|
|
Design decisions:
|
|
- Total tokens (input + output) are the basis for time-remaining predictions.
|
|
- All org members may call this endpoint (not admin-only).
|
|
- Parsing is fully defensive: malformed or missing fields default to zero.
|
|
- Unknown paid models are flagged ``unpriced=True`` so the UI can warn.
|
|
- Ollama / local models are flagged ``unpriced=False, cost_usd=0`` (free by design).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import re
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from app.core.logging import get_logger
|
|
from app.core.time import utcnow
|
|
from app.schemas.runtime_usage import (
|
|
ModelUsageEntry,
|
|
RuntimeUsageBurnRate,
|
|
RuntimeUsageCurrent,
|
|
RuntimeUsagePredictions,
|
|
RuntimeUsageResponse,
|
|
RuntimeUsageWindow,
|
|
TopSession,
|
|
)
|
|
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
|
|
from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pricing config (USD per million tokens)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PricingEntry = dict[str, float] # keys: input, output, cache_read, cache_write
|
|
|
|
DEFAULT_MODEL_PRICING: dict[str, _PricingEntry] = {
|
|
# Anthropic — Claude 4.x
|
|
"anthropic/claude-opus-4-7": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75},
|
|
"anthropic/claude-opus-4-5": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75},
|
|
"anthropic/claude-sonnet-4-6": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
|
"anthropic/claude-sonnet-4-5": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
|
"anthropic/claude-haiku-4-5": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00},
|
|
# Anthropic — Claude 3.x
|
|
"anthropic/claude-3-5-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
|
"anthropic/claude-3-5-haiku": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00},
|
|
"anthropic/claude-3-opus": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75},
|
|
"anthropic/claude-3-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
|
"anthropic/claude-3-haiku": {"input": 0.25, "output": 1.25, "cache_read": 0.03, "cache_write": 0.30},
|
|
# OpenAI — GPT-4o family
|
|
"openai/gpt-4o": {"input": 2.50, "output": 10.00, "cache_read": 1.25, "cache_write": 0.00},
|
|
"openai/gpt-4o-mini": {"input": 0.15, "output": 0.60, "cache_read": 0.075, "cache_write": 0.00},
|
|
"openai/gpt-4-turbo": {"input": 10.00, "output": 30.00, "cache_read": 0.00, "cache_write": 0.00},
|
|
"openai/gpt-4": {"input": 30.00, "output": 60.00, "cache_read": 0.00, "cache_write": 0.00},
|
|
"openai/gpt-3-5-turbo": {"input": 0.50, "output": 1.50, "cache_read": 0.00, "cache_write": 0.00},
|
|
# OpenAI — o-series reasoning
|
|
"openai/o1": {"input": 15.00, "output": 60.00, "cache_read": 7.50, "cache_write": 0.00},
|
|
"openai/o1-mini": {"input": 3.00, "output": 12.00, "cache_read": 1.50, "cache_write": 0.00},
|
|
"openai/o3": {"input": 10.00, "output": 40.00, "cache_read": 2.50, "cache_write": 0.00},
|
|
"openai/o3-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.55, "cache_write": 0.00},
|
|
"openai/o4-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.275, "cache_write": 0.00},
|
|
# Codex alias
|
|
"openai/codex": {"input": 0.00, "output": 0.00, "cache_read": 0.00, "cache_write": 0.00},
|
|
}
|
|
|
|
_pricing_cache: dict[str, _PricingEntry] | None = None
|
|
|
|
|
|
def load_pricing() -> dict[str, _PricingEntry]:
|
|
"""Return merged pricing: defaults + optional override file."""
|
|
global _pricing_cache
|
|
if _pricing_cache is not None:
|
|
return _pricing_cache
|
|
merged = dict(DEFAULT_MODEL_PRICING)
|
|
override_path = os.getenv("RUNTIME_USAGE_PRICING_FILE", "").strip()
|
|
if override_path:
|
|
try:
|
|
with open(override_path) as fh:
|
|
overrides = json.load(fh)
|
|
if isinstance(overrides, dict):
|
|
merged.update(overrides)
|
|
logger.info("runtime_usage.pricing.override_loaded path=%s", override_path)
|
|
except Exception as exc:
|
|
logger.warning("runtime_usage.pricing.override_failed path=%s error=%s", override_path, exc)
|
|
_pricing_cache = merged
|
|
return _pricing_cache
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Provider / model normalisation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PROVIDER_ALIASES: dict[str, str] = {
|
|
"anthropic": "anthropic",
|
|
"claude": "anthropic",
|
|
"openai": "openai",
|
|
"codex": "openai",
|
|
"ollama": "ollama",
|
|
"local": "ollama",
|
|
"gemini": "google",
|
|
"google": "google",
|
|
}
|
|
|
|
_MODEL_STRIP_RE = re.compile(
|
|
r"(-\d{8}|-latest|-preview|-instruct|-chat|-v\d+(\.\d+)*)$",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def normalize_provider(raw: str) -> str:
|
|
"""Normalise a provider string to a canonical lower-case slug."""
|
|
cleaned = raw.strip().lower()
|
|
return _PROVIDER_ALIASES.get(cleaned, cleaned or "unknown")
|
|
|
|
|
|
def normalize_model(raw: str) -> str:
|
|
"""Strip date stamps, version tags, and known suffixes from a model name."""
|
|
cleaned = raw.strip().lower()
|
|
# Remove provider prefix if present (e.g. "anthropic/claude-sonnet-4-5")
|
|
if "/" in cleaned:
|
|
cleaned = cleaned.split("/", 1)[1]
|
|
# Strip trailing date/version stamps
|
|
cleaned = _MODEL_STRIP_RE.sub("", cleaned)
|
|
return cleaned or raw.strip().lower()
|
|
|
|
|
|
def model_key(provider: str, model: str) -> str:
|
|
"""Return the canonical dict key ``"provider/model"``."""
|
|
return f"{provider}/{model}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cost estimation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def estimate_cost(
|
|
provider: str,
|
|
model: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
cache_read_tokens: int = 0,
|
|
cache_write_tokens: int = 0,
|
|
) -> tuple[float, bool]:
|
|
"""Return ``(cost_usd, unpriced)`` for the given token counts.
|
|
|
|
``unpriced=True`` means we have no price for this paid model — the caller
|
|
should surface a warning. Ollama / local models return ``(0.0, False)``.
|
|
"""
|
|
pricing = load_pricing()
|
|
key = model_key(provider, model)
|
|
entry = pricing.get(key)
|
|
|
|
if entry is None:
|
|
if provider == "ollama":
|
|
return 0.0, False
|
|
return 0.0, True # unknown paid model
|
|
|
|
per_m = 1_000_000
|
|
cost = (
|
|
input_tokens * entry.get("input", 0.0) / per_m
|
|
+ output_tokens * entry.get("output", 0.0) / per_m
|
|
+ cache_read_tokens * entry.get("cache_read", 0.0) / per_m
|
|
+ cache_write_tokens * entry.get("cache_write", 0.0) / per_m
|
|
)
|
|
return round(cost, 8), False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Gateway RPC helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _safe_call(
|
|
method: str,
|
|
config: GatewayClientConfig,
|
|
) -> dict[str, Any]:
|
|
"""Call a gateway method and return a dict, or an empty dict on any error."""
|
|
try:
|
|
result = await openclaw_call(method, config=config)
|
|
if isinstance(result, dict):
|
|
return result
|
|
logger.debug("runtime_usage.rpc.unexpected_type method=%s type=%s", method, type(result).__name__)
|
|
return {}
|
|
except OpenClawGatewayError as exc:
|
|
logger.debug("runtime_usage.rpc.gateway_error method=%s error=%s", method, exc)
|
|
return {}
|
|
except Exception as exc:
|
|
logger.warning("runtime_usage.rpc.error method=%s error=%s", method, exc)
|
|
return {}
|
|
|
|
|
|
def _get_float(d: dict[str, Any], *keys: str, default: float = 0.0) -> float:
|
|
for key in keys:
|
|
val = d.get(key)
|
|
if val is not None:
|
|
try:
|
|
return float(val)
|
|
except (TypeError, ValueError):
|
|
pass
|
|
return default
|
|
|
|
|
|
def _get_int(d: dict[str, Any], *keys: str, default: int = 0) -> int:
|
|
return int(_get_float(d, *keys, default=float(default)))
|
|
|
|
|
|
def _get_str(d: dict[str, Any], *keys: str, default: str = "") -> str:
|
|
for key in keys:
|
|
val = d.get(key)
|
|
if isinstance(val, str) and val.strip():
|
|
return val.strip()
|
|
return default
|
|
|
|
|
|
def _parse_datetime(value: object) -> datetime | None:
|
|
if not isinstance(value, str) or not value.strip():
|
|
return None
|
|
normalized = value.strip().replace("Z", "+00:00")
|
|
try:
|
|
parsed = datetime.fromisoformat(normalized)
|
|
if parsed.tzinfo is not None:
|
|
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
|
return parsed
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _parse_sessions(cost_raw: dict[str, Any]) -> list[dict[str, Any]]:
|
|
"""Extract a flat list of session dicts from the usage.cost response.
|
|
|
|
The gateway may return sessions at the top level or nested under a window
|
|
key (``"5hour"``, ``"today"``, ``"week"``, etc.). We prefer the most
|
|
granular available list.
|
|
"""
|
|
# Try flat list first
|
|
flat = cost_raw.get("sessions")
|
|
if isinstance(flat, list):
|
|
return [s for s in flat if isinstance(s, dict)]
|
|
|
|
# Try nested under window keys, in preference order
|
|
for window_key in ("5hour", "5h", "today", "week", "7day", "data"):
|
|
bucket = cost_raw.get(window_key)
|
|
if isinstance(bucket, dict):
|
|
nested = bucket.get("sessions")
|
|
if isinstance(nested, list):
|
|
return [s for s in nested if isinstance(s, dict)]
|
|
if isinstance(bucket, list):
|
|
return [s for s in bucket if isinstance(s, dict)]
|
|
|
|
return []
|
|
|
|
|
|
def _parse_session_usage(session: dict[str, Any]) -> dict[str, int]:
|
|
"""Extract token counts from a session dict, trying multiple key conventions."""
|
|
usage = session.get("usage") or session.get("tokens") or {}
|
|
if not isinstance(usage, dict):
|
|
usage = {}
|
|
|
|
return {
|
|
"input": _get_int(usage, "input_tokens", "inputTokens", "input", default=0)
|
|
or _get_int(session, "input_tokens", "inputTokens", default=0),
|
|
"output": _get_int(usage, "output_tokens", "outputTokens", "output", default=0)
|
|
or _get_int(session, "output_tokens", "outputTokens", default=0),
|
|
"cache_read": _get_int(usage, "cache_read_input_tokens", "cacheReadTokens", "cache_read", default=0)
|
|
or _get_int(session, "cache_read_tokens", "cacheReadTokens", default=0),
|
|
"cache_write": _get_int(usage, "cache_creation_input_tokens", "cacheWriteTokens", "cache_write", default=0)
|
|
or _get_int(session, "cache_write_tokens", "cacheWriteTokens", default=0),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Aggregation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def aggregate_per_model(
|
|
sessions: list[dict[str, Any]],
|
|
account_key: str = "default",
|
|
) -> dict[str, ModelUsageEntry]:
|
|
"""Roll up token counts and cost across sessions, keyed by provider/model."""
|
|
entries: dict[str, dict[str, Any]] = {}
|
|
|
|
for session in sessions:
|
|
raw_provider = _get_str(session, "provider", default="anthropic")
|
|
raw_model = _get_str(session, "model", "modelName", default="unknown")
|
|
provider = normalize_provider(raw_provider)
|
|
model = normalize_model(raw_model)
|
|
key = model_key(provider, model)
|
|
|
|
tokens = _parse_session_usage(session)
|
|
session_cost = _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
|
|
calls = _get_int(session, "calls", "messageCount", "messages", default=1)
|
|
# If gateway didn't compute cost, estimate it
|
|
if session_cost == 0.0:
|
|
session_cost, _ = estimate_cost(
|
|
provider, model,
|
|
tokens["input"], tokens["output"],
|
|
tokens["cache_read"], tokens["cache_write"],
|
|
)
|
|
_, unpriced = estimate_cost(provider, model, 0, 0)
|
|
|
|
if key not in entries:
|
|
entries[key] = {
|
|
"provider": provider,
|
|
"account_key": account_key,
|
|
"model": model,
|
|
"input_tokens": 0,
|
|
"output_tokens": 0,
|
|
"cache_read_tokens": 0,
|
|
"cache_write_tokens": 0,
|
|
"cost_usd": 0.0,
|
|
"calls": 0,
|
|
"unpriced": unpriced,
|
|
"source": "local_jsonl_estimate", # default source for aggregated session data
|
|
}
|
|
e = entries[key]
|
|
e["input_tokens"] += tokens["input"]
|
|
e["output_tokens"] += tokens["output"]
|
|
e["cache_read_tokens"] += tokens["cache_read"]
|
|
e["cache_write_tokens"] += tokens["cache_write"]
|
|
e["cost_usd"] += session_cost
|
|
e["calls"] += calls
|
|
|
|
return {
|
|
key: ModelUsageEntry(
|
|
**{**e, "total_tokens": e["input_tokens"] + e["output_tokens"]},
|
|
)
|
|
for key, e in entries.items()
|
|
}
|
|
|
|
|
|
def _top_sessions(
|
|
sessions: list[dict[str, Any]],
|
|
limit: int = 10,
|
|
) -> list[TopSession]:
|
|
rows = []
|
|
for session in sessions:
|
|
sid = _get_str(session, "sessionId", "id", "session_id", default="")
|
|
label = _get_str(session, "label", "name", "title") or None
|
|
model = _get_str(session, "model", "modelName") or None
|
|
if model:
|
|
provider = normalize_provider(_get_str(session, "provider", default="anthropic"))
|
|
model = model_key(provider, normalize_model(model))
|
|
tokens = _parse_session_usage(session)
|
|
total = tokens["input"] + tokens["output"]
|
|
cost = _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
|
|
if cost == 0.0 and model:
|
|
parts = model.split("/", 1)
|
|
if len(parts) == 2:
|
|
cost, _ = estimate_cost(
|
|
parts[0], parts[1],
|
|
tokens["input"], tokens["output"],
|
|
tokens["cache_read"], tokens["cache_write"],
|
|
)
|
|
updated = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity") or None
|
|
rows.append(TopSession(
|
|
session_id=sid,
|
|
label=label,
|
|
model=model,
|
|
cost_usd=round(cost, 8),
|
|
total_tokens=total,
|
|
updated_at=updated,
|
|
source="local_jsonl_estimate", # default source for session data
|
|
))
|
|
rows.sort(key=lambda r: r.cost_usd, reverse=True)
|
|
return rows[:limit]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Window and limit helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_WINDOW_HOURS = 5
|
|
|
|
|
|
def _build_window(
|
|
status_raw: dict[str, Any],
|
|
now: datetime,
|
|
account_key: str = "default",
|
|
) -> RuntimeUsageWindow:
|
|
"""Build the usage window, preferring gateway status data then falling back.
|
|
|
|
Source assignment:
|
|
- If gateway status provides explicit window data, use provider_native
|
|
- If API rate-limit headers are the only source, use provider_api_rate_limit
|
|
- If falling back to local logic, use local_jsonl_estimate
|
|
"""
|
|
# Check if gateway status provides explicit window data
|
|
has_window_start = status_raw.get("windowStart") or status_raw.get("window_start") or status_raw.get("period_start") or status_raw.get("started_at")
|
|
has_window_end = status_raw.get("windowEnd") or status_raw.get("window_end") or status_raw.get("period_end") or status_raw.get("resets_at")
|
|
|
|
# Check for API rate-limit headers (these indicate throttling, not subscription usage)
|
|
has_rate_limit_headers = (
|
|
status_raw.get("x_ratelimit_remaining") or
|
|
status_raw.get("x_ratelimit_limit") or
|
|
status_raw.get("x_ratelimit_reset") or
|
|
status_raw.get("anthropic_ratelimit_remaining") or
|
|
status_raw.get("anthropic_ratelimit_limit")
|
|
)
|
|
|
|
if has_window_start and has_window_end:
|
|
# Gateway status provides explicit window data
|
|
source = "provider_native"
|
|
confidence = "high"
|
|
elif has_rate_limit_headers:
|
|
# Only API rate-limit headers available - treat as diagnostics
|
|
source = "provider_api_rate_limit"
|
|
confidence = "medium"
|
|
else:
|
|
# Fall back to local logic (5-hour window from oldest event)
|
|
source = "local_jsonl_estimate"
|
|
confidence = "low"
|
|
|
|
started_at = _parse_datetime(
|
|
status_raw.get("windowStart") or status_raw.get("window_start")
|
|
or status_raw.get("period_start") or status_raw.get("started_at")
|
|
)
|
|
resets_at = _parse_datetime(
|
|
status_raw.get("windowEnd") or status_raw.get("window_end")
|
|
or status_raw.get("period_end") or status_raw.get("resets_at")
|
|
)
|
|
if started_at is None:
|
|
started_at = now - timedelta(hours=_WINDOW_HOURS)
|
|
if resets_at is None:
|
|
resets_at = started_at + timedelta(hours=_WINDOW_HOURS)
|
|
|
|
reset_delta = resets_at - now
|
|
reset_in_ms = max(0, int(reset_delta.total_seconds() * 1000))
|
|
return RuntimeUsageWindow(
|
|
key=f"{_WINDOW_HOURS}h",
|
|
started_at=started_at,
|
|
resets_at=resets_at,
|
|
reset_in_ms=reset_in_ms,
|
|
source=source,
|
|
confidence=confidence,
|
|
)
|
|
|
|
|
|
def _build_current(
|
|
per_model: dict[str, ModelUsageEntry],
|
|
status_raw: dict[str, Any],
|
|
account_key: str = "default",
|
|
) -> RuntimeUsageCurrent:
|
|
total_cost = round(sum(e.cost_usd for e in per_model.values()), 8)
|
|
total_tokens = sum(e.total_tokens for e in per_model.values())
|
|
total_calls = sum(e.calls for e in per_model.values())
|
|
|
|
# Try to get configured limits from the gateway status
|
|
raw_token_limit = _get_int(status_raw, "tokenLimit", "token_limit", "messageLimit", "message_limit", default=0)
|
|
token_limit = raw_token_limit or None
|
|
|
|
# Determine source for token limit
|
|
if raw_token_limit:
|
|
# Check for API rate-limit headers
|
|
has_rate_limit_headers = (
|
|
status_raw.get("x_ratelimit_remaining") or
|
|
status_raw.get("x_ratelimit_limit") or
|
|
status_raw.get("anthropic_ratelimit_remaining") or
|
|
status_raw.get("anthropic_ratelimit_limit")
|
|
)
|
|
if has_rate_limit_headers:
|
|
token_limit_source = "provider_api_rate_limit"
|
|
else:
|
|
token_limit_source = "configured_limit"
|
|
else:
|
|
token_limit_source = None
|
|
|
|
token_pct = int(min(100, total_tokens * 100 // raw_token_limit)) if raw_token_limit else None
|
|
|
|
raw_cost_limit = _get_float(status_raw, "costLimit", "cost_limit", "costLimitUsd", default=0.0)
|
|
cost_limit = raw_cost_limit or None
|
|
|
|
# Determine source for cost limit
|
|
if raw_cost_limit:
|
|
cost_limit_source = "configured_limit"
|
|
else:
|
|
cost_limit_source = None
|
|
|
|
cost_pct = int(min(100, total_cost * 100 / raw_cost_limit)) if raw_cost_limit else None
|
|
|
|
return RuntimeUsageCurrent(
|
|
total_cost_usd=total_cost,
|
|
total_tokens=total_tokens,
|
|
total_calls=total_calls,
|
|
token_limit=token_limit,
|
|
token_pct=token_pct,
|
|
cost_limit_usd=cost_limit,
|
|
cost_pct=cost_pct,
|
|
token_limit_source=token_limit_source,
|
|
cost_limit_source=cost_limit_source,
|
|
)
|
|
|
|
|
|
def _compute_burn_rate(
|
|
sessions: list[dict[str, Any]],
|
|
window: RuntimeUsageWindow,
|
|
now: datetime,
|
|
) -> RuntimeUsageBurnRate:
|
|
"""Compute tokens/min and cost/min from the most recent 60 minutes of sessions."""
|
|
cutoff = now - timedelta(minutes=60)
|
|
recent_tokens = 0
|
|
recent_cost = 0.0
|
|
|
|
for session in sessions:
|
|
raw_ts = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity")
|
|
ts = _parse_datetime(raw_ts)
|
|
if ts is None or ts < cutoff:
|
|
continue
|
|
tokens = _parse_session_usage(session)
|
|
recent_tokens += tokens["input"] + tokens["output"]
|
|
recent_cost += _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
|
|
|
|
# Rate per minute over the last 60 minutes
|
|
tokens_per_minute = round(recent_tokens / 60, 4)
|
|
cost_per_minute = round(recent_cost / 60, 8)
|
|
return RuntimeUsageBurnRate(
|
|
tokens_per_minute=tokens_per_minute,
|
|
cost_usd_per_minute=cost_per_minute,
|
|
)
|
|
|
|
|
|
def _build_predictions(
|
|
current: RuntimeUsageCurrent,
|
|
burn_rate: RuntimeUsageBurnRate,
|
|
window: RuntimeUsageWindow,
|
|
) -> RuntimeUsagePredictions:
|
|
"""Estimate time-to-limit in ms based on total-token burn rate."""
|
|
if burn_rate.tokens_per_minute <= 0 or current.token_limit is None:
|
|
return RuntimeUsagePredictions(time_to_limit_ms=None, safe=True)
|
|
|
|
tokens_remaining = current.token_limit - current.total_tokens
|
|
if tokens_remaining <= 0:
|
|
return RuntimeUsagePredictions(time_to_limit_ms=0, safe=False)
|
|
|
|
minutes_to_limit = tokens_remaining / burn_rate.tokens_per_minute
|
|
time_to_limit_ms = int(minutes_to_limit * 60 * 1000)
|
|
safe = time_to_limit_ms > window.reset_in_ms
|
|
return RuntimeUsagePredictions(time_to_limit_ms=time_to_limit_ms, safe=safe)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public service entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def get_runtime_usage(
|
|
gateway_id: UUID,
|
|
config: GatewayClientConfig,
|
|
account_key: str = "default",
|
|
) -> RuntimeUsageResponse:
|
|
"""Fetch and aggregate runtime usage for one gateway.
|
|
|
|
Args:
|
|
gateway_id: Pipeline gateway DB id (echoed in response).
|
|
config: Credentials and URL for the gateway RPC connection.
|
|
account_key: Stable identifier for this gateway's account
|
|
(e.g. ``"claude-default"``, ``"openai-work"``). Used in
|
|
per-model breakdowns to keep separate accounts distinct.
|
|
|
|
Returns:
|
|
A fully populated ``RuntimeUsageResponse``. All fields default to
|
|
safe zeroes if the gateway is unreachable or returns unexpected data.
|
|
"""
|
|
now = utcnow()
|
|
|
|
cost_raw, status_raw, sessions_raw = await asyncio.gather(
|
|
_safe_call("usage.cost", config),
|
|
_safe_call("usage.status", config),
|
|
_safe_call("sessions.list", config),
|
|
)
|
|
|
|
# Extract sessions from sessions.list response (primary source)
|
|
# Fallback to usage.cost if sessions.list fails
|
|
if isinstance(sessions_raw, dict):
|
|
raw_sessions = sessions_raw.get("sessions") or []
|
|
elif isinstance(sessions_raw, list):
|
|
raw_sessions = sessions_raw
|
|
else:
|
|
raw_sessions = []
|
|
|
|
# Filter to dicts and merge with usage.cost data if available
|
|
sessions: list[dict[str, Any]] = []
|
|
if raw_sessions:
|
|
sessions = [s for s in raw_sessions if isinstance(s, dict)]
|
|
else:
|
|
# Fallback: parse from usage.cost response
|
|
sessions = _parse_sessions(cost_raw)
|
|
|
|
# Merge both payloads — some gateways return everything in one response
|
|
merged_status = {**cost_raw, **status_raw}
|
|
|
|
per_model = aggregate_per_model(sessions, account_key=account_key)
|
|
window = _build_window(merged_status, now)
|
|
current = _build_current(per_model, merged_status)
|
|
burn_rate = _compute_burn_rate(sessions, window, now)
|
|
predictions = _build_predictions(current, burn_rate, window)
|
|
top = _top_sessions(sessions)
|
|
|
|
logger.info(
|
|
"runtime_usage.computed gateway_id=%s sessions=%d models=%d total_cost=%.6f",
|
|
gateway_id,
|
|
len(sessions),
|
|
len(per_model),
|
|
current.total_cost_usd,
|
|
)
|
|
return RuntimeUsageResponse(
|
|
generated_at=now,
|
|
gateway_id=gateway_id,
|
|
window=window,
|
|
current=current,
|
|
burn_rate=burn_rate,
|
|
predictions=predictions,
|
|
per_model=per_model,
|
|
top_sessions=top,
|
|
)
|