1
0
Fork 0

Compare commits

..

5 commits

26 changed files with 1332 additions and 333 deletions

View file

@ -1,8 +1,11 @@
# HTTP server port
PORT=80
# Trigger background token refresh when usage reaches threshold percent
USAGE_REFRESH_THRESHOLD=85
# Prepare next ChatGPT account when active usage reaches threshold percent
CHATGPT_PREPARE_THRESHOLD=85
# Switch active ChatGPT account when usage reaches threshold percent
CHATGPT_SWITCH_THRESHOLD=95
# Persistent data directory (tokens, screenshots)
DATA_DIR=/data

106
README.md
View file

@ -1,29 +1,22 @@
# megapt
HTTP service that returns an active ChatGPT access token.
The service can:
- restore/refresh a saved token from `/data`
- auto-register a new ChatGPT account when needed
- get verification email from a disposable mail provider (`temp-mail.org`)
- expose token and usage info via HTTP endpoint
Service for issuing ChatGPT OAuth tokens via browser automation with disposable email.
## Endpoints
- `GET /token` - legacy route (defaults to `chatgpt` provider)
- `GET /chatgpt/token` - explicit provider route
- `GET /chatgpt/token`
- `GET /token` (legacy alias, same as chatgpt)
Example response:
Response shape:
```json
{
"token": "<access_token>",
"token": "...",
"limit": {
"used_percent": 0,
"remaining_percent": 100,
"exhausted": false,
"needs_refresh": false
"needs_prepare": false
},
"usage": {
"primary_window": {
@ -37,67 +30,58 @@ Example response:
}
```
## Environment Variables
## Environment variables
- `PORT` - HTTP server port (default: `8080`)
- `DATA_DIR` - persistent data directory for tokens/screenshots (default: `./data`)
- `CHATGPT_PREPARE_THRESHOLD` - usage threshold to prepare `next_account` (default: `85`)
- `CHATGPT_SWITCH_THRESHOLD` - usage threshold to switch active account to `next_account` (default: `95`)
See `.env.example`.
Example config is in `.env.example`.
- `PORT` - HTTP port for the service
- `USAGE_REFRESH_THRESHOLD` - percent threshold to trigger background token rotation
- `DATA_DIR` - directory for persistent data (`chatgpt_tokens.json`, screenshots, etc.)
## Token Lifecycle
- **active account** - currently served token.
- **next account** - pre-created account/token stored for fast switch.
## Local run
Behavior:
Requirements:
- Python 3.14+
- Playwright Chromium dependencies
1. If active token is valid, service returns it immediately.
2. If active token is expired, service tries refresh under a single write lock.
3. If refresh fails or token is missing, service registers a new account (up to 4 attempts).
4. When usage reaches `CHATGPT_PREPARE_THRESHOLD`, service prepares `next_account`.
5. When usage reaches `CHATGPT_SWITCH_THRESHOLD`, service switches active account to `next_account`.
Install and run:
## Disposable Email Provider
- Default provider is `mail.tm` API (`MailTmProvider`) and does not use browser automation.
- Flow: fetch domains -> create account with random address/password -> get JWT token -> poll messages.
## Startup Behavior
On startup, service ensures active token exists and is usable.
Standby preparation runs through provider lifecycle hooks/background trigger when needed.
## Data Files
- `DATA_DIR/chatgpt_tokens.json` - token state with `active` and `next_account`.
- `DATA_DIR/screenshots/` - automation failure screenshots.
## Run Locally
```bash
uv sync --frozen --no-dev
./.venv/bin/python -m playwright install --with-deps chromium
PYTHONPATH=./src ./.venv/bin/python src/server.py
PYTHONPATH=./src python src/server.py
```
Then request token:
## Unit Tests
The project has unit tests only (no integration/network tests).
```bash
curl http://127.0.0.1:8080/chatgpt/token
pytest -q
```
## Docker Notes
## Docker deployment
Build image:
```bash
docker build -t megapt:latest .
```
Run container:
```bash
docker run -d \
--name megapt \
--restart unless-stopped \
--env-file .env \
-v ./data:/data \
-p 80:80 \
megapt:latest
```
Check logs:
```bash
docker logs -f megapt
```
## Notes
- Service performs a startup token check and tries to recover token automatically.
- Token write path is synchronized (single-writer lock) to avoid parallel re-registration.
- Browser runs in virtual display (`Xvfb`) inside container.
- Keep `/data` persistent between restarts.
- Dockerfile sets `DATA_DIR=/data`.
- `entrypoint.sh` starts Xvfb and runs `server.py`.

View file

@ -8,5 +8,10 @@ dependencies = [
"pkce==1.0.3",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
]
[tool.uv]
package = false

View file

@ -0,0 +1,55 @@
import argparse
import asyncio
import json
import logging
import browser
import server
class _FakeRequest:
def __init__(self, provider: str):
self.match_info = {"provider": provider}
self.method = "GET"
self.path_qs = f"/{provider}/token"
def _enable_headed_browser() -> bool:
if "--no-startup-window" in browser.CHROME_FLAGS:
browser.CHROME_FLAGS.remove("--no-startup-window")
return True
return False
async def _run(provider: str) -> int:
patched = _enable_headed_browser()
logging.info("Headed mode patch applied: %s", patched)
request = _FakeRequest(provider)
response = await server.token_handler(request)
payload = json.loads(response.body.decode("utf-8"))
logging.info("Response status: %s", response.status)
logging.info("Response body: %s", json.dumps(payload, indent=2))
return 0 if response.status == 200 else 1
def main() -> int:
parser = argparse.ArgumentParser(
description=(
"Run the same token refresh/issue flow as server /{provider}/token "
"in headed browser mode (non-headless)."
)
)
parser.add_argument("--provider", default="chatgpt")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
return asyncio.run(_run(args.provider))
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -1,6 +1,7 @@
import asyncio
import json
import logging
import socket
import shutil
import subprocess
import tempfile
@ -44,7 +45,12 @@ CHROME_FLAGS = [
"--disable-search-engine-choice-screen",
]
DEFAULT_CDP_PORT = 9222
def _allocate_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return int(s.getsockname()[1])
def _fetch_ws_endpoint(port: int) -> str | None:
@ -79,10 +85,9 @@ class ManagedBrowser:
shutil.rmtree(self.profile_dir, ignore_errors=True)
async def launch(
playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT
) -> ManagedBrowser:
async def launch(playwright: Playwright) -> ManagedBrowser:
chrome_path = playwright.chromium.executable_path
cdp_port = _allocate_free_port()
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
args = [

View file

@ -1,5 +1,11 @@
from .base import BaseProvider
from .mail_tm import MailTmProvider
from .ten_minute_mail import TenMinuteMailProvider
from .temp_mail_org import TempMailOrgProvider
__all__ = ["BaseProvider", "TenMinuteMailProvider", "TempMailOrgProvider"]
__all__ = [
"BaseProvider",
"MailTmProvider",
"TenMinuteMailProvider",
"TempMailOrgProvider",
]

View file

@ -12,5 +12,5 @@ class BaseProvider(ABC):
pass
@abstractmethod
async def get_latest_message(self, email: str) -> str | None:
async def get_latest_message(self) -> str | None:
pass

View file

@ -0,0 +1,227 @@
import asyncio
import logging
import os
import secrets
import string
from typing import Any
import aiohttp
from playwright.async_api import BrowserContext
from .base import BaseProvider
from utils.randoms import generate_password
logger = logging.getLogger(__name__)
_API_BASE = os.environ.get("MAIL_TM_API_BASE", "https://api.mail.tm")
_TIMEOUT_SECONDS = 20
_FIRST_NAMES = [
"james",
"john",
"robert",
"michael",
"david",
"william",
"joseph",
"thomas",
"daniel",
"mark",
"paul",
"kevin",
]
_LAST_NAMES = [
"smith",
"johnson",
"williams",
"brown",
"jones",
"miller",
"davis",
"wilson",
"anderson",
"taylor",
"martin",
"thompson",
]
def _generate_local_part() -> str:
first = secrets.choice(_FIRST_NAMES)
last = secrets.choice(_LAST_NAMES)
digits = "".join(secrets.choice(string.digits) for _ in range(8))
return f"{first}{last}{digits}"
class MailTmProvider(BaseProvider):
def __init__(self, browser_session: BrowserContext):
super().__init__(browser_session)
self._address: str | None = None
self._password: str | None = None
self._token: str | None = None
async def _request(
self,
method: str,
path: str,
*,
token: str | None = None,
json_body: dict[str, Any] | None = None,
) -> tuple[int, dict[str, Any] | list[Any] | None]:
url = f"{_API_BASE.rstrip('/')}{path}"
headers: dict[str, str] = {}
if token:
headers["Authorization"] = f"Bearer {token}"
timeout = aiohttp.ClientTimeout(total=_TIMEOUT_SECONDS)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.request(
method,
url,
headers=headers,
json=json_body,
) as resp:
status = resp.status
try:
payload = await resp.json()
except aiohttp.ContentTypeError:
payload = None
return status, payload
except aiohttp.ClientError as e:
logger.warning("[mail.tm] request failed %s %s: %s", method, path, e)
return 0, None
async def _get_domains(self) -> list[str]:
status, payload = await self._request("GET", "/domains")
if status != 200 or not isinstance(payload, dict):
raise RuntimeError("mail.tm domains request failed")
members = payload.get("hydra:member")
if not isinstance(members, list):
raise RuntimeError("mail.tm domains response has unexpected format")
domains: list[str] = []
for item in members:
if not isinstance(item, dict):
continue
domain = item.get("domain")
is_active = bool(item.get("isActive", True))
if isinstance(domain, str) and domain and is_active:
domains.append(domain)
if not domains:
raise RuntimeError("mail.tm returned no active domains")
return domains
async def _create_account(self, address: str, password: str) -> bool:
status, _ = await self._request(
"POST",
"/accounts",
json_body={"address": address, "password": password},
)
if status in (200, 201):
return True
return False
async def _create_token(self, address: str, password: str) -> str | None:
status, payload = await self._request(
"POST",
"/token",
json_body={"address": address, "password": password},
)
if status != 200 or not isinstance(payload, dict):
return None
token = payload.get("token")
if isinstance(token, str) and token:
return token
return None
async def get_new_email(self) -> str:
domains = await self._get_domains()
for _ in range(8):
domain = secrets.choice(domains)
address = f"{_generate_local_part()}@{domain}"
password = generate_password(length=24)
created = await self._create_account(address, password)
if not created:
continue
token = await self._create_token(address, password)
if not token:
continue
self._address = address
self._password = password
self._token = token
logger.info("[mail.tm] New mailbox acquired: %s", address)
return address
raise RuntimeError("mail.tm could not create account")
async def _list_messages(self) -> list[dict[str, Any]]:
if not self._token:
return []
status, payload = await self._request(
"GET",
"/messages",
token=self._token,
)
if status == 401 and self._address and self._password:
token = await self._create_token(self._address, self._password)
if token:
self._token = token
status, payload = await self._request(
"GET",
"/messages",
token=self._token,
)
if status != 200 or not isinstance(payload, dict):
return []
members = payload.get("hydra:member")
if not isinstance(members, list):
return []
return [item for item in members if isinstance(item, dict)]
async def _get_message_text(self, message_id: str) -> str | None:
if not self._token:
return None
status, payload = await self._request(
"GET",
f"/messages/{message_id}",
token=self._token,
)
if status != 200 or not isinstance(payload, dict):
return None
parts = [
payload.get("subject"),
payload.get("intro"),
payload.get("text"),
payload.get("html"),
]
text = "\n".join(str(part) for part in parts if part)
return text or None
async def get_latest_message(self) -> str | None:
if not self._token:
raise RuntimeError("mail.tm provider is not initialized with mailbox token")
for _ in range(45):
messages = await self._list_messages()
if messages:
latest = messages[0]
message_id = latest.get("id")
if isinstance(message_id, str) and message_id:
full_message = await self._get_message_text(message_id)
if full_message:
logger.info("[mail.tm] Latest message received")
return full_message
await asyncio.sleep(2)
logger.warning("[mail.tm] No messages received within timeout")
return None

View file

@ -2,9 +2,10 @@ import asyncio
import logging
import re
from playwright.async_api import BrowserContext, Page
from playwright.async_api import BrowserContext, Error as PlaywrightError, Page
from .base import BaseProvider
from .utils import ensure_page
logger = logging.getLogger(__name__)
@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider):
self.page: Page | None = None
async def _ensure_page(self) -> Page:
if self.page is None or self.page.is_closed():
self.page = await self.browser_session.new_page()
self.page = await ensure_page(self.browser_session, self.page)
return self.page
async def get_new_email(self) -> str:
@ -44,7 +44,7 @@ class TempMailOrgProvider(BaseProvider):
value,
)
return value
except Exception:
except PlaywrightError:
continue
try:
@ -53,16 +53,16 @@ class TempMailOrgProvider(BaseProvider):
if found:
logger.info("[temp-mail.org] email found by body scan: %s", found)
return found
except Exception:
pass
except PlaywrightError:
logger.debug("Failed to scan body text for email")
await asyncio.sleep(1)
raise RuntimeError("Could not get temp email from temp-mail.org")
async def get_latest_message(self, email: str) -> str | None:
async def get_latest_message(self) -> str | None:
page = await self._ensure_page()
logger.info("[temp-mail.org] Waiting for latest message for %s", email)
logger.info("[temp-mail.org] Waiting for latest message")
if page.is_closed():
raise RuntimeError("temp-mail.org tab was closed unexpectedly")
@ -76,7 +76,7 @@ class TempMailOrgProvider(BaseProvider):
try:
count = await items.count()
logger.info("[temp-mail.org] inbox items: %s", count)
except Exception:
except PlaywrightError:
count = 0
if count > 0:
@ -87,30 +87,30 @@ class TempMailOrgProvider(BaseProvider):
continue
text = (await item.inner_text()).strip().replace("\n", " ")
logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160])
except Exception:
except PlaywrightError:
continue
if text:
try:
await item.click()
logger.info("[temp-mail.org] opened item[%s]", idx)
except Exception:
pass
except PlaywrightError:
logger.debug("Failed to open inbox item[%s]", idx)
message_text = text
try:
content = await page.content()
if content and "Your ChatGPT code is" in content:
message_text = content
except Exception:
pass
except PlaywrightError:
logger.debug("Failed to read opened message content")
try:
await page.go_back(
wait_until="domcontentloaded", timeout=5000
)
logger.info("[temp-mail.org] returned back to inbox")
except Exception:
pass
except PlaywrightError:
logger.debug("Failed to return back to inbox")
return message_text

View file

@ -4,6 +4,7 @@ import logging
from playwright.async_api import BrowserContext, Page
from .base import BaseProvider
from .utils import ensure_page
logger = logging.getLogger(__name__)
@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider):
self.page: Page | None = None
async def _ensure_page(self) -> Page:
if self.page is None or self.page.is_closed():
self.page = await self.browser_session.new_page()
self.page = await ensure_page(self.browser_session, self.page)
return self.page
async def get_new_email(self) -> str:
@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider):
logger.info("[10min] New email acquired: %s", email)
return email
async def get_latest_message(self, email: str) -> str | None:
async def get_latest_message(self) -> str | None:
page = await self._ensure_page()
logger.info("[10min] Waiting for latest message for %s", email)
logger.info("[10min] Waiting for latest message")
seen_count = 0
for attempt in range(60):

View file

@ -0,0 +1,10 @@
from playwright.async_api import BrowserContext, Page
async def ensure_page(
browser_session: BrowserContext,
page: Page | None,
) -> Page:
if page is None or page.is_closed():
return await browser_session.new_page()
return page

View file

@ -52,3 +52,36 @@ class Provider(ABC):
def save_tokens(self, tokens: ProviderTokens) -> None:
"""Save tokens to storage"""
pass
async def force_recreate_token(self) -> str | None:
"""Force-create a new active token when normal acquisition fails."""
return None
async def maybe_rotate_account(self, usage_percent: int) -> bool:
"""Rotate active account/token if provider policy requires it."""
return False
@property
def prepare_threshold(self) -> int:
"""Usage percent when provider should prepare standby account/token."""
return 100
@property
def switch_threshold(self) -> int | None:
"""Usage percent when provider may switch active account/token."""
return None
def should_prepare_standby(self, usage_percent: int) -> bool:
"""Whether standby preparation should be triggered for current usage."""
return False
async def ensure_standby_account(
self,
usage_percent: int,
) -> None:
"""Prepare standby account/token asynchronously when needed."""
return None
async def startup_prepare(self) -> None:
"""Optional provider-specific startup preparation."""
return None

View file

@ -1,19 +1,36 @@
import asyncio
import logging
from typing import Callable
from typing import Any
from typing import Callable
from playwright.async_api import BrowserContext
from providers.base import Provider, ProviderTokens
from email_providers import BaseProvider
from email_providers import TempMailOrgProvider
from .tokens import load_tokens, save_tokens, refresh_tokens
from email_providers import MailTmProvider
from providers.base import Provider, ProviderTokens
from utils.env import parse_int_env
from .tokens import (
clear_next_tokens,
load_next_tokens,
load_state,
load_tokens,
promote_next_tokens,
refresh_tokens,
save_state,
save_tokens,
)
from .usage import get_usage_data
from .registration import register_chatgpt_account
logger = logging.getLogger(__name__)
MAX_REGISTRATION_ATTEMPTS = 4
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
CHATGPT_SWITCH_THRESHOLD = parse_int_env(
"CHATGPT_SWITCH_THRESHOLD",
95,
0,
100,
)
class ChatGPTProvider(Provider):
@ -23,20 +40,61 @@ class ChatGPTProvider(Provider):
self,
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
):
self.email_provider_factory = email_provider_factory or TempMailOrgProvider
self.email_provider_factory = email_provider_factory or MailTmProvider
self._token_write_lock = asyncio.Lock()
@property
def prepare_threshold(self) -> int:
return CHATGPT_PREPARE_THRESHOLD
@property
def switch_threshold(self) -> int | None:
return CHATGPT_SWITCH_THRESHOLD
async def _register_with_retries(self) -> bool:
for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1):
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
logger.info(
"Registration attempt %s/%s",
attempt,
MAX_REGISTRATION_ATTEMPTS,
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
)
success = await self.register_new_account()
if success:
generated_tokens = await register_chatgpt_account(
email_provider_factory=self.email_provider_factory,
)
if generated_tokens:
save_tokens(generated_tokens)
return True
logger.warning("Registration attempt %s failed", attempt)
await asyncio.sleep(1.5 * attempt)
return False
async def _create_next_account_under_lock(self) -> bool:
active_before, next_before = load_state()
if next_before:
return True
logger.info("Creating next account")
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
logger.info(
"Next-account registration attempt %s/%s",
attempt,
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
)
generated_tokens = await register_chatgpt_account(
email_provider_factory=self.email_provider_factory,
)
if generated_tokens:
if active_before:
save_state(active_before, generated_tokens)
else:
save_state(generated_tokens, None)
logger.info("Next account is ready")
return True
logger.warning("Next-account registration attempt %s failed", attempt)
await asyncio.sleep(1.5 * attempt)
if active_before or next_before:
save_state(active_before, next_before)
return False
async def force_recreate_token(self) -> str | None:
@ -44,11 +102,62 @@ class ChatGPTProvider(Provider):
success = await self._register_with_retries()
if not success:
return None
clear_next_tokens()
tokens = load_tokens()
if not tokens:
return None
return tokens.access_token
async def startup_prepare(self) -> None:
await self.ensure_next_account()
async def ensure_next_account(self) -> bool:
next_tokens = load_next_tokens()
if next_tokens:
return True
async with self._token_write_lock:
next_tokens = load_next_tokens()
if next_tokens:
return True
return await self._create_next_account_under_lock()
def should_prepare_standby(self, usage_percent: int) -> bool:
return usage_percent >= self.prepare_threshold
async def ensure_standby_account(
self,
usage_percent: int,
) -> None:
if self.should_prepare_standby(usage_percent):
await self.ensure_next_account()
async def maybe_switch_active_account(self, usage_percent: int) -> bool:
if usage_percent < CHATGPT_SWITCH_THRESHOLD:
return False
async with self._token_write_lock:
next_tokens = load_next_tokens()
if not next_tokens or next_tokens.is_expired:
logger.info(
"Active usage >= %s%% and next account missing",
CHATGPT_SWITCH_THRESHOLD,
)
created = await self._create_next_account_under_lock()
if not created:
return False
switched = promote_next_tokens()
if switched:
logger.info(
"Switched active account (usage >= %s%%)",
CHATGPT_SWITCH_THRESHOLD,
)
return switched
async def maybe_rotate_account(self, usage_percent: int) -> bool:
return await self.maybe_switch_active_account(usage_percent)
@property
def name(self) -> str:
return "chatgpt"
@ -84,13 +193,17 @@ class ChatGPTProvider(Provider):
async def register_new_account(self) -> bool:
"""Register a new ChatGPT account"""
return await register_chatgpt_account(
generated_tokens = await register_chatgpt_account(
email_provider_factory=self.email_provider_factory,
)
if not generated_tokens:
return False
save_tokens(generated_tokens)
return True
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
"""Get usage information for the current token"""
usage_data = get_usage_data(access_token)
usage_data = await get_usage_data(access_token)
if not usage_data:
return {"error": "Failed to get usage"}

View file

@ -5,7 +5,6 @@ import logging
import random
import re
import secrets
import string
import time
from datetime import datetime
from pathlib import Path
@ -14,12 +13,18 @@ from typing import Callable
from urllib.parse import parse_qs, urlencode, urlparse
import aiohttp
from playwright.async_api import async_playwright, Page, BrowserContext
from playwright.async_api import (
async_playwright,
Error as PlaywrightError,
Page,
BrowserContext,
)
from browser import launch as launch_browser
from email_providers import BaseProvider
from providers.base import ProviderTokens
from .tokens import CLIENT_ID, save_tokens
from utils.randoms import generate_password
from .tokens import CLIENT_ID
logger = logging.getLogger(__name__)
@ -46,14 +51,9 @@ async def save_error_screenshot(page: Page | None, step: str):
filename = screenshots_dir / f"error_{step}_{timestamp}.png"
try:
await page.screenshot(path=str(filename))
logger.error(f"Screenshot saved: {filename}")
except:
pass
def generate_password(length: int = 20) -> str:
alphabet = string.ascii_letters + string.digits
return "".join(random.choice(alphabet) for _ in range(length))
logger.error("Screenshot saved: %s", filename)
except PlaywrightError as e:
logger.warning("Failed to save screenshot at step %s: %s", step, e)
def generate_name() -> str:
@ -204,8 +204,7 @@ def generate_state() -> str:
return secrets.token_urlsafe(32)
def build_authorize_url(verifier: str, challenge: str, state: str) -> str:
del verifier
def build_authorize_url(challenge: str, state: str) -> str:
params = {
"response_type": "code",
"client_id": CLIENT_ID,
@ -222,7 +221,6 @@ def build_authorize_url(verifier: str, challenge: str, state: str) -> str:
async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
async with aiohttp.ClientSession() as session:
payload = {
"grant_type": "authorization_code",
"client_id": CLIENT_ID,
@ -230,22 +228,30 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
"code_verifier": verifier,
"redirect_uri": REDIRECT_URI,
}
timeout = aiohttp.ClientTimeout(total=20)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(TOKEN_URL, data=payload) as resp:
if not resp.ok:
text = await resp.text()
raise RuntimeError(f"Token exchange failed: {resp.status} {text}")
body = await resp.json()
except (aiohttp.ClientError, TimeoutError) as e:
raise RuntimeError(f"Token exchange request error: {e}") from e
try:
expires_in = int(body["expires_in"])
return ProviderTokens(
access_token=body["access_token"],
refresh_token=body["refresh_token"],
expires_at=time.time() + expires_in,
)
except (KeyError, TypeError, ValueError) as e:
raise RuntimeError(f"Token exchange response parse error: {e}") from e
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
message = await email_provider.get_latest_message(email)
async def get_latest_code(email_provider: BaseProvider) -> str | None:
message = await email_provider.get_latest_message()
if not message:
return None
return extract_verification_code(message)
@ -270,6 +276,51 @@ async def click_continue(page: Page, timeout_ms: int = 10000):
await btn.click()
async def oauth_needs_email_check(page: Page) -> bool:
marker = page.get_by_text("Check your inbox", exact=False)
return await marker.count() > 0
async def fill_oauth_code_if_present(page: Page, code: str) -> bool:
candidates = [
page.get_by_placeholder("Code"),
page.get_by_label("Code"),
page.locator(
'input[name*="code" i], input[id*="code" i], '
'input[autocomplete="one-time-code"], input[inputmode="numeric"]'
),
]
for locator in candidates:
if await locator.count() == 0:
continue
try:
await locator.first.wait_for(state="visible", timeout=1500)
await locator.first.fill(code)
return True
except PlaywrightError:
continue
return False
async def click_first_visible_button(
page: Page,
labels: list[str],
timeout_ms: int = 2000,
) -> bool:
for label in labels:
button = page.get_by_role("button", name=label)
if await button.count() == 0:
continue
try:
await button.first.wait_for(state="visible", timeout=timeout_ms)
await button.first.click(timeout=timeout_ms)
return True
except PlaywrightError:
continue
return False
async def wait_for_signup_stabilization(
page: Page,
source_url: str,
@ -288,12 +339,12 @@ async def wait_for_signup_stabilization(
async def register_chatgpt_account(
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
) -> bool:
) -> ProviderTokens | None:
logger.info("=== Starting ChatGPT account registration ===")
if email_provider_factory is None:
logger.error("No email provider factory configured")
return False
return None
birth_month, birth_day, birth_year = generate_birthdate_90s()
@ -321,7 +372,7 @@ async def register_chatgpt_account(
full_name = generate_name()
verifier, challenge = generate_pkce_pair()
oauth_state = generate_state()
authorize_url = build_authorize_url(verifier, challenge, oauth_state)
authorize_url = build_authorize_url(challenge, oauth_state)
logger.info("[2/5] Registering ChatGPT for %s", email)
chatgpt_page = await context.new_page()
@ -347,24 +398,23 @@ async def register_chatgpt_account(
)
logger.info("[3/5] Getting verification message from email provider...")
code = await get_latest_code(email_provider, email)
code = await get_latest_code(email_provider)
if not code:
raise AutomationError(
"email_provider", "Email provider returned no verification message"
)
logger.info("[3/5] Verification code extracted: %s", code)
logger.info("[3/5] Verification code extracted")
await chatgpt_page.bring_to_front()
code_input = chatgpt_page.get_by_placeholder("Code")
if await code_input.count() > 0:
await code_input.fill(code)
await code_input.first.wait_for(state="visible", timeout=10000)
await code_input.first.fill(code)
await click_continue(chatgpt_page)
logger.info("[4/5] Setting profile...")
name_input = chatgpt_page.get_by_placeholder("Full name")
await name_input.first.wait_for(state="visible", timeout=20000)
if await name_input.count() > 0:
await name_input.fill(full_name)
await name_input.first.fill(full_name)
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
profile_url = chatgpt_page.url
@ -409,25 +459,45 @@ async def register_chatgpt_account(
if await continue_button.count() > 0:
await continue_button.first.click()
for label in ["Continue", "Allow", "Authorize"]:
button = oauth_page.get_by_role("button", name=label)
if await button.count() > 0:
try:
await button.first.click(timeout=5000)
await oauth_page.wait_for_timeout(500)
except Exception:
pass
last_oauth_email_code = code
oauth_deadline = asyncio.get_running_loop().time() + 60
while asyncio.get_running_loop().time() < oauth_deadline:
if redirect_url_captured:
break
if await oauth_needs_email_check(oauth_page):
logger.info("OAuth requested email confirmation code")
new_code = await get_latest_code(email_provider)
if new_code and new_code != last_oauth_email_code:
filled = await fill_oauth_code_if_present(oauth_page, new_code)
if filled:
last_oauth_email_code = new_code
logger.info("Filled OAuth email confirmation code")
else:
logger.warning(
"OAuth inbox challenge detected but code field not found"
)
if not redirect_url_captured:
try:
await oauth_page.wait_for_timeout(4000)
current_url = oauth_page.url
if "localhost:1455" in current_url and "code=" in current_url:
redirect_url_captured = current_url
logger.info("Captured OAuth redirect from page URL")
break
except Exception:
pass
clicked = await click_first_visible_button(
oauth_page,
["Continue", "Allow", "Authorize", "Verify"],
timeout_ms=2000,
)
if clicked:
await oauth_page.wait_for_timeout(500)
else:
await oauth_page.wait_for_timeout(1000)
if not redirect_url_captured:
raise AutomationError(
"oauth", "OAuth redirect with code was not captured", oauth_page
@ -446,20 +516,18 @@ async def register_chatgpt_account(
raise AutomationError("oauth", "OAuth state mismatch", oauth_page)
tokens = await exchange_code_for_tokens(auth_code, verifier)
save_tokens(tokens)
logger.info("OAuth tokens saved successfully")
logger.info("OAuth tokens fetched successfully")
return True
return tokens
except AutomationError as e:
logger.error(f"Error at step [{e.step}]: {e.message}")
await save_error_screenshot(e.page, e.step)
return False
return None
except Exception as e:
logger.error(f"Unexpected error: {e}")
await save_error_screenshot(current_page, "unexpected")
return False
return None
finally:
if managed:
await asyncio.sleep(2)
await managed.close()

View file

@ -1,9 +1,12 @@
import json
import time
import os
import aiohttp
from pathlib import Path
import logging
import os
import tempfile
import time
from pathlib import Path
from typing import Any
import aiohttp
from providers.base import ProviderTokens
@ -16,72 +19,143 @@ CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
TOKEN_URL = "https://auth.openai.com/oauth/token"
def load_tokens() -> ProviderTokens | None:
if not TOKENS_FILE.exists():
def _tokens_to_dict(tokens: ProviderTokens) -> dict[str, Any]:
return {
"access_token": tokens.access_token,
"refresh_token": tokens.refresh_token,
"expires_at": tokens.expires_at,
}
def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None:
if not isinstance(data, dict):
return None
try:
with open(TOKENS_FILE) as f:
data = json.load(f)
return ProviderTokens(
access_token=data["access_token"],
refresh_token=data["refresh_token"],
expires_at=data["expires_at"],
)
except json.JSONDecodeError, KeyError:
except (KeyError, TypeError):
return None
def save_tokens(tokens: ProviderTokens):
def _load_raw() -> dict[str, Any] | None:
if not TOKENS_FILE.exists():
return None
try:
with open(TOKENS_FILE) as f:
data = json.load(f)
if isinstance(data, dict):
return data
return None
except json.JSONDecodeError:
return None
def _save_raw(data: dict[str, Any]) -> None:
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(TOKENS_FILE, "w") as f:
json.dump(
{
"access_token": tokens.access_token,
"refresh_token": tokens.refresh_token,
"expires_at": tokens.expires_at,
},
f,
indent=2,
fd, tmp_path = tempfile.mkstemp(
prefix=f"{TOKENS_FILE.name}.",
suffix=".tmp",
dir=str(TOKENS_FILE.parent),
)
try:
with os.fdopen(fd, "w") as f:
json.dump(data, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, TOKENS_FILE)
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
if not data:
return {"active": None, "next_account": None}
if "active" in data or "next_account" in data:
return {
"active": data.get("active"),
"next_account": data.get("next_account"),
}
return {"active": data, "next_account": None}
def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]:
normalized = _normalize_state(_load_raw())
active = _dict_to_tokens(normalized.get("active"))
next_account = _dict_to_tokens(normalized.get("next_account"))
return active, next_account
def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None:
payload = {
"active": _tokens_to_dict(active) if active else None,
"next_account": _tokens_to_dict(next_account) if next_account else None,
}
_save_raw(payload)
def load_tokens() -> ProviderTokens | None:
active, _ = load_state()
return active
def load_next_tokens() -> ProviderTokens | None:
_, next_account = load_state()
return next_account
def save_tokens(tokens: ProviderTokens):
_, next_account = load_state()
save_state(tokens, next_account)
def promote_next_tokens() -> bool:
_, next_account = load_state()
if not next_account:
return False
save_state(next_account, None)
return True
def clear_next_tokens():
active, _ = load_state()
save_state(active, None)
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
async with aiohttp.ClientSession() as session:
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": CLIENT_ID,
}
timeout = aiohttp.ClientTimeout(total=15)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(TOKEN_URL, data=data) as resp:
if not resp.ok:
text = await resp.text()
print(f"Token refresh failed: {resp.status} {text}")
logger.warning("Token refresh failed: %s %s", resp.status, text)
return None
json_resp = await resp.json()
expires_in = json_resp["expires_in"]
except (aiohttp.ClientError, TimeoutError) as e:
logger.warning("Token refresh request error: %s", e)
return None
except Exception as e:
logger.warning("Token refresh unexpected error: %s", e)
return None
try:
expires_in = int(json_resp["expires_in"])
return ProviderTokens(
access_token=json_resp["access_token"],
refresh_token=json_resp["refresh_token"],
expires_at=time.time() + expires_in,
)
async def get_valid_tokens() -> ProviderTokens | None:
tokens = load_tokens()
if not tokens:
logger.info("No tokens found")
except (KeyError, TypeError, ValueError) as e:
logger.warning("Token refresh response parse error: %s", e)
return None
if tokens.is_expired:
logger.info("Token expired, refreshing...")
if not tokens.refresh_token:
logger.info("No refresh token available")
return None
new_tokens = await refresh_tokens(tokens.refresh_token)
if not new_tokens:
logger.warning("Failed to refresh token")
return None
save_tokens(new_tokens)
return new_tokens
return tokens

View file

@ -1,14 +1,15 @@
import json
import socket
import urllib.error
import urllib.request
import logging
from typing import Any
import aiohttp
logger = logging.getLogger(__name__)
def clamp_percent(value: Any) -> int:
try:
num = float(value)
except TypeError, ValueError:
except (TypeError, ValueError):
return 0
if num < 0:
return 0
@ -28,30 +29,36 @@ def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None:
}
def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None:
async def get_usage_data(
access_token: str,
timeout_ms: int = 10000,
) -> dict[str, Any] | None:
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": "CodexProxy",
"Accept": "application/json",
}
req = urllib.request.Request(
"https://chatgpt.com/backend-api/wham/usage",
headers=headers,
method="GET",
timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000)
url = "https://chatgpt.com/backend-api/wham/usage"
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as res:
if not res.ok:
body = await res.text()
logger.warning(
"Usage fetch failed: status=%s body=%s",
res.status,
body[:300],
)
try:
with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res:
body = res.read().decode("utf-8", errors="replace")
except urllib.error.HTTPError:
return None
except urllib.error.URLError, socket.timeout:
data = await res.json()
except (aiohttp.ClientError, TimeoutError) as e:
logger.warning("Usage fetch request error: %s", e)
return None
try:
data = json.loads(body)
except json.JSONDecodeError:
except Exception as e:
logger.warning("Usage fetch unexpected error: %s", e)
return None
rate_limit = data.get("rate_limit") or {}
@ -76,10 +83,3 @@ def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any]
"limit_reached": bool(rate_limit.get("limit_reached")),
"allowed": bool(rate_limit.get("allowed", True)),
}
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
data = get_usage_data(access_token, timeout_ms=timeout_ms)
if not data:
return -1
return int(data["used_percent"])

View file

@ -1,27 +1,22 @@
import asyncio
import logging
import os
from aiohttp import web
from providers.chatgpt import ChatGPTProvider
PORT = int(os.environ.get("PORT", "8080"))
USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85"))
LIMIT_EXHAUSTED_PERCENT = 100
from providers.base import Provider
from utils.env import parse_int_env
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
PORT = parse_int_env("PORT", 8080, 1, 65535)
LIMIT_EXHAUSTED_PERCENT = 100
# Registry of available providers
PROVIDERS = {
PROVIDERS: dict[str, Provider] = {
"chatgpt": ChatGPTProvider(),
}
refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()}
background_refresh_tasks: dict[str, asyncio.Task | None] = {
name: None for name in PROVIDERS.keys()
}
background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS}
@web.middleware
@ -31,16 +26,30 @@ async def request_log_middleware(request: web.Request, handler):
return response
def build_limit(usage_percent: int) -> dict[str, int | bool]:
def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]:
remaining = max(0, 100 - usage_percent)
return {
"used_percent": usage_percent,
"remaining_percent": remaining,
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
"needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD,
"needs_prepare": usage_percent >= prepare_threshold,
}
def get_prepare_threshold(provider_name: str) -> int:
provider = PROVIDERS.get(provider_name)
if not provider:
return 100
return provider.prepare_threshold
def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bool:
provider = PROVIDERS.get(provider_name)
if not provider:
return False
return provider.should_prepare_standby(usage_percent)
async def ensure_provider_token_ready(provider_name: str):
provider = PROVIDERS.get(provider_name)
if not provider:
@ -52,7 +61,6 @@ async def ensure_provider_token_ready(provider_name: str):
logger.warning(
"[%s] Startup token check failed, forcing recreation", provider_name
)
if isinstance(provider, ChatGPTProvider):
token = await provider.force_recreate_token()
if not token:
@ -60,103 +68,88 @@ async def ensure_provider_token_ready(provider_name: str):
return
usage_info = await provider.get_usage_info(token)
if "error" not in usage_info:
logger.info("[%s] Startup token is ready", provider_name)
return
if "error" in usage_info:
logger.warning(
"[%s] Startup token invalid for usage, forcing recreation", provider_name
)
if isinstance(provider, ChatGPTProvider):
token = await provider.force_recreate_token()
if token:
logger.info("[%s] Startup token recreated successfully", provider_name)
if not token:
logger.error("[%s] Startup token recreation failed", provider_name)
return
logger.error("[%s] Startup token recreation failed", provider_name)
await provider.startup_prepare()
logger.info("[%s] Startup token is ready", provider_name)
async def on_startup(app: web.Application):
del app
for provider_name in PROVIDERS.keys():
await ensure_provider_token_ready(provider_name)
async def issue_new_token(provider_name: str) -> str | None:
async def ensure_standby_task(provider_name: str, usage_percent: int, reason: str):
provider = PROVIDERS.get(provider_name)
if not provider:
return None
return
async with refresh_locks[provider_name]:
logger.info(f"[{provider_name}] Generating new token")
success = await provider.register_new_account()
if not success:
logger.error(f"[{provider_name}] Token generation failed")
return None
if not provider.should_prepare_standby(usage_percent):
return
token = await provider.get_token()
if not token:
logger.error(f"[{provider_name}] Token was generated but not available")
return None
return token
async def background_refresh_worker(provider_name: str, reason: str):
try:
logger.info(f"[{provider_name}] Starting background token refresh ({reason})")
new_token = await issue_new_token(provider_name)
if new_token:
logger.info(f"[{provider_name}] Background token refresh completed")
else:
logger.error(f"[{provider_name}] Background token refresh failed")
logger.info("[%s] Preparing standby in background (%s)", provider_name, reason)
await provider.ensure_standby_account(usage_percent)
except Exception:
logger.exception(
f"[{provider_name}] Unhandled error in background token refresh"
)
logger.exception("[%s] Unhandled standby preparation error", provider_name)
def trigger_background_refresh(provider_name: str, reason: str):
task = background_refresh_tasks.get(provider_name)
def trigger_standby_prepare(provider_name: str, usage_percent: int, reason: str):
task = background_tasks.get(provider_name)
if task and not task.done():
logger.info(
f"[{provider_name}] Background refresh already running, skip ({reason})"
"[%s] Standby prep already running, skip (%s)", provider_name, reason
)
return
background_refresh_tasks[provider_name] = asyncio.create_task(
background_refresh_worker(provider_name, reason)
background_tasks[provider_name] = asyncio.create_task(
ensure_standby_task(provider_name, usage_percent, reason)
)
async def token_handler(request: web.Request) -> web.Response:
provider_name = request.match_info.get("provider", "chatgpt")
provider = PROVIDERS.get(provider_name)
if not provider:
return web.json_response(
{"error": f"Unknown provider: {provider_name}"},
status=404,
{"error": f"Unknown provider: {provider_name}"}, status=404
)
# Get or create token
token = await provider.get_token()
if not token:
return web.json_response({"error": "Failed to get active token"}, status=503)
usage_info = await provider.get_usage_info(token)
if "error" in usage_info:
return web.json_response({"error": usage_info["error"]}, status=503)
usage_percent = int(usage_info.get("used_percent", 0))
switched = await provider.maybe_rotate_account(usage_percent)
if switched:
token = await provider.get_token()
if not token:
return web.json_response(
{"error": "Failed to get active token"},
{"error": "Failed to get active token after account switch"},
status=503,
)
# Get usage info
usage_info = await provider.get_usage_info(token)
if "error" in usage_info:
return web.json_response(
{"error": usage_info["error"]},
status=503,
return web.json_response({"error": usage_info["error"]}, status=503)
usage_percent = int(usage_info.get("used_percent", 0))
logger.info("[%s] Active account switched before response", provider_name)
prepare_threshold = get_prepare_threshold(provider_name)
if should_trigger_standby_prepare(provider_name, usage_percent):
trigger_standby_prepare(
provider_name,
usage_percent,
f"usage {usage_percent}% reached standby policy",
)
usage_percent = usage_info.get("used_percent", 0)
remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent))
remaining_percent = int(
usage_info.get("remaining_percent", max(0, 100 - usage_percent))
)
logger.info(
"[%s] token issued, used=%s%% remaining=%s%%",
provider_name,
@ -183,17 +176,10 @@ async def token_handler(request: web.Request) -> web.Response:
secondary_window.get("reset_after_seconds", 0),
)
# Trigger background refresh if needed
if usage_percent >= USAGE_REFRESH_THRESHOLD:
trigger_background_refresh(
provider_name,
f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%",
)
return web.json_response(
{
"token": token,
"limit": build_limit(usage_percent),
"limit": build_limit(usage_percent, prepare_threshold),
"usage": {
"primary_window": primary_window,
"secondary_window": secondary_window,
@ -202,19 +188,42 @@ async def token_handler(request: web.Request) -> web.Response:
)
async def on_startup(app: web.Application):
del app
for provider_name in PROVIDERS:
await ensure_provider_token_ready(provider_name)
async def on_cleanup(app: web.Application):
del app
for task in background_tasks.values():
if task and not task.done():
task.cancel()
pending = [t for t in background_tasks.values() if t is not None]
if pending:
await asyncio.gather(*pending, return_exceptions=True)
def create_app() -> web.Application:
app = web.Application(middlewares=[request_log_middleware])
app.on_startup.append(on_startup)
# New route: /{provider}/token
app.on_cleanup.append(on_cleanup)
app.router.add_get("/{provider}/token", token_handler)
# Legacy route for backward compatibility
app.router.add_get("/token", token_handler)
return app
if __name__ == "__main__":
logger.info("Starting token service on port %s", PORT)
logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD)
chatgpt_provider = PROVIDERS.get("chatgpt")
if chatgpt_provider:
logger.info(
"ChatGPT prepare-next threshold: %s%%", chatgpt_provider.prepare_threshold
)
if chatgpt_provider.switch_threshold is not None:
logger.info(
"ChatGPT switch threshold: %s%%", chatgpt_provider.switch_threshold
)
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
app = create_app()
web.run_app(app, host="0.0.0.0", port=PORT)

4
src/utils/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .env import parse_int_env
from .randoms import generate_password
__all__ = ["parse_int_env", "generate_password"]

22
src/utils/env.py Normal file
View file

@ -0,0 +1,22 @@
import os
def parse_int_env(
name: str,
default: int,
minimum: int,
maximum: int,
) -> int:
raw = os.environ.get(name)
if raw is None:
return default
try:
value = int(raw)
except ValueError:
return default
if value < minimum or value > maximum:
return default
return value

13
src/utils/randoms.py Normal file
View file

@ -0,0 +1,13 @@
import random
import secrets
import string
def generate_password(
length: int = 20,
*,
secure: bool = True,
) -> str:
alphabet = string.ascii_letters + string.digits
chooser = secrets.choice if secure else random.choice
return "".join(chooser(alphabet) for _ in range(length))

12
tests/conftest.py Normal file
View file

@ -0,0 +1,12 @@
import sys
from pathlib import Path
def _add_src_to_path() -> None:
root = Path(__file__).resolve().parents[1]
src = root / "src"
if str(src) not in sys.path:
sys.path.insert(0, str(src))
_add_src_to_path()

View file

@ -0,0 +1,37 @@
from providers.chatgpt.registration import (
build_authorize_url,
extract_verification_code,
generate_birthdate_90s,
generate_name,
)
def test_generate_name_shape():
name = generate_name()
parts = name.split(" ")
assert len(parts) == 2
assert all(p.isalpha() for p in parts)
def test_generate_birthdate_90s_range():
month, day, year = generate_birthdate_90s()
assert 1 <= int(month) <= 12
assert 1 <= int(day) <= 28
assert 1990 <= int(year) <= 1999
def test_extract_verification_code_prefers_chatgpt_phrase():
text = "foo 123456 bar Your ChatGPT code is 654321"
assert extract_verification_code(text) == "654321"
def test_extract_verification_code_fallback_last_code():
text = "codes 111111 and 222222"
assert extract_verification_code(text) == "222222"
def test_build_authorize_url_contains_required_params():
url = build_authorize_url("challenge", "state123")
assert "response_type=code" in url
assert "code_challenge=challenge" in url
assert "state=state123" in url

159
tests/test_server_unit.py Normal file
View file

@ -0,0 +1,159 @@
import asyncio
import json
import server
from providers.base import Provider, ProviderTokens
from utils.env import parse_int_env
class FakeRequest:
def __init__(self, provider: str):
self.match_info = {"provider": provider}
class FakeProvider(Provider):
def __init__(
self,
token: str | None = "tok",
usage: dict | None = None,
rotate: bool = False,
):
self._token = token
self._usage = usage or {
"used_percent": 10,
"remaining_percent": 90,
"primary_window": None,
"secondary_window": None,
}
self._rotate = rotate
self._prepare_threshold = 80
self.get_token_calls = 0
self.standby_calls = 0
@property
def prepare_threshold(self) -> int:
return self._prepare_threshold
def should_prepare_standby(self, usage_percent: int) -> bool:
return usage_percent >= self.prepare_threshold
@property
def name(self) -> str:
return "fake"
async def get_token(self) -> str | None:
self.get_token_calls += 1
return self._token
async def register_new_account(self) -> bool:
return True
async def get_usage_info(self, access_token: str) -> dict:
_ = access_token
return dict(self._usage)
def load_tokens(self) -> ProviderTokens | None:
return None
def save_tokens(self, tokens: ProviderTokens) -> None:
_ = tokens
async def maybe_rotate_account(self, usage_percent: int) -> bool:
_ = usage_percent
return self._rotate
async def ensure_standby_account(
self,
usage_percent: int,
) -> None:
_ = usage_percent
self.standby_calls += 1
def _response_json(resp) -> dict:
return json.loads(resp.body.decode("utf-8"))
def test_parse_int_env_defaults(monkeypatch):
monkeypatch.delenv("X_TEST", raising=False)
assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_parse_int_env_invalid(monkeypatch):
monkeypatch.setenv("X_TEST", "abc")
assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_parse_int_env_out_of_range(monkeypatch):
monkeypatch.setenv("X_TEST", "999")
assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_build_limit_fields():
limit = server.build_limit(90, 85)
assert limit == {
"used_percent": 90,
"remaining_percent": 10,
"exhausted": False,
"needs_prepare": True,
}
def test_get_prepare_threshold():
assert (
server.get_prepare_threshold("chatgpt")
== server.PROVIDERS["chatgpt"].prepare_threshold
)
assert server.get_prepare_threshold("unknown") == 100
def test_token_handler_unknown_provider(monkeypatch):
monkeypatch.setattr(server, "PROVIDERS", {})
resp = asyncio.run(server.token_handler(FakeRequest("missing")))
assert resp.status == 404
def test_token_handler_success(monkeypatch):
provider = FakeProvider()
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
monkeypatch.setattr(server, "background_tasks", {"fake": None})
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
data = _response_json(resp)
assert resp.status == 200
assert data["token"] == "tok"
assert data["limit"]["needs_prepare"] is False
def test_token_handler_triggers_standby(monkeypatch):
provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10})
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
monkeypatch.setattr(server, "background_tasks", {"fake": None})
called = {"value": False}
def fake_trigger(name, usage_percent, reason):
assert name == "fake"
assert usage_percent == 90
assert "standby policy" in reason
called["value"] = True
monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger)
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
assert resp.status == 200
assert called["value"] is True
def test_token_handler_rotation_path(monkeypatch):
provider = FakeProvider(
usage={"used_percent": 96, "remaining_percent": 4},
rotate=True,
)
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
monkeypatch.setattr(server, "background_tasks", {"fake": None})
monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None)
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
assert resp.status == 200
assert provider.get_token_calls >= 2

60
tests/test_tokens_unit.py Normal file
View file

@ -0,0 +1,60 @@
import json
from pathlib import Path
from providers.base import ProviderTokens
from providers.chatgpt import tokens as t
def test_normalize_state_backward_compatible():
raw = {"access_token": "a", "refresh_token": "r", "expires_at": 1}
normalized = t._normalize_state(raw)
assert normalized["active"]["access_token"] == "a"
assert normalized["next_account"] is None
def test_promote_next_tokens(tmp_path, monkeypatch):
file_path = tmp_path / "chatgpt_tokens.json"
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
active = ProviderTokens("a1", "r1", 100)
nxt = ProviderTokens("a2", "r2", 200)
t.save_state(active, nxt)
assert t.promote_next_tokens() is True
cur, next_cur = t.load_state()
assert cur is not None
assert cur.access_token == "a2"
assert next_cur is None
def test_save_tokens_preserves_next(tmp_path, monkeypatch):
file_path = tmp_path / "chatgpt_tokens.json"
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
active = ProviderTokens("a1", "r1", 100)
nxt = ProviderTokens("a2", "r2", 200)
t.save_state(active, nxt)
t.save_tokens(ProviderTokens("a3", "r3", 300))
cur, next_cur = t.load_state()
assert cur is not None and cur.access_token == "a3"
assert next_cur is not None and next_cur.access_token == "a2"
def test_atomic_write_produces_valid_json(tmp_path, monkeypatch):
file_path = tmp_path / "chatgpt_tokens.json"
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
t.save_state(ProviderTokens("x", "y", 123), None)
with open(file_path) as f:
data = json.load(f)
assert "active" in data
assert data["active"]["access_token"] == "x"
def test_load_state_from_missing_file(tmp_path, monkeypatch):
file_path = tmp_path / "missing.json"
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
active, nxt = t.load_state()
assert active is None
assert nxt is None

