From 8b5449b1fd989257510557c0ec4de1b98efacd2f Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 17:40:38 +0300 Subject: [PATCH 1/6] --- README.md | 16 +- pyproject.toml | 5 + src/browser.py | 13 +- src/email_providers/temp_mail_org.py | 24 +-- src/providers/base.py | 20 +++ src/providers/chatgpt/provider.py | 69 +++++--- src/providers/chatgpt/registration.py | 149 ++++++++++------- src/providers/chatgpt/tokens.py | 97 ++++++----- src/providers/chatgpt/usage.py | 54 +++--- src/server.py | 227 ++++++++++++-------------- tests/conftest.py | 12 ++ tests/test_registration_unit.py | 37 +++++ tests/test_server_unit.py | 150 +++++++++++++++++ tests/test_tokens_unit.py | 60 +++++++ tests/test_usage_unit.py | 32 ++++ 15 files changed, 663 insertions(+), 302 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_registration_unit.py create mode 100644 tests/test_server_unit.py create mode 100644 tests/test_tokens_unit.py create mode 100644 tests/test_usage_unit.py diff --git a/README.md b/README.md index 6f423f1..a061094 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Response shape: "used_percent": 0, "remaining_percent": 100, "exhausted": false, - "needs_refresh": false + "needs_prepare": false }, "usage": { "primary_window": { @@ -54,10 +54,8 @@ Behavior: ## Startup Behavior -On startup, service: - -1. Ensures active token exists and is usable. -2. Ensures `next_account` is prepared for ChatGPT. +On startup, service ensures active token exists and is usable. +Standby preparation runs through provider lifecycle hooks/background trigger when needed. ## Data Files @@ -70,6 +68,14 @@ On startup, service: PYTHONPATH=./src python src/server.py ``` +## Unit Tests + +The project has unit tests only (no integration/network tests). + +```bash +pytest -q +``` + ## Docker Notes - Dockerfile sets `DATA_DIR=/data`. diff --git a/pyproject.toml b/pyproject.toml index cc1ff33..7f927f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,5 +8,10 @@ dependencies = [ "pkce==1.0.3", ] +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", +] + [tool.uv] package = false diff --git a/src/browser.py b/src/browser.py index 3b35da9..daa7fef 100644 --- a/src/browser.py +++ b/src/browser.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import socket import shutil import subprocess import tempfile @@ -44,7 +45,12 @@ CHROME_FLAGS = [ "--disable-search-engine-choice-screen", ] -DEFAULT_CDP_PORT = 9222 + +def _allocate_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return int(s.getsockname()[1]) def _fetch_ws_endpoint(port: int) -> str | None: @@ -79,10 +85,9 @@ class ManagedBrowser: shutil.rmtree(self.profile_dir, ignore_errors=True) -async def launch( - playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT -) -> ManagedBrowser: +async def launch(playwright: Playwright) -> ManagedBrowser: chrome_path = playwright.chromium.executable_path + cdp_port = _allocate_free_port() profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp")) args = [ diff --git a/src/email_providers/temp_mail_org.py b/src/email_providers/temp_mail_org.py index cb2b840..0a0b65f 100644 --- a/src/email_providers/temp_mail_org.py +++ b/src/email_providers/temp_mail_org.py @@ -2,7 +2,7 @@ import asyncio import logging import re -from playwright.async_api import BrowserContext, Page +from playwright.async_api import BrowserContext, Error as PlaywrightError, Page from .base import BaseProvider @@ -44,7 +44,7 @@ class TempMailOrgProvider(BaseProvider): value, ) return value - except Exception: + except PlaywrightError: continue try: @@ -53,8 +53,8 @@ class TempMailOrgProvider(BaseProvider): if found: logger.info("[temp-mail.org] email found by body scan: %s", found) return found - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to scan body text for email") await asyncio.sleep(1) @@ -76,7 +76,7 @@ class TempMailOrgProvider(BaseProvider): try: count = await items.count() logger.info("[temp-mail.org] inbox items: %s", count) - except Exception: + except PlaywrightError: count = 0 if count > 0: @@ -87,30 +87,30 @@ class TempMailOrgProvider(BaseProvider): continue text = (await item.inner_text()).strip().replace("\n", " ") logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160]) - except Exception: + except PlaywrightError: continue if text: try: await item.click() logger.info("[temp-mail.org] opened item[%s]", idx) - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to open inbox item[%s]", idx) message_text = text try: content = await page.content() if content and "Your ChatGPT code is" in content: message_text = content - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to read opened message content") try: await page.go_back( wait_until="domcontentloaded", timeout=5000 ) logger.info("[temp-mail.org] returned back to inbox") - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to return back to inbox") return message_text diff --git a/src/providers/base.py b/src/providers/base.py index 1a5b6bc..3185c5b 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -52,3 +52,23 @@ class Provider(ABC): def save_tokens(self, tokens: ProviderTokens) -> None: """Save tokens to storage""" pass + + async def force_recreate_token(self) -> str | None: + """Force-create a new active token when normal acquisition fails.""" + return None + + async def maybe_rotate_account(self, usage_percent: int) -> bool: + """Rotate active account/token if provider policy requires it.""" + return False + + async def ensure_standby_account( + self, + usage_percent: int, + prepare_threshold: int, + ) -> None: + """Prepare standby account/token asynchronously when needed.""" + return None + + async def startup_prepare(self) -> None: + """Optional provider-specific startup preparation.""" + return None diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index 2971c27..ae2e527 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -16,7 +16,7 @@ from .tokens import ( load_tokens, promote_next_tokens, refresh_tokens, - save_next_tokens, + save_state, save_tokens, ) from .usage import get_usage_data @@ -44,10 +44,14 @@ class ChatGPTProvider(Provider): attempt, CHATGPT_REGISTRATION_MAX_ATTEMPTS, ) - success = await self.register_new_account() - if success: + generated_tokens = await register_chatgpt_account( + email_provider_factory=self.email_provider_factory, + ) + if generated_tokens: + save_tokens(generated_tokens) return True logger.warning("Registration attempt %s failed", attempt) + await asyncio.sleep(1.5 * attempt) return False async def _create_next_account_under_lock(self) -> bool: @@ -56,23 +60,28 @@ class ChatGPTProvider(Provider): return True logger.info("Creating next account") - success = await self._register_with_retries() - if not success: - return False + for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1): + logger.info( + "Next-account registration attempt %s/%s", + attempt, + CHATGPT_REGISTRATION_MAX_ATTEMPTS, + ) + generated_tokens = await register_chatgpt_account( + email_provider_factory=self.email_provider_factory, + ) + if generated_tokens: + if active_before: + save_state(active_before, generated_tokens) + else: + save_state(generated_tokens, None) + logger.info("Next account is ready") + return True + logger.warning("Next-account registration attempt %s failed", attempt) + await asyncio.sleep(1.5 * attempt) - generated_active = load_tokens() - if not generated_active: - return False - - # Registration writes new tokens as active; restore old active and keep - # generated account as next. - if active_before: - save_tokens(active_before) - else: - clear_next_tokens() - save_next_tokens(generated_active) - logger.info("Next account is ready") - return True + if active_before or next_before: + save_state(active_before, next_before) + return False async def force_recreate_token(self) -> str | None: async with self._token_write_lock: @@ -85,6 +94,9 @@ class ChatGPTProvider(Provider): return None return tokens.access_token + async def startup_prepare(self) -> None: + await self.ensure_next_account() + async def ensure_next_account(self) -> bool: next_tokens = load_next_tokens() if next_tokens and not next_tokens.is_expired: @@ -96,6 +108,14 @@ class ChatGPTProvider(Provider): return True return await self._create_next_account_under_lock() + async def ensure_standby_account( + self, + usage_percent: int, + prepare_threshold: int, + ) -> None: + if usage_percent >= prepare_threshold: + await self.ensure_next_account() + async def maybe_switch_active_account(self, usage_percent: int) -> bool: if usage_percent < CHATGPT_SWITCH_THRESHOLD: return False @@ -119,6 +139,9 @@ class ChatGPTProvider(Provider): ) return switched + async def maybe_rotate_account(self, usage_percent: int) -> bool: + return await self.maybe_switch_active_account(usage_percent) + @property def name(self) -> str: return "chatgpt" @@ -154,13 +177,17 @@ class ChatGPTProvider(Provider): async def register_new_account(self) -> bool: """Register a new ChatGPT account""" - return await register_chatgpt_account( + generated_tokens = await register_chatgpt_account( email_provider_factory=self.email_provider_factory, ) + if not generated_tokens: + return False + save_tokens(generated_tokens) + return True async def get_usage_info(self, access_token: str) -> dict[str, Any]: """Get usage information for the current token""" - usage_data = get_usage_data(access_token) + usage_data = await get_usage_data(access_token) if not usage_data: return {"error": "Failed to get usage"} diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py index 1af2581..3d106c1 100644 --- a/src/providers/chatgpt/registration.py +++ b/src/providers/chatgpt/registration.py @@ -14,12 +14,17 @@ from typing import Callable from urllib.parse import parse_qs, urlencode, urlparse import aiohttp -from playwright.async_api import async_playwright, Page, BrowserContext +from playwright.async_api import ( + async_playwright, + Error as PlaywrightError, + Page, + BrowserContext, +) from browser import launch as launch_browser from email_providers import BaseProvider from providers.base import ProviderTokens -from .tokens import CLIENT_ID, save_tokens +from .tokens import CLIENT_ID logger = logging.getLogger(__name__) @@ -46,9 +51,9 @@ async def save_error_screenshot(page: Page | None, step: str): filename = screenshots_dir / f"error_{step}_{timestamp}.png" try: await page.screenshot(path=str(filename)) - logger.error(f"Screenshot saved: {filename}") - except: - pass + logger.error("Screenshot saved: %s", filename) + except PlaywrightError as e: + logger.warning("Failed to save screenshot at step %s: %s", step, e) def generate_password(length: int = 20) -> str: @@ -204,8 +209,7 @@ def generate_state() -> str: return secrets.token_urlsafe(32) -def build_authorize_url(verifier: str, challenge: str, state: str) -> str: - del verifier +def build_authorize_url(challenge: str, state: str) -> str: params = { "response_type": "code", "client_id": CLIENT_ID, @@ -222,26 +226,33 @@ def build_authorize_url(verifier: str, challenge: str, state: str) -> str: async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens: - async with aiohttp.ClientSession() as session: - payload = { - "grant_type": "authorization_code", - "client_id": CLIENT_ID, - "code": code, - "code_verifier": verifier, - "redirect_uri": REDIRECT_URI, - } - async with session.post(TOKEN_URL, data=payload) as resp: - if not resp.ok: - text = await resp.text() - raise RuntimeError(f"Token exchange failed: {resp.status} {text}") - body = await resp.json() + payload = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": REDIRECT_URI, + } + timeout = aiohttp.ClientTimeout(total=20) + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=payload) as resp: + if not resp.ok: + text = await resp.text() + raise RuntimeError(f"Token exchange failed: {resp.status} {text}") + body = await resp.json() + except (aiohttp.ClientError, TimeoutError) as e: + raise RuntimeError(f"Token exchange request error: {e}") from e - expires_in = int(body["expires_in"]) - return ProviderTokens( - access_token=body["access_token"], - refresh_token=body["refresh_token"], - expires_at=time.time() + expires_in, - ) + try: + expires_in = int(body["expires_in"]) + return ProviderTokens( + access_token=body["access_token"], + refresh_token=body["refresh_token"], + expires_at=time.time() + expires_in, + ) + except (KeyError, TypeError, ValueError) as e: + raise RuntimeError(f"Token exchange response parse error: {e}") from e async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: @@ -270,6 +281,24 @@ async def click_continue(page: Page, timeout_ms: int = 10000): await btn.click() +async def click_any_visible_button( + page: Page, + labels: list[str], + timeout_ms: int = 2000, +) -> bool: + for label in labels: + button = page.get_by_role("button", name=label) + if await button.count() == 0: + continue + try: + await button.first.wait_for(state="visible", timeout=timeout_ms) + await button.first.click(timeout=timeout_ms) + return True + except PlaywrightError: + continue + return False + + async def wait_for_signup_stabilization( page: Page, source_url: str, @@ -288,12 +317,12 @@ async def wait_for_signup_stabilization( async def register_chatgpt_account( email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, -) -> bool: +) -> ProviderTokens | None: logger.info("=== Starting ChatGPT account registration ===") if email_provider_factory is None: logger.error("No email provider factory configured") - return False + return None birth_month, birth_day, birth_year = generate_birthdate_90s() @@ -321,7 +350,7 @@ async def register_chatgpt_account( full_name = generate_name() verifier, challenge = generate_pkce_pair() oauth_state = generate_state() - authorize_url = build_authorize_url(verifier, challenge, oauth_state) + authorize_url = build_authorize_url(challenge, oauth_state) logger.info("[2/5] Registering ChatGPT for %s", email) chatgpt_page = await context.new_page() @@ -352,19 +381,18 @@ async def register_chatgpt_account( raise AutomationError( "email_provider", "Email provider returned no verification message" ) - logger.info("[3/5] Verification code extracted: %s", code) + logger.info("[3/5] Verification code extracted") await chatgpt_page.bring_to_front() code_input = chatgpt_page.get_by_placeholder("Code") - if await code_input.count() > 0: - await code_input.fill(code) + await code_input.first.wait_for(state="visible", timeout=10000) + await code_input.first.fill(code) await click_continue(chatgpt_page) logger.info("[4/5] Setting profile...") name_input = chatgpt_page.get_by_placeholder("Full name") await name_input.first.wait_for(state="visible", timeout=20000) - if await name_input.count() > 0: - await name_input.fill(full_name) + await name_input.first.fill(full_name) await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year) profile_url = chatgpt_page.url @@ -387,45 +415,42 @@ async def register_chatgpt_account( oauth_page.on("request", handle_request) await oauth_page.goto(authorize_url, wait_until="domcontentloaded") - await oauth_page.locator( - 'input[type="email"], input[name="email"]' - ).first.wait_for(state="visible", timeout=20000) - email_input = oauth_page.locator('input[type="email"], input[name="email"]') if await email_input.count() > 0: + await email_input.first.wait_for(state="visible", timeout=10000) await email_input.first.fill(email) - - continue_button = oauth_page.get_by_role("button", name="Continue") - if await continue_button.count() > 0: - await continue_button.first.click() - await oauth_page.locator('input[type="password"]').first.wait_for( - state="visible", timeout=20000 + await click_any_visible_button( + oauth_page, ["Continue"], timeout_ms=4000 ) password_input = oauth_page.locator('input[type="password"]') if await password_input.count() > 0: + await password_input.first.wait_for(state="visible", timeout=10000) await password_input.first.fill(password) - continue_button = oauth_page.get_by_role("button", name="Continue") - if await continue_button.count() > 0: - await continue_button.first.click() + await click_any_visible_button( + oauth_page, ["Continue"], timeout_ms=4000 + ) - for label in ["Continue", "Allow", "Authorize"]: - button = oauth_page.get_by_role("button", name=label) - if await button.count() > 0: - try: - await button.first.click(timeout=5000) - await oauth_page.wait_for_timeout(500) - except Exception: - pass + for _ in range(6): + if redirect_url_captured: + break + clicked = await click_any_visible_button( + oauth_page, + ["Continue", "Allow", "Authorize"], + timeout_ms=2000, + ) + if clicked: + await asyncio.sleep(0.4) + else: + await asyncio.sleep(0.4) if not redirect_url_captured: try: - await oauth_page.wait_for_timeout(4000) current_url = oauth_page.url if "localhost:1455" in current_url and "code=" in current_url: redirect_url_captured = current_url logger.info("Captured OAuth redirect from page URL") - except Exception: + except PlaywrightError: pass if not redirect_url_captured: @@ -446,20 +471,18 @@ async def register_chatgpt_account( raise AutomationError("oauth", "OAuth state mismatch", oauth_page) tokens = await exchange_code_for_tokens(auth_code, verifier) - save_tokens(tokens) - logger.info("OAuth tokens saved successfully") + logger.info("OAuth tokens fetched successfully") - return True + return tokens except AutomationError as e: logger.error(f"Error at step [{e.step}]: {e.message}") await save_error_screenshot(e.page, e.step) - return False + return None except Exception as e: logger.error(f"Unexpected error: {e}") await save_error_screenshot(current_page, "unexpected") - return False + return None finally: if managed: - await asyncio.sleep(2) await managed.close() diff --git a/src/providers/chatgpt/tokens.py b/src/providers/chatgpt/tokens.py index cc9a4f6..b4ed7fb 100644 --- a/src/providers/chatgpt/tokens.py +++ b/src/providers/chatgpt/tokens.py @@ -1,6 +1,7 @@ import json import logging import os +import tempfile import time from pathlib import Path from typing import Any @@ -35,7 +36,7 @@ def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None: refresh_token=data["refresh_token"], expires_at=data["expires_at"], ) - except KeyError, TypeError: + except (KeyError, TypeError): return None @@ -54,8 +55,20 @@ def _load_raw() -> dict[str, Any] | None: def _save_raw(data: dict[str, Any]) -> None: TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) - with open(TOKENS_FILE, "w") as f: - json.dump(data, f, indent=2) + fd, tmp_path = tempfile.mkstemp( + prefix=f"{TOKENS_FILE.name}.", + suffix=".tmp", + dir=str(TOKENS_FILE.parent), + ) + try: + with os.fdopen(fd, "w") as f: + json.dump(data, f, indent=2) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, TOKENS_FILE) + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]: @@ -68,7 +81,6 @@ def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]: "next_account": data.get("next_account"), } - # Backward compatibility with old flat schema return {"active": data, "next_account": None} @@ -79,9 +91,7 @@ def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]: return active, next_account -def save_state( - active: ProviderTokens | None, next_account: ProviderTokens | None -) -> None: +def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None: payload = { "active": _tokens_to_dict(active) if active else None, "next_account": _tokens_to_dict(next_account) if next_account else None, @@ -104,13 +114,8 @@ def save_tokens(tokens: ProviderTokens): save_state(tokens, next_account) -def save_next_tokens(tokens: ProviderTokens): - active, _ = load_state() - save_state(active, tokens) - - def promote_next_tokens() -> bool: - active, next_account = load_state() + _, next_account = load_state() if not next_account: return False save_state(next_account, None) @@ -123,42 +128,34 @@ def clear_next_tokens(): async def refresh_tokens(refresh_token: str) -> ProviderTokens | None: - async with aiohttp.ClientSession() as session: - data = { - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "client_id": CLIENT_ID, - } - async with session.post(TOKEN_URL, data=data) as resp: - if not resp.ok: - text = await resp.text() - logger.warning("Token refresh failed: %s %s", resp.status, text) - return None - json_resp = await resp.json() - expires_in = json_resp["expires_in"] - return ProviderTokens( - access_token=json_resp["access_token"], - refresh_token=json_resp["refresh_token"], - expires_at=time.time() + expires_in, - ) - - -async def get_valid_tokens() -> ProviderTokens | None: - tokens = load_tokens() - if not tokens: - logger.info("No tokens found") + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CLIENT_ID, + } + timeout = aiohttp.ClientTimeout(total=15) + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=data) as resp: + if not resp.ok: + text = await resp.text() + logger.warning("Token refresh failed: %s %s", resp.status, text) + return None + json_resp = await resp.json() + except (aiohttp.ClientError, TimeoutError) as e: + logger.warning("Token refresh request error: %s", e) + return None + except Exception as e: + logger.warning("Token refresh unexpected error: %s", e) return None - if tokens.is_expired: - logger.info("Token expired, refreshing...") - if not tokens.refresh_token: - logger.info("No refresh token available") - return None - new_tokens = await refresh_tokens(tokens.refresh_token) - if not new_tokens: - logger.warning("Failed to refresh token") - return None - save_tokens(new_tokens) - return new_tokens - - return tokens + try: + expires_in = int(json_resp["expires_in"]) + return ProviderTokens( + access_token=json_resp["access_token"], + refresh_token=json_resp["refresh_token"], + expires_at=time.time() + expires_in, + ) + except (KeyError, TypeError, ValueError) as e: + logger.warning("Token refresh response parse error: %s", e) + return None diff --git a/src/providers/chatgpt/usage.py b/src/providers/chatgpt/usage.py index 6236bd8..7c2edab 100644 --- a/src/providers/chatgpt/usage.py +++ b/src/providers/chatgpt/usage.py @@ -1,14 +1,15 @@ -import json -import socket -import urllib.error -import urllib.request +import logging from typing import Any +import aiohttp + +logger = logging.getLogger(__name__) + def clamp_percent(value: Any) -> int: try: num = float(value) - except TypeError, ValueError: + except (TypeError, ValueError): return 0 if num < 0: return 0 @@ -28,30 +29,36 @@ def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None: } -def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None: +async def get_usage_data( + access_token: str, + timeout_ms: int = 10000, +) -> dict[str, Any] | None: headers = { "Authorization": f"Bearer {access_token}", "User-Agent": "CodexProxy", "Accept": "application/json", } - req = urllib.request.Request( - "https://chatgpt.com/backend-api/wham/usage", - headers=headers, - method="GET", - ) + timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000) + url = "https://chatgpt.com/backend-api/wham/usage" try: - with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res: - body = res.read().decode("utf-8", errors="replace") - except urllib.error.HTTPError: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as res: + if not res.ok: + body = await res.text() + logger.warning( + "Usage fetch failed: status=%s body=%s", + res.status, + body[:300], + ) + return None + data = await res.json() + except (aiohttp.ClientError, TimeoutError) as e: + logger.warning("Usage fetch request error: %s", e) return None - except urllib.error.URLError, socket.timeout: - return None - - try: - data = json.loads(body) - except json.JSONDecodeError: + except Exception as e: + logger.warning("Usage fetch unexpected error: %s", e) return None rate_limit = data.get("rate_limit") or {} @@ -76,10 +83,3 @@ def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] "limit_reached": bool(rate_limit.get("limit_reached")), "allowed": bool(rate_limit.get("allowed", True)), } - - -def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int: - data = get_usage_data(access_token, timeout_ms=timeout_ms) - if not data: - return -1 - return int(data["used_percent"]) diff --git a/src/server.py b/src/server.py index 64ac91d..7740c37 100644 --- a/src/server.py +++ b/src/server.py @@ -5,23 +5,44 @@ import os from aiohttp import web from providers.chatgpt import ChatGPTProvider - -PORT = int(os.environ.get("PORT", "8080")) -CHATGPT_PREPARE_THRESHOLD = int(os.environ.get("CHATGPT_PREPARE_THRESHOLD", "85")) -LIMIT_EXHAUSTED_PERCENT = 100 +from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD +from providers.base import Provider logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Registry of available providers -PROVIDERS = { + +def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int: + raw = os.environ.get(name) + if raw is None: + return default + try: + value = int(raw) + except ValueError: + logger.warning("Invalid %s=%r, using default %s", name, raw, default) + return default + if value < minimum or value > maximum: + logger.warning( + "%s=%s out of range [%s,%s], using default %s", + name, + value, + minimum, + maximum, + default, + ) + return default + return value + + +PORT = _parse_int_env("PORT", 8080, 1, 65535) +CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +LIMIT_EXHAUSTED_PERCENT = 100 + +PROVIDERS: dict[str, Provider] = { "chatgpt": ChatGPTProvider(), } -refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()} -background_refresh_tasks: dict[str, asyncio.Task | None] = { - name: None for name in PROVIDERS.keys() -} +background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS} @web.middleware @@ -31,16 +52,22 @@ async def request_log_middleware(request: web.Request, handler): return response -def build_limit(usage_percent: int) -> dict[str, int | bool]: +def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]: remaining = max(0, 100 - usage_percent) return { "used_percent": usage_percent, "remaining_percent": remaining, "exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT, - "needs_refresh": usage_percent >= CHATGPT_PREPARE_THRESHOLD, + "needs_prepare": usage_percent >= prepare_threshold, } +def get_prepare_threshold(provider_name: str) -> int: + if provider_name == "chatgpt": + return CHATGPT_PREPARE_THRESHOLD + return 100 + + async def ensure_provider_token_ready(provider_name: str): provider = PROVIDERS.get(provider_name) if not provider: @@ -52,135 +79,92 @@ async def ensure_provider_token_ready(provider_name: str): logger.warning( "[%s] Startup token check failed, forcing recreation", provider_name ) - if isinstance(provider, ChatGPTProvider): - token = await provider.force_recreate_token() + token = await provider.force_recreate_token() if not token: logger.error("[%s] Could not prepare token at startup", provider_name) return - if isinstance(provider, ChatGPTProvider): - await provider.ensure_next_account() - usage_info = await provider.get_usage_info(token) - if "error" not in usage_info: - logger.info("[%s] Startup token is ready", provider_name) - return - - logger.warning( - "[%s] Startup token invalid for usage, forcing recreation", provider_name - ) - if isinstance(provider, ChatGPTProvider): + if "error" in usage_info: + logger.warning( + "[%s] Startup token invalid for usage, forcing recreation", provider_name + ) token = await provider.force_recreate_token() - if token: - logger.info("[%s] Startup token recreated successfully", provider_name) + if not token: + logger.error("[%s] Startup token recreation failed", provider_name) return - logger.error("[%s] Startup token recreation failed", provider_name) + await provider.startup_prepare() + logger.info("[%s] Startup token is ready", provider_name) -async def on_startup(app: web.Application): - del app - for provider_name in PROVIDERS.keys(): - await ensure_provider_token_ready(provider_name) - - -async def issue_new_token(provider_name: str) -> str | None: +async def ensure_standby_task(provider_name: str, usage_percent: int, reason: str): provider = PROVIDERS.get(provider_name) if not provider: - return None - - async with refresh_locks[provider_name]: - logger.info(f"[{provider_name}] Generating new token") - success = await provider.register_new_account() - if not success: - logger.error(f"[{provider_name}] Token generation failed") - return None - - token = await provider.get_token() - if not token: - logger.error(f"[{provider_name}] Token was generated but not available") - return None - - return token - - -async def background_refresh_worker(provider_name: str, reason: str): + return try: - logger.info(f"[{provider_name}] Starting background token refresh ({reason})") - new_token = await issue_new_token(provider_name) - if new_token: - logger.info(f"[{provider_name}] Background token refresh completed") - else: - logger.error(f"[{provider_name}] Background token refresh failed") + logger.info("[%s] Preparing standby in background (%s)", provider_name, reason) + threshold = get_prepare_threshold(provider_name) + await provider.ensure_standby_account(usage_percent, threshold) except Exception: - logger.exception( - f"[{provider_name}] Unhandled error in background token refresh" - ) + logger.exception("[%s] Unhandled standby preparation error", provider_name) -def trigger_background_refresh(provider_name: str, reason: str): - task = background_refresh_tasks.get(provider_name) +def trigger_standby_prepare(provider_name: str, usage_percent: int, reason: str): + task = background_tasks.get(provider_name) if task and not task.done(): logger.info( - f"[{provider_name}] Background refresh already running, skip ({reason})" + "[%s] Standby prep already running, skip (%s)", provider_name, reason ) return - background_refresh_tasks[provider_name] = asyncio.create_task( - background_refresh_worker(provider_name, reason) + background_tasks[provider_name] = asyncio.create_task( + ensure_standby_task(provider_name, usage_percent, reason) ) async def token_handler(request: web.Request) -> web.Response: provider_name = request.match_info.get("provider", "chatgpt") - provider = PROVIDERS.get(provider_name) if not provider: return web.json_response( - {"error": f"Unknown provider: {provider_name}"}, - status=404, + {"error": f"Unknown provider: {provider_name}"}, status=404 ) - # Get or create token token = await provider.get_token() if not token: - return web.json_response( - {"error": "Failed to get active token"}, - status=503, - ) + return web.json_response({"error": "Failed to get active token"}, status=503) - # Get usage info usage_info = await provider.get_usage_info(token) if "error" in usage_info: - return web.json_response( - {"error": usage_info["error"]}, - status=503, + return web.json_response({"error": usage_info["error"]}, status=503) + + usage_percent = int(usage_info.get("used_percent", 0)) + switched = await provider.maybe_rotate_account(usage_percent) + if switched: + token = await provider.get_token() + if not token: + return web.json_response( + {"error": "Failed to get active token after account switch"}, + status=503, + ) + usage_info = await provider.get_usage_info(token) + if "error" in usage_info: + return web.json_response({"error": usage_info["error"]}, status=503) + usage_percent = int(usage_info.get("used_percent", 0)) + logger.info("[%s] Active account switched before response", provider_name) + + prepare_threshold = get_prepare_threshold(provider_name) + if usage_percent >= prepare_threshold: + trigger_standby_prepare( + provider_name, + usage_percent, + f"usage {usage_percent}% >= threshold {prepare_threshold}%", ) - usage_percent = usage_info.get("used_percent", 0) - remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent)) - - if isinstance(provider, ChatGPTProvider): - switched = await provider.maybe_switch_active_account(usage_percent) - if switched: - token = await provider.get_token() - if not token: - return web.json_response( - {"error": "Failed to get active token after account switch"}, - status=503, - ) - usage_info = await provider.get_usage_info(token) - if "error" in usage_info: - return web.json_response( - {"error": usage_info["error"]}, - status=503, - ) - usage_percent = usage_info.get("used_percent", 0) - remaining_percent = usage_info.get( - "remaining_percent", max(0, 100 - usage_percent) - ) - logger.info("[%s] Active account switched before response", provider_name) - + remaining_percent = int( + usage_info.get("remaining_percent", max(0, 100 - usage_percent)) + ) logger.info( "[%s] token issued, used=%s%% remaining=%s%%", provider_name, @@ -207,20 +191,10 @@ async def token_handler(request: web.Request) -> web.Response: secondary_window.get("reset_after_seconds", 0), ) - # Trigger background refresh if needed - if usage_percent >= CHATGPT_PREPARE_THRESHOLD: - if isinstance(provider, ChatGPTProvider): - await provider.ensure_next_account() - else: - trigger_background_refresh( - provider_name, - f"usage {usage_percent}% >= threshold {CHATGPT_PREPARE_THRESHOLD}%", - ) - return web.json_response( { "token": token, - "limit": build_limit(usage_percent), + "limit": build_limit(usage_percent, prepare_threshold), "usage": { "primary_window": primary_window, "secondary_window": secondary_window, @@ -229,22 +203,35 @@ async def token_handler(request: web.Request) -> web.Response: ) +async def on_startup(app: web.Application): + del app + for provider_name in PROVIDERS: + await ensure_provider_token_ready(provider_name) + + +async def on_cleanup(app: web.Application): + del app + for task in background_tasks.values(): + if task and not task.done(): + task.cancel() + pending = [t for t in background_tasks.values() if t is not None] + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + def create_app() -> web.Application: app = web.Application(middlewares=[request_log_middleware]) app.on_startup.append(on_startup) - # New route: /{provider}/token + app.on_cleanup.append(on_cleanup) app.router.add_get("/{provider}/token", token_handler) - # Legacy route for backward compatibility app.router.add_get("/token", token_handler) return app if __name__ == "__main__": logger.info("Starting token service on port %s", PORT) - logger.info( - "ChatGPT prepare-next threshold: %s%%", - CHATGPT_PREPARE_THRESHOLD, - ) + logger.info("ChatGPT prepare-next threshold: %s%%", CHATGPT_PREPARE_THRESHOLD) + logger.info("ChatGPT switch threshold: %s%%", CHATGPT_SWITCH_THRESHOLD) logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) app = create_app() web.run_app(app, host="0.0.0.0", port=PORT) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9bf27d4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import sys +from pathlib import Path + + +def _add_src_to_path() -> None: + root = Path(__file__).resolve().parents[1] + src = root / "src" + if str(src) not in sys.path: + sys.path.insert(0, str(src)) + + +_add_src_to_path() diff --git a/tests/test_registration_unit.py b/tests/test_registration_unit.py new file mode 100644 index 0000000..32fd941 --- /dev/null +++ b/tests/test_registration_unit.py @@ -0,0 +1,37 @@ +from providers.chatgpt.registration import ( + build_authorize_url, + extract_verification_code, + generate_birthdate_90s, + generate_name, +) + + +def test_generate_name_shape(): + name = generate_name() + parts = name.split(" ") + assert len(parts) == 2 + assert all(p.isalpha() for p in parts) + + +def test_generate_birthdate_90s_range(): + month, day, year = generate_birthdate_90s() + assert 1 <= int(month) <= 12 + assert 1 <= int(day) <= 28 + assert 1990 <= int(year) <= 1999 + + +def test_extract_verification_code_prefers_chatgpt_phrase(): + text = "foo 123456 bar Your ChatGPT code is 654321" + assert extract_verification_code(text) == "654321" + + +def test_extract_verification_code_fallback_last_code(): + text = "codes 111111 and 222222" + assert extract_verification_code(text) == "222222" + + +def test_build_authorize_url_contains_required_params(): + url = build_authorize_url("challenge", "state123") + assert "response_type=code" in url + assert "code_challenge=challenge" in url + assert "state=state123" in url diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py new file mode 100644 index 0000000..4b74a3c --- /dev/null +++ b/tests/test_server_unit.py @@ -0,0 +1,150 @@ +import asyncio +import json + +import server +from providers.base import Provider, ProviderTokens + + +class FakeRequest: + def __init__(self, provider: str): + self.match_info = {"provider": provider} + + +class FakeProvider(Provider): + def __init__( + self, + token: str | None = "tok", + usage: dict | None = None, + rotate: bool = False, + ): + self._token = token + self._usage = usage or { + "used_percent": 10, + "remaining_percent": 90, + "primary_window": None, + "secondary_window": None, + } + self._rotate = rotate + self.get_token_calls = 0 + self.standby_calls = 0 + + @property + def name(self) -> str: + return "fake" + + async def get_token(self) -> str | None: + self.get_token_calls += 1 + return self._token + + async def register_new_account(self) -> bool: + return True + + async def get_usage_info(self, access_token: str) -> dict: + _ = access_token + return dict(self._usage) + + def load_tokens(self) -> ProviderTokens | None: + return None + + def save_tokens(self, tokens: ProviderTokens) -> None: + _ = tokens + + async def maybe_rotate_account(self, usage_percent: int) -> bool: + _ = usage_percent + return self._rotate + + async def ensure_standby_account( + self, usage_percent: int, prepare_threshold: int + ) -> None: + _ = usage_percent, prepare_threshold + self.standby_calls += 1 + + +def _response_json(resp) -> dict: + return json.loads(resp.body.decode("utf-8")) + + +def test_parse_int_env_defaults(monkeypatch): + monkeypatch.delenv("X_TEST", raising=False) + assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + + +def test_parse_int_env_invalid(monkeypatch): + monkeypatch.setenv("X_TEST", "abc") + assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + + +def test_parse_int_env_out_of_range(monkeypatch): + monkeypatch.setenv("X_TEST", "999") + assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + + +def test_build_limit_fields(): + limit = server.build_limit(90, 85) + assert limit == { + "used_percent": 90, + "remaining_percent": 10, + "exhausted": False, + "needs_prepare": True, + } + + +def test_get_prepare_threshold(): + assert server.get_prepare_threshold("chatgpt") == server.CHATGPT_PREPARE_THRESHOLD + assert server.get_prepare_threshold("unknown") == 100 + + +def test_token_handler_unknown_provider(monkeypatch): + monkeypatch.setattr(server, "PROVIDERS", {}) + resp = asyncio.run(server.token_handler(FakeRequest("missing"))) + assert resp.status == 404 + + +def test_token_handler_success(monkeypatch): + provider = FakeProvider() + monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) + monkeypatch.setattr(server, "background_tasks", {"fake": None}) + monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) + + resp = asyncio.run(server.token_handler(FakeRequest("fake"))) + data = _response_json(resp) + + assert resp.status == 200 + assert data["token"] == "tok" + assert data["limit"]["needs_prepare"] is False + + +def test_token_handler_triggers_standby(monkeypatch): + provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10}) + monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) + monkeypatch.setattr(server, "background_tasks", {"fake": None}) + + called = {"value": False} + + def fake_trigger(name, usage_percent, reason): + assert name == "fake" + assert usage_percent == 90 + assert "threshold" in reason + called["value"] = True + + monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) + monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger) + + resp = asyncio.run(server.token_handler(FakeRequest("fake"))) + assert resp.status == 200 + assert called["value"] is True + + +def test_token_handler_rotation_path(monkeypatch): + provider = FakeProvider( + usage={"used_percent": 96, "remaining_percent": 4}, + rotate=True, + ) + monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) + monkeypatch.setattr(server, "background_tasks", {"fake": None}) + monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) + monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None) + + resp = asyncio.run(server.token_handler(FakeRequest("fake"))) + assert resp.status == 200 + assert provider.get_token_calls >= 2 diff --git a/tests/test_tokens_unit.py b/tests/test_tokens_unit.py new file mode 100644 index 0000000..58f5af5 --- /dev/null +++ b/tests/test_tokens_unit.py @@ -0,0 +1,60 @@ +import json +from pathlib import Path + +from providers.base import ProviderTokens +from providers.chatgpt import tokens as t + + +def test_normalize_state_backward_compatible(): + raw = {"access_token": "a", "refresh_token": "r", "expires_at": 1} + normalized = t._normalize_state(raw) + assert normalized["active"]["access_token"] == "a" + assert normalized["next_account"] is None + + +def test_promote_next_tokens(tmp_path, monkeypatch): + file_path = tmp_path / "chatgpt_tokens.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + + active = ProviderTokens("a1", "r1", 100) + nxt = ProviderTokens("a2", "r2", 200) + t.save_state(active, nxt) + + assert t.promote_next_tokens() is True + cur, next_cur = t.load_state() + assert cur is not None + assert cur.access_token == "a2" + assert next_cur is None + + +def test_save_tokens_preserves_next(tmp_path, monkeypatch): + file_path = tmp_path / "chatgpt_tokens.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + + active = ProviderTokens("a1", "r1", 100) + nxt = ProviderTokens("a2", "r2", 200) + t.save_state(active, nxt) + + t.save_tokens(ProviderTokens("a3", "r3", 300)) + cur, next_cur = t.load_state() + assert cur is not None and cur.access_token == "a3" + assert next_cur is not None and next_cur.access_token == "a2" + + +def test_atomic_write_produces_valid_json(tmp_path, monkeypatch): + file_path = tmp_path / "chatgpt_tokens.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + + t.save_state(ProviderTokens("x", "y", 123), None) + with open(file_path) as f: + data = json.load(f) + assert "active" in data + assert data["active"]["access_token"] == "x" + + +def test_load_state_from_missing_file(tmp_path, monkeypatch): + file_path = tmp_path / "missing.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + active, nxt = t.load_state() + assert active is None + assert nxt is None diff --git a/tests/test_usage_unit.py b/tests/test_usage_unit.py new file mode 100644 index 0000000..63cad99 --- /dev/null +++ b/tests/test_usage_unit.py @@ -0,0 +1,32 @@ +from providers.chatgpt.usage import _parse_window, clamp_percent + + +def test_clamp_percent_bounds(): + assert clamp_percent(-1) == 0 + assert clamp_percent(150) == 100 + assert clamp_percent(49.6) == 50 + + +def test_clamp_percent_invalid(): + assert clamp_percent(None) == 0 + assert clamp_percent("bad") == 0 + + +def test_parse_window_valid(): + window = { + "used_percent": 34.4, + "limit_window_seconds": 3600, + "reset_after_seconds": 120, + "reset_at": 999, + } + parsed = _parse_window(window) + assert parsed == { + "used_percent": 34, + "limit_window_seconds": 3600, + "reset_after_seconds": 120, + "reset_at": 999, + } + + +def test_parse_window_none(): + assert _parse_window(None) is None From 307ca38ecc60904ca7838961c840c6215fdc0786 Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 19:10:50 +0300 Subject: [PATCH 2/6] fix: revert old oauth behavior and new default email provider --- README.md | 5 + scripts/run_token_refresh_flow.py | 55 ++++++ src/email_providers/__init__.py | 8 +- src/email_providers/mail_tm.py | 232 ++++++++++++++++++++++++++ src/providers/chatgpt/provider.py | 4 +- src/providers/chatgpt/registration.py | 90 +++++++--- uv.lock | 68 ++++++++ 7 files changed, 439 insertions(+), 23 deletions(-) create mode 100644 scripts/run_token_refresh_flow.py create mode 100644 src/email_providers/mail_tm.py diff --git a/README.md b/README.md index a061094..3e37601 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,11 @@ Behavior: 4. When usage reaches `CHATGPT_PREPARE_THRESHOLD`, service prepares `next_account`. 5. When usage reaches `CHATGPT_SWITCH_THRESHOLD`, service switches active account to `next_account`. +## Disposable Email Provider + +- Default provider is `mail.tm` API (`MailTmProvider`) and does not use browser automation. +- Flow: fetch domains -> create account with random address/password -> get JWT token -> poll messages. + ## Startup Behavior On startup, service ensures active token exists and is usable. diff --git a/scripts/run_token_refresh_flow.py b/scripts/run_token_refresh_flow.py new file mode 100644 index 0000000..577763b --- /dev/null +++ b/scripts/run_token_refresh_flow.py @@ -0,0 +1,55 @@ +import argparse +import asyncio +import json +import logging + +import browser +import server + + +class _FakeRequest: + def __init__(self, provider: str): + self.match_info = {"provider": provider} + self.method = "GET" + self.path_qs = f"/{provider}/token" + + +def _enable_headed_browser() -> bool: + if "--no-startup-window" in browser.CHROME_FLAGS: + browser.CHROME_FLAGS.remove("--no-startup-window") + return True + return False + + +async def _run(provider: str) -> int: + patched = _enable_headed_browser() + logging.info("Headed mode patch applied: %s", patched) + + request = _FakeRequest(provider) + response = await server.token_handler(request) + payload = json.loads(response.body.decode("utf-8")) + + logging.info("Response status: %s", response.status) + logging.info("Response body: %s", json.dumps(payload, indent=2)) + return 0 if response.status == 200 else 1 + + +def main() -> int: + parser = argparse.ArgumentParser( + description=( + "Run the same token refresh/issue flow as server /{provider}/token " + "in headed browser mode (non-headless)." + ) + ) + parser.add_argument("--provider", default="chatgpt") + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + return asyncio.run(_run(args.provider)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/email_providers/__init__.py b/src/email_providers/__init__.py index 76e66ca..a1af08e 100644 --- a/src/email_providers/__init__.py +++ b/src/email_providers/__init__.py @@ -1,5 +1,11 @@ from .base import BaseProvider +from .mail_tm import MailTmProvider from .ten_minute_mail import TenMinuteMailProvider from .temp_mail_org import TempMailOrgProvider -__all__ = ["BaseProvider", "TenMinuteMailProvider", "TempMailOrgProvider"] +__all__ = [ + "BaseProvider", + "MailTmProvider", + "TenMinuteMailProvider", + "TempMailOrgProvider", +] diff --git a/src/email_providers/mail_tm.py b/src/email_providers/mail_tm.py new file mode 100644 index 0000000..891a789 --- /dev/null +++ b/src/email_providers/mail_tm.py @@ -0,0 +1,232 @@ +import asyncio +import logging +import os +import secrets +import string +from typing import Any + +import aiohttp +from playwright.async_api import BrowserContext + +from .base import BaseProvider + +logger = logging.getLogger(__name__) + +_API_BASE = os.environ.get("MAIL_TM_API_BASE", "https://api.mail.tm") +_TIMEOUT_SECONDS = 20 +_FIRST_NAMES = [ + "james", + "john", + "robert", + "michael", + "david", + "william", + "joseph", + "thomas", + "daniel", + "mark", + "paul", + "kevin", +] +_LAST_NAMES = [ + "smith", + "johnson", + "williams", + "brown", + "jones", + "miller", + "davis", + "wilson", + "anderson", + "taylor", + "martin", + "thompson", +] + + +def _generate_local_part() -> str: + first = secrets.choice(_FIRST_NAMES) + last = secrets.choice(_LAST_NAMES) + digits = "".join(secrets.choice(string.digits) for _ in range(8)) + return f"{first}{last}{digits}" + + +def _generate_password(length: int = 24) -> str: + alphabet = string.ascii_letters + string.digits + return "".join(secrets.choice(alphabet) for _ in range(length)) + + +class MailTmProvider(BaseProvider): + def __init__(self, browser_session: BrowserContext): + super().__init__(browser_session) + self._address: str | None = None + self._password: str | None = None + self._token: str | None = None + + async def _request( + self, + method: str, + path: str, + *, + token: str | None = None, + json_body: dict[str, Any] | None = None, + ) -> tuple[int, dict[str, Any] | list[Any] | None]: + url = f"{_API_BASE.rstrip('/')}{path}" + headers: dict[str, str] = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + timeout = aiohttp.ClientTimeout(total=_TIMEOUT_SECONDS) + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.request( + method, + url, + headers=headers, + json=json_body, + ) as resp: + status = resp.status + try: + payload = await resp.json() + except aiohttp.ContentTypeError: + payload = None + return status, payload + except aiohttp.ClientError as e: + logger.warning("[mail.tm] request failed %s %s: %s", method, path, e) + return 0, None + + async def _get_domains(self) -> list[str]: + status, payload = await self._request("GET", "/domains") + if status != 200 or not isinstance(payload, dict): + raise RuntimeError("mail.tm domains request failed") + + members = payload.get("hydra:member") + if not isinstance(members, list): + raise RuntimeError("mail.tm domains response has unexpected format") + + domains: list[str] = [] + for item in members: + if not isinstance(item, dict): + continue + domain = item.get("domain") + is_active = bool(item.get("isActive", True)) + if isinstance(domain, str) and domain and is_active: + domains.append(domain) + + if not domains: + raise RuntimeError("mail.tm returned no active domains") + return domains + + async def _create_account(self, address: str, password: str) -> bool: + status, _ = await self._request( + "POST", + "/accounts", + json_body={"address": address, "password": password}, + ) + if status in (200, 201): + return True + return False + + async def _create_token(self, address: str, password: str) -> str | None: + status, payload = await self._request( + "POST", + "/token", + json_body={"address": address, "password": password}, + ) + if status != 200 or not isinstance(payload, dict): + return None + token = payload.get("token") + if isinstance(token, str) and token: + return token + return None + + async def get_new_email(self) -> str: + domains = await self._get_domains() + + for _ in range(8): + domain = secrets.choice(domains) + address = f"{_generate_local_part()}@{domain}" + password = _generate_password() + + created = await self._create_account(address, password) + if not created: + continue + + token = await self._create_token(address, password) + if not token: + continue + + self._address = address + self._password = password + self._token = token + logger.info("[mail.tm] New mailbox acquired: %s", address) + return address + + raise RuntimeError("mail.tm could not create account") + + async def _list_messages(self) -> list[dict[str, Any]]: + if not self._token: + return [] + status, payload = await self._request( + "GET", + "/messages", + token=self._token, + ) + if status == 401 and self._address and self._password: + token = await self._create_token(self._address, self._password) + if token: + self._token = token + status, payload = await self._request( + "GET", + "/messages", + token=self._token, + ) + + if status != 200 or not isinstance(payload, dict): + return [] + + members = payload.get("hydra:member") + if not isinstance(members, list): + return [] + return [item for item in members if isinstance(item, dict)] + + async def _get_message_text(self, message_id: str) -> str | None: + if not self._token: + return None + status, payload = await self._request( + "GET", + f"/messages/{message_id}", + token=self._token, + ) + if status != 200 or not isinstance(payload, dict): + return None + + parts = [ + payload.get("subject"), + payload.get("intro"), + payload.get("text"), + payload.get("html"), + ] + text = "\n".join(str(part) for part in parts if part) + return text or None + + async def get_latest_message(self, email: str) -> str | None: + del email + if not self._token: + raise RuntimeError("mail.tm provider is not initialized with mailbox token") + + for _ in range(45): + messages = await self._list_messages() + if messages: + latest = messages[0] + message_id = latest.get("id") + if isinstance(message_id, str) and message_id: + full_message = await self._get_message_text(message_id) + if full_message: + logger.info("[mail.tm] Latest message received") + return full_message + + await asyncio.sleep(2) + + logger.warning("[mail.tm] No messages received within timeout") + return None diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index ae2e527..445500f 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -7,7 +7,7 @@ from typing import Callable from playwright.async_api import BrowserContext from email_providers import BaseProvider -from email_providers import TempMailOrgProvider +from email_providers import MailTmProvider from providers.base import Provider, ProviderTokens from .tokens import ( clear_next_tokens, @@ -34,7 +34,7 @@ class ChatGPTProvider(Provider): self, email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, ): - self.email_provider_factory = email_provider_factory or TempMailOrgProvider + self.email_provider_factory = email_provider_factory or MailTmProvider self._token_write_lock = asyncio.Lock() async def _register_with_retries(self) -> bool: diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py index 3d106c1..270c6d3 100644 --- a/src/providers/chatgpt/registration.py +++ b/src/providers/chatgpt/registration.py @@ -281,7 +281,34 @@ async def click_continue(page: Page, timeout_ms: int = 10000): await btn.click() -async def click_any_visible_button( +async def oauth_needs_email_check(page: Page) -> bool: + marker = page.get_by_text("Check your inbox", exact=False) + return await marker.count() > 0 + + +async def fill_oauth_code_if_present(page: Page, code: str) -> bool: + candidates = [ + page.get_by_placeholder("Code"), + page.get_by_label("Code"), + page.locator( + 'input[name*="code" i], input[id*="code" i], ' + 'input[autocomplete="one-time-code"], input[inputmode="numeric"]' + ), + ] + + for locator in candidates: + if await locator.count() == 0: + continue + try: + await locator.first.wait_for(state="visible", timeout=1500) + await locator.first.fill(code) + return True + except PlaywrightError: + continue + return False + + +async def click_first_visible_button( page: Page, labels: list[str], timeout_ms: int = 2000, @@ -415,44 +442,67 @@ async def register_chatgpt_account( oauth_page.on("request", handle_request) await oauth_page.goto(authorize_url, wait_until="domcontentloaded") + await oauth_page.locator( + 'input[type="email"], input[name="email"]' + ).first.wait_for(state="visible", timeout=20000) + email_input = oauth_page.locator('input[type="email"], input[name="email"]') if await email_input.count() > 0: - await email_input.first.wait_for(state="visible", timeout=10000) await email_input.first.fill(email) - await click_any_visible_button( - oauth_page, ["Continue"], timeout_ms=4000 + + continue_button = oauth_page.get_by_role("button", name="Continue") + if await continue_button.count() > 0: + await continue_button.first.click() + await oauth_page.locator('input[type="password"]').first.wait_for( + state="visible", timeout=20000 ) password_input = oauth_page.locator('input[type="password"]') if await password_input.count() > 0: - await password_input.first.wait_for(state="visible", timeout=10000) await password_input.first.fill(password) - await click_any_visible_button( - oauth_page, ["Continue"], timeout_ms=4000 - ) + continue_button = oauth_page.get_by_role("button", name="Continue") + if await continue_button.count() > 0: + await continue_button.first.click() - for _ in range(6): + last_oauth_email_code = code + oauth_deadline = asyncio.get_running_loop().time() + 60 + while asyncio.get_running_loop().time() < oauth_deadline: if redirect_url_captured: break - clicked = await click_any_visible_button( - oauth_page, - ["Continue", "Allow", "Authorize"], - timeout_ms=2000, - ) - if clicked: - await asyncio.sleep(0.4) - else: - await asyncio.sleep(0.4) - if not redirect_url_captured: + if await oauth_needs_email_check(oauth_page): + logger.info("OAuth requested email confirmation code") + new_code = await get_latest_code(email_provider, email) + if new_code and new_code != last_oauth_email_code: + filled = await fill_oauth_code_if_present(oauth_page, new_code) + if filled: + last_oauth_email_code = new_code + logger.info("Filled OAuth email confirmation code") + else: + logger.warning( + "OAuth inbox challenge detected but code field not found" + ) + try: current_url = oauth_page.url if "localhost:1455" in current_url and "code=" in current_url: redirect_url_captured = current_url logger.info("Captured OAuth redirect from page URL") - except PlaywrightError: + break + except Exception: pass + clicked = await click_first_visible_button( + oauth_page, + ["Continue", "Allow", "Authorize", "Verify"], + timeout_ms=2000, + ) + + if clicked: + await oauth_page.wait_for_timeout(500) + else: + await oauth_page.wait_for_timeout(1000) + if not redirect_url_captured: raise AutomationError( "oauth", "OAuth redirect with code was not captured", oauth_page diff --git a/uv.lock b/uv.lock index bb427f9..17acd39 100644 --- a/uv.lock +++ b/uv.lock @@ -83,6 +83,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, ] +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + [[package]] name = "frozenlist" version = "1.8.0" @@ -158,6 +167,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "megapt" version = "0.1.0" @@ -168,12 +186,19 @@ dependencies = [ { name = "playwright" }, ] +[package.optional-dependencies] +dev = [ + { name = "pytest" }, +] + [package.metadata] requires-dist = [ { name = "aiohttp", specifier = "==3.13.3" }, { name = "pkce", specifier = "==1.0.3" }, { name = "playwright", specifier = "==1.58.0" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, ] +provides-extras = ["dev"] [[package]] name = "multidict" @@ -220,6 +245,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" }, ] +[[package]] +name = "packaging" +version = "26.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, +] + [[package]] name = "pkce" version = "1.0.3" @@ -248,6 +282,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/c4/cc0229fea55c87d6c9c67fe44a21e2cd28d1d558a5478ed4d617e9fb0c93/playwright-1.58.0-py3-none-win_arm64.whl", hash = "sha256:32ffe5c303901a13a0ecab91d1c3f74baf73b84f4bedbb6b935f5bc11cc98e1b", size = 33085919, upload-time = "2026-01-30T15:09:45.71Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "propcache" version = "0.4.1" @@ -299,6 +342,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/c4/b4d4827c93ef43c01f599ef31453ccc1c132b353284fc6c87d535c233129/pyee-13.0.1-py3-none-any.whl", hash = "sha256:af2f8fede4171ef667dfded53f96e2ed0d6e6bd7ee3bb46437f77e3b57689228", size = 15659, upload-time = "2026-02-14T21:12:26.263Z" }, ] +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From 858d127246929446e9f9a8eb8df377e5af518675 Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 19:33:43 +0300 Subject: [PATCH 3/6] refactor: some minor cleanup --- src/email_providers/base.py | 2 +- src/email_providers/mail_tm.py | 11 +++------ src/email_providers/temp_mail_org.py | 8 +++---- src/email_providers/ten_minute_mail.py | 8 +++---- src/email_providers/utils.py | 10 ++++++++ src/providers/chatgpt/provider.py | 10 ++++++-- src/providers/chatgpt/registration.py | 15 ++++-------- src/server.py | 33 +++++--------------------- src/utils/__init__.py | 4 ++++ src/utils/env.py | 22 +++++++++++++++++ src/utils/randoms.py | 13 ++++++++++ tests/test_server_unit.py | 7 +++--- 12 files changed, 84 insertions(+), 59 deletions(-) create mode 100644 src/email_providers/utils.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/env.py create mode 100644 src/utils/randoms.py diff --git a/src/email_providers/base.py b/src/email_providers/base.py index 386601d..878c7e6 100644 --- a/src/email_providers/base.py +++ b/src/email_providers/base.py @@ -12,5 +12,5 @@ class BaseProvider(ABC): pass @abstractmethod - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: pass diff --git a/src/email_providers/mail_tm.py b/src/email_providers/mail_tm.py index 891a789..4dd9c6c 100644 --- a/src/email_providers/mail_tm.py +++ b/src/email_providers/mail_tm.py @@ -9,6 +9,7 @@ import aiohttp from playwright.async_api import BrowserContext from .base import BaseProvider +from utils.randoms import generate_password logger = logging.getLogger(__name__) @@ -51,11 +52,6 @@ def _generate_local_part() -> str: return f"{first}{last}{digits}" -def _generate_password(length: int = 24) -> str: - alphabet = string.ascii_letters + string.digits - return "".join(secrets.choice(alphabet) for _ in range(length)) - - class MailTmProvider(BaseProvider): def __init__(self, browser_session: BrowserContext): super().__init__(browser_session) @@ -146,7 +142,7 @@ class MailTmProvider(BaseProvider): for _ in range(8): domain = secrets.choice(domains) address = f"{_generate_local_part()}@{domain}" - password = _generate_password() + password = generate_password(length=24) created = await self._create_account(address, password) if not created: @@ -210,8 +206,7 @@ class MailTmProvider(BaseProvider): text = "\n".join(str(part) for part in parts if part) return text or None - async def get_latest_message(self, email: str) -> str | None: - del email + async def get_latest_message(self) -> str | None: if not self._token: raise RuntimeError("mail.tm provider is not initialized with mailbox token") diff --git a/src/email_providers/temp_mail_org.py b/src/email_providers/temp_mail_org.py index 0a0b65f..25d62cb 100644 --- a/src/email_providers/temp_mail_org.py +++ b/src/email_providers/temp_mail_org.py @@ -5,6 +5,7 @@ import re from playwright.async_api import BrowserContext, Error as PlaywrightError, Page from .base import BaseProvider +from .utils import ensure_page logger = logging.getLogger(__name__) @@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - if self.page is None or self.page.is_closed(): - self.page = await self.browser_session.new_page() + self.page = await ensure_page(self.browser_session, self.page) return self.page async def get_new_email(self) -> str: @@ -60,9 +60,9 @@ class TempMailOrgProvider(BaseProvider): raise RuntimeError("Could not get temp email from temp-mail.org") - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: page = await self._ensure_page() - logger.info("[temp-mail.org] Waiting for latest message for %s", email) + logger.info("[temp-mail.org] Waiting for latest message") if page.is_closed(): raise RuntimeError("temp-mail.org tab was closed unexpectedly") diff --git a/src/email_providers/ten_minute_mail.py b/src/email_providers/ten_minute_mail.py index ae4c35e..1ff38a6 100644 --- a/src/email_providers/ten_minute_mail.py +++ b/src/email_providers/ten_minute_mail.py @@ -4,6 +4,7 @@ import logging from playwright.async_api import BrowserContext, Page from .base import BaseProvider +from .utils import ensure_page logger = logging.getLogger(__name__) @@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - if self.page is None or self.page.is_closed(): - self.page = await self.browser_session.new_page() + self.page = await ensure_page(self.browser_session, self.page) return self.page async def get_new_email(self) -> str: @@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider): logger.info("[10min] New email acquired: %s", email) return email - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: page = await self._ensure_page() - logger.info("[10min] Waiting for latest message for %s", email) + logger.info("[10min] Waiting for latest message") seen_count = 0 for attempt in range(60): diff --git a/src/email_providers/utils.py b/src/email_providers/utils.py new file mode 100644 index 0000000..c6d44ce --- /dev/null +++ b/src/email_providers/utils.py @@ -0,0 +1,10 @@ +from playwright.async_api import BrowserContext, Page + + +async def ensure_page( + browser_session: BrowserContext, + page: Page | None, +) -> Page: + if page is None or page.is_closed(): + return await browser_session.new_page() + return page diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index 445500f..e171b51 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -1,6 +1,5 @@ import asyncio import logging -import os from typing import Any from typing import Callable @@ -9,6 +8,7 @@ from playwright.async_api import BrowserContext from email_providers import BaseProvider from email_providers import MailTmProvider from providers.base import Provider, ProviderTokens +from utils.env import parse_int_env from .tokens import ( clear_next_tokens, load_next_tokens, @@ -24,7 +24,13 @@ from .registration import register_chatgpt_account logger = logging.getLogger(__name__) CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4 -CHATGPT_SWITCH_THRESHOLD = int(os.environ.get("CHATGPT_SWITCH_THRESHOLD", "95")) +CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +CHATGPT_SWITCH_THRESHOLD = parse_int_env( + "CHATGPT_SWITCH_THRESHOLD", + 95, + 0, + 100, +) class ChatGPTProvider(Provider): diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py index 270c6d3..3087078 100644 --- a/src/providers/chatgpt/registration.py +++ b/src/providers/chatgpt/registration.py @@ -5,7 +5,6 @@ import logging import random import re import secrets -import string import time from datetime import datetime from pathlib import Path @@ -24,6 +23,7 @@ from playwright.async_api import ( from browser import launch as launch_browser from email_providers import BaseProvider from providers.base import ProviderTokens +from utils.randoms import generate_password from .tokens import CLIENT_ID logger = logging.getLogger(__name__) @@ -56,11 +56,6 @@ async def save_error_screenshot(page: Page | None, step: str): logger.warning("Failed to save screenshot at step %s: %s", step, e) -def generate_password(length: int = 20) -> str: - alphabet = string.ascii_letters + string.digits - return "".join(random.choice(alphabet) for _ in range(length)) - - def generate_name() -> str: first_names = [ "James", @@ -255,8 +250,8 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens: raise RuntimeError(f"Token exchange response parse error: {e}") from e -async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: - message = await email_provider.get_latest_message(email) +async def get_latest_code(email_provider: BaseProvider) -> str | None: + message = await email_provider.get_latest_message() if not message: return None return extract_verification_code(message) @@ -403,7 +398,7 @@ async def register_chatgpt_account( ) logger.info("[3/5] Getting verification message from email provider...") - code = await get_latest_code(email_provider, email) + code = await get_latest_code(email_provider) if not code: raise AutomationError( "email_provider", "Email provider returned no verification message" @@ -472,7 +467,7 @@ async def register_chatgpt_account( if await oauth_needs_email_check(oauth_page): logger.info("OAuth requested email confirmation code") - new_code = await get_latest_code(email_provider, email) + new_code = await get_latest_code(email_provider) if new_code and new_code != last_oauth_email_code: filled = await fill_oauth_code_if_present(oauth_page, new_code) if filled: diff --git a/src/server.py b/src/server.py index 7740c37..6e42ecc 100644 --- a/src/server.py +++ b/src/server.py @@ -5,37 +5,16 @@ import os from aiohttp import web from providers.chatgpt import ChatGPTProvider -from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD +from providers.chatgpt.provider import ( + CHATGPT_PREPARE_THRESHOLD, + CHATGPT_SWITCH_THRESHOLD, +) from providers.base import Provider +from utils.env import parse_int_env logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - - -def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int: - raw = os.environ.get(name) - if raw is None: - return default - try: - value = int(raw) - except ValueError: - logger.warning("Invalid %s=%r, using default %s", name, raw, default) - return default - if value < minimum or value > maximum: - logger.warning( - "%s=%s out of range [%s,%s], using default %s", - name, - value, - minimum, - maximum, - default, - ) - return default - return value - - -PORT = _parse_int_env("PORT", 8080, 1, 65535) -CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +PORT = parse_int_env("PORT", 8080, 1, 65535) LIMIT_EXHAUSTED_PERCENT = 100 PROVIDERS: dict[str, Provider] = { diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..00303ef --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,4 @@ +from .env import parse_int_env +from .randoms import generate_password + +__all__ = ["parse_int_env", "generate_password"] diff --git a/src/utils/env.py b/src/utils/env.py new file mode 100644 index 0000000..688e65f --- /dev/null +++ b/src/utils/env.py @@ -0,0 +1,22 @@ +import os + + +def parse_int_env( + name: str, + default: int, + minimum: int, + maximum: int, +) -> int: + raw = os.environ.get(name) + if raw is None: + return default + + try: + value = int(raw) + except ValueError: + return default + + if value < minimum or value > maximum: + return default + + return value diff --git a/src/utils/randoms.py b/src/utils/randoms.py new file mode 100644 index 0000000..76adab8 --- /dev/null +++ b/src/utils/randoms.py @@ -0,0 +1,13 @@ +import random +import secrets +import string + + +def generate_password( + length: int = 20, + *, + secure: bool = True, +) -> str: + alphabet = string.ascii_letters + string.digits + chooser = secrets.choice if secure else random.choice + return "".join(chooser(alphabet) for _ in range(length)) diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py index 4b74a3c..bfe06d4 100644 --- a/tests/test_server_unit.py +++ b/tests/test_server_unit.py @@ -3,6 +3,7 @@ import json import server from providers.base import Provider, ProviderTokens +from utils.env import parse_int_env class FakeRequest: @@ -66,17 +67,17 @@ def _response_json(resp) -> dict: def test_parse_int_env_defaults(monkeypatch): monkeypatch.delenv("X_TEST", raising=False) - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_parse_int_env_invalid(monkeypatch): monkeypatch.setenv("X_TEST", "abc") - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_parse_int_env_out_of_range(monkeypatch): monkeypatch.setenv("X_TEST", "999") - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_build_limit_fields(): From 2611d1bb6d228faa3b90c8ed763f6fd827df8d2f Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 20:15:33 +0300 Subject: [PATCH 4/6] chore: handle noisy log messages --- src/providers/base.py | 15 ++++++++++- src/providers/chatgpt/provider.py | 18 +++++++++++--- src/server.py | 41 ++++++++++++++++++++----------- tests/test_server_unit.py | 24 ++++++++++++------ 4 files changed, 71 insertions(+), 27 deletions(-) diff --git a/src/providers/base.py b/src/providers/base.py index 3185c5b..478ab53 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -61,10 +61,23 @@ class Provider(ABC): """Rotate active account/token if provider policy requires it.""" return False + @property + def prepare_threshold(self) -> int: + """Usage percent when provider should prepare standby account/token.""" + return 100 + + @property + def switch_threshold(self) -> int | None: + """Usage percent when provider may switch active account/token.""" + return None + + def should_prepare_standby(self, usage_percent: int) -> bool: + """Whether standby preparation should be triggered for current usage.""" + return False + async def ensure_standby_account( self, usage_percent: int, - prepare_threshold: int, ) -> None: """Prepare standby account/token asynchronously when needed.""" return None diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index e171b51..603c247 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -43,6 +43,14 @@ class ChatGPTProvider(Provider): self.email_provider_factory = email_provider_factory or MailTmProvider self._token_write_lock = asyncio.Lock() + @property + def prepare_threshold(self) -> int: + return CHATGPT_PREPARE_THRESHOLD + + @property + def switch_threshold(self) -> int | None: + return CHATGPT_SWITCH_THRESHOLD + async def _register_with_retries(self) -> bool: for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1): logger.info( @@ -105,21 +113,23 @@ class ChatGPTProvider(Provider): async def ensure_next_account(self) -> bool: next_tokens = load_next_tokens() - if next_tokens and not next_tokens.is_expired: + if next_tokens: return True async with self._token_write_lock: next_tokens = load_next_tokens() - if next_tokens and not next_tokens.is_expired: + if next_tokens: return True return await self._create_next_account_under_lock() + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold and bool(load_next_tokens()) + async def ensure_standby_account( self, usage_percent: int, - prepare_threshold: int, ) -> None: - if usage_percent >= prepare_threshold: + if self.should_prepare_standby(usage_percent): await self.ensure_next_account() async def maybe_switch_active_account(self, usage_percent: int) -> bool: diff --git a/src/server.py b/src/server.py index 6e42ecc..be2e905 100644 --- a/src/server.py +++ b/src/server.py @@ -1,14 +1,9 @@ import asyncio import logging -import os from aiohttp import web from providers.chatgpt import ChatGPTProvider -from providers.chatgpt.provider import ( - CHATGPT_PREPARE_THRESHOLD, - CHATGPT_SWITCH_THRESHOLD, -) from providers.base import Provider from utils.env import parse_int_env @@ -42,9 +37,17 @@ def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | b def get_prepare_threshold(provider_name: str) -> int: - if provider_name == "chatgpt": - return CHATGPT_PREPARE_THRESHOLD - return 100 + provider = PROVIDERS.get(provider_name) + if not provider: + return 100 + return provider.prepare_threshold + + +def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bool: + provider = PROVIDERS.get(provider_name) + if not provider: + return False + return provider.should_prepare_standby(usage_percent) async def ensure_provider_token_ready(provider_name: str): @@ -82,10 +85,13 @@ async def ensure_standby_task(provider_name: str, usage_percent: int, reason: st provider = PROVIDERS.get(provider_name) if not provider: return + + if not provider.should_prepare_standby(usage_percent): + return + try: logger.info("[%s] Preparing standby in background (%s)", provider_name, reason) - threshold = get_prepare_threshold(provider_name) - await provider.ensure_standby_account(usage_percent, threshold) + await provider.ensure_standby_account(usage_percent) except Exception: logger.exception("[%s] Unhandled standby preparation error", provider_name) @@ -134,11 +140,11 @@ async def token_handler(request: web.Request) -> web.Response: logger.info("[%s] Active account switched before response", provider_name) prepare_threshold = get_prepare_threshold(provider_name) - if usage_percent >= prepare_threshold: + if should_trigger_standby_prepare(provider_name, usage_percent): trigger_standby_prepare( provider_name, usage_percent, - f"usage {usage_percent}% >= threshold {prepare_threshold}%", + f"usage {usage_percent}% reached standby policy", ) remaining_percent = int( @@ -209,8 +215,15 @@ def create_app() -> web.Application: if __name__ == "__main__": logger.info("Starting token service on port %s", PORT) - logger.info("ChatGPT prepare-next threshold: %s%%", CHATGPT_PREPARE_THRESHOLD) - logger.info("ChatGPT switch threshold: %s%%", CHATGPT_SWITCH_THRESHOLD) + chatgpt_provider = PROVIDERS.get("chatgpt") + if chatgpt_provider: + logger.info( + "ChatGPT prepare-next threshold: %s%%", chatgpt_provider.prepare_threshold + ) + if chatgpt_provider.switch_threshold is not None: + logger.info( + "ChatGPT switch threshold: %s%%", chatgpt_provider.switch_threshold + ) logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) app = create_app() web.run_app(app, host="0.0.0.0", port=PORT) diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py index bfe06d4..68b3030 100644 --- a/tests/test_server_unit.py +++ b/tests/test_server_unit.py @@ -26,9 +26,17 @@ class FakeProvider(Provider): "secondary_window": None, } self._rotate = rotate + self._prepare_threshold = 80 self.get_token_calls = 0 self.standby_calls = 0 + @property + def prepare_threshold(self) -> int: + return self._prepare_threshold + + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold + @property def name(self) -> str: return "fake" @@ -55,9 +63,10 @@ class FakeProvider(Provider): return self._rotate async def ensure_standby_account( - self, usage_percent: int, prepare_threshold: int + self, + usage_percent: int, ) -> None: - _ = usage_percent, prepare_threshold + _ = usage_percent self.standby_calls += 1 @@ -91,7 +100,10 @@ def test_build_limit_fields(): def test_get_prepare_threshold(): - assert server.get_prepare_threshold("chatgpt") == server.CHATGPT_PREPARE_THRESHOLD + assert ( + server.get_prepare_threshold("chatgpt") + == server.PROVIDERS["chatgpt"].prepare_threshold + ) assert server.get_prepare_threshold("unknown") == 100 @@ -105,8 +117,6 @@ def test_token_handler_success(monkeypatch): provider = FakeProvider() monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) monkeypatch.setattr(server, "background_tasks", {"fake": None}) - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) - resp = asyncio.run(server.token_handler(FakeRequest("fake"))) data = _response_json(resp) @@ -125,10 +135,9 @@ def test_token_handler_triggers_standby(monkeypatch): def fake_trigger(name, usage_percent, reason): assert name == "fake" assert usage_percent == 90 - assert "threshold" in reason + assert "standby policy" in reason called["value"] = True - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger) resp = asyncio.run(server.token_handler(FakeRequest("fake"))) @@ -143,7 +152,6 @@ def test_token_handler_rotation_path(monkeypatch): ) monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) monkeypatch.setattr(server, "background_tasks", {"fake": None}) - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None) resp = asyncio.run(server.token_handler(FakeRequest("fake"))) From 42282ce8cbbdec88df0e44f9d3ddd8e9785ff5b6 Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 20:44:34 +0300 Subject: [PATCH 5/6] chore: minor cleanup, healthcheck --- .gitignore | 1 + Dockerfile | 3 +++ src/healthcheck.py | 18 ++++++++++++++++++ src/providers/base.py | 2 +- src/providers/chatgpt/provider.py | 6 +++--- src/server.py | 10 ++++++++-- tests/test_server_unit.py | 10 ++++++++-- 7 files changed, 42 insertions(+), 8 deletions(-) create mode 100644 src/healthcheck.py diff --git a/.gitignore b/.gitignore index afa1ad1..701c6dd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__/ .ruff_cache/ .venv/ +.pytest_cache/ diff --git a/Dockerfile b/Dockerfile index 700b8ab..e08d4ba 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,6 +28,9 @@ VOLUME ["/data"] EXPOSE 80 +HEALTHCHECK --start-period=5s --start-interval=1s CMD \ + test "$(curl -fsS "http://127.0.0.1:$PORT/health")" = "ok" + STOPSIGNAL SIGINT CMD ["/entrypoint.sh"] diff --git a/src/healthcheck.py b/src/healthcheck.py new file mode 100644 index 0000000..b38e59a --- /dev/null +++ b/src/healthcheck.py @@ -0,0 +1,18 @@ +import os +import urllib.request + + +def main() -> int: + port = os.environ.get("PORT", "80") + url = f"http://127.0.0.1:{port}/health" + + try: + with urllib.request.urlopen(url, timeout=2) as resp: + body = resp.read().decode("utf-8").strip() + return 0 if resp.status == 200 and body == "ok" else 1 + except Exception: + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/src/providers/base.py b/src/providers/base.py index 478ab53..5f9cdc6 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -71,7 +71,7 @@ class Provider(ABC): """Usage percent when provider may switch active account/token.""" return None - def should_prepare_standby(self, usage_percent: int) -> bool: + def should_prepare_standby(self) -> bool: """Whether standby preparation should be triggered for current usage.""" return False diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index 603c247..b536611 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -122,14 +122,14 @@ class ChatGPTProvider(Provider): return True return await self._create_next_account_under_lock() - def should_prepare_standby(self, usage_percent: int) -> bool: - return usage_percent >= self.prepare_threshold and bool(load_next_tokens()) + def should_prepare_standby(self) -> bool: + return bool(load_next_tokens()) async def ensure_standby_account( self, usage_percent: int, ) -> None: - if self.should_prepare_standby(usage_percent): + if usage_percent >= self.prepare_threshold: await self.ensure_next_account() async def maybe_switch_active_account(self, usage_percent: int) -> bool: diff --git a/src/server.py b/src/server.py index be2e905..171b5c8 100644 --- a/src/server.py +++ b/src/server.py @@ -47,7 +47,7 @@ def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bo provider = PROVIDERS.get(provider_name) if not provider: return False - return provider.should_prepare_standby(usage_percent) + return provider.should_prepare_standby() async def ensure_provider_token_ready(provider_name: str): @@ -86,7 +86,7 @@ async def ensure_standby_task(provider_name: str, usage_percent: int, reason: st if not provider: return - if not provider.should_prepare_standby(usage_percent): + if not provider.should_prepare_standby(): return try: @@ -188,6 +188,11 @@ async def token_handler(request: web.Request) -> web.Response: ) +async def health_handler(request: web.Request) -> web.Response: + del request + return web.Response(text="ok") + + async def on_startup(app: web.Application): del app for provider_name in PROVIDERS: @@ -208,6 +213,7 @@ def create_app() -> web.Application: app = web.Application(middlewares=[request_log_middleware]) app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) + app.router.add_get("/health", health_handler) app.router.add_get("/{provider}/token", token_handler) app.router.add_get("/token", token_handler) return app diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py index 68b3030..8ebefb8 100644 --- a/tests/test_server_unit.py +++ b/tests/test_server_unit.py @@ -34,8 +34,8 @@ class FakeProvider(Provider): def prepare_threshold(self) -> int: return self._prepare_threshold - def should_prepare_standby(self, usage_percent: int) -> bool: - return usage_percent >= self.prepare_threshold + def should_prepare_standby(self) -> bool: + return False @property def name(self) -> str: @@ -113,6 +113,12 @@ def test_token_handler_unknown_provider(monkeypatch): assert resp.status == 404 +def test_health_handler_ok(): + resp = asyncio.run(server.health_handler(object())) + assert resp.status == 200 + assert resp.text == "ok" + + def test_token_handler_success(monkeypatch): provider = FakeProvider() monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) From 79460e499817e0f607f34de2e48231627cf6d32e Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 20:51:29 +0300 Subject: [PATCH 6/6] chore: remove noisy healthcheck messages --- src/healthcheck.py | 18 ------------------ src/server.py | 22 ++++++++++++++-------- 2 files changed, 14 insertions(+), 26 deletions(-) delete mode 100644 src/healthcheck.py diff --git a/src/healthcheck.py b/src/healthcheck.py deleted file mode 100644 index b38e59a..0000000 --- a/src/healthcheck.py +++ /dev/null @@ -1,18 +0,0 @@ -import os -import urllib.request - - -def main() -> int: - port = os.environ.get("PORT", "80") - url = f"http://127.0.0.1:{port}/health" - - try: - with urllib.request.urlopen(url, timeout=2) as resp: - body = resp.read().decode("utf-8").strip() - return 0 if resp.status == 200 and body == "ok" else 1 - except Exception: - return 1 - - -if __name__ == "__main__": - exit(main()) diff --git a/src/server.py b/src/server.py index 171b5c8..464d92b 100644 --- a/src/server.py +++ b/src/server.py @@ -1,7 +1,7 @@ import asyncio import logging -from aiohttp import web +from aiohttp import web, web_log from providers.chatgpt import ChatGPTProvider from providers.base import Provider @@ -19,11 +19,13 @@ PROVIDERS: dict[str, Provider] = { background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS} -@web.middleware -async def request_log_middleware(request: web.Request, handler): - response = await handler(request) - logger.info("%s %s -> %s", request.method, request.path_qs, response.status) - return response +class AccessLogger(web_log.AccessLogger): + def log( + self, request: web.BaseRequest, response: web.StreamResponse, time: float + ) -> None: + if request.path == "/health": + return + super().log(request, response, time) def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]: @@ -210,7 +212,6 @@ async def on_cleanup(app: web.Application): def create_app() -> web.Application: - app = web.Application(middlewares=[request_log_middleware]) app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) app.router.add_get("/health", health_handler) @@ -232,4 +233,9 @@ if __name__ == "__main__": ) logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) app = create_app() - web.run_app(app, host="0.0.0.0", port=PORT) + web.run_app( + app, + host="0.0.0.0", + port=PORT, + access_log_class=AccessLogger, + )