From 0a71012709bae8e0d7ec80c8ff0279b3a8c0990d Mon Sep 17 00:00:00 2001 From: "Arthur K." Date: Mon, 2 Mar 2026 19:33:43 +0300 Subject: [PATCH] refactor: some minor cleanup --- .gitignore | 1 + Dockerfile | 3 + src/email_providers/base.py | 2 +- src/email_providers/mail_tm.py | 11 +-- src/email_providers/temp_mail_org.py | 8 +- src/email_providers/ten_minute_mail.py | 8 +- src/email_providers/utils.py | 10 +++ src/providers/base.py | 16 +++- src/providers/chatgpt/provider.py | 28 +++++-- src/providers/chatgpt/registration.py | 15 ++-- src/server.py | 101 ++++++++++++++----------- src/utils/__init__.py | 4 + src/utils/env.py | 22 ++++++ src/utils/randoms.py | 13 ++++ tests/test_server_unit.py | 37 ++++++--- 15 files changed, 188 insertions(+), 91 deletions(-) create mode 100644 src/email_providers/utils.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/env.py create mode 100644 src/utils/randoms.py diff --git a/.gitignore b/.gitignore index afa1ad1..701c6dd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__/ .ruff_cache/ .venv/ +.pytest_cache/ diff --git a/Dockerfile b/Dockerfile index 700b8ab..e08d4ba 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,6 +28,9 @@ VOLUME ["/data"] EXPOSE 80 +HEALTHCHECK --start-period=5s --start-interval=1s CMD \ + test "$(curl -fsS "http://127.0.0.1:$PORT/health")" = "ok" + STOPSIGNAL SIGINT CMD ["/entrypoint.sh"] diff --git a/src/email_providers/base.py b/src/email_providers/base.py index 386601d..878c7e6 100644 --- a/src/email_providers/base.py +++ b/src/email_providers/base.py @@ -12,5 +12,5 @@ class BaseProvider(ABC): pass @abstractmethod - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: pass diff --git a/src/email_providers/mail_tm.py b/src/email_providers/mail_tm.py index 891a789..4dd9c6c 100644 --- a/src/email_providers/mail_tm.py +++ b/src/email_providers/mail_tm.py @@ -9,6 +9,7 @@ import aiohttp from playwright.async_api import BrowserContext from .base import BaseProvider +from utils.randoms import generate_password logger = logging.getLogger(__name__) @@ -51,11 +52,6 @@ def _generate_local_part() -> str: return f"{first}{last}{digits}" -def _generate_password(length: int = 24) -> str: - alphabet = string.ascii_letters + string.digits - return "".join(secrets.choice(alphabet) for _ in range(length)) - - class MailTmProvider(BaseProvider): def __init__(self, browser_session: BrowserContext): super().__init__(browser_session) @@ -146,7 +142,7 @@ class MailTmProvider(BaseProvider): for _ in range(8): domain = secrets.choice(domains) address = f"{_generate_local_part()}@{domain}" - password = _generate_password() + password = generate_password(length=24) created = await self._create_account(address, password) if not created: @@ -210,8 +206,7 @@ class MailTmProvider(BaseProvider): text = "\n".join(str(part) for part in parts if part) return text or None - async def get_latest_message(self, email: str) -> str | None: - del email + async def get_latest_message(self) -> str | None: if not self._token: raise RuntimeError("mail.tm provider is not initialized with mailbox token") diff --git a/src/email_providers/temp_mail_org.py b/src/email_providers/temp_mail_org.py index 0a0b65f..25d62cb 100644 --- a/src/email_providers/temp_mail_org.py +++ b/src/email_providers/temp_mail_org.py @@ -5,6 +5,7 @@ import re from playwright.async_api import BrowserContext, Error as PlaywrightError, Page from .base import BaseProvider +from .utils import ensure_page logger = logging.getLogger(__name__) @@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - if self.page is None or self.page.is_closed(): - self.page = await self.browser_session.new_page() + self.page = await ensure_page(self.browser_session, self.page) return self.page async def get_new_email(self) -> str: @@ -60,9 +60,9 @@ class TempMailOrgProvider(BaseProvider): raise RuntimeError("Could not get temp email from temp-mail.org") - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: page = await self._ensure_page() - logger.info("[temp-mail.org] Waiting for latest message for %s", email) + logger.info("[temp-mail.org] Waiting for latest message") if page.is_closed(): raise RuntimeError("temp-mail.org tab was closed unexpectedly") diff --git a/src/email_providers/ten_minute_mail.py b/src/email_providers/ten_minute_mail.py index ae4c35e..1ff38a6 100644 --- a/src/email_providers/ten_minute_mail.py +++ b/src/email_providers/ten_minute_mail.py @@ -4,6 +4,7 @@ import logging from playwright.async_api import BrowserContext, Page from .base import BaseProvider +from .utils import ensure_page logger = logging.getLogger(__name__) @@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - if self.page is None or self.page.is_closed(): - self.page = await self.browser_session.new_page() + self.page = await ensure_page(self.browser_session, self.page) return self.page async def get_new_email(self) -> str: @@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider): logger.info("[10min] New email acquired: %s", email) return email - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: page = await self._ensure_page() - logger.info("[10min] Waiting for latest message for %s", email) + logger.info("[10min] Waiting for latest message") seen_count = 0 for attempt in range(60): diff --git a/src/email_providers/utils.py b/src/email_providers/utils.py new file mode 100644 index 0000000..c6d44ce --- /dev/null +++ b/src/email_providers/utils.py @@ -0,0 +1,10 @@ +from playwright.async_api import BrowserContext, Page + + +async def ensure_page( + browser_session: BrowserContext, + page: Page | None, +) -> Page: + if page is None or page.is_closed(): + return await browser_session.new_page() + return page diff --git a/src/providers/base.py b/src/providers/base.py index 3185c5b..d1deb94 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -61,10 +61,24 @@ class Provider(ABC): """Rotate active account/token if provider policy requires it.""" return False + @property + def prepare_threshold(self) -> int: + """Usage percent when provider should prepare standby account/token.""" + return 100 + + @property + def switch_threshold(self) -> int | None: + """Usage percent when provider may switch active account/token.""" + return None + + def should_prepare_standby(self, usage_percent: int) -> bool: + """Whether standby preparation should be triggered for current usage.""" + _ = usage_percent + return False + async def ensure_standby_account( self, usage_percent: int, - prepare_threshold: int, ) -> None: """Prepare standby account/token asynchronously when needed.""" return None diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index 445500f..67ffa88 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -1,6 +1,5 @@ import asyncio import logging -import os from typing import Any from typing import Callable @@ -9,6 +8,7 @@ from playwright.async_api import BrowserContext from email_providers import BaseProvider from email_providers import MailTmProvider from providers.base import Provider, ProviderTokens +from utils.env import parse_int_env from .tokens import ( clear_next_tokens, load_next_tokens, @@ -24,7 +24,13 @@ from .registration import register_chatgpt_account logger = logging.getLogger(__name__) CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4 -CHATGPT_SWITCH_THRESHOLD = int(os.environ.get("CHATGPT_SWITCH_THRESHOLD", "95")) +CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +CHATGPT_SWITCH_THRESHOLD = parse_int_env( + "CHATGPT_SWITCH_THRESHOLD", + 95, + 0, + 100, +) class ChatGPTProvider(Provider): @@ -37,6 +43,14 @@ class ChatGPTProvider(Provider): self.email_provider_factory = email_provider_factory or MailTmProvider self._token_write_lock = asyncio.Lock() + @property + def prepare_threshold(self) -> int: + return CHATGPT_PREPARE_THRESHOLD + + @property + def switch_threshold(self) -> int | None: + return CHATGPT_SWITCH_THRESHOLD + async def _register_with_retries(self) -> bool: for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1): logger.info( @@ -99,21 +113,23 @@ class ChatGPTProvider(Provider): async def ensure_next_account(self) -> bool: next_tokens = load_next_tokens() - if next_tokens and not next_tokens.is_expired: + if next_tokens: return True async with self._token_write_lock: next_tokens = load_next_tokens() - if next_tokens and not next_tokens.is_expired: + if next_tokens: return True return await self._create_next_account_under_lock() + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold and not bool(load_next_tokens()) + async def ensure_standby_account( self, usage_percent: int, - prepare_threshold: int, ) -> None: - if usage_percent >= prepare_threshold: + if usage_percent >= self.prepare_threshold: await self.ensure_next_account() async def maybe_switch_active_account(self, usage_percent: int) -> bool: diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py index 270c6d3..3087078 100644 --- a/src/providers/chatgpt/registration.py +++ b/src/providers/chatgpt/registration.py @@ -5,7 +5,6 @@ import logging import random import re import secrets -import string import time from datetime import datetime from pathlib import Path @@ -24,6 +23,7 @@ from playwright.async_api import ( from browser import launch as launch_browser from email_providers import BaseProvider from providers.base import ProviderTokens +from utils.randoms import generate_password from .tokens import CLIENT_ID logger = logging.getLogger(__name__) @@ -56,11 +56,6 @@ async def save_error_screenshot(page: Page | None, step: str): logger.warning("Failed to save screenshot at step %s: %s", step, e) -def generate_password(length: int = 20) -> str: - alphabet = string.ascii_letters + string.digits - return "".join(random.choice(alphabet) for _ in range(length)) - - def generate_name() -> str: first_names = [ "James", @@ -255,8 +250,8 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens: raise RuntimeError(f"Token exchange response parse error: {e}") from e -async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: - message = await email_provider.get_latest_message(email) +async def get_latest_code(email_provider: BaseProvider) -> str | None: + message = await email_provider.get_latest_message() if not message: return None return extract_verification_code(message) @@ -403,7 +398,7 @@ async def register_chatgpt_account( ) logger.info("[3/5] Getting verification message from email provider...") - code = await get_latest_code(email_provider, email) + code = await get_latest_code(email_provider) if not code: raise AutomationError( "email_provider", "Email provider returned no verification message" @@ -472,7 +467,7 @@ async def register_chatgpt_account( if await oauth_needs_email_check(oauth_page): logger.info("OAuth requested email confirmation code") - new_code = await get_latest_code(email_provider, email) + new_code = await get_latest_code(email_provider) if new_code and new_code != last_oauth_email_code: filled = await fill_oauth_code_if_present(oauth_page, new_code) if filled: diff --git a/src/server.py b/src/server.py index 7740c37..39fc3ca 100644 --- a/src/server.py +++ b/src/server.py @@ -1,41 +1,15 @@ import asyncio import logging -import os -from aiohttp import web +from aiohttp import web, web_log from providers.chatgpt import ChatGPTProvider -from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD from providers.base import Provider +from utils.env import parse_int_env logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - - -def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int: - raw = os.environ.get(name) - if raw is None: - return default - try: - value = int(raw) - except ValueError: - logger.warning("Invalid %s=%r, using default %s", name, raw, default) - return default - if value < minimum or value > maximum: - logger.warning( - "%s=%s out of range [%s,%s], using default %s", - name, - value, - minimum, - maximum, - default, - ) - return default - return value - - -PORT = _parse_int_env("PORT", 8080, 1, 65535) -CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +PORT = parse_int_env("PORT", 8080, 1, 65535) LIMIT_EXHAUSTED_PERCENT = 100 PROVIDERS: dict[str, Provider] = { @@ -45,11 +19,13 @@ PROVIDERS: dict[str, Provider] = { background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS} -@web.middleware -async def request_log_middleware(request: web.Request, handler): - response = await handler(request) - logger.info("%s %s -> %s", request.method, request.path_qs, response.status) - return response +class AccessLogger(web_log.AccessLogger): + def log( + self, request: web.BaseRequest, response: web.StreamResponse, time: float + ) -> None: + if request.path == "/health": + return + super().log(request, response, time) def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]: @@ -63,9 +39,17 @@ def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | b def get_prepare_threshold(provider_name: str) -> int: - if provider_name == "chatgpt": - return CHATGPT_PREPARE_THRESHOLD - return 100 + provider = PROVIDERS.get(provider_name) + if not provider: + return 100 + return provider.prepare_threshold + + +def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bool: + provider = PROVIDERS.get(provider_name) + if not provider: + return False + return provider.should_prepare_standby(usage_percent) async def ensure_provider_token_ready(provider_name: str): @@ -103,10 +87,13 @@ async def ensure_standby_task(provider_name: str, usage_percent: int, reason: st provider = PROVIDERS.get(provider_name) if not provider: return + + if not provider.should_prepare_standby(usage_percent): + return + try: logger.info("[%s] Preparing standby in background (%s)", provider_name, reason) - threshold = get_prepare_threshold(provider_name) - await provider.ensure_standby_account(usage_percent, threshold) + await provider.ensure_standby_account(usage_percent) except Exception: logger.exception("[%s] Unhandled standby preparation error", provider_name) @@ -155,11 +142,11 @@ async def token_handler(request: web.Request) -> web.Response: logger.info("[%s] Active account switched before response", provider_name) prepare_threshold = get_prepare_threshold(provider_name) - if usage_percent >= prepare_threshold: + if should_trigger_standby_prepare(provider_name, usage_percent): trigger_standby_prepare( provider_name, usage_percent, - f"usage {usage_percent}% >= threshold {prepare_threshold}%", + f"usage {usage_percent}% reached standby policy", ) remaining_percent = int( @@ -203,6 +190,11 @@ async def token_handler(request: web.Request) -> web.Response: ) +async def health_handler(request: web.Request) -> web.Response: + del request + return web.Response(text="ok") + + async def on_startup(app: web.Application): del app for provider_name in PROVIDERS: @@ -220,18 +212,35 @@ async def on_cleanup(app: web.Application): def create_app() -> web.Application: - app = web.Application(middlewares=[request_log_middleware]) + app = web.Application() app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) + app.router.add_get("/health", health_handler) app.router.add_get("/{provider}/token", token_handler) app.router.add_get("/token", token_handler) return app -if __name__ == "__main__": +def main(): logger.info("Starting token service on port %s", PORT) - logger.info("ChatGPT prepare-next threshold: %s%%", CHATGPT_PREPARE_THRESHOLD) - logger.info("ChatGPT switch threshold: %s%%", CHATGPT_SWITCH_THRESHOLD) + chatgpt_provider = PROVIDERS.get("chatgpt") + if chatgpt_provider: + logger.info( + "ChatGPT prepare-next threshold: %s%%", chatgpt_provider.prepare_threshold + ) + if chatgpt_provider.switch_threshold is not None: + logger.info( + "ChatGPT switch threshold: %s%%", chatgpt_provider.switch_threshold + ) logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) app = create_app() - web.run_app(app, host="0.0.0.0", port=PORT) + web.run_app( + app, + host="0.0.0.0", + port=PORT, + access_log_class=AccessLogger, + ) + + +if __name__ == "__main__": + main() diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..00303ef --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,4 @@ +from .env import parse_int_env +from .randoms import generate_password + +__all__ = ["parse_int_env", "generate_password"] diff --git a/src/utils/env.py b/src/utils/env.py new file mode 100644 index 0000000..688e65f --- /dev/null +++ b/src/utils/env.py @@ -0,0 +1,22 @@ +import os + + +def parse_int_env( + name: str, + default: int, + minimum: int, + maximum: int, +) -> int: + raw = os.environ.get(name) + if raw is None: + return default + + try: + value = int(raw) + except ValueError: + return default + + if value < minimum or value > maximum: + return default + + return value diff --git a/src/utils/randoms.py b/src/utils/randoms.py new file mode 100644 index 0000000..76adab8 --- /dev/null +++ b/src/utils/randoms.py @@ -0,0 +1,13 @@ +import random +import secrets +import string + + +def generate_password( + length: int = 20, + *, + secure: bool = True, +) -> str: + alphabet = string.ascii_letters + string.digits + chooser = secrets.choice if secure else random.choice + return "".join(chooser(alphabet) for _ in range(length)) diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py index 4b74a3c..864d374 100644 --- a/tests/test_server_unit.py +++ b/tests/test_server_unit.py @@ -3,6 +3,7 @@ import json import server from providers.base import Provider, ProviderTokens +from utils.env import parse_int_env class FakeRequest: @@ -25,9 +26,17 @@ class FakeProvider(Provider): "secondary_window": None, } self._rotate = rotate + self._prepare_threshold = 80 self.get_token_calls = 0 self.standby_calls = 0 + @property + def prepare_threshold(self) -> int: + return self._prepare_threshold + + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold + @property def name(self) -> str: return "fake" @@ -54,9 +63,10 @@ class FakeProvider(Provider): return self._rotate async def ensure_standby_account( - self, usage_percent: int, prepare_threshold: int + self, + usage_percent: int, ) -> None: - _ = usage_percent, prepare_threshold + _ = usage_percent self.standby_calls += 1 @@ -66,17 +76,17 @@ def _response_json(resp) -> dict: def test_parse_int_env_defaults(monkeypatch): monkeypatch.delenv("X_TEST", raising=False) - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_parse_int_env_invalid(monkeypatch): monkeypatch.setenv("X_TEST", "abc") - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_parse_int_env_out_of_range(monkeypatch): monkeypatch.setenv("X_TEST", "999") - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_build_limit_fields(): @@ -90,7 +100,10 @@ def test_build_limit_fields(): def test_get_prepare_threshold(): - assert server.get_prepare_threshold("chatgpt") == server.CHATGPT_PREPARE_THRESHOLD + assert ( + server.get_prepare_threshold("chatgpt") + == server.PROVIDERS["chatgpt"].prepare_threshold + ) assert server.get_prepare_threshold("unknown") == 100 @@ -100,12 +113,16 @@ def test_token_handler_unknown_provider(monkeypatch): assert resp.status == 404 +def test_health_handler_ok(): + resp = asyncio.run(server.health_handler(object())) + assert resp.status == 200 + assert resp.text == "ok" + + def test_token_handler_success(monkeypatch): provider = FakeProvider() monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) monkeypatch.setattr(server, "background_tasks", {"fake": None}) - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) - resp = asyncio.run(server.token_handler(FakeRequest("fake"))) data = _response_json(resp) @@ -124,10 +141,9 @@ def test_token_handler_triggers_standby(monkeypatch): def fake_trigger(name, usage_percent, reason): assert name == "fake" assert usage_percent == 90 - assert "threshold" in reason + assert "standby policy" in reason called["value"] = True - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger) resp = asyncio.run(server.token_handler(FakeRequest("fake"))) @@ -142,7 +158,6 @@ def test_token_handler_rotation_path(monkeypatch): ) monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) monkeypatch.setattr(server, "background_tasks", {"fake": None}) - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None) resp = asyncio.run(server.token_handler(FakeRequest("fake")))