refactor: add staged next-account rotation and clarify ChatGPT config
This commit is contained in:
parent
ccd4d82194
commit
d6396e4050
5 changed files with 253 additions and 100 deletions
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import time
|
||||
import os
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from providers.base import ProviderTokens
|
||||
|
||||
|
|
@ -16,33 +18,108 @@ CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
|||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
|
||||
def load_tokens() -> ProviderTokens | None:
|
||||
if not TOKENS_FILE.exists():
|
||||
def _tokens_to_dict(tokens: ProviderTokens) -> dict[str, Any]:
|
||||
return {
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"expires_at": tokens.expires_at,
|
||||
}
|
||||
|
||||
|
||||
def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None:
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
try:
|
||||
with open(TOKENS_FILE) as f:
|
||||
data = json.load(f)
|
||||
return ProviderTokens(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=data["refresh_token"],
|
||||
expires_at=data["expires_at"],
|
||||
)
|
||||
except json.JSONDecodeError, KeyError:
|
||||
except KeyError, TypeError:
|
||||
return None
|
||||
|
||||
|
||||
def save_tokens(tokens: ProviderTokens):
|
||||
def _load_raw() -> dict[str, Any] | None:
|
||||
if not TOKENS_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
with open(TOKENS_FILE) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def _save_raw(data: dict[str, Any]) -> None:
|
||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(TOKENS_FILE, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"expires_at": tokens.expires_at,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
if not data:
|
||||
return {"active": None, "next_account": None}
|
||||
|
||||
if "active" in data or "next_account" in data:
|
||||
return {
|
||||
"active": data.get("active"),
|
||||
"next_account": data.get("next_account"),
|
||||
}
|
||||
|
||||
# Backward compatibility with old flat schema
|
||||
return {"active": data, "next_account": None}
|
||||
|
||||
|
||||
def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]:
|
||||
normalized = _normalize_state(_load_raw())
|
||||
active = _dict_to_tokens(normalized.get("active"))
|
||||
next_account = _dict_to_tokens(normalized.get("next_account"))
|
||||
return active, next_account
|
||||
|
||||
|
||||
def save_state(
|
||||
active: ProviderTokens | None, next_account: ProviderTokens | None
|
||||
) -> None:
|
||||
payload = {
|
||||
"active": _tokens_to_dict(active) if active else None,
|
||||
"next_account": _tokens_to_dict(next_account) if next_account else None,
|
||||
}
|
||||
_save_raw(payload)
|
||||
|
||||
|
||||
def load_tokens() -> ProviderTokens | None:
|
||||
active, _ = load_state()
|
||||
return active
|
||||
|
||||
|
||||
def load_next_tokens() -> ProviderTokens | None:
|
||||
_, next_account = load_state()
|
||||
return next_account
|
||||
|
||||
|
||||
def save_tokens(tokens: ProviderTokens):
|
||||
_, next_account = load_state()
|
||||
save_state(tokens, next_account)
|
||||
|
||||
|
||||
def save_next_tokens(tokens: ProviderTokens):
|
||||
active, _ = load_state()
|
||||
save_state(active, tokens)
|
||||
|
||||
|
||||
def promote_next_tokens() -> bool:
|
||||
active, next_account = load_state()
|
||||
if not next_account:
|
||||
return False
|
||||
save_state(next_account, None)
|
||||
return True
|
||||
|
||||
|
||||
def clear_next_tokens():
|
||||
active, _ = load_state()
|
||||
save_state(active, None)
|
||||
|
||||
|
||||
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||
|
|
@ -55,7 +132,7 @@ async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
|||
async with session.post(TOKEN_URL, data=data) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
print(f"Token refresh failed: {resp.status} {text}")
|
||||
logger.warning("Token refresh failed: %s %s", resp.status, text)
|
||||
return None
|
||||
json_resp = await resp.json()
|
||||
expires_in = json_resp["expires_in"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue