refactor: add staged next-account rotation and clarify ChatGPT config
This commit is contained in:
parent
ccd4d82194
commit
d6396e4050
5 changed files with 253 additions and 100 deletions
|
|
@ -1,19 +1,30 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Callable
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
from playwright.async_api import BrowserContext
|
||||
|
||||
from providers.base import Provider, ProviderTokens
|
||||
from email_providers import BaseProvider
|
||||
from email_providers import TempMailOrgProvider
|
||||
from .tokens import load_tokens, save_tokens, refresh_tokens
|
||||
from providers.base import Provider, ProviderTokens
|
||||
from .tokens import (
|
||||
clear_next_tokens,
|
||||
load_next_tokens,
|
||||
load_state,
|
||||
load_tokens,
|
||||
promote_next_tokens,
|
||||
refresh_tokens,
|
||||
save_next_tokens,
|
||||
save_tokens,
|
||||
)
|
||||
from .usage import get_usage_data
|
||||
from .registration import register_chatgpt_account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MAX_REGISTRATION_ATTEMPTS = 4
|
||||
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
|
||||
CHATGPT_SWITCH_THRESHOLD = int(os.environ.get("CHATGPT_SWITCH_THRESHOLD", "95"))
|
||||
|
||||
|
||||
class ChatGPTProvider(Provider):
|
||||
|
|
@ -27,11 +38,11 @@ class ChatGPTProvider(Provider):
|
|||
self._token_write_lock = asyncio.Lock()
|
||||
|
||||
async def _register_with_retries(self) -> bool:
|
||||
for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1):
|
||||
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
||||
logger.info(
|
||||
"Registration attempt %s/%s",
|
||||
attempt,
|
||||
MAX_REGISTRATION_ATTEMPTS,
|
||||
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
|
||||
)
|
||||
success = await self.register_new_account()
|
||||
if success:
|
||||
|
|
@ -39,16 +50,75 @@ class ChatGPTProvider(Provider):
|
|||
logger.warning("Registration attempt %s failed", attempt)
|
||||
return False
|
||||
|
||||
async def _create_next_account_under_lock(self) -> bool:
|
||||
active_before, next_before = load_state()
|
||||
if next_before:
|
||||
return True
|
||||
|
||||
logger.info("Creating next account")
|
||||
success = await self._register_with_retries()
|
||||
if not success:
|
||||
return False
|
||||
|
||||
generated_active = load_tokens()
|
||||
if not generated_active:
|
||||
return False
|
||||
|
||||
# Registration writes new tokens as active; restore old active and keep
|
||||
# generated account as next.
|
||||
if active_before:
|
||||
save_tokens(active_before)
|
||||
else:
|
||||
clear_next_tokens()
|
||||
save_next_tokens(generated_active)
|
||||
logger.info("Next account is ready")
|
||||
return True
|
||||
|
||||
async def force_recreate_token(self) -> str | None:
|
||||
async with self._token_write_lock:
|
||||
success = await self._register_with_retries()
|
||||
if not success:
|
||||
return None
|
||||
clear_next_tokens()
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
return None
|
||||
return tokens.access_token
|
||||
|
||||
async def ensure_next_account(self) -> bool:
|
||||
next_tokens = load_next_tokens()
|
||||
if next_tokens and not next_tokens.is_expired:
|
||||
return True
|
||||
|
||||
async with self._token_write_lock:
|
||||
next_tokens = load_next_tokens()
|
||||
if next_tokens and not next_tokens.is_expired:
|
||||
return True
|
||||
return await self._create_next_account_under_lock()
|
||||
|
||||
async def maybe_switch_active_account(self, usage_percent: int) -> bool:
|
||||
if usage_percent < CHATGPT_SWITCH_THRESHOLD:
|
||||
return False
|
||||
|
||||
async with self._token_write_lock:
|
||||
next_tokens = load_next_tokens()
|
||||
if not next_tokens or next_tokens.is_expired:
|
||||
logger.info(
|
||||
"Active usage >= %s%% and next account missing",
|
||||
CHATGPT_SWITCH_THRESHOLD,
|
||||
)
|
||||
created = await self._create_next_account_under_lock()
|
||||
if not created:
|
||||
return False
|
||||
|
||||
switched = promote_next_tokens()
|
||||
if switched:
|
||||
logger.info(
|
||||
"Switched active account (usage >= %s%%)",
|
||||
CHATGPT_SWITCH_THRESHOLD,
|
||||
)
|
||||
return switched
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "chatgpt"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import time
|
||||
import os
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from providers.base import ProviderTokens
|
||||
|
||||
|
|
@ -16,33 +18,108 @@ CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
|||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
|
||||
def load_tokens() -> ProviderTokens | None:
|
||||
if not TOKENS_FILE.exists():
|
||||
def _tokens_to_dict(tokens: ProviderTokens) -> dict[str, Any]:
|
||||
return {
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"expires_at": tokens.expires_at,
|
||||
}
|
||||
|
||||
|
||||
def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None:
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
try:
|
||||
with open(TOKENS_FILE) as f:
|
||||
data = json.load(f)
|
||||
return ProviderTokens(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=data["refresh_token"],
|
||||
expires_at=data["expires_at"],
|
||||
)
|
||||
except json.JSONDecodeError, KeyError:
|
||||
except KeyError, TypeError:
|
||||
return None
|
||||
|
||||
|
||||
def save_tokens(tokens: ProviderTokens):
|
||||
def _load_raw() -> dict[str, Any] | None:
|
||||
if not TOKENS_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
with open(TOKENS_FILE) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def _save_raw(data: dict[str, Any]) -> None:
|
||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(TOKENS_FILE, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"expires_at": tokens.expires_at,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
if not data:
|
||||
return {"active": None, "next_account": None}
|
||||
|
||||
if "active" in data or "next_account" in data:
|
||||
return {
|
||||
"active": data.get("active"),
|
||||
"next_account": data.get("next_account"),
|
||||
}
|
||||
|
||||
# Backward compatibility with old flat schema
|
||||
return {"active": data, "next_account": None}
|
||||
|
||||
|
||||
def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]:
|
||||
normalized = _normalize_state(_load_raw())
|
||||
active = _dict_to_tokens(normalized.get("active"))
|
||||
next_account = _dict_to_tokens(normalized.get("next_account"))
|
||||
return active, next_account
|
||||
|
||||
|
||||
def save_state(
|
||||
active: ProviderTokens | None, next_account: ProviderTokens | None
|
||||
) -> None:
|
||||
payload = {
|
||||
"active": _tokens_to_dict(active) if active else None,
|
||||
"next_account": _tokens_to_dict(next_account) if next_account else None,
|
||||
}
|
||||
_save_raw(payload)
|
||||
|
||||
|
||||
def load_tokens() -> ProviderTokens | None:
|
||||
active, _ = load_state()
|
||||
return active
|
||||
|
||||
|
||||
def load_next_tokens() -> ProviderTokens | None:
|
||||
_, next_account = load_state()
|
||||
return next_account
|
||||
|
||||
|
||||
def save_tokens(tokens: ProviderTokens):
|
||||
_, next_account = load_state()
|
||||
save_state(tokens, next_account)
|
||||
|
||||
|
||||
def save_next_tokens(tokens: ProviderTokens):
|
||||
active, _ = load_state()
|
||||
save_state(active, tokens)
|
||||
|
||||
|
||||
def promote_next_tokens() -> bool:
|
||||
active, next_account = load_state()
|
||||
if not next_account:
|
||||
return False
|
||||
save_state(next_account, None)
|
||||
return True
|
||||
|
||||
|
||||
def clear_next_tokens():
|
||||
active, _ = load_state()
|
||||
save_state(active, None)
|
||||
|
||||
|
||||
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||
|
|
@ -55,7 +132,7 @@ async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
|||
async with session.post(TOKEN_URL, data=data) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
print(f"Token refresh failed: {resp.status} {text}")
|
||||
logger.warning("Token refresh failed: %s %s", resp.status, text)
|
||||
return None
|
||||
json_resp = await resp.json()
|
||||
expires_in = json_resp["expires_in"]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from aiohttp import web
|
|||
from providers.chatgpt import ChatGPTProvider
|
||||
|
||||
PORT = int(os.environ.get("PORT", "8080"))
|
||||
USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85"))
|
||||
CHATGPT_PREPARE_THRESHOLD = int(os.environ.get("CHATGPT_PREPARE_THRESHOLD", "85"))
|
||||
LIMIT_EXHAUSTED_PERCENT = 100
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
|
@ -37,7 +37,7 @@ def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
|||
"used_percent": usage_percent,
|
||||
"remaining_percent": remaining,
|
||||
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
|
||||
"needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD,
|
||||
"needs_refresh": usage_percent >= CHATGPT_PREPARE_THRESHOLD,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -59,6 +59,9 @@ async def ensure_provider_token_ready(provider_name: str):
|
|||
logger.error("[%s] Could not prepare token at startup", provider_name)
|
||||
return
|
||||
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
await provider.ensure_next_account()
|
||||
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" not in usage_info:
|
||||
logger.info("[%s] Startup token is ready", provider_name)
|
||||
|
|
@ -157,6 +160,27 @@ async def token_handler(request: web.Request) -> web.Response:
|
|||
usage_percent = usage_info.get("used_percent", 0)
|
||||
remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent))
|
||||
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
switched = await provider.maybe_switch_active_account(usage_percent)
|
||||
if switched:
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
return web.json_response(
|
||||
{"error": "Failed to get active token after account switch"},
|
||||
status=503,
|
||||
)
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" in usage_info:
|
||||
return web.json_response(
|
||||
{"error": usage_info["error"]},
|
||||
status=503,
|
||||
)
|
||||
usage_percent = usage_info.get("used_percent", 0)
|
||||
remaining_percent = usage_info.get(
|
||||
"remaining_percent", max(0, 100 - usage_percent)
|
||||
)
|
||||
logger.info("[%s] Active account switched before response", provider_name)
|
||||
|
||||
logger.info(
|
||||
"[%s] token issued, used=%s%% remaining=%s%%",
|
||||
provider_name,
|
||||
|
|
@ -184,11 +208,14 @@ async def token_handler(request: web.Request) -> web.Response:
|
|||
)
|
||||
|
||||
# Trigger background refresh if needed
|
||||
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
||||
trigger_background_refresh(
|
||||
provider_name,
|
||||
f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%",
|
||||
)
|
||||
if usage_percent >= CHATGPT_PREPARE_THRESHOLD:
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
await provider.ensure_next_account()
|
||||
else:
|
||||
trigger_background_refresh(
|
||||
provider_name,
|
||||
f"usage {usage_percent}% >= threshold {CHATGPT_PREPARE_THRESHOLD}%",
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
|
|
@ -214,7 +241,10 @@ def create_app() -> web.Application:
|
|||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting token service on port %s", PORT)
|
||||
logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD)
|
||||
logger.info(
|
||||
"ChatGPT prepare-next threshold: %s%%",
|
||||
CHATGPT_PREPARE_THRESHOLD,
|
||||
)
|
||||
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
||||
app = create_app()
|
||||
web.run_app(app, host="0.0.0.0", port=PORT)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue