ai/ai
1
0
Fork 0
ai/app/providers/codex_responses/oauth.py

129 lines
3.8 KiB
Python

"""OAuth helper for codex responses provider."""
from __future__ import annotations
import asyncio
import base64
import json
import time
from dataclasses import dataclass
from typing import Any
import httpx
from app.config.models import OAuthAuth
JWT_AUTH_CLAIM_PATH = "https://api.openai.com/auth"
OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
@dataclass(slots=True)
class OAuthData:
token: str
headers: dict[str, str]
class CodexOAuthProvider:
"""Keeps oauth state and refreshes access token when required."""
def __init__(
self,
*,
auth: OAuthAuth,
timeout_seconds: float = 30.0,
token_url: str = OAUTH_TOKEN_URL,
) -> None:
self._access = auth.access
self._refresh = auth.refresh
self._expires = auth.expires
self._client_id = _extract_client_id(self._access)
self._account_id = _extract_account_id(self._access)
self._timeout_seconds = timeout_seconds
self._token_url = token_url
self._lock = asyncio.Lock()
async def get(self) -> OAuthData:
if self._is_expired():
async with self._lock:
if self._is_expired():
await self._refresh_token()
headers = {"Authorization": f"Bearer {self._access}"}
if self._account_id:
headers["ChatGPT-Account-Id"] = self._account_id
return OAuthData(token=self._access, headers=headers)
async def get_headers(self) -> dict[str, str]:
oauth = await self.get()
return oauth.headers
def _is_expired(self) -> bool:
return int(time.time() * 1000) >= self._expires - 60_000
async def _refresh_token(self) -> None:
payload: dict[str, Any] = {
"grant_type": "refresh_token",
"refresh_token": self._refresh,
}
if self._client_id:
payload["client_id"] = self._client_id
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
response = await client.post(self._token_url, json=payload)
if response.status_code >= 400:
raise ValueError(f"OAuth refresh failed with status {response.status_code}")
data = response.json()
access = data.get("access_token")
refresh = data.get("refresh_token")
expires_in = data.get("expires_in")
if not isinstance(access, str):
raise ValueError("OAuth refresh response missing access_token")
self._access = access
if isinstance(refresh, str):
self._refresh = refresh
if isinstance(expires_in, int):
self._expires = int(time.time() * 1000) + expires_in * 1000
self._account_id = _extract_account_id(self._access)
self._client_id = _extract_client_id(self._access) or self._client_id
def _extract_account_id(token: str) -> str | None:
payload = _decode_jwt_payload(token)
if payload is None:
return None
claim = payload.get(JWT_AUTH_CLAIM_PATH)
if not isinstance(claim, dict):
return None
account_id = claim.get("chatgpt_account_id")
return account_id if isinstance(account_id, str) else None
def _extract_client_id(token: str) -> str | None:
payload = _decode_jwt_payload(token)
if payload is None:
return None
client_id = payload.get("client_id")
return client_id if isinstance(client_id, str) else None
def _decode_jwt_payload(token: str) -> dict[str, Any] | None:
parts = token.split(".")
if len(parts) != 3:
return None
try:
payload = _b64url_decode(parts[1])
parsed = json.loads(payload)
except Exception:
return None
return parsed if isinstance(parsed, dict) else None
def _b64url_decode(data: str) -> str:
padded = data + "=" * ((4 - len(data) % 4) % 4)
raw = base64.urlsafe_b64decode(padded)
return raw.decode("utf-8")