refactor: harden ChatGPT token lifecycle with startup recovery, single-writer locking, and faster auth flow
This commit is contained in:
parent
71d1050adb
commit
533e382e0e
9 changed files with 313 additions and 178 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Any
|
||||
|
|
@ -7,11 +8,12 @@ from playwright.async_api import BrowserContext
|
|||
from providers.base import Provider, ProviderTokens
|
||||
from email_providers import BaseProvider
|
||||
from email_providers import TempMailOrgProvider
|
||||
from .tokens import load_tokens, save_tokens, get_valid_tokens
|
||||
from .usage import get_usage_percent
|
||||
from .tokens import load_tokens, save_tokens, refresh_tokens
|
||||
from .usage import get_usage_data
|
||||
from .registration import register_chatgpt_account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MAX_REGISTRATION_ATTEMPTS = 4
|
||||
|
||||
|
||||
class ChatGPTProvider(Provider):
|
||||
|
|
@ -22,23 +24,63 @@ class ChatGPTProvider(Provider):
|
|||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||
):
|
||||
self.email_provider_factory = email_provider_factory or TempMailOrgProvider
|
||||
self._token_write_lock = asyncio.Lock()
|
||||
|
||||
async def _register_with_retries(self) -> bool:
|
||||
for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1):
|
||||
logger.info(
|
||||
"Registration attempt %s/%s",
|
||||
attempt,
|
||||
MAX_REGISTRATION_ATTEMPTS,
|
||||
)
|
||||
success = await self.register_new_account()
|
||||
if success:
|
||||
return True
|
||||
logger.warning("Registration attempt %s failed", attempt)
|
||||
return False
|
||||
|
||||
async def force_recreate_token(self) -> str | None:
|
||||
async with self._token_write_lock:
|
||||
success = await self._register_with_retries()
|
||||
if not success:
|
||||
return None
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
return None
|
||||
return tokens.access_token
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "chatgpt"
|
||||
|
||||
async def get_token(self) -> str | None:
|
||||
"""Get valid access token, refreshing if needed"""
|
||||
tokens = await get_valid_tokens()
|
||||
if not tokens:
|
||||
logger.info("No valid tokens, registering new account")
|
||||
success = await self.register_new_account()
|
||||
"""Get valid access token with single-writer refresh/register path."""
|
||||
tokens = load_tokens()
|
||||
if tokens and not tokens.is_expired:
|
||||
return tokens.access_token
|
||||
|
||||
async with self._token_write_lock:
|
||||
tokens = load_tokens()
|
||||
if tokens and not tokens.is_expired:
|
||||
return tokens.access_token
|
||||
|
||||
if tokens and tokens.refresh_token:
|
||||
logger.info("Token expired, refreshing under lock")
|
||||
refreshed = await refresh_tokens(tokens.refresh_token)
|
||||
if refreshed:
|
||||
save_tokens(refreshed)
|
||||
return refreshed.access_token
|
||||
logger.warning("Token refresh failed, falling back to registration")
|
||||
|
||||
logger.info("No valid tokens, registering new account under lock")
|
||||
success = await self._register_with_retries()
|
||||
if not success:
|
||||
return None
|
||||
tokens = await get_valid_tokens()
|
||||
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
return None
|
||||
return tokens.access_token
|
||||
return tokens.access_token
|
||||
|
||||
async def register_new_account(self) -> bool:
|
||||
"""Register a new ChatGPT account"""
|
||||
|
|
@ -48,15 +90,18 @@ class ChatGPTProvider(Provider):
|
|||
|
||||
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Get usage information for the current token"""
|
||||
usage_percent = get_usage_percent(access_token)
|
||||
if usage_percent < 0:
|
||||
usage_data = get_usage_data(access_token)
|
||||
if not usage_data:
|
||||
return {"error": "Failed to get usage"}
|
||||
|
||||
remaining = max(0, 100 - usage_percent)
|
||||
return {
|
||||
"used_percent": usage_percent,
|
||||
"remaining_percent": remaining,
|
||||
"exhausted": usage_percent >= 100,
|
||||
"used_percent": int(usage_data["used_percent"]),
|
||||
"remaining_percent": int(usage_data["remaining_percent"]),
|
||||
"exhausted": int(usage_data["used_percent"]) >= 100,
|
||||
"primary_window": usage_data.get("primary_window"),
|
||||
"secondary_window": usage_data.get("secondary_window"),
|
||||
"limit_reached": bool(usage_data.get("limit_reached")),
|
||||
"allowed": bool(usage_data.get("allowed", True)),
|
||||
}
|
||||
|
||||
def load_tokens(self) -> ProviderTokens | None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue