166 lines
5.2 KiB
Python
166 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import sys
|
|
import urllib.parse
|
|
from pathlib import Path
|
|
from typing import Any, cast
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts"))
|
|
|
|
from gibby.client import OpenAIAPIError
|
|
from gibby.models import UsageSnapshot, UsageWindow
|
|
from gibby.oauth import build_authorize_url, generate_pkce_pair
|
|
from gibby.settings import Settings
|
|
from oauth_helper import ( # type: ignore[import-not-found]
|
|
exchange_and_store_account,
|
|
parse_redirect_url,
|
|
wait_for_callback,
|
|
)
|
|
from gibby.store import JsonStateStore
|
|
|
|
|
|
class FakeClient:
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
usage: UsageSnapshot,
|
|
*,
|
|
transient_usage_failure: bool = False,
|
|
):
|
|
self.settings = settings
|
|
self.usage = usage
|
|
self.transient_usage_failure = transient_usage_failure
|
|
|
|
async def exchange_code(self, code: str, verifier: str) -> tuple[str, str, int]:
|
|
return ("access-token", "refresh-token", 1776000000)
|
|
|
|
async def refresh_access_token(self, refresh_token: str) -> tuple[str, str, int]:
|
|
return ("access-token", refresh_token, 1776000000)
|
|
|
|
async def fetch_usage_payload(self, access_token: str) -> dict:
|
|
if self.transient_usage_failure:
|
|
raise OpenAIAPIError("usage timeout", permanent=False)
|
|
primary_window = self.usage.primary_window
|
|
assert primary_window is not None
|
|
secondary_window = (
|
|
{
|
|
"used_percent": self.usage.secondary_window.used_percent,
|
|
"reset_at": self.usage.secondary_window.reset_at,
|
|
}
|
|
if self.usage.secondary_window is not None
|
|
else None
|
|
)
|
|
return {
|
|
"email": "oauth@example.com",
|
|
"account_id": "oauth-1",
|
|
"rate_limit": {
|
|
"allowed": self.usage.allowed,
|
|
"limit_reached": self.usage.limit_reached,
|
|
"primary_window": {
|
|
"used_percent": primary_window.used_percent,
|
|
"reset_at": primary_window.reset_at,
|
|
},
|
|
"secondary_window": secondary_window,
|
|
},
|
|
}
|
|
|
|
|
|
def make_usage(primary: int, secondary: int | None = None) -> UsageSnapshot:
|
|
return UsageSnapshot(
|
|
checked_at=1775000000,
|
|
primary_window=UsageWindow(primary, 1775000100),
|
|
secondary_window=UsageWindow(secondary or 0, 1775000100)
|
|
if secondary is not None
|
|
else None,
|
|
)
|
|
|
|
|
|
def test_generate_pkce_pair_shapes() -> None:
|
|
verifier, challenge = generate_pkce_pair()
|
|
assert len(verifier) > 40
|
|
assert len(challenge) > 40
|
|
assert "=" not in challenge
|
|
|
|
|
|
def test_build_authorize_url_contains_redirect_and_state() -> None:
|
|
settings = Settings(callback_host="localhost", callback_port=1455)
|
|
url = build_authorize_url(settings, "challenge", "state-123")
|
|
parsed = urllib.parse.urlparse(url)
|
|
query = urllib.parse.parse_qs(parsed.query)
|
|
assert parsed.scheme == "https"
|
|
assert query["redirect_uri"] == ["http://localhost:1455/auth/callback"]
|
|
assert query["state"] == ["state-123"]
|
|
assert query["code_challenge"] == ["challenge"]
|
|
|
|
|
|
def test_parse_redirect_url_extracts_code_and_state() -> None:
|
|
code, state = parse_redirect_url(
|
|
"http://127.0.0.1:1455/auth/callback?code=abc&state=xyz"
|
|
)
|
|
assert code == "abc"
|
|
assert state == "xyz"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wait_for_callback_receives_code_and_state() -> None:
|
|
task = asyncio.create_task(
|
|
wait_for_callback("127.0.0.1", 18555, "state-1", timeout=5)
|
|
)
|
|
await asyncio.sleep(0.05)
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"http://127.0.0.1:18555/auth/callback?code=abc123&state=state-1",
|
|
timeout=5,
|
|
)
|
|
result = await task
|
|
|
|
assert response.status_code == 200
|
|
assert result == ("abc123", "state-1")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exchange_and_store_account_populates_usage_snapshot(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
store = JsonStateStore(tmp_path / "accounts.json")
|
|
settings = Settings(data_dir=tmp_path)
|
|
client = FakeClient(settings, make_usage(12, 3))
|
|
|
|
account = await exchange_and_store_account(
|
|
store,
|
|
cast(Any, client),
|
|
"code",
|
|
"verifier",
|
|
False,
|
|
)
|
|
|
|
assert account.usage is not None
|
|
assert account.email == "oauth@example.com"
|
|
assert account.usage.primary_window is not None
|
|
assert account.usage.primary_window.used_percent == 12
|
|
assert account.usage.secondary_window is not None
|
|
assert account.usage.secondary_window.used_percent == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exchange_and_store_account_raises_on_transient_usage_failure(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
store = JsonStateStore(tmp_path / "accounts.json")
|
|
settings = Settings(data_dir=tmp_path)
|
|
client = FakeClient(settings, make_usage(12, 3), transient_usage_failure=True)
|
|
|
|
with pytest.raises(OpenAIAPIError):
|
|
await exchange_and_store_account(
|
|
store,
|
|
cast(Any, client),
|
|
"code",
|
|
"verifier",
|
|
False,
|
|
)
|
|
|
|
assert store.load().accounts == []
|