32
tests/test_usage_unit.py Normal file
View file

@ -0,0 +1,32 @@
from providers.chatgpt.usage import _parse_window, clamp_percent
def test_clamp_percent_bounds():
assert clamp_percent(-1) == 0
assert clamp_percent(150) == 100
assert clamp_percent(49.6) == 50
def test_clamp_percent_invalid():
assert clamp_percent(None) == 0
assert clamp_percent("bad") == 0
def test_parse_window_valid():
window = {
"used_percent": 34.4,
"limit_window_seconds": 3600,
"reset_after_seconds": 120,
"reset_at": 999,
}
parsed = _parse_window(window)
assert parsed == {
"used_percent": 34,
"limit_window_seconds": 3600,
"reset_after_seconds": 120,
"reset_at": 999,
}
def test_parse_window_none():
assert _parse_window(None) is None

68
uv.lock generated
View file

@ -83,6 +83,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" },
]
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "frozenlist"
version = "1.8.0"
@ -158,6 +167,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]]
name = "megapt"
version = "0.1.0"
@ -168,12 +186,19 @@ dependencies = [
{ name = "playwright" },
]
[package.optional-dependencies]
dev = [
{ name = "pytest" },
]
[package.metadata]
requires-dist = [
{ name = "aiohttp", specifier = "==3.13.3" },
{ name = "pkce", specifier = "==1.0.3" },
{ name = "playwright", specifier = "==1.58.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
]
provides-extras = ["dev"]
[[package]]
name = "multidict"
@ -220,6 +245,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
]
[[package]]
name = "packaging"
version = "26.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
]
[[package]]
name = "pkce"
version = "1.0.3"
@ -248,6 +282,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/c4/cc0229fea55c87d6c9c67fe44a21e2cd28d1d558a5478ed4d617e9fb0c93/playwright-1.58.0-py3-none-win_arm64.whl", hash = "sha256:32ffe5c303901a13a0ecab91d1c3f74baf73b84f4bedbb6b935f5bc11cc98e1b", size = 33085919, upload-time = "2026-01-30T15:09:45.71Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "propcache"
version = "0.4.1"
@ -299,6 +342,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/c4/b4d4827c93ef43c01f599ef31453ccc1c132b353284fc6c87d535c233129/pyee-13.0.1-py3-none-any.whl", hash = "sha256:af2f8fede4171ef667dfded53f96e2ed0d6e6bd7ee3bb46437f77e3b57689228", size = 15659, upload-time = "2026-02-14T21:12:26.263Z" },
]
[[package]]
name = "pygments"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
[[package]]
name = "pytest"
version = "9.0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
]
[[package]]
name = "typing-extensions"
version = "4.15.0"