1
0
Fork 0
gibidy/src/proxy.py
2026-02-20 19:11:32 +03:00

436 lines
14 KiB
Python

import os
import asyncio
import logging
import time
import secrets
import json
import base64
import uuid
from urllib.parse import urlencode
from aiohttp import web
import aiohttp
from tokens import get_valid_tokens, load_tokens, DATA_DIR
from codex_usage import get_usage_percent
from get_new_token import get_new_token
CODEX_BASE_URL = "https://chatgpt.com/backend-api"
PORT = int(os.environ.get("PORT", "8080"))
USAGE_THRESHOLD = int(os.environ.get("USAGE_THRESHOLD", "85"))
CHECK_INTERVAL = int(os.environ.get("CHECK_INTERVAL", "60"))
FAKE_EXPIRES_IN = 9999999999999
AUTH_FILE = DATA_DIR / "auth.json"
JWT_AUTH_CLAIM_PATH = "https://api.openai.com/auth"
JWT_PROFILE_CLAIM_PATH = "https://api.openai.com/profile"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
refresh_in_progress = False
auth_codes: dict[str, dict] = {}
def _b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
def _generate_jwt_like() -> str:
account_id = str(uuid.uuid4())
now = int(time.time())
header = {"alg": "HS256", "typ": "JWT"}
user_id = f"user-{secrets.token_urlsafe(18)}"
account_user_id = f"{user_id}__{account_id}"
payload = {
"aud": ["https://api.openai.com/v1"],
"client_id": "app_EMoamEEZ73f0CkXaXp7hrann",
"iss": "https://auth.openai.com",
"iat": now,
"nbf": now,
"exp": now + 315360000,
"jti": str(uuid.uuid4()),
"scp": ["openid", "profile", "email", "offline_access"],
"session_id": f"authsess_{secrets.token_urlsafe(24)}",
JWT_AUTH_CLAIM_PATH: {
"chatgpt_account_id": account_id,
"chatgpt_account_user_id": account_user_id,
"chatgpt_compute_residency": "no_constraint",
"chatgpt_plan_type": "plus",
"chatgpt_user_id": user_id,
"user_id": user_id,
},
JWT_PROFILE_CLAIM_PATH: {
"email": f"proxy-{secrets.token_hex(4)}@example.local",
"email_verified": True,
},
"sub": f"auth0|{secrets.token_urlsafe(20)}",
}
head = _b64url(json.dumps(header, separators=(",", ":")).encode("utf-8"))
body = _b64url(json.dumps(payload, separators=(",", ":")).encode("utf-8"))
sign = _b64url(secrets.token_bytes(32))
return f"{head}.{body}.{sign}"
def _generate_refresh_like() -> str:
return f"rt_{secrets.token_urlsafe(40)}.{secrets.token_urlsafe(32)}"
def _mask(value: str, head: int = 8, tail: int = 6) -> str:
if not value:
return "<empty>"
if len(value) <= head + tail:
return "<hidden>"
return f"{value[:head]}...{value[-tail:]}"
def load_or_create_auth() -> dict:
if AUTH_FILE.exists():
with open(AUTH_FILE) as f:
data = json.load(f)
if (
data.get("access_token")
and data.get("refresh_token")
and data.get("expires_at")
):
return data
DATA_DIR.mkdir(parents=True, exist_ok=True)
access_token = _generate_jwt_like()
data = {
"access_token": access_token,
"refresh_token": _generate_refresh_like(),
"expires_at": FAKE_EXPIRES_IN,
}
with open(AUTH_FILE, "w") as f:
json.dump(data, f, indent=2)
return data
@web.middleware
async def request_log_middleware(request: web.Request, handler):
started = time.perf_counter()
response = None
try:
response = await handler(request)
return response
finally:
elapsed_ms = int((time.perf_counter() - started) * 1000)
status = getattr(response, "status", "ERR")
logger.info(
"%s %s -> %s (%d ms)",
request.method,
request.path_qs,
status,
elapsed_ms,
)
def check_auth(request: web.Request) -> bool:
auth_data = load_or_create_auth()
expected_token = auth_data["access_token"]
auth = request.headers.get("Authorization", "")
if auth.lower().startswith("bearer "):
token = auth[7:].strip()
return token == expected_token
return False
async def oauth_authorize_handler(request: web.Request) -> web.Response:
params = request.rel_url.query
redirect_uri = params.get("redirect_uri")
state = params.get("state", "")
if not redirect_uri:
return web.json_response(
{"error": "invalid_request", "error_description": "Missing redirect_uri"},
status=400,
)
code = f"ac_{secrets.token_urlsafe(48)}"
auth_codes[code] = {
"state": state,
"created_at": time.time(),
}
query = urlencode(
{
"code": code,
"scope": "openid profile email offline_access",
"state": state,
}
)
location = f"{redirect_uri}?{query}"
logger.info("OAuth authorize: issued code")
raise web.HTTPFound(location=location)
async def oauth_token_handler(request: web.Request) -> web.Response:
auth_data = load_or_create_auth()
content_type = request.content_type or ""
grant_type = None
refresh_token = None
code = None
if content_type.startswith("application/json"):
body = await request.json()
grant_type = body.get("grant_type")
refresh_token = body.get("refresh_token")
code = body.get("code")
else:
form = await request.post()
grant_type = form.get("grant_type")
refresh_token = form.get("refresh_token")
code = form.get("code")
if grant_type == "authorization_code":
code = str(code) if code else ""
if not code or code not in auth_codes:
return web.json_response(
{
"error": "invalid_grant",
"error_description": "Invalid authorization code",
},
status=400,
)
created_at = auth_codes[code]["created_at"]
del auth_codes[code]
if time.time() - created_at > 300:
return web.json_response(
{
"error": "invalid_grant",
"error_description": "Authorization code expired",
},
status=400,
)
return web.json_response(
{
"access_token": auth_data["access_token"],
"refresh_token": auth_data["refresh_token"],
"token_type": "Bearer",
"expires_in": FAKE_EXPIRES_IN,
}
)
if grant_type == "refresh_token":
if refresh_token != auth_data["refresh_token"]:
return web.json_response(
{
"error": "invalid_grant",
"error_description": "Invalid refresh token",
},
status=400,
)
return web.json_response(
{
"access_token": auth_data["access_token"],
"refresh_token": auth_data["refresh_token"],
"token_type": "Bearer",
"expires_in": FAKE_EXPIRES_IN,
}
)
return web.json_response(
{
"error": "unsupported_grant_type",
"error_description": "Only authorization_code and refresh_token are supported",
},
status=400,
)
async def refresh_tokens_task():
global refresh_in_progress
if refresh_in_progress:
logger.info("Token refresh already in progress")
return
refresh_in_progress = True
logger.info("Starting token refresh...")
try:
success = await get_new_token(headless=False)
if success:
logger.info("Token refresh completed successfully")
else:
logger.error("Token refresh failed")
except Exception as e:
logger.error(f"Error during token refresh: {e}")
finally:
refresh_in_progress = False
async def usage_monitor():
while True:
for _ in range(1):
tokens = load_tokens()
if not tokens:
if not refresh_in_progress:
logger.warning("No tokens found, starting refresh...")
asyncio.create_task(refresh_tokens_task())
break
usage = get_usage_percent(tokens.access_token)
if usage < 0:
logger.warning("Failed to get usage, token may be invalid")
asyncio.create_task(refresh_tokens_task())
break
logger.info(f"Current usage: {usage}%")
if usage >= USAGE_THRESHOLD:
logger.info(
f"Usage {usage}% >= threshold {USAGE_THRESHOLD}%, starting refresh..."
)
asyncio.create_task(refresh_tokens_task())
break
await asyncio.sleep(CHECK_INTERVAL)
async def proxy_handler(request: web.Request) -> web.StreamResponse | web.Response:
if not check_auth(request):
auth = request.headers.get("Authorization", "")
auth_preview = auth[:24] + ("..." if len(auth) > 24 else "")
logger.warning(
"Auth failed: method=%s path=%s auth_present=%s auth_preview=%s ua=%s",
request.method,
request.path,
bool(auth),
auth_preview,
request.headers.get("User-Agent", ""),
)
return web.json_response({"error": "Unauthorized"}, status=401)
tokens = await get_valid_tokens()
if not tokens:
return web.json_response({"error": "No valid tokens"}, status=500)
path = request.path
target_url = f"{CODEX_BASE_URL}{path}"
logger.info(
"Proxying request: %s %s -> %s",
request.method,
request.path_qs,
target_url,
)
headers = {}
for key, value in request.headers.items():
if key.lower() not in ("host", "authorization", "content-length"):
headers[key] = value
headers["Authorization"] = f"Bearer {tokens.access_token}"
if request.method in ("POST", "PUT", "PATCH"):
body = await request.read()
else:
body = None
async with aiohttp.ClientSession() as session:
try:
async with session.request(
method=request.method,
url=target_url,
headers=headers,
data=body,
params=request.query,
) as resp:
content_type = resp.content_type or "application/json"
is_stream = (
content_type == "text/event-stream" or "stream" in content_type
)
if is_stream:
response = web.StreamResponse(
status=resp.status,
reason=resp.reason,
headers={
"Content-Type": content_type,
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
await response.prepare(request)
async for chunk in resp.content.iter_any():
await response.write(chunk)
await response.write_eof()
return response
else:
response_body = await resp.read()
if resp.status >= 400:
preview = response_body[:500].decode("utf-8", errors="replace")
logger.warning(
"Upstream error: status=%s path=%s body=%s",
resp.status,
request.path,
preview,
)
return web.Response(
status=resp.status,
body=response_body,
headers={"Content-Type": content_type},
)
except aiohttp.ClientError as e:
return web.json_response({"error": f"Proxy error: {e}"}, status=502)
async def health_handler(request: web.Request) -> web.Response:
tokens = await get_valid_tokens()
usage = -1
if tokens:
usage = get_usage_percent(tokens.access_token)
return web.json_response(
{
"status": "ok" if tokens else "no_tokens",
"has_tokens": tokens is not None,
"usage_percent": usage,
"refresh_in_progress": refresh_in_progress,
}
)
async def start_background_tasks(app: web.Application):
app["usage_monitor"] = asyncio.create_task(usage_monitor())
async def cleanup_background_tasks(app: web.Application):
app["usage_monitor"].cancel()
try:
await app["usage_monitor"]
except asyncio.CancelledError:
pass
def create_app() -> web.Application:
app = web.Application(middlewares=[request_log_middleware])
app.router.add_get("/oauth/authorize", oauth_authorize_handler)
app.router.add_post("/oauth/token", oauth_token_handler)
app.router.add_get("/health", health_handler)
app.router.add_route("*", "/{path:.*}", proxy_handler)
app.on_startup.append(start_background_tasks)
app.on_cleanup.append(cleanup_background_tasks)
return app
if __name__ == "__main__":
logger.info(f"Starting proxy on port {PORT}")
logger.info(f"Usage threshold: {USAGE_THRESHOLD}%")
logger.info(f"Check interval: {CHECK_INTERVAL}s")
auth_data = load_or_create_auth()
logger.info("Client access token: %s", _mask(auth_data["access_token"]))
logger.info("Client refresh token: %s", _mask(auth_data["refresh_token"]))
startup_tokens = load_tokens()
if startup_tokens:
logger.info("Upstream access token: %s", _mask(startup_tokens.access_token))
else:
logger.warning("No upstream token found at %s", DATA_DIR / "tokens.json")
app = create_app()
web.run_app(app, host="0.0.0.0", port=PORT)