refactor!: change the entire purpose of this script
This commit is contained in:
parent
217e176975
commit
71d1050adb
20 changed files with 1124 additions and 872 deletions
84
src/providers/chatgpt/tokens.py
Normal file
84
src/providers/chatgpt/tokens.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import json
|
||||
import time
|
||||
import os
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
|
||||
from providers.base import ProviderTokens
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
|
||||
TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json"
|
||||
|
||||
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
|
||||
def load_tokens() -> ProviderTokens | None:
|
||||
if not TOKENS_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
with open(TOKENS_FILE) as f:
|
||||
data = json.load(f)
|
||||
return ProviderTokens(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=data["refresh_token"],
|
||||
expires_at=data["expires_at"],
|
||||
)
|
||||
except json.JSONDecodeError, KeyError:
|
||||
return None
|
||||
|
||||
|
||||
def save_tokens(tokens: ProviderTokens):
|
||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(TOKENS_FILE, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"expires_at": tokens.expires_at,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
}
|
||||
async with session.post(TOKEN_URL, data=data) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
print(f"Token refresh failed: {resp.status} {text}")
|
||||
return None
|
||||
json_resp = await resp.json()
|
||||
expires_in = json_resp["expires_in"]
|
||||
return ProviderTokens(
|
||||
access_token=json_resp["access_token"],
|
||||
refresh_token=json_resp["refresh_token"],
|
||||
expires_at=time.time() + expires_in,
|
||||
)
|
||||
|
||||
|
||||
async def get_valid_tokens() -> ProviderTokens | None:
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
print("No tokens found")
|
||||
return None
|
||||
|
||||
if tokens.is_expired:
|
||||
print("Token expired, refreshing...")
|
||||
if not tokens.refresh_token:
|
||||
print("No refresh token available")
|
||||
return None
|
||||
new_tokens = await refresh_tokens(tokens.refresh_token)
|
||||
if not new_tokens:
|
||||
print("Failed to refresh token")
|
||||
return None
|
||||
save_tokens(new_tokens)
|
||||
return new_tokens
|
||||
|
||||
return tokens
|
||||
Loading…
Add table
Add a link
Reference in a new issue