diff --git a/.env.example b/.env.example index f596830..356e109 100644 --- a/.env.example +++ b/.env.example @@ -1,11 +1,8 @@ # HTTP server port PORT=80 -# Prepare next ChatGPT account when active usage reaches threshold percent -CHATGPT_PREPARE_THRESHOLD=85 - -# Switch active ChatGPT account when usage reaches threshold percent -CHATGPT_SWITCH_THRESHOLD=95 +# Trigger background token refresh when usage reaches threshold percent +USAGE_REFRESH_THRESHOLD=85 # Persistent data directory (tokens, screenshots) DATA_DIR=/data diff --git a/README.md b/README.md index 3e37601..7c38f5c 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,29 @@ # megapt -Service for issuing ChatGPT OAuth tokens via browser automation with disposable email. +HTTP service that returns an active ChatGPT access token. + +The service can: +- restore/refresh a saved token from `/data` +- auto-register a new ChatGPT account when needed +- get verification email from a disposable mail provider (`temp-mail.org`) +- expose token and usage info via HTTP endpoint + ## Endpoints -- `GET /chatgpt/token` -- `GET /token` (legacy alias, same as chatgpt) +- `GET /token` - legacy route (defaults to `chatgpt` provider) +- `GET /chatgpt/token` - explicit provider route -Response shape: +Example response: ```json { - "token": "...", + "token": "", "limit": { "used_percent": 0, "remaining_percent": 100, "exhausted": false, - "needs_prepare": false + "needs_refresh": false }, "usage": { "primary_window": { @@ -30,58 +37,67 @@ Response shape: } ``` -## Environment Variables -- `PORT` - HTTP server port (default: `8080`) -- `DATA_DIR` - persistent data directory for tokens/screenshots (default: `./data`) -- `CHATGPT_PREPARE_THRESHOLD` - usage threshold to prepare `next_account` (default: `85`) -- `CHATGPT_SWITCH_THRESHOLD` - usage threshold to switch active account to `next_account` (default: `95`) +## Environment variables -Example config is in `.env.example`. +See `.env.example`. -## Token Lifecycle +- `PORT` - HTTP port for the service +- `USAGE_REFRESH_THRESHOLD` - percent threshold to trigger background token rotation +- `DATA_DIR` - directory for persistent data (`chatgpt_tokens.json`, screenshots, etc.) -- **active account** - currently served token. -- **next account** - pre-created account/token stored for fast switch. -Behavior: +## Local run -1. If active token is valid, service returns it immediately. -2. If active token is expired, service tries refresh under a single write lock. -3. If refresh fails or token is missing, service registers a new account (up to 4 attempts). -4. When usage reaches `CHATGPT_PREPARE_THRESHOLD`, service prepares `next_account`. -5. When usage reaches `CHATGPT_SWITCH_THRESHOLD`, service switches active account to `next_account`. +Requirements: +- Python 3.14+ +- Playwright Chromium dependencies -## Disposable Email Provider - -- Default provider is `mail.tm` API (`MailTmProvider`) and does not use browser automation. -- Flow: fetch domains -> create account with random address/password -> get JWT token -> poll messages. - -## Startup Behavior - -On startup, service ensures active token exists and is usable. -Standby preparation runs through provider lifecycle hooks/background trigger when needed. - -## Data Files - -- `DATA_DIR/chatgpt_tokens.json` - token state with `active` and `next_account`. -- `DATA_DIR/screenshots/` - automation failure screenshots. - -## Run Locally +Install and run: ```bash -PYTHONPATH=./src python src/server.py +uv sync --frozen --no-dev +./.venv/bin/python -m playwright install --with-deps chromium +PYTHONPATH=./src ./.venv/bin/python src/server.py ``` -## Unit Tests - -The project has unit tests only (no integration/network tests). +Then request token: ```bash -pytest -q +curl http://127.0.0.1:8080/chatgpt/token ``` -## Docker Notes -- Dockerfile sets `DATA_DIR=/data`. -- `entrypoint.sh` starts Xvfb and runs `server.py`. +## Docker deployment + +Build image: + +```bash +docker build -t megapt:latest . +``` + +Run container: + +```bash +docker run -d \ + --name megapt \ + --restart unless-stopped \ + --env-file .env \ + -v ./data:/data \ + -p 80:80 \ + megapt:latest +``` + +Check logs: + +```bash +docker logs -f megapt +``` + + +## Notes + +- Service performs a startup token check and tries to recover token automatically. +- Token write path is synchronized (single-writer lock) to avoid parallel re-registration. +- Browser runs in virtual display (`Xvfb`) inside container. +- Keep `/data` persistent between restarts. diff --git a/pyproject.toml b/pyproject.toml index 7f927f9..cc1ff33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,5 @@ dependencies = [ "pkce==1.0.3", ] -[project.optional-dependencies] -dev = [ - "pytest>=8.0.0", -] - [tool.uv] package = false diff --git a/scripts/run_token_refresh_flow.py b/scripts/run_token_refresh_flow.py deleted file mode 100644 index 577763b..0000000 --- a/scripts/run_token_refresh_flow.py +++ /dev/null @@ -1,55 +0,0 @@ -import argparse -import asyncio -import json -import logging - -import browser -import server - - -class _FakeRequest: - def __init__(self, provider: str): - self.match_info = {"provider": provider} - self.method = "GET" - self.path_qs = f"/{provider}/token" - - -def _enable_headed_browser() -> bool: - if "--no-startup-window" in browser.CHROME_FLAGS: - browser.CHROME_FLAGS.remove("--no-startup-window") - return True - return False - - -async def _run(provider: str) -> int: - patched = _enable_headed_browser() - logging.info("Headed mode patch applied: %s", patched) - - request = _FakeRequest(provider) - response = await server.token_handler(request) - payload = json.loads(response.body.decode("utf-8")) - - logging.info("Response status: %s", response.status) - logging.info("Response body: %s", json.dumps(payload, indent=2)) - return 0 if response.status == 200 else 1 - - -def main() -> int: - parser = argparse.ArgumentParser( - description=( - "Run the same token refresh/issue flow as server /{provider}/token " - "in headed browser mode (non-headless)." - ) - ) - parser.add_argument("--provider", default="chatgpt") - args = parser.parse_args() - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(name)s: %(message)s", - ) - return asyncio.run(_run(args.provider)) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/browser.py b/src/browser.py index daa7fef..3b35da9 100644 --- a/src/browser.py +++ b/src/browser.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import socket import shutil import subprocess import tempfile @@ -45,12 +44,7 @@ CHROME_FLAGS = [ "--disable-search-engine-choice-screen", ] - -def _allocate_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return int(s.getsockname()[1]) +DEFAULT_CDP_PORT = 9222 def _fetch_ws_endpoint(port: int) -> str | None: @@ -85,9 +79,10 @@ class ManagedBrowser: shutil.rmtree(self.profile_dir, ignore_errors=True) -async def launch(playwright: Playwright) -> ManagedBrowser: +async def launch( + playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT +) -> ManagedBrowser: chrome_path = playwright.chromium.executable_path - cdp_port = _allocate_free_port() profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp")) args = [ diff --git a/src/email_providers/__init__.py b/src/email_providers/__init__.py index a1af08e..76e66ca 100644 --- a/src/email_providers/__init__.py +++ b/src/email_providers/__init__.py @@ -1,11 +1,5 @@ from .base import BaseProvider -from .mail_tm import MailTmProvider from .ten_minute_mail import TenMinuteMailProvider from .temp_mail_org import TempMailOrgProvider -__all__ = [ - "BaseProvider", - "MailTmProvider", - "TenMinuteMailProvider", - "TempMailOrgProvider", -] +__all__ = ["BaseProvider", "TenMinuteMailProvider", "TempMailOrgProvider"] diff --git a/src/email_providers/base.py b/src/email_providers/base.py index 878c7e6..386601d 100644 --- a/src/email_providers/base.py +++ b/src/email_providers/base.py @@ -12,5 +12,5 @@ class BaseProvider(ABC): pass @abstractmethod - async def get_latest_message(self) -> str | None: + async def get_latest_message(self, email: str) -> str | None: pass diff --git a/src/email_providers/mail_tm.py b/src/email_providers/mail_tm.py deleted file mode 100644 index 4dd9c6c..0000000 --- a/src/email_providers/mail_tm.py +++ /dev/null @@ -1,227 +0,0 @@ -import asyncio -import logging -import os -import secrets -import string -from typing import Any - -import aiohttp -from playwright.async_api import BrowserContext - -from .base import BaseProvider -from utils.randoms import generate_password - -logger = logging.getLogger(__name__) - -_API_BASE = os.environ.get("MAIL_TM_API_BASE", "https://api.mail.tm") -_TIMEOUT_SECONDS = 20 -_FIRST_NAMES = [ - "james", - "john", - "robert", - "michael", - "david", - "william", - "joseph", - "thomas", - "daniel", - "mark", - "paul", - "kevin", -] -_LAST_NAMES = [ - "smith", - "johnson", - "williams", - "brown", - "jones", - "miller", - "davis", - "wilson", - "anderson", - "taylor", - "martin", - "thompson", -] - - -def _generate_local_part() -> str: - first = secrets.choice(_FIRST_NAMES) - last = secrets.choice(_LAST_NAMES) - digits = "".join(secrets.choice(string.digits) for _ in range(8)) - return f"{first}{last}{digits}" - - -class MailTmProvider(BaseProvider): - def __init__(self, browser_session: BrowserContext): - super().__init__(browser_session) - self._address: str | None = None - self._password: str | None = None - self._token: str | None = None - - async def _request( - self, - method: str, - path: str, - *, - token: str | None = None, - json_body: dict[str, Any] | None = None, - ) -> tuple[int, dict[str, Any] | list[Any] | None]: - url = f"{_API_BASE.rstrip('/')}{path}" - headers: dict[str, str] = {} - if token: - headers["Authorization"] = f"Bearer {token}" - - timeout = aiohttp.ClientTimeout(total=_TIMEOUT_SECONDS) - try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.request( - method, - url, - headers=headers, - json=json_body, - ) as resp: - status = resp.status - try: - payload = await resp.json() - except aiohttp.ContentTypeError: - payload = None - return status, payload - except aiohttp.ClientError as e: - logger.warning("[mail.tm] request failed %s %s: %s", method, path, e) - return 0, None - - async def _get_domains(self) -> list[str]: - status, payload = await self._request("GET", "/domains") - if status != 200 or not isinstance(payload, dict): - raise RuntimeError("mail.tm domains request failed") - - members = payload.get("hydra:member") - if not isinstance(members, list): - raise RuntimeError("mail.tm domains response has unexpected format") - - domains: list[str] = [] - for item in members: - if not isinstance(item, dict): - continue - domain = item.get("domain") - is_active = bool(item.get("isActive", True)) - if isinstance(domain, str) and domain and is_active: - domains.append(domain) - - if not domains: - raise RuntimeError("mail.tm returned no active domains") - return domains - - async def _create_account(self, address: str, password: str) -> bool: - status, _ = await self._request( - "POST", - "/accounts", - json_body={"address": address, "password": password}, - ) - if status in (200, 201): - return True - return False - - async def _create_token(self, address: str, password: str) -> str | None: - status, payload = await self._request( - "POST", - "/token", - json_body={"address": address, "password": password}, - ) - if status != 200 or not isinstance(payload, dict): - return None - token = payload.get("token") - if isinstance(token, str) and token: - return token - return None - - async def get_new_email(self) -> str: - domains = await self._get_domains() - - for _ in range(8): - domain = secrets.choice(domains) - address = f"{_generate_local_part()}@{domain}" - password = generate_password(length=24) - - created = await self._create_account(address, password) - if not created: - continue - - token = await self._create_token(address, password) - if not token: - continue - - self._address = address - self._password = password - self._token = token - logger.info("[mail.tm] New mailbox acquired: %s", address) - return address - - raise RuntimeError("mail.tm could not create account") - - async def _list_messages(self) -> list[dict[str, Any]]: - if not self._token: - return [] - status, payload = await self._request( - "GET", - "/messages", - token=self._token, - ) - if status == 401 and self._address and self._password: - token = await self._create_token(self._address, self._password) - if token: - self._token = token - status, payload = await self._request( - "GET", - "/messages", - token=self._token, - ) - - if status != 200 or not isinstance(payload, dict): - return [] - - members = payload.get("hydra:member") - if not isinstance(members, list): - return [] - return [item for item in members if isinstance(item, dict)] - - async def _get_message_text(self, message_id: str) -> str | None: - if not self._token: - return None - status, payload = await self._request( - "GET", - f"/messages/{message_id}", - token=self._token, - ) - if status != 200 or not isinstance(payload, dict): - return None - - parts = [ - payload.get("subject"), - payload.get("intro"), - payload.get("text"), - payload.get("html"), - ] - text = "\n".join(str(part) for part in parts if part) - return text or None - - async def get_latest_message(self) -> str | None: - if not self._token: - raise RuntimeError("mail.tm provider is not initialized with mailbox token") - - for _ in range(45): - messages = await self._list_messages() - if messages: - latest = messages[0] - message_id = latest.get("id") - if isinstance(message_id, str) and message_id: - full_message = await self._get_message_text(message_id) - if full_message: - logger.info("[mail.tm] Latest message received") - return full_message - - await asyncio.sleep(2) - - logger.warning("[mail.tm] No messages received within timeout") - return None diff --git a/src/email_providers/temp_mail_org.py b/src/email_providers/temp_mail_org.py index 25d62cb..cb2b840 100644 --- a/src/email_providers/temp_mail_org.py +++ b/src/email_providers/temp_mail_org.py @@ -2,10 +2,9 @@ import asyncio import logging import re -from playwright.async_api import BrowserContext, Error as PlaywrightError, Page +from playwright.async_api import BrowserContext, Page from .base import BaseProvider -from .utils import ensure_page logger = logging.getLogger(__name__) @@ -16,7 +15,8 @@ class TempMailOrgProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - self.page = await ensure_page(self.browser_session, self.page) + if self.page is None or self.page.is_closed(): + self.page = await self.browser_session.new_page() return self.page async def get_new_email(self) -> str: @@ -44,7 +44,7 @@ class TempMailOrgProvider(BaseProvider): value, ) return value - except PlaywrightError: + except Exception: continue try: @@ -53,16 +53,16 @@ class TempMailOrgProvider(BaseProvider): if found: logger.info("[temp-mail.org] email found by body scan: %s", found) return found - except PlaywrightError: - logger.debug("Failed to scan body text for email") + except Exception: + pass await asyncio.sleep(1) raise RuntimeError("Could not get temp email from temp-mail.org") - async def get_latest_message(self) -> str | None: + async def get_latest_message(self, email: str) -> str | None: page = await self._ensure_page() - logger.info("[temp-mail.org] Waiting for latest message") + logger.info("[temp-mail.org] Waiting for latest message for %s", email) if page.is_closed(): raise RuntimeError("temp-mail.org tab was closed unexpectedly") @@ -76,7 +76,7 @@ class TempMailOrgProvider(BaseProvider): try: count = await items.count() logger.info("[temp-mail.org] inbox items: %s", count) - except PlaywrightError: + except Exception: count = 0 if count > 0: @@ -87,30 +87,30 @@ class TempMailOrgProvider(BaseProvider): continue text = (await item.inner_text()).strip().replace("\n", " ") logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160]) - except PlaywrightError: + except Exception: continue if text: try: await item.click() logger.info("[temp-mail.org] opened item[%s]", idx) - except PlaywrightError: - logger.debug("Failed to open inbox item[%s]", idx) + except Exception: + pass message_text = text try: content = await page.content() if content and "Your ChatGPT code is" in content: message_text = content - except PlaywrightError: - logger.debug("Failed to read opened message content") + except Exception: + pass try: await page.go_back( wait_until="domcontentloaded", timeout=5000 ) logger.info("[temp-mail.org] returned back to inbox") - except PlaywrightError: - logger.debug("Failed to return back to inbox") + except Exception: + pass return message_text diff --git a/src/email_providers/ten_minute_mail.py b/src/email_providers/ten_minute_mail.py index 1ff38a6..ae4c35e 100644 --- a/src/email_providers/ten_minute_mail.py +++ b/src/email_providers/ten_minute_mail.py @@ -4,7 +4,6 @@ import logging from playwright.async_api import BrowserContext, Page from .base import BaseProvider -from .utils import ensure_page logger = logging.getLogger(__name__) @@ -15,7 +14,8 @@ class TenMinuteMailProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - self.page = await ensure_page(self.browser_session, self.page) + if self.page is None or self.page.is_closed(): + self.page = await self.browser_session.new_page() return self.page async def get_new_email(self) -> str: @@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider): logger.info("[10min] New email acquired: %s", email) return email - async def get_latest_message(self) -> str | None: + async def get_latest_message(self, email: str) -> str | None: page = await self._ensure_page() - logger.info("[10min] Waiting for latest message") + logger.info("[10min] Waiting for latest message for %s", email) seen_count = 0 for attempt in range(60): diff --git a/src/email_providers/utils.py b/src/email_providers/utils.py deleted file mode 100644 index c6d44ce..0000000 --- a/src/email_providers/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from playwright.async_api import BrowserContext, Page - - -async def ensure_page( - browser_session: BrowserContext, - page: Page | None, -) -> Page: - if page is None or page.is_closed(): - return await browser_session.new_page() - return page diff --git a/src/providers/base.py b/src/providers/base.py index 478ab53..1a5b6bc 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -52,36 +52,3 @@ class Provider(ABC): def save_tokens(self, tokens: ProviderTokens) -> None: """Save tokens to storage""" pass - - async def force_recreate_token(self) -> str | None: - """Force-create a new active token when normal acquisition fails.""" - return None - - async def maybe_rotate_account(self, usage_percent: int) -> bool: - """Rotate active account/token if provider policy requires it.""" - return False - - @property - def prepare_threshold(self) -> int: - """Usage percent when provider should prepare standby account/token.""" - return 100 - - @property - def switch_threshold(self) -> int | None: - """Usage percent when provider may switch active account/token.""" - return None - - def should_prepare_standby(self, usage_percent: int) -> bool: - """Whether standby preparation should be triggered for current usage.""" - return False - - async def ensure_standby_account( - self, - usage_percent: int, - ) -> None: - """Prepare standby account/token asynchronously when needed.""" - return None - - async def startup_prepare(self) -> None: - """Optional provider-specific startup preparation.""" - return None diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index 03ae90a..6b70b9c 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -1,36 +1,19 @@ import asyncio import logging -from typing import Any from typing import Callable +from typing import Any from playwright.async_api import BrowserContext -from email_providers import BaseProvider -from email_providers import MailTmProvider from providers.base import Provider, ProviderTokens -from utils.env import parse_int_env -from .tokens import ( - clear_next_tokens, - load_next_tokens, - load_state, - load_tokens, - promote_next_tokens, - refresh_tokens, - save_state, - save_tokens, -) +from email_providers import BaseProvider +from email_providers import TempMailOrgProvider +from .tokens import load_tokens, save_tokens, refresh_tokens from .usage import get_usage_data from .registration import register_chatgpt_account logger = logging.getLogger(__name__) -CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4 -CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) -CHATGPT_SWITCH_THRESHOLD = parse_int_env( - "CHATGPT_SWITCH_THRESHOLD", - 95, - 0, - 100, -) +MAX_REGISTRATION_ATTEMPTS = 4 class ChatGPTProvider(Provider): @@ -40,61 +23,20 @@ class ChatGPTProvider(Provider): self, email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, ): - self.email_provider_factory = email_provider_factory or MailTmProvider + self.email_provider_factory = email_provider_factory or TempMailOrgProvider self._token_write_lock = asyncio.Lock() - @property - def prepare_threshold(self) -> int: - return CHATGPT_PREPARE_THRESHOLD - - @property - def switch_threshold(self) -> int | None: - return CHATGPT_SWITCH_THRESHOLD - async def _register_with_retries(self) -> bool: - for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1): + for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1): logger.info( "Registration attempt %s/%s", attempt, - CHATGPT_REGISTRATION_MAX_ATTEMPTS, + MAX_REGISTRATION_ATTEMPTS, ) - generated_tokens = await register_chatgpt_account( - email_provider_factory=self.email_provider_factory, - ) - if generated_tokens: - save_tokens(generated_tokens) + success = await self.register_new_account() + if success: return True logger.warning("Registration attempt %s failed", attempt) - await asyncio.sleep(1.5 * attempt) - return False - - async def _create_next_account_under_lock(self) -> bool: - active_before, next_before = load_state() - if next_before: - return True - - logger.info("Creating next account") - for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1): - logger.info( - "Next-account registration attempt %s/%s", - attempt, - CHATGPT_REGISTRATION_MAX_ATTEMPTS, - ) - generated_tokens = await register_chatgpt_account( - email_provider_factory=self.email_provider_factory, - ) - if generated_tokens: - if active_before: - save_state(active_before, generated_tokens) - else: - save_state(generated_tokens, None) - logger.info("Next account is ready") - return True - logger.warning("Next-account registration attempt %s failed", attempt) - await asyncio.sleep(1.5 * attempt) - - if active_before or next_before: - save_state(active_before, next_before) return False async def force_recreate_token(self) -> str | None: @@ -102,62 +44,11 @@ class ChatGPTProvider(Provider): success = await self._register_with_retries() if not success: return None - clear_next_tokens() tokens = load_tokens() if not tokens: return None return tokens.access_token - async def startup_prepare(self) -> None: - await self.ensure_next_account() - - async def ensure_next_account(self) -> bool: - next_tokens = load_next_tokens() - if next_tokens: - return True - - async with self._token_write_lock: - next_tokens = load_next_tokens() - if next_tokens: - return True - return await self._create_next_account_under_lock() - - def should_prepare_standby(self, usage_percent: int) -> bool: - return usage_percent >= self.prepare_threshold - - async def ensure_standby_account( - self, - usage_percent: int, - ) -> None: - if self.should_prepare_standby(usage_percent): - await self.ensure_next_account() - - async def maybe_switch_active_account(self, usage_percent: int) -> bool: - if usage_percent < CHATGPT_SWITCH_THRESHOLD: - return False - - async with self._token_write_lock: - next_tokens = load_next_tokens() - if not next_tokens or next_tokens.is_expired: - logger.info( - "Active usage >= %s%% and next account missing", - CHATGPT_SWITCH_THRESHOLD, - ) - created = await self._create_next_account_under_lock() - if not created: - return False - - switched = promote_next_tokens() - if switched: - logger.info( - "Switched active account (usage >= %s%%)", - CHATGPT_SWITCH_THRESHOLD, - ) - return switched - - async def maybe_rotate_account(self, usage_percent: int) -> bool: - return await self.maybe_switch_active_account(usage_percent) - @property def name(self) -> str: return "chatgpt" @@ -193,17 +84,13 @@ class ChatGPTProvider(Provider): async def register_new_account(self) -> bool: """Register a new ChatGPT account""" - generated_tokens = await register_chatgpt_account( + return await register_chatgpt_account( email_provider_factory=self.email_provider_factory, ) - if not generated_tokens: - return False - save_tokens(generated_tokens) - return True async def get_usage_info(self, access_token: str) -> dict[str, Any]: """Get usage information for the current token""" - usage_data = await get_usage_data(access_token) + usage_data = get_usage_data(access_token) if not usage_data: return {"error": "Failed to get usage"} diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py index 3087078..1af2581 100644 --- a/src/providers/chatgpt/registration.py +++ b/src/providers/chatgpt/registration.py @@ -5,6 +5,7 @@ import logging import random import re import secrets +import string import time from datetime import datetime from pathlib import Path @@ -13,18 +14,12 @@ from typing import Callable from urllib.parse import parse_qs, urlencode, urlparse import aiohttp -from playwright.async_api import ( - async_playwright, - Error as PlaywrightError, - Page, - BrowserContext, -) +from playwright.async_api import async_playwright, Page, BrowserContext from browser import launch as launch_browser from email_providers import BaseProvider from providers.base import ProviderTokens -from utils.randoms import generate_password -from .tokens import CLIENT_ID +from .tokens import CLIENT_ID, save_tokens logger = logging.getLogger(__name__) @@ -51,9 +46,14 @@ async def save_error_screenshot(page: Page | None, step: str): filename = screenshots_dir / f"error_{step}_{timestamp}.png" try: await page.screenshot(path=str(filename)) - logger.error("Screenshot saved: %s", filename) - except PlaywrightError as e: - logger.warning("Failed to save screenshot at step %s: %s", step, e) + logger.error(f"Screenshot saved: {filename}") + except: + pass + + +def generate_password(length: int = 20) -> str: + alphabet = string.ascii_letters + string.digits + return "".join(random.choice(alphabet) for _ in range(length)) def generate_name() -> str: @@ -204,7 +204,8 @@ def generate_state() -> str: return secrets.token_urlsafe(32) -def build_authorize_url(challenge: str, state: str) -> str: +def build_authorize_url(verifier: str, challenge: str, state: str) -> str: + del verifier params = { "response_type": "code", "client_id": CLIENT_ID, @@ -221,37 +222,30 @@ def build_authorize_url(challenge: str, state: str) -> str: async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens: - payload = { - "grant_type": "authorization_code", - "client_id": CLIENT_ID, - "code": code, - "code_verifier": verifier, - "redirect_uri": REDIRECT_URI, - } - timeout = aiohttp.ClientTimeout(total=20) - try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(TOKEN_URL, data=payload) as resp: - if not resp.ok: - text = await resp.text() - raise RuntimeError(f"Token exchange failed: {resp.status} {text}") - body = await resp.json() - except (aiohttp.ClientError, TimeoutError) as e: - raise RuntimeError(f"Token exchange request error: {e}") from e + async with aiohttp.ClientSession() as session: + payload = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": REDIRECT_URI, + } + async with session.post(TOKEN_URL, data=payload) as resp: + if not resp.ok: + text = await resp.text() + raise RuntimeError(f"Token exchange failed: {resp.status} {text}") + body = await resp.json() - try: - expires_in = int(body["expires_in"]) - return ProviderTokens( - access_token=body["access_token"], - refresh_token=body["refresh_token"], - expires_at=time.time() + expires_in, - ) - except (KeyError, TypeError, ValueError) as e: - raise RuntimeError(f"Token exchange response parse error: {e}") from e + expires_in = int(body["expires_in"]) + return ProviderTokens( + access_token=body["access_token"], + refresh_token=body["refresh_token"], + expires_at=time.time() + expires_in, + ) -async def get_latest_code(email_provider: BaseProvider) -> str | None: - message = await email_provider.get_latest_message() +async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: + message = await email_provider.get_latest_message(email) if not message: return None return extract_verification_code(message) @@ -276,51 +270,6 @@ async def click_continue(page: Page, timeout_ms: int = 10000): await btn.click() -async def oauth_needs_email_check(page: Page) -> bool: - marker = page.get_by_text("Check your inbox", exact=False) - return await marker.count() > 0 - - -async def fill_oauth_code_if_present(page: Page, code: str) -> bool: - candidates = [ - page.get_by_placeholder("Code"), - page.get_by_label("Code"), - page.locator( - 'input[name*="code" i], input[id*="code" i], ' - 'input[autocomplete="one-time-code"], input[inputmode="numeric"]' - ), - ] - - for locator in candidates: - if await locator.count() == 0: - continue - try: - await locator.first.wait_for(state="visible", timeout=1500) - await locator.first.fill(code) - return True - except PlaywrightError: - continue - return False - - -async def click_first_visible_button( - page: Page, - labels: list[str], - timeout_ms: int = 2000, -) -> bool: - for label in labels: - button = page.get_by_role("button", name=label) - if await button.count() == 0: - continue - try: - await button.first.wait_for(state="visible", timeout=timeout_ms) - await button.first.click(timeout=timeout_ms) - return True - except PlaywrightError: - continue - return False - - async def wait_for_signup_stabilization( page: Page, source_url: str, @@ -339,12 +288,12 @@ async def wait_for_signup_stabilization( async def register_chatgpt_account( email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, -) -> ProviderTokens | None: +) -> bool: logger.info("=== Starting ChatGPT account registration ===") if email_provider_factory is None: logger.error("No email provider factory configured") - return None + return False birth_month, birth_day, birth_year = generate_birthdate_90s() @@ -372,7 +321,7 @@ async def register_chatgpt_account( full_name = generate_name() verifier, challenge = generate_pkce_pair() oauth_state = generate_state() - authorize_url = build_authorize_url(challenge, oauth_state) + authorize_url = build_authorize_url(verifier, challenge, oauth_state) logger.info("[2/5] Registering ChatGPT for %s", email) chatgpt_page = await context.new_page() @@ -398,23 +347,24 @@ async def register_chatgpt_account( ) logger.info("[3/5] Getting verification message from email provider...") - code = await get_latest_code(email_provider) + code = await get_latest_code(email_provider, email) if not code: raise AutomationError( "email_provider", "Email provider returned no verification message" ) - logger.info("[3/5] Verification code extracted") + logger.info("[3/5] Verification code extracted: %s", code) await chatgpt_page.bring_to_front() code_input = chatgpt_page.get_by_placeholder("Code") - await code_input.first.wait_for(state="visible", timeout=10000) - await code_input.first.fill(code) + if await code_input.count() > 0: + await code_input.fill(code) await click_continue(chatgpt_page) logger.info("[4/5] Setting profile...") name_input = chatgpt_page.get_by_placeholder("Full name") await name_input.first.wait_for(state="visible", timeout=20000) - await name_input.first.fill(full_name) + if await name_input.count() > 0: + await name_input.fill(full_name) await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year) profile_url = chatgpt_page.url @@ -459,45 +409,25 @@ async def register_chatgpt_account( if await continue_button.count() > 0: await continue_button.first.click() - last_oauth_email_code = code - oauth_deadline = asyncio.get_running_loop().time() + 60 - while asyncio.get_running_loop().time() < oauth_deadline: - if redirect_url_captured: - break - - if await oauth_needs_email_check(oauth_page): - logger.info("OAuth requested email confirmation code") - new_code = await get_latest_code(email_provider) - if new_code and new_code != last_oauth_email_code: - filled = await fill_oauth_code_if_present(oauth_page, new_code) - if filled: - last_oauth_email_code = new_code - logger.info("Filled OAuth email confirmation code") - else: - logger.warning( - "OAuth inbox challenge detected but code field not found" - ) + for label in ["Continue", "Allow", "Authorize"]: + button = oauth_page.get_by_role("button", name=label) + if await button.count() > 0: + try: + await button.first.click(timeout=5000) + await oauth_page.wait_for_timeout(500) + except Exception: + pass + if not redirect_url_captured: try: + await oauth_page.wait_for_timeout(4000) current_url = oauth_page.url if "localhost:1455" in current_url and "code=" in current_url: redirect_url_captured = current_url logger.info("Captured OAuth redirect from page URL") - break except Exception: pass - clicked = await click_first_visible_button( - oauth_page, - ["Continue", "Allow", "Authorize", "Verify"], - timeout_ms=2000, - ) - - if clicked: - await oauth_page.wait_for_timeout(500) - else: - await oauth_page.wait_for_timeout(1000) - if not redirect_url_captured: raise AutomationError( "oauth", "OAuth redirect with code was not captured", oauth_page @@ -516,18 +446,20 @@ async def register_chatgpt_account( raise AutomationError("oauth", "OAuth state mismatch", oauth_page) tokens = await exchange_code_for_tokens(auth_code, verifier) - logger.info("OAuth tokens fetched successfully") + save_tokens(tokens) + logger.info("OAuth tokens saved successfully") - return tokens + return True except AutomationError as e: logger.error(f"Error at step [{e.step}]: {e.message}") await save_error_screenshot(e.page, e.step) - return None + return False except Exception as e: logger.error(f"Unexpected error: {e}") await save_error_screenshot(current_page, "unexpected") - return None + return False finally: if managed: + await asyncio.sleep(2) await managed.close() diff --git a/src/providers/chatgpt/tokens.py b/src/providers/chatgpt/tokens.py index b4ed7fb..4be1661 100644 --- a/src/providers/chatgpt/tokens.py +++ b/src/providers/chatgpt/tokens.py @@ -1,12 +1,9 @@ import json -import logging -import os -import tempfile import time -from pathlib import Path -from typing import Any - +import os import aiohttp +from pathlib import Path +import logging from providers.base import ProviderTokens @@ -19,143 +16,72 @@ CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" TOKEN_URL = "https://auth.openai.com/oauth/token" -def _tokens_to_dict(tokens: ProviderTokens) -> dict[str, Any]: - return { - "access_token": tokens.access_token, - "refresh_token": tokens.refresh_token, - "expires_at": tokens.expires_at, - } - - -def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None: - if not isinstance(data, dict): - return None - try: - return ProviderTokens( - access_token=data["access_token"], - refresh_token=data["refresh_token"], - expires_at=data["expires_at"], - ) - except (KeyError, TypeError): - return None - - -def _load_raw() -> dict[str, Any] | None: +def load_tokens() -> ProviderTokens | None: if not TOKENS_FILE.exists(): return None try: with open(TOKENS_FILE) as f: data = json.load(f) - if isinstance(data, dict): - return data + return ProviderTokens( + access_token=data["access_token"], + refresh_token=data["refresh_token"], + expires_at=data["expires_at"], + ) + except json.JSONDecodeError, KeyError: return None - except json.JSONDecodeError: - return None - - -def _save_raw(data: dict[str, Any]) -> None: - TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) - fd, tmp_path = tempfile.mkstemp( - prefix=f"{TOKENS_FILE.name}.", - suffix=".tmp", - dir=str(TOKENS_FILE.parent), - ) - try: - with os.fdopen(fd, "w") as f: - json.dump(data, f, indent=2) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, TOKENS_FILE) - finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) - - -def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]: - if not data: - return {"active": None, "next_account": None} - - if "active" in data or "next_account" in data: - return { - "active": data.get("active"), - "next_account": data.get("next_account"), - } - - return {"active": data, "next_account": None} - - -def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]: - normalized = _normalize_state(_load_raw()) - active = _dict_to_tokens(normalized.get("active")) - next_account = _dict_to_tokens(normalized.get("next_account")) - return active, next_account - - -def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None: - payload = { - "active": _tokens_to_dict(active) if active else None, - "next_account": _tokens_to_dict(next_account) if next_account else None, - } - _save_raw(payload) - - -def load_tokens() -> ProviderTokens | None: - active, _ = load_state() - return active - - -def load_next_tokens() -> ProviderTokens | None: - _, next_account = load_state() - return next_account def save_tokens(tokens: ProviderTokens): - _, next_account = load_state() - save_state(tokens, next_account) - - -def promote_next_tokens() -> bool: - _, next_account = load_state() - if not next_account: - return False - save_state(next_account, None) - return True - - -def clear_next_tokens(): - active, _ = load_state() - save_state(active, None) + TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) + with open(TOKENS_FILE, "w") as f: + json.dump( + { + "access_token": tokens.access_token, + "refresh_token": tokens.refresh_token, + "expires_at": tokens.expires_at, + }, + f, + indent=2, + ) async def refresh_tokens(refresh_token: str) -> ProviderTokens | None: - data = { - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "client_id": CLIENT_ID, - } - timeout = aiohttp.ClientTimeout(total=15) - try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(TOKEN_URL, data=data) as resp: - if not resp.ok: - text = await resp.text() - logger.warning("Token refresh failed: %s %s", resp.status, text) - return None - json_resp = await resp.json() - except (aiohttp.ClientError, TimeoutError) as e: - logger.warning("Token refresh request error: %s", e) - return None - except Exception as e: - logger.warning("Token refresh unexpected error: %s", e) + async with aiohttp.ClientSession() as session: + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CLIENT_ID, + } + async with session.post(TOKEN_URL, data=data) as resp: + if not resp.ok: + text = await resp.text() + print(f"Token refresh failed: {resp.status} {text}") + return None + json_resp = await resp.json() + expires_in = json_resp["expires_in"] + return ProviderTokens( + access_token=json_resp["access_token"], + refresh_token=json_resp["refresh_token"], + expires_at=time.time() + expires_in, + ) + + +async def get_valid_tokens() -> ProviderTokens | None: + tokens = load_tokens() + if not tokens: + logger.info("No tokens found") return None - try: - expires_in = int(json_resp["expires_in"]) - return ProviderTokens( - access_token=json_resp["access_token"], - refresh_token=json_resp["refresh_token"], - expires_at=time.time() + expires_in, - ) - except (KeyError, TypeError, ValueError) as e: - logger.warning("Token refresh response parse error: %s", e) - return None + if tokens.is_expired: + logger.info("Token expired, refreshing...") + if not tokens.refresh_token: + logger.info("No refresh token available") + return None + new_tokens = await refresh_tokens(tokens.refresh_token) + if not new_tokens: + logger.warning("Failed to refresh token") + return None + save_tokens(new_tokens) + return new_tokens + + return tokens diff --git a/src/providers/chatgpt/usage.py b/src/providers/chatgpt/usage.py index 7c2edab..6236bd8 100644 --- a/src/providers/chatgpt/usage.py +++ b/src/providers/chatgpt/usage.py @@ -1,15 +1,14 @@ -import logging +import json +import socket +import urllib.error +import urllib.request from typing import Any -import aiohttp - -logger = logging.getLogger(__name__) - def clamp_percent(value: Any) -> int: try: num = float(value) - except (TypeError, ValueError): + except TypeError, ValueError: return 0 if num < 0: return 0 @@ -29,36 +28,30 @@ def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None: } -async def get_usage_data( - access_token: str, - timeout_ms: int = 10000, -) -> dict[str, Any] | None: +def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None: headers = { "Authorization": f"Bearer {access_token}", "User-Agent": "CodexProxy", "Accept": "application/json", } - timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000) - url = "https://chatgpt.com/backend-api/wham/usage" + req = urllib.request.Request( + "https://chatgpt.com/backend-api/wham/usage", + headers=headers, + method="GET", + ) try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers) as res: - if not res.ok: - body = await res.text() - logger.warning( - "Usage fetch failed: status=%s body=%s", - res.status, - body[:300], - ) - return None - data = await res.json() - except (aiohttp.ClientError, TimeoutError) as e: - logger.warning("Usage fetch request error: %s", e) + with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res: + body = res.read().decode("utf-8", errors="replace") + except urllib.error.HTTPError: return None - except Exception as e: - logger.warning("Usage fetch unexpected error: %s", e) + except urllib.error.URLError, socket.timeout: + return None + + try: + data = json.loads(body) + except json.JSONDecodeError: return None rate_limit = data.get("rate_limit") or {} @@ -83,3 +76,10 @@ async def get_usage_data( "limit_reached": bool(rate_limit.get("limit_reached")), "allowed": bool(rate_limit.get("allowed", True)), } + + +def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int: + data = get_usage_data(access_token, timeout_ms=timeout_ms) + if not data: + return -1 + return int(data["used_percent"]) diff --git a/src/server.py b/src/server.py index be2e905..bb89e51 100644 --- a/src/server.py +++ b/src/server.py @@ -1,22 +1,27 @@ import asyncio import logging +import os from aiohttp import web from providers.chatgpt import ChatGPTProvider -from providers.base import Provider -from utils.env import parse_int_env + +PORT = int(os.environ.get("PORT", "8080")) +USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85")) +LIMIT_EXHAUSTED_PERCENT = 100 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -PORT = parse_int_env("PORT", 8080, 1, 65535) -LIMIT_EXHAUSTED_PERCENT = 100 -PROVIDERS: dict[str, Provider] = { +# Registry of available providers +PROVIDERS = { "chatgpt": ChatGPTProvider(), } -background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS} +refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()} +background_refresh_tasks: dict[str, asyncio.Task | None] = { + name: None for name in PROVIDERS.keys() +} @web.middleware @@ -26,30 +31,16 @@ async def request_log_middleware(request: web.Request, handler): return response -def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]: +def build_limit(usage_percent: int) -> dict[str, int | bool]: remaining = max(0, 100 - usage_percent) return { "used_percent": usage_percent, "remaining_percent": remaining, "exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT, - "needs_prepare": usage_percent >= prepare_threshold, + "needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD, } -def get_prepare_threshold(provider_name: str) -> int: - provider = PROVIDERS.get(provider_name) - if not provider: - return 100 - return provider.prepare_threshold - - -def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bool: - provider = PROVIDERS.get(provider_name) - if not provider: - return False - return provider.should_prepare_standby(usage_percent) - - async def ensure_provider_token_ready(provider_name: str): provider = PROVIDERS.get(provider_name) if not provider: @@ -61,95 +52,111 @@ async def ensure_provider_token_ready(provider_name: str): logger.warning( "[%s] Startup token check failed, forcing recreation", provider_name ) - token = await provider.force_recreate_token() + if isinstance(provider, ChatGPTProvider): + token = await provider.force_recreate_token() if not token: logger.error("[%s] Could not prepare token at startup", provider_name) return usage_info = await provider.get_usage_info(token) - if "error" in usage_info: - logger.warning( - "[%s] Startup token invalid for usage, forcing recreation", provider_name - ) + if "error" not in usage_info: + logger.info("[%s] Startup token is ready", provider_name) + return + + logger.warning( + "[%s] Startup token invalid for usage, forcing recreation", provider_name + ) + if isinstance(provider, ChatGPTProvider): token = await provider.force_recreate_token() - if not token: - logger.error("[%s] Startup token recreation failed", provider_name) + if token: + logger.info("[%s] Startup token recreated successfully", provider_name) return - await provider.startup_prepare() - logger.info("[%s] Startup token is ready", provider_name) + logger.error("[%s] Startup token recreation failed", provider_name) -async def ensure_standby_task(provider_name: str, usage_percent: int, reason: str): +async def on_startup(app: web.Application): + del app + for provider_name in PROVIDERS.keys(): + await ensure_provider_token_ready(provider_name) + + +async def issue_new_token(provider_name: str) -> str | None: provider = PROVIDERS.get(provider_name) if not provider: - return + return None - if not provider.should_prepare_standby(usage_percent): - return + async with refresh_locks[provider_name]: + logger.info(f"[{provider_name}] Generating new token") + success = await provider.register_new_account() + if not success: + logger.error(f"[{provider_name}] Token generation failed") + return None + token = await provider.get_token() + if not token: + logger.error(f"[{provider_name}] Token was generated but not available") + return None + + return token + + +async def background_refresh_worker(provider_name: str, reason: str): try: - logger.info("[%s] Preparing standby in background (%s)", provider_name, reason) - await provider.ensure_standby_account(usage_percent) + logger.info(f"[{provider_name}] Starting background token refresh ({reason})") + new_token = await issue_new_token(provider_name) + if new_token: + logger.info(f"[{provider_name}] Background token refresh completed") + else: + logger.error(f"[{provider_name}] Background token refresh failed") except Exception: - logger.exception("[%s] Unhandled standby preparation error", provider_name) + logger.exception( + f"[{provider_name}] Unhandled error in background token refresh" + ) -def trigger_standby_prepare(provider_name: str, usage_percent: int, reason: str): - task = background_tasks.get(provider_name) +def trigger_background_refresh(provider_name: str, reason: str): + task = background_refresh_tasks.get(provider_name) if task and not task.done(): logger.info( - "[%s] Standby prep already running, skip (%s)", provider_name, reason + f"[{provider_name}] Background refresh already running, skip ({reason})" ) return - background_tasks[provider_name] = asyncio.create_task( - ensure_standby_task(provider_name, usage_percent, reason) + background_refresh_tasks[provider_name] = asyncio.create_task( + background_refresh_worker(provider_name, reason) ) async def token_handler(request: web.Request) -> web.Response: provider_name = request.match_info.get("provider", "chatgpt") + provider = PROVIDERS.get(provider_name) if not provider: return web.json_response( - {"error": f"Unknown provider: {provider_name}"}, status=404 + {"error": f"Unknown provider: {provider_name}"}, + status=404, ) + # Get or create token token = await provider.get_token() if not token: - return web.json_response({"error": "Failed to get active token"}, status=503) - - usage_info = await provider.get_usage_info(token) - if "error" in usage_info: - return web.json_response({"error": usage_info["error"]}, status=503) - - usage_percent = int(usage_info.get("used_percent", 0)) - switched = await provider.maybe_rotate_account(usage_percent) - if switched: - token = await provider.get_token() - if not token: - return web.json_response( - {"error": "Failed to get active token after account switch"}, - status=503, - ) - usage_info = await provider.get_usage_info(token) - if "error" in usage_info: - return web.json_response({"error": usage_info["error"]}, status=503) - usage_percent = int(usage_info.get("used_percent", 0)) - logger.info("[%s] Active account switched before response", provider_name) - - prepare_threshold = get_prepare_threshold(provider_name) - if should_trigger_standby_prepare(provider_name, usage_percent): - trigger_standby_prepare( - provider_name, - usage_percent, - f"usage {usage_percent}% reached standby policy", + return web.json_response( + {"error": "Failed to get active token"}, + status=503, ) - remaining_percent = int( - usage_info.get("remaining_percent", max(0, 100 - usage_percent)) - ) + # Get usage info + usage_info = await provider.get_usage_info(token) + if "error" in usage_info: + return web.json_response( + {"error": usage_info["error"]}, + status=503, + ) + + usage_percent = usage_info.get("used_percent", 0) + remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent)) + logger.info( "[%s] token issued, used=%s%% remaining=%s%%", provider_name, @@ -176,10 +183,17 @@ async def token_handler(request: web.Request) -> web.Response: secondary_window.get("reset_after_seconds", 0), ) + # Trigger background refresh if needed + if usage_percent >= USAGE_REFRESH_THRESHOLD: + trigger_background_refresh( + provider_name, + f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%", + ) + return web.json_response( { "token": token, - "limit": build_limit(usage_percent, prepare_threshold), + "limit": build_limit(usage_percent), "usage": { "primary_window": primary_window, "secondary_window": secondary_window, @@ -188,42 +202,19 @@ async def token_handler(request: web.Request) -> web.Response: ) -async def on_startup(app: web.Application): - del app - for provider_name in PROVIDERS: - await ensure_provider_token_ready(provider_name) - - -async def on_cleanup(app: web.Application): - del app - for task in background_tasks.values(): - if task and not task.done(): - task.cancel() - pending = [t for t in background_tasks.values() if t is not None] - if pending: - await asyncio.gather(*pending, return_exceptions=True) - - def create_app() -> web.Application: app = web.Application(middlewares=[request_log_middleware]) app.on_startup.append(on_startup) - app.on_cleanup.append(on_cleanup) + # New route: /{provider}/token app.router.add_get("/{provider}/token", token_handler) + # Legacy route for backward compatibility app.router.add_get("/token", token_handler) return app if __name__ == "__main__": logger.info("Starting token service on port %s", PORT) - chatgpt_provider = PROVIDERS.get("chatgpt") - if chatgpt_provider: - logger.info( - "ChatGPT prepare-next threshold: %s%%", chatgpt_provider.prepare_threshold - ) - if chatgpt_provider.switch_threshold is not None: - logger.info( - "ChatGPT switch threshold: %s%%", chatgpt_provider.switch_threshold - ) + logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD) logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) app = create_app() web.run_app(app, host="0.0.0.0", port=PORT) diff --git a/src/utils/__init__.py b/src/utils/__init__.py deleted file mode 100644 index 00303ef..0000000 --- a/src/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .env import parse_int_env -from .randoms import generate_password - -__all__ = ["parse_int_env", "generate_password"] diff --git a/src/utils/env.py b/src/utils/env.py deleted file mode 100644 index 688e65f..0000000 --- a/src/utils/env.py +++ /dev/null @@ -1,22 +0,0 @@ -import os - - -def parse_int_env( - name: str, - default: int, - minimum: int, - maximum: int, -) -> int: - raw = os.environ.get(name) - if raw is None: - return default - - try: - value = int(raw) - except ValueError: - return default - - if value < minimum or value > maximum: - return default - - return value diff --git a/src/utils/randoms.py b/src/utils/randoms.py deleted file mode 100644 index 76adab8..0000000 --- a/src/utils/randoms.py +++ /dev/null @@ -1,13 +0,0 @@ -import random -import secrets -import string - - -def generate_password( - length: int = 20, - *, - secure: bool = True, -) -> str: - alphabet = string.ascii_letters + string.digits - chooser = secrets.choice if secure else random.choice - return "".join(chooser(alphabet) for _ in range(length)) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 9bf27d4..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -import sys -from pathlib import Path - - -def _add_src_to_path() -> None: - root = Path(__file__).resolve().parents[1] - src = root / "src" - if str(src) not in sys.path: - sys.path.insert(0, str(src)) - - -_add_src_to_path() diff --git a/tests/test_registration_unit.py b/tests/test_registration_unit.py deleted file mode 100644 index 32fd941..0000000 --- a/tests/test_registration_unit.py +++ /dev/null @@ -1,37 +0,0 @@ -from providers.chatgpt.registration import ( - build_authorize_url, - extract_verification_code, - generate_birthdate_90s, - generate_name, -) - - -def test_generate_name_shape(): - name = generate_name() - parts = name.split(" ") - assert len(parts) == 2 - assert all(p.isalpha() for p in parts) - - -def test_generate_birthdate_90s_range(): - month, day, year = generate_birthdate_90s() - assert 1 <= int(month) <= 12 - assert 1 <= int(day) <= 28 - assert 1990 <= int(year) <= 1999 - - -def test_extract_verification_code_prefers_chatgpt_phrase(): - text = "foo 123456 bar Your ChatGPT code is 654321" - assert extract_verification_code(text) == "654321" - - -def test_extract_verification_code_fallback_last_code(): - text = "codes 111111 and 222222" - assert extract_verification_code(text) == "222222" - - -def test_build_authorize_url_contains_required_params(): - url = build_authorize_url("challenge", "state123") - assert "response_type=code" in url - assert "code_challenge=challenge" in url - assert "state=state123" in url diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py deleted file mode 100644 index 68b3030..0000000 --- a/tests/test_server_unit.py +++ /dev/null @@ -1,159 +0,0 @@ -import asyncio -import json - -import server -from providers.base import Provider, ProviderTokens -from utils.env import parse_int_env - - -class FakeRequest: - def __init__(self, provider: str): - self.match_info = {"provider": provider} - - -class FakeProvider(Provider): - def __init__( - self, - token: str | None = "tok", - usage: dict | None = None, - rotate: bool = False, - ): - self._token = token - self._usage = usage or { - "used_percent": 10, - "remaining_percent": 90, - "primary_window": None, - "secondary_window": None, - } - self._rotate = rotate - self._prepare_threshold = 80 - self.get_token_calls = 0 - self.standby_calls = 0 - - @property - def prepare_threshold(self) -> int: - return self._prepare_threshold - - def should_prepare_standby(self, usage_percent: int) -> bool: - return usage_percent >= self.prepare_threshold - - @property - def name(self) -> str: - return "fake" - - async def get_token(self) -> str | None: - self.get_token_calls += 1 - return self._token - - async def register_new_account(self) -> bool: - return True - - async def get_usage_info(self, access_token: str) -> dict: - _ = access_token - return dict(self._usage) - - def load_tokens(self) -> ProviderTokens | None: - return None - - def save_tokens(self, tokens: ProviderTokens) -> None: - _ = tokens - - async def maybe_rotate_account(self, usage_percent: int) -> bool: - _ = usage_percent - return self._rotate - - async def ensure_standby_account( - self, - usage_percent: int, - ) -> None: - _ = usage_percent - self.standby_calls += 1 - - -def _response_json(resp) -> dict: - return json.loads(resp.body.decode("utf-8")) - - -def test_parse_int_env_defaults(monkeypatch): - monkeypatch.delenv("X_TEST", raising=False) - assert parse_int_env("X_TEST", 10, 1, 20) == 10 - - -def test_parse_int_env_invalid(monkeypatch): - monkeypatch.setenv("X_TEST", "abc") - assert parse_int_env("X_TEST", 10, 1, 20) == 10 - - -def test_parse_int_env_out_of_range(monkeypatch): - monkeypatch.setenv("X_TEST", "999") - assert parse_int_env("X_TEST", 10, 1, 20) == 10 - - -def test_build_limit_fields(): - limit = server.build_limit(90, 85) - assert limit == { - "used_percent": 90, - "remaining_percent": 10, - "exhausted": False, - "needs_prepare": True, - } - - -def test_get_prepare_threshold(): - assert ( - server.get_prepare_threshold("chatgpt") - == server.PROVIDERS["chatgpt"].prepare_threshold - ) - assert server.get_prepare_threshold("unknown") == 100 - - -def test_token_handler_unknown_provider(monkeypatch): - monkeypatch.setattr(server, "PROVIDERS", {}) - resp = asyncio.run(server.token_handler(FakeRequest("missing"))) - assert resp.status == 404 - - -def test_token_handler_success(monkeypatch): - provider = FakeProvider() - monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) - monkeypatch.setattr(server, "background_tasks", {"fake": None}) - resp = asyncio.run(server.token_handler(FakeRequest("fake"))) - data = _response_json(resp) - - assert resp.status == 200 - assert data["token"] == "tok" - assert data["limit"]["needs_prepare"] is False - - -def test_token_handler_triggers_standby(monkeypatch): - provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10}) - monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) - monkeypatch.setattr(server, "background_tasks", {"fake": None}) - - called = {"value": False} - - def fake_trigger(name, usage_percent, reason): - assert name == "fake" - assert usage_percent == 90 - assert "standby policy" in reason - called["value"] = True - - monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger) - - resp = asyncio.run(server.token_handler(FakeRequest("fake"))) - assert resp.status == 200 - assert called["value"] is True - - -def test_token_handler_rotation_path(monkeypatch): - provider = FakeProvider( - usage={"used_percent": 96, "remaining_percent": 4}, - rotate=True, - ) - monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) - monkeypatch.setattr(server, "background_tasks", {"fake": None}) - monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None) - - resp = asyncio.run(server.token_handler(FakeRequest("fake"))) - assert resp.status == 200 - assert provider.get_token_calls >= 2 diff --git a/tests/test_tokens_unit.py b/tests/test_tokens_unit.py deleted file mode 100644 index 58f5af5..0000000 --- a/tests/test_tokens_unit.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -from pathlib import Path - -from providers.base import ProviderTokens -from providers.chatgpt import tokens as t - - -def test_normalize_state_backward_compatible(): - raw = {"access_token": "a", "refresh_token": "r", "expires_at": 1} - normalized = t._normalize_state(raw) - assert normalized["active"]["access_token"] == "a" - assert normalized["next_account"] is None - - -def test_promote_next_tokens(tmp_path, monkeypatch): - file_path = tmp_path / "chatgpt_tokens.json" - monkeypatch.setattr(t, "TOKENS_FILE", file_path) - - active = ProviderTokens("a1", "r1", 100) - nxt = ProviderTokens("a2", "r2", 200) - t.save_state(active, nxt) - - assert t.promote_next_tokens() is True - cur, next_cur = t.load_state() - assert cur is not None - assert cur.access_token == "a2" - assert next_cur is None - - -def test_save_tokens_preserves_next(tmp_path, monkeypatch): - file_path = tmp_path / "chatgpt_tokens.json" - monkeypatch.setattr(t, "TOKENS_FILE", file_path) - - active = ProviderTokens("a1", "r1", 100) - nxt = ProviderTokens("a2", "r2", 200) - t.save_state(active, nxt) - - t.save_tokens(ProviderTokens("a3", "r3", 300)) - cur, next_cur = t.load_state() - assert cur is not None and cur.access_token == "a3" - assert next_cur is not None and next_cur.access_token == "a2" - - -def test_atomic_write_produces_valid_json(tmp_path, monkeypatch): - file_path = tmp_path / "chatgpt_tokens.json" - monkeypatch.setattr(t, "TOKENS_FILE", file_path) - - t.save_state(ProviderTokens("x", "y", 123), None) - with open(file_path) as f: - data = json.load(f) - assert "active" in data - assert data["active"]["access_token"] == "x" - - -def test_load_state_from_missing_file(tmp_path, monkeypatch): - file_path = tmp_path / "missing.json" - monkeypatch.setattr(t, "TOKENS_FILE", file_path) - active, nxt = t.load_state() - assert active is None - assert nxt is None diff --git a/tests/test_usage_unit.py b/tests/test_usage_unit.py deleted file mode 100644 index 63cad99..0000000 --- a/tests/test_usage_unit.py +++ /dev/null @@ -1,32 +0,0 @@ -from providers.chatgpt.usage import _parse_window, clamp_percent - - -def test_clamp_percent_bounds(): - assert clamp_percent(-1) == 0 - assert clamp_percent(150) == 100 - assert clamp_percent(49.6) == 50 - - -def test_clamp_percent_invalid(): - assert clamp_percent(None) == 0 - assert clamp_percent("bad") == 0 - - -def test_parse_window_valid(): - window = { - "used_percent": 34.4, - "limit_window_seconds": 3600, - "reset_after_seconds": 120, - "reset_at": 999, - } - parsed = _parse_window(window) - assert parsed == { - "used_percent": 34, - "limit_window_seconds": 3600, - "reset_after_seconds": 120, - "reset_at": 999, - } - - -def test_parse_window_none(): - assert _parse_window(None) is None diff --git a/uv.lock b/uv.lock index 17acd39..bb427f9 100644 --- a/uv.lock +++ b/uv.lock @@ -83,15 +83,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, ] -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, -] - [[package]] name = "frozenlist" version = "1.8.0" @@ -167,15 +158,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] -[[package]] -name = "iniconfig" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, -] - [[package]] name = "megapt" version = "0.1.0" @@ -186,19 +168,12 @@ dependencies = [ { name = "playwright" }, ] -[package.optional-dependencies] -dev = [ - { name = "pytest" }, -] - [package.metadata] requires-dist = [ { name = "aiohttp", specifier = "==3.13.3" }, { name = "pkce", specifier = "==1.0.3" }, { name = "playwright", specifier = "==1.58.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, ] -provides-extras = ["dev"] [[package]] name = "multidict" @@ -245,15 +220,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" }, ] -[[package]] -name = "packaging" -version = "26.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, -] - [[package]] name = "pkce" version = "1.0.3" @@ -282,15 +248,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/c4/cc0229fea55c87d6c9c67fe44a21e2cd28d1d558a5478ed4d617e9fb0c93/playwright-1.58.0-py3-none-win_arm64.whl", hash = "sha256:32ffe5c303901a13a0ecab91d1c3f74baf73b84f4bedbb6b935f5bc11cc98e1b", size = 33085919, upload-time = "2026-01-30T15:09:45.71Z" }, ] -[[package]] -name = "pluggy" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, -] - [[package]] name = "propcache" version = "0.4.1" @@ -342,31 +299,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/c4/b4d4827c93ef43c01f599ef31453ccc1c132b353284fc6c87d535c233129/pyee-13.0.1-py3-none-any.whl", hash = "sha256:af2f8fede4171ef667dfded53f96e2ed0d6e6bd7ee3bb46437f77e3b57689228", size = 15659, upload-time = "2026-02-14T21:12:26.263Z" }, ] -[[package]] -name = "pygments" -version = "2.19.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, -] - -[[package]] -name = "pytest" -version = "9.0.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "pygments" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, -] - [[package]] name = "typing-extensions" version = "4.15.0"