refactor: harden ChatGPT token lifecycle with startup recovery, single-writer locking, and faster auth flow
This commit is contained in:
parent
71d1050adb
commit
533e382e0e
9 changed files with 313 additions and 178 deletions
10
Dockerfile
10
Dockerfile
|
|
@ -15,19 +15,19 @@ RUN pip install --no-cache-dir uv
|
|||
RUN uv sync --frozen --no-dev
|
||||
RUN /app/.venv/bin/python -m playwright install --with-deps chromium
|
||||
|
||||
COPY src/*.py /app/
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
|
||||
COPY src .
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PORT=8000
|
||||
ENV PORT=80
|
||||
ENV DATA_DIR=/data
|
||||
ENV PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD=1
|
||||
|
||||
VOLUME ["/data"]
|
||||
|
||||
EXPOSE 8000
|
||||
EXPOSE 80
|
||||
|
||||
STOPSIGNAL SIGINT
|
||||
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
|
||||
CMD ["/entrypoint.sh"]
|
||||
|
|
|
|||
11
compose.yml
11
compose.yml
|
|
@ -1,11 +0,0 @@
|
|||
services:
|
||||
megapt:
|
||||
build: src
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
USAGE_THRESHOLD: 85
|
||||
CHECK_INTERVAL: 60
|
||||
labels:
|
||||
traefik.host: megapt
|
||||
volumes:
|
||||
- ./data:/data
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
|
@ -45,6 +44,8 @@ CHROME_FLAGS = [
|
|||
"--disable-search-engine-choice-screen",
|
||||
]
|
||||
|
||||
DEFAULT_CDP_PORT = 9222
|
||||
|
||||
|
||||
def _fetch_ws_endpoint(port: int) -> str | None:
|
||||
try:
|
||||
|
|
@ -78,9 +79,10 @@ class ManagedBrowser:
|
|||
shutil.rmtree(self.profile_dir, ignore_errors=True)
|
||||
|
||||
|
||||
async def launch(playwright: Playwright, cdp_port: int | None = None) -> ManagedBrowser:
|
||||
chrome_path = os.environ.get("CHROMIUM_PATH") or playwright.chromium.executable_path
|
||||
cdp_port = cdp_port or int(os.environ.get("CDP_PORT", "9222"))
|
||||
async def launch(
|
||||
playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT
|
||||
) -> ManagedBrowser:
|
||||
chrome_path = playwright.chromium.executable_path
|
||||
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
|
||||
|
||||
args = [
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Any
|
||||
|
|
@ -7,11 +8,12 @@ from playwright.async_api import BrowserContext
|
|||
from providers.base import Provider, ProviderTokens
|
||||
from email_providers import BaseProvider
|
||||
from email_providers import TempMailOrgProvider
|
||||
from .tokens import load_tokens, save_tokens, get_valid_tokens
|
||||
from .usage import get_usage_percent
|
||||
from .tokens import load_tokens, save_tokens, refresh_tokens
|
||||
from .usage import get_usage_data
|
||||
from .registration import register_chatgpt_account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MAX_REGISTRATION_ATTEMPTS = 4
|
||||
|
||||
|
||||
class ChatGPTProvider(Provider):
|
||||
|
|
@ -22,23 +24,63 @@ class ChatGPTProvider(Provider):
|
|||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||
):
|
||||
self.email_provider_factory = email_provider_factory or TempMailOrgProvider
|
||||
self._token_write_lock = asyncio.Lock()
|
||||
|
||||
async def _register_with_retries(self) -> bool:
|
||||
for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1):
|
||||
logger.info(
|
||||
"Registration attempt %s/%s",
|
||||
attempt,
|
||||
MAX_REGISTRATION_ATTEMPTS,
|
||||
)
|
||||
success = await self.register_new_account()
|
||||
if success:
|
||||
return True
|
||||
logger.warning("Registration attempt %s failed", attempt)
|
||||
return False
|
||||
|
||||
async def force_recreate_token(self) -> str | None:
|
||||
async with self._token_write_lock:
|
||||
success = await self._register_with_retries()
|
||||
if not success:
|
||||
return None
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
return None
|
||||
return tokens.access_token
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "chatgpt"
|
||||
|
||||
async def get_token(self) -> str | None:
|
||||
"""Get valid access token, refreshing if needed"""
|
||||
tokens = await get_valid_tokens()
|
||||
if not tokens:
|
||||
logger.info("No valid tokens, registering new account")
|
||||
success = await self.register_new_account()
|
||||
"""Get valid access token with single-writer refresh/register path."""
|
||||
tokens = load_tokens()
|
||||
if tokens and not tokens.is_expired:
|
||||
return tokens.access_token
|
||||
|
||||
async with self._token_write_lock:
|
||||
tokens = load_tokens()
|
||||
if tokens and not tokens.is_expired:
|
||||
return tokens.access_token
|
||||
|
||||
if tokens and tokens.refresh_token:
|
||||
logger.info("Token expired, refreshing under lock")
|
||||
refreshed = await refresh_tokens(tokens.refresh_token)
|
||||
if refreshed:
|
||||
save_tokens(refreshed)
|
||||
return refreshed.access_token
|
||||
logger.warning("Token refresh failed, falling back to registration")
|
||||
|
||||
logger.info("No valid tokens, registering new account under lock")
|
||||
success = await self._register_with_retries()
|
||||
if not success:
|
||||
return None
|
||||
tokens = await get_valid_tokens()
|
||||
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
return None
|
||||
return tokens.access_token
|
||||
return tokens.access_token
|
||||
|
||||
async def register_new_account(self) -> bool:
|
||||
"""Register a new ChatGPT account"""
|
||||
|
|
@ -48,15 +90,18 @@ class ChatGPTProvider(Provider):
|
|||
|
||||
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Get usage information for the current token"""
|
||||
usage_percent = get_usage_percent(access_token)
|
||||
if usage_percent < 0:
|
||||
usage_data = get_usage_data(access_token)
|
||||
if not usage_data:
|
||||
return {"error": "Failed to get usage"}
|
||||
|
||||
remaining = max(0, 100 - usage_percent)
|
||||
return {
|
||||
"used_percent": usage_percent,
|
||||
"remaining_percent": remaining,
|
||||
"exhausted": usage_percent >= 100,
|
||||
"used_percent": int(usage_data["used_percent"]),
|
||||
"remaining_percent": int(usage_data["remaining_percent"]),
|
||||
"exhausted": int(usage_data["used_percent"]) >= 100,
|
||||
"primary_window": usage_data.get("primary_window"),
|
||||
"secondary_window": usage_data.get("secondary_window"),
|
||||
"limit_reached": bool(usage_data.get("limit_reached")),
|
||||
"allowed": bool(usage_data.get("allowed", True)),
|
||||
}
|
||||
|
||||
def load_tokens(self) -> ProviderTokens | None:
|
||||
|
|
|
|||
|
|
@ -78,6 +78,36 @@ def generate_name() -> str:
|
|||
"Paul",
|
||||
"Andrew",
|
||||
"Joshua",
|
||||
"Kenneth",
|
||||
"Kevin",
|
||||
"Brian",
|
||||
"George",
|
||||
"Edward",
|
||||
"Ronald",
|
||||
"Timothy",
|
||||
"Jason",
|
||||
"Jeffrey",
|
||||
"Ryan",
|
||||
"Jacob",
|
||||
"Gary",
|
||||
"Nicholas",
|
||||
"Eric",
|
||||
"Jonathan",
|
||||
"Stephen",
|
||||
"Larry",
|
||||
"Justin",
|
||||
"Scott",
|
||||
"Brandon",
|
||||
"Benjamin",
|
||||
"Samuel",
|
||||
"Frank",
|
||||
"Gregory",
|
||||
"Raymond",
|
||||
"Alexander",
|
||||
"Patrick",
|
||||
"Jack",
|
||||
"Dennis",
|
||||
"Jerry",
|
||||
]
|
||||
last_names = [
|
||||
"Smith",
|
||||
|
|
@ -100,10 +130,47 @@ def generate_name() -> str:
|
|||
"Moore",
|
||||
"Jackson",
|
||||
"Martin",
|
||||
"Lee",
|
||||
"Perez",
|
||||
"Thompson",
|
||||
"White",
|
||||
"Harris",
|
||||
"Sanchez",
|
||||
"Clark",
|
||||
"Ramirez",
|
||||
"Lewis",
|
||||
"Robinson",
|
||||
"Walker",
|
||||
"Young",
|
||||
"Allen",
|
||||
"King",
|
||||
"Wright",
|
||||
"Scott",
|
||||
"Torres",
|
||||
"Nguyen",
|
||||
"Hill",
|
||||
"Flores",
|
||||
"Green",
|
||||
"Adams",
|
||||
"Nelson",
|
||||
"Baker",
|
||||
"Hall",
|
||||
"Rivera",
|
||||
"Campbell",
|
||||
"Mitchell",
|
||||
"Carter",
|
||||
"Roberts",
|
||||
]
|
||||
return f"{random.choice(first_names)} {random.choice(last_names)}"
|
||||
|
||||
|
||||
def generate_birthdate_90s() -> tuple[str, str, str]:
|
||||
year = random.randint(1990, 1999)
|
||||
month = random.randint(1, 12)
|
||||
day = random.randint(1, 28)
|
||||
return f"{month:02d}", f"{day:02d}", str(year)
|
||||
|
||||
|
||||
def extract_verification_code(message: str) -> str | None:
|
||||
normalized = re.sub(r"\s+", " ", message)
|
||||
|
||||
|
|
@ -177,24 +244,6 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
|||
)
|
||||
|
||||
|
||||
async def get_new_verification_code(
|
||||
email_provider: BaseProvider,
|
||||
email: str,
|
||||
used_codes: set[str],
|
||||
timeout_seconds: int = 240,
|
||||
) -> str | None:
|
||||
attempts = max(1, timeout_seconds // 5)
|
||||
for _ in range(attempts):
|
||||
message = await email_provider.get_latest_message(email)
|
||||
if message:
|
||||
all_codes = re.findall(r"\b(\d{6})\b", message)
|
||||
for candidate in all_codes:
|
||||
if candidate not in used_codes:
|
||||
return candidate
|
||||
await asyncio.sleep(5)
|
||||
return None
|
||||
|
||||
|
||||
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
||||
message = await email_provider.get_latest_message(email)
|
||||
if not message:
|
||||
|
|
@ -209,24 +258,32 @@ async def fill_date_field(page: Page, month: str, day: str, year: str):
|
|||
|
||||
await month_field.scroll_into_view_if_needed()
|
||||
await month_field.click()
|
||||
await page.wait_for_timeout(120)
|
||||
await page.wait_for_timeout(80)
|
||||
|
||||
await page.keyboard.type(f"{month}{day}{year}")
|
||||
await page.wait_for_timeout(200)
|
||||
await page.wait_for_timeout(120)
|
||||
|
||||
|
||||
async def wait_for_signup_stabilization(page: Page):
|
||||
try:
|
||||
await page.wait_for_load_state("networkidle", timeout=15000)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Signup page did not reach networkidle quickly; continuing with fallback"
|
||||
)
|
||||
try:
|
||||
await page.wait_for_load_state("domcontentloaded", timeout=10000)
|
||||
except Exception:
|
||||
pass
|
||||
await page.wait_for_timeout(3000)
|
||||
async def click_continue(page: Page, timeout_ms: int = 10000):
|
||||
btn = page.get_by_role("button", name="Continue", exact=True).first
|
||||
await btn.wait_for(state="visible", timeout=timeout_ms)
|
||||
await btn.click()
|
||||
|
||||
|
||||
async def wait_for_signup_stabilization(
|
||||
page: Page,
|
||||
source_url: str,
|
||||
timeout_seconds: int = 30,
|
||||
):
|
||||
end_at = asyncio.get_running_loop().time() + timeout_seconds
|
||||
while asyncio.get_running_loop().time() < end_at:
|
||||
current_url = page.url
|
||||
if current_url != source_url:
|
||||
logger.info("Signup redirect detected: %s -> %s", source_url, current_url)
|
||||
return
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
logger.warning("Signup redirect was not detected within %ss", timeout_seconds)
|
||||
|
||||
|
||||
async def register_chatgpt_account(
|
||||
|
|
@ -238,7 +295,7 @@ async def register_chatgpt_account(
|
|||
logger.error("No email provider factory configured")
|
||||
return False
|
||||
|
||||
birth_month, birth_day, birth_year = "01", "15", "1995"
|
||||
birth_month, birth_day, birth_year = generate_birthdate_90s()
|
||||
|
||||
current_page: Page | None = None
|
||||
redirect_url_captured: str | None = None
|
||||
|
|
@ -253,7 +310,7 @@ async def register_chatgpt_account(
|
|||
)
|
||||
email_provider = email_provider_factory(context)
|
||||
|
||||
logger.info("[1/6] Getting new email from configured provider...")
|
||||
logger.info("[1/5] Getting new email from configured provider...")
|
||||
email = await email_provider.get_new_email()
|
||||
if not email:
|
||||
raise AutomationError(
|
||||
|
|
@ -266,127 +323,57 @@ async def register_chatgpt_account(
|
|||
oauth_state = generate_state()
|
||||
authorize_url = build_authorize_url(verifier, challenge, oauth_state)
|
||||
|
||||
logger.info("[2/6] Registering ChatGPT for %s", email)
|
||||
logger.info("[2/5] Registering ChatGPT for %s", email)
|
||||
chatgpt_page = await context.new_page()
|
||||
current_page = chatgpt_page
|
||||
await chatgpt_page.goto("https://chatgpt.com")
|
||||
await chatgpt_page.wait_for_load_state("domcontentloaded")
|
||||
|
||||
await chatgpt_page.get_by_text("Sign up for free", exact=True).click()
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
await chatgpt_page.locator('input[type="email"]').first.wait_for(
|
||||
state="visible", timeout=15000
|
||||
)
|
||||
|
||||
await chatgpt_page.locator('input[type="email"]').fill(email)
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await chatgpt_page.wait_for_timeout(3000)
|
||||
await click_continue(chatgpt_page)
|
||||
await chatgpt_page.locator('input[type="password"]').first.wait_for(
|
||||
state="visible", timeout=15000
|
||||
)
|
||||
|
||||
await chatgpt_page.locator('input[type="password"]').fill(password)
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
await click_continue(chatgpt_page)
|
||||
await chatgpt_page.get_by_placeholder("Code").first.wait_for(
|
||||
state="visible", timeout=30000
|
||||
)
|
||||
|
||||
logger.info("[3/6] Getting verification message from email provider...")
|
||||
logger.info("[3/5] Getting verification message from email provider...")
|
||||
code = await get_latest_code(email_provider, email)
|
||||
if not code:
|
||||
raise AutomationError(
|
||||
"email_provider", "Email provider returned no verification message"
|
||||
)
|
||||
logger.info("[3/6] Verification code extracted: %s", code)
|
||||
used_codes = {code}
|
||||
logger.info("[3/5] Verification code extracted: %s", code)
|
||||
|
||||
await chatgpt_page.bring_to_front()
|
||||
code_input = chatgpt_page.get_by_placeholder("Code")
|
||||
if await code_input.count() > 0:
|
||||
await code_input.fill(code)
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
await click_continue(chatgpt_page)
|
||||
|
||||
continue_btn = chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
)
|
||||
if await continue_btn.count() > 0:
|
||||
await continue_btn.click()
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
logger.info("[4/6] Setting profile...")
|
||||
logger.info("[4/5] Setting profile...")
|
||||
name_input = chatgpt_page.get_by_placeholder("Full name")
|
||||
await name_input.first.wait_for(state="visible", timeout=20000)
|
||||
if await name_input.count() > 0:
|
||||
await name_input.fill(full_name)
|
||||
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
||||
await chatgpt_page.wait_for_timeout(1000)
|
||||
|
||||
continue_btn = chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
)
|
||||
if await continue_btn.count() > 0:
|
||||
await continue_btn.click()
|
||||
profile_url = chatgpt_page.url
|
||||
await click_continue(chatgpt_page)
|
||||
|
||||
logger.info("Account registered!")
|
||||
await wait_for_signup_stabilization(chatgpt_page)
|
||||
await wait_for_signup_stabilization(chatgpt_page, source_url=profile_url)
|
||||
|
||||
logger.info("[5/6] Skipping onboarding...")
|
||||
|
||||
for _ in range(5):
|
||||
skip_btn = chatgpt_page.locator(
|
||||
'button:has-text("Skip"):not(:has-text("Skip Tour"))'
|
||||
)
|
||||
if await skip_btn.count() > 0:
|
||||
for i in range(await skip_btn.count()):
|
||||
try:
|
||||
btn = skip_btn.nth(i)
|
||||
if await btn.is_visible():
|
||||
await btn.click(timeout=5000)
|
||||
logger.info("Clicked: Skip")
|
||||
await chatgpt_page.wait_for_timeout(1500)
|
||||
except:
|
||||
pass
|
||||
await chatgpt_page.wait_for_timeout(1000)
|
||||
|
||||
skip_tour = chatgpt_page.locator('button:has-text("Skip Tour")')
|
||||
if await skip_tour.count() > 0:
|
||||
try:
|
||||
await skip_tour.first.wait_for(state="visible", timeout=5000)
|
||||
await skip_tour.first.click(timeout=5000)
|
||||
logger.info("Clicked: Skip Tour")
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
except:
|
||||
pass
|
||||
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
for _ in range(3):
|
||||
continue_btn = chatgpt_page.locator('button:has-text("Continue")')
|
||||
if await continue_btn.count() > 0:
|
||||
try:
|
||||
await continue_btn.first.wait_for(state="visible", timeout=5000)
|
||||
await continue_btn.first.click(timeout=5000)
|
||||
logger.info("Clicked: Continue")
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
except:
|
||||
pass
|
||||
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
okay_btn = chatgpt_page.locator('button:has-text("Okay, let")')
|
||||
for _ in range(10):
|
||||
try:
|
||||
await okay_btn.first.wait_for(state="visible", timeout=3000)
|
||||
await okay_btn.first.click(timeout=5000)
|
||||
logger.info("Clicked: Okay, let's go")
|
||||
await chatgpt_page.wait_for_timeout(3000)
|
||||
break
|
||||
except:
|
||||
await chatgpt_page.wait_for_timeout(1000)
|
||||
|
||||
logger.info("Skipping subscription/card flow (disabled)")
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
logger.info("[6/6] Running OAuth flow to get tokens...")
|
||||
logger.info("[5/5] Running OAuth flow to get tokens...")
|
||||
oauth_page = await context.new_page()
|
||||
current_page = oauth_page
|
||||
|
||||
|
|
@ -400,33 +387,34 @@ async def register_chatgpt_account(
|
|||
oauth_page.on("request", handle_request)
|
||||
|
||||
await oauth_page.goto(authorize_url, wait_until="domcontentloaded")
|
||||
await oauth_page.wait_for_timeout(2000)
|
||||
await oauth_page.locator(
|
||||
'input[type="email"], input[name="email"]'
|
||||
).first.wait_for(state="visible", timeout=20000)
|
||||
|
||||
email_input = oauth_page.locator('input[type="email"], input[name="email"]')
|
||||
if await email_input.count() > 0:
|
||||
await email_input.first.fill(email)
|
||||
await oauth_page.wait_for_timeout(400)
|
||||
|
||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||
if await continue_button.count() > 0:
|
||||
await continue_button.first.click()
|
||||
await oauth_page.wait_for_timeout(2500)
|
||||
await oauth_page.locator('input[type="password"]').first.wait_for(
|
||||
state="visible", timeout=20000
|
||||
)
|
||||
|
||||
password_input = oauth_page.locator('input[type="password"]')
|
||||
if await password_input.count() > 0:
|
||||
await password_input.first.fill(password)
|
||||
await oauth_page.wait_for_timeout(400)
|
||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||
if await continue_button.count() > 0:
|
||||
await continue_button.first.click()
|
||||
await oauth_page.wait_for_timeout(2500)
|
||||
|
||||
for label in ["Continue", "Allow", "Authorize"]:
|
||||
button = oauth_page.get_by_role("button", name=label)
|
||||
if await button.count() > 0:
|
||||
try:
|
||||
await button.first.click(timeout=5000)
|
||||
await oauth_page.wait_for_timeout(2000)
|
||||
await oauth_page.wait_for_timeout(500)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ import time
|
|||
import os
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from providers.base import ProviderTokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
|
||||
TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json"
|
||||
|
||||
|
|
@ -66,17 +69,17 @@ async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
|||
async def get_valid_tokens() -> ProviderTokens | None:
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
print("No tokens found")
|
||||
logger.info("No tokens found")
|
||||
return None
|
||||
|
||||
if tokens.is_expired:
|
||||
print("Token expired, refreshing...")
|
||||
logger.info("Token expired, refreshing...")
|
||||
if not tokens.refresh_token:
|
||||
print("No refresh token available")
|
||||
logger.info("No refresh token available")
|
||||
return None
|
||||
new_tokens = await refresh_tokens(tokens.refresh_token)
|
||||
if not new_tokens:
|
||||
print("Failed to refresh token")
|
||||
logger.warning("Failed to refresh token")
|
||||
return None
|
||||
save_tokens(new_tokens)
|
||||
return new_tokens
|
||||
|
|
|
|||
|
|
@ -17,7 +17,18 @@ def clamp_percent(value: Any) -> int:
|
|||
return int(round(num))
|
||||
|
||||
|
||||
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
||||
def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None:
|
||||
if not isinstance(window, dict):
|
||||
return None
|
||||
return {
|
||||
"used_percent": clamp_percent(window.get("used_percent") or 0),
|
||||
"limit_window_seconds": int(window.get("limit_window_seconds") or 0),
|
||||
"reset_after_seconds": int(window.get("reset_after_seconds") or 0),
|
||||
"reset_at": int(window.get("reset_at") or 0),
|
||||
}
|
||||
|
||||
|
||||
def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": "CodexProxy",
|
||||
|
|
@ -33,18 +44,42 @@ def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
|||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res:
|
||||
body = res.read().decode("utf-8", errors="replace")
|
||||
except urllib.error.HTTPError as e:
|
||||
return -1
|
||||
except urllib.error.HTTPError:
|
||||
return None
|
||||
except urllib.error.URLError, socket.timeout:
|
||||
return -1
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
return -1
|
||||
return None
|
||||
|
||||
primary = (data.get("rate_limit") or {}).get("primary_window") or {}
|
||||
rate_limit = data.get("rate_limit") or {}
|
||||
primary = _parse_window(rate_limit.get("primary_window"))
|
||||
secondary = _parse_window(rate_limit.get("secondary_window"))
|
||||
|
||||
used_candidates = []
|
||||
if primary:
|
||||
return clamp_percent(primary.get("used_percent") or 0)
|
||||
used_candidates.append(primary["used_percent"])
|
||||
if secondary:
|
||||
used_candidates.append(secondary["used_percent"])
|
||||
|
||||
return -1
|
||||
if not used_candidates:
|
||||
return None
|
||||
|
||||
effective_used = max(used_candidates)
|
||||
return {
|
||||
"used_percent": effective_used,
|
||||
"remaining_percent": max(0, 100 - effective_used),
|
||||
"primary_window": primary,
|
||||
"secondary_window": secondary,
|
||||
"limit_reached": bool(rate_limit.get("limit_reached")),
|
||||
"allowed": bool(rate_limit.get("allowed", True)),
|
||||
}
|
||||
|
||||
|
||||
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
||||
data = get_usage_data(access_token, timeout_ms=timeout_ms)
|
||||
if not data:
|
||||
return -1
|
||||
return int(data["used_percent"])
|
||||
|
|
|
|||
|
|
@ -41,6 +41,47 @@ def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
|||
}
|
||||
|
||||
|
||||
async def ensure_provider_token_ready(provider_name: str):
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
return
|
||||
|
||||
logger.info("[%s] Startup token check", provider_name)
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
logger.warning(
|
||||
"[%s] Startup token check failed, forcing recreation", provider_name
|
||||
)
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
token = await provider.force_recreate_token()
|
||||
|
||||
if not token:
|
||||
logger.error("[%s] Could not prepare token at startup", provider_name)
|
||||
return
|
||||
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" not in usage_info:
|
||||
logger.info("[%s] Startup token is ready", provider_name)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"[%s] Startup token invalid for usage, forcing recreation", provider_name
|
||||
)
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
token = await provider.force_recreate_token()
|
||||
if token:
|
||||
logger.info("[%s] Startup token recreated successfully", provider_name)
|
||||
return
|
||||
|
||||
logger.error("[%s] Startup token recreation failed", provider_name)
|
||||
|
||||
|
||||
async def on_startup(app: web.Application):
|
||||
del app
|
||||
for provider_name in PROVIDERS.keys():
|
||||
await ensure_provider_token_ready(provider_name)
|
||||
|
||||
|
||||
async def issue_new_token(provider_name: str) -> str | None:
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
|
|
@ -114,6 +155,33 @@ async def token_handler(request: web.Request) -> web.Response:
|
|||
)
|
||||
|
||||
usage_percent = usage_info.get("used_percent", 0)
|
||||
remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent))
|
||||
|
||||
logger.info(
|
||||
"[%s] token issued, used=%s%% remaining=%s%%",
|
||||
provider_name,
|
||||
usage_percent,
|
||||
remaining_percent,
|
||||
)
|
||||
|
||||
primary_window = usage_info.get("primary_window")
|
||||
secondary_window = usage_info.get("secondary_window")
|
||||
if primary_window:
|
||||
logger.info(
|
||||
"[%s] primary window: used=%s%% window=%ss reset_in=%ss",
|
||||
provider_name,
|
||||
primary_window.get("used_percent", 0),
|
||||
primary_window.get("limit_window_seconds", 0),
|
||||
primary_window.get("reset_after_seconds", 0),
|
||||
)
|
||||
if secondary_window:
|
||||
logger.info(
|
||||
"[%s] secondary window: used=%s%% window=%ss reset_in=%ss",
|
||||
provider_name,
|
||||
secondary_window.get("used_percent", 0),
|
||||
secondary_window.get("limit_window_seconds", 0),
|
||||
secondary_window.get("reset_after_seconds", 0),
|
||||
)
|
||||
|
||||
# Trigger background refresh if needed
|
||||
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
||||
|
|
@ -126,12 +194,17 @@ async def token_handler(request: web.Request) -> web.Response:
|
|||
{
|
||||
"token": token,
|
||||
"limit": build_limit(usage_percent),
|
||||
"usage": {
|
||||
"primary_window": primary_window,
|
||||
"secondary_window": secondary_window,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_app() -> web.Application:
|
||||
app = web.Application(middlewares=[request_log_middleware])
|
||||
app.on_startup.append(on_startup)
|
||||
# New route: /{provider}/token
|
||||
app.router.add_get("/{provider}/token", token_handler)
|
||||
# Legacy route for backward compatibility
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue