refactor: some minor cleanup
This commit is contained in:
parent
6dd26ad3d8
commit
0a71012709
15 changed files with 188 additions and 91 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -2,3 +2,4 @@
|
||||||
__pycache__/
|
__pycache__/
|
||||||
.ruff_cache/
|
.ruff_cache/
|
||||||
.venv/
|
.venv/
|
||||||
|
.pytest_cache/
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,9 @@ VOLUME ["/data"]
|
||||||
|
|
||||||
EXPOSE 80
|
EXPOSE 80
|
||||||
|
|
||||||
|
HEALTHCHECK --start-period=5s --start-interval=1s CMD \
|
||||||
|
test "$(curl -fsS "http://127.0.0.1:$PORT/health")" = "ok"
|
||||||
|
|
||||||
STOPSIGNAL SIGINT
|
STOPSIGNAL SIGINT
|
||||||
|
|
||||||
CMD ["/entrypoint.sh"]
|
CMD ["/entrypoint.sh"]
|
||||||
|
|
|
||||||
|
|
@ -12,5 +12,5 @@ class BaseProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_latest_message(self, email: str) -> str | None:
|
async def get_latest_message(self) -> str | None:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import aiohttp
|
||||||
from playwright.async_api import BrowserContext
|
from playwright.async_api import BrowserContext
|
||||||
|
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
|
from utils.randoms import generate_password
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -51,11 +52,6 @@ def _generate_local_part() -> str:
|
||||||
return f"{first}{last}{digits}"
|
return f"{first}{last}{digits}"
|
||||||
|
|
||||||
|
|
||||||
def _generate_password(length: int = 24) -> str:
|
|
||||||
alphabet = string.ascii_letters + string.digits
|
|
||||||
return "".join(secrets.choice(alphabet) for _ in range(length))
|
|
||||||
|
|
||||||
|
|
||||||
class MailTmProvider(BaseProvider):
|
class MailTmProvider(BaseProvider):
|
||||||
def __init__(self, browser_session: BrowserContext):
|
def __init__(self, browser_session: BrowserContext):
|
||||||
super().__init__(browser_session)
|
super().__init__(browser_session)
|
||||||
|
|
@ -146,7 +142,7 @@ class MailTmProvider(BaseProvider):
|
||||||
for _ in range(8):
|
for _ in range(8):
|
||||||
domain = secrets.choice(domains)
|
domain = secrets.choice(domains)
|
||||||
address = f"{_generate_local_part()}@{domain}"
|
address = f"{_generate_local_part()}@{domain}"
|
||||||
password = _generate_password()
|
password = generate_password(length=24)
|
||||||
|
|
||||||
created = await self._create_account(address, password)
|
created = await self._create_account(address, password)
|
||||||
if not created:
|
if not created:
|
||||||
|
|
@ -210,8 +206,7 @@ class MailTmProvider(BaseProvider):
|
||||||
text = "\n".join(str(part) for part in parts if part)
|
text = "\n".join(str(part) for part in parts if part)
|
||||||
return text or None
|
return text or None
|
||||||
|
|
||||||
async def get_latest_message(self, email: str) -> str | None:
|
async def get_latest_message(self) -> str | None:
|
||||||
del email
|
|
||||||
if not self._token:
|
if not self._token:
|
||||||
raise RuntimeError("mail.tm provider is not initialized with mailbox token")
|
raise RuntimeError("mail.tm provider is not initialized with mailbox token")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import re
|
||||||
from playwright.async_api import BrowserContext, Error as PlaywrightError, Page
|
from playwright.async_api import BrowserContext, Error as PlaywrightError, Page
|
||||||
|
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
|
from .utils import ensure_page
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider):
|
||||||
self.page: Page | None = None
|
self.page: Page | None = None
|
||||||
|
|
||||||
async def _ensure_page(self) -> Page:
|
async def _ensure_page(self) -> Page:
|
||||||
if self.page is None or self.page.is_closed():
|
self.page = await ensure_page(self.browser_session, self.page)
|
||||||
self.page = await self.browser_session.new_page()
|
|
||||||
return self.page
|
return self.page
|
||||||
|
|
||||||
async def get_new_email(self) -> str:
|
async def get_new_email(self) -> str:
|
||||||
|
|
@ -60,9 +60,9 @@ class TempMailOrgProvider(BaseProvider):
|
||||||
|
|
||||||
raise RuntimeError("Could not get temp email from temp-mail.org")
|
raise RuntimeError("Could not get temp email from temp-mail.org")
|
||||||
|
|
||||||
async def get_latest_message(self, email: str) -> str | None:
|
async def get_latest_message(self) -> str | None:
|
||||||
page = await self._ensure_page()
|
page = await self._ensure_page()
|
||||||
logger.info("[temp-mail.org] Waiting for latest message for %s", email)
|
logger.info("[temp-mail.org] Waiting for latest message")
|
||||||
|
|
||||||
if page.is_closed():
|
if page.is_closed():
|
||||||
raise RuntimeError("temp-mail.org tab was closed unexpectedly")
|
raise RuntimeError("temp-mail.org tab was closed unexpectedly")
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import logging
|
||||||
from playwright.async_api import BrowserContext, Page
|
from playwright.async_api import BrowserContext, Page
|
||||||
|
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
|
from .utils import ensure_page
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider):
|
||||||
self.page: Page | None = None
|
self.page: Page | None = None
|
||||||
|
|
||||||
async def _ensure_page(self) -> Page:
|
async def _ensure_page(self) -> Page:
|
||||||
if self.page is None or self.page.is_closed():
|
self.page = await ensure_page(self.browser_session, self.page)
|
||||||
self.page = await self.browser_session.new_page()
|
|
||||||
return self.page
|
return self.page
|
||||||
|
|
||||||
async def get_new_email(self) -> str:
|
async def get_new_email(self) -> str:
|
||||||
|
|
@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider):
|
||||||
logger.info("[10min] New email acquired: %s", email)
|
logger.info("[10min] New email acquired: %s", email)
|
||||||
return email
|
return email
|
||||||
|
|
||||||
async def get_latest_message(self, email: str) -> str | None:
|
async def get_latest_message(self) -> str | None:
|
||||||
page = await self._ensure_page()
|
page = await self._ensure_page()
|
||||||
logger.info("[10min] Waiting for latest message for %s", email)
|
logger.info("[10min] Waiting for latest message")
|
||||||
|
|
||||||
seen_count = 0
|
seen_count = 0
|
||||||
for attempt in range(60):
|
for attempt in range(60):
|
||||||
|
|
|
||||||
10
src/email_providers/utils.py
Normal file
10
src/email_providers/utils.py
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
from playwright.async_api import BrowserContext, Page
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_page(
|
||||||
|
browser_session: BrowserContext,
|
||||||
|
page: Page | None,
|
||||||
|
) -> Page:
|
||||||
|
if page is None or page.is_closed():
|
||||||
|
return await browser_session.new_page()
|
||||||
|
return page
|
||||||
|
|
@ -61,10 +61,24 @@ class Provider(ABC):
|
||||||
"""Rotate active account/token if provider policy requires it."""
|
"""Rotate active account/token if provider policy requires it."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prepare_threshold(self) -> int:
|
||||||
|
"""Usage percent when provider should prepare standby account/token."""
|
||||||
|
return 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def switch_threshold(self) -> int | None:
|
||||||
|
"""Usage percent when provider may switch active account/token."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
||||||
|
"""Whether standby preparation should be triggered for current usage."""
|
||||||
|
_ = usage_percent
|
||||||
|
return False
|
||||||
|
|
||||||
async def ensure_standby_account(
|
async def ensure_standby_account(
|
||||||
self,
|
self,
|
||||||
usage_percent: int,
|
usage_percent: int,
|
||||||
prepare_threshold: int,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Prepare standby account/token asynchronously when needed."""
|
"""Prepare standby account/token asynchronously when needed."""
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
|
@ -9,6 +8,7 @@ from playwright.async_api import BrowserContext
|
||||||
from email_providers import BaseProvider
|
from email_providers import BaseProvider
|
||||||
from email_providers import MailTmProvider
|
from email_providers import MailTmProvider
|
||||||
from providers.base import Provider, ProviderTokens
|
from providers.base import Provider, ProviderTokens
|
||||||
|
from utils.env import parse_int_env
|
||||||
from .tokens import (
|
from .tokens import (
|
||||||
clear_next_tokens,
|
clear_next_tokens,
|
||||||
load_next_tokens,
|
load_next_tokens,
|
||||||
|
|
@ -24,7 +24,13 @@ from .registration import register_chatgpt_account
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
|
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
|
||||||
CHATGPT_SWITCH_THRESHOLD = int(os.environ.get("CHATGPT_SWITCH_THRESHOLD", "95"))
|
CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
|
||||||
|
CHATGPT_SWITCH_THRESHOLD = parse_int_env(
|
||||||
|
"CHATGPT_SWITCH_THRESHOLD",
|
||||||
|
95,
|
||||||
|
0,
|
||||||
|
100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTProvider(Provider):
|
class ChatGPTProvider(Provider):
|
||||||
|
|
@ -37,6 +43,14 @@ class ChatGPTProvider(Provider):
|
||||||
self.email_provider_factory = email_provider_factory or MailTmProvider
|
self.email_provider_factory = email_provider_factory or MailTmProvider
|
||||||
self._token_write_lock = asyncio.Lock()
|
self._token_write_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prepare_threshold(self) -> int:
|
||||||
|
return CHATGPT_PREPARE_THRESHOLD
|
||||||
|
|
||||||
|
@property
|
||||||
|
def switch_threshold(self) -> int | None:
|
||||||
|
return CHATGPT_SWITCH_THRESHOLD
|
||||||
|
|
||||||
async def _register_with_retries(self) -> bool:
|
async def _register_with_retries(self) -> bool:
|
||||||
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -99,21 +113,23 @@ class ChatGPTProvider(Provider):
|
||||||
|
|
||||||
async def ensure_next_account(self) -> bool:
|
async def ensure_next_account(self) -> bool:
|
||||||
next_tokens = load_next_tokens()
|
next_tokens = load_next_tokens()
|
||||||
if next_tokens and not next_tokens.is_expired:
|
if next_tokens:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async with self._token_write_lock:
|
async with self._token_write_lock:
|
||||||
next_tokens = load_next_tokens()
|
next_tokens = load_next_tokens()
|
||||||
if next_tokens and not next_tokens.is_expired:
|
if next_tokens:
|
||||||
return True
|
return True
|
||||||
return await self._create_next_account_under_lock()
|
return await self._create_next_account_under_lock()
|
||||||
|
|
||||||
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
||||||
|
return usage_percent >= self.prepare_threshold and not bool(load_next_tokens())
|
||||||
|
|
||||||
async def ensure_standby_account(
|
async def ensure_standby_account(
|
||||||
self,
|
self,
|
||||||
usage_percent: int,
|
usage_percent: int,
|
||||||
prepare_threshold: int,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if usage_percent >= prepare_threshold:
|
if usage_percent >= self.prepare_threshold:
|
||||||
await self.ensure_next_account()
|
await self.ensure_next_account()
|
||||||
|
|
||||||
async def maybe_switch_active_account(self, usage_percent: int) -> bool:
|
async def maybe_switch_active_account(self, usage_percent: int) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -24,6 +23,7 @@ from playwright.async_api import (
|
||||||
from browser import launch as launch_browser
|
from browser import launch as launch_browser
|
||||||
from email_providers import BaseProvider
|
from email_providers import BaseProvider
|
||||||
from providers.base import ProviderTokens
|
from providers.base import ProviderTokens
|
||||||
|
from utils.randoms import generate_password
|
||||||
from .tokens import CLIENT_ID
|
from .tokens import CLIENT_ID
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -56,11 +56,6 @@ async def save_error_screenshot(page: Page | None, step: str):
|
||||||
logger.warning("Failed to save screenshot at step %s: %s", step, e)
|
logger.warning("Failed to save screenshot at step %s: %s", step, e)
|
||||||
|
|
||||||
|
|
||||||
def generate_password(length: int = 20) -> str:
|
|
||||||
alphabet = string.ascii_letters + string.digits
|
|
||||||
return "".join(random.choice(alphabet) for _ in range(length))
|
|
||||||
|
|
||||||
|
|
||||||
def generate_name() -> str:
|
def generate_name() -> str:
|
||||||
first_names = [
|
first_names = [
|
||||||
"James",
|
"James",
|
||||||
|
|
@ -255,8 +250,8 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
||||||
raise RuntimeError(f"Token exchange response parse error: {e}") from e
|
raise RuntimeError(f"Token exchange response parse error: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
async def get_latest_code(email_provider: BaseProvider) -> str | None:
|
||||||
message = await email_provider.get_latest_message(email)
|
message = await email_provider.get_latest_message()
|
||||||
if not message:
|
if not message:
|
||||||
return None
|
return None
|
||||||
return extract_verification_code(message)
|
return extract_verification_code(message)
|
||||||
|
|
@ -403,7 +398,7 @@ async def register_chatgpt_account(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("[3/5] Getting verification message from email provider...")
|
logger.info("[3/5] Getting verification message from email provider...")
|
||||||
code = await get_latest_code(email_provider, email)
|
code = await get_latest_code(email_provider)
|
||||||
if not code:
|
if not code:
|
||||||
raise AutomationError(
|
raise AutomationError(
|
||||||
"email_provider", "Email provider returned no verification message"
|
"email_provider", "Email provider returned no verification message"
|
||||||
|
|
@ -472,7 +467,7 @@ async def register_chatgpt_account(
|
||||||
|
|
||||||
if await oauth_needs_email_check(oauth_page):
|
if await oauth_needs_email_check(oauth_page):
|
||||||
logger.info("OAuth requested email confirmation code")
|
logger.info("OAuth requested email confirmation code")
|
||||||
new_code = await get_latest_code(email_provider, email)
|
new_code = await get_latest_code(email_provider)
|
||||||
if new_code and new_code != last_oauth_email_code:
|
if new_code and new_code != last_oauth_email_code:
|
||||||
filled = await fill_oauth_code_if_present(oauth_page, new_code)
|
filled = await fill_oauth_code_if_present(oauth_page, new_code)
|
||||||
if filled:
|
if filled:
|
||||||
|
|
|
||||||
101
src/server.py
101
src/server.py
|
|
@ -1,41 +1,15 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web, web_log
|
||||||
|
|
||||||
from providers.chatgpt import ChatGPTProvider
|
from providers.chatgpt import ChatGPTProvider
|
||||||
from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD
|
|
||||||
from providers.base import Provider
|
from providers.base import Provider
|
||||||
|
from utils.env import parse_int_env
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
PORT = parse_int_env("PORT", 8080, 1, 65535)
|
||||||
|
|
||||||
def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int:
|
|
||||||
raw = os.environ.get(name)
|
|
||||||
if raw is None:
|
|
||||||
return default
|
|
||||||
try:
|
|
||||||
value = int(raw)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Invalid %s=%r, using default %s", name, raw, default)
|
|
||||||
return default
|
|
||||||
if value < minimum or value > maximum:
|
|
||||||
logger.warning(
|
|
||||||
"%s=%s out of range [%s,%s], using default %s",
|
|
||||||
name,
|
|
||||||
value,
|
|
||||||
minimum,
|
|
||||||
maximum,
|
|
||||||
default,
|
|
||||||
)
|
|
||||||
return default
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
PORT = _parse_int_env("PORT", 8080, 1, 65535)
|
|
||||||
CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
|
|
||||||
LIMIT_EXHAUSTED_PERCENT = 100
|
LIMIT_EXHAUSTED_PERCENT = 100
|
||||||
|
|
||||||
PROVIDERS: dict[str, Provider] = {
|
PROVIDERS: dict[str, Provider] = {
|
||||||
|
|
@ -45,11 +19,13 @@ PROVIDERS: dict[str, Provider] = {
|
||||||
background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS}
|
background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS}
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
class AccessLogger(web_log.AccessLogger):
|
||||||
async def request_log_middleware(request: web.Request, handler):
|
def log(
|
||||||
response = await handler(request)
|
self, request: web.BaseRequest, response: web.StreamResponse, time: float
|
||||||
logger.info("%s %s -> %s", request.method, request.path_qs, response.status)
|
) -> None:
|
||||||
return response
|
if request.path == "/health":
|
||||||
|
return
|
||||||
|
super().log(request, response, time)
|
||||||
|
|
||||||
|
|
||||||
def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]:
|
def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]:
|
||||||
|
|
@ -63,9 +39,17 @@ def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | b
|
||||||
|
|
||||||
|
|
||||||
def get_prepare_threshold(provider_name: str) -> int:
|
def get_prepare_threshold(provider_name: str) -> int:
|
||||||
if provider_name == "chatgpt":
|
provider = PROVIDERS.get(provider_name)
|
||||||
return CHATGPT_PREPARE_THRESHOLD
|
if not provider:
|
||||||
return 100
|
return 100
|
||||||
|
return provider.prepare_threshold
|
||||||
|
|
||||||
|
|
||||||
|
def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bool:
|
||||||
|
provider = PROVIDERS.get(provider_name)
|
||||||
|
if not provider:
|
||||||
|
return False
|
||||||
|
return provider.should_prepare_standby(usage_percent)
|
||||||
|
|
||||||
|
|
||||||
async def ensure_provider_token_ready(provider_name: str):
|
async def ensure_provider_token_ready(provider_name: str):
|
||||||
|
|
@ -103,10 +87,13 @@ async def ensure_standby_task(provider_name: str, usage_percent: int, reason: st
|
||||||
provider = PROVIDERS.get(provider_name)
|
provider = PROVIDERS.get(provider_name)
|
||||||
if not provider:
|
if not provider:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not provider.should_prepare_standby(usage_percent):
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("[%s] Preparing standby in background (%s)", provider_name, reason)
|
logger.info("[%s] Preparing standby in background (%s)", provider_name, reason)
|
||||||
threshold = get_prepare_threshold(provider_name)
|
await provider.ensure_standby_account(usage_percent)
|
||||||
await provider.ensure_standby_account(usage_percent, threshold)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[%s] Unhandled standby preparation error", provider_name)
|
logger.exception("[%s] Unhandled standby preparation error", provider_name)
|
||||||
|
|
||||||
|
|
@ -155,11 +142,11 @@ async def token_handler(request: web.Request) -> web.Response:
|
||||||
logger.info("[%s] Active account switched before response", provider_name)
|
logger.info("[%s] Active account switched before response", provider_name)
|
||||||
|
|
||||||
prepare_threshold = get_prepare_threshold(provider_name)
|
prepare_threshold = get_prepare_threshold(provider_name)
|
||||||
if usage_percent >= prepare_threshold:
|
if should_trigger_standby_prepare(provider_name, usage_percent):
|
||||||
trigger_standby_prepare(
|
trigger_standby_prepare(
|
||||||
provider_name,
|
provider_name,
|
||||||
usage_percent,
|
usage_percent,
|
||||||
f"usage {usage_percent}% >= threshold {prepare_threshold}%",
|
f"usage {usage_percent}% reached standby policy",
|
||||||
)
|
)
|
||||||
|
|
||||||
remaining_percent = int(
|
remaining_percent = int(
|
||||||
|
|
@ -203,6 +190,11 @@ async def token_handler(request: web.Request) -> web.Response:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def health_handler(request: web.Request) -> web.Response:
|
||||||
|
del request
|
||||||
|
return web.Response(text="ok")
|
||||||
|
|
||||||
|
|
||||||
async def on_startup(app: web.Application):
|
async def on_startup(app: web.Application):
|
||||||
del app
|
del app
|
||||||
for provider_name in PROVIDERS:
|
for provider_name in PROVIDERS:
|
||||||
|
|
@ -220,18 +212,35 @@ async def on_cleanup(app: web.Application):
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> web.Application:
|
def create_app() -> web.Application:
|
||||||
app = web.Application(middlewares=[request_log_middleware])
|
app = web.Application()
|
||||||
app.on_startup.append(on_startup)
|
app.on_startup.append(on_startup)
|
||||||
app.on_cleanup.append(on_cleanup)
|
app.on_cleanup.append(on_cleanup)
|
||||||
|
app.router.add_get("/health", health_handler)
|
||||||
app.router.add_get("/{provider}/token", token_handler)
|
app.router.add_get("/{provider}/token", token_handler)
|
||||||
app.router.add_get("/token", token_handler)
|
app.router.add_get("/token", token_handler)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
logger.info("Starting token service on port %s", PORT)
|
logger.info("Starting token service on port %s", PORT)
|
||||||
logger.info("ChatGPT prepare-next threshold: %s%%", CHATGPT_PREPARE_THRESHOLD)
|
chatgpt_provider = PROVIDERS.get("chatgpt")
|
||||||
logger.info("ChatGPT switch threshold: %s%%", CHATGPT_SWITCH_THRESHOLD)
|
if chatgpt_provider:
|
||||||
|
logger.info(
|
||||||
|
"ChatGPT prepare-next threshold: %s%%", chatgpt_provider.prepare_threshold
|
||||||
|
)
|
||||||
|
if chatgpt_provider.switch_threshold is not None:
|
||||||
|
logger.info(
|
||||||
|
"ChatGPT switch threshold: %s%%", chatgpt_provider.switch_threshold
|
||||||
|
)
|
||||||
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
||||||
app = create_app()
|
app = create_app()
|
||||||
web.run_app(app, host="0.0.0.0", port=PORT)
|
web.run_app(
|
||||||
|
app,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=PORT,
|
||||||
|
access_log_class=AccessLogger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
||||||
4
src/utils/__init__.py
Normal file
4
src/utils/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .env import parse_int_env
|
||||||
|
from .randoms import generate_password
|
||||||
|
|
||||||
|
__all__ = ["parse_int_env", "generate_password"]
|
||||||
22
src/utils/env.py
Normal file
22
src/utils/env.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def parse_int_env(
|
||||||
|
name: str,
|
||||||
|
default: int,
|
||||||
|
minimum: int,
|
||||||
|
maximum: int,
|
||||||
|
) -> int:
|
||||||
|
raw = os.environ.get(name)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = int(raw)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
if value < minimum or value > maximum:
|
||||||
|
return default
|
||||||
|
|
||||||
|
return value
|
||||||
13
src/utils/randoms.py
Normal file
13
src/utils/randoms.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
import random
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
|
||||||
|
|
||||||
|
def generate_password(
|
||||||
|
length: int = 20,
|
||||||
|
*,
|
||||||
|
secure: bool = True,
|
||||||
|
) -> str:
|
||||||
|
alphabet = string.ascii_letters + string.digits
|
||||||
|
chooser = secrets.choice if secure else random.choice
|
||||||
|
return "".join(chooser(alphabet) for _ in range(length))
|
||||||
|
|
@ -3,6 +3,7 @@ import json
|
||||||
|
|
||||||
import server
|
import server
|
||||||
from providers.base import Provider, ProviderTokens
|
from providers.base import Provider, ProviderTokens
|
||||||
|
from utils.env import parse_int_env
|
||||||
|
|
||||||
|
|
||||||
class FakeRequest:
|
class FakeRequest:
|
||||||
|
|
@ -25,9 +26,17 @@ class FakeProvider(Provider):
|
||||||
"secondary_window": None,
|
"secondary_window": None,
|
||||||
}
|
}
|
||||||
self._rotate = rotate
|
self._rotate = rotate
|
||||||
|
self._prepare_threshold = 80
|
||||||
self.get_token_calls = 0
|
self.get_token_calls = 0
|
||||||
self.standby_calls = 0
|
self.standby_calls = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prepare_threshold(self) -> int:
|
||||||
|
return self._prepare_threshold
|
||||||
|
|
||||||
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
||||||
|
return usage_percent >= self.prepare_threshold
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "fake"
|
return "fake"
|
||||||
|
|
@ -54,9 +63,10 @@ class FakeProvider(Provider):
|
||||||
return self._rotate
|
return self._rotate
|
||||||
|
|
||||||
async def ensure_standby_account(
|
async def ensure_standby_account(
|
||||||
self, usage_percent: int, prepare_threshold: int
|
self,
|
||||||
|
usage_percent: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
_ = usage_percent, prepare_threshold
|
_ = usage_percent
|
||||||
self.standby_calls += 1
|
self.standby_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,17 +76,17 @@ def _response_json(resp) -> dict:
|
||||||
|
|
||||||
def test_parse_int_env_defaults(monkeypatch):
|
def test_parse_int_env_defaults(monkeypatch):
|
||||||
monkeypatch.delenv("X_TEST", raising=False)
|
monkeypatch.delenv("X_TEST", raising=False)
|
||||||
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||||
|
|
||||||
|
|
||||||
def test_parse_int_env_invalid(monkeypatch):
|
def test_parse_int_env_invalid(monkeypatch):
|
||||||
monkeypatch.setenv("X_TEST", "abc")
|
monkeypatch.setenv("X_TEST", "abc")
|
||||||
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||||
|
|
||||||
|
|
||||||
def test_parse_int_env_out_of_range(monkeypatch):
|
def test_parse_int_env_out_of_range(monkeypatch):
|
||||||
monkeypatch.setenv("X_TEST", "999")
|
monkeypatch.setenv("X_TEST", "999")
|
||||||
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||||
|
|
||||||
|
|
||||||
def test_build_limit_fields():
|
def test_build_limit_fields():
|
||||||
|
|
@ -90,7 +100,10 @@ def test_build_limit_fields():
|
||||||
|
|
||||||
|
|
||||||
def test_get_prepare_threshold():
|
def test_get_prepare_threshold():
|
||||||
assert server.get_prepare_threshold("chatgpt") == server.CHATGPT_PREPARE_THRESHOLD
|
assert (
|
||||||
|
server.get_prepare_threshold("chatgpt")
|
||||||
|
== server.PROVIDERS["chatgpt"].prepare_threshold
|
||||||
|
)
|
||||||
assert server.get_prepare_threshold("unknown") == 100
|
assert server.get_prepare_threshold("unknown") == 100
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -100,12 +113,16 @@ def test_token_handler_unknown_provider(monkeypatch):
|
||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_handler_ok():
|
||||||
|
resp = asyncio.run(server.health_handler(object()))
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.text == "ok"
|
||||||
|
|
||||||
|
|
||||||
def test_token_handler_success(monkeypatch):
|
def test_token_handler_success(monkeypatch):
|
||||||
provider = FakeProvider()
|
provider = FakeProvider()
|
||||||
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||||
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||||
monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80)
|
|
||||||
|
|
||||||
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||||
data = _response_json(resp)
|
data = _response_json(resp)
|
||||||
|
|
||||||
|
|
@ -124,10 +141,9 @@ def test_token_handler_triggers_standby(monkeypatch):
|
||||||
def fake_trigger(name, usage_percent, reason):
|
def fake_trigger(name, usage_percent, reason):
|
||||||
assert name == "fake"
|
assert name == "fake"
|
||||||
assert usage_percent == 90
|
assert usage_percent == 90
|
||||||
assert "threshold" in reason
|
assert "standby policy" in reason
|
||||||
called["value"] = True
|
called["value"] = True
|
||||||
|
|
||||||
monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80)
|
|
||||||
monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger)
|
monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger)
|
||||||
|
|
||||||
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||||
|
|
@ -142,7 +158,6 @@ def test_token_handler_rotation_path(monkeypatch):
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||||
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||||
monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80)
|
|
||||||
monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None)
|
monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None)
|
||||||
|
|
||||||
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue