refactor!: change the entire purpose of this script
This commit is contained in:
parent
217e176975
commit
71d1050adb
20 changed files with 1124 additions and 872 deletions
147
src/server.py
Normal file
147
src/server.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from providers.chatgpt import ChatGPTProvider
|
||||
|
||||
PORT = int(os.environ.get("PORT", "8080"))
|
||||
USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85"))
|
||||
LIMIT_EXHAUSTED_PERCENT = 100
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Registry of available providers
|
||||
PROVIDERS = {
|
||||
"chatgpt": ChatGPTProvider(),
|
||||
}
|
||||
|
||||
refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()}
|
||||
background_refresh_tasks: dict[str, asyncio.Task | None] = {
|
||||
name: None for name in PROVIDERS.keys()
|
||||
}
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def request_log_middleware(request: web.Request, handler):
|
||||
response = await handler(request)
|
||||
logger.info("%s %s -> %s", request.method, request.path_qs, response.status)
|
||||
return response
|
||||
|
||||
|
||||
def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
||||
remaining = max(0, 100 - usage_percent)
|
||||
return {
|
||||
"used_percent": usage_percent,
|
||||
"remaining_percent": remaining,
|
||||
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
|
||||
"needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD,
|
||||
}
|
||||
|
||||
|
||||
async def issue_new_token(provider_name: str) -> str | None:
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
async with refresh_locks[provider_name]:
|
||||
logger.info(f"[{provider_name}] Generating new token")
|
||||
success = await provider.register_new_account()
|
||||
if not success:
|
||||
logger.error(f"[{provider_name}] Token generation failed")
|
||||
return None
|
||||
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
logger.error(f"[{provider_name}] Token was generated but not available")
|
||||
return None
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def background_refresh_worker(provider_name: str, reason: str):
|
||||
try:
|
||||
logger.info(f"[{provider_name}] Starting background token refresh ({reason})")
|
||||
new_token = await issue_new_token(provider_name)
|
||||
if new_token:
|
||||
logger.info(f"[{provider_name}] Background token refresh completed")
|
||||
else:
|
||||
logger.error(f"[{provider_name}] Background token refresh failed")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"[{provider_name}] Unhandled error in background token refresh"
|
||||
)
|
||||
|
||||
|
||||
def trigger_background_refresh(provider_name: str, reason: str):
|
||||
task = background_refresh_tasks.get(provider_name)
|
||||
if task and not task.done():
|
||||
logger.info(
|
||||
f"[{provider_name}] Background refresh already running, skip ({reason})"
|
||||
)
|
||||
return
|
||||
background_refresh_tasks[provider_name] = asyncio.create_task(
|
||||
background_refresh_worker(provider_name, reason)
|
||||
)
|
||||
|
||||
|
||||
async def token_handler(request: web.Request) -> web.Response:
|
||||
provider_name = request.match_info.get("provider", "chatgpt")
|
||||
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
return web.json_response(
|
||||
{"error": f"Unknown provider: {provider_name}"},
|
||||
status=404,
|
||||
)
|
||||
|
||||
# Get or create token
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
return web.json_response(
|
||||
{"error": "Failed to get active token"},
|
||||
status=503,
|
||||
)
|
||||
|
||||
# Get usage info
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" in usage_info:
|
||||
return web.json_response(
|
||||
{"error": usage_info["error"]},
|
||||
status=503,
|
||||
)
|
||||
|
||||
usage_percent = usage_info.get("used_percent", 0)
|
||||
|
||||
# Trigger background refresh if needed
|
||||
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
||||
trigger_background_refresh(
|
||||
provider_name,
|
||||
f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%",
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"token": token,
|
||||
"limit": build_limit(usage_percent),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_app() -> web.Application:
|
||||
app = web.Application(middlewares=[request_log_middleware])
|
||||
# New route: /{provider}/token
|
||||
app.router.add_get("/{provider}/token", token_handler)
|
||||
# Legacy route for backward compatibility
|
||||
app.router.add_get("/token", token_handler)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting token service on port %s", PORT)
|
||||
logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD)
|
||||
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
||||
app = create_app()
|
||||
web.run_app(app, host="0.0.0.0", port=PORT)
|
||||
Loading…
Add table
Add a link
Reference in a new issue