diff --git a/src/email_providers/base.py b/src/email_providers/base.py index 386601d..878c7e6 100644 --- a/src/email_providers/base.py +++ b/src/email_providers/base.py @@ -12,5 +12,5 @@ class BaseProvider(ABC): pass @abstractmethod - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: pass diff --git a/src/email_providers/mail_tm.py b/src/email_providers/mail_tm.py index 891a789..4dd9c6c 100644 --- a/src/email_providers/mail_tm.py +++ b/src/email_providers/mail_tm.py @@ -9,6 +9,7 @@ import aiohttp from playwright.async_api import BrowserContext from .base import BaseProvider +from utils.randoms import generate_password logger = logging.getLogger(__name__) @@ -51,11 +52,6 @@ def _generate_local_part() -> str: return f"{first}{last}{digits}" -def _generate_password(length: int = 24) -> str: - alphabet = string.ascii_letters + string.digits - return "".join(secrets.choice(alphabet) for _ in range(length)) - - class MailTmProvider(BaseProvider): def __init__(self, browser_session: BrowserContext): super().__init__(browser_session) @@ -146,7 +142,7 @@ class MailTmProvider(BaseProvider): for _ in range(8): domain = secrets.choice(domains) address = f"{_generate_local_part()}@{domain}" - password = _generate_password() + password = generate_password(length=24) created = await self._create_account(address, password) if not created: @@ -210,8 +206,7 @@ class MailTmProvider(BaseProvider): text = "\n".join(str(part) for part in parts if part) return text or None - async def get_latest_message(self, email: str) -> str | None: - del email + async def get_latest_message(self) -> str | None: if not self._token: raise RuntimeError("mail.tm provider is not initialized with mailbox token") diff --git a/src/email_providers/temp_mail_org.py b/src/email_providers/temp_mail_org.py index 0a0b65f..25d62cb 100644 --- a/src/email_providers/temp_mail_org.py +++ b/src/email_providers/temp_mail_org.py @@ -5,6 +5,7 @@ import re from playwright.async_api import BrowserContext, Error as PlaywrightError, Page from .base import BaseProvider +from .utils import ensure_page logger = logging.getLogger(__name__) @@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - if self.page is None or self.page.is_closed(): - self.page = await self.browser_session.new_page() + self.page = await ensure_page(self.browser_session, self.page) return self.page async def get_new_email(self) -> str: @@ -60,9 +60,9 @@ class TempMailOrgProvider(BaseProvider): raise RuntimeError("Could not get temp email from temp-mail.org") - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: page = await self._ensure_page() - logger.info("[temp-mail.org] Waiting for latest message for %s", email) + logger.info("[temp-mail.org] Waiting for latest message") if page.is_closed(): raise RuntimeError("temp-mail.org tab was closed unexpectedly") diff --git a/src/email_providers/ten_minute_mail.py b/src/email_providers/ten_minute_mail.py index ae4c35e..1ff38a6 100644 --- a/src/email_providers/ten_minute_mail.py +++ b/src/email_providers/ten_minute_mail.py @@ -4,6 +4,7 @@ import logging from playwright.async_api import BrowserContext, Page from .base import BaseProvider +from .utils import ensure_page logger = logging.getLogger(__name__) @@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider): self.page: Page | None = None async def _ensure_page(self) -> Page: - if self.page is None or self.page.is_closed(): - self.page = await self.browser_session.new_page() + self.page = await ensure_page(self.browser_session, self.page) return self.page async def get_new_email(self) -> str: @@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider): logger.info("[10min] New email acquired: %s", email) return email - async def get_latest_message(self, email: str) -> str | None: + async def get_latest_message(self) -> str | None: page = await self._ensure_page() - logger.info("[10min] Waiting for latest message for %s", email) + logger.info("[10min] Waiting for latest message") seen_count = 0 for attempt in range(60): diff --git a/src/email_providers/utils.py b/src/email_providers/utils.py new file mode 100644 index 0000000..c6d44ce --- /dev/null +++ b/src/email_providers/utils.py @@ -0,0 +1,10 @@ +from playwright.async_api import BrowserContext, Page + + +async def ensure_page( + browser_session: BrowserContext, + page: Page | None, +) -> Page: + if page is None or page.is_closed(): + return await browser_session.new_page() + return page diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index 445500f..e171b51 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -1,6 +1,5 @@ import asyncio import logging -import os from typing import Any from typing import Callable @@ -9,6 +8,7 @@ from playwright.async_api import BrowserContext from email_providers import BaseProvider from email_providers import MailTmProvider from providers.base import Provider, ProviderTokens +from utils.env import parse_int_env from .tokens import ( clear_next_tokens, load_next_tokens, @@ -24,7 +24,13 @@ from .registration import register_chatgpt_account logger = logging.getLogger(__name__) CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4 -CHATGPT_SWITCH_THRESHOLD = int(os.environ.get("CHATGPT_SWITCH_THRESHOLD", "95")) +CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +CHATGPT_SWITCH_THRESHOLD = parse_int_env( + "CHATGPT_SWITCH_THRESHOLD", + 95, + 0, + 100, +) class ChatGPTProvider(Provider): diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py index 270c6d3..3087078 100644 --- a/src/providers/chatgpt/registration.py +++ b/src/providers/chatgpt/registration.py @@ -5,7 +5,6 @@ import logging import random import re import secrets -import string import time from datetime import datetime from pathlib import Path @@ -24,6 +23,7 @@ from playwright.async_api import ( from browser import launch as launch_browser from email_providers import BaseProvider from providers.base import ProviderTokens +from utils.randoms import generate_password from .tokens import CLIENT_ID logger = logging.getLogger(__name__) @@ -56,11 +56,6 @@ async def save_error_screenshot(page: Page | None, step: str): logger.warning("Failed to save screenshot at step %s: %s", step, e) -def generate_password(length: int = 20) -> str: - alphabet = string.ascii_letters + string.digits - return "".join(random.choice(alphabet) for _ in range(length)) - - def generate_name() -> str: first_names = [ "James", @@ -255,8 +250,8 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens: raise RuntimeError(f"Token exchange response parse error: {e}") from e -async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: - message = await email_provider.get_latest_message(email) +async def get_latest_code(email_provider: BaseProvider) -> str | None: + message = await email_provider.get_latest_message() if not message: return None return extract_verification_code(message) @@ -403,7 +398,7 @@ async def register_chatgpt_account( ) logger.info("[3/5] Getting verification message from email provider...") - code = await get_latest_code(email_provider, email) + code = await get_latest_code(email_provider) if not code: raise AutomationError( "email_provider", "Email provider returned no verification message" @@ -472,7 +467,7 @@ async def register_chatgpt_account( if await oauth_needs_email_check(oauth_page): logger.info("OAuth requested email confirmation code") - new_code = await get_latest_code(email_provider, email) + new_code = await get_latest_code(email_provider) if new_code and new_code != last_oauth_email_code: filled = await fill_oauth_code_if_present(oauth_page, new_code) if filled: diff --git a/src/server.py b/src/server.py index 7740c37..6e42ecc 100644 --- a/src/server.py +++ b/src/server.py @@ -5,37 +5,16 @@ import os from aiohttp import web from providers.chatgpt import ChatGPTProvider -from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD +from providers.chatgpt.provider import ( + CHATGPT_PREPARE_THRESHOLD, + CHATGPT_SWITCH_THRESHOLD, +) from providers.base import Provider +from utils.env import parse_int_env logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - - -def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int: - raw = os.environ.get(name) - if raw is None: - return default - try: - value = int(raw) - except ValueError: - logger.warning("Invalid %s=%r, using default %s", name, raw, default) - return default - if value < minimum or value > maximum: - logger.warning( - "%s=%s out of range [%s,%s], using default %s", - name, - value, - minimum, - maximum, - default, - ) - return default - return value - - -PORT = _parse_int_env("PORT", 8080, 1, 65535) -CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100) +PORT = parse_int_env("PORT", 8080, 1, 65535) LIMIT_EXHAUSTED_PERCENT = 100 PROVIDERS: dict[str, Provider] = { diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..00303ef --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,4 @@ +from .env import parse_int_env +from .randoms import generate_password + +__all__ = ["parse_int_env", "generate_password"] diff --git a/src/utils/env.py b/src/utils/env.py new file mode 100644 index 0000000..688e65f --- /dev/null +++ b/src/utils/env.py @@ -0,0 +1,22 @@ +import os + + +def parse_int_env( + name: str, + default: int, + minimum: int, + maximum: int, +) -> int: + raw = os.environ.get(name) + if raw is None: + return default + + try: + value = int(raw) + except ValueError: + return default + + if value < minimum or value > maximum: + return default + + return value diff --git a/src/utils/randoms.py b/src/utils/randoms.py new file mode 100644 index 0000000..76adab8 --- /dev/null +++ b/src/utils/randoms.py @@ -0,0 +1,13 @@ +import random +import secrets +import string + + +def generate_password( + length: int = 20, + *, + secure: bool = True, +) -> str: + alphabet = string.ascii_letters + string.digits + chooser = secrets.choice if secure else random.choice + return "".join(chooser(alphabet) for _ in range(length)) diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py index 4b74a3c..bfe06d4 100644 --- a/tests/test_server_unit.py +++ b/tests/test_server_unit.py @@ -3,6 +3,7 @@ import json import server from providers.base import Provider, ProviderTokens +from utils.env import parse_int_env class FakeRequest: @@ -66,17 +67,17 @@ def _response_json(resp) -> dict: def test_parse_int_env_defaults(monkeypatch): monkeypatch.delenv("X_TEST", raising=False) - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_parse_int_env_invalid(monkeypatch): monkeypatch.setenv("X_TEST", "abc") - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_parse_int_env_out_of_range(monkeypatch): monkeypatch.setenv("X_TEST", "999") - assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 + assert parse_int_env("X_TEST", 10, 1, 20) == 10 def test_build_limit_fields():