diff --git a/src/providers/base.py b/src/providers/base.py index 3185c5b..478ab53 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -61,10 +61,23 @@ class Provider(ABC): """Rotate active account/token if provider policy requires it.""" return False + @property + def prepare_threshold(self) -> int: + """Usage percent when provider should prepare standby account/token.""" + return 100 + + @property + def switch_threshold(self) -> int | None: + """Usage percent when provider may switch active account/token.""" + return None + + def should_prepare_standby(self, usage_percent: int) -> bool: + """Whether standby preparation should be triggered for current usage.""" + return False + async def ensure_standby_account( self, usage_percent: int, - prepare_threshold: int, ) -> None: """Prepare standby account/token asynchronously when needed.""" return None diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index e171b51..603c247 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -43,6 +43,14 @@ class ChatGPTProvider(Provider): self.email_provider_factory = email_provider_factory or MailTmProvider self._token_write_lock = asyncio.Lock() + @property + def prepare_threshold(self) -> int: + return CHATGPT_PREPARE_THRESHOLD + + @property + def switch_threshold(self) -> int | None: + return CHATGPT_SWITCH_THRESHOLD + async def _register_with_retries(self) -> bool: for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1): logger.info( @@ -105,21 +113,23 @@ class ChatGPTProvider(Provider): async def ensure_next_account(self) -> bool: next_tokens = load_next_tokens() - if next_tokens and not next_tokens.is_expired: + if next_tokens: return True async with self._token_write_lock: next_tokens = load_next_tokens() - if next_tokens and not next_tokens.is_expired: + if next_tokens: return True return await self._create_next_account_under_lock() + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold and bool(load_next_tokens()) + async def ensure_standby_account( self, usage_percent: int, - prepare_threshold: int, ) -> None: - if usage_percent >= prepare_threshold: + if self.should_prepare_standby(usage_percent): await self.ensure_next_account() async def maybe_switch_active_account(self, usage_percent: int) -> bool: diff --git a/src/server.py b/src/server.py index 6e42ecc..be2e905 100644 --- a/src/server.py +++ b/src/server.py @@ -1,14 +1,9 @@ import asyncio import logging -import os from aiohttp import web from providers.chatgpt import ChatGPTProvider -from providers.chatgpt.provider import ( - CHATGPT_PREPARE_THRESHOLD, - CHATGPT_SWITCH_THRESHOLD, -) from providers.base import Provider from utils.env import parse_int_env @@ -42,9 +37,17 @@ def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | b def get_prepare_threshold(provider_name: str) -> int: - if provider_name == "chatgpt": - return CHATGPT_PREPARE_THRESHOLD - return 100 + provider = PROVIDERS.get(provider_name) + if not provider: + return 100 + return provider.prepare_threshold + + +def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bool: + provider = PROVIDERS.get(provider_name) + if not provider: + return False + return provider.should_prepare_standby(usage_percent) async def ensure_provider_token_ready(provider_name: str): @@ -82,10 +85,13 @@ async def ensure_standby_task(provider_name: str, usage_percent: int, reason: st provider = PROVIDERS.get(provider_name) if not provider: return + + if not provider.should_prepare_standby(usage_percent): + return + try: logger.info("[%s] Preparing standby in background (%s)", provider_name, reason) - threshold = get_prepare_threshold(provider_name) - await provider.ensure_standby_account(usage_percent, threshold) + await provider.ensure_standby_account(usage_percent) except Exception: logger.exception("[%s] Unhandled standby preparation error", provider_name) @@ -134,11 +140,11 @@ async def token_handler(request: web.Request) -> web.Response: logger.info("[%s] Active account switched before response", provider_name) prepare_threshold = get_prepare_threshold(provider_name) - if usage_percent >= prepare_threshold: + if should_trigger_standby_prepare(provider_name, usage_percent): trigger_standby_prepare( provider_name, usage_percent, - f"usage {usage_percent}% >= threshold {prepare_threshold}%", + f"usage {usage_percent}% reached standby policy", ) remaining_percent = int( @@ -209,8 +215,15 @@ def create_app() -> web.Application: if __name__ == "__main__": logger.info("Starting token service on port %s", PORT) - logger.info("ChatGPT prepare-next threshold: %s%%", CHATGPT_PREPARE_THRESHOLD) - logger.info("ChatGPT switch threshold: %s%%", CHATGPT_SWITCH_THRESHOLD) + chatgpt_provider = PROVIDERS.get("chatgpt") + if chatgpt_provider: + logger.info( + "ChatGPT prepare-next threshold: %s%%", chatgpt_provider.prepare_threshold + ) + if chatgpt_provider.switch_threshold is not None: + logger.info( + "ChatGPT switch threshold: %s%%", chatgpt_provider.switch_threshold + ) logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) app = create_app() web.run_app(app, host="0.0.0.0", port=PORT) diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py index bfe06d4..68b3030 100644 --- a/tests/test_server_unit.py +++ b/tests/test_server_unit.py @@ -26,9 +26,17 @@ class FakeProvider(Provider): "secondary_window": None, } self._rotate = rotate + self._prepare_threshold = 80 self.get_token_calls = 0 self.standby_calls = 0 + @property + def prepare_threshold(self) -> int: + return self._prepare_threshold + + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold + @property def name(self) -> str: return "fake" @@ -55,9 +63,10 @@ class FakeProvider(Provider): return self._rotate async def ensure_standby_account( - self, usage_percent: int, prepare_threshold: int + self, + usage_percent: int, ) -> None: - _ = usage_percent, prepare_threshold + _ = usage_percent self.standby_calls += 1 @@ -91,7 +100,10 @@ def test_build_limit_fields(): def test_get_prepare_threshold(): - assert server.get_prepare_threshold("chatgpt") == server.CHATGPT_PREPARE_THRESHOLD + assert ( + server.get_prepare_threshold("chatgpt") + == server.PROVIDERS["chatgpt"].prepare_threshold + ) assert server.get_prepare_threshold("unknown") == 100 @@ -105,8 +117,6 @@ def test_token_handler_success(monkeypatch): provider = FakeProvider() monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) monkeypatch.setattr(server, "background_tasks", {"fake": None}) - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) - resp = asyncio.run(server.token_handler(FakeRequest("fake"))) data = _response_json(resp) @@ -125,10 +135,9 @@ def test_token_handler_triggers_standby(monkeypatch): def fake_trigger(name, usage_percent, reason): assert name == "fake" assert usage_percent == 90 - assert "threshold" in reason + assert "standby policy" in reason called["value"] = True - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger) resp = asyncio.run(server.token_handler(FakeRequest("fake"))) @@ -143,7 +152,6 @@ def test_token_handler_rotation_path(monkeypatch): ) monkeypatch.setattr(server, "PROVIDERS", {"fake": provider}) monkeypatch.setattr(server, "background_tasks", {"fake": None}) - monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80) monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None) resp = asyncio.run(server.token_handler(FakeRequest("fake")))