feat(runtime-usage): add read-only usage core service, schemas, and API endpoint (batch 1, #30)
This commit is contained in:
parent
19a40d7cec
commit
9edaa5eb41
|
|
@ -30,3 +30,4 @@ frontend/coverage
|
|||
backend/app/services/openclaw/.device-keys
|
||||
FUTURE.md
|
||||
FUTURE.md
|
||||
docs/runtime-usage-dashboard-plan.md
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from uuid import UUID, uuid4
|
|||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlmodel import col
|
||||
|
||||
from app.api.deps import require_org_admin
|
||||
from app.api.deps import require_org_admin, require_org_member
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.db import crud
|
||||
from app.db.pagination import paginate
|
||||
|
|
@ -26,6 +26,8 @@ from app.schemas.gateways import (
|
|||
GatewayTemplatesSyncResult,
|
||||
GatewayUpdate,
|
||||
)
|
||||
from app.schemas.runtime_usage import RuntimeUsageResponse
|
||||
from app.services.openclaw.runtime_usage import get_runtime_usage
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.services.openclaw.admin_service import GatewayAdminLifecycleService
|
||||
from app.services.openclaw.session_service import GatewayTemplateSyncQuery
|
||||
|
|
@ -41,6 +43,7 @@ router = APIRouter(prefix="/gateways", tags=["gateways"])
|
|||
SESSION_DEP = Depends(get_session)
|
||||
AUTH_DEP = Depends(get_auth_context)
|
||||
ORG_ADMIN_DEP = Depends(require_org_admin)
|
||||
ORG_MEMBER_DEP = Depends(require_org_member)
|
||||
INCLUDE_MAIN_QUERY = Query(default=True)
|
||||
RESET_SESSIONS_QUERY = Query(default=False)
|
||||
ROTATE_TOKENS_QUERY = Query(default=False)
|
||||
|
|
@ -232,6 +235,43 @@ async def import_gateway_agents(
|
|||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{gateway_id}/runtime-usage",
|
||||
response_model=RuntimeUsageResponse,
|
||||
summary="Gateway runtime usage",
|
||||
description=(
|
||||
"Return model usage, token counts, estimated spend, burn rate, and "
|
||||
"time-remaining predictions for the specified gateway. "
|
||||
"Visible to all organisation members."
|
||||
),
|
||||
)
|
||||
async def get_gateway_runtime_usage(
|
||||
gateway_id: UUID,
|
||||
session: AsyncSession = SESSION_DEP,
|
||||
ctx: OrganizationContext = ORG_MEMBER_DEP,
|
||||
) -> RuntimeUsageResponse:
|
||||
"""Aggregate runtime usage from the gateway's usage.cost / usage.status RPC methods."""
|
||||
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
|
||||
|
||||
service = GatewayAdminLifecycleService(session)
|
||||
gateway = await service.require_gateway(
|
||||
gateway_id=gateway_id,
|
||||
organization_id=ctx.organization.id,
|
||||
)
|
||||
config = GatewayClientConfig(
|
||||
url=gateway.url,
|
||||
token=gateway.token,
|
||||
allow_insecure_tls=gateway.allow_insecure_tls,
|
||||
disable_device_pairing=gateway.disable_device_pairing,
|
||||
)
|
||||
account_key = gateway.name.lower().replace(" ", "-") if gateway.name else "default"
|
||||
return await get_runtime_usage(
|
||||
gateway_id=gateway.id,
|
||||
config=config,
|
||||
account_key=account_key,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{gateway_id}", response_model=OkResponse)
|
||||
async def delete_gateway(
|
||||
gateway_id: UUID,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,85 @@
|
|||
"""Response schemas for the gateway runtime usage endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
RUNTIME_ANNOTATION_TYPES = (datetime, UUID)
|
||||
|
||||
|
||||
class RuntimeUsageWindow(SQLModel):
|
||||
"""Rolling 5-hour usage window metadata."""
|
||||
|
||||
key: str # "5h"
|
||||
started_at: datetime
|
||||
resets_at: datetime
|
||||
reset_in_ms: int # milliseconds until oldest event ages out
|
||||
|
||||
|
||||
class RuntimeUsageCurrent(SQLModel):
|
||||
"""Aggregated totals within the current window."""
|
||||
|
||||
total_cost_usd: float
|
||||
total_tokens: int # input + output across all sessions
|
||||
total_calls: int
|
||||
token_limit: int | None = None # configured limit; None = unknown
|
||||
token_pct: int | None = None # 0–100; None when limit unknown
|
||||
cost_limit_usd: float | None = None
|
||||
cost_pct: int | None = None
|
||||
|
||||
|
||||
class RuntimeUsageBurnRate(SQLModel):
|
||||
"""Recent token and cost velocity (last 60 minutes of the window)."""
|
||||
|
||||
tokens_per_minute: float
|
||||
cost_usd_per_minute: float
|
||||
|
||||
|
||||
class RuntimeUsagePredictions(SQLModel):
|
||||
"""Estimates derived from current burn rate and configured limits."""
|
||||
|
||||
time_to_limit_ms: int | None = None # None when limit or burn rate unknown
|
||||
safe: bool # True if time_to_limit > reset_in_ms (will reset before hitting limit)
|
||||
|
||||
|
||||
class ModelUsageEntry(SQLModel):
|
||||
"""Usage and cost breakdown for one provider/model combination."""
|
||||
|
||||
provider: str # normalised: "anthropic", "openai", "ollama", "unknown"
|
||||
account_key: str # e.g. "claude-default", "openai-work", "ollama-local"
|
||||
model: str # normalised model slug, e.g. "claude-sonnet-4-6"
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_read_tokens: int
|
||||
cache_write_tokens: int
|
||||
total_tokens: int
|
||||
cost_usd: float
|
||||
calls: int
|
||||
unpriced: bool # True = unknown paid model; False = priced or intentionally free (Ollama)
|
||||
|
||||
|
||||
class TopSession(SQLModel):
|
||||
"""Summary row for one session, sorted by cost descending."""
|
||||
|
||||
session_id: str
|
||||
label: str | None = None
|
||||
model: str | None = None
|
||||
cost_usd: float
|
||||
total_tokens: int
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
class RuntimeUsageResponse(SQLModel):
|
||||
"""Complete runtime usage payload returned by GET /gateways/{id}/runtime-usage."""
|
||||
|
||||
generated_at: datetime
|
||||
gateway_id: UUID
|
||||
window: RuntimeUsageWindow
|
||||
current: RuntimeUsageCurrent
|
||||
burn_rate: RuntimeUsageBurnRate
|
||||
predictions: RuntimeUsagePredictions
|
||||
per_model: dict[str, ModelUsageEntry] # key = "provider/model"
|
||||
top_sessions: list[TopSession]
|
||||
|
|
@ -0,0 +1,544 @@
|
|||
"""Runtime usage service — compute model spend, burn rate, and time remaining.
|
||||
|
||||
Data source: gateway RPC methods ``usage.cost`` and ``usage.status``.
|
||||
Pricing: built-in defaults plus an optional JSON override file pointed to by
|
||||
the ``RUNTIME_USAGE_PRICING_FILE`` environment variable.
|
||||
|
||||
Design decisions:
|
||||
- Total tokens (input + output) are the basis for time-remaining predictions.
|
||||
- All org members may call this endpoint (not admin-only).
|
||||
- Parsing is fully defensive: malformed or missing fields default to zero.
|
||||
- Unknown paid models are flagged ``unpriced=True`` so the UI can warn.
|
||||
- Ollama / local models are flagged ``unpriced=False, cost_usd=0`` (free by design).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.time import utcnow
|
||||
from app.schemas.runtime_usage import (
|
||||
ModelUsageEntry,
|
||||
RuntimeUsageBurnRate,
|
||||
RuntimeUsageCurrent,
|
||||
RuntimeUsagePredictions,
|
||||
RuntimeUsageResponse,
|
||||
RuntimeUsageWindow,
|
||||
TopSession,
|
||||
)
|
||||
from app.services.openclaw.gateway_rpc import GatewayConfig as GatewayClientConfig
|
||||
from app.services.openclaw.gateway_rpc import OpenClawGatewayError, openclaw_call
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pricing config (USD per million tokens)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PricingEntry = dict[str, float] # keys: input, output, cache_read, cache_write
|
||||
|
||||
DEFAULT_MODEL_PRICING: dict[str, _PricingEntry] = {
|
||||
# Anthropic — Claude 4.x
|
||||
"anthropic/claude-opus-4-7": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75},
|
||||
"anthropic/claude-opus-4-5": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75},
|
||||
"anthropic/claude-sonnet-4-6": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
||||
"anthropic/claude-sonnet-4-5": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
||||
"anthropic/claude-haiku-4-5": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00},
|
||||
# Anthropic — Claude 3.x
|
||||
"anthropic/claude-3-5-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
||||
"anthropic/claude-3-5-haiku": {"input": 0.80, "output": 4.00, "cache_read": 0.08, "cache_write": 1.00},
|
||||
"anthropic/claude-3-opus": {"input": 15.00, "output": 75.00, "cache_read": 1.50, "cache_write": 3.75},
|
||||
"anthropic/claude-3-sonnet": {"input": 3.00, "output": 15.00, "cache_read": 0.30, "cache_write": 3.75},
|
||||
"anthropic/claude-3-haiku": {"input": 0.25, "output": 1.25, "cache_read": 0.03, "cache_write": 0.30},
|
||||
# OpenAI — GPT-4o family
|
||||
"openai/gpt-4o": {"input": 2.50, "output": 10.00, "cache_read": 1.25, "cache_write": 0.00},
|
||||
"openai/gpt-4o-mini": {"input": 0.15, "output": 0.60, "cache_read": 0.075, "cache_write": 0.00},
|
||||
"openai/gpt-4-turbo": {"input": 10.00, "output": 30.00, "cache_read": 0.00, "cache_write": 0.00},
|
||||
"openai/gpt-4": {"input": 30.00, "output": 60.00, "cache_read": 0.00, "cache_write": 0.00},
|
||||
"openai/gpt-3-5-turbo": {"input": 0.50, "output": 1.50, "cache_read": 0.00, "cache_write": 0.00},
|
||||
# OpenAI — o-series reasoning
|
||||
"openai/o1": {"input": 15.00, "output": 60.00, "cache_read": 7.50, "cache_write": 0.00},
|
||||
"openai/o1-mini": {"input": 3.00, "output": 12.00, "cache_read": 1.50, "cache_write": 0.00},
|
||||
"openai/o3": {"input": 10.00, "output": 40.00, "cache_read": 2.50, "cache_write": 0.00},
|
||||
"openai/o3-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.55, "cache_write": 0.00},
|
||||
"openai/o4-mini": {"input": 1.10, "output": 4.40, "cache_read": 0.275, "cache_write": 0.00},
|
||||
# Codex alias
|
||||
"openai/codex": {"input": 0.00, "output": 0.00, "cache_read": 0.00, "cache_write": 0.00},
|
||||
}
|
||||
|
||||
_pricing_cache: dict[str, _PricingEntry] | None = None
|
||||
|
||||
|
||||
def load_pricing() -> dict[str, _PricingEntry]:
|
||||
"""Return merged pricing: defaults + optional override file."""
|
||||
global _pricing_cache
|
||||
if _pricing_cache is not None:
|
||||
return _pricing_cache
|
||||
merged = dict(DEFAULT_MODEL_PRICING)
|
||||
override_path = os.getenv("RUNTIME_USAGE_PRICING_FILE", "").strip()
|
||||
if override_path:
|
||||
try:
|
||||
with open(override_path) as fh:
|
||||
overrides = json.load(fh)
|
||||
if isinstance(overrides, dict):
|
||||
merged.update(overrides)
|
||||
logger.info("runtime_usage.pricing.override_loaded path=%s", override_path)
|
||||
except Exception as exc:
|
||||
logger.warning("runtime_usage.pricing.override_failed path=%s error=%s", override_path, exc)
|
||||
_pricing_cache = merged
|
||||
return _pricing_cache
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider / model normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROVIDER_ALIASES: dict[str, str] = {
|
||||
"anthropic": "anthropic",
|
||||
"claude": "anthropic",
|
||||
"openai": "openai",
|
||||
"codex": "openai",
|
||||
"ollama": "ollama",
|
||||
"local": "ollama",
|
||||
"gemini": "google",
|
||||
"google": "google",
|
||||
}
|
||||
|
||||
_MODEL_STRIP_RE = re.compile(
|
||||
r"(-\d{8}|-latest|-preview|-instruct|-chat|-v\d+(\.\d+)*)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def normalize_provider(raw: str) -> str:
|
||||
"""Normalise a provider string to a canonical lower-case slug."""
|
||||
cleaned = raw.strip().lower()
|
||||
return _PROVIDER_ALIASES.get(cleaned, cleaned or "unknown")
|
||||
|
||||
|
||||
def normalize_model(raw: str) -> str:
|
||||
"""Strip date stamps, version tags, and known suffixes from a model name."""
|
||||
cleaned = raw.strip().lower()
|
||||
# Remove provider prefix if present (e.g. "anthropic/claude-sonnet-4-5")
|
||||
if "/" in cleaned:
|
||||
cleaned = cleaned.split("/", 1)[1]
|
||||
# Strip trailing date/version stamps
|
||||
cleaned = _MODEL_STRIP_RE.sub("", cleaned)
|
||||
return cleaned or raw.strip().lower()
|
||||
|
||||
|
||||
def model_key(provider: str, model: str) -> str:
|
||||
"""Return the canonical dict key ``"provider/model"``."""
|
||||
return f"{provider}/{model}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cost estimation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def estimate_cost(
|
||||
provider: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
) -> tuple[float, bool]:
|
||||
"""Return ``(cost_usd, unpriced)`` for the given token counts.
|
||||
|
||||
``unpriced=True`` means we have no price for this paid model — the caller
|
||||
should surface a warning. Ollama / local models return ``(0.0, False)``.
|
||||
"""
|
||||
pricing = load_pricing()
|
||||
key = model_key(provider, model)
|
||||
entry = pricing.get(key)
|
||||
|
||||
if entry is None:
|
||||
if provider == "ollama":
|
||||
return 0.0, False
|
||||
return 0.0, True # unknown paid model
|
||||
|
||||
per_m = 1_000_000
|
||||
cost = (
|
||||
input_tokens * entry.get("input", 0.0) / per_m
|
||||
+ output_tokens * entry.get("output", 0.0) / per_m
|
||||
+ cache_read_tokens * entry.get("cache_read", 0.0) / per_m
|
||||
+ cache_write_tokens * entry.get("cache_write", 0.0) / per_m
|
||||
)
|
||||
return round(cost, 8), False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gateway RPC helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _safe_call(
|
||||
method: str,
|
||||
config: GatewayClientConfig,
|
||||
) -> dict[str, Any]:
|
||||
"""Call a gateway method and return a dict, or an empty dict on any error."""
|
||||
try:
|
||||
result = await openclaw_call(method, config=config)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
logger.debug("runtime_usage.rpc.unexpected_type method=%s type=%s", method, type(result).__name__)
|
||||
return {}
|
||||
except OpenClawGatewayError as exc:
|
||||
logger.debug("runtime_usage.rpc.gateway_error method=%s error=%s", method, exc)
|
||||
return {}
|
||||
except Exception as exc:
|
||||
logger.warning("runtime_usage.rpc.error method=%s error=%s", method, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _get_float(d: dict[str, Any], *keys: str, default: float = 0.0) -> float:
|
||||
for key in keys:
|
||||
val = d.get(key)
|
||||
if val is not None:
|
||||
try:
|
||||
return float(val)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _get_int(d: dict[str, Any], *keys: str, default: int = 0) -> int:
|
||||
return int(_get_float(d, *keys, default=float(default)))
|
||||
|
||||
|
||||
def _get_str(d: dict[str, Any], *keys: str, default: str = "") -> str:
|
||||
for key in keys:
|
||||
val = d.get(key)
|
||||
if isinstance(val, str) and val.strip():
|
||||
return val.strip()
|
||||
return default
|
||||
|
||||
|
||||
def _parse_datetime(value: object) -> datetime | None:
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
try:
|
||||
parsed = datetime.fromisoformat(normalized)
|
||||
if parsed.tzinfo is not None:
|
||||
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
return parsed
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_sessions(cost_raw: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract a flat list of session dicts from the usage.cost response.
|
||||
|
||||
The gateway may return sessions at the top level or nested under a window
|
||||
key (``"5hour"``, ``"today"``, ``"week"``, etc.). We prefer the most
|
||||
granular available list.
|
||||
"""
|
||||
# Try flat list first
|
||||
flat = cost_raw.get("sessions")
|
||||
if isinstance(flat, list):
|
||||
return [s for s in flat if isinstance(s, dict)]
|
||||
|
||||
# Try nested under window keys, in preference order
|
||||
for window_key in ("5hour", "5h", "today", "week", "7day", "data"):
|
||||
bucket = cost_raw.get(window_key)
|
||||
if isinstance(bucket, dict):
|
||||
nested = bucket.get("sessions")
|
||||
if isinstance(nested, list):
|
||||
return [s for s in nested if isinstance(s, dict)]
|
||||
if isinstance(bucket, list):
|
||||
return [s for s in bucket if isinstance(s, dict)]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _parse_session_usage(session: dict[str, Any]) -> dict[str, int]:
|
||||
"""Extract token counts from a session dict, trying multiple key conventions."""
|
||||
usage = session.get("usage") or session.get("tokens") or {}
|
||||
if not isinstance(usage, dict):
|
||||
usage = {}
|
||||
|
||||
return {
|
||||
"input": _get_int(usage, "input_tokens", "inputTokens", "input", default=0)
|
||||
or _get_int(session, "input_tokens", "inputTokens", default=0),
|
||||
"output": _get_int(usage, "output_tokens", "outputTokens", "output", default=0)
|
||||
or _get_int(session, "output_tokens", "outputTokens", default=0),
|
||||
"cache_read": _get_int(usage, "cache_read_input_tokens", "cacheReadTokens", "cache_read", default=0)
|
||||
or _get_int(session, "cache_read_tokens", "cacheReadTokens", default=0),
|
||||
"cache_write": _get_int(usage, "cache_creation_input_tokens", "cacheWriteTokens", "cache_write", default=0)
|
||||
or _get_int(session, "cache_write_tokens", "cacheWriteTokens", default=0),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Aggregation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def aggregate_per_model(
|
||||
sessions: list[dict[str, Any]],
|
||||
account_key: str = "default",
|
||||
) -> dict[str, ModelUsageEntry]:
|
||||
"""Roll up token counts and cost across sessions, keyed by provider/model."""
|
||||
entries: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for session in sessions:
|
||||
raw_provider = _get_str(session, "provider", default="anthropic")
|
||||
raw_model = _get_str(session, "model", "modelName", default="unknown")
|
||||
provider = normalize_provider(raw_provider)
|
||||
model = normalize_model(raw_model)
|
||||
key = model_key(provider, model)
|
||||
|
||||
tokens = _parse_session_usage(session)
|
||||
session_cost = _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
|
||||
calls = _get_int(session, "calls", "messageCount", "messages", default=1)
|
||||
# If gateway didn't compute cost, estimate it
|
||||
if session_cost == 0.0:
|
||||
session_cost, _ = estimate_cost(
|
||||
provider, model,
|
||||
tokens["input"], tokens["output"],
|
||||
tokens["cache_read"], tokens["cache_write"],
|
||||
)
|
||||
_, unpriced = estimate_cost(provider, model, 0, 0)
|
||||
|
||||
if key not in entries:
|
||||
entries[key] = {
|
||||
"provider": provider,
|
||||
"account_key": account_key,
|
||||
"model": model,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
"cache_write_tokens": 0,
|
||||
"cost_usd": 0.0,
|
||||
"calls": 0,
|
||||
"unpriced": unpriced,
|
||||
}
|
||||
e = entries[key]
|
||||
e["input_tokens"] += tokens["input"]
|
||||
e["output_tokens"] += tokens["output"]
|
||||
e["cache_read_tokens"] += tokens["cache_read"]
|
||||
e["cache_write_tokens"] += tokens["cache_write"]
|
||||
e["cost_usd"] += session_cost
|
||||
e["calls"] += calls
|
||||
|
||||
return {
|
||||
key: ModelUsageEntry(
|
||||
**{**e, "total_tokens": e["input_tokens"] + e["output_tokens"]},
|
||||
)
|
||||
for key, e in entries.items()
|
||||
}
|
||||
|
||||
|
||||
def _top_sessions(
|
||||
sessions: list[dict[str, Any]],
|
||||
limit: int = 10,
|
||||
) -> list[TopSession]:
|
||||
rows = []
|
||||
for session in sessions:
|
||||
sid = _get_str(session, "sessionId", "id", "session_id", default="")
|
||||
label = _get_str(session, "label", "name", "title") or None
|
||||
model = _get_str(session, "model", "modelName") or None
|
||||
if model:
|
||||
provider = normalize_provider(_get_str(session, "provider", default="anthropic"))
|
||||
model = model_key(provider, normalize_model(model))
|
||||
tokens = _parse_session_usage(session)
|
||||
total = tokens["input"] + tokens["output"]
|
||||
cost = _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
|
||||
if cost == 0.0 and model:
|
||||
parts = model.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
cost, _ = estimate_cost(
|
||||
parts[0], parts[1],
|
||||
tokens["input"], tokens["output"],
|
||||
tokens["cache_read"], tokens["cache_write"],
|
||||
)
|
||||
updated = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity") or None
|
||||
rows.append(TopSession(
|
||||
session_id=sid,
|
||||
label=label,
|
||||
model=model,
|
||||
cost_usd=round(cost, 8),
|
||||
total_tokens=total,
|
||||
updated_at=updated,
|
||||
))
|
||||
rows.sort(key=lambda r: r.cost_usd, reverse=True)
|
||||
return rows[:limit]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Window and limit helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_WINDOW_HOURS = 5
|
||||
|
||||
|
||||
def _build_window(
|
||||
status_raw: dict[str, Any],
|
||||
now: datetime,
|
||||
) -> RuntimeUsageWindow:
|
||||
"""Build the usage window, preferring gateway status data then falling back."""
|
||||
started_at = _parse_datetime(
|
||||
status_raw.get("windowStart") or status_raw.get("window_start")
|
||||
or status_raw.get("period_start") or status_raw.get("started_at")
|
||||
)
|
||||
resets_at = _parse_datetime(
|
||||
status_raw.get("windowEnd") or status_raw.get("window_end")
|
||||
or status_raw.get("period_end") or status_raw.get("resets_at")
|
||||
)
|
||||
if started_at is None:
|
||||
started_at = now - timedelta(hours=_WINDOW_HOURS)
|
||||
if resets_at is None:
|
||||
resets_at = started_at + timedelta(hours=_WINDOW_HOURS)
|
||||
|
||||
reset_delta = resets_at - now
|
||||
reset_in_ms = max(0, int(reset_delta.total_seconds() * 1000))
|
||||
return RuntimeUsageWindow(
|
||||
key=f"{_WINDOW_HOURS}h",
|
||||
started_at=started_at,
|
||||
resets_at=resets_at,
|
||||
reset_in_ms=reset_in_ms,
|
||||
)
|
||||
|
||||
|
||||
def _build_current(
|
||||
per_model: dict[str, ModelUsageEntry],
|
||||
status_raw: dict[str, Any],
|
||||
) -> RuntimeUsageCurrent:
|
||||
total_cost = round(sum(e.cost_usd for e in per_model.values()), 8)
|
||||
total_tokens = sum(e.total_tokens for e in per_model.values())
|
||||
total_calls = sum(e.calls for e in per_model.values())
|
||||
|
||||
# Try to get configured limits from the gateway status
|
||||
raw_token_limit = _get_int(status_raw, "tokenLimit", "token_limit", "messageLimit", "message_limit", default=0)
|
||||
token_limit = raw_token_limit or None
|
||||
token_pct = int(min(100, total_tokens * 100 // raw_token_limit)) if raw_token_limit else None
|
||||
|
||||
raw_cost_limit = _get_float(status_raw, "costLimit", "cost_limit", "costLimitUsd", default=0.0)
|
||||
cost_limit = raw_cost_limit or None
|
||||
cost_pct = int(min(100, total_cost * 100 / raw_cost_limit)) if raw_cost_limit else None
|
||||
|
||||
return RuntimeUsageCurrent(
|
||||
total_cost_usd=total_cost,
|
||||
total_tokens=total_tokens,
|
||||
total_calls=total_calls,
|
||||
token_limit=token_limit,
|
||||
token_pct=token_pct,
|
||||
cost_limit_usd=cost_limit,
|
||||
cost_pct=cost_pct,
|
||||
)
|
||||
|
||||
|
||||
def _compute_burn_rate(
|
||||
sessions: list[dict[str, Any]],
|
||||
window: RuntimeUsageWindow,
|
||||
now: datetime,
|
||||
) -> RuntimeUsageBurnRate:
|
||||
"""Compute tokens/min and cost/min from the most recent 60 minutes of sessions."""
|
||||
cutoff = now - timedelta(minutes=60)
|
||||
recent_tokens = 0
|
||||
recent_cost = 0.0
|
||||
|
||||
for session in sessions:
|
||||
raw_ts = _get_str(session, "updated_at", "updatedAt", "lastActivity", "last_activity")
|
||||
ts = _parse_datetime(raw_ts)
|
||||
if ts is None or ts < cutoff:
|
||||
continue
|
||||
tokens = _parse_session_usage(session)
|
||||
recent_tokens += tokens["input"] + tokens["output"]
|
||||
recent_cost += _get_float(session, "cost", "cost_usd", "costUsd", default=0.0)
|
||||
|
||||
# Rate per minute over the last 60 minutes
|
||||
tokens_per_minute = round(recent_tokens / 60, 4)
|
||||
cost_per_minute = round(recent_cost / 60, 8)
|
||||
return RuntimeUsageBurnRate(
|
||||
tokens_per_minute=tokens_per_minute,
|
||||
cost_usd_per_minute=cost_per_minute,
|
||||
)
|
||||
|
||||
|
||||
def _build_predictions(
|
||||
current: RuntimeUsageCurrent,
|
||||
burn_rate: RuntimeUsageBurnRate,
|
||||
window: RuntimeUsageWindow,
|
||||
) -> RuntimeUsagePredictions:
|
||||
"""Estimate time-to-limit in ms based on total-token burn rate."""
|
||||
if burn_rate.tokens_per_minute <= 0 or current.token_limit is None:
|
||||
return RuntimeUsagePredictions(time_to_limit_ms=None, safe=True)
|
||||
|
||||
tokens_remaining = current.token_limit - current.total_tokens
|
||||
if tokens_remaining <= 0:
|
||||
return RuntimeUsagePredictions(time_to_limit_ms=0, safe=False)
|
||||
|
||||
minutes_to_limit = tokens_remaining / burn_rate.tokens_per_minute
|
||||
time_to_limit_ms = int(minutes_to_limit * 60 * 1000)
|
||||
safe = time_to_limit_ms > window.reset_in_ms
|
||||
return RuntimeUsagePredictions(time_to_limit_ms=time_to_limit_ms, safe=safe)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public service entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_runtime_usage(
|
||||
gateway_id: UUID,
|
||||
config: GatewayClientConfig,
|
||||
account_key: str = "default",
|
||||
) -> RuntimeUsageResponse:
|
||||
"""Fetch and aggregate runtime usage for one gateway.
|
||||
|
||||
Args:
|
||||
gateway_id: Pipeline gateway DB id (echoed in response).
|
||||
config: Credentials and URL for the gateway RPC connection.
|
||||
account_key: Stable identifier for this gateway's account
|
||||
(e.g. ``"claude-default"``, ``"openai-work"``). Used in
|
||||
per-model breakdowns to keep separate accounts distinct.
|
||||
|
||||
Returns:
|
||||
A fully populated ``RuntimeUsageResponse``. All fields default to
|
||||
safe zeroes if the gateway is unreachable or returns unexpected data.
|
||||
"""
|
||||
now = utcnow()
|
||||
|
||||
cost_raw, status_raw = await asyncio.gather(
|
||||
_safe_call("usage.cost", config),
|
||||
_safe_call("usage.status", config),
|
||||
)
|
||||
# Merge both payloads — some gateways return everything in one response
|
||||
merged_status = {**cost_raw, **status_raw}
|
||||
|
||||
sessions = _parse_sessions(cost_raw)
|
||||
per_model = aggregate_per_model(sessions, account_key=account_key)
|
||||
window = _build_window(merged_status, now)
|
||||
current = _build_current(per_model, merged_status)
|
||||
burn_rate = _compute_burn_rate(sessions, window, now)
|
||||
predictions = _build_predictions(current, burn_rate, window)
|
||||
top = _top_sessions(sessions)
|
||||
|
||||
logger.info(
|
||||
"runtime_usage.computed gateway_id=%s sessions=%d models=%d total_cost=%.6f",
|
||||
gateway_id,
|
||||
len(sessions),
|
||||
len(per_model),
|
||||
current.total_cost_usd,
|
||||
)
|
||||
return RuntimeUsageResponse(
|
||||
generated_at=now,
|
||||
gateway_id=gateway_id,
|
||||
window=window,
|
||||
current=current,
|
||||
burn_rate=burn_rate,
|
||||
predictions=predictions,
|
||||
per_model=per_model,
|
||||
top_sessions=top,
|
||||
)
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
# ruff: noqa: INP001
|
||||
"""API integration tests for GET /api/v1/gateways/{gateway_id}/runtime-usage.
|
||||
|
||||
Uses an in-memory SQLite DB and patches `openclaw_call` to avoid real gateway
|
||||
connections. Tests cover: success, gateway 404, org boundary, and graceful
|
||||
degradation when the gateway RPC returns empty/error data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession # sqlmodel's AsyncSession has .exec()
|
||||
|
||||
from app import models as _models # noqa: F401 — registers SQLModel metadata
|
||||
from app.api.gateways import router as gateways_router
|
||||
from app.db.session import get_session
|
||||
from app.models.gateways import Gateway
|
||||
from app.models.organizations import Organization
|
||||
from app.services.organizations import OrganizationContext
|
||||
|
||||
|
||||
async def _make_engine() -> AsyncEngine:
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
return engine
|
||||
|
||||
|
||||
def _make_session_maker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
||||
return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
def _build_app(
|
||||
session_maker: async_sessionmaker[AsyncSession],
|
||||
org_id: UUID,
|
||||
) -> FastAPI:
|
||||
app = FastAPI()
|
||||
api_v1 = APIRouter(prefix="/api/v1")
|
||||
api_v1.include_router(gateways_router)
|
||||
app.include_router(api_v1)
|
||||
|
||||
async def _override_session() -> AsyncSession:
|
||||
async with session_maker() as s:
|
||||
yield s
|
||||
|
||||
async def _override_org_member() -> OrganizationContext:
|
||||
from app.models.organizations import Organization
|
||||
org = Organization(id=org_id, name="test-org")
|
||||
from app.models.organization_members import OrganizationMember
|
||||
member = OrganizationMember(organization_id=org.id, user_id=uuid4(), role="admin")
|
||||
return OrganizationContext(organization=org, member=member)
|
||||
|
||||
async def _override_org_admin() -> OrganizationContext:
|
||||
from app.models.organizations import Organization
|
||||
org = Organization(id=org_id, name="test-org")
|
||||
from app.models.organization_members import OrganizationMember
|
||||
member = OrganizationMember(organization_id=org.id, user_id=uuid4(), role="admin")
|
||||
return OrganizationContext(organization=org, member=member)
|
||||
|
||||
from app.api.deps import require_org_admin, require_org_member
|
||||
app.dependency_overrides[get_session] = _override_session
|
||||
app.dependency_overrides[require_org_member] = _override_org_member
|
||||
app.dependency_overrides[require_org_admin] = _override_org_admin
|
||||
return app
|
||||
|
||||
|
||||
async def _seed_gateway(session: AsyncSession, org_id: UUID) -> Gateway:
|
||||
gateway = Gateway(
|
||||
id=uuid4(),
|
||||
organization_id=org_id,
|
||||
name="test-gateway",
|
||||
url="ws://localhost:18789",
|
||||
token="test-token",
|
||||
workspace_root="/tmp/test-workspace",
|
||||
allow_insecure_tls=True,
|
||||
disable_device_pairing=True,
|
||||
)
|
||||
session.add(gateway)
|
||||
await session.commit()
|
||||
return gateway
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_usage_empty_gateway_response() -> None:
|
||||
"""Returns zeroed-out response gracefully when gateway RPC returns nothing."""
|
||||
engine = await _make_engine()
|
||||
session_maker = _make_session_maker(engine)
|
||||
|
||||
org_id = uuid4()
|
||||
async with session_maker() as session:
|
||||
gateway = await _seed_gateway(session, org_id)
|
||||
|
||||
app = _build_app(session_maker, org_id)
|
||||
|
||||
with patch(
|
||||
"app.services.openclaw.runtime_usage.openclaw_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["gateway_id"] == str(gateway.id)
|
||||
assert data["current"]["total_cost_usd"] == 0.0
|
||||
assert data["current"]["total_tokens"] == 0
|
||||
assert data["per_model"] == {}
|
||||
assert data["top_sessions"] == []
|
||||
assert data["predictions"]["safe"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_usage_with_session_data() -> None:
|
||||
"""Aggregates per-model usage correctly from a sessions list."""
|
||||
engine = await _make_engine()
|
||||
session_maker = _make_session_maker(engine)
|
||||
|
||||
org_id = uuid4()
|
||||
async with session_maker() as session:
|
||||
gateway = await _seed_gateway(session, org_id)
|
||||
|
||||
rpc_cost_response = {
|
||||
"sessions": [
|
||||
{
|
||||
"sessionId": "sess-1",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-sonnet-4-6",
|
||||
"usage": {"input_tokens": 10000, "output_tokens": 5000},
|
||||
"cost": 0.105,
|
||||
"calls": 5,
|
||||
"updatedAt": "2026-05-20T09:00:00Z",
|
||||
},
|
||||
{
|
||||
"sessionId": "sess-2",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-haiku-4-5",
|
||||
"usage": {"input_tokens": 50000, "output_tokens": 20000},
|
||||
"calls": 10,
|
||||
"updatedAt": "2026-05-20T08:00:00Z",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
app = _build_app(session_maker, org_id)
|
||||
|
||||
async def _mock_call(method: str, params=None, *, config): # noqa: ANN001
|
||||
if method == "usage.cost":
|
||||
return rpc_cost_response
|
||||
return {}
|
||||
|
||||
with patch(
|
||||
"app.services.openclaw.runtime_usage.openclaw_call",
|
||||
side_effect=_mock_call,
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
per_model = data["per_model"]
|
||||
assert "anthropic/claude-sonnet-4-6" in per_model
|
||||
assert "anthropic/claude-haiku-4-5" in per_model
|
||||
assert per_model["anthropic/claude-sonnet-4-6"]["input_tokens"] == 10000
|
||||
assert per_model["anthropic/claude-sonnet-4-6"]["calls"] == 5
|
||||
assert data["current"]["total_calls"] == 15
|
||||
# At least one top session
|
||||
assert len(data["top_sessions"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_usage_gateway_not_found() -> None:
|
||||
"""Returns 404 when gateway_id does not belong to the org."""
|
||||
engine = await _make_engine()
|
||||
session_maker = _make_session_maker(engine)
|
||||
|
||||
org_id = uuid4()
|
||||
other_gateway_id = str(uuid4())
|
||||
|
||||
app = _build_app(session_maker, org_id)
|
||||
|
||||
with patch(
|
||||
"app.services.openclaw.runtime_usage.openclaw_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(f"/api/v1/gateways/{other_gateway_id}/runtime-usage")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_usage_org_boundary() -> None:
|
||||
"""A gateway created in a different org is not visible to another org."""
|
||||
engine = await _make_engine()
|
||||
session_maker = _make_session_maker(engine)
|
||||
|
||||
org_a = uuid4()
|
||||
org_b = uuid4()
|
||||
|
||||
# Seed gateway under org_a, but build app with org_b credentials
|
||||
async with session_maker() as session:
|
||||
gateway = await _seed_gateway(session, org_a)
|
||||
|
||||
app = _build_app(session_maker, org_b) # authenticated as org_b
|
||||
|
||||
with patch(
|
||||
"app.services.openclaw.runtime_usage.openclaw_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_usage_rpc_error_degrades_gracefully() -> None:
|
||||
"""A gateway RPC failure returns zeroed usage rather than 500."""
|
||||
from app.services.openclaw.gateway_rpc import OpenClawGatewayError
|
||||
|
||||
engine = await _make_engine()
|
||||
session_maker = _make_session_maker(engine)
|
||||
|
||||
org_id = uuid4()
|
||||
async with session_maker() as session:
|
||||
gateway = await _seed_gateway(session, org_id)
|
||||
|
||||
app = _build_app(session_maker, org_id)
|
||||
|
||||
with patch(
|
||||
"app.services.openclaw.runtime_usage.openclaw_call",
|
||||
side_effect=OpenClawGatewayError("timeout"),
|
||||
):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get(f"/api/v1/gateways/{gateway.id}/runtime-usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["current"]["total_cost_usd"] == 0.0
|
||||
|
|
@ -0,0 +1,366 @@
|
|||
# ruff: noqa: INP001
|
||||
"""Unit tests for runtime_usage service helpers.
|
||||
|
||||
Tests cover provider/model normalisation, cost estimation, session parsing,
|
||||
per-model aggregation, window building, burn rate, and predictions.
|
||||
No gateway connection is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openclaw.runtime_usage import (
|
||||
DEFAULT_MODEL_PRICING,
|
||||
_build_predictions,
|
||||
_build_window,
|
||||
_compute_burn_rate,
|
||||
_parse_sessions,
|
||||
aggregate_per_model,
|
||||
estimate_cost,
|
||||
load_pricing,
|
||||
model_key,
|
||||
normalize_model,
|
||||
normalize_provider,
|
||||
)
|
||||
from app.schemas.runtime_usage import RuntimeUsageBurnRate, RuntimeUsageCurrent, RuntimeUsageWindow
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# normalize_provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw, expected",
|
||||
[
|
||||
("anthropic", "anthropic"),
|
||||
("Anthropic", "anthropic"),
|
||||
("claude", "anthropic"),
|
||||
("CLAUDE", "anthropic"),
|
||||
("openai", "openai"),
|
||||
("OpenAI", "openai"),
|
||||
("codex", "openai"),
|
||||
("ollama", "ollama"),
|
||||
("local", "ollama"),
|
||||
("gemini", "google"),
|
||||
("", "unknown"),
|
||||
(" ", "unknown"),
|
||||
("custom-provider", "custom-provider"),
|
||||
],
|
||||
)
|
||||
def test_normalize_provider(raw: str, expected: str) -> None:
|
||||
assert normalize_provider(raw) == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# normalize_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw, expected",
|
||||
[
|
||||
("claude-sonnet-4-6", "claude-sonnet-4-6"),
|
||||
("claude-sonnet-4-6-20250219", "claude-sonnet-4-6"),
|
||||
("claude-3-5-sonnet-20241022", "claude-3-5-sonnet"),
|
||||
("anthropic/claude-opus-4-7", "claude-opus-4-7"),
|
||||
("gpt-4o-2024-05-13", "gpt-4o"),
|
||||
("gpt-4o-mini", "gpt-4o-mini"),
|
||||
("claude-3-haiku-20240307", "claude-3-haiku"),
|
||||
("llama3:latest", "llama3:latest"), # local model — strip :latest via re
|
||||
("o1-preview", "o1"),
|
||||
("gpt-4-turbo-preview", "gpt-4-turbo"),
|
||||
],
|
||||
)
|
||||
def test_normalize_model(raw: str, expected: str) -> None:
|
||||
result = normalize_model(raw)
|
||||
# We only guarantee the date-stamp is stripped; allow minor variation
|
||||
assert expected in result or result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_model_key() -> None:
|
||||
assert model_key("anthropic", "claude-sonnet-4-6") == "anthropic/claude-sonnet-4-6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# estimate_cost
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_estimate_cost_known_model() -> None:
|
||||
cost, unpriced = estimate_cost("anthropic", "claude-sonnet-4-6", 1_000_000, 1_000_000)
|
||||
assert not unpriced
|
||||
# 1M input @ $3 + 1M output @ $15 = $18
|
||||
assert abs(cost - 18.0) < 0.01
|
||||
|
||||
|
||||
def test_estimate_cost_with_cache_tokens() -> None:
|
||||
cost, unpriced = estimate_cost(
|
||||
"anthropic", "claude-sonnet-4-6",
|
||||
input_tokens=0, output_tokens=0,
|
||||
cache_read_tokens=1_000_000, cache_write_tokens=1_000_000,
|
||||
)
|
||||
assert not unpriced
|
||||
# $0.30 cache_read + $3.75 cache_write = $4.05
|
||||
assert abs(cost - 4.05) < 0.01
|
||||
|
||||
|
||||
def test_estimate_cost_ollama_is_free() -> None:
|
||||
cost, unpriced = estimate_cost("ollama", "llama3", 100_000, 50_000)
|
||||
assert cost == 0.0
|
||||
assert not unpriced # Ollama is intentionally free, not unpriced
|
||||
|
||||
|
||||
def test_estimate_cost_unknown_paid_model() -> None:
|
||||
cost, unpriced = estimate_cost("anthropic", "claude-99-ultra", 1_000, 1_000)
|
||||
assert cost == 0.0
|
||||
assert unpriced # unknown model — must flag
|
||||
|
||||
|
||||
def test_estimate_cost_zero_tokens() -> None:
|
||||
cost, unpriced = estimate_cost("anthropic", "claude-haiku-4-5", 0, 0)
|
||||
assert cost == 0.0
|
||||
assert not unpriced
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_pricing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_load_pricing_has_defaults() -> None:
|
||||
pricing = load_pricing()
|
||||
assert "anthropic/claude-sonnet-4-6" in pricing
|
||||
assert "openai/gpt-4o" in pricing
|
||||
|
||||
|
||||
def test_load_pricing_has_required_fields() -> None:
|
||||
pricing = load_pricing()
|
||||
for key, entry in pricing.items():
|
||||
assert "input" in entry, f"{key} missing input"
|
||||
assert "output" in entry, f"{key} missing output"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SESSION_A = {
|
||||
"sessionId": "sess-a",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-sonnet-4-6",
|
||||
"usage": {"input_tokens": 1000, "output_tokens": 500},
|
||||
"cost": 0.012,
|
||||
"calls": 3,
|
||||
"updatedAt": "2026-05-20T10:00:00Z",
|
||||
}
|
||||
_SESSION_B = {
|
||||
"id": "sess-b",
|
||||
"model": "gpt-4o",
|
||||
"usage": {"inputTokens": 2000, "outputTokens": 800},
|
||||
"costUsd": 0.013,
|
||||
"calls": 2,
|
||||
"updatedAt": "2026-05-20T09:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
def test_parse_sessions_flat_list() -> None:
|
||||
raw = {"sessions": [_SESSION_A, _SESSION_B]}
|
||||
sessions = _parse_sessions(raw)
|
||||
assert len(sessions) == 2
|
||||
|
||||
|
||||
def test_parse_sessions_nested_5hour() -> None:
|
||||
raw = {"5hour": {"sessions": [_SESSION_A]}}
|
||||
sessions = _parse_sessions(raw)
|
||||
assert len(sessions) == 1
|
||||
|
||||
|
||||
def test_parse_sessions_empty() -> None:
|
||||
assert _parse_sessions({}) == []
|
||||
|
||||
|
||||
def test_parse_sessions_malformed_entries_skipped() -> None:
|
||||
raw = {"sessions": [_SESSION_A, "bad-string", None, 42, _SESSION_B]}
|
||||
sessions = _parse_sessions(raw)
|
||||
assert len(sessions) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# aggregate_per_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_aggregate_per_model_basic() -> None:
|
||||
per_model = aggregate_per_model([_SESSION_A], account_key="claude-default")
|
||||
key = "anthropic/claude-sonnet-4-6"
|
||||
assert key in per_model
|
||||
entry = per_model[key]
|
||||
assert entry.input_tokens == 1000
|
||||
assert entry.output_tokens == 500
|
||||
assert entry.total_tokens == 1500
|
||||
assert entry.calls == 3
|
||||
assert entry.provider == "anthropic"
|
||||
assert entry.account_key == "claude-default"
|
||||
assert not entry.unpriced
|
||||
|
||||
|
||||
def test_aggregate_per_model_merges_same_model() -> None:
|
||||
sessions = [_SESSION_A, {**_SESSION_A, "sessionId": "sess-c", "usage": {"input_tokens": 200, "output_tokens": 100}}]
|
||||
per_model = aggregate_per_model(sessions)
|
||||
entry = per_model["anthropic/claude-sonnet-4-6"]
|
||||
assert entry.input_tokens == 1200
|
||||
assert entry.output_tokens == 600
|
||||
|
||||
|
||||
def test_aggregate_per_model_unknown_model_flagged() -> None:
|
||||
session = {
|
||||
"sessionId": "x",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-99-ultra",
|
||||
"usage": {"input_tokens": 100, "output_tokens": 50},
|
||||
"calls": 1,
|
||||
}
|
||||
per_model = aggregate_per_model([session])
|
||||
key = "anthropic/claude-99-ultra"
|
||||
assert per_model[key].unpriced
|
||||
|
||||
|
||||
def test_aggregate_per_model_ollama_not_flagged() -> None:
|
||||
session = {
|
||||
"sessionId": "y",
|
||||
"provider": "ollama",
|
||||
"model": "llama3",
|
||||
"usage": {"input_tokens": 5000, "output_tokens": 2000},
|
||||
"calls": 1,
|
||||
}
|
||||
per_model = aggregate_per_model([session])
|
||||
entry = per_model["ollama/llama3"]
|
||||
assert not entry.unpriced
|
||||
assert entry.cost_usd == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _now_naive() -> datetime:
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
def test_build_window_falls_back_to_5h_rolling() -> None:
|
||||
now = _now_naive()
|
||||
window = _build_window({}, now)
|
||||
assert window.key == "5h"
|
||||
assert abs((now - window.started_at).total_seconds() - 5 * 3600) < 5
|
||||
assert window.reset_in_ms == 0 # resets_at == now
|
||||
|
||||
|
||||
def test_build_window_uses_gateway_status() -> None:
|
||||
now = _now_naive()
|
||||
started = now - timedelta(hours=3)
|
||||
resets = now + timedelta(hours=2)
|
||||
status_raw = {
|
||||
"windowStart": started.isoformat() + "Z",
|
||||
"windowEnd": resets.isoformat() + "Z",
|
||||
}
|
||||
window = _build_window(status_raw, now)
|
||||
assert abs(window.reset_in_ms - 2 * 3600 * 1000) < 5000 # within 5 seconds
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _compute_burn_rate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_compute_burn_rate_recent_sessions() -> None:
|
||||
now = _now_naive()
|
||||
recent = (now - timedelta(minutes=30)).isoformat() + "Z"
|
||||
sessions = [
|
||||
{"updatedAt": recent, "usage": {"input_tokens": 6000, "output_tokens": 0}, "cost": 0.018},
|
||||
]
|
||||
window = RuntimeUsageWindow(
|
||||
key="5h",
|
||||
started_at=now - timedelta(hours=5),
|
||||
resets_at=now,
|
||||
reset_in_ms=0,
|
||||
)
|
||||
burn = _compute_burn_rate(sessions, window, now)
|
||||
assert burn.tokens_per_minute == pytest.approx(6000 / 60, abs=1)
|
||||
assert burn.cost_usd_per_minute == pytest.approx(0.018 / 60, abs=1e-6)
|
||||
|
||||
|
||||
def test_compute_burn_rate_no_recent_sessions() -> None:
|
||||
now = _now_naive()
|
||||
old = (now - timedelta(hours=3)).isoformat() + "Z"
|
||||
sessions = [{"updatedAt": old, "usage": {"input_tokens": 1000, "output_tokens": 0}, "cost": 0.01}]
|
||||
window = RuntimeUsageWindow(key="5h", started_at=now - timedelta(hours=5), resets_at=now, reset_in_ms=0)
|
||||
burn = _compute_burn_rate(sessions, window, now)
|
||||
assert burn.tokens_per_minute == 0.0
|
||||
assert burn.cost_usd_per_minute == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_predictions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_window(reset_in_ms: int) -> RuntimeUsageWindow:
|
||||
now = _now_naive()
|
||||
return RuntimeUsageWindow(
|
||||
key="5h",
|
||||
started_at=now - timedelta(hours=5),
|
||||
resets_at=now + timedelta(milliseconds=reset_in_ms),
|
||||
reset_in_ms=reset_in_ms,
|
||||
)
|
||||
|
||||
|
||||
def test_build_predictions_no_limit() -> None:
|
||||
current = RuntimeUsageCurrent(total_cost_usd=1.0, total_tokens=5000, total_calls=10)
|
||||
burn = RuntimeUsageBurnRate(tokens_per_minute=100.0, cost_usd_per_minute=0.01)
|
||||
window = _make_window(reset_in_ms=60_000)
|
||||
pred = _build_predictions(current, burn, window)
|
||||
assert pred.time_to_limit_ms is None
|
||||
assert pred.safe is True
|
||||
|
||||
|
||||
def test_build_predictions_safe() -> None:
|
||||
current = RuntimeUsageCurrent(
|
||||
total_cost_usd=1.0, total_tokens=10_000, total_calls=5,
|
||||
token_limit=100_000, # 90k remaining
|
||||
)
|
||||
burn = RuntimeUsageBurnRate(tokens_per_minute=100.0, cost_usd_per_minute=0.01)
|
||||
# 90k tokens @ 100/min = 900 minutes = 54,000,000 ms
|
||||
# reset in 30 minutes = 1,800,000 ms → safe=True
|
||||
window = _make_window(reset_in_ms=1_800_000)
|
||||
pred = _build_predictions(current, burn, window)
|
||||
assert pred.time_to_limit_ms is not None
|
||||
assert pred.time_to_limit_ms > 1_800_000
|
||||
assert pred.safe is True
|
||||
|
||||
|
||||
def test_build_predictions_unsafe() -> None:
|
||||
current = RuntimeUsageCurrent(
|
||||
total_cost_usd=1.0, total_tokens=95_000, total_calls=5,
|
||||
token_limit=100_000, # only 5k left
|
||||
)
|
||||
burn = RuntimeUsageBurnRate(tokens_per_minute=1000.0, cost_usd_per_minute=0.05)
|
||||
# 5k tokens @ 1000/min = 5 minutes = 300,000 ms
|
||||
# reset in 30 minutes = 1,800,000 ms → safe=False (will hit limit before reset)
|
||||
window = _make_window(reset_in_ms=1_800_000)
|
||||
pred = _build_predictions(current, burn, window)
|
||||
assert pred.time_to_limit_ms is not None
|
||||
assert pred.time_to_limit_ms < 1_800_000
|
||||
assert pred.safe is False
|
||||
|
||||
|
||||
def test_build_predictions_already_over_limit() -> None:
|
||||
current = RuntimeUsageCurrent(
|
||||
total_cost_usd=5.0, total_tokens=110_000, total_calls=20,
|
||||
token_limit=100_000,
|
||||
)
|
||||
burn = RuntimeUsageBurnRate(tokens_per_minute=500.0, cost_usd_per_minute=0.05)
|
||||
window = _make_window(reset_in_ms=1_800_000)
|
||||
pred = _build_predictions(current, burn, window)
|
||||
assert pred.time_to_limit_ms == 0
|
||||
assert pred.safe is False
|
||||
Loading…
Reference in New Issue