129 lines
3.8 KiB
Python
129 lines
3.8 KiB
Python
"""OAuth helper for codex responses provider."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from app.config.models import OAuthAuth
|
|
|
|
JWT_AUTH_CLAIM_PATH = "https://api.openai.com/auth"
|
|
OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class OAuthData:
|
|
token: str
|
|
headers: dict[str, str]
|
|
|
|
|
|
class CodexOAuthProvider:
|
|
"""Keeps oauth state and refreshes access token when required."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
auth: OAuthAuth,
|
|
timeout_seconds: float = 30.0,
|
|
token_url: str = OAUTH_TOKEN_URL,
|
|
) -> None:
|
|
self._access = auth.access
|
|
self._refresh = auth.refresh
|
|
self._expires = auth.expires
|
|
self._client_id = _extract_client_id(self._access)
|
|
self._account_id = _extract_account_id(self._access)
|
|
|
|
self._timeout_seconds = timeout_seconds
|
|
self._token_url = token_url
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def get(self) -> OAuthData:
|
|
if self._is_expired():
|
|
async with self._lock:
|
|
if self._is_expired():
|
|
await self._refresh_token()
|
|
|
|
headers = {"Authorization": f"Bearer {self._access}"}
|
|
if self._account_id:
|
|
headers["ChatGPT-Account-Id"] = self._account_id
|
|
return OAuthData(token=self._access, headers=headers)
|
|
|
|
async def get_headers(self) -> dict[str, str]:
|
|
oauth = await self.get()
|
|
return oauth.headers
|
|
|
|
def _is_expired(self) -> bool:
|
|
return int(time.time() * 1000) >= self._expires - 60_000
|
|
|
|
async def _refresh_token(self) -> None:
|
|
payload: dict[str, Any] = {
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": self._refresh,
|
|
}
|
|
if self._client_id:
|
|
payload["client_id"] = self._client_id
|
|
|
|
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
|
|
response = await client.post(self._token_url, json=payload)
|
|
|
|
if response.status_code >= 400:
|
|
raise ValueError(f"OAuth refresh failed with status {response.status_code}")
|
|
|
|
data = response.json()
|
|
access = data.get("access_token")
|
|
refresh = data.get("refresh_token")
|
|
expires_in = data.get("expires_in")
|
|
if not isinstance(access, str):
|
|
raise ValueError("OAuth refresh response missing access_token")
|
|
|
|
self._access = access
|
|
if isinstance(refresh, str):
|
|
self._refresh = refresh
|
|
if isinstance(expires_in, int):
|
|
self._expires = int(time.time() * 1000) + expires_in * 1000
|
|
|
|
self._account_id = _extract_account_id(self._access)
|
|
self._client_id = _extract_client_id(self._access) or self._client_id
|
|
|
|
|
|
def _extract_account_id(token: str) -> str | None:
|
|
payload = _decode_jwt_payload(token)
|
|
if payload is None:
|
|
return None
|
|
claim = payload.get(JWT_AUTH_CLAIM_PATH)
|
|
if not isinstance(claim, dict):
|
|
return None
|
|
account_id = claim.get("chatgpt_account_id")
|
|
return account_id if isinstance(account_id, str) else None
|
|
|
|
|
|
def _extract_client_id(token: str) -> str | None:
|
|
payload = _decode_jwt_payload(token)
|
|
if payload is None:
|
|
return None
|
|
client_id = payload.get("client_id")
|
|
return client_id if isinstance(client_id, str) else None
|
|
|
|
|
|
def _decode_jwt_payload(token: str) -> dict[str, Any] | None:
|
|
parts = token.split(".")
|
|
if len(parts) != 3:
|
|
return None
|
|
try:
|
|
payload = _b64url_decode(parts[1])
|
|
parsed = json.loads(payload)
|
|
except Exception:
|
|
return None
|
|
return parsed if isinstance(parsed, dict) else None
|
|
|
|
|
|
def _b64url_decode(data: str) -> str:
|
|
padded = data + "=" * ((4 - len(data) % 4) % 4)
|
|
raw = base64.urlsafe_b64decode(padded)
|
|
return raw.decode("utf-8")
|