From 0af717959665bf7aed9ea6c61e8818e9b6899a7e Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 21:14:20 +0300 Subject: [PATCH] refactor!: a lot of stuff --- README.md | 16 +- pyproject.toml | 5 + src/browser.py | 13 +- src/email_providers/temp_mail_org.py | 24 +-- src/providers/base.py | 20 +++ src/providers/chatgpt/provider.py | 69 +++++--- src/providers/chatgpt/registration.py | 149 ++++++++++------- src/providers/chatgpt/tokens.py | 97 ++++++----- src/providers/chatgpt/usage.py | 54 +++--- src/server.py | 227 ++++++++++++-------------- tests/conftest.py | 12 ++ tests/test_registration_unit.py | 37 +++++ tests/test_server_unit.py | 150 +++++++++++++++++ tests/test_tokens_unit.py | 60 +++++++ tests/test_usage_unit.py | 32 ++++ 15 files changed, 663 insertions(+), 302 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_registration_unit.py create mode 100644 tests/test_server_unit.py create mode 100644 tests/test_tokens_unit.py create mode 100644 tests/test_usage_unit.py diff --git a/README.md b/README.md index 6f423f1..a061094 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Response shape: "used_percent": 0, "remaining_percent": 100, "exhausted": false, - "needs_refresh": false + "needs_prepare": false }, "usage": { "primary_window": { @@ -54,10 +54,8 @@ Behavior: ## Startup Behavior -On startup, service: - -1. Ensures active token exists and is usable. -2. Ensures `next_account` is prepared for ChatGPT. +On startup, service ensures active token exists and is usable. +Standby preparation runs through provider lifecycle hooks/background trigger when needed. ## Data Files @@ -70,6 +68,14 @@ On startup, service: PYTHONPATH=./src python src/server.py ``` +## Unit Tests + +The project has unit tests only (no integration/network tests). + +```bash +pytest -q +``` + ## Docker Notes - Dockerfile sets `DATA_DIR=/data`. diff --git a/pyproject.toml b/pyproject.toml index cc1ff33..7f927f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,5 +8,10 @@ dependencies = [ "pkce==1.0.3", ] +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", +] + [tool.uv] package = false diff --git a/src/browser.py b/src/browser.py index 3b35da9..daa7fef 100644 --- a/src/browser.py +++ b/src/browser.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import socket import shutil import subprocess import tempfile @@ -44,7 +45,12 @@ CHROME_FLAGS = [ "--disable-search-engine-choice-screen", ] -DEFAULT_CDP_PORT = 9222 + +def _allocate_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return int(s.getsockname()[1]) def _fetch_ws_endpoint(port: int) -> str | None: @@ -79,10 +85,9 @@ class ManagedBrowser: shutil.rmtree(self.profile_dir, ignore_errors=True) -async def launch( - playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT -) -> ManagedBrowser: +async def launch(playwright: Playwright) -> ManagedBrowser: chrome_path = playwright.chromium.executable_path + cdp_port = _allocate_free_port() profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp")) args = [ diff --git a/src/email_providers/temp_mail_org.py b/src/email_providers/temp_mail_org.py index cb2b840..0a0b65f 100644 --- a/src/email_providers/temp_mail_org.py +++ b/src/email_providers/temp_mail_org.py @@ -2,7 +2,7 @@ import asyncio import logging import re -from playwright.async_api import BrowserContext, Page +from playwright.async_api import BrowserContext, Error as PlaywrightError, Page from .base import BaseProvider @@ -44,7 +44,7 @@ class TempMailOrgProvider(BaseProvider): value, ) return value - except Exception: + except PlaywrightError: continue try: @@ -53,8 +53,8 @@ class TempMailOrgProvider(BaseProvider): if found: logger.info("[temp-mail.org] email found by body scan: %s", found) return found - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to scan body text for email") await asyncio.sleep(1) @@ -76,7 +76,7 @@ class TempMailOrgProvider(BaseProvider): try: count = await items.count() logger.info("[temp-mail.org] inbox items: %s", count) - except Exception: + except PlaywrightError: count = 0 if count > 0: @@ -87,30 +87,30 @@ class TempMailOrgProvider(BaseProvider): continue text = (await item.inner_text()).strip().replace("\n", " ") logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160]) - except Exception: + except PlaywrightError: continue if text: try: await item.click() logger.info("[temp-mail.org] opened item[%s]", idx) - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to open inbox item[%s]", idx) message_text = text try: content = await page.content() if content and "Your ChatGPT code is" in content: message_text = content - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to read opened message content") try: await page.go_back( wait_until="domcontentloaded", timeout=5000 ) logger.info("[temp-mail.org] returned back to inbox") - except Exception: - pass + except PlaywrightError: + logger.debug("Failed to return back to inbox") return message_text diff --git a/src/providers/base.py b/src/providers/base.py index 1a5b6bc..3185c5b 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -52,3 +52,23 @@ class Provider(ABC): def save_tokens(self, tokens: ProviderTokens) -> None: """Save tokens to storage""" pass + + async def force_recreate_token(self) -> str | None: + """Force-create a new active token when normal acquisition fails.""" + return None + + async def maybe_rotate_account(self, usage_percent: int) -> bool: + """Rotate active account/token if provider policy requires it.""" + return False + + async def ensure_standby_account( + self, + usage_percent: int, + prepare_threshold: int, + ) -> None: + """Prepare standby account/token asynchronously when needed.""" + return None + + async def startup_prepare(self) -> None: + """Optional provider-specific startup preparation.""" + return None diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index 2971c27..ae2e527 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -16,7 +16,7 @@ from .tokens import ( load_tokens, promote_next_tokens, refresh_tokens, - save_next_tokens, + save_state, save_tokens, ) from .usage import get_usage_data @@ -44,10 +44,14 @@ class ChatGPTProvider(Provider): attempt, CHATGPT_REGISTRATION_MAX_ATTEMPTS, ) - success = await self.register_new_account() - if success: + generated_tokens = await register_chatgpt_account( + email_provider_factory=self.email_provider_factory, + ) + if generated_tokens: + save_tokens(generated_tokens) return True logger.warning("Registration attempt %s failed", attempt) + await asyncio.sleep(1.5 * attempt) return False async def _create_next_account_under_lock(self) -> bool: @@ -56,23 +60,28 @@ class ChatGPTProvider(Provider): return True logger.info("Creating next account") - success = await self._register_with_retries() - if not success: - return False + for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1): + logger.info( + "Next-account registration attempt %s/%s", + attempt, + CHATGPT_REGISTRATION_MAX_ATTEMPTS, + ) + generated_tokens = await register_chatgpt_account( + email_provider_factory=self.email_provider_factory, + ) + if generated_tokens: + if active_before: + save_state(active_before, generated_tokens) + else: + save_state(generated_tokens, None) + logger.info("Next account is ready") + return True + logger.warning("Next-account registration attempt %s failed", attempt) + await asyncio.sleep(1.5 * attempt) - generated_active = load_tokens() - if not generated_active: - return False - - # Registration writes new tokens as active; restore old active and keep - # generated account as next. - if active_before: - save_tokens(active_before) - else: - clear_next_tokens() - save_next_tokens(generated_active) - logger.info("Next account is ready") - return True + if active_before or next_before: + save_state(active_before, next_before) + return False async def force_recreate_token(self) -> str | None: async with self._token_write_lock: @@ -85,6 +94,9 @@ class ChatGPTProvider(Provider): return None return tokens.access_token + async def startup_prepare(self) -> None: + await self.ensure_next_account() + async def ensure_next_account(self) -> bool: next_tokens = load_next_tokens() if next_tokens and not next_tokens.is_expired: @@ -96,6 +108,14 @@ class ChatGPTProvider(Provider): return True return await self._create_next_account_under_lock() + async def ensure_standby_account( + self, + usage_percent: int, + prepare_threshold: int, + ) -> None: + if usage_percent >= prepare_threshold: + await self.ensure_next_account() + async def maybe_switch_active_account(self, usage_percent: int) -> bool: if usage_percent < CHATGPT_SWITCH_THRESHOLD: return False @@ -119,6 +139,9 @@ class ChatGPTProvider(Provider): ) return switched + async def maybe_rotate_account(self, usage_percent: int) -> bool: + return await self.maybe_switch_active_account(usage_percent) + @property def name(self) -> str: return "chatgpt" @@ -154,13 +177,17 @@ class ChatGPTProvider(Provider): async def register_new_account(self) -> bool: """Register a new ChatGPT account""" - return await register_chatgpt_account( + generated_tokens = await register_chatgpt_account( email_provider_factory=self.email_provider_factory, ) + if not generated_tokens: + return False + save_tokens(generated_tokens) + return True async def get_usage_info(self, access_token: str) -> dict[str, Any]: """Get usage information for the current token""" - usage_data = get_usage_data(access_token) + usage_data = await get_usage_data(access_token) if not usage_data: return {"error": "Failed to get usage"} diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py index 1af2581..3d106c1 100644 --- a/src/providers/chatgpt/registration.py +++ b/src/providers/chatgpt/registration.py @@ -14,12 +14,17 @@ from typing import Callable from urllib.parse import parse_qs, urlencode, urlparse import aiohttp -from playwright.async_api import async_playwright, Page, BrowserContext +from playwright.async_api import ( + async_playwright, + Error as PlaywrightError, + Page, + BrowserContext, +) from browser import launch as launch_browser from email_providers import BaseProvider from providers.base import ProviderTokens -from .tokens import CLIENT_ID, save_tokens +from .tokens import CLIENT_ID logger = logging.getLogger(__name__) @@ -46,9 +51,9 @@ async def save_error_screenshot(page: Page | None, step: str): filename = screenshots_dir / f"error_{step}_{timestamp}.png" try: await page.screenshot(path=str(filename)) - logger.error(f"Screenshot saved: {filename}") - except: - pass + logger.error("Screenshot saved: %s", filename) + except PlaywrightError as e: + logger.warning("Failed to save screenshot at step %s: %s", step, e) def generate_password(length: int = 20) -> str: @@ -204,8 +209,7 @@ def generate_state() -> str: return secrets.token_urlsafe(32) -def build_authorize_url(verifier: str, challenge: str, state: str) -> str: - del verifier +def build_authorize_url(challenge: str, state: str) -> str: params = { "response_type": "code", "client_id": CLIENT_ID, @@ -222,26 +226,33 @@ def build_authorize_url(verifier: str, challenge: str, state: str) -> str: async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens: - async with aiohttp.ClientSession() as session: - payload = { - "grant_type": "authorization_code", - "client_id": CLIENT_ID, - "code": code, - "code_verifier": verifier, - "redirect_uri": REDIRECT_URI, - } - async with session.post(TOKEN_URL, data=payload) as resp: - if not resp.ok: - text = await resp.text() - raise RuntimeError(f"Token exchange failed: {resp.status} {text}") - body = await resp.json() + payload = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": REDIRECT_URI, + } + timeout = aiohttp.ClientTimeout(total=20) + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=payload) as resp: + if not resp.ok: + text = await resp.text() + raise RuntimeError(f"Token exchange failed: {resp.status} {text}") + body = await resp.json() + except (aiohttp.ClientError, TimeoutError) as e: + raise RuntimeError(f"Token exchange request error: {e}") from e - expires_in = int(body["expires_in"]) - return ProviderTokens( - access_token=body["access_token"], - refresh_token=body["refresh_token"], - expires_at=time.time() + expires_in, - ) + try: + expires_in = int(body["expires_in"]) + return ProviderTokens( + access_token=body["access_token"], + refresh_token=body["refresh_token"], + expires_at=time.time() + expires_in, + ) + except (KeyError, TypeError, ValueError) as e: + raise RuntimeError(f"Token exchange response parse error: {e}") from e async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: @@ -270,6 +281,24 @@ async def click_continue(page: Page, timeout_ms: int = 10000): await btn.click() +async def click_any_visible_button( + page: Page, + labels: list[str], + timeout_ms: int = 2000, +) -> bool: + for label in labels: + button = page.get_by_role("button", name=label) + if await button.count() == 0: + continue + try: + await button.first.wait_for(state="visible", timeout=timeout_ms) + await button.first.click(timeout=timeout_ms) + return True + except PlaywrightError: + continue + return False + + async def wait_for_signup_stabilization( page: Page, source_url: str, @@ -288,12 +317,12 @@ async def wait_for_signup_stabilization( async def register_chatgpt_account( email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, -) -> bool: +) -> ProviderTokens | None: logger.info("=== Starting ChatGPT account registration ===") if email_provider_factory is None: logger.error("No email provider factory configured") - return False + return None birth_month, birth_day, birth_year = generate_birthdate_90s() @@ -321,7 +350,7 @@ async def register_chatgpt_account( full_name = generate_name() verifier, challenge = generate_pkce_pair() oauth_state = generate_state() - authorize_url = build_authorize_url(verifier, challenge, oauth_state) + authorize_url = build_authorize_url(challenge, oauth_state) logger.info("[2/5] Registering ChatGPT for %s", email) chatgpt_page = await context.new_page() @@ -352,19 +381,18 @@ async def register_chatgpt_account( raise AutomationError( "email_provider", "Email provider returned no verification message" ) - logger.info("[3/5] Verification code extracted: %s", code) + logger.info("[3/5] Verification code extracted") await chatgpt_page.bring_to_front() code_input = chatgpt_page.get_by_placeholder("Code") - if await code_input.count() > 0: - await code_input.fill(code) + await code_input.first.wait_for(state="visible", timeout=10000) + await code_input.first.fill(code) await click_continue(chatgpt_page) logger.info("[4/5] Setting profile...") name_input = chatgpt_page.get_by_placeholder("Full name") await name_input.first.wait_for(state="visible", timeout=20000) - if await name_input.count() > 0: - await name_input.fill(full_name) + await name_input.first.fill(full_name) await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year) profile_url = chatgpt_page.url @@ -387,45 +415,42 @@ async def register_chatgpt_account( oauth_page.on("request", handle_request) await oauth_page.goto(authorize_url, wait_until="domcontentloaded") - await oauth_page.locator( - 'input[type="email"], input[name="email"]' - ).first.wait_for(state="visible", timeout=20000) - email_input = oauth_page.locator('input[type="email"], input[name="email"]') if await email_input.count() > 0: + await email_input.first.wait_for(state="visible", timeout=10000) await email_input.first.fill(email) - - continue_button = oauth_page.get_by_role("button", name="Continue") - if await continue_button.count() > 0: - await continue_button.first.click() - await oauth_page.locator('input[type="password"]').first.wait_for( - state="visible", timeout=20000 + await click_any_visible_button( + oauth_page, ["Continue"], timeout_ms=4000 ) password_input = oauth_page.locator('input[type="password"]') if await password_input.count() > 0: + await password_input.first.wait_for(state="visible", timeout=10000) await password_input.first.fill(password) - continue_button = oauth_page.get_by_role("button", name="Continue") - if await continue_button.count() > 0: - await continue_button.first.click() + await click_any_visible_button( + oauth_page, ["Continue"], timeout_ms=4000 + ) - for label in ["Continue", "Allow", "Authorize"]: - button = oauth_page.get_by_role("button", name=label) - if await button.count() > 0: - try: - await button.first.click(timeout=5000) - await oauth_page.wait_for_timeout(500) - except Exception: - pass + for _ in range(6): + if redirect_url_captured: + break + clicked = await click_any_visible_button( + oauth_page, + ["Continue", "Allow", "Authorize"], + timeout_ms=2000, + ) + if clicked: + await asyncio.sleep(0.4) + else: + await asyncio.sleep(0.4) if not redirect_url_captured: try: - await oauth_page.wait_for_timeout(4000) current_url = oauth_page.url if "localhost:1455" in current_url and "code=" in current_url: redirect_url_captured = current_url logger.info("Captured OAuth redirect from page URL") - except Exception: + except PlaywrightError: pass if not redirect_url_captured: @@ -446,20 +471,18 @@ async def register_chatgpt_account( raise AutomationError("oauth", "OAuth state mismatch", oauth_page) tokens = await exchange_code_for_tokens(auth_code, verifier) - save_tokens(tokens) - logger.info("OAuth tokens saved successfully") + logger.info("OAuth tokens fetched successfully") - return True + return tokens except AutomationError as e: logger.error(f"Error at step [{e.step}]: {e.message}") await save_error_screenshot(e.page, e.step) - return False + return None except Exception as e: logger.error(f"Unexpected error: {e}") await save_error_screenshot(current_page, "unexpected") - return False + return None finally: if managed: - await asyncio.sleep(2) await managed.close() diff --git a/src/providers/chatgpt/tokens.py b/src/providers/chatgpt/tokens.py index cc9a4f6..b4ed7fb 100644 --- a/src/providers/chatgpt/tokens.py +++ b/src/providers/chatgpt/tokens.py @@ -1,6 +1,7 @@ import json import logging import os +import tempfile import time from pathlib import Path from typing import Any @@ -35,7 +36,7 @@ def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None: refresh_token=data["refresh_token"], expires_at=data["expires_at"], ) - except KeyError, TypeError: + except (KeyError, TypeError): return None @@ -54,8 +55,20 @@ def _load_raw() -> dict[str, Any] | None: def _save_raw(data: dict[str, Any]) -> None: TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) - with open(TOKENS_FILE, "w") as f: - json.dump(data, f, indent=2) + fd, tmp_path = tempfile.mkstemp( + prefix=f"{TOKENS_FILE.name}.", + suffix=".tmp", + dir=str(TOKENS_FILE.parent), + ) + try: + with os.fdopen(fd, "w") as f: + json.dump(data, f, indent=2) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, TOKENS_FILE) + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]: @@ -68,7 +81,6 @@ def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]: "next_account": data.get("next_account"), } - # Backward compatibility with old flat schema return {"active": data, "next_account": None} @@ -79,9 +91,7 @@ def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]: return active, next_account -def save_state( - active: ProviderTokens | None, next_account: ProviderTokens | None -) -> None: +def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None: payload = { "active": _tokens_to_dict(active) if active else None, "next_account": _tokens_to_dict(next_account) if next_account else None, @@ -104,13 +114,8 @@ def save_tokens(tokens: ProviderTokens): save_state(tokens, next_account) -def save_next_tokens(tokens: ProviderTokens): - active, _ = load_state() - save_state(active, tokens) - - def promote_next_tokens() -> bool: - active, next_account = load_state() + _, next_account = load_state() if not next_account: return False save_state(next_account, None) @@ -123,42 +128,34 @@ def clear_next_tokens(): async def refresh_tokens(refresh_token: str) -> ProviderTokens | None: - async with aiohttp.ClientSession() as session: - data = { - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "client_id": CLIENT_ID, - } - async with session.post(TOKEN_URL, data=data) as resp: - if not resp.ok: - text = await resp.text() - logger.warning("Token refresh failed: %s %s", resp.status, text) - return None - json_resp = await resp.json() - expires_in = json_resp["expires_in"] - return ProviderTokens( - access_token=json_resp["access_token"], - refresh_token=json_resp["refresh_token"], - expires_at=time.time() + expires_in, - ) - - -async def get_valid_tokens() -> ProviderTokens | None: - tokens = load_tokens() - if not tokens: - logger.info("No tokens found") + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CLIENT_ID, + } + timeout = aiohttp.ClientTimeout(total=15) + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=data) as resp: + if not resp.ok: + text = await resp.text() + logger.warning("Token refresh failed: %s %s", resp.status, text) + return None + json_resp = await resp.json() + except (aiohttp.ClientError, TimeoutError) as e: + logger.warning("Token refresh request error: %s", e) + return None + except Exception as e: + logger.warning("Token refresh unexpected error: %s", e) return None - if tokens.is_expired: - logger.info("Token expired, refreshing...") - if not tokens.refresh_token: - logger.info("No refresh token available") - return None - new_tokens = await refresh_tokens(tokens.refresh_token) - if not new_tokens: - logger.warning("Failed to refresh token") - return None - save_tokens(new_tokens) - return new_tokens - - return tokens + try: + expires_in = int(json_resp["expires_in"]) + return ProviderTokens( + access_token=json_resp["access_token"], + refresh_token=json_resp["refresh_token"], + expires_at=time.time() + expires_in, + ) + except (KeyError, TypeError, ValueError) as e: + logger.warning("Token refresh response parse error: %s", e) + return None diff --git a/src/providers/chatgpt/usage.py b/src/providers/chatgpt/usage.py index 6236bd8..7c2edab 100644 --- a/src/providers/chatgpt/usage.py +++ b/src/providers/chatgpt/usage.py @@ -1,14 +1,15 @@ -import json -import socket -import urllib.error -import urllib.request +import logging from typing import Any +import aiohttp + +logger = logging.getLogger(__name__) + def clamp_percent(value: Any) -> int: try: num = float(value) - except TypeError, ValueError: + except (TypeError, ValueError): return 0 if num < 0: return 0 @@ -28,30 +29,36 @@ def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None: } -def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None: +async def get_usage_data( + access_token: str, + timeout_ms: int = 10000, +) -> dict[str, Any] | None: headers = { "Authorization": f"Bearer {access_token}", "User-Agent": "CodexProxy", "Accept": "application/json", } - req = urllib.request.Request( - "https://chatgpt.com/backend-api/wham/usage", - headers=headers, - method="GET", - ) + timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000) + url = "https://chatgpt.com/backend-api/wham/usage" try: - with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res: - body = res.read().decode("utf-8", errors="replace") - except urllib.error.HTTPError: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as res: + if not res.ok: + body = await res.text() + logger.warning( + "Usage fetch failed: status=%s body=%s", + res.status, + body[:300], + ) + return None + data = await res.json() + except (aiohttp.ClientError, TimeoutError) as e: + logger.warning("Usage fetch request error: %s", e) return None - except urllib.error.URLError, socket.timeout: - return None - - try: - data = json.loads(body) - except json.JSONDecodeError: + except Exception as e: + logger.warning("Usage fetch unexpected error: %s", e) return None rate_limit = data.get("rate_limit") or {} @@ -76,10 +83,3 @@ def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] "limit_reached": bool(rate_limit.get("limit_reached")), "allowed": bool(rate_limit.get("allowed", True)), } - - -def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int: - data = get_usage_data(access_token, timeout_ms=timeout_ms) - if not data: - return -1 - return int(data["used_percent"]) diff --git a/src/server.py b/src/server.py index 64ac91d..7740c37 100644 --- a/src/server.py +++ b/src/server.py @@ -5,23 +5,44 @@ import os from aiohttp import web from providers.chatgpt import ChatGPTProvider - -PORT = int(os.environ.get("PORT", "8080")) -CHATGPT_PREPARE_THRESHOLD = int(os.environ.get("CHATGPT_PREPARE_THRESHOLD", "85")) -LIMIT_EXHAUSTED_PERCENT = 100 +from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD +from providers.base import Provider logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Registry of available providers -PROVIDERS = { + +def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int: + raw = os.environ.get(name) + if raw is None: + return default + try: + value = int(raw) + except ValueError: + logger.warning("Invalid %s=%r, using default %s", name, raw, default) + return default + if value < minimum or value > maximum: + logger.warning( + "%s=%s out of range [%s,%s], using default %s", + name, + value, + minimum, + maximum, + default, + ) + return default + return value + + +PORT = _parse_int_env("PORT", 8080, 1, 65535) +CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +LIMIT_EXHAUSTED_PERCENT = 100 + +PROVIDERS: dict[str, Provider] = { "chatgpt": ChatGPTProvider(), } -refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()} -background_refresh_tasks: dict[str, asyncio.Task | None] = { - name: None for name in PROVIDERS.keys() -} +background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS} @web.middleware @@ -31,16 +52,22 @@ async def request_log_middleware(request: web.Request, handler): return response -def build_limit(usage_percent: int) -> dict[str, int | bool]: +def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]: remaining = max(0, 100 - usage_percent) return { "used_percent": usage_percent, "remaining_percent": remaining, "exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT, - "needs_refresh": usage_percent >= CHATGPT_PREPARE_THRESHOLD, + "needs_prepare": usage_percent >= prepare_threshold, } +def get_prepare_threshold(provider_name: str) -> int: + if provider_name == "chatgpt": + return CHATGPT_PREPARE_THRESHOLD + return 100 + + async def ensure_provider_token_ready(provider_name: str): provider = PROVIDERS.get(provider_name) if not provider: @@ -52,135 +79,92 @@ async def ensure_provider_token_ready(provider_name: str): logger.warning( "[%s] Startup token check failed, forcing recreation", provider_name ) - if isinstance(provider, ChatGPTProvider): - token = await provider.force_recreate_token() + token = await provider.force_recreate_token() if not token: logger.error("[%s] Could not prepare token at startup", provider_name) return - if isinstance(provider, ChatGPTProvider): - await provider.ensure_next_account() - usage_info = await provider.get_usage_info(token) - if "error" not in usage_info: - logger.info("[%s] Startup token is ready", provider_name) - return - - logger.warning( - "[%s] Startup token invalid for usage, forcing recreation", provider_name - ) - if isinstance(provider, ChatGPTProvider): + if "error" in usage_info: + logger.warning( + "[%s] Startup token invalid for usage, forcing recreation", provider_name + ) token = await provider.force_recreate_token() - if token: - logger.info("[%s] Startup token recreated successfully", provider_name) + if not token: + logger.error("[%s] Startup token recreation failed", provider_name) return - logger.error("[%s] Startup token recreation failed", provider_name) + await provider.startup_prepare() + logger.info("[%s] Startup token is ready", provider_name) -async def on_startup(app: web.Application): - del app - for provider_name in PROVIDERS.keys(): - await ensure_provider_token_ready(provider_name) - - -async def issue_new_token(provider_name: str) -> str | None: +async def ensure_standby_task(provider_name: str, usage_percent: int, reason: str): provider = PROVIDERS.get(provider_name) if not provider: - return None - - async with refresh_locks[provider_name]: - logger.info(f"[{provider_name}] Generating new token") - success = await provider.register_new_account() - if not success: - logger.error(f"[{provider_name}] Token generation failed") - return None - - token = await provider.get_token() - if not token: - logger.error(f"[{provider_name}] Token was generated but not available") - return None - - return token - - -async def background_refresh_worker(provider_name: str, reason: str): + return try: - logger.info(f"[{provider_name}] Starting background token refresh ({reason})") - new_token = await issue_new_token(provider_name) - if new_token: - logger.info(f"[{provider_name}] Background token refresh completed") - else: - logger.error(f"[{provider_name}] Background token refresh failed") + logger.info("[%s] Preparing standby in background (%s)", provider_name, reason) + threshold = get_prepare_threshold(provider_name) + await provider.ensure_standby_account(usage_percent, threshold) except Exception: - logger.exception( - f"[{provider_name}] Unhandled error in background token refresh" - ) + logger.exception("[%s] Unhandled standby preparation error", provider_name) -def trigger_background_refresh(provider_name: str, reason: str): - task = background_refresh_tasks.get(provider_name) +def trigger_standby_prepare(provider_name: str, usage_percent: int, reason: str): + task = background_tasks.get(provider_name) if task and not task.done(): logger.info( - f"[{provider_name}] Background refresh already running, skip ({reason})" + "[%s] Standby prep already running, skip (%s)", provider_name, reason ) return - background_refresh_tasks[provider_name] = asyncio.create_task( - background_refresh_worker(provider_name, reason) + background_tasks[provider_name] = asyncio.create_task( + ensure_standby_task(provider_name, usage_percent, reason) ) async def token_handler(request: web.Request) -> web.Response: provider_name = request.match_info.get("provider", "chatgpt") - provider = PROVIDERS.get(provider_name) if not provider: return web.json_response( - {"error": f"Unknown provider: {provider_name}"}, - status=404, + {"error": f"Unknown provider: {provider_name}"}, status=404 ) - # Get or create token token = await provider.get_token() if not token: - return web.json_response( - {"error": "Failed to get active token"}, - status=503, - ) + return web.json_response({"error": "Failed to get active token"}, status=503) - # Get usage info usage_info = await provider.get_usage_info(token) if "error" in usage_info: - return web.json_response( - {"error": usage_info["error"]}, - status=503, + return web.json_response({"error": usage_info["error"]}, status=503) + + usage_percent = int(usage_info.get("used_percent", 0)) + switched = await provider.maybe_rotate_account(usage_percent) + if switched: + token = await provider.get_token() + if not token: + return web.json_response( + {"error": "Failed to get active token after account switch"}, + status=503, + ) + usage_info = await provider.get_usage_info(token) + if "error" in usage_info: + return web.json_response({"error": usage_info["error"]}, status=503) + usage_percent = int(usage_info.get("used_percent", 0)) + logger.info("[%s] Active account switched before response", provider_name) + + prepare_threshold = get_prepare_threshold(provider_name) + if usage_percent >= prepare_threshold: + trigger_standby_prepare( + provider_name, + usage_percent, + f"usage {usage_percent}% >= threshold {prepare_threshold}%", ) - usage_percent = usage_info.get("used_percent", 0) - remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent)) - - if isinstance(provider, ChatGPTProvider): - switched = await provider.maybe_switch_active_account(usage_percent) - if switched: - token = await provider.get_token() - if not token: - return web.json_response( - {"error": "Failed to get active token after account switch"}, - status=503, - ) - usage_info = await provider.get_usage_info(token) - if "error" in usage_info: - return web.json_response( - {"error": usage_info["error"]}, - status=503, - ) - usage_percent = usage_info.get("used_percent", 0) - remaining_percent = usage_info.get( - "remaining_percent", max(0, 100 - usage_percent) - ) - logger.info("[%s] Active account switched before response", provider_name) - + remaining_percent = int( + usage_info.get("remaining_percent", max(0, 100 - usage_percent)) + ) logger.info( "[%s] token issued, used=%s%% remaining=%s%%", provider_name, @@ -207,20 +191,10 @@ async def token_handler(request: web.Request) -> web.Response: secondary_window.get("reset_after_seconds", 0), ) - # Trigger background refresh if needed - if usage_percent >= CHATGPT_PREPARE_THRESHOLD: - if isinstance(provider, ChatGPTProvider): - await provider.ensure_next_account() - else: - trigger_background_refresh( - provider_name, - f"usage {usage_percent}% >= threshold {CHATGPT_PREPARE_THRESHOLD}%", - ) - return web.json_response( { "token": token, - "limit": build_limit(usage_percent), + "limit": build_limit(usage_percent, prepare_threshold), "usage": { "primary_window": primary_window, "secondary_window": secondary_window, @@ -229,22 +203,35 @@ async def token_handler(request: web.Request) -> web.Response: ) +async def on_startup(app: web.Application): + del app + for provider_name in PROVIDERS: + await ensure_provider_token_ready(provider_name) + + +async def on_cleanup(app: web.Application): + del app + for task in background_tasks.values(): + if task and not task.done(): + task.cancel() + pending = [t for t in background_tasks.values() if t is not None] + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + def create_app() -> web.Application: app = web.Application(middlewares=[request_log_middleware]) app.on_startup.append(on_startup) - # New route: /{provider}/token + app.on_cleanup.append(on_cleanup) app.router.add_get("/{provider}/token", token_handler) - # Legacy route for backward compatibility app.router.add_get("/token", token_handler) return app if __name__ == "__main__": logger.info("Starting token service on port %s", PORT) - logger.info( - "ChatGPT prepare-next threshold: %s%%", - CHATGPT_PREPARE_THRESHOLD, - ) + logger.info("ChatGPT prepare-next threshold: %s%%", CHATGPT_PREPARE_THRESHOLD) + logger.info("ChatGPT switch threshold: %s%%", CHATGPT_SWITCH_THRESHOLD) logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) app = create_app() web.run_app(app, host="0.0.0.0", port=PORT) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9bf27d4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import sys +from pathlib import Path + + +def _add_src_to_path() -> None: + root = Path(__file__).resolve().parents[1] + src = root / "src" + if str(src) not in sys.path: + sys.path.insert(0, str(src)) + + +_add_src_to_path() diff --git a/tests/test_registration_unit.py b/tests/test_registration_unit.py new file mode 100644 index 0000000..32fd941 --- /dev/null +++ b/tests/test_registration_unit.py @@ -0,0 +1,37 @@ +from providers.chatgpt.registration import ( + build_authorize_url, + extract_verification_code, + generate_birthdate_90s, + generate_name, +) + + +def test_generate_name_shape(): + name = generate_name() + parts = name.split(" ") + assert len(parts) == 2 + assert all(p.isalpha() for p in parts) + + +def test_generate_birthdate_90s_range(): + month, day, year = generate_birthdate_90s() + assert 1 <= int(month) <= 12 + assert 1 <= int(day) <= 28 + assert 1990 <= int(year) <= 1999 + + +def test_extract_verification_code_prefers_chatgpt_phrase(): + text = "foo 123456 bar Your ChatGPT code is 654321" + assert extract_verification_code(text) == "654321" + + +def test_extract_verification_code_fallback_last_code(): + text = "codes 111111 and 222222" + assert extract_verification_code(text) == "222222" + + +def test_build_authorize_url_contains_required_params(): + url = build_authorize_url("challenge", "state123") + assert "response_type=code" in url + assert "code_challenge=challenge" in url + assert "state=state123" in url diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py new file mode 100644 index 0000000..4b74a3c --- /dev/null +++ b/tests/test_server_unit.py @@ -0,0 +1,150 @@ +import asyncio +import json + +import server +from providers.base import Provider, ProviderTokens + + +class FakeRequest: + def __init__(self, provider: str): + self.match_info = {"provider": provider} + + +class FakeProvider(Provider): + def __init__( + self, + token: str | None = "tok", + usage: dict | None = None, + rotate: bool = False, + ): + self._token = token + self._usage = usage or { + "used_percent": 10, + "remaining_percent": 90, + "primary_window": None, + "secondary_window": None, + } + self._rotate = rotate + self.get_token_calls = 0 + self.standby_calls = 0 + + @property + def name(self) -> str: + return "fake" + + async def get_token(self) -> str | None: + self.get_token_calls += 1 + return self._token + + async def register_new_account(self) -> bool: + return True + + async def get_usage_info(self, access_token: str) -> dict: + _ = access_token + return dict(self._usage) + + def load_tokens(self) -> ProviderTokens | None: + return None + + def save_tokens(self, tokens: ProviderTokens) -> None: + _ = tokens + + async def maybe_rotate_account(self, usage_percent: int) -> bool: + _ = usage_percent + return self._rotate + + async def ensure_standby_account( + self, usage_percent: int, prepare_threshold: int + ) -> None: + _ = usage_percent, prepare_threshold + self.standby_calls += 1 + + +def _response_json(resp) -> dict: + return json.loads(resp.body.decode("utf-8")) + + +def test_parse_int_env_defaults(monkeypatch): + monkeypatch.delenv("X_TEST", raising=False) + assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + + +def test_parse_int_env_invalid(monkeypatch): + monkeypatch.setenv("X_TEST", "abc") + assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + + +def test_parse_int_env_out_of_range(monkeypatch): + monkeypatch.setenv("X_TEST", "999") + assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + + +def test_build_limit_fields(): + limit = server.build_limit(90, 85) + assert limit == { + "used_percent": 90, + "remaining_percent": 10, + "exhausted": False, + "needs_prepare": True, + } + + +def test_get_prepare_threshold(): + assert server.get_prepare_threshold("chatgpt") == server.CHATGPT_PREPARE_THRESHOLD + assert server.get_prepare_threshold("unknown") == 100 + + +def test_token_handler_unknown_provider(monkeypatch): + monkeypatch.setattr(server, "PROVIDERS", {}) + resp = asyncio.run(server.token_handler(FakeRequest("missing"))) + assert resp.status == 404 + + +def test_token_handler_success(monkeypatch): + provider = FakeProvider() + monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) + monkeypatch.setattr(server, "background_tasks", {"fake": None}) + monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) + + resp = asyncio.run(server.token_handler(FakeRequest("fake"))) + data = _response_json(resp) + + assert resp.status == 200 + assert data["token"] == "tok" + assert data["limit"]["needs_prepare"] is False + + +def test_token_handler_triggers_standby(monkeypatch): + provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10}) + monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) + monkeypatch.setattr(server, "background_tasks", {"fake": None}) + + called = {"value": False} + + def fake_trigger(name, usage_percent, reason): + assert name == "fake" + assert usage_percent == 90 + assert "threshold" in reason + called["value"] = True + + monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) + monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger) + + resp = asyncio.run(server.token_handler(FakeRequest("fake"))) + assert resp.status == 200 + assert called["value"] is True + + +def test_token_handler_rotation_path(monkeypatch): + provider = FakeProvider( + usage={"used_percent": 96, "remaining_percent": 4}, + rotate=True, + ) + monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) + monkeypatch.setattr(server, "background_tasks", {"fake": None}) + monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) + monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None) + + resp = asyncio.run(server.token_handler(FakeRequest("fake"))) + assert resp.status == 200 + assert provider.get_token_calls >= 2 diff --git a/tests/test_tokens_unit.py b/tests/test_tokens_unit.py new file mode 100644 index 0000000..58f5af5 --- /dev/null +++ b/tests/test_tokens_unit.py @@ -0,0 +1,60 @@ +import json +from pathlib import Path + +from providers.base import ProviderTokens +from providers.chatgpt import tokens as t + + +def test_normalize_state_backward_compatible(): + raw = {"access_token": "a", "refresh_token": "r", "expires_at": 1} + normalized = t._normalize_state(raw) + assert normalized["active"]["access_token"] == "a" + assert normalized["next_account"] is None + + +def test_promote_next_tokens(tmp_path, monkeypatch): + file_path = tmp_path / "chatgpt_tokens.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + + active = ProviderTokens("a1", "r1", 100) + nxt = ProviderTokens("a2", "r2", 200) + t.save_state(active, nxt) + + assert t.promote_next_tokens() is True + cur, next_cur = t.load_state() + assert cur is not None + assert cur.access_token == "a2" + assert next_cur is None + + +def test_save_tokens_preserves_next(tmp_path, monkeypatch): + file_path = tmp_path / "chatgpt_tokens.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + + active = ProviderTokens("a1", "r1", 100) + nxt = ProviderTokens("a2", "r2", 200) + t.save_state(active, nxt) + + t.save_tokens(ProviderTokens("a3", "r3", 300)) + cur, next_cur = t.load_state() + assert cur is not None and cur.access_token == "a3" + assert next_cur is not None and next_cur.access_token == "a2" + + +def test_atomic_write_produces_valid_json(tmp_path, monkeypatch): + file_path = tmp_path / "chatgpt_tokens.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + + t.save_state(ProviderTokens("x", "y", 123), None) + with open(file_path) as f: + data = json.load(f) + assert "active" in data + assert data["active"]["access_token"] == "x" + + +def test_load_state_from_missing_file(tmp_path, monkeypatch): + file_path = tmp_path / "missing.json" + monkeypatch.setattr(t, "TOKENS_FILE", file_path) + active, nxt = t.load_state() + assert active is None + assert nxt is None diff --git a/tests/test_usage_unit.py b/tests/test_usage_unit.py new file mode 100644 index 0000000..63cad99 --- /dev/null +++ b/tests/test_usage_unit.py @@ -0,0 +1,32 @@ +from providers.chatgpt.usage import _parse_window, clamp_percent + + +def test_clamp_percent_bounds(): + assert clamp_percent(-1) == 0 + assert clamp_percent(150) == 100 + assert clamp_percent(49.6) == 50 + + +def test_clamp_percent_invalid(): + assert clamp_percent(None) == 0 + assert clamp_percent("bad") == 0 + + +def test_parse_window_valid(): + window = { + "used_percent": 34.4, + "limit_window_seconds": 3600, + "reset_after_seconds": 120, + "reset_at": 999, + } + parsed = _parse_window(window) + assert parsed == { + "used_percent": 34, + "limit_window_seconds": 3600, + "reset_after_seconds": 120, + "reset_at": 999, + } + + +def test_parse_window_none(): + assert _parse_window(None) is None