import os import asyncio import logging import time import secrets import json import base64 import uuid from urllib.parse import urlencode from aiohttp import web import aiohttp from tokens import get_valid_tokens, load_tokens, DATA_DIR from codex_usage import get_usage_percent from get_new_token import get_new_token CODEX_BASE_URL = "https://chatgpt.com/backend-api" PORT = int(os.environ.get("PORT", "8080")) USAGE_THRESHOLD = int(os.environ.get("USAGE_THRESHOLD", "85")) CHECK_INTERVAL = int(os.environ.get("CHECK_INTERVAL", "60")) FAKE_EXPIRES_IN = 9999999999999 AUTH_FILE = DATA_DIR / "auth.json" JWT_AUTH_CLAIM_PATH = "https://api.openai.com/auth" JWT_PROFILE_CLAIM_PATH = "https://api.openai.com/profile" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) refresh_in_progress = False auth_codes: dict[str, dict] = {} def _b64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") def _generate_jwt_like() -> str: account_id = str(uuid.uuid4()) now = int(time.time()) header = {"alg": "HS256", "typ": "JWT"} user_id = f"user-{secrets.token_urlsafe(18)}" account_user_id = f"{user_id}__{account_id}" payload = { "aud": ["https://api.openai.com/v1"], "client_id": "app_EMoamEEZ73f0CkXaXp7hrann", "iss": "https://auth.openai.com", "iat": now, "nbf": now, "exp": now + 315360000, "jti": str(uuid.uuid4()), "scp": ["openid", "profile", "email", "offline_access"], "session_id": f"authsess_{secrets.token_urlsafe(24)}", JWT_AUTH_CLAIM_PATH: { "chatgpt_account_id": account_id, "chatgpt_account_user_id": account_user_id, "chatgpt_compute_residency": "no_constraint", "chatgpt_plan_type": "plus", "chatgpt_user_id": user_id, "user_id": user_id, }, JWT_PROFILE_CLAIM_PATH: { "email": f"proxy-{secrets.token_hex(4)}@example.local", "email_verified": True, }, "sub": f"auth0|{secrets.token_urlsafe(20)}", } head = _b64url(json.dumps(header, separators=(",", ":")).encode("utf-8")) body = _b64url(json.dumps(payload, separators=(",", ":")).encode("utf-8")) sign = _b64url(secrets.token_bytes(32)) return f"{head}.{body}.{sign}" def _generate_refresh_like() -> str: return f"rt_{secrets.token_urlsafe(40)}.{secrets.token_urlsafe(32)}" def _mask(value: str, head: int = 8, tail: int = 6) -> str: if not value: return "" if len(value) <= head + tail: return "" return f"{value[:head]}...{value[-tail:]}" def load_or_create_auth() -> dict: if AUTH_FILE.exists(): with open(AUTH_FILE) as f: data = json.load(f) if ( data.get("access_token") and data.get("refresh_token") and data.get("expires_at") ): return data DATA_DIR.mkdir(parents=True, exist_ok=True) access_token = _generate_jwt_like() data = { "access_token": access_token, "refresh_token": _generate_refresh_like(), "expires_at": FAKE_EXPIRES_IN, } with open(AUTH_FILE, "w") as f: json.dump(data, f, indent=2) return data @web.middleware async def request_log_middleware(request: web.Request, handler): started = time.perf_counter() response = None try: response = await handler(request) return response finally: elapsed_ms = int((time.perf_counter() - started) * 1000) status = getattr(response, "status", "ERR") logger.info( "%s %s -> %s (%d ms)", request.method, request.path_qs, status, elapsed_ms, ) def check_auth(request: web.Request) -> bool: auth_data = load_or_create_auth() expected_token = auth_data["access_token"] auth = request.headers.get("Authorization", "") if auth.lower().startswith("bearer "): token = auth[7:].strip() return token == expected_token return False async def oauth_authorize_handler(request: web.Request) -> web.Response: params = request.rel_url.query redirect_uri = params.get("redirect_uri") state = params.get("state", "") if not redirect_uri: return web.json_response( {"error": "invalid_request", "error_description": "Missing redirect_uri"}, status=400, ) code = f"ac_{secrets.token_urlsafe(48)}" auth_codes[code] = { "state": state, "created_at": time.time(), } query = urlencode( { "code": code, "scope": "openid profile email offline_access", "state": state, } ) location = f"{redirect_uri}?{query}" logger.info("OAuth authorize: issued code") raise web.HTTPFound(location=location) async def oauth_token_handler(request: web.Request) -> web.Response: auth_data = load_or_create_auth() content_type = request.content_type or "" grant_type = None refresh_token = None code = None if content_type.startswith("application/json"): body = await request.json() grant_type = body.get("grant_type") refresh_token = body.get("refresh_token") code = body.get("code") else: form = await request.post() grant_type = form.get("grant_type") refresh_token = form.get("refresh_token") code = form.get("code") if grant_type == "authorization_code": code = str(code) if code else "" if not code or code not in auth_codes: return web.json_response( { "error": "invalid_grant", "error_description": "Invalid authorization code", }, status=400, ) created_at = auth_codes[code]["created_at"] del auth_codes[code] if time.time() - created_at > 300: return web.json_response( { "error": "invalid_grant", "error_description": "Authorization code expired", }, status=400, ) return web.json_response( { "access_token": auth_data["access_token"], "refresh_token": auth_data["refresh_token"], "token_type": "Bearer", "expires_in": FAKE_EXPIRES_IN, } ) if grant_type == "refresh_token": if refresh_token != auth_data["refresh_token"]: return web.json_response( { "error": "invalid_grant", "error_description": "Invalid refresh token", }, status=400, ) return web.json_response( { "access_token": auth_data["access_token"], "refresh_token": auth_data["refresh_token"], "token_type": "Bearer", "expires_in": FAKE_EXPIRES_IN, } ) return web.json_response( { "error": "unsupported_grant_type", "error_description": "Only authorization_code and refresh_token are supported", }, status=400, ) async def refresh_tokens_task(): global refresh_in_progress if refresh_in_progress: logger.info("Token refresh already in progress") return refresh_in_progress = True logger.info("Starting token refresh...") try: success = await get_new_token(headless=False) if success: logger.info("Token refresh completed successfully") else: logger.error("Token refresh failed") except Exception as e: logger.error(f"Error during token refresh: {e}") finally: refresh_in_progress = False async def usage_monitor(): while True: for _ in range(1): tokens = load_tokens() if not tokens: if not refresh_in_progress: logger.warning("No tokens found, starting refresh...") asyncio.create_task(refresh_tokens_task()) break usage = get_usage_percent(tokens.access_token) if usage < 0: logger.warning("Failed to get usage, token may be invalid") asyncio.create_task(refresh_tokens_task()) break logger.info(f"Current usage: {usage}%") if usage >= USAGE_THRESHOLD: logger.info( f"Usage {usage}% >= threshold {USAGE_THRESHOLD}%, starting refresh..." ) asyncio.create_task(refresh_tokens_task()) break await asyncio.sleep(CHECK_INTERVAL) async def proxy_handler(request: web.Request) -> web.StreamResponse | web.Response: if not check_auth(request): auth = request.headers.get("Authorization", "") auth_preview = auth[:24] + ("..." if len(auth) > 24 else "") logger.warning( "Auth failed: method=%s path=%s auth_present=%s auth_preview=%s ua=%s", request.method, request.path, bool(auth), auth_preview, request.headers.get("User-Agent", ""), ) return web.json_response({"error": "Unauthorized"}, status=401) tokens = await get_valid_tokens() if not tokens: return web.json_response({"error": "No valid tokens"}, status=500) path = request.path target_url = f"{CODEX_BASE_URL}{path}" logger.info( "Proxying request: %s %s -> %s", request.method, request.path_qs, target_url, ) headers = {} for key, value in request.headers.items(): if key.lower() not in ("host", "authorization", "content-length"): headers[key] = value headers["Authorization"] = f"Bearer {tokens.access_token}" if request.method in ("POST", "PUT", "PATCH"): body = await request.read() else: body = None async with aiohttp.ClientSession() as session: try: async with session.request( method=request.method, url=target_url, headers=headers, data=body, params=request.query, ) as resp: content_type = resp.content_type or "application/json" is_stream = ( content_type == "text/event-stream" or "stream" in content_type ) if is_stream: response = web.StreamResponse( status=resp.status, reason=resp.reason, headers={ "Content-Type": content_type, "Cache-Control": "no-cache", "Connection": "keep-alive", }, ) await response.prepare(request) async for chunk in resp.content.iter_any(): await response.write(chunk) await response.write_eof() return response else: response_body = await resp.read() if resp.status >= 400: preview = response_body[:500].decode("utf-8", errors="replace") logger.warning( "Upstream error: status=%s path=%s body=%s", resp.status, request.path, preview, ) return web.Response( status=resp.status, body=response_body, headers={"Content-Type": content_type}, ) except aiohttp.ClientError as e: return web.json_response({"error": f"Proxy error: {e}"}, status=502) async def health_handler(request: web.Request) -> web.Response: tokens = await get_valid_tokens() usage = -1 if tokens: usage = get_usage_percent(tokens.access_token) return web.json_response( { "status": "ok" if tokens else "no_tokens", "has_tokens": tokens is not None, "usage_percent": usage, "refresh_in_progress": refresh_in_progress, } ) async def start_background_tasks(app: web.Application): app["usage_monitor"] = asyncio.create_task(usage_monitor()) async def cleanup_background_tasks(app: web.Application): app["usage_monitor"].cancel() try: await app["usage_monitor"] except asyncio.CancelledError: pass def create_app() -> web.Application: app = web.Application(middlewares=[request_log_middleware]) app.router.add_get("/oauth/authorize", oauth_authorize_handler) app.router.add_post("/oauth/token", oauth_token_handler) app.router.add_get("/health", health_handler) app.router.add_route("*", "/{path:.*}", proxy_handler) app.on_startup.append(start_background_tasks) app.on_cleanup.append(cleanup_background_tasks) return app if __name__ == "__main__": logger.info(f"Starting proxy on port {PORT}") logger.info(f"Usage threshold: {USAGE_THRESHOLD}%") logger.info(f"Check interval: {CHECK_INTERVAL}s") auth_data = load_or_create_auth() logger.info("Client access token: %s", _mask(auth_data["access_token"])) logger.info("Client refresh token: %s", _mask(auth_data["refresh_token"])) startup_tokens = load_tokens() if startup_tokens: logger.info("Upstream access token: %s", _mask(startup_tokens.access_token)) else: logger.warning("No upstream token found at %s", DATA_DIR / "tokens.json") app = create_app() web.run_app(app, host="0.0.0.0", port=PORT)