200 lines
6 KiB
Python
200 lines
6 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
|
|
|
|
from gibby.account_ops import (
|
|
PermanentAccountFailure,
|
|
failed_identifier,
|
|
refresh_account_usage,
|
|
window_used_percent,
|
|
)
|
|
from gibby.client import OpenAIAPIError, OpenAIClient
|
|
from gibby.models import AccountRecord, format_reset_in
|
|
from gibby.oauth import (
|
|
build_authorize_url,
|
|
generate_pkce_pair,
|
|
generate_state,
|
|
make_account_id,
|
|
)
|
|
from gibby.settings import Settings
|
|
from gibby.store import JsonStateStore
|
|
|
|
|
|
def parse_redirect_url(url: str) -> tuple[str, str]:
|
|
parsed = urlparse(url.strip())
|
|
query = parse_qs(parsed.query)
|
|
code = (query.get("code") or [None])[0]
|
|
state = (query.get("state") or [None])[0]
|
|
if not code or not state:
|
|
raise ValueError("Redirect URL must contain code and state")
|
|
return code, state
|
|
|
|
|
|
async def wait_for_callback(
|
|
host: str, port: int, expected_state: str, timeout: float = 300.0
|
|
) -> tuple[str, str]:
|
|
loop = asyncio.get_running_loop()
|
|
result: asyncio.Future[tuple[str, str]] = loop.create_future()
|
|
|
|
async def handle_client(
|
|
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
) -> None:
|
|
try:
|
|
request_line = await reader.readline()
|
|
parts = request_line.decode("utf-8", errors="replace").split()
|
|
if len(parts) < 2:
|
|
raise ValueError("Invalid HTTP request line")
|
|
code, state = parse_redirect_url(parts[1])
|
|
if not result.done():
|
|
result.set_result((code, state))
|
|
|
|
while True:
|
|
line = await reader.readline()
|
|
if not line or line == b"\r\n":
|
|
break
|
|
|
|
writer.write(
|
|
b"HTTP/1.1 200 OK\r\n"
|
|
b"Content-Type: text/plain; charset=utf-8\r\n"
|
|
b"Connection: close\r\n\r\n"
|
|
b"Authorization received. You can close this tab."
|
|
)
|
|
await writer.drain()
|
|
finally:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
|
|
server = await asyncio.start_server(handle_client, host, port)
|
|
try:
|
|
code, state = await asyncio.wait_for(result, timeout=timeout)
|
|
finally:
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
if state != expected_state:
|
|
raise ValueError("OAuth state mismatch")
|
|
return code, state
|
|
|
|
|
|
async def exchange_and_store_account(
|
|
store: JsonStateStore,
|
|
client: OpenAIClient,
|
|
code: str,
|
|
verifier: str,
|
|
set_active: bool,
|
|
) -> AccountRecord:
|
|
access_token, refresh_token, expires_at = await client.exchange_code(code, verifier)
|
|
account = AccountRecord(
|
|
id=make_account_id(),
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
expires_at=expires_at,
|
|
)
|
|
try:
|
|
usage = await refresh_account_usage(
|
|
account,
|
|
client,
|
|
client.settings.exhausted_usage_threshold,
|
|
)
|
|
except PermanentAccountFailure:
|
|
store.append_failed_identifier(failed_identifier(account))
|
|
raise
|
|
except OpenAIAPIError as exc:
|
|
account.last_error = str(exc)
|
|
store.upsert_account(account, set_active=set_active)
|
|
print("Usage fetch failed, stored account without usage snapshot.")
|
|
return account
|
|
|
|
store.upsert_account(account, set_active=set_active)
|
|
print(
|
|
f"token ready for {account.id}, "
|
|
f"primary {window_used_percent(usage.primary_window)}% "
|
|
f"reset in {format_reset_in(usage.primary_window.reset_at if usage.primary_window else None)}, "
|
|
f"secondary {window_used_percent(usage.secondary_window)}% "
|
|
f"reset in {format_reset_in(usage.secondary_window.reset_at if usage.secondary_window else None)}"
|
|
)
|
|
return account
|
|
|
|
|
|
async def run(
|
|
mode: str,
|
|
set_active: bool,
|
|
data_dir: Path | None = None,
|
|
) -> None:
|
|
settings = Settings()
|
|
if data_dir is not None:
|
|
settings.data_dir = data_dir
|
|
verifier, challenge = generate_pkce_pair()
|
|
state = generate_state()
|
|
url = build_authorize_url(settings, challenge, state)
|
|
|
|
print("Open this URL and complete authorization:\n")
|
|
print(url)
|
|
print()
|
|
|
|
opener = shutil.which("xdg-open")
|
|
if opener:
|
|
print("Opening browser...")
|
|
print()
|
|
try:
|
|
subprocess.Popen([opener, url])
|
|
except OSError:
|
|
print("Open the URL manually.")
|
|
print()
|
|
else:
|
|
print("Open the URL manually.")
|
|
print()
|
|
|
|
if mode == "local":
|
|
code, _ = await wait_for_callback(
|
|
settings.callback_host,
|
|
settings.callback_port,
|
|
state,
|
|
)
|
|
else:
|
|
redirect_url = input("Paste the final redirect URL: ").strip()
|
|
code, returned_state = parse_redirect_url(redirect_url)
|
|
if returned_state != state:
|
|
raise ValueError("OAuth state mismatch")
|
|
|
|
store = JsonStateStore(settings.accounts_file, settings.failed_file)
|
|
client = OpenAIClient(settings)
|
|
try:
|
|
account = await exchange_and_store_account(
|
|
store, client, code, verifier, set_active
|
|
)
|
|
finally:
|
|
await client.aclose()
|
|
|
|
print(f"Stored account: {account.id}")
|
|
print(f"Access token expires at: {account.expires_at}")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-m", "--mode", choices=["local", "manual"], default="local")
|
|
parser.add_argument("-a", "--set-active", action="store_true")
|
|
parser.add_argument("-d", "--data-dir", type=Path)
|
|
args = parser.parse_args()
|
|
try:
|
|
asyncio.run(run(args.mode, args.set_active, args.data_dir))
|
|
except KeyboardInterrupt:
|
|
print("\nCancelled.")
|
|
except TimeoutError:
|
|
print("Timed out waiting for OAuth callback.")
|
|
except ValueError as exc:
|
|
print(str(exc))
|
|
except PermanentAccountFailure as exc:
|
|
print(str(exc))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|