436 lines
14 KiB
Python
436 lines
14 KiB
Python
import os
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import secrets
|
|
import json
|
|
import base64
|
|
import uuid
|
|
from urllib.parse import urlencode
|
|
from aiohttp import web
|
|
import aiohttp
|
|
|
|
from tokens import get_valid_tokens, load_tokens, DATA_DIR
|
|
from codex_usage import get_usage_percent
|
|
from get_new_token import get_new_token
|
|
|
|
CODEX_BASE_URL = "https://chatgpt.com/backend-api"
|
|
PORT = int(os.environ.get("PORT", "8080"))
|
|
USAGE_THRESHOLD = int(os.environ.get("USAGE_THRESHOLD", "85"))
|
|
CHECK_INTERVAL = int(os.environ.get("CHECK_INTERVAL", "60"))
|
|
FAKE_EXPIRES_IN = 9999999999999
|
|
AUTH_FILE = DATA_DIR / "auth.json"
|
|
JWT_AUTH_CLAIM_PATH = "https://api.openai.com/auth"
|
|
JWT_PROFILE_CLAIM_PATH = "https://api.openai.com/profile"
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
refresh_in_progress = False
|
|
auth_codes: dict[str, dict] = {}
|
|
|
|
|
|
def _b64url(data: bytes) -> str:
|
|
return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
|
|
|
|
|
|
def _generate_jwt_like() -> str:
|
|
account_id = str(uuid.uuid4())
|
|
now = int(time.time())
|
|
header = {"alg": "HS256", "typ": "JWT"}
|
|
user_id = f"user-{secrets.token_urlsafe(18)}"
|
|
account_user_id = f"{user_id}__{account_id}"
|
|
payload = {
|
|
"aud": ["https://api.openai.com/v1"],
|
|
"client_id": "app_EMoamEEZ73f0CkXaXp7hrann",
|
|
"iss": "https://auth.openai.com",
|
|
"iat": now,
|
|
"nbf": now,
|
|
"exp": now + 315360000,
|
|
"jti": str(uuid.uuid4()),
|
|
"scp": ["openid", "profile", "email", "offline_access"],
|
|
"session_id": f"authsess_{secrets.token_urlsafe(24)}",
|
|
JWT_AUTH_CLAIM_PATH: {
|
|
"chatgpt_account_id": account_id,
|
|
"chatgpt_account_user_id": account_user_id,
|
|
"chatgpt_compute_residency": "no_constraint",
|
|
"chatgpt_plan_type": "plus",
|
|
"chatgpt_user_id": user_id,
|
|
"user_id": user_id,
|
|
},
|
|
JWT_PROFILE_CLAIM_PATH: {
|
|
"email": f"proxy-{secrets.token_hex(4)}@example.local",
|
|
"email_verified": True,
|
|
},
|
|
"sub": f"auth0|{secrets.token_urlsafe(20)}",
|
|
}
|
|
head = _b64url(json.dumps(header, separators=(",", ":")).encode("utf-8"))
|
|
body = _b64url(json.dumps(payload, separators=(",", ":")).encode("utf-8"))
|
|
sign = _b64url(secrets.token_bytes(32))
|
|
return f"{head}.{body}.{sign}"
|
|
|
|
|
|
def _generate_refresh_like() -> str:
|
|
return f"rt_{secrets.token_urlsafe(40)}.{secrets.token_urlsafe(32)}"
|
|
|
|
|
|
def _mask(value: str, head: int = 8, tail: int = 6) -> str:
|
|
if not value:
|
|
return "<empty>"
|
|
if len(value) <= head + tail:
|
|
return "<hidden>"
|
|
return f"{value[:head]}...{value[-tail:]}"
|
|
|
|
|
|
def load_or_create_auth() -> dict:
|
|
if AUTH_FILE.exists():
|
|
with open(AUTH_FILE) as f:
|
|
data = json.load(f)
|
|
if (
|
|
data.get("access_token")
|
|
and data.get("refresh_token")
|
|
and data.get("expires_at")
|
|
):
|
|
return data
|
|
|
|
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
access_token = _generate_jwt_like()
|
|
|
|
data = {
|
|
"access_token": access_token,
|
|
"refresh_token": _generate_refresh_like(),
|
|
"expires_at": FAKE_EXPIRES_IN,
|
|
}
|
|
with open(AUTH_FILE, "w") as f:
|
|
json.dump(data, f, indent=2)
|
|
return data
|
|
|
|
|
|
@web.middleware
|
|
async def request_log_middleware(request: web.Request, handler):
|
|
started = time.perf_counter()
|
|
response = None
|
|
try:
|
|
response = await handler(request)
|
|
return response
|
|
finally:
|
|
elapsed_ms = int((time.perf_counter() - started) * 1000)
|
|
status = getattr(response, "status", "ERR")
|
|
logger.info(
|
|
"%s %s -> %s (%d ms)",
|
|
request.method,
|
|
request.path_qs,
|
|
status,
|
|
elapsed_ms,
|
|
)
|
|
|
|
|
|
def check_auth(request: web.Request) -> bool:
|
|
auth_data = load_or_create_auth()
|
|
expected_token = auth_data["access_token"]
|
|
auth = request.headers.get("Authorization", "")
|
|
if auth.lower().startswith("bearer "):
|
|
token = auth[7:].strip()
|
|
return token == expected_token
|
|
return False
|
|
|
|
|
|
async def oauth_authorize_handler(request: web.Request) -> web.Response:
|
|
params = request.rel_url.query
|
|
redirect_uri = params.get("redirect_uri")
|
|
state = params.get("state", "")
|
|
|
|
if not redirect_uri:
|
|
return web.json_response(
|
|
{"error": "invalid_request", "error_description": "Missing redirect_uri"},
|
|
status=400,
|
|
)
|
|
|
|
code = f"ac_{secrets.token_urlsafe(48)}"
|
|
auth_codes[code] = {
|
|
"state": state,
|
|
"created_at": time.time(),
|
|
}
|
|
|
|
query = urlencode(
|
|
{
|
|
"code": code,
|
|
"scope": "openid profile email offline_access",
|
|
"state": state,
|
|
}
|
|
)
|
|
location = f"{redirect_uri}?{query}"
|
|
logger.info("OAuth authorize: issued code")
|
|
raise web.HTTPFound(location=location)
|
|
|
|
|
|
async def oauth_token_handler(request: web.Request) -> web.Response:
|
|
auth_data = load_or_create_auth()
|
|
|
|
content_type = request.content_type or ""
|
|
grant_type = None
|
|
refresh_token = None
|
|
code = None
|
|
if content_type.startswith("application/json"):
|
|
body = await request.json()
|
|
grant_type = body.get("grant_type")
|
|
refresh_token = body.get("refresh_token")
|
|
code = body.get("code")
|
|
else:
|
|
form = await request.post()
|
|
grant_type = form.get("grant_type")
|
|
refresh_token = form.get("refresh_token")
|
|
code = form.get("code")
|
|
|
|
if grant_type == "authorization_code":
|
|
code = str(code) if code else ""
|
|
if not code or code not in auth_codes:
|
|
return web.json_response(
|
|
{
|
|
"error": "invalid_grant",
|
|
"error_description": "Invalid authorization code",
|
|
},
|
|
status=400,
|
|
)
|
|
created_at = auth_codes[code]["created_at"]
|
|
del auth_codes[code]
|
|
if time.time() - created_at > 300:
|
|
return web.json_response(
|
|
{
|
|
"error": "invalid_grant",
|
|
"error_description": "Authorization code expired",
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"access_token": auth_data["access_token"],
|
|
"refresh_token": auth_data["refresh_token"],
|
|
"token_type": "Bearer",
|
|
"expires_in": FAKE_EXPIRES_IN,
|
|
}
|
|
)
|
|
|
|
if grant_type == "refresh_token":
|
|
if refresh_token != auth_data["refresh_token"]:
|
|
return web.json_response(
|
|
{
|
|
"error": "invalid_grant",
|
|
"error_description": "Invalid refresh token",
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"access_token": auth_data["access_token"],
|
|
"refresh_token": auth_data["refresh_token"],
|
|
"token_type": "Bearer",
|
|
"expires_in": FAKE_EXPIRES_IN,
|
|
}
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"error": "unsupported_grant_type",
|
|
"error_description": "Only authorization_code and refresh_token are supported",
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
|
|
async def refresh_tokens_task():
|
|
global refresh_in_progress
|
|
if refresh_in_progress:
|
|
logger.info("Token refresh already in progress")
|
|
return
|
|
|
|
refresh_in_progress = True
|
|
logger.info("Starting token refresh...")
|
|
|
|
try:
|
|
success = await get_new_token(headless=False)
|
|
if success:
|
|
logger.info("Token refresh completed successfully")
|
|
else:
|
|
logger.error("Token refresh failed")
|
|
except Exception as e:
|
|
logger.error(f"Error during token refresh: {e}")
|
|
finally:
|
|
refresh_in_progress = False
|
|
|
|
|
|
async def usage_monitor():
|
|
while True:
|
|
for _ in range(1):
|
|
tokens = load_tokens()
|
|
|
|
if not tokens:
|
|
if not refresh_in_progress:
|
|
logger.warning("No tokens found, starting refresh...")
|
|
asyncio.create_task(refresh_tokens_task())
|
|
break
|
|
|
|
usage = get_usage_percent(tokens.access_token)
|
|
|
|
if usage < 0:
|
|
logger.warning("Failed to get usage, token may be invalid")
|
|
asyncio.create_task(refresh_tokens_task())
|
|
break
|
|
|
|
logger.info(f"Current usage: {usage}%")
|
|
|
|
if usage >= USAGE_THRESHOLD:
|
|
logger.info(
|
|
f"Usage {usage}% >= threshold {USAGE_THRESHOLD}%, starting refresh..."
|
|
)
|
|
asyncio.create_task(refresh_tokens_task())
|
|
break
|
|
|
|
await asyncio.sleep(CHECK_INTERVAL)
|
|
|
|
|
|
async def proxy_handler(request: web.Request) -> web.StreamResponse | web.Response:
|
|
if not check_auth(request):
|
|
auth = request.headers.get("Authorization", "")
|
|
auth_preview = auth[:24] + ("..." if len(auth) > 24 else "")
|
|
logger.warning(
|
|
"Auth failed: method=%s path=%s auth_present=%s auth_preview=%s ua=%s",
|
|
request.method,
|
|
request.path,
|
|
bool(auth),
|
|
auth_preview,
|
|
request.headers.get("User-Agent", ""),
|
|
)
|
|
return web.json_response({"error": "Unauthorized"}, status=401)
|
|
|
|
tokens = await get_valid_tokens()
|
|
if not tokens:
|
|
return web.json_response({"error": "No valid tokens"}, status=500)
|
|
|
|
path = request.path
|
|
target_url = f"{CODEX_BASE_URL}{path}"
|
|
logger.info(
|
|
"Proxying request: %s %s -> %s",
|
|
request.method,
|
|
request.path_qs,
|
|
target_url,
|
|
)
|
|
|
|
headers = {}
|
|
for key, value in request.headers.items():
|
|
if key.lower() not in ("host", "authorization", "content-length"):
|
|
headers[key] = value
|
|
headers["Authorization"] = f"Bearer {tokens.access_token}"
|
|
|
|
if request.method in ("POST", "PUT", "PATCH"):
|
|
body = await request.read()
|
|
else:
|
|
body = None
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
try:
|
|
async with session.request(
|
|
method=request.method,
|
|
url=target_url,
|
|
headers=headers,
|
|
data=body,
|
|
params=request.query,
|
|
) as resp:
|
|
content_type = resp.content_type or "application/json"
|
|
is_stream = (
|
|
content_type == "text/event-stream" or "stream" in content_type
|
|
)
|
|
|
|
if is_stream:
|
|
response = web.StreamResponse(
|
|
status=resp.status,
|
|
reason=resp.reason,
|
|
headers={
|
|
"Content-Type": content_type,
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
},
|
|
)
|
|
await response.prepare(request)
|
|
|
|
async for chunk in resp.content.iter_any():
|
|
await response.write(chunk)
|
|
|
|
await response.write_eof()
|
|
return response
|
|
else:
|
|
response_body = await resp.read()
|
|
if resp.status >= 400:
|
|
preview = response_body[:500].decode("utf-8", errors="replace")
|
|
logger.warning(
|
|
"Upstream error: status=%s path=%s body=%s",
|
|
resp.status,
|
|
request.path,
|
|
preview,
|
|
)
|
|
return web.Response(
|
|
status=resp.status,
|
|
body=response_body,
|
|
headers={"Content-Type": content_type},
|
|
)
|
|
except aiohttp.ClientError as e:
|
|
return web.json_response({"error": f"Proxy error: {e}"}, status=502)
|
|
|
|
|
|
async def health_handler(request: web.Request) -> web.Response:
|
|
tokens = await get_valid_tokens()
|
|
usage = -1
|
|
if tokens:
|
|
usage = get_usage_percent(tokens.access_token)
|
|
|
|
return web.json_response(
|
|
{
|
|
"status": "ok" if tokens else "no_tokens",
|
|
"has_tokens": tokens is not None,
|
|
"usage_percent": usage,
|
|
"refresh_in_progress": refresh_in_progress,
|
|
}
|
|
)
|
|
|
|
|
|
async def start_background_tasks(app: web.Application):
|
|
app["usage_monitor"] = asyncio.create_task(usage_monitor())
|
|
|
|
|
|
async def cleanup_background_tasks(app: web.Application):
|
|
app["usage_monitor"].cancel()
|
|
try:
|
|
await app["usage_monitor"]
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
|
|
def create_app() -> web.Application:
|
|
app = web.Application(middlewares=[request_log_middleware])
|
|
app.router.add_get("/oauth/authorize", oauth_authorize_handler)
|
|
app.router.add_post("/oauth/token", oauth_token_handler)
|
|
app.router.add_get("/health", health_handler)
|
|
app.router.add_route("*", "/{path:.*}", proxy_handler)
|
|
app.on_startup.append(start_background_tasks)
|
|
app.on_cleanup.append(cleanup_background_tasks)
|
|
return app
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info(f"Starting proxy on port {PORT}")
|
|
logger.info(f"Usage threshold: {USAGE_THRESHOLD}%")
|
|
logger.info(f"Check interval: {CHECK_INTERVAL}s")
|
|
|
|
auth_data = load_or_create_auth()
|
|
logger.info("Client access token: %s", _mask(auth_data["access_token"]))
|
|
logger.info("Client refresh token: %s", _mask(auth_data["refresh_token"]))
|
|
|
|
startup_tokens = load_tokens()
|
|
if startup_tokens:
|
|
logger.info("Upstream access token: %s", _mask(startup_tokens.access_token))
|
|
else:
|
|
logger.warning("No upstream token found at %s", DATA_DIR / "tokens.json")
|
|
app = create_app()
|
|
web.run_app(app, host="0.0.0.0", port=PORT)
|