This commit is contained in:
commit
7cef56de15
23 changed files with 3136 additions and 0 deletions
200
scripts/oauth_helper.py
Normal file
200
scripts/oauth_helper.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
||||
|
||||
from gibby.account_ops import (
|
||||
PermanentAccountFailure,
|
||||
failed_identifier,
|
||||
refresh_account_usage,
|
||||
window_used_percent,
|
||||
)
|
||||
from gibby.client import OpenAIAPIError, OpenAIClient
|
||||
from gibby.models import AccountRecord, format_reset_in
|
||||
from gibby.oauth import (
|
||||
build_authorize_url,
|
||||
generate_pkce_pair,
|
||||
generate_state,
|
||||
make_account_id,
|
||||
)
|
||||
from gibby.settings import Settings
|
||||
from gibby.store import JsonStateStore
|
||||
|
||||
|
||||
def parse_redirect_url(url: str) -> tuple[str, str]:
|
||||
parsed = urlparse(url.strip())
|
||||
query = parse_qs(parsed.query)
|
||||
code = (query.get("code") or [None])[0]
|
||||
state = (query.get("state") or [None])[0]
|
||||
if not code or not state:
|
||||
raise ValueError("Redirect URL must contain code and state")
|
||||
return code, state
|
||||
|
||||
|
||||
async def wait_for_callback(
|
||||
host: str, port: int, expected_state: str, timeout: float = 300.0
|
||||
) -> tuple[str, str]:
|
||||
loop = asyncio.get_running_loop()
|
||||
result: asyncio.Future[tuple[str, str]] = loop.create_future()
|
||||
|
||||
async def handle_client(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
try:
|
||||
request_line = await reader.readline()
|
||||
parts = request_line.decode("utf-8", errors="replace").split()
|
||||
if len(parts) < 2:
|
||||
raise ValueError("Invalid HTTP request line")
|
||||
code, state = parse_redirect_url(parts[1])
|
||||
if not result.done():
|
||||
result.set_result((code, state))
|
||||
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if not line or line == b"\r\n":
|
||||
break
|
||||
|
||||
writer.write(
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Content-Type: text/plain; charset=utf-8\r\n"
|
||||
b"Connection: close\r\n\r\n"
|
||||
b"Authorization received. You can close this tab."
|
||||
)
|
||||
await writer.drain()
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
server = await asyncio.start_server(handle_client, host, port)
|
||||
try:
|
||||
code, state = await asyncio.wait_for(result, timeout=timeout)
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
if state != expected_state:
|
||||
raise ValueError("OAuth state mismatch")
|
||||
return code, state
|
||||
|
||||
|
||||
async def exchange_and_store_account(
|
||||
store: JsonStateStore,
|
||||
client: OpenAIClient,
|
||||
code: str,
|
||||
verifier: str,
|
||||
set_active: bool,
|
||||
) -> AccountRecord:
|
||||
access_token, refresh_token, expires_at = await client.exchange_code(code, verifier)
|
||||
account = AccountRecord(
|
||||
id=make_account_id(),
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
try:
|
||||
usage = await refresh_account_usage(
|
||||
account,
|
||||
client,
|
||||
client.settings.exhausted_usage_threshold,
|
||||
)
|
||||
except PermanentAccountFailure:
|
||||
store.append_failed_identifier(failed_identifier(account))
|
||||
raise
|
||||
except OpenAIAPIError as exc:
|
||||
account.last_error = str(exc)
|
||||
store.upsert_account(account, set_active=set_active)
|
||||
print("Usage fetch failed, stored account without usage snapshot.")
|
||||
return account
|
||||
|
||||
store.upsert_account(account, set_active=set_active)
|
||||
print(
|
||||
f"token ready for {account.id}, "
|
||||
f"primary {window_used_percent(usage.primary_window)}% "
|
||||
f"reset in {format_reset_in(usage.primary_window.reset_at if usage.primary_window else None)}, "
|
||||
f"secondary {window_used_percent(usage.secondary_window)}% "
|
||||
f"reset in {format_reset_in(usage.secondary_window.reset_at if usage.secondary_window else None)}"
|
||||
)
|
||||
return account
|
||||
|
||||
|
||||
async def run(
|
||||
mode: str,
|
||||
set_active: bool,
|
||||
data_dir: Path | None = None,
|
||||
) -> None:
|
||||
settings = Settings()
|
||||
if data_dir is not None:
|
||||
settings.data_dir = data_dir
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
state = generate_state()
|
||||
url = build_authorize_url(settings, challenge, state)
|
||||
|
||||
print("Open this URL and complete authorization:\n")
|
||||
print(url)
|
||||
print()
|
||||
|
||||
opener = shutil.which("xdg-open")
|
||||
if opener:
|
||||
print("Opening browser...")
|
||||
print()
|
||||
try:
|
||||
subprocess.Popen([opener, url])
|
||||
except OSError:
|
||||
print("Open the URL manually.")
|
||||
print()
|
||||
else:
|
||||
print("Open the URL manually.")
|
||||
print()
|
||||
|
||||
if mode == "local":
|
||||
code, _ = await wait_for_callback(
|
||||
settings.callback_host,
|
||||
settings.callback_port,
|
||||
state,
|
||||
)
|
||||
else:
|
||||
redirect_url = input("Paste the final redirect URL: ").strip()
|
||||
code, returned_state = parse_redirect_url(redirect_url)
|
||||
if returned_state != state:
|
||||
raise ValueError("OAuth state mismatch")
|
||||
|
||||
store = JsonStateStore(settings.accounts_file, settings.failed_file)
|
||||
client = OpenAIClient(settings)
|
||||
try:
|
||||
account = await exchange_and_store_account(
|
||||
store, client, code, verifier, set_active
|
||||
)
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
print(f"Stored account: {account.id}")
|
||||
print(f"Access token expires at: {account.expires_at}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--mode", choices=["local", "manual"], default="local")
|
||||
parser.add_argument("-a", "--set-active", action="store_true")
|
||||
parser.add_argument("-d", "--data-dir", type=Path)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
asyncio.run(run(args.mode, args.set_active, args.data_dir))
|
||||
except KeyboardInterrupt:
|
||||
print("\nCancelled.")
|
||||
except TimeoutError:
|
||||
print("Timed out waiting for OAuth callback.")
|
||||
except ValueError as exc:
|
||||
print(str(exc))
|
||||
except PermanentAccountFailure as exc:
|
||||
print(str(exc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue