1
0
Fork 0

refactor: some minor cleanup

This commit is contained in:
Arthur K. 2026-03-02 19:33:43 +03:00
parent 307ca38ecc
commit 858d127246
Signed by: wzray
GPG key ID: B97F30FDC4636357
12 changed files with 84 additions and 59 deletions

View file

@ -12,5 +12,5 @@ class BaseProvider(ABC):
pass pass
@abstractmethod @abstractmethod
async def get_latest_message(self, email: str) -> str | None: async def get_latest_message(self) -> str | None:
pass pass

View file

@ -9,6 +9,7 @@ import aiohttp
from playwright.async_api import BrowserContext from playwright.async_api import BrowserContext
from .base import BaseProvider from .base import BaseProvider
from utils.randoms import generate_password
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,11 +52,6 @@ def _generate_local_part() -> str:
return f"{first}{last}{digits}" return f"{first}{last}{digits}"
def _generate_password(length: int = 24) -> str:
alphabet = string.ascii_letters + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))
class MailTmProvider(BaseProvider): class MailTmProvider(BaseProvider):
def __init__(self, browser_session: BrowserContext): def __init__(self, browser_session: BrowserContext):
super().__init__(browser_session) super().__init__(browser_session)
@ -146,7 +142,7 @@ class MailTmProvider(BaseProvider):
for _ in range(8): for _ in range(8):
domain = secrets.choice(domains) domain = secrets.choice(domains)
address = f"{_generate_local_part()}@{domain}" address = f"{_generate_local_part()}@{domain}"
password = _generate_password() password = generate_password(length=24)
created = await self._create_account(address, password) created = await self._create_account(address, password)
if not created: if not created:
@ -210,8 +206,7 @@ class MailTmProvider(BaseProvider):
text = "\n".join(str(part) for part in parts if part) text = "\n".join(str(part) for part in parts if part)
return text or None return text or None
async def get_latest_message(self, email: str) -> str | None: async def get_latest_message(self) -> str | None:
del email
if not self._token: if not self._token:
raise RuntimeError("mail.tm provider is not initialized with mailbox token") raise RuntimeError("mail.tm provider is not initialized with mailbox token")

View file

@ -5,6 +5,7 @@ import re
from playwright.async_api import BrowserContext, Error as PlaywrightError, Page from playwright.async_api import BrowserContext, Error as PlaywrightError, Page
from .base import BaseProvider from .base import BaseProvider
from .utils import ensure_page
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider):
self.page: Page | None = None self.page: Page | None = None
async def _ensure_page(self) -> Page: async def _ensure_page(self) -> Page:
if self.page is None or self.page.is_closed(): self.page = await ensure_page(self.browser_session, self.page)
self.page = await self.browser_session.new_page()
return self.page return self.page
async def get_new_email(self) -> str: async def get_new_email(self) -> str:
@ -60,9 +60,9 @@ class TempMailOrgProvider(BaseProvider):
raise RuntimeError("Could not get temp email from temp-mail.org") raise RuntimeError("Could not get temp email from temp-mail.org")
async def get_latest_message(self, email: str) -> str | None: async def get_latest_message(self) -> str | None:
page = await self._ensure_page() page = await self._ensure_page()
logger.info("[temp-mail.org] Waiting for latest message for %s", email) logger.info("[temp-mail.org] Waiting for latest message")
if page.is_closed(): if page.is_closed():
raise RuntimeError("temp-mail.org tab was closed unexpectedly") raise RuntimeError("temp-mail.org tab was closed unexpectedly")

View file

@ -4,6 +4,7 @@ import logging
from playwright.async_api import BrowserContext, Page from playwright.async_api import BrowserContext, Page
from .base import BaseProvider from .base import BaseProvider
from .utils import ensure_page
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider):
self.page: Page | None = None self.page: Page | None = None
async def _ensure_page(self) -> Page: async def _ensure_page(self) -> Page:
if self.page is None or self.page.is_closed(): self.page = await ensure_page(self.browser_session, self.page)
self.page = await self.browser_session.new_page()
return self.page return self.page
async def get_new_email(self) -> str: async def get_new_email(self) -> str:
@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider):
logger.info("[10min] New email acquired: %s", email) logger.info("[10min] New email acquired: %s", email)
return email return email
async def get_latest_message(self, email: str) -> str | None: async def get_latest_message(self) -> str | None:
page = await self._ensure_page() page = await self._ensure_page()
logger.info("[10min] Waiting for latest message for %s", email) logger.info("[10min] Waiting for latest message")
seen_count = 0 seen_count = 0
for attempt in range(60): for attempt in range(60):

View file

@ -0,0 +1,10 @@
from playwright.async_api import BrowserContext, Page
async def ensure_page(
browser_session: BrowserContext,
page: Page | None,
) -> Page:
if page is None or page.is_closed():
return await browser_session.new_page()
return page

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
import logging import logging
import os
from typing import Any from typing import Any
from typing import Callable from typing import Callable
@ -9,6 +8,7 @@ from playwright.async_api import BrowserContext
from email_providers import BaseProvider from email_providers import BaseProvider
from email_providers import MailTmProvider from email_providers import MailTmProvider
from providers.base import Provider, ProviderTokens from providers.base import Provider, ProviderTokens
from utils.env import parse_int_env
from .tokens import ( from .tokens import (
clear_next_tokens, clear_next_tokens,
load_next_tokens, load_next_tokens,
@ -24,7 +24,13 @@ from .registration import register_chatgpt_account
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4 CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
CHATGPT_SWITCH_THRESHOLD = int(os.environ.get("CHATGPT_SWITCH_THRESHOLD", "95")) CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
CHATGPT_SWITCH_THRESHOLD = parse_int_env(
"CHATGPT_SWITCH_THRESHOLD",
95,
0,
100,
)
class ChatGPTProvider(Provider): class ChatGPTProvider(Provider):

View file

@ -5,7 +5,6 @@ import logging
import random import random
import re import re
import secrets import secrets
import string
import time import time
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -24,6 +23,7 @@ from playwright.async_api import (
from browser import launch as launch_browser from browser import launch as launch_browser
from email_providers import BaseProvider from email_providers import BaseProvider
from providers.base import ProviderTokens from providers.base import ProviderTokens
from utils.randoms import generate_password
from .tokens import CLIENT_ID from .tokens import CLIENT_ID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,11 +56,6 @@ async def save_error_screenshot(page: Page | None, step: str):
logger.warning("Failed to save screenshot at step %s: %s", step, e) logger.warning("Failed to save screenshot at step %s: %s", step, e)
def generate_password(length: int = 20) -> str:
alphabet = string.ascii_letters + string.digits
return "".join(random.choice(alphabet) for _ in range(length))
def generate_name() -> str: def generate_name() -> str:
first_names = [ first_names = [
"James", "James",
@ -255,8 +250,8 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
raise RuntimeError(f"Token exchange response parse error: {e}") from e raise RuntimeError(f"Token exchange response parse error: {e}") from e
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: async def get_latest_code(email_provider: BaseProvider) -> str | None:
message = await email_provider.get_latest_message(email) message = await email_provider.get_latest_message()
if not message: if not message:
return None return None
return extract_verification_code(message) return extract_verification_code(message)
@ -403,7 +398,7 @@ async def register_chatgpt_account(
) )
logger.info("[3/5] Getting verification message from email provider...") logger.info("[3/5] Getting verification message from email provider...")
code = await get_latest_code(email_provider, email) code = await get_latest_code(email_provider)
if not code: if not code:
raise AutomationError( raise AutomationError(
"email_provider", "Email provider returned no verification message" "email_provider", "Email provider returned no verification message"
@ -472,7 +467,7 @@ async def register_chatgpt_account(
if await oauth_needs_email_check(oauth_page): if await oauth_needs_email_check(oauth_page):
logger.info("OAuth requested email confirmation code") logger.info("OAuth requested email confirmation code")
new_code = await get_latest_code(email_provider, email) new_code = await get_latest_code(email_provider)
if new_code and new_code != last_oauth_email_code: if new_code and new_code != last_oauth_email_code:
filled = await fill_oauth_code_if_present(oauth_page, new_code) filled = await fill_oauth_code_if_present(oauth_page, new_code)
if filled: if filled:

View file

@ -5,37 +5,16 @@ import os
from aiohttp import web from aiohttp import web
from providers.chatgpt import ChatGPTProvider from providers.chatgpt import ChatGPTProvider
from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD from providers.chatgpt.provider import (
CHATGPT_PREPARE_THRESHOLD,
CHATGPT_SWITCH_THRESHOLD,
)
from providers.base import Provider from providers.base import Provider
from utils.env import parse_int_env
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PORT = parse_int_env("PORT", 8080, 1, 65535)
def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int:
raw = os.environ.get(name)
if raw is None:
return default
try:
value = int(raw)
except ValueError:
logger.warning("Invalid %s=%r, using default %s", name, raw, default)
return default
if value < minimum or value > maximum:
logger.warning(
"%s=%s out of range [%s,%s], using default %s",
name,
value,
minimum,
maximum,
default,
)
return default
return value
PORT = _parse_int_env("PORT", 8080, 1, 65535)
CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
LIMIT_EXHAUSTED_PERCENT = 100 LIMIT_EXHAUSTED_PERCENT = 100
PROVIDERS: dict[str, Provider] = { PROVIDERS: dict[str, Provider] = {

4
src/utils/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .env import parse_int_env
from .randoms import generate_password
__all__ = ["parse_int_env", "generate_password"]

22
src/utils/env.py Normal file
View file

@ -0,0 +1,22 @@
import os
def parse_int_env(
name: str,
default: int,
minimum: int,
maximum: int,
) -> int:
raw = os.environ.get(name)
if raw is None:
return default
try:
value = int(raw)
except ValueError:
return default
if value < minimum or value > maximum:
return default
return value

13
src/utils/randoms.py Normal file
View file

@ -0,0 +1,13 @@
import random
import secrets
import string
def generate_password(
length: int = 20,
*,
secure: bool = True,
) -> str:
alphabet = string.ascii_letters + string.digits
chooser = secrets.choice if secure else random.choice
return "".join(chooser(alphabet) for _ in range(length))

View file

@ -3,6 +3,7 @@ import json
import server import server
from providers.base import Provider, ProviderTokens from providers.base import Provider, ProviderTokens
from utils.env import parse_int_env
class FakeRequest: class FakeRequest:
@ -66,17 +67,17 @@ def _response_json(resp) -> dict:
def test_parse_int_env_defaults(monkeypatch): def test_parse_int_env_defaults(monkeypatch):
monkeypatch.delenv("X_TEST", raising=False) monkeypatch.delenv("X_TEST", raising=False)
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_parse_int_env_invalid(monkeypatch): def test_parse_int_env_invalid(monkeypatch):
monkeypatch.setenv("X_TEST", "abc") monkeypatch.setenv("X_TEST", "abc")
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_parse_int_env_out_of_range(monkeypatch): def test_parse_int_env_out_of_range(monkeypatch):
monkeypatch.setenv("X_TEST", "999") monkeypatch.setenv("X_TEST", "999")
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10 assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_build_limit_fields(): def test_build_limit_fields():