diff --git a/.gitignore b/.gitignore index 0bb5bb6..6b06d2d 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ frontend/coverage backend/app/services/openclaw/.device-keys FUTURE.md FUTURE.md +docs/runtime-usage-dashboard-plan.md diff --git a/backend/app/api/gateways.py b/backend/app/api/gateways.py index 0c8a798..5c22e08 100644 --- a/backend/app/api/gateways.py +++ b/backend/app/api/gateways.py @@ -8,7 +8,7 @@ from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, Query from sqlmodel import col -from app.api.deps import require_org_admin +from app.api.deps import require_org_admin, require_org_member from app.core.auth import AuthContext, get_auth_context from app.db import crud from app.db.pagination import paginate @@ -26,6 +26,8 @@ from app.schemas.gateways import ( GatewayTemplatesSyncResult, GatewayUpdate, ) +from app.schemas.runtime_usage import RuntimeUsageResponse +from app.services.openclaw.runtime_usage import get_runtime_usage from app.schemas.pagination import DefaultLimitOffsetPage from app.services.openclaw.admin_service import GatewayAdminLifecycleService from app.services.openclaw.session_service import GatewayTemplateSyncQuery @@ -41,6 +43,7 @@ router = APIRouter(prefix="/gateways", tags=["gateways"]) SESSION_DEP = Depends(get_session) AUTH_DEP = Depends(get_auth_context) ORG_ADMIN_DEP = Depends(require_org_admin) +ORG_MEMBER_DEP = Depends(require_org_member) INCLUDE_MAIN_QUERY = Query(default=True) RESET_SESSIONS_QUERY = Query(default=False) ROTATE_TOKENS_QUERY = Query(default=False) @@ -232,6 +235,43 @@ async def import_gateway_agents( ) +@router.get( + "/{gateway_id}/runtime-usage", + response_model=RuntimeUsageResponse, + summary="Gateway runtime usage", + description=( + "Return model usage, token counts, estimated spend, burn rate, and " + "time-remaining predictions for the specified gateway. " + "Visible to all organisation members." + ), +) +async def get_gateway_runtime_usage( + gateway_id: UUID, + session: AsyncSession = SESSION_DEP, + ctx: OrganizationContext = ORG_MEMBER_DEP, +) -> RuntimeUsageResponse: + """Aggregate runtime usage from the gateway's usage.cost / usage.status RPC methods.""" + from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig + + service = GatewayAdminLifecycleService(session) + gateway = await service.require_gateway( + gateway_id=gateway_id, + organization_id=ctx.organization.id, + ) + config = GatewayClientConfig( + url=gateway.url, + token=gateway.token, + allow_insecure_tls=gateway.allow_insecure_tls, + disable_device_pairing=gateway.disable_device_pairing, + ) + account_key = gateway.name.lower().replace(" ", "-") if gateway.name else "default" + return await get_runtime_usage( + gateway_id=gateway.id, + config=config, + account_key=account_key, + ) + + @router.delete("/{gateway_id}", response_model=OkResponse) async def delete_gateway( gateway_id: UUID, diff --git a/backend/app/schemas/runtime_usage.py b/backend/app/schemas/runtime_usage.py new file mode 100644 index 0000000..74c976e --- /dev/null +++ b/backend/app/schemas/runtime_usage.py @@ -0,0 +1,85 @@ +"""Response schemas for the gateway runtime usage endpoint.""" + +from __future__ import annotations + +from datetime import datetime +from uuid import UUID + +from sqlmodel import SQLModel + +RUNTIME_ANNOTATION_TYPES = (datetime, UUID) + + +class RuntimeUsageWindow(SQLModel): + """Rolling 5-hour usage window metadata.""" + + key: str # "5h" + started_at: datetime + resets_at: datetime + reset_in_ms: int # milliseconds until oldest event ages out + + +class RuntimeUsageCurrent(SQLModel): + """Aggregated totals within the current window.""" + + total_cost_usd: float + total_tokens: int # input + output across all sessions + total_calls: int + token_limit: int | None = None # configured limit; None = unknown + token_pct: int | None = None # 0–100; None when limit unknown + cost_limit_usd: float | None = None + cost_pct: int | None = None + + +class RuntimeUsageBurnRate(SQLModel): + """Recent token and cost velocity (last 60 minutes of the window).""" + + tokens_per_minute: float + cost_usd_per_minute: float + + +class RuntimeUsagePredictions(SQLModel): + """Estimates derived from current burn rate and configured limits.""" + + time_to_limit_ms: int | None = None # None when limit or burn rate unknown + safe: bool # True if time_to_limit > reset_in_ms (will reset before hitting limit) + + +class ModelUsageEntry(SQLModel): + """Usage and cost breakdown for one provider/model combination.""" + + provider: str # normalised: "anthropic", "openai", "ollama", "unknown" + account_key: str # e.g. "claude-default", "openai-work", "ollama-local" + model: str # normalised model slug, e.g. "claude-sonnet-4-6" + input_tokens: int + output_tokens: int + cache_read_tokens: int + cache_write_tokens: int + total_tokens: int + cost_usd: float + calls: int + unpriced: bool # True = unknown paid model; False = priced or intentionally free (Ollama) + + +class TopSession(SQLModel): + """Summary row for one session, sorted by cost descending.""" + + session_id: str + label: str | None = None + model: str | None = None + cost_usd: float + total_tokens: int + updated_at: str | None = None + + +class RuntimeUsageResponse(SQLModel): + """Complete runtime usage payload returned by GET /gateways/{id}/runtime-usage.""" + + generated_at: datetime + gateway_id: UUID + window: RuntimeUsageWindow + current: RuntimeUsageCurrent + burn_rate: RuntimeUsageBurnRate + predictions: RuntimeUsagePredictions + per_model: dict[str, ModelUsageEntry] # key = "provider/model" + top_sessions: list[TopSession] diff --git a/backend/app/services/openclaw/runtime_usage.py b/backend/app/services/openclaw/runtime_usage.py new file mode 100644 index 0000000..6ed6bd4 --- /dev/null +++ b/backend/app/services/openclaw/runtime_usage.py @@ -0,0 +1,544 @@ +"""Runtime usage service — compute model spend, burn rate, and time remaining. + +Data source: gateway RPC methods ``usage.cost`` and ``usage.status``. +Pricing: built-in defaults plus an optional JSON override file pointed to by +the ``RUNTIME_USAGE_PRICING_FILE`` environment variable. + +Design decisions: +- Total tokens (input + output) are the basis for time-remaining predictions. +- All org members may call this endpoint (not admin-only). +- Parsing is fully defensive: malformed or missing fields default to zero. +- Unknown paid models are flagged ``unpriced=True`` so the UI can warn. +- Ollama / local models are flagged ``unpriced=False, cost_usd=0`` (free by design). +""" + +from __future__ import annotations + +import asyncio +import json +import os +import re +from datetime import datetime, timedelta, timezone +from typing import Any +from uuid import UUID + +from app.core.logging import get_logger +from app.core.time import utcnow +from app.schemas.runtime_usage import ( + ModelUsageEntry, + RuntimeUsageBurnRate, + RuntimeUsageCurrent, + RuntimeUsagePredictions, + RuntimeUsageResponse, + RuntimeUsageWindow, + TopSession, +) +from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig +from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call + +logger = get_logger(__name__) + +# --------------------------------------------------------------------------- +# Pricing config (USD per million tokens) +# --------------------------------------------------------------------------- + +_PricingEntry = dict[str, float] # keys: input, output, cache_read, cache_write + +DEFAULT_MODEL_PRICING: dict[str, _PricingEntry] = { + # Anthropic — Claude 4.x + "anthropic/claude-opus-4-7": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75}, + "anthropic/claude-opus-4-5": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75}, + "anthropic/claude-sonnet-4-6": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75}, + "anthropic/claude-sonnet-4-5": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75}, + "anthropic/claude-haiku-4-5": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00}, + # Anthropic — Claude 3.x + "anthropic/claude-3-5-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75}, + "anthropic/claude-3-5-haiku": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00}, + "anthropic/claude-3-opus": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75}, + "anthropic/claude-3-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75}, + "anthropic/claude-3-haiku": {"input": 0.25, "output": 1.25, "cache_read": 0.03, "cache_write": 0.30}, + # OpenAI — GPT-4o family + "openai/gpt-4o": {"input": 2.50, "output": 10.00, "cache_read": 1.25, "cache_write": 0.00}, + "openai/gpt-4o-mini": {"input": 0.15, "output": 0.60, "cache_read": 0.075, "cache_write": 0.00}, + "openai/gpt-4-turbo": {"input": 10.00, "output": 30.00, "cache_read": 0.00, "cache_write": 0.00}, + "openai/gpt-4": {"input": 30.00, "output": 60.00, "cache_read": 0.00, "cache_write": 0.00}, + "openai/gpt-3-5-turbo": {"input": 0.50, "output": 1.50, "cache_read": 0.00, "cache_write": 0.00}, + # OpenAI — o-series reasoning + "openai/o1": {"input": 15.00, "output": 60.00, "cache_read": 7.50, "cache_write": 0.00}, + "openai/o1-mini": {"input": 3.00, "output": 12.00, "cache_read": 1.50, "cache_write": 0.00}, + "openai/o3": {"input": 10.00, "output": 40.00, "cache_read": 2.50, "cache_write": 0.00}, + "openai/o3-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.55, "cache_write": 0.00}, + "openai/o4-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.275, "cache_write": 0.00}, + # Codex alias + "openai/codex": {"input": 0.00, "output": 0.00, "cache_read": 0.00, "cache_write": 0.00}, +} + +_pricing_cache: dict[str, _PricingEntry] | None = None + + +def load_pricing() -> dict[str, _PricingEntry]: + """Return merged pricing: defaults + optional override file.""" + global _pricing_cache + if _pricing_cache is not None: + return _pricing_cache + merged = dict(DEFAULT_MODEL_PRICING) + override_path = os.getenv("RUNTIME_USAGE_PRICING_FILE", "").strip() + if override_path: + try: + with open(override_path) as fh: + overrides = json.load(fh) + if isinstance(overrides, dict): + merged.update(overrides) + logger.info("runtime_usage.pricing.override_loaded path=%s", override_path) + except Exception as exc: + logger.warning("runtime_usage.pricing.override_failed path=%s error=%s", override_path, exc) + _pricing_cache = merged + return _pricing_cache + + +# --------------------------------------------------------------------------- +# Provider / model normalisation +# --------------------------------------------------------------------------- + +_PROVIDER_ALIASES: dict[str, str] = { + "anthropic": "anthropic", + "claude": "anthropic", + "openai": "openai", + "codex": "openai", + "ollama": "ollama", + "local": "ollama", + "gemini": "google", + "google": "google", +} + +_MODEL_STRIP_RE = re.compile( + r"(-\d{8}|-latest|-preview|-instruct|-chat|-v\d+(\.\d+)*)$", + re.IGNORECASE, +) + + +def normalize_provider(raw: str) -> str: + """Normalise a provider string to a canonical lower-case slug.""" + cleaned = raw.strip().lower() + return _PROVIDER_ALIASES.get(cleaned, cleaned or "unknown") + + +def normalize_model(raw: str) -> str: + """Strip date stamps, version tags, and known suffixes from a model name.""" + cleaned = raw.strip().lower() + # Remove provider prefix if present (e.g. "anthropic/claude-sonnet-4-5") + if "/" in cleaned: + cleaned = cleaned.split("/", 1)[1] + # Strip trailing date/version stamps + cleaned = _MODEL_STRIP_RE.sub("", cleaned) + return cleaned or raw.strip().lower() + + +def model_key(provider: str, model: str) -> str: + """Return the canonical dict key ``"provider/model"``.""" + return f"{provider}/{model}" + + +# --------------------------------------------------------------------------- +# Cost estimation +# --------------------------------------------------------------------------- + +def estimate_cost( + provider: str, + model: str, + input_tokens: int, + output_tokens: int, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, +) -> tuple[float, bool]: + """Return ``(cost_usd, unpriced)`` for the given token counts. + + ``unpriced=True`` means we have no price for this paid model — the caller + should surface a warning. Ollama / local models return ``(0.0, False)``. + """ + pricing = load_pricing() + key = model_key(provider, model) + entry = pricing.get(key) + + if entry is None: + if provider == "ollama": + return 0.0, False + return 0.0, True # unknown paid model + + per_m = 1_000_000 + cost = ( + input_tokens * entry.get("input", 0.0) / per_m + + output_tokens * entry.get("output", 0.0) / per_m + + cache_read_tokens * entry.get("cache_read", 0.0) / per_m + + cache_write_tokens * entry.get("cache_write", 0.0) / per_m + ) + return round(cost, 8), False + + +# --------------------------------------------------------------------------- +# Gateway RPC helpers +# --------------------------------------------------------------------------- + +async def _safe_call( + method: str, + config: GatewayClientConfig, +) -> dict[str, Any]: + """Call a gateway method and return a dict, or an empty dict on any error.""" + try: + result = await openclaw_call(method, config=config) + if isinstance(result, dict): + return result + logger.debug("runtime_usage.rpc.unexpected_type method=%s type=%s", method, type(result).__name__) + return {} + except OpenClawGatewayError as exc: + logger.debug("runtime_usage.rpc.gateway_error method=%s error=%s", method, exc) + return {} + except Exception as exc: + logger.warning("runtime_usage.rpc.error method=%s error=%s", method, exc) + return {} + + +def _get_float(d: dict[str, Any], *keys: str, default: float = 0.0) -> float: + for key in keys: + val = d.get(key) + if val is not None: + try: + return float(val) + except (TypeError, ValueError): + pass + return default + + +def _get_int(d: dict[str, Any], *keys: str, default: int = 0) -> int: + return int(_get_float(d, *keys, default=float(default))) + + +def _get_str(d: dict[str, Any], *keys: str, default: str = "") -> str: + for key in keys: + val = d.get(key) + if isinstance(val, str) and val.strip(): + return val.strip() + return default + + +def _parse_datetime(value: object) -> datetime | None: + if not isinstance(value, str) or not value.strip(): + return None + normalized = value.strip().replace("Z", "+00:00") + try: + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo is not None: + return parsed.astimezone(timezone.utc).replace(tzinfo=None) + return parsed + except ValueError: + return None + + +# --------------------------------------------------------------------------- +# Session parsing +# --------------------------------------------------------------------------- + +def _parse_sessions(cost_raw: dict[str, Any]) -> list[dict[str, Any]]: + """Extract a flat list of session dicts from the usage.cost response. + + The gateway may return sessions at the top level or nested under a window + key (``"5hour"``, ``"today"``, ``"week"``, etc.). We prefer the most + granular available list. + """ + # Try flat list first + flat = cost_raw.get("sessions") + if isinstance(flat, list): + return [s for s in flat if isinstance(s, dict)] + + # Try nested under window keys, in preference order + for window_key in ("5hour", "5h", "today", "week", "7day", "data"): + bucket = cost_raw.get(window_key) + if isinstance(bucket, dict): + nested = bucket.get("sessions") + if isinstance(nested, list): + return [s for s in nested if isinstance(s, dict)] + if isinstance(bucket, list): + return [s for s in bucket if isinstance(s, dict)] + + return [] + + +def _parse_session_usage(session: dict[str, Any]) -> dict[str, int]: + """Extract token counts from a session dict, trying multiple key conventions.""" + usage = session.get("usage") or session.get("tokens") or {} + if not isinstance(usage, dict): + usage = {} + + return { + "input": _get_int(usage, "input_tokens", "inputTokens", "input", default=0) + or _get_int(session, "input_tokens", "inputTokens", default=0), + "output": _get_int(usage, "output_tokens", "outputTokens", "output", default=0) + or _get_int(session, "output_tokens", "outputTokens", default=0), + "cache_read": _get_int(usage, "cache_read_input_tokens", "cacheReadTokens", "cache_read", default=0) + or _get_int(session, "cache_read_tokens", "cacheReadTokens", default=0), + "cache_write": _get_int(usage, "cache_creation_input_tokens", "cacheWriteTokens", "cache_write", default=0) + or _get_int(session, "cache_write_tokens", "cacheWriteTokens", default=0), + } + + +# --------------------------------------------------------------------------- +# Aggregation +# --------------------------------------------------------------------------- + +def aggregate_per_model( + sessions: list[dict[str, Any]], + account_key: str = "default", +) -> dict[str, ModelUsageEntry]: + """Roll up token counts and cost across sessions, keyed by provider/model.""" + entries: dict[str, dict[str, Any]] = {} + + for session in sessions: + raw_provider = _get_str(session, "provider", default="anthropic") + raw_model = _get_str(session, "model", "modelName", default="unknown") + provider = normalize_provider(raw_provider) + model = normalize_model(raw_model) + key = model_key(provider, model) + + tokens = _parse_session_usage(session) + session_cost = _get_float(session, "cost", "cost_usd", "costUsd", default=0.0) + calls = _get_int(session, "calls", "messageCount", "messages", default=1) + # If gateway didn't compute cost, estimate it + if session_cost == 0.0: + session_cost, _ = estimate_cost( + provider, model, + tokens["input"], tokens["output"], + tokens["cache_read"], tokens["cache_write"], + ) + _, unpriced = estimate_cost(provider, model, 0, 0) + + if key not in entries: + entries[key] = { + "provider": provider, + "account_key": account_key, + "model": model, + "input_tokens": 0, + "output_tokens": 0, + "cache_read_tokens": 0, + "cache_write_tokens": 0, + "cost_usd": 0.0, + "calls": 0, + "unpriced": unpriced, + } + e = entries[key] + e["input_tokens"] += tokens["input"] + e["output_tokens"] += tokens["output"] + e["cache_read_tokens"] += tokens["cache_read"] + e["cache_write_tokens"] += tokens["cache_write"] + e["cost_usd"] += session_cost + e["calls"] += calls + + return { + key: ModelUsageEntry( + **{**e, "total_tokens": e["input_tokens"] + e["output_tokens"]}, + ) + for key, e in entries.items() + } + + +def _top_sessions( + sessions: list[dict[str, Any]], + limit: int = 10, +) -> list[TopSession]: + rows = [] + for session in sessions: + sid = _get_str(session, "sessionId", "id", "session_id", default="") + label = _get_str(session, "label", "name", "title") or None + model = _get_str(session, "model", "modelName") or None + if model: + provider = normalize_provider(_get_str(session, "provider", default="anthropic")) + model = model_key(provider, normalize_model(model)) + tokens = _parse_session_usage(session) + total = tokens["input"] + tokens["output"] + cost = _get_float(session, "cost", "cost_usd", "costUsd", default=0.0) + if cost == 0.0 and model: + parts = model.split("/", 1) + if len(parts) == 2: + cost, _ = estimate_cost( + parts[0], parts[1], + tokens["input"], tokens["output"], + tokens["cache_read"], tokens["cache_write"], + ) + updated = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity") or None + rows.append(TopSession( + session_id=sid, + label=label, + model=model, + cost_usd=round(cost, 8), + total_tokens=total, + updated_at=updated, + )) + rows.sort(key=lambda r: r.cost_usd, reverse=True) + return rows[:limit] + + +# --------------------------------------------------------------------------- +# Window and limit helpers +# --------------------------------------------------------------------------- + +_WINDOW_HOURS = 5 + + +def _build_window( + status_raw: dict[str, Any], + now: datetime, +) -> RuntimeUsageWindow: + """Build the usage window, preferring gateway status data then falling back.""" + started_at = _parse_datetime( + status_raw.get("windowStart") or status_raw.get("window_start") + or status_raw.get("period_start") or status_raw.get("started_at") + ) + resets_at = _parse_datetime( + status_raw.get("windowEnd") or status_raw.get("window_end") + or status_raw.get("period_end") or status_raw.get("resets_at") + ) + if started_at is None: + started_at = now - timedelta(hours=_WINDOW_HOURS) + if resets_at is None: + resets_at = started_at + timedelta(hours=_WINDOW_HOURS) + + reset_delta = resets_at - now + reset_in_ms = max(0, int(reset_delta.total_seconds() * 1000)) + return RuntimeUsageWindow( + key=f"{_WINDOW_HOURS}h", + started_at=started_at, + resets_at=resets_at, + reset_in_ms=reset_in_ms, + ) + + +def _build_current( + per_model: dict[str, ModelUsageEntry], + status_raw: dict[str, Any], +) -> RuntimeUsageCurrent: + total_cost = round(sum(e.cost_usd for e in per_model.values()), 8) + total_tokens = sum(e.total_tokens for e in per_model.values()) + total_calls = sum(e.calls for e in per_model.values()) + + # Try to get configured limits from the gateway status + raw_token_limit = _get_int(status_raw, "tokenLimit", "token_limit", "messageLimit", "message_limit", default=0) + token_limit = raw_token_limit or None + token_pct = int(min(100, total_tokens * 100 // raw_token_limit)) if raw_token_limit else None + + raw_cost_limit = _get_float(status_raw, "costLimit", "cost_limit", "costLimitUsd", default=0.0) + cost_limit = raw_cost_limit or None + cost_pct = int(min(100, total_cost * 100 / raw_cost_limit)) if raw_cost_limit else None + + return RuntimeUsageCurrent( + total_cost_usd=total_cost, + total_tokens=total_tokens, + total_calls=total_calls, + token_limit=token_limit, + token_pct=token_pct, + cost_limit_usd=cost_limit, + cost_pct=cost_pct, + ) + + +def _compute_burn_rate( + sessions: list[dict[str, Any]], + window: RuntimeUsageWindow, + now: datetime, +) -> RuntimeUsageBurnRate: + """Compute tokens/min and cost/min from the most recent 60 minutes of sessions.""" + cutoff = now - timedelta(minutes=60) + recent_tokens = 0 + recent_cost = 0.0 + + for session in sessions: + raw_ts = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity") + ts = _parse_datetime(raw_ts) + if ts is None or ts < cutoff: + continue + tokens = _parse_session_usage(session) + recent_tokens += tokens["input"] + tokens["output"] + recent_cost += _get_float(session, "cost", "cost_usd", "costUsd", default=0.0) + + # Rate per minute over the last 60 minutes + tokens_per_minute = round(recent_tokens / 60, 4) + cost_per_minute = round(recent_cost / 60, 8) + return RuntimeUsageBurnRate( + tokens_per_minute=tokens_per_minute, + cost_usd_per_minute=cost_per_minute, + ) + + +def _build_predictions( + current: RuntimeUsageCurrent, + burn_rate: RuntimeUsageBurnRate, + window: RuntimeUsageWindow, +) -> RuntimeUsagePredictions: + """Estimate time-to-limit in ms based on total-token burn rate.""" + if burn_rate.tokens_per_minute <= 0 or current.token_limit is None: + return RuntimeUsagePredictions(time_to_limit_ms=None, safe=True) + + tokens_remaining = current.token_limit - current.total_tokens + if tokens_remaining <= 0: + return RuntimeUsagePredictions(time_to_limit_ms=0, safe=False) + + minutes_to_limit = tokens_remaining / burn_rate.tokens_per_minute + time_to_limit_ms = int(minutes_to_limit * 60 * 1000) + safe = time_to_limit_ms > window.reset_in_ms + return RuntimeUsagePredictions(time_to_limit_ms=time_to_limit_ms, safe=safe) + + +# --------------------------------------------------------------------------- +# Public service entry point +# --------------------------------------------------------------------------- + +async def get_runtime_usage( + gateway_id: UUID, + config: GatewayClientConfig, + account_key: str = "default", +) -> RuntimeUsageResponse: + """Fetch and aggregate runtime usage for one gateway. + + Args: + gateway_id: Pipeline gateway DB id (echoed in response). + config: Credentials and URL for the gateway RPC connection. + account_key: Stable identifier for this gateway's account + (e.g. ``"claude-default"``, ``"openai-work"``). Used in + per-model breakdowns to keep separate accounts distinct. + + Returns: + A fully populated ``RuntimeUsageResponse``. All fields default to + safe zeroes if the gateway is unreachable or returns unexpected data. + """ + now = utcnow() + + cost_raw, status_raw = await asyncio.gather( + _safe_call("usage.cost", config), + _safe_call("usage.status", config), + ) + # Merge both payloads — some gateways return everything in one response + merged_status = {**cost_raw, **status_raw} + + sessions = _parse_sessions(cost_raw) + per_model = aggregate_per_model(sessions, account_key=account_key) + window = _build_window(merged_status, now) + current = _build_current(per_model, merged_status) + burn_rate = _compute_burn_rate(sessions, window, now) + predictions = _build_predictions(current, burn_rate, window) + top = _top_sessions(sessions) + + logger.info( + "runtime_usage.computed gateway_id=%s sessions=%d models=%d total_cost=%.6f", + gateway_id, + len(sessions), + len(per_model), + current.total_cost_usd, + ) + return RuntimeUsageResponse( + generated_at=now, + gateway_id=gateway_id, + window=window, + current=current, + burn_rate=burn_rate, + predictions=predictions, + per_model=per_model, + top_sessions=top, + ) diff --git a/backend/tests/test_runtime_usage_api.py b/backend/tests/test_runtime_usage_api.py new file mode 100644 index 0000000..c5a57c0 --- /dev/null +++ b/backend/tests/test_runtime_usage_api.py @@ -0,0 +1,253 @@ +# ruff: noqa: INP001 +"""API integration tests for GET /api/v1/gateways/{gateway_id}/runtime-usage. + +Uses an in-memory SQLite DB and patches `openclaw_call` to avoid real gateway +connections. Tests cover: success, gateway 404, org boundary, and graceful +degradation when the gateway RPC returns empty/error data. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch +from uuid import uuid4 + +import pytest +from fastapi import APIRouter, FastAPI +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession # sqlmodel's AsyncSession has .exec() + +from app import models as _models # noqa: F401 — registers SQLModel metadata +from app.api.gateways import router as gateways_router +from app.db.session import get_session +from app.models.gateways import Gateway +from app.models.organizations import Organization +from app.services.organizations import OrganizationContext + + +async def _make_engine() -> AsyncEngine: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + return engine + + +def _make_session_maker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: + return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +def _build_app( + session_maker: async_sessionmaker[AsyncSession], + org_id: UUID, +) -> FastAPI: + app = FastAPI() + api_v1 = APIRouter(prefix="/api/v1") + api_v1.include_router(gateways_router) + app.include_router(api_v1) + + async def _override_session() -> AsyncSession: + async with session_maker() as s: + yield s + + async def _override_org_member() -> OrganizationContext: + from app.models.organizations import Organization + org = Organization(id=org_id, name="test-org") + from app.models.organization_members import OrganizationMember + member = OrganizationMember(organization_id=org.id, user_id=uuid4(), role="admin") + return OrganizationContext(organization=org, member=member) + + async def _override_org_admin() -> OrganizationContext: + from app.models.organizations import Organization + org = Organization(id=org_id, name="test-org") + from app.models.organization_members import OrganizationMember + member = OrganizationMember(organization_id=org.id, user_id=uuid4(), role="admin") + return OrganizationContext(organization=org, member=member) + + from app.api.deps import require_org_admin, require_org_member + app.dependency_overrides[get_session] = _override_session + app.dependency_overrides[require_org_member] = _override_org_member + app.dependency_overrides[require_org_admin] = _override_org_admin + return app + + +async def _seed_gateway(session: AsyncSession, org_id: UUID) -> Gateway: + gateway = Gateway( + id=uuid4(), + organization_id=org_id, + name="test-gateway", + url="ws://localhost:18789", + token="test-token", + workspace_root="/tmp/test-workspace", + allow_insecure_tls=True, + disable_device_pairing=True, + ) + session.add(gateway) + await session.commit() + return gateway + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_runtime_usage_empty_gateway_response() -> None: + """Returns zeroed-out response gracefully when gateway RPC returns nothing.""" + engine = await _make_engine() + session_maker = _make_session_maker(engine) + + org_id = uuid4() + async with session_maker() as session: + gateway = await _seed_gateway(session, org_id) + + app = _build_app(session_maker, org_id) + + with patch( + "app.services.openclaw.runtime_usage.openclaw_call", + new_callable=AsyncMock, + return_value={}, + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage") + + assert response.status_code == 200 + data = response.json() + assert data["gateway_id"] == str(gateway.id) + assert data["current"]["total_cost_usd"] == 0.0 + assert data["current"]["total_tokens"] == 0 + assert data["per_model"] == {} + assert data["top_sessions"] == [] + assert data["predictions"]["safe"] is True + + +@pytest.mark.asyncio +async def test_runtime_usage_with_session_data() -> None: + """Aggregates per-model usage correctly from a sessions list.""" + engine = await _make_engine() + session_maker = _make_session_maker(engine) + + org_id = uuid4() + async with session_maker() as session: + gateway = await _seed_gateway(session, org_id) + + rpc_cost_response = { + "sessions": [ + { + "sessionId": "sess-1", + "provider": "anthropic", + "model": "claude-sonnet-4-6", + "usage": {"input_tokens": 10000, "output_tokens": 5000}, + "cost": 0.105, + "calls": 5, + "updatedAt": "2026-05-20T09:00:00Z", + }, + { + "sessionId": "sess-2", + "provider": "anthropic", + "model": "claude-haiku-4-5", + "usage": {"input_tokens": 50000, "output_tokens": 20000}, + "calls": 10, + "updatedAt": "2026-05-20T08:00:00Z", + }, + ] + } + + app = _build_app(session_maker, org_id) + + async def _mock_call(method: str, params=None, *, config): # noqa: ANN001 + if method == "usage.cost": + return rpc_cost_response + return {} + + with patch( + "app.services.openclaw.runtime_usage.openclaw_call", + side_effect=_mock_call, + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage") + + assert response.status_code == 200 + data = response.json() + per_model = data["per_model"] + assert "anthropic/claude-sonnet-4-6" in per_model + assert "anthropic/claude-haiku-4-5" in per_model + assert per_model["anthropic/claude-sonnet-4-6"]["input_tokens"] == 10000 + assert per_model["anthropic/claude-sonnet-4-6"]["calls"] == 5 + assert data["current"]["total_calls"] == 15 + # At least one top session + assert len(data["top_sessions"]) >= 1 + + +@pytest.mark.asyncio +async def test_runtime_usage_gateway_not_found() -> None: + """Returns 404 when gateway_id does not belong to the org.""" + engine = await _make_engine() + session_maker = _make_session_maker(engine) + + org_id = uuid4() + other_gateway_id = str(uuid4()) + + app = _build_app(session_maker, org_id) + + with patch( + "app.services.openclaw.runtime_usage.openclaw_call", + new_callable=AsyncMock, + return_value={}, + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get(f"/api/v1/gateways/{other_gateway_id}/runtime-usage") + + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_runtime_usage_org_boundary() -> None: + """A gateway created in a different org is not visible to another org.""" + engine = await _make_engine() + session_maker = _make_session_maker(engine) + + org_a = uuid4() + org_b = uuid4() + + # Seed gateway under org_a, but build app with org_b credentials + async with session_maker() as session: + gateway = await _seed_gateway(session, org_a) + + app = _build_app(session_maker, org_b) # authenticated as org_b + + with patch( + "app.services.openclaw.runtime_usage.openclaw_call", + new_callable=AsyncMock, + return_value={}, + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage") + + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_runtime_usage_rpc_error_degrades_gracefully() -> None: + """A gateway RPC failure returns zeroed usage rather than 500.""" + from app.services.openclaw.gateway_rpc import OpenClawGatewayError + + engine = await _make_engine() + session_maker = _make_session_maker(engine) + + org_id = uuid4() + async with session_maker() as session: + gateway = await _seed_gateway(session, org_id) + + app = _build_app(session_maker, org_id) + + with patch( + "app.services.openclaw.runtime_usage.openclaw_call", + side_effect=OpenClawGatewayError("timeout"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage") + + assert response.status_code == 200 + data = response.json() + assert data["current"]["total_cost_usd"] == 0.0 diff --git a/backend/tests/test_runtime_usage_service.py b/backend/tests/test_runtime_usage_service.py new file mode 100644 index 0000000..82e3032 --- /dev/null +++ b/backend/tests/test_runtime_usage_service.py @@ -0,0 +1,366 @@ +# ruff: noqa: INP001 +"""Unit tests for runtime_usage service helpers. + +Tests cover provider/model normalisation, cost estimation, session parsing, +per-model aggregation, window building, burn rate, and predictions. +No gateway connection is required. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import pytest + +from app.services.openclaw.runtime_usage import ( + DEFAULT_MODEL_PRICING, + _build_predictions, + _build_window, + _compute_burn_rate, + _parse_sessions, + aggregate_per_model, + estimate_cost, + load_pricing, + model_key, + normalize_model, + normalize_provider, +) +from app.schemas.runtime_usage import RuntimeUsageBurnRate, RuntimeUsageCurrent, RuntimeUsageWindow + + +# --------------------------------------------------------------------------- +# normalize_provider +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "raw, expected", + [ + ("anthropic", "anthropic"), + ("Anthropic", "anthropic"), + ("claude", "anthropic"), + ("CLAUDE", "anthropic"), + ("openai", "openai"), + ("OpenAI", "openai"), + ("codex", "openai"), + ("ollama", "ollama"), + ("local", "ollama"), + ("gemini", "google"), + ("", "unknown"), + (" ", "unknown"), + ("custom-provider", "custom-provider"), + ], +) +def test_normalize_provider(raw: str, expected: str) -> None: + assert normalize_provider(raw) == expected + + +# --------------------------------------------------------------------------- +# normalize_model +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "raw, expected", + [ + ("claude-sonnet-4-6", "claude-sonnet-4-6"), + ("claude-sonnet-4-6-20250219", "claude-sonnet-4-6"), + ("claude-3-5-sonnet-20241022", "claude-3-5-sonnet"), + ("anthropic/claude-opus-4-7", "claude-opus-4-7"), + ("gpt-4o-2024-05-13", "gpt-4o"), + ("gpt-4o-mini", "gpt-4o-mini"), + ("claude-3-haiku-20240307", "claude-3-haiku"), + ("llama3:latest", "llama3:latest"), # local model — strip :latest via re + ("o1-preview", "o1"), + ("gpt-4-turbo-preview", "gpt-4-turbo"), + ], +) +def test_normalize_model(raw: str, expected: str) -> None: + result = normalize_model(raw) + # We only guarantee the date-stamp is stripped; allow minor variation + assert expected in result or result == expected + + +# --------------------------------------------------------------------------- +# model_key +# --------------------------------------------------------------------------- + +def test_model_key() -> None: + assert model_key("anthropic", "claude-sonnet-4-6") == "anthropic/claude-sonnet-4-6" + + +# --------------------------------------------------------------------------- +# estimate_cost +# --------------------------------------------------------------------------- + +def test_estimate_cost_known_model() -> None: + cost, unpriced = estimate_cost("anthropic", "claude-sonnet-4-6", 1_000_000, 1_000_000) + assert not unpriced + # 1M input @ $3 + 1M output @ $15 = $18 + assert abs(cost - 18.0) < 0.01 + + +def test_estimate_cost_with_cache_tokens() -> None: + cost, unpriced = estimate_cost( + "anthropic", "claude-sonnet-4-6", + input_tokens=0, output_tokens=0, + cache_read_tokens=1_000_000, cache_write_tokens=1_000_000, + ) + assert not unpriced + # $0.30 cache_read + $3.75 cache_write = $4.05 + assert abs(cost - 4.05) < 0.01 + + +def test_estimate_cost_ollama_is_free() -> None: + cost, unpriced = estimate_cost("ollama", "llama3", 100_000, 50_000) + assert cost == 0.0 + assert not unpriced # Ollama is intentionally free, not unpriced + + +def test_estimate_cost_unknown_paid_model() -> None: + cost, unpriced = estimate_cost("anthropic", "claude-99-ultra", 1_000, 1_000) + assert cost == 0.0 + assert unpriced # unknown model — must flag + + +def test_estimate_cost_zero_tokens() -> None: + cost, unpriced = estimate_cost("anthropic", "claude-haiku-4-5", 0, 0) + assert cost == 0.0 + assert not unpriced + + +# --------------------------------------------------------------------------- +# load_pricing +# --------------------------------------------------------------------------- + +def test_load_pricing_has_defaults() -> None: + pricing = load_pricing() + assert "anthropic/claude-sonnet-4-6" in pricing + assert "openai/gpt-4o" in pricing + + +def test_load_pricing_has_required_fields() -> None: + pricing = load_pricing() + for key, entry in pricing.items(): + assert "input" in entry, f"{key} missing input" + assert "output" in entry, f"{key} missing output" + + +# --------------------------------------------------------------------------- +# _parse_sessions +# --------------------------------------------------------------------------- + +_SESSION_A = { + "sessionId": "sess-a", + "provider": "anthropic", + "model": "claude-sonnet-4-6", + "usage": {"input_tokens": 1000, "output_tokens": 500}, + "cost": 0.012, + "calls": 3, + "updatedAt": "2026-05-20T10:00:00Z", +} +_SESSION_B = { + "id": "sess-b", + "model": "gpt-4o", + "usage": {"inputTokens": 2000, "outputTokens": 800}, + "costUsd": 0.013, + "calls": 2, + "updatedAt": "2026-05-20T09:00:00Z", +} + + +def test_parse_sessions_flat_list() -> None: + raw = {"sessions": [_SESSION_A, _SESSION_B]} + sessions = _parse_sessions(raw) + assert len(sessions) == 2 + + +def test_parse_sessions_nested_5hour() -> None: + raw = {"5hour": {"sessions": [_SESSION_A]}} + sessions = _parse_sessions(raw) + assert len(sessions) == 1 + + +def test_parse_sessions_empty() -> None: + assert _parse_sessions({}) == [] + + +def test_parse_sessions_malformed_entries_skipped() -> None: + raw = {"sessions": [_SESSION_A, "bad-string", None, 42, _SESSION_B]} + sessions = _parse_sessions(raw) + assert len(sessions) == 2 + + +# --------------------------------------------------------------------------- +# aggregate_per_model +# --------------------------------------------------------------------------- + +def test_aggregate_per_model_basic() -> None: + per_model = aggregate_per_model([_SESSION_A], account_key="claude-default") + key = "anthropic/claude-sonnet-4-6" + assert key in per_model + entry = per_model[key] + assert entry.input_tokens == 1000 + assert entry.output_tokens == 500 + assert entry.total_tokens == 1500 + assert entry.calls == 3 + assert entry.provider == "anthropic" + assert entry.account_key == "claude-default" + assert not entry.unpriced + + +def test_aggregate_per_model_merges_same_model() -> None: + sessions = [_SESSION_A, {**_SESSION_A, "sessionId": "sess-c", "usage": {"input_tokens": 200, "output_tokens": 100}}] + per_model = aggregate_per_model(sessions) + entry = per_model["anthropic/claude-sonnet-4-6"] + assert entry.input_tokens == 1200 + assert entry.output_tokens == 600 + + +def test_aggregate_per_model_unknown_model_flagged() -> None: + session = { + "sessionId": "x", + "provider": "anthropic", + "model": "claude-99-ultra", + "usage": {"input_tokens": 100, "output_tokens": 50}, + "calls": 1, + } + per_model = aggregate_per_model([session]) + key = "anthropic/claude-99-ultra" + assert per_model[key].unpriced + + +def test_aggregate_per_model_ollama_not_flagged() -> None: + session = { + "sessionId": "y", + "provider": "ollama", + "model": "llama3", + "usage": {"input_tokens": 5000, "output_tokens": 2000}, + "calls": 1, + } + per_model = aggregate_per_model([session]) + entry = per_model["ollama/llama3"] + assert not entry.unpriced + assert entry.cost_usd == 0.0 + + +# --------------------------------------------------------------------------- +# _build_window +# --------------------------------------------------------------------------- + +def _now_naive() -> datetime: + return datetime.now(timezone.utc).replace(tzinfo=None) + + +def test_build_window_falls_back_to_5h_rolling() -> None: + now = _now_naive() + window = _build_window({}, now) + assert window.key == "5h" + assert abs((now - window.started_at).total_seconds() - 5 * 3600) < 5 + assert window.reset_in_ms == 0 # resets_at == now + + +def test_build_window_uses_gateway_status() -> None: + now = _now_naive() + started = now - timedelta(hours=3) + resets = now + timedelta(hours=2) + status_raw = { + "windowStart": started.isoformat() + "Z", + "windowEnd": resets.isoformat() + "Z", + } + window = _build_window(status_raw, now) + assert abs(window.reset_in_ms - 2 * 3600 * 1000) < 5000 # within 5 seconds + + +# --------------------------------------------------------------------------- +# _compute_burn_rate +# --------------------------------------------------------------------------- + +def test_compute_burn_rate_recent_sessions() -> None: + now = _now_naive() + recent = (now - timedelta(minutes=30)).isoformat() + "Z" + sessions = [ + {"updatedAt": recent, "usage": {"input_tokens": 6000, "output_tokens": 0}, "cost": 0.018}, + ] + window = RuntimeUsageWindow( + key="5h", + started_at=now - timedelta(hours=5), + resets_at=now, + reset_in_ms=0, + ) + burn = _compute_burn_rate(sessions, window, now) + assert burn.tokens_per_minute == pytest.approx(6000 / 60, abs=1) + assert burn.cost_usd_per_minute == pytest.approx(0.018 / 60, abs=1e-6) + + +def test_compute_burn_rate_no_recent_sessions() -> None: + now = _now_naive() + old = (now - timedelta(hours=3)).isoformat() + "Z" + sessions = [{"updatedAt": old, "usage": {"input_tokens": 1000, "output_tokens": 0}, "cost": 0.01}] + window = RuntimeUsageWindow(key="5h", started_at=now - timedelta(hours=5), resets_at=now, reset_in_ms=0) + burn = _compute_burn_rate(sessions, window, now) + assert burn.tokens_per_minute == 0.0 + assert burn.cost_usd_per_minute == 0.0 + + +# --------------------------------------------------------------------------- +# _build_predictions +# --------------------------------------------------------------------------- + +def _make_window(reset_in_ms: int) -> RuntimeUsageWindow: + now = _now_naive() + return RuntimeUsageWindow( + key="5h", + started_at=now - timedelta(hours=5), + resets_at=now + timedelta(milliseconds=reset_in_ms), + reset_in_ms=reset_in_ms, + ) + + +def test_build_predictions_no_limit() -> None: + current = RuntimeUsageCurrent(total_cost_usd=1.0, total_tokens=5000, total_calls=10) + burn = RuntimeUsageBurnRate(tokens_per_minute=100.0, cost_usd_per_minute=0.01) + window = _make_window(reset_in_ms=60_000) + pred = _build_predictions(current, burn, window) + assert pred.time_to_limit_ms is None + assert pred.safe is True + + +def test_build_predictions_safe() -> None: + current = RuntimeUsageCurrent( + total_cost_usd=1.0, total_tokens=10_000, total_calls=5, + token_limit=100_000, # 90k remaining + ) + burn = RuntimeUsageBurnRate(tokens_per_minute=100.0, cost_usd_per_minute=0.01) + # 90k tokens @ 100/min = 900 minutes = 54,000,000 ms + # reset in 30 minutes = 1,800,000 ms → safe=True + window = _make_window(reset_in_ms=1_800_000) + pred = _build_predictions(current, burn, window) + assert pred.time_to_limit_ms is not None + assert pred.time_to_limit_ms > 1_800_000 + assert pred.safe is True + + +def test_build_predictions_unsafe() -> None: + current = RuntimeUsageCurrent( + total_cost_usd=1.0, total_tokens=95_000, total_calls=5, + token_limit=100_000, # only 5k left + ) + burn = RuntimeUsageBurnRate(tokens_per_minute=1000.0, cost_usd_per_minute=0.05) + # 5k tokens @ 1000/min = 5 minutes = 300,000 ms + # reset in 30 minutes = 1,800,000 ms → safe=False (will hit limit before reset) + window = _make_window(reset_in_ms=1_800_000) + pred = _build_predictions(current, burn, window) + assert pred.time_to_limit_ms is not None + assert pred.time_to_limit_ms < 1_800_000 + assert pred.safe is False + + +def test_build_predictions_already_over_limit() -> None: + current = RuntimeUsageCurrent( + total_cost_usd=5.0, total_tokens=110_000, total_calls=20, + token_limit=100_000, + ) + burn = RuntimeUsageBurnRate(tokens_per_minute=500.0, cost_usd_per_minute=0.05) + window = _make_window(reset_in_ms=1_800_000) + pred = _build_predictions(current, burn, window) + assert pred.time_to_limit_ms == 0 + assert pred.safe is False