147 lines
4.4 KiB
Python
147 lines
4.4 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
|
|
from aiohttp import web
|
|
|
|
from providers.chatgpt import ChatGPTProvider
|
|
|
|
PORT = int(os.environ.get("PORT", "8080"))
|
|
USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85"))
|
|
LIMIT_EXHAUSTED_PERCENT = 100
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Registry of available providers
|
|
PROVIDERS = {
|
|
"chatgpt": ChatGPTProvider(),
|
|
}
|
|
|
|
refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()}
|
|
background_refresh_tasks: dict[str, asyncio.Task | None] = {
|
|
name: None for name in PROVIDERS.keys()
|
|
}
|
|
|
|
|
|
@web.middleware
|
|
async def request_log_middleware(request: web.Request, handler):
|
|
response = await handler(request)
|
|
logger.info("%s %s -> %s", request.method, request.path_qs, response.status)
|
|
return response
|
|
|
|
|
|
def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
|
remaining = max(0, 100 - usage_percent)
|
|
return {
|
|
"used_percent": usage_percent,
|
|
"remaining_percent": remaining,
|
|
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
|
|
"needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD,
|
|
}
|
|
|
|
|
|
async def issue_new_token(provider_name: str) -> str | None:
|
|
provider = PROVIDERS.get(provider_name)
|
|
if not provider:
|
|
return None
|
|
|
|
async with refresh_locks[provider_name]:
|
|
logger.info(f"[{provider_name}] Generating new token")
|
|
success = await provider.register_new_account()
|
|
if not success:
|
|
logger.error(f"[{provider_name}] Token generation failed")
|
|
return None
|
|
|
|
token = await provider.get_token()
|
|
if not token:
|
|
logger.error(f"[{provider_name}] Token was generated but not available")
|
|
return None
|
|
|
|
return token
|
|
|
|
|
|
async def background_refresh_worker(provider_name: str, reason: str):
|
|
try:
|
|
logger.info(f"[{provider_name}] Starting background token refresh ({reason})")
|
|
new_token = await issue_new_token(provider_name)
|
|
if new_token:
|
|
logger.info(f"[{provider_name}] Background token refresh completed")
|
|
else:
|
|
logger.error(f"[{provider_name}] Background token refresh failed")
|
|
except Exception:
|
|
logger.exception(
|
|
f"[{provider_name}] Unhandled error in background token refresh"
|
|
)
|
|
|
|
|
|
def trigger_background_refresh(provider_name: str, reason: str):
|
|
task = background_refresh_tasks.get(provider_name)
|
|
if task and not task.done():
|
|
logger.info(
|
|
f"[{provider_name}] Background refresh already running, skip ({reason})"
|
|
)
|
|
return
|
|
background_refresh_tasks[provider_name] = asyncio.create_task(
|
|
background_refresh_worker(provider_name, reason)
|
|
)
|
|
|
|
|
|
async def token_handler(request: web.Request) -> web.Response:
|
|
provider_name = request.match_info.get("provider", "chatgpt")
|
|
|
|
provider = PROVIDERS.get(provider_name)
|
|
if not provider:
|
|
return web.json_response(
|
|
{"error": f"Unknown provider: {provider_name}"},
|
|
status=404,
|
|
)
|
|
|
|
# Get or create token
|
|
token = await provider.get_token()
|
|
if not token:
|
|
return web.json_response(
|
|
{"error": "Failed to get active token"},
|
|
status=503,
|
|
)
|
|
|
|
# Get usage info
|
|
usage_info = await provider.get_usage_info(token)
|
|
if "error" in usage_info:
|
|
return web.json_response(
|
|
{"error": usage_info["error"]},
|
|
status=503,
|
|
)
|
|
|
|
usage_percent = usage_info.get("used_percent", 0)
|
|
|
|
# Trigger background refresh if needed
|
|
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
|
trigger_background_refresh(
|
|
provider_name,
|
|
f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%",
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"token": token,
|
|
"limit": build_limit(usage_percent),
|
|
}
|
|
)
|
|
|
|
|
|
def create_app() -> web.Application:
|
|
app = web.Application(middlewares=[request_log_middleware])
|
|
# New route: /{provider}/token
|
|
app.router.add_get("/{provider}/token", token_handler)
|
|
# Legacy route for backward compatibility
|
|
app.router.add_get("/token", token_handler)
|
|
return app
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("Starting token service on port %s", PORT)
|
|
logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD)
|
|
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
|
app = create_app()
|
|
web.run_app(app, host="0.0.0.0", port=PORT)
|