refactor: some minor cleanup
This commit is contained in:
parent
307ca38ecc
commit
858d127246
12 changed files with 84 additions and 59 deletions
|
|
@ -12,5 +12,5 @@ class BaseProvider(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_latest_message(self, email: str) -> str | None:
|
||||
async def get_latest_message(self) -> str | None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import aiohttp
|
|||
from playwright.async_api import BrowserContext
|
||||
|
||||
from .base import BaseProvider
|
||||
from utils.randoms import generate_password
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -51,11 +52,6 @@ def _generate_local_part() -> str:
|
|||
return f"{first}{last}{digits}"
|
||||
|
||||
|
||||
def _generate_password(length: int = 24) -> str:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
|
||||
class MailTmProvider(BaseProvider):
|
||||
def __init__(self, browser_session: BrowserContext):
|
||||
super().__init__(browser_session)
|
||||
|
|
@ -146,7 +142,7 @@ class MailTmProvider(BaseProvider):
|
|||
for _ in range(8):
|
||||
domain = secrets.choice(domains)
|
||||
address = f"{_generate_local_part()}@{domain}"
|
||||
password = _generate_password()
|
||||
password = generate_password(length=24)
|
||||
|
||||
created = await self._create_account(address, password)
|
||||
if not created:
|
||||
|
|
@ -210,8 +206,7 @@ class MailTmProvider(BaseProvider):
|
|||
text = "\n".join(str(part) for part in parts if part)
|
||||
return text or None
|
||||
|
||||
async def get_latest_message(self, email: str) -> str | None:
|
||||
del email
|
||||
async def get_latest_message(self) -> str | None:
|
||||
if not self._token:
|
||||
raise RuntimeError("mail.tm provider is not initialized with mailbox token")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import re
|
|||
from playwright.async_api import BrowserContext, Error as PlaywrightError, Page
|
||||
|
||||
from .base import BaseProvider
|
||||
from .utils import ensure_page
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider):
|
|||
self.page: Page | None = None
|
||||
|
||||
async def _ensure_page(self) -> Page:
|
||||
if self.page is None or self.page.is_closed():
|
||||
self.page = await self.browser_session.new_page()
|
||||
self.page = await ensure_page(self.browser_session, self.page)
|
||||
return self.page
|
||||
|
||||
async def get_new_email(self) -> str:
|
||||
|
|
@ -60,9 +60,9 @@ class TempMailOrgProvider(BaseProvider):
|
|||
|
||||
raise RuntimeError("Could not get temp email from temp-mail.org")
|
||||
|
||||
async def get_latest_message(self, email: str) -> str | None:
|
||||
async def get_latest_message(self) -> str | None:
|
||||
page = await self._ensure_page()
|
||||
logger.info("[temp-mail.org] Waiting for latest message for %s", email)
|
||||
logger.info("[temp-mail.org] Waiting for latest message")
|
||||
|
||||
if page.is_closed():
|
||||
raise RuntimeError("temp-mail.org tab was closed unexpectedly")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
from playwright.async_api import BrowserContext, Page
|
||||
|
||||
from .base import BaseProvider
|
||||
from .utils import ensure_page
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider):
|
|||
self.page: Page | None = None
|
||||
|
||||
async def _ensure_page(self) -> Page:
|
||||
if self.page is None or self.page.is_closed():
|
||||
self.page = await self.browser_session.new_page()
|
||||
self.page = await ensure_page(self.browser_session, self.page)
|
||||
return self.page
|
||||
|
||||
async def get_new_email(self) -> str:
|
||||
|
|
@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider):
|
|||
logger.info("[10min] New email acquired: %s", email)
|
||||
return email
|
||||
|
||||
async def get_latest_message(self, email: str) -> str | None:
|
||||
async def get_latest_message(self) -> str | None:
|
||||
page = await self._ensure_page()
|
||||
logger.info("[10min] Waiting for latest message for %s", email)
|
||||
logger.info("[10min] Waiting for latest message")
|
||||
|
||||
seen_count = 0
|
||||
for attempt in range(60):
|
||||
|
|
|
|||
10
src/email_providers/utils.py
Normal file
10
src/email_providers/utils.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from playwright.async_api import BrowserContext, Page
|
||||
|
||||
|
||||
async def ensure_page(
|
||||
browser_session: BrowserContext,
|
||||
page: Page | None,
|
||||
) -> Page:
|
||||
if page is None or page.is_closed():
|
||||
return await browser_session.new_page()
|
||||
return page
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
|
|
@ -9,6 +8,7 @@ from playwright.async_api import BrowserContext
|
|||
from email_providers import BaseProvider
|
||||
from email_providers import MailTmProvider
|
||||
from providers.base import Provider, ProviderTokens
|
||||
from utils.env import parse_int_env
|
||||
from .tokens import (
|
||||
clear_next_tokens,
|
||||
load_next_tokens,
|
||||
|
|
@ -24,7 +24,13 @@ from .registration import register_chatgpt_account
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
|
||||
CHATGPT_SWITCH_THRESHOLD = int(os.environ.get("CHATGPT_SWITCH_THRESHOLD", "95"))
|
||||
CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
|
||||
CHATGPT_SWITCH_THRESHOLD = parse_int_env(
|
||||
"CHATGPT_SWITCH_THRESHOLD",
|
||||
95,
|
||||
0,
|
||||
100,
|
||||
)
|
||||
|
||||
|
||||
class ChatGPTProvider(Provider):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import logging
|
|||
import random
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
|
@ -24,6 +23,7 @@ from playwright.async_api import (
|
|||
from browser import launch as launch_browser
|
||||
from email_providers import BaseProvider
|
||||
from providers.base import ProviderTokens
|
||||
from utils.randoms import generate_password
|
||||
from .tokens import CLIENT_ID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -56,11 +56,6 @@ async def save_error_screenshot(page: Page | None, step: str):
|
|||
logger.warning("Failed to save screenshot at step %s: %s", step, e)
|
||||
|
||||
|
||||
def generate_password(length: int = 20) -> str:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return "".join(random.choice(alphabet) for _ in range(length))
|
||||
|
||||
|
||||
def generate_name() -> str:
|
||||
first_names = [
|
||||
"James",
|
||||
|
|
@ -255,8 +250,8 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
|||
raise RuntimeError(f"Token exchange response parse error: {e}") from e
|
||||
|
||||
|
||||
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
||||
message = await email_provider.get_latest_message(email)
|
||||
async def get_latest_code(email_provider: BaseProvider) -> str | None:
|
||||
message = await email_provider.get_latest_message()
|
||||
if not message:
|
||||
return None
|
||||
return extract_verification_code(message)
|
||||
|
|
@ -403,7 +398,7 @@ async def register_chatgpt_account(
|
|||
)
|
||||
|
||||
logger.info("[3/5] Getting verification message from email provider...")
|
||||
code = await get_latest_code(email_provider, email)
|
||||
code = await get_latest_code(email_provider)
|
||||
if not code:
|
||||
raise AutomationError(
|
||||
"email_provider", "Email provider returned no verification message"
|
||||
|
|
@ -472,7 +467,7 @@ async def register_chatgpt_account(
|
|||
|
||||
if await oauth_needs_email_check(oauth_page):
|
||||
logger.info("OAuth requested email confirmation code")
|
||||
new_code = await get_latest_code(email_provider, email)
|
||||
new_code = await get_latest_code(email_provider)
|
||||
if new_code and new_code != last_oauth_email_code:
|
||||
filled = await fill_oauth_code_if_present(oauth_page, new_code)
|
||||
if filled:
|
||||
|
|
|
|||
|
|
@ -5,37 +5,16 @@ import os
|
|||
from aiohttp import web
|
||||
|
||||
from providers.chatgpt import ChatGPTProvider
|
||||
from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD
|
||||
from providers.chatgpt.provider import (
|
||||
CHATGPT_PREPARE_THRESHOLD,
|
||||
CHATGPT_SWITCH_THRESHOLD,
|
||||
)
|
||||
from providers.base import Provider
|
||||
from utils.env import parse_int_env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
logger.warning("Invalid %s=%r, using default %s", name, raw, default)
|
||||
return default
|
||||
if value < minimum or value > maximum:
|
||||
logger.warning(
|
||||
"%s=%s out of range [%s,%s], using default %s",
|
||||
name,
|
||||
value,
|
||||
minimum,
|
||||
maximum,
|
||||
default,
|
||||
)
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
PORT = _parse_int_env("PORT", 8080, 1, 65535)
|
||||
CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
|
||||
PORT = parse_int_env("PORT", 8080, 1, 65535)
|
||||
LIMIT_EXHAUSTED_PERCENT = 100
|
||||
|
||||
PROVIDERS: dict[str, Provider] = {
|
||||
|
|
|
|||
4
src/utils/__init__.py
Normal file
4
src/utils/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .env import parse_int_env
|
||||
from .randoms import generate_password
|
||||
|
||||
__all__ = ["parse_int_env", "generate_password"]
|
||||
22
src/utils/env.py
Normal file
22
src/utils/env.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import os
|
||||
|
||||
|
||||
def parse_int_env(
|
||||
name: str,
|
||||
default: int,
|
||||
minimum: int,
|
||||
maximum: int,
|
||||
) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
if value < minimum or value > maximum:
|
||||
return default
|
||||
|
||||
return value
|
||||
13
src/utils/randoms.py
Normal file
13
src/utils/randoms.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import random
|
||||
import secrets
|
||||
import string
|
||||
|
||||
|
||||
def generate_password(
|
||||
length: int = 20,
|
||||
*,
|
||||
secure: bool = True,
|
||||
) -> str:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
chooser = secrets.choice if secure else random.choice
|
||||
return "".join(chooser(alphabet) for _ in range(length))
|
||||
Loading…
Add table
Add a link
Reference in a new issue