Pipeline/backend/app/services/openclaw/runtime_usage.py

871 lines
35 KiB
Python
Raw Normal View History

"""Runtime usage service — compute model spend, burn rate, and time remaining.
Data source: gateway RPC methods ``usage.cost`` and ``usage.status``.
Pricing: built-in defaults plus an optional JSON override file pointed to by
the ``RUNTIME_USAGE_PRICING_FILE`` environment variable.
Design decisions:
- Total tokens (input + output) are the basis for time-remaining predictions.
- All org members may call this endpoint (not admin-only).
- Parsing is fully defensive: malformed or missing fields default to zero.
- Unknown paid models are flagged ``unpriced=True`` so the UI can warn.
- Ollama / local models are flagged ``unpriced=False, cost_usd=0`` (free by design).
"""
from __future__ import annotations
import asyncio
import json
import os
import re
from datetime import datetime, timedelta, timezone
from typing import Any
from uuid import UUID
from app.core.logging import get_logger
from app.core.time import utcnow
from app.schemas.runtime_usage import (
ModelUsageEntry,
RuntimeUsageBurnRate,
RuntimeUsageCurrent,
RuntimeUsagePredictions,
RuntimeUsageResponse,
RuntimeUsageWindow,
TopSession,
)
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Pricing config (USD per million tokens)
# ---------------------------------------------------------------------------
_PricingEntry = dict[str, float] # keys: input, output, cache_read, cache_write
DEFAULT_MODEL_PRICING: dict[str, _PricingEntry] = {
# Anthropic — Claude 4.x
# Opus cache_write = $18.75/MTok (5x input price, per Anthropic docs)
"anthropic/claude-opus-4-7": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 18.75},
"anthropic/claude-opus-4-5": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 18.75},
"anthropic/claude-sonnet-4-6": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
"anthropic/claude-sonnet-4-5": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
"anthropic/claude-haiku-4-5": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00},
# Anthropic — Claude 3.x
"anthropic/claude-3-5-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
"anthropic/claude-3-5-haiku": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00},
"anthropic/claude-3-opus": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 18.75},
"anthropic/claude-3-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
"anthropic/claude-3-haiku": {"input": 0.25, "output": 1.25, "cache_read": 0.03, "cache_write": 0.30},
# OpenAI — GPT-4.1 family (2025)
"openai/gpt-4.1": {"input": 2.00, "output": 8.00, "cache_read": 0.50, "cache_write": 0.00},
"openai/gpt-4.1-mini": {"input": 0.40, "output": 1.60, "cache_read": 0.10, "cache_write": 0.00},
"openai/gpt-4.1-nano": {"input": 0.10, "output": 0.40, "cache_read": 0.025, "cache_write": 0.00},
# OpenAI — GPT-4o family
"openai/gpt-4o": {"input": 2.50, "output": 10.00, "cache_read": 1.25, "cache_write": 0.00},
"openai/gpt-4o-mini": {"input": 0.15, "output": 0.60, "cache_read": 0.075, "cache_write": 0.00},
"openai/gpt-4-5": {"input": 75.0, "output": 150.0, "cache_read": 37.50, "cache_write": 0.00},
"openai/gpt-4-turbo": {"input": 10.00, "output": 30.00, "cache_read": 0.00, "cache_write": 0.00},
"openai/gpt-4": {"input": 30.00, "output": 60.00, "cache_read": 0.00, "cache_write": 0.00},
"openai/gpt-3-5-turbo": {"input": 0.50, "output": 1.50, "cache_read": 0.00, "cache_write": 0.00},
# OpenAI — o-series reasoning
"openai/o1": {"input": 15.00, "output": 60.00, "cache_read": 7.50, "cache_write": 0.00},
"openai/o1-mini": {"input": 3.00, "output": 12.00, "cache_read": 1.50, "cache_write": 0.00},
"openai/o3": {"input": 10.00, "output": 40.00, "cache_read": 2.50, "cache_write": 0.00},
"openai/o3-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.55, "cache_write": 0.00},
"openai/o4-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.275, "cache_write": 0.00},
# Codex alias (free under Codex plan)
"openai/codex": {"input": 0.00, "output": 0.00, "cache_read": 0.00, "cache_write": 0.00},
}
_pricing_cache: dict[str, _PricingEntry] | None = None
def load_pricing() -> dict[str, _PricingEntry]:
"""Return merged pricing: defaults + optional override file.
Override file may use either shape:
{ "provider/model": { "input": X, "output": Y, ... } }
{ "rates_usd_per_million": { "provider/model": { ... } } }
"""
global _pricing_cache
if _pricing_cache is not None:
return _pricing_cache
merged = dict(DEFAULT_MODEL_PRICING)
override_path = os.getenv("RUNTIME_USAGE_PRICING_FILE", "").strip()
if override_path:
try:
with open(override_path) as fh:
raw = json.load(fh)
if isinstance(raw, dict):
# Unwrap reference-dashboard shape if present
overrides: dict[str, Any] = raw.get("rates_usd_per_million", raw)
if isinstance(overrides, dict):
merged.update(overrides)
logger.info("runtime_usage.pricing.override_loaded path=%s count=%d", override_path, len(overrides))
except Exception as exc:
logger.warning("runtime_usage.pricing.override_failed path=%s error=%s", override_path, exc)
_pricing_cache = merged
return _pricing_cache
# ---------------------------------------------------------------------------
# Provider / model normalisation
# ---------------------------------------------------------------------------
_PROVIDER_ALIASES: dict[str, str] = {
"anthropic": "anthropic",
"claude": "anthropic",
"openai": "openai",
"codex": "openai",
"ollama": "ollama",
"local": "ollama",
"gemini": "google",
"google": "google",
}
_MODEL_STRIP_RE = re.compile(
r"(-\d{8}|-latest|-preview|-instruct|-chat|-v\d+(\.\d+)*)$",
re.IGNORECASE,
)
def normalize_provider(raw: str) -> str:
"""Normalise a provider string to a canonical lower-case slug."""
cleaned = raw.strip().lower()
return _PROVIDER_ALIASES.get(cleaned, cleaned or "unknown")
def normalize_model(raw: str) -> str:
"""Strip date stamps, version tags, and known suffixes from a model name."""
cleaned = raw.strip().lower()
# Remove provider prefix if present (e.g. "anthropic/claude-sonnet-4-5")
if "/" in cleaned:
cleaned = cleaned.split("/", 1)[1]
# Strip trailing date/version stamps
cleaned = _MODEL_STRIP_RE.sub("", cleaned)
return cleaned or raw.strip().lower()
def model_key(provider: str, model: str) -> str:
"""Return the canonical dict key ``"provider/model"``."""
return f"{provider}/{model}"
# ---------------------------------------------------------------------------
# Cost estimation
# ---------------------------------------------------------------------------
def estimate_cost(
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
) -> tuple[float, bool]:
"""Return ``(cost_usd, unpriced)`` for the given token counts.
``unpriced=True`` means we have no price for this paid model the caller
should surface a warning. Ollama / local models return ``(0.0, False)``.
"""
pricing = load_pricing()
key = model_key(provider, model)
entry = pricing.get(key)
if entry is None:
if provider == "ollama":
return 0.0, False
return 0.0, True # unknown paid model
per_m = 1_000_000
cost = (
input_tokens * entry.get("input", 0.0) / per_m
+ output_tokens * entry.get("output", 0.0) / per_m
+ cache_read_tokens * entry.get("cache_read", 0.0) / per_m
+ cache_write_tokens * entry.get("cache_write", 0.0) / per_m
)
return round(cost, 8), False
# ---------------------------------------------------------------------------
# Gateway RPC helpers
# ---------------------------------------------------------------------------
async def _safe_call(
method: str,
config: GatewayClientConfig,
) -> dict[str, Any]:
"""Call a gateway method and return a dict, or an empty dict on any error."""
try:
result = await openclaw_call(method, config=config)
if isinstance(result, dict):
return result
logger.debug("runtime_usage.rpc.unexpected_type method=%s type=%s", method, type(result).__name__)
return {}
except OpenClawGatewayError as exc:
logger.debug("runtime_usage.rpc.gateway_error method=%s error=%s", method, exc)
return {}
except Exception as exc:
logger.warning("runtime_usage.rpc.error method=%s error=%s", method, exc)
return {}
def _get_float(d: dict[str, Any], *keys: str, default: float = 0.0) -> float:
for key in keys:
val = d.get(key)
if val is not None:
try:
return float(val)
except (TypeError, ValueError):
pass
return default
def _get_int(d: dict[str, Any], *keys: str, default: int = 0) -> int:
return int(_get_float(d, *keys, default=float(default)))
def _get_str(d: dict[str, Any], *keys: str, default: str = "") -> str:
for key in keys:
val = d.get(key)
if isinstance(val, str) and val.strip():
return val.strip()
return default
def _get_explicit_cost(session: dict[str, Any]) -> float:
"""Return the explicit provider/runtime cost for a session, or 0.0 if absent.
Priority (reference dashboard order):
1. session["usage"]["cost"]["total"] explicit cost from provider/runtime
2. session["usage"]["cost"] flat cost in usage block
3. session["cost"] / session["cost_usd"] / session["costUsd"]
A value of 0.0 means "not present or zero" callers should fall back to a
local price-table estimate in that case. Never overwrite a positive explicit
cost with a local estimate.
"""
usage = session.get("usage")
if isinstance(usage, dict):
cost_block = usage.get("cost")
if isinstance(cost_block, dict):
total = cost_block.get("total")
if isinstance(total, (int, float)) and total > 0:
return float(total)
if isinstance(cost_block, (int, float)) and cost_block > 0:
return float(cost_block)
return _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
def _parse_datetime(value: object) -> datetime | None:
if not isinstance(value, str) or not value.strip():
return None
normalized = value.strip().replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(normalized)
if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
return parsed
except ValueError:
return None
def _parse_rate_limit_reset_value(value: object, now: datetime) -> datetime | None:
"""Parse rate-limit reset values from headers into a UTC naive datetime.
Supports:
- RFC3339/ISO timestamps
- Unix epoch seconds / milliseconds
- Delta-seconds strings
"""
parsed = _parse_datetime(value)
if parsed is not None:
return parsed
if value is None:
return None
raw = str(value).strip()
if not raw:
return None
try:
numeric = float(raw)
except ValueError:
return None
# Very large values are likely epoch milliseconds.
if numeric >= 1_000_000_000_000:
return datetime.fromtimestamp(numeric / 1000, tz=timezone.utc).replace(tzinfo=None)
# Epoch seconds (typical range today ~1.7e9).
if numeric >= 1_000_000_000:
return datetime.fromtimestamp(numeric, tz=timezone.utc).replace(tzinfo=None)
# Otherwise treat as delta-seconds until reset.
if numeric >= 0:
return now + timedelta(seconds=numeric)
return None
def _extract_rate_limit_reset_at(status_raw: dict[str, Any], now: datetime) -> datetime | None:
"""Find the first parseable rate-limit reset timestamp in status data."""
explicit_keys = (
"x_ratelimit_reset",
"x_ratelimit_reset_tokens",
"x_ratelimit_reset_requests",
"ratelimit_reset",
"anthropic_ratelimit_reset",
"anthropic_ratelimit_tokens_reset",
"anthropic_ratelimit_requests_reset",
"anthropic_ratelimit_input_tokens_reset",
)
for key in explicit_keys:
if key in status_raw:
parsed = _parse_rate_limit_reset_value(status_raw.get(key), now)
if parsed is not None:
return parsed
# Defensive fallback: inspect any ratelimit*reset field that may appear.
for key, value in status_raw.items():
normalized = str(key).lower().replace("-", "_")
if "ratelimit" in normalized and "reset" in normalized:
parsed = _parse_rate_limit_reset_value(value, now)
if parsed is not None:
return parsed
return None
# ---------------------------------------------------------------------------
# Session parsing
# ---------------------------------------------------------------------------
def _parse_sessions(cost_raw: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract a flat list of session dicts from the usage.cost response.
The gateway may return sessions at the top level or nested under a window
key (``"5hour"``, ``"today"``, ``"week"``, etc.). We prefer the most
granular available list.
"""
# Try flat list first
flat = cost_raw.get("sessions")
if isinstance(flat, list):
return [s for s in flat if isinstance(s, dict)]
# Try nested under window keys, in preference order
for window_key in ("5hour", "5h", "today", "week", "7day", "data"):
bucket = cost_raw.get(window_key)
if isinstance(bucket, dict):
nested = bucket.get("sessions")
if isinstance(nested, list):
return [s for s in nested if isinstance(s, dict)]
if isinstance(bucket, list):
return [s for s in bucket if isinstance(s, dict)]
return []
def _parse_session_usage(session: dict[str, Any]) -> dict[str, int]:
"""Extract token counts from a session dict, trying multiple key conventions."""
usage = session.get("usage") or session.get("tokens") or {}
if not isinstance(usage, dict):
usage = {}
return {
"input": _get_int(usage, "input_tokens", "inputTokens", "input", default=0)
or _get_int(session, "input_tokens", "inputTokens", default=0),
"output": _get_int(usage, "output_tokens", "outputTokens", "output", default=0)
or _get_int(session, "output_tokens", "outputTokens", default=0),
"cache_read": _get_int(usage, "cache_read_input_tokens", "cacheReadTokens", "cache_read", default=0)
or _get_int(session, "cache_read_tokens", "cacheReadTokens", default=0),
"cache_write": _get_int(usage, "cache_creation_input_tokens", "cacheWriteTokens", "cache_write", default=0)
or _get_int(session, "cache_write_tokens", "cacheWriteTokens", default=0),
}
# ---------------------------------------------------------------------------
# Aggregation
# ---------------------------------------------------------------------------
def aggregate_per_model(
sessions: list[dict[str, Any]],
account_key: str = "default",
) -> dict[str, ModelUsageEntry]:
"""Roll up token counts and cost across sessions, keyed by provider/model."""
entries: dict[str, dict[str, Any]] = {}
for session in sessions:
raw_provider = _get_str(session, "provider", default="anthropic")
raw_model = _get_str(session, "model", "modelName", default="unknown")
provider = normalize_provider(raw_provider)
model = normalize_model(raw_model)
key = model_key(provider, model)
tokens = _parse_session_usage(session)
session_cost = _get_explicit_cost(session)
calls = _get_int(session, "calls", "messageCount", "messages", default=1)
# Only estimate when the gateway provided no explicit cost
if session_cost == 0.0:
session_cost, _ = estimate_cost(
provider, model,
tokens["input"], tokens["output"],
tokens["cache_read"], tokens["cache_write"],
)
_, unpriced = estimate_cost(provider, model, 0, 0)
if key not in entries:
entries[key] = {
"provider": provider,
"account_key": account_key,
"model": model,
"input_tokens": 0,
"output_tokens": 0,
"cache_read_tokens": 0,
"cache_write_tokens": 0,
"cost_usd": 0.0,
"calls": 0,
"unpriced": unpriced,
"source": "local_jsonl_estimate", # default source for aggregated session data
}
e = entries[key]
e["input_tokens"] += tokens["input"]
e["output_tokens"] += tokens["output"]
e["cache_read_tokens"] += tokens["cache_read"]
e["cache_write_tokens"] += tokens["cache_write"]
e["cost_usd"] += session_cost
e["calls"] += calls
return {
key: ModelUsageEntry(
**{**e, "total_tokens": e["input_tokens"] + e["output_tokens"]},
)
for key, e in entries.items()
}
def _top_sessions(
sessions: list[dict[str, Any]],
limit: int = 10,
) -> list[TopSession]:
rows = []
for session in sessions:
sid = _get_str(session, "sessionId", "id", "session_id", default="")
label = _get_str(session, "label", "name", "title") or None
model = _get_str(session, "model", "modelName") or None
if model:
provider = normalize_provider(_get_str(session, "provider", default="anthropic"))
model = model_key(provider, normalize_model(model))
tokens = _parse_session_usage(session)
total = tokens["input"] + tokens["output"]
cost = _get_explicit_cost(session)
if cost == 0.0 and model:
parts = model.split("/", 1)
if len(parts) == 2:
cost, _ = estimate_cost(
parts[0], parts[1],
tokens["input"], tokens["output"],
tokens["cache_read"], tokens["cache_write"],
)
updated = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity") or None
rows.append(TopSession(
session_id=sid,
label=label,
model=model,
cost_usd=round(cost, 8),
total_tokens=total,
updated_at=updated,
source="local_jsonl_estimate", # default source for session data
))
rows.sort(key=lambda r: r.cost_usd, reverse=True)
return rows[:limit]
# ---------------------------------------------------------------------------
# Window and limit helpers
# ---------------------------------------------------------------------------
_WINDOW_HOURS = 5
def _oldest_active_ts(
sessions: list[dict[str, Any]],
now: datetime,
) -> datetime | None:
"""Return the oldest session timestamp still inside the rolling window, or None."""
cutoff = now - timedelta(hours=_WINDOW_HOURS)
oldest: datetime | None = None
for session in sessions:
raw_ts = _get_str(
session,
"updated_at", "updatedAt", "lastActivity", "last_activity",
"created_at", "createdAt",
)
ts = _parse_datetime(raw_ts)
if ts is None or ts < cutoff:
continue
if oldest is None or ts < oldest:
oldest = ts
return oldest
def _build_window(
status_raw: dict[str, Any],
now: datetime,
account_key: str = "default",
oldest_event_ts: datetime | None = None,
) -> RuntimeUsageWindow:
"""Build the usage window, preferring gateway status data then falling back.
Source assignment:
- If gateway status provides explicit window data, use provider_native
- If API rate-limit headers are the only source, use provider_api_rate_limit
- If falling back to local logic, use local_jsonl_estimate
When falling back, `oldest_event_ts` anchors the window start to the
oldest active session timestamp (oldest_event_ts + 5h = reset). This
avoids the previous bug where started_at = now - 5h made reset_in_ms = 0.
"""
explicit_started_at = _parse_datetime(
status_raw.get("windowStart") or status_raw.get("window_start")
or status_raw.get("period_start") or status_raw.get("started_at")
)
explicit_resets_at = _parse_datetime(
status_raw.get("windowEnd") or status_raw.get("window_end")
or status_raw.get("period_end") or status_raw.get("resets_at")
)
header_resets_at = _extract_rate_limit_reset_at(status_raw, now)
started_at = explicit_started_at
if started_at is None:
# Anchor to oldest active session so reset_in_ms reflects real remaining time.
# If no sessions exist in the window, keep a conservative local fallback.
started_at = oldest_event_ts if oldest_event_ts is not None else now - timedelta(hours=_WINDOW_HOURS)
resets_at = explicit_resets_at or header_resets_at
if resets_at is None:
resets_at = started_at + timedelta(hours=_WINDOW_HOURS)
# Label by reset-time provenance. Local fallback reset is always an estimate.
if explicit_resets_at is not None:
source = "provider_native"
confidence = "high"
elif header_resets_at is not None:
source = "provider_api_rate_limit"
confidence = "medium"
else:
source = "local_jsonl_estimate"
confidence = "low"
reset_delta = resets_at - now
reset_in_ms = max(0, int(reset_delta.total_seconds() * 1000))
return RuntimeUsageWindow(
key=f"{_WINDOW_HOURS}h",
started_at=started_at,
resets_at=resets_at,
reset_in_ms=reset_in_ms,
source=source,
confidence=confidence,
)
def _limit_source(status_raw: dict[str, Any]) -> str:
"""Return the appropriate source label for a limit read from gateway status."""
has_rate_limit_headers = (
status_raw.get("x_ratelimit_remaining") or
status_raw.get("x_ratelimit_limit") or
status_raw.get("anthropic_ratelimit_remaining") or
status_raw.get("anthropic_ratelimit_limit")
)
return "provider_api_rate_limit" if has_rate_limit_headers else "configured_limit"
def _pct(numerator: int | float, denominator: int | float) -> int | None:
if not denominator:
return None
return int(min(100, numerator * 100 // denominator))
def _build_current(
per_model: dict[str, ModelUsageEntry],
status_raw: dict[str, Any],
account_key: str = "default",
) -> RuntimeUsageCurrent:
total_cost = round(sum(e.cost_usd for e in per_model.values()), 8)
total_tokens = sum(e.total_tokens for e in per_model.values())
total_output_tokens = sum(e.output_tokens for e in per_model.values())
total_calls = sum(e.calls for e in per_model.values())
src = _limit_source(status_raw)
# ── Explicit output-token limit ───────────────────────────────────────────
raw_output_limit = _get_int(
status_raw, "outputTokenLimit", "output_token_limit", default=0
)
output_token_limit = raw_output_limit or None
output_token_limit_pct = _pct(total_output_tokens, raw_output_limit)
output_token_limit_src = src if raw_output_limit else None
# ── Explicit total-token limit ────────────────────────────────────────────
raw_total_limit = _get_int(
status_raw, "totalTokenLimit", "total_token_limit", default=0
)
total_token_limit = raw_total_limit or None
total_token_limit_pct = _pct(total_tokens, raw_total_limit)
total_token_limit_src = src if raw_total_limit else None
# ── Message/request limit (count-based, never token-based) ───────────────
raw_message_limit = _get_int(
status_raw, "messageLimit", "message_limit", "requestLimit", "request_limit",
default=0,
)
message_limit = raw_message_limit or None
message_pct = _pct(total_calls, raw_message_limit)
message_limit_src = src if raw_message_limit else None
# ── Legacy token_limit (ambiguous kind — maps to tokenLimit only) ─────────
# Do NOT fold messageLimit into this; keep units separate.
raw_token_limit = _get_int(status_raw, "tokenLimit", "token_limit", default=0)
token_limit = raw_token_limit or None
token_pct = _pct(total_tokens, raw_token_limit)
token_limit_src = src if raw_token_limit else None
# If we got an explicit typed limit but no legacy one, backfill legacy
# so existing dashboard code still works during the transition.
if token_limit is None:
if output_token_limit is not None:
token_limit = output_token_limit
token_pct = output_token_limit_pct
token_limit_src = output_token_limit_src
elif total_token_limit is not None:
token_limit = total_token_limit
token_pct = total_token_limit_pct
token_limit_src = total_token_limit_src
# ── Cost limit ────────────────────────────────────────────────────────────
raw_cost_limit = _get_float(status_raw, "costLimit", "cost_limit", "costLimitUsd", default=0.0)
cost_limit = raw_cost_limit or None
cost_pct = _pct(total_cost, raw_cost_limit) if raw_cost_limit else None
cost_limit_src = src if raw_cost_limit else None
return RuntimeUsageCurrent(
total_cost_usd=total_cost,
total_tokens=total_tokens,
total_output_tokens=total_output_tokens,
total_calls=total_calls,
# legacy
token_limit=token_limit,
token_pct=token_pct,
cost_limit_usd=cost_limit,
cost_pct=cost_pct,
token_limit_source=token_limit_src,
cost_limit_source=cost_limit_src,
# typed
output_token_limit=output_token_limit,
output_token_limit_pct=output_token_limit_pct,
output_token_limit_source=output_token_limit_src,
total_token_limit=total_token_limit,
total_token_limit_pct=total_token_limit_pct,
total_token_limit_source=total_token_limit_src,
message_limit=message_limit,
message_pct=message_pct,
message_limit_source=message_limit_src,
)
def _compute_burn_rate(
sessions: list[dict[str, Any]],
window: RuntimeUsageWindow,
now: datetime,
) -> RuntimeUsageBurnRate:
"""Compute tokens/min and cost/min from the most recent 60 minutes of sessions.
Tracks total tokens (input+output) and output tokens separately so that
predictions against output-token limits use the correct numerator.
"""
cutoff = now - timedelta(minutes=60)
recent_tokens = 0
recent_output_tokens = 0
recent_cost = 0.0
for session in sessions:
raw_ts = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity")
ts = _parse_datetime(raw_ts)
if ts is None or ts < cutoff:
continue
tokens = _parse_session_usage(session)
recent_tokens += tokens["input"] + tokens["output"]
recent_output_tokens += tokens["output"]
recent_cost += _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
return RuntimeUsageBurnRate(
tokens_per_minute=round(recent_tokens / 60, 4),
output_tokens_per_minute=round(recent_output_tokens / 60, 4),
cost_usd_per_minute=round(recent_cost / 60, 8),
)
def _build_predictions(
current: RuntimeUsageCurrent,
burn_rate: RuntimeUsageBurnRate,
window: RuntimeUsageWindow,
) -> RuntimeUsagePredictions:
"""Estimate time-to-limit in ms using the most constrained matching limit.
Priority order (tightest first):
1. output_token_limit vs output_tokens (burn: output_tokens_per_minute)
2. total_token_limit vs total_tokens (burn: tokens_per_minute)
3. legacy token_limit vs total_tokens (burn: tokens_per_minute)
4. message_limit vs total_calls (constant rate = calls / window_minutes)
Cost and request limits are not used for time-to-limit since they either
require billing data (cost) or are not the binding constraint in practice.
"""
candidates: list[tuple[int, str]] = [] # (time_to_limit_ms, kind)
# ── Output-token limit ────────────────────────────────────────────────────
if (
current.output_token_limit is not None
and burn_rate.output_tokens_per_minute > 0
):
remaining = current.output_token_limit - current.total_output_tokens
if remaining <= 0:
return RuntimeUsagePredictions(time_to_limit_ms=0, safe=False, limit_kind="output_tokens")
candidates.append((
int(remaining / burn_rate.output_tokens_per_minute * 60_000),
"output_tokens",
))
# ── Total-token limit ─────────────────────────────────────────────────────
if (
current.total_token_limit is not None
and burn_rate.tokens_per_minute > 0
):
remaining = current.total_token_limit - current.total_tokens
if remaining <= 0:
return RuntimeUsagePredictions(time_to_limit_ms=0, safe=False, limit_kind="total_tokens")
candidates.append((
int(remaining / burn_rate.tokens_per_minute * 60_000),
"total_tokens",
))
# ── Legacy token_limit (only when no typed token limit) ───────────────────
if (
not candidates
and current.token_limit is not None
and burn_rate.tokens_per_minute > 0
):
remaining = current.token_limit - current.total_tokens
if remaining <= 0:
return RuntimeUsagePredictions(time_to_limit_ms=0, safe=False, limit_kind="total_tokens")
candidates.append((
int(remaining / burn_rate.tokens_per_minute * 60_000),
"total_tokens",
))
# ── Message limit ─────────────────────────────────────────────────────────
if current.message_limit is not None and current.message_limit > 0:
# Use elapsed time (window duration - remaining) so rate reflects
# actual usage density, not just time left in the window.
total_window_ms = max(1, int((window.resets_at - window.started_at).total_seconds() * 1000))
elapsed_ms = max(1, total_window_ms - window.reset_in_ms)
elapsed_minutes = elapsed_ms / 60_000
calls_per_minute = current.total_calls / elapsed_minutes if elapsed_minutes > 0 else 0
if calls_per_minute > 0:
remaining = current.message_limit - current.total_calls
if remaining <= 0:
return RuntimeUsagePredictions(time_to_limit_ms=0, safe=False, limit_kind="messages")
candidates.append((
int(remaining / calls_per_minute * 60_000),
"messages",
))
if not candidates:
return RuntimeUsagePredictions(time_to_limit_ms=None, safe=True, limit_kind="total_tokens")
# Pick the most constrained (smallest time) — that is what will actually block work.
time_to_limit_ms, kind = min(candidates, key=lambda c: c[0])
safe = time_to_limit_ms > window.reset_in_ms
return RuntimeUsagePredictions(time_to_limit_ms=time_to_limit_ms, safe=safe, limit_kind=kind)
# ---------------------------------------------------------------------------
# Public service entry point
# ---------------------------------------------------------------------------
async def get_runtime_usage(
gateway_id: UUID,
config: GatewayClientConfig,
account_key: str = "default",
) -> RuntimeUsageResponse:
"""Fetch and aggregate runtime usage for one gateway.
Args:
gateway_id: Pipeline gateway DB id (echoed in response).
config: Credentials and URL for the gateway RPC connection.
account_key: Stable identifier for this gateway's account
(e.g. ``"claude-default"``, ``"openai-work"``). Used in
per-model breakdowns to keep separate accounts distinct.
Returns:
A fully populated ``RuntimeUsageResponse``. All fields default to
safe zeroes if the gateway is unreachable or returns unexpected data.
"""
now = utcnow()
cost_raw, status_raw = await asyncio.gather(
_safe_call("usage.cost", config),
_safe_call("usage.status", config),
)
# sessions.list can block the gateway on large session stores; cap at 8s.
try:
sessions_raw = await asyncio.wait_for(
_safe_call("sessions.list", config), timeout=8
)
except asyncio.TimeoutError:
logger.warning("runtime_usage.sessions_list.timeout gateway_id=%s", gateway_id)
sessions_raw = {}
# Extract sessions from sessions.list response (primary source)
# Fallback to usage.cost if sessions.list fails
if isinstance(sessions_raw, dict):
raw_sessions = sessions_raw.get("sessions") or []
elif isinstance(sessions_raw, list):
raw_sessions = sessions_raw
else:
raw_sessions = []
sessions: list[dict[str, Any]] = []
if raw_sessions:
sessions = [s for s in raw_sessions if isinstance(s, dict)]
else:
sessions = _parse_sessions(cost_raw)
merged_status = {**cost_raw, **status_raw}
oldest_event_ts = _oldest_active_ts(sessions, now)
per_model = aggregate_per_model(sessions, account_key=account_key)
window = _build_window(merged_status, now, oldest_event_ts=oldest_event_ts)
current = _build_current(per_model, merged_status)
burn_rate = _compute_burn_rate(sessions, window, now)
predictions = _build_predictions(current, burn_rate, window)
top = _top_sessions(sessions)
logger.info(
"runtime_usage.computed gateway_id=%s sessions=%d models=%d total_cost=%.6f",
gateway_id,
len(sessions),
len(per_model),
current.total_cost_usd,
)
return RuntimeUsageResponse(
generated_at=now,
gateway_id=gateway_id,
window=window,
current=current,
burn_rate=burn_rate,
predictions=predictions,
per_model=per_model,
top_sessions=top,
)