diff --git a/backend/app/services/provider_usage.py b/backend/app/services/provider_usage.py index 664fe2b..5a81930 100644 --- a/backend/app/services/provider_usage.py +++ b/backend/app/services/provider_usage.py @@ -38,6 +38,7 @@ avoid hammering provider APIs on every page load. from __future__ import annotations +import asyncio import json as _json_module import os import re @@ -54,6 +55,7 @@ from app.core.time import utcnow logger = get_logger(__name__) CACHE_TTL_SECONDS = 60 +CACHE_TTL_FAILURE_SECONDS = 5 # short TTL for results with no subscription windows REQUEST_TIMEOUT = 8.0 # seconds @@ -772,8 +774,7 @@ async def _fetch_anthropic_subscription(session_key: str) -> list[SubscriptionWi logger.warning("provider_usage.subscription.anthropic.fetch_failed error=%s", exc) return [] if resp.status_code == 429 and attempt == 0: - import asyncio as _asyncio - await _asyncio.sleep(1.5) + await asyncio.sleep(1.5) continue break @@ -908,58 +909,44 @@ async def _fetch_codex_subscription(session_key: str) -> tuple[list[Subscription # --------------------------------------------------------------------------- -# In-memory TTL cache +# In-memory TTL cache + in-flight deduplication # --------------------------------------------------------------------------- -_cache: dict[str, tuple[datetime, ProviderUsageLive]] = {} +_cache: dict[str, tuple[datetime, ProviderUsageLive, int]] = {} +# Tracks in-progress fetches so concurrent requests share one result instead +# of each racing to hit the provider API (which triggers 429 cascades). +_inflight: dict[str, asyncio.Future[ProviderUsageLive]] = {} def _get_cached(credential_id: str) -> ProviderUsageLive | None: entry = _cache.get(credential_id) if entry is None: return None - cached_at, result = entry - if (utcnow() - cached_at).total_seconds() > CACHE_TTL_SECONDS: + cached_at, result, ttl = entry + if (utcnow() - cached_at).total_seconds() > ttl: del _cache[credential_id] return None return result -def _set_cached(credential_id: str, result: ProviderUsageLive) -> None: - _cache[credential_id] = (utcnow(), result) +def _set_cached(credential_id: str, result: ProviderUsageLive, ttl: int = CACHE_TTL_SECONDS) -> None: + _cache[credential_id] = (utcnow(), result, ttl) # --------------------------------------------------------------------------- # Public entry point # --------------------------------------------------------------------------- -async def fetch_provider_usage( +async def _do_fetch_provider_usage( credential_id: str, provider: str, account_key: str, api_key: str | None, base_url: str | None, - *, - session_key: str | None = None, - force_refresh: bool = False, + session_key: str | None, ) -> ProviderUsageLive: - """Fetch live usage from the provider API. - - When ``session_key`` is provided, also fetches subscription-plan usage - windows (e.g. "Current session 77% used, resets in 23 min") in addition - to the standard API rate-limit diagnostics. - - Results are cached for CACHE_TTL_SECONDS. Pass force_refresh=True to - bypass the cache (e.g., when the user clicks Refresh). - """ - if not force_refresh: - cached = _get_cached(credential_id) - if cached is not None: - return cached - - # Tracks whether a subscription-window fetch was attempted so we can skip - # caching on transient empty-window failures (avoids stale navbar state). - _subscription_attempted = False + """Inner fetch — called once per credential even under concurrent load.""" + subscription_attempted = False if provider == "anthropic": # Auto-detected local OAuth token (sk-ant-oat01-) takes precedence — it is @@ -984,7 +971,7 @@ async def fetch_provider_usage( ) # Overlay subscription windows from OAuth token (auto or explicit) if effective_session_key and result.reachable is not False: - _subscription_attempted = True + subscription_attempted = True sub_windows = await _fetch_anthropic_subscription(effective_session_key) if sub_windows: result.subscription_windows = sub_windows @@ -1012,7 +999,7 @@ async def fetch_provider_usage( local_codex = _read_codex_local_token() effective_codex_key = session_key or local_codex if effective_codex_key and result.reachable is not False: - _subscription_attempted = True + subscription_attempted = True sub_windows, plan_label = await _fetch_codex_subscription(effective_codex_key) if sub_windows: result.subscription_windows = sub_windows @@ -1032,18 +1019,60 @@ async def fetch_provider_usage( ) result.account_key = account_key - # Skip caching when subscription windows were expected but came back empty — - # a transient failure at startup would otherwise be served stale for up to - # CACHE_TTL_SECONDS, causing the navbar to show no usage data. - if _subscription_attempted and not result.subscription_windows: - logger.debug( - "provider_usage.skip_cache provider=%s reason=empty_subscription_windows", - provider, - ) - else: - _set_cached(credential_id, result) + # Use a short TTL when subscription windows were expected but came back empty + # (e.g. a transient 429 at startup). This avoids persisting a 60s stale result + # while still preventing the thundering-herd that occurs with no caching at all. + ttl = CACHE_TTL_FAILURE_SECONDS if (subscription_attempted and not result.subscription_windows) else CACHE_TTL_SECONDS + _set_cached(credential_id, result, ttl=ttl) logger.info( - "provider_usage.checked provider=%s account=%s reachable=%s error=%s", - provider, account_key, result.reachable, result.error, + "provider_usage.checked provider=%s account=%s reachable=%s windows=%d error=%s", + provider, account_key, result.reachable, len(result.subscription_windows), result.error, ) return result + + +async def fetch_provider_usage( + credential_id: str, + provider: str, + account_key: str, + api_key: str | None, + base_url: str | None, + *, + session_key: str | None = None, + force_refresh: bool = False, +) -> ProviderUsageLive: + """Fetch live usage from the provider API. + + When ``session_key`` is provided, also fetches subscription-plan usage + windows (e.g. "Current session 77% used, resets in 23 min") in addition + to the standard API rate-limit diagnostics. + + Results are cached for CACHE_TTL_SECONDS (short CACHE_TTL_FAILURE_SECONDS + when subscription windows are unavailable). Concurrent requests for the same + credential share one in-flight fetch to avoid rate-limit cascades. + Pass force_refresh=True to bypass the cache. + """ + if not force_refresh: + cached = _get_cached(credential_id) + if cached is not None: + return cached + + # In-flight deduplication: if another coroutine is already fetching this + # credential, await its result rather than racing to hit the provider API. + loop = asyncio.get_event_loop() + if credential_id in _inflight: + return await asyncio.shield(_inflight[credential_id]) + + fut: asyncio.Future[ProviderUsageLive] = loop.create_future() + _inflight[credential_id] = fut + try: + result = await _do_fetch_provider_usage( + credential_id, provider, account_key, api_key, base_url, session_key, + ) + fut.set_result(result) + return result + except Exception as exc: + fut.set_exception(exc) + raise + finally: + _inflight.pop(credential_id, None)