1
0
Fork 0

refactor: harden ChatGPT token lifecycle with startup recovery, single-writer locking, and faster auth flow

This commit is contained in:
Arthur K. 2026-03-01 20:58:24 +03:00
parent 71d1050adb
commit 533e382e0e
Signed by: wzray
GPG key ID: B97F30FDC4636357
9 changed files with 313 additions and 178 deletions

View file

@ -15,19 +15,19 @@ RUN pip install --no-cache-dir uv
RUN uv sync --frozen --no-dev RUN uv sync --frozen --no-dev
RUN /app/.venv/bin/python -m playwright install --with-deps chromium RUN /app/.venv/bin/python -m playwright install --with-deps chromium
COPY src/*.py /app/ COPY entrypoint.sh /entrypoint.sh
COPY src .
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
ENV PORT=8000 ENV PORT=80
ENV DATA_DIR=/data ENV DATA_DIR=/data
ENV PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD=1 ENV PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD=1
VOLUME ["/data"] VOLUME ["/data"]
EXPOSE 8000 EXPOSE 80
STOPSIGNAL SIGINT STOPSIGNAL SIGINT
COPY entrypoint.sh /entrypoint.sh
CMD ["/entrypoint.sh"] CMD ["/entrypoint.sh"]

View file

@ -1,11 +0,0 @@
services:
megapt:
build: src
restart: unless-stopped
environment:
USAGE_THRESHOLD: 85
CHECK_INTERVAL: 60
labels:
traefik.host: megapt
volumes:
- ./data:/data

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import json import json
import logging import logging
import os
import shutil import shutil
import subprocess import subprocess
import tempfile import tempfile
@ -45,6 +44,8 @@ CHROME_FLAGS = [
"--disable-search-engine-choice-screen", "--disable-search-engine-choice-screen",
] ]
DEFAULT_CDP_PORT = 9222
def _fetch_ws_endpoint(port: int) -> str | None: def _fetch_ws_endpoint(port: int) -> str | None:
try: try:
@ -78,9 +79,10 @@ class ManagedBrowser:
shutil.rmtree(self.profile_dir, ignore_errors=True) shutil.rmtree(self.profile_dir, ignore_errors=True)
async def launch(playwright: Playwright, cdp_port: int | None = None) -> ManagedBrowser: async def launch(
chrome_path = os.environ.get("CHROMIUM_PATH") or playwright.chromium.executable_path playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT
cdp_port = cdp_port or int(os.environ.get("CDP_PORT", "9222")) ) -> ManagedBrowser:
chrome_path = playwright.chromium.executable_path
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp")) profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
args = [ args = [

View file

@ -1,3 +1,4 @@
import asyncio
import logging import logging
from typing import Callable from typing import Callable
from typing import Any from typing import Any
@ -7,11 +8,12 @@ from playwright.async_api import BrowserContext
from providers.base import Provider, ProviderTokens from providers.base import Provider, ProviderTokens
from email_providers import BaseProvider from email_providers import BaseProvider
from email_providers import TempMailOrgProvider from email_providers import TempMailOrgProvider
from .tokens import load_tokens, save_tokens, get_valid_tokens from .tokens import load_tokens, save_tokens, refresh_tokens
from .usage import get_usage_percent from .usage import get_usage_data
from .registration import register_chatgpt_account from .registration import register_chatgpt_account
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_REGISTRATION_ATTEMPTS = 4
class ChatGPTProvider(Provider): class ChatGPTProvider(Provider):
@ -22,23 +24,63 @@ class ChatGPTProvider(Provider):
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
): ):
self.email_provider_factory = email_provider_factory or TempMailOrgProvider self.email_provider_factory = email_provider_factory or TempMailOrgProvider
self._token_write_lock = asyncio.Lock()
async def _register_with_retries(self) -> bool:
for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1):
logger.info(
"Registration attempt %s/%s",
attempt,
MAX_REGISTRATION_ATTEMPTS,
)
success = await self.register_new_account()
if success:
return True
logger.warning("Registration attempt %s failed", attempt)
return False
async def force_recreate_token(self) -> str | None:
async with self._token_write_lock:
success = await self._register_with_retries()
if not success:
return None
tokens = load_tokens()
if not tokens:
return None
return tokens.access_token
@property @property
def name(self) -> str: def name(self) -> str:
return "chatgpt" return "chatgpt"
async def get_token(self) -> str | None: async def get_token(self) -> str | None:
"""Get valid access token, refreshing if needed""" """Get valid access token with single-writer refresh/register path."""
tokens = await get_valid_tokens() tokens = load_tokens()
if not tokens: if tokens and not tokens.is_expired:
logger.info("No valid tokens, registering new account") return tokens.access_token
success = await self.register_new_account()
async with self._token_write_lock:
tokens = load_tokens()
if tokens and not tokens.is_expired:
return tokens.access_token
if tokens and tokens.refresh_token:
logger.info("Token expired, refreshing under lock")
refreshed = await refresh_tokens(tokens.refresh_token)
if refreshed:
save_tokens(refreshed)
return refreshed.access_token
logger.warning("Token refresh failed, falling back to registration")
logger.info("No valid tokens, registering new account under lock")
success = await self._register_with_retries()
if not success: if not success:
return None return None
tokens = await get_valid_tokens()
tokens = load_tokens()
if not tokens: if not tokens:
return None return None
return tokens.access_token return tokens.access_token
async def register_new_account(self) -> bool: async def register_new_account(self) -> bool:
"""Register a new ChatGPT account""" """Register a new ChatGPT account"""
@ -48,15 +90,18 @@ class ChatGPTProvider(Provider):
async def get_usage_info(self, access_token: str) -> dict[str, Any]: async def get_usage_info(self, access_token: str) -> dict[str, Any]:
"""Get usage information for the current token""" """Get usage information for the current token"""
usage_percent = get_usage_percent(access_token) usage_data = get_usage_data(access_token)
if usage_percent < 0: if not usage_data:
return {"error": "Failed to get usage"} return {"error": "Failed to get usage"}
remaining = max(0, 100 - usage_percent)
return { return {
"used_percent": usage_percent, "used_percent": int(usage_data["used_percent"]),
"remaining_percent": remaining, "remaining_percent": int(usage_data["remaining_percent"]),
"exhausted": usage_percent >= 100, "exhausted": int(usage_data["used_percent"]) >= 100,
"primary_window": usage_data.get("primary_window"),
"secondary_window": usage_data.get("secondary_window"),
"limit_reached": bool(usage_data.get("limit_reached")),
"allowed": bool(usage_data.get("allowed", True)),
} }
def load_tokens(self) -> ProviderTokens | None: def load_tokens(self) -> ProviderTokens | None:

View file

@ -78,6 +78,36 @@ def generate_name() -> str:
"Paul", "Paul",
"Andrew", "Andrew",
"Joshua", "Joshua",
"Kenneth",
"Kevin",
"Brian",
"George",
"Edward",
"Ronald",
"Timothy",
"Jason",
"Jeffrey",
"Ryan",
"Jacob",
"Gary",
"Nicholas",
"Eric",
"Jonathan",
"Stephen",
"Larry",
"Justin",
"Scott",
"Brandon",
"Benjamin",
"Samuel",
"Frank",
"Gregory",
"Raymond",
"Alexander",
"Patrick",
"Jack",
"Dennis",
"Jerry",
] ]
last_names = [ last_names = [
"Smith", "Smith",
@ -100,10 +130,47 @@ def generate_name() -> str:
"Moore", "Moore",
"Jackson", "Jackson",
"Martin", "Martin",
"Lee",
"Perez",
"Thompson",
"White",
"Harris",
"Sanchez",
"Clark",
"Ramirez",
"Lewis",
"Robinson",
"Walker",
"Young",
"Allen",
"King",
"Wright",
"Scott",
"Torres",
"Nguyen",
"Hill",
"Flores",
"Green",
"Adams",
"Nelson",
"Baker",
"Hall",
"Rivera",
"Campbell",
"Mitchell",
"Carter",
"Roberts",
] ]
return f"{random.choice(first_names)} {random.choice(last_names)}" return f"{random.choice(first_names)} {random.choice(last_names)}"
def generate_birthdate_90s() -> tuple[str, str, str]:
year = random.randint(1990, 1999)
month = random.randint(1, 12)
day = random.randint(1, 28)
return f"{month:02d}", f"{day:02d}", str(year)
def extract_verification_code(message: str) -> str | None: def extract_verification_code(message: str) -> str | None:
normalized = re.sub(r"\s+", " ", message) normalized = re.sub(r"\s+", " ", message)
@ -177,24 +244,6 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
) )
async def get_new_verification_code(
email_provider: BaseProvider,
email: str,
used_codes: set[str],
timeout_seconds: int = 240,
) -> str | None:
attempts = max(1, timeout_seconds // 5)
for _ in range(attempts):
message = await email_provider.get_latest_message(email)
if message:
all_codes = re.findall(r"\b(\d{6})\b", message)
for candidate in all_codes:
if candidate not in used_codes:
return candidate
await asyncio.sleep(5)
return None
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
message = await email_provider.get_latest_message(email) message = await email_provider.get_latest_message(email)
if not message: if not message:
@ -209,24 +258,32 @@ async def fill_date_field(page: Page, month: str, day: str, year: str):
await month_field.scroll_into_view_if_needed() await month_field.scroll_into_view_if_needed()
await month_field.click() await month_field.click()
await page.wait_for_timeout(120) await page.wait_for_timeout(80)
await page.keyboard.type(f"{month}{day}{year}") await page.keyboard.type(f"{month}{day}{year}")
await page.wait_for_timeout(200) await page.wait_for_timeout(120)
async def wait_for_signup_stabilization(page: Page): async def click_continue(page: Page, timeout_ms: int = 10000):
try: btn = page.get_by_role("button", name="Continue", exact=True).first
await page.wait_for_load_state("networkidle", timeout=15000) await btn.wait_for(state="visible", timeout=timeout_ms)
except Exception: await btn.click()
logger.warning(
"Signup page did not reach networkidle quickly; continuing with fallback"
) async def wait_for_signup_stabilization(
try: page: Page,
await page.wait_for_load_state("domcontentloaded", timeout=10000) source_url: str,
except Exception: timeout_seconds: int = 30,
pass ):
await page.wait_for_timeout(3000) end_at = asyncio.get_running_loop().time() + timeout_seconds
while asyncio.get_running_loop().time() < end_at:
current_url = page.url
if current_url != source_url:
logger.info("Signup redirect detected: %s -> %s", source_url, current_url)
return
await asyncio.sleep(0.5)
logger.warning("Signup redirect was not detected within %ss", timeout_seconds)
async def register_chatgpt_account( async def register_chatgpt_account(
@ -238,7 +295,7 @@ async def register_chatgpt_account(
logger.error("No email provider factory configured") logger.error("No email provider factory configured")
return False return False
birth_month, birth_day, birth_year = "01", "15", "1995" birth_month, birth_day, birth_year = generate_birthdate_90s()
current_page: Page | None = None current_page: Page | None = None
redirect_url_captured: str | None = None redirect_url_captured: str | None = None
@ -253,7 +310,7 @@ async def register_chatgpt_account(
) )
email_provider = email_provider_factory(context) email_provider = email_provider_factory(context)
logger.info("[1/6] Getting new email from configured provider...") logger.info("[1/5] Getting new email from configured provider...")
email = await email_provider.get_new_email() email = await email_provider.get_new_email()
if not email: if not email:
raise AutomationError( raise AutomationError(
@ -266,127 +323,57 @@ async def register_chatgpt_account(
oauth_state = generate_state() oauth_state = generate_state()
authorize_url = build_authorize_url(verifier, challenge, oauth_state) authorize_url = build_authorize_url(verifier, challenge, oauth_state)
logger.info("[2/6] Registering ChatGPT for %s", email) logger.info("[2/5] Registering ChatGPT for %s", email)
chatgpt_page = await context.new_page() chatgpt_page = await context.new_page()
current_page = chatgpt_page current_page = chatgpt_page
await chatgpt_page.goto("https://chatgpt.com") await chatgpt_page.goto("https://chatgpt.com")
await chatgpt_page.wait_for_load_state("domcontentloaded") await chatgpt_page.wait_for_load_state("domcontentloaded")
await chatgpt_page.get_by_text("Sign up for free", exact=True).click() await chatgpt_page.get_by_text("Sign up for free", exact=True).click()
await chatgpt_page.wait_for_timeout(2000) await chatgpt_page.locator('input[type="email"]').first.wait_for(
state="visible", timeout=15000
)
await chatgpt_page.locator('input[type="email"]').fill(email) await chatgpt_page.locator('input[type="email"]').fill(email)
await chatgpt_page.wait_for_timeout(500) await click_continue(chatgpt_page)
await chatgpt_page.get_by_role( await chatgpt_page.locator('input[type="password"]').first.wait_for(
"button", name="Continue", exact=True state="visible", timeout=15000
).click() )
await chatgpt_page.wait_for_timeout(3000)
await chatgpt_page.locator('input[type="password"]').fill(password) await chatgpt_page.locator('input[type="password"]').fill(password)
await chatgpt_page.wait_for_timeout(500) await click_continue(chatgpt_page)
await chatgpt_page.get_by_role( await chatgpt_page.get_by_placeholder("Code").first.wait_for(
"button", name="Continue", exact=True state="visible", timeout=30000
).click() )
await chatgpt_page.wait_for_timeout(5000)
logger.info("[3/6] Getting verification message from email provider...") logger.info("[3/5] Getting verification message from email provider...")
code = await get_latest_code(email_provider, email) code = await get_latest_code(email_provider, email)
if not code: if not code:
raise AutomationError( raise AutomationError(
"email_provider", "Email provider returned no verification message" "email_provider", "Email provider returned no verification message"
) )
logger.info("[3/6] Verification code extracted: %s", code) logger.info("[3/5] Verification code extracted: %s", code)
used_codes = {code}
await chatgpt_page.bring_to_front() await chatgpt_page.bring_to_front()
code_input = chatgpt_page.get_by_placeholder("Code") code_input = chatgpt_page.get_by_placeholder("Code")
if await code_input.count() > 0: if await code_input.count() > 0:
await code_input.fill(code) await code_input.fill(code)
await chatgpt_page.wait_for_timeout(5000) await click_continue(chatgpt_page)
continue_btn = chatgpt_page.get_by_role( logger.info("[4/5] Setting profile...")
"button", name="Continue", exact=True
)
if await continue_btn.count() > 0:
await continue_btn.click()
await chatgpt_page.wait_for_timeout(5000)
logger.info("[4/6] Setting profile...")
name_input = chatgpt_page.get_by_placeholder("Full name") name_input = chatgpt_page.get_by_placeholder("Full name")
await name_input.first.wait_for(state="visible", timeout=20000)
if await name_input.count() > 0: if await name_input.count() > 0:
await name_input.fill(full_name) await name_input.fill(full_name)
await chatgpt_page.wait_for_timeout(500)
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year) await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
await chatgpt_page.wait_for_timeout(1000) profile_url = chatgpt_page.url
await click_continue(chatgpt_page)
continue_btn = chatgpt_page.get_by_role(
"button", name="Continue", exact=True
)
if await continue_btn.count() > 0:
await continue_btn.click()
logger.info("Account registered!") logger.info("Account registered!")
await wait_for_signup_stabilization(chatgpt_page) await wait_for_signup_stabilization(chatgpt_page, source_url=profile_url)
logger.info("[5/6] Skipping onboarding...") logger.info("[5/5] Running OAuth flow to get tokens...")
for _ in range(5):
skip_btn = chatgpt_page.locator(
'button:has-text("Skip"):not(:has-text("Skip Tour"))'
)
if await skip_btn.count() > 0:
for i in range(await skip_btn.count()):
try:
btn = skip_btn.nth(i)
if await btn.is_visible():
await btn.click(timeout=5000)
logger.info("Clicked: Skip")
await chatgpt_page.wait_for_timeout(1500)
except:
pass
await chatgpt_page.wait_for_timeout(1000)
skip_tour = chatgpt_page.locator('button:has-text("Skip Tour")')
if await skip_tour.count() > 0:
try:
await skip_tour.first.wait_for(state="visible", timeout=5000)
await skip_tour.first.click(timeout=5000)
logger.info("Clicked: Skip Tour")
await chatgpt_page.wait_for_timeout(2000)
except:
pass
await chatgpt_page.wait_for_timeout(2000)
for _ in range(3):
continue_btn = chatgpt_page.locator('button:has-text("Continue")')
if await continue_btn.count() > 0:
try:
await continue_btn.first.wait_for(state="visible", timeout=5000)
await continue_btn.first.click(timeout=5000)
logger.info("Clicked: Continue")
await chatgpt_page.wait_for_timeout(2000)
except:
pass
await chatgpt_page.wait_for_timeout(2000)
okay_btn = chatgpt_page.locator('button:has-text("Okay, let")')
for _ in range(10):
try:
await okay_btn.first.wait_for(state="visible", timeout=3000)
await okay_btn.first.click(timeout=5000)
logger.info("Clicked: Okay, let's go")
await chatgpt_page.wait_for_timeout(3000)
break
except:
await chatgpt_page.wait_for_timeout(1000)
logger.info("Skipping subscription/card flow (disabled)")
await chatgpt_page.wait_for_timeout(2000)
logger.info("[6/6] Running OAuth flow to get tokens...")
oauth_page = await context.new_page() oauth_page = await context.new_page()
current_page = oauth_page current_page = oauth_page
@ -400,33 +387,34 @@ async def register_chatgpt_account(
oauth_page.on("request", handle_request) oauth_page.on("request", handle_request)
await oauth_page.goto(authorize_url, wait_until="domcontentloaded") await oauth_page.goto(authorize_url, wait_until="domcontentloaded")
await oauth_page.wait_for_timeout(2000) await oauth_page.locator(
'input[type="email"], input[name="email"]'
).first.wait_for(state="visible", timeout=20000)
email_input = oauth_page.locator('input[type="email"], input[name="email"]') email_input = oauth_page.locator('input[type="email"], input[name="email"]')
if await email_input.count() > 0: if await email_input.count() > 0:
await email_input.first.fill(email) await email_input.first.fill(email)
await oauth_page.wait_for_timeout(400)
continue_button = oauth_page.get_by_role("button", name="Continue") continue_button = oauth_page.get_by_role("button", name="Continue")
if await continue_button.count() > 0: if await continue_button.count() > 0:
await continue_button.first.click() await continue_button.first.click()
await oauth_page.wait_for_timeout(2500) await oauth_page.locator('input[type="password"]').first.wait_for(
state="visible", timeout=20000
)
password_input = oauth_page.locator('input[type="password"]') password_input = oauth_page.locator('input[type="password"]')
if await password_input.count() > 0: if await password_input.count() > 0:
await password_input.first.fill(password) await password_input.first.fill(password)
await oauth_page.wait_for_timeout(400)
continue_button = oauth_page.get_by_role("button", name="Continue") continue_button = oauth_page.get_by_role("button", name="Continue")
if await continue_button.count() > 0: if await continue_button.count() > 0:
await continue_button.first.click() await continue_button.first.click()
await oauth_page.wait_for_timeout(2500)
for label in ["Continue", "Allow", "Authorize"]: for label in ["Continue", "Allow", "Authorize"]:
button = oauth_page.get_by_role("button", name=label) button = oauth_page.get_by_role("button", name=label)
if await button.count() > 0: if await button.count() > 0:
try: try:
await button.first.click(timeout=5000) await button.first.click(timeout=5000)
await oauth_page.wait_for_timeout(2000) await oauth_page.wait_for_timeout(500)
except Exception: except Exception:
pass pass

View file

@ -3,9 +3,12 @@ import time
import os import os
import aiohttp import aiohttp
from pathlib import Path from pathlib import Path
import logging
from providers.base import ProviderTokens from providers.base import ProviderTokens
logger = logging.getLogger(__name__)
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data")) DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json" TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json"
@ -66,17 +69,17 @@ async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
async def get_valid_tokens() -> ProviderTokens | None: async def get_valid_tokens() -> ProviderTokens | None:
tokens = load_tokens() tokens = load_tokens()
if not tokens: if not tokens:
print("No tokens found") logger.info("No tokens found")
return None return None
if tokens.is_expired: if tokens.is_expired:
print("Token expired, refreshing...") logger.info("Token expired, refreshing...")
if not tokens.refresh_token: if not tokens.refresh_token:
print("No refresh token available") logger.info("No refresh token available")
return None return None
new_tokens = await refresh_tokens(tokens.refresh_token) new_tokens = await refresh_tokens(tokens.refresh_token)
if not new_tokens: if not new_tokens:
print("Failed to refresh token") logger.warning("Failed to refresh token")
return None return None
save_tokens(new_tokens) save_tokens(new_tokens)
return new_tokens return new_tokens

View file

@ -17,7 +17,18 @@ def clamp_percent(value: Any) -> int:
return int(round(num)) return int(round(num))
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int: def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None:
if not isinstance(window, dict):
return None
return {
"used_percent": clamp_percent(window.get("used_percent") or 0),
"limit_window_seconds": int(window.get("limit_window_seconds") or 0),
"reset_after_seconds": int(window.get("reset_after_seconds") or 0),
"reset_at": int(window.get("reset_at") or 0),
}
def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None:
headers = { headers = {
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"User-Agent": "CodexProxy", "User-Agent": "CodexProxy",
@ -33,18 +44,42 @@ def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
try: try:
with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res: with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res:
body = res.read().decode("utf-8", errors="replace") body = res.read().decode("utf-8", errors="replace")
except urllib.error.HTTPError as e: except urllib.error.HTTPError:
return -1 return None
except urllib.error.URLError, socket.timeout: except urllib.error.URLError, socket.timeout:
return -1 return None
try: try:
data = json.loads(body) data = json.loads(body)
except json.JSONDecodeError: except json.JSONDecodeError:
return -1 return None
primary = (data.get("rate_limit") or {}).get("primary_window") or {} rate_limit = data.get("rate_limit") or {}
primary = _parse_window(rate_limit.get("primary_window"))
secondary = _parse_window(rate_limit.get("secondary_window"))
used_candidates = []
if primary: if primary:
return clamp_percent(primary.get("used_percent") or 0) used_candidates.append(primary["used_percent"])
if secondary:
used_candidates.append(secondary["used_percent"])
return -1 if not used_candidates:
return None
effective_used = max(used_candidates)
return {
"used_percent": effective_used,
"remaining_percent": max(0, 100 - effective_used),
"primary_window": primary,
"secondary_window": secondary,
"limit_reached": bool(rate_limit.get("limit_reached")),
"allowed": bool(rate_limit.get("allowed", True)),
}
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
data = get_usage_data(access_token, timeout_ms=timeout_ms)
if not data:
return -1
return int(data["used_percent"])

View file

@ -41,6 +41,47 @@ def build_limit(usage_percent: int) -> dict[str, int | bool]:
} }
async def ensure_provider_token_ready(provider_name: str):
provider = PROVIDERS.get(provider_name)
if not provider:
return
logger.info("[%s] Startup token check", provider_name)
token = await provider.get_token()
if not token:
logger.warning(
"[%s] Startup token check failed, forcing recreation", provider_name
)
if isinstance(provider, ChatGPTProvider):
token = await provider.force_recreate_token()
if not token:
logger.error("[%s] Could not prepare token at startup", provider_name)
return
usage_info = await provider.get_usage_info(token)
if "error" not in usage_info:
logger.info("[%s] Startup token is ready", provider_name)
return
logger.warning(
"[%s] Startup token invalid for usage, forcing recreation", provider_name
)
if isinstance(provider, ChatGPTProvider):
token = await provider.force_recreate_token()
if token:
logger.info("[%s] Startup token recreated successfully", provider_name)
return
logger.error("[%s] Startup token recreation failed", provider_name)
async def on_startup(app: web.Application):
del app
for provider_name in PROVIDERS.keys():
await ensure_provider_token_ready(provider_name)
async def issue_new_token(provider_name: str) -> str | None: async def issue_new_token(provider_name: str) -> str | None:
provider = PROVIDERS.get(provider_name) provider = PROVIDERS.get(provider_name)
if not provider: if not provider:
@ -114,6 +155,33 @@ async def token_handler(request: web.Request) -> web.Response:
) )
usage_percent = usage_info.get("used_percent", 0) usage_percent = usage_info.get("used_percent", 0)
remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent))
logger.info(
"[%s] token issued, used=%s%% remaining=%s%%",
provider_name,
usage_percent,
remaining_percent,
)
primary_window = usage_info.get("primary_window")
secondary_window = usage_info.get("secondary_window")
if primary_window:
logger.info(
"[%s] primary window: used=%s%% window=%ss reset_in=%ss",
provider_name,
primary_window.get("used_percent", 0),
primary_window.get("limit_window_seconds", 0),
primary_window.get("reset_after_seconds", 0),
)
if secondary_window:
logger.info(
"[%s] secondary window: used=%s%% window=%ss reset_in=%ss",
provider_name,
secondary_window.get("used_percent", 0),
secondary_window.get("limit_window_seconds", 0),
secondary_window.get("reset_after_seconds", 0),
)
# Trigger background refresh if needed # Trigger background refresh if needed
if usage_percent >= USAGE_REFRESH_THRESHOLD: if usage_percent >= USAGE_REFRESH_THRESHOLD:
@ -126,12 +194,17 @@ async def token_handler(request: web.Request) -> web.Response:
{ {
"token": token, "token": token,
"limit": build_limit(usage_percent), "limit": build_limit(usage_percent),
"usage": {
"primary_window": primary_window,
"secondary_window": secondary_window,
},
} }
) )
def create_app() -> web.Application: def create_app() -> web.Application:
app = web.Application(middlewares=[request_log_middleware]) app = web.Application(middlewares=[request_log_middleware])
app.on_startup.append(on_startup)
# New route: /{provider}/token # New route: /{provider}/token
app.router.add_get("/{provider}/token", token_handler) app.router.add_get("/{provider}/token", token_handler)
# Legacy route for backward compatibility # Legacy route for backward compatibility