226 lines
7.6 KiB
Python
226 lines
7.6 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
from typing import Callable
|
|
|
|
from playwright.async_api import BrowserContext
|
|
|
|
from email_providers import BaseProvider
|
|
from email_providers import MailTmProvider
|
|
from providers.base import Provider, ProviderTokens
|
|
from utils.env import parse_int_env
|
|
from .tokens import (
|
|
clear_next_tokens,
|
|
load_next_tokens,
|
|
load_state,
|
|
load_tokens,
|
|
promote_next_tokens,
|
|
refresh_tokens,
|
|
save_state,
|
|
save_tokens,
|
|
)
|
|
from .usage import get_usage_data
|
|
from .registration import register_chatgpt_account
|
|
|
|
logger = logging.getLogger(__name__)
|
|
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
|
|
CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
|
|
CHATGPT_SWITCH_THRESHOLD = parse_int_env(
|
|
"CHATGPT_SWITCH_THRESHOLD",
|
|
95,
|
|
0,
|
|
100,
|
|
)
|
|
|
|
|
|
class ChatGPTProvider(Provider):
|
|
"""ChatGPT account provider"""
|
|
|
|
def __init__(
|
|
self,
|
|
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
|
):
|
|
self.email_provider_factory = email_provider_factory or MailTmProvider
|
|
self._token_write_lock = asyncio.Lock()
|
|
|
|
@property
|
|
def prepare_threshold(self) -> int:
|
|
return CHATGPT_PREPARE_THRESHOLD
|
|
|
|
@property
|
|
def switch_threshold(self) -> int | None:
|
|
return CHATGPT_SWITCH_THRESHOLD
|
|
|
|
async def _register_with_retries(self) -> bool:
|
|
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
|
logger.info(
|
|
"Registration attempt %s/%s",
|
|
attempt,
|
|
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
|
|
)
|
|
generated_tokens = await register_chatgpt_account(
|
|
email_provider_factory=self.email_provider_factory,
|
|
)
|
|
if generated_tokens:
|
|
save_tokens(generated_tokens)
|
|
return True
|
|
logger.warning("Registration attempt %s failed", attempt)
|
|
await asyncio.sleep(1.5 * attempt)
|
|
return False
|
|
|
|
async def _create_next_account_under_lock(self) -> bool:
|
|
active_before, next_before = load_state()
|
|
if next_before:
|
|
return True
|
|
|
|
logger.info("Creating next account")
|
|
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
|
logger.info(
|
|
"Next-account registration attempt %s/%s",
|
|
attempt,
|
|
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
|
|
)
|
|
generated_tokens = await register_chatgpt_account(
|
|
email_provider_factory=self.email_provider_factory,
|
|
)
|
|
if generated_tokens:
|
|
if active_before:
|
|
save_state(active_before, generated_tokens)
|
|
else:
|
|
save_state(generated_tokens, None)
|
|
logger.info("Next account is ready")
|
|
return True
|
|
logger.warning("Next-account registration attempt %s failed", attempt)
|
|
await asyncio.sleep(1.5 * attempt)
|
|
|
|
if active_before or next_before:
|
|
save_state(active_before, next_before)
|
|
return False
|
|
|
|
async def force_recreate_token(self) -> str | None:
|
|
async with self._token_write_lock:
|
|
success = await self._register_with_retries()
|
|
if not success:
|
|
return None
|
|
clear_next_tokens()
|
|
tokens = load_tokens()
|
|
if not tokens:
|
|
return None
|
|
return tokens.access_token
|
|
|
|
async def startup_prepare(self) -> None:
|
|
await self.ensure_next_account()
|
|
|
|
async def ensure_next_account(self) -> bool:
|
|
next_tokens = load_next_tokens()
|
|
if next_tokens:
|
|
return True
|
|
|
|
async with self._token_write_lock:
|
|
next_tokens = load_next_tokens()
|
|
if next_tokens:
|
|
return True
|
|
return await self._create_next_account_under_lock()
|
|
|
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
|
return usage_percent >= self.prepare_threshold
|
|
|
|
async def ensure_standby_account(
|
|
self,
|
|
usage_percent: int,
|
|
) -> None:
|
|
if self.should_prepare_standby(usage_percent):
|
|
await self.ensure_next_account()
|
|
|
|
async def maybe_switch_active_account(self, usage_percent: int) -> bool:
|
|
if usage_percent < CHATGPT_SWITCH_THRESHOLD:
|
|
return False
|
|
|
|
async with self._token_write_lock:
|
|
next_tokens = load_next_tokens()
|
|
if not next_tokens or next_tokens.is_expired:
|
|
logger.info(
|
|
"Active usage >= %s%% and next account missing",
|
|
CHATGPT_SWITCH_THRESHOLD,
|
|
)
|
|
created = await self._create_next_account_under_lock()
|
|
if not created:
|
|
return False
|
|
|
|
switched = promote_next_tokens()
|
|
if switched:
|
|
logger.info(
|
|
"Switched active account (usage >= %s%%)",
|
|
CHATGPT_SWITCH_THRESHOLD,
|
|
)
|
|
return switched
|
|
|
|
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
|
return await self.maybe_switch_active_account(usage_percent)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "chatgpt"
|
|
|
|
async def get_token(self) -> str | None:
|
|
"""Get valid access token with single-writer refresh/register path."""
|
|
tokens = load_tokens()
|
|
if tokens and not tokens.is_expired:
|
|
return tokens.access_token
|
|
|
|
async with self._token_write_lock:
|
|
tokens = load_tokens()
|
|
if tokens and not tokens.is_expired:
|
|
return tokens.access_token
|
|
|
|
if tokens and tokens.refresh_token:
|
|
logger.info("Token expired, refreshing under lock")
|
|
refreshed = await refresh_tokens(tokens.refresh_token)
|
|
if refreshed:
|
|
save_tokens(refreshed)
|
|
return refreshed.access_token
|
|
logger.warning("Token refresh failed, falling back to registration")
|
|
|
|
logger.info("No valid tokens, registering new account under lock")
|
|
success = await self._register_with_retries()
|
|
if not success:
|
|
return None
|
|
|
|
tokens = load_tokens()
|
|
if not tokens:
|
|
return None
|
|
return tokens.access_token
|
|
|
|
async def register_new_account(self) -> bool:
|
|
"""Register a new ChatGPT account"""
|
|
generated_tokens = await register_chatgpt_account(
|
|
email_provider_factory=self.email_provider_factory,
|
|
)
|
|
if not generated_tokens:
|
|
return False
|
|
save_tokens(generated_tokens)
|
|
return True
|
|
|
|
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
|
"""Get usage information for the current token"""
|
|
usage_data = await get_usage_data(access_token)
|
|
if not usage_data:
|
|
return {"error": "Failed to get usage"}
|
|
|
|
return {
|
|
"used_percent": int(usage_data["used_percent"]),
|
|
"remaining_percent": int(usage_data["remaining_percent"]),
|
|
"exhausted": int(usage_data["used_percent"]) >= 100,
|
|
"primary_window": usage_data.get("primary_window"),
|
|
"secondary_window": usage_data.get("secondary_window"),
|
|
"limit_reached": bool(usage_data.get("limit_reached")),
|
|
"allowed": bool(usage_data.get("allowed", True)),
|
|
}
|
|
|
|
def load_tokens(self) -> ProviderTokens | None:
|
|
"""Load tokens from storage"""
|
|
return load_tokens()
|
|
|
|
def save_tokens(self, tokens: ProviderTokens) -> None:
|
|
"""Save tokens to storage"""
|
|
save_tokens(tokens)
|