import json import logging import os import tempfile import time from pathlib import Path from typing import Any import aiohttp from providers.base import ProviderTokens logger = logging.getLogger(__name__) DATA_DIR = Path(os.environ.get("DATA_DIR", "./data")) TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json" CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" TOKEN_URL = "https://auth.openai.com/oauth/token" def _tokens_to_dict(tokens: ProviderTokens) -> dict[str, Any]: return { "access_token": tokens.access_token, "refresh_token": tokens.refresh_token, "expires_at": tokens.expires_at, } def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None: if not isinstance(data, dict): return None try: return ProviderTokens( access_token=data["access_token"], refresh_token=data["refresh_token"], expires_at=data["expires_at"], ) except (KeyError, TypeError): return None def _load_raw() -> dict[str, Any] | None: if not TOKENS_FILE.exists(): return None try: with open(TOKENS_FILE) as f: data = json.load(f) if isinstance(data, dict): return data return None except json.JSONDecodeError: return None def _save_raw(data: dict[str, Any]) -> None: TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) fd, tmp_path = tempfile.mkstemp( prefix=f"{TOKENS_FILE.name}.", suffix=".tmp", dir=str(TOKENS_FILE.parent), ) try: with os.fdopen(fd, "w") as f: json.dump(data, f, indent=2) f.flush() os.fsync(f.fileno()) os.replace(tmp_path, TOKENS_FILE) finally: if os.path.exists(tmp_path): os.unlink(tmp_path) def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]: if not data: return {"active": None, "next_account": None} if "active" in data or "next_account" in data: return { "active": data.get("active"), "next_account": data.get("next_account"), } return {"active": data, "next_account": None} def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]: normalized = _normalize_state(_load_raw()) active = _dict_to_tokens(normalized.get("active")) next_account = _dict_to_tokens(normalized.get("next_account")) return active, next_account def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None: payload = { "active": _tokens_to_dict(active) if active else None, "next_account": _tokens_to_dict(next_account) if next_account else None, } _save_raw(payload) def load_tokens() -> ProviderTokens | None: active, _ = load_state() return active def load_next_tokens() -> ProviderTokens | None: _, next_account = load_state() return next_account def save_tokens(tokens: ProviderTokens): _, next_account = load_state() save_state(tokens, next_account) def promote_next_tokens() -> bool: _, next_account = load_state() if not next_account: return False save_state(next_account, None) return True def clear_next_tokens(): active, _ = load_state() save_state(active, None) async def refresh_tokens(refresh_token: str) -> ProviderTokens | None: data = { "grant_type": "refresh_token", "refresh_token": refresh_token, "client_id": CLIENT_ID, } timeout = aiohttp.ClientTimeout(total=15) try: async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(TOKEN_URL, data=data) as resp: if not resp.ok: text = await resp.text() logger.warning("Token refresh failed: %s %s", resp.status, text) return None json_resp = await resp.json() except (aiohttp.ClientError, TimeoutError) as e: logger.warning("Token refresh request error: %s", e) return None except Exception as e: logger.warning("Token refresh unexpected error: %s", e) return None try: expires_in = int(json_resp["expires_in"]) return ProviderTokens( access_token=json_resp["access_token"], refresh_token=json_resp["refresh_token"], expires_at=time.time() + expires_in, ) except (KeyError, TypeError, ValueError) as e: logger.warning("Token refresh response parse error: %s", e) return None