161 lines
4.5 KiB
Python
161 lines
4.5 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
|
|
from providers.base import ProviderTokens
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
|
|
TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json"
|
|
|
|
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
|
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
|
|
|
|
|
def _tokens_to_dict(tokens: ProviderTokens) -> dict[str, Any]:
|
|
return {
|
|
"access_token": tokens.access_token,
|
|
"refresh_token": tokens.refresh_token,
|
|
"expires_at": tokens.expires_at,
|
|
}
|
|
|
|
|
|
def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None:
|
|
if not isinstance(data, dict):
|
|
return None
|
|
try:
|
|
return ProviderTokens(
|
|
access_token=data["access_token"],
|
|
refresh_token=data["refresh_token"],
|
|
expires_at=data["expires_at"],
|
|
)
|
|
except (KeyError, TypeError):
|
|
return None
|
|
|
|
|
|
def _load_raw() -> dict[str, Any] | None:
|
|
if not TOKENS_FILE.exists():
|
|
return None
|
|
try:
|
|
with open(TOKENS_FILE) as f:
|
|
data = json.load(f)
|
|
if isinstance(data, dict):
|
|
return data
|
|
return None
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
|
|
def _save_raw(data: dict[str, Any]) -> None:
|
|
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
fd, tmp_path = tempfile.mkstemp(
|
|
prefix=f"{TOKENS_FILE.name}.",
|
|
suffix=".tmp",
|
|
dir=str(TOKENS_FILE.parent),
|
|
)
|
|
try:
|
|
with os.fdopen(fd, "w") as f:
|
|
json.dump(data, f, indent=2)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.replace(tmp_path, TOKENS_FILE)
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.unlink(tmp_path)
|
|
|
|
|
|
def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
|
if not data:
|
|
return {"active": None, "next_account": None}
|
|
|
|
if "active" in data or "next_account" in data:
|
|
return {
|
|
"active": data.get("active"),
|
|
"next_account": data.get("next_account"),
|
|
}
|
|
|
|
return {"active": data, "next_account": None}
|
|
|
|
|
|
def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]:
|
|
normalized = _normalize_state(_load_raw())
|
|
active = _dict_to_tokens(normalized.get("active"))
|
|
next_account = _dict_to_tokens(normalized.get("next_account"))
|
|
return active, next_account
|
|
|
|
|
|
def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None:
|
|
payload = {
|
|
"active": _tokens_to_dict(active) if active else None,
|
|
"next_account": _tokens_to_dict(next_account) if next_account else None,
|
|
}
|
|
_save_raw(payload)
|
|
|
|
|
|
def load_tokens() -> ProviderTokens | None:
|
|
active, _ = load_state()
|
|
return active
|
|
|
|
|
|
def load_next_tokens() -> ProviderTokens | None:
|
|
_, next_account = load_state()
|
|
return next_account
|
|
|
|
|
|
def save_tokens(tokens: ProviderTokens):
|
|
_, next_account = load_state()
|
|
save_state(tokens, next_account)
|
|
|
|
|
|
def promote_next_tokens() -> bool:
|
|
_, next_account = load_state()
|
|
if not next_account:
|
|
return False
|
|
save_state(next_account, None)
|
|
return True
|
|
|
|
|
|
def clear_next_tokens():
|
|
active, _ = load_state()
|
|
save_state(active, None)
|
|
|
|
|
|
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
|
data = {
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": CLIENT_ID,
|
|
}
|
|
timeout = aiohttp.ClientTimeout(total=15)
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(TOKEN_URL, data=data) as resp:
|
|
if not resp.ok:
|
|
text = await resp.text()
|
|
logger.warning("Token refresh failed: %s %s", resp.status, text)
|
|
return None
|
|
json_resp = await resp.json()
|
|
except (aiohttp.ClientError, TimeoutError) as e:
|
|
logger.warning("Token refresh request error: %s", e)
|
|
return None
|
|
except Exception as e:
|
|
logger.warning("Token refresh unexpected error: %s", e)
|
|
return None
|
|
|
|
try:
|
|
expires_in = int(json_resp["expires_in"])
|
|
return ProviderTokens(
|
|
access_token=json_resp["access_token"],
|
|
refresh_token=json_resp["refresh_token"],
|
|
expires_at=time.time() + expires_in,
|
|
)
|
|
except (KeyError, TypeError, ValueError) as e:
|
|
logger.warning("Token refresh response parse error: %s", e)
|
|
return None
|