diff --git a/src/providers/base.py b/src/providers/base.py index 5f9cdc6..d1deb94 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -71,8 +71,9 @@ class Provider(ABC): """Usage percent when provider may switch active account/token.""" return None - def should_prepare_standby(self) -> bool: + def should_prepare_standby(self, usage_percent: int) -> bool: """Whether standby preparation should be triggered for current usage.""" + _ = usage_percent return False async def ensure_standby_account( diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py index b536611..67ffa88 100644 --- a/src/providers/chatgpt/provider.py +++ b/src/providers/chatgpt/provider.py @@ -122,8 +122,8 @@ class ChatGPTProvider(Provider): return True return await self._create_next_account_under_lock() - def should_prepare_standby(self) -> bool: - return bool(load_next_tokens()) + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold and not bool(load_next_tokens()) async def ensure_standby_account( self, diff --git a/src/server.py b/src/server.py index 464d92b..39fc3ca 100644 --- a/src/server.py +++ b/src/server.py @@ -49,7 +49,7 @@ def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bo provider = PROVIDERS.get(provider_name) if not provider: return False - return provider.should_prepare_standby() + return provider.should_prepare_standby(usage_percent) async def ensure_provider_token_ready(provider_name: str): @@ -88,7 +88,7 @@ async def ensure_standby_task(provider_name: str, usage_percent: int, reason: st if not provider: return - if not provider.should_prepare_standby(): + if not provider.should_prepare_standby(usage_percent): return try: @@ -212,6 +212,7 @@ async def on_cleanup(app: web.Application): def create_app() -> web.Application: + app = web.Application() app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) app.router.add_get("/health", health_handler) @@ -220,7 +221,7 @@ def create_app() -> web.Application: return app -if __name__ == "__main__": +def main(): logger.info("Starting token service on port %s", PORT) chatgpt_provider = PROVIDERS.get("chatgpt") if chatgpt_provider: @@ -239,3 +240,7 @@ if __name__ == "__main__": port=PORT, access_log_class=AccessLogger, ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_server_unit.py b/tests/test_server_unit.py index 8ebefb8..864d374 100644 --- a/tests/test_server_unit.py +++ b/tests/test_server_unit.py @@ -34,8 +34,8 @@ class FakeProvider(Provider): def prepare_threshold(self) -> int: return self._prepare_threshold - def should_prepare_standby(self) -> bool: - return False + def should_prepare_standby(self, usage_percent: int) -> bool: + return usage_percent >= self.prepare_threshold @property def name(self) -> str: