This commit is contained in:
commit
7cef56de15
23 changed files with 3136 additions and 0 deletions
177
tests/test_oauth_helper.py
Normal file
177
tests/test_oauth_helper.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts"))
|
||||
|
||||
from gibby.client import OpenAIAPIError
|
||||
from gibby.models import UsageSnapshot, UsageWindow
|
||||
from gibby.oauth import build_authorize_url, generate_pkce_pair
|
||||
from gibby.settings import Settings
|
||||
from oauth_helper import ( # type: ignore[import-not-found]
|
||||
exchange_and_store_account,
|
||||
parse_redirect_url,
|
||||
wait_for_callback,
|
||||
)
|
||||
from gibby.store import JsonStateStore
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
usage: UsageSnapshot,
|
||||
*,
|
||||
transient_usage_failure: bool = False,
|
||||
):
|
||||
self.settings = settings
|
||||
self.usage = usage
|
||||
self.transient_usage_failure = transient_usage_failure
|
||||
|
||||
async def exchange_code(self, code: str, verifier: str) -> tuple[str, str, int]:
|
||||
return ("access-token", "refresh-token", 1776000000)
|
||||
|
||||
async def refresh_access_token(self, refresh_token: str) -> tuple[str, str, int]:
|
||||
return ("access-token", refresh_token, 1776000000)
|
||||
|
||||
async def fetch_usage_payload(self, access_token: str) -> dict:
|
||||
if self.transient_usage_failure:
|
||||
raise OpenAIAPIError("usage timeout", permanent=False)
|
||||
primary_window = self.usage.primary_window
|
||||
assert primary_window is not None
|
||||
secondary_window = (
|
||||
{
|
||||
"used_percent": self.usage.secondary_window.used_percent,
|
||||
"limit_window_seconds": self.usage.secondary_window.limit_window_seconds,
|
||||
"reset_after_seconds": self.usage.secondary_window.reset_after_seconds,
|
||||
"reset_at": self.usage.secondary_window.reset_at,
|
||||
}
|
||||
if self.usage.secondary_window is not None
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"email": "oauth@example.com",
|
||||
"account_id": "oauth-1",
|
||||
"rate_limit": {
|
||||
"allowed": self.usage.allowed,
|
||||
"limit_reached": self.usage.limit_reached,
|
||||
"primary_window": {
|
||||
"used_percent": primary_window.used_percent,
|
||||
"limit_window_seconds": primary_window.limit_window_seconds,
|
||||
"reset_after_seconds": primary_window.reset_after_seconds,
|
||||
"reset_at": primary_window.reset_at,
|
||||
},
|
||||
"secondary_window": secondary_window,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def make_usage(primary: int, secondary: int | None = None) -> UsageSnapshot:
|
||||
return UsageSnapshot(
|
||||
checked_at=1775000000,
|
||||
used_percent=max(primary, secondary or 0),
|
||||
remaining_percent=max(0, 100 - max(primary, secondary or 0)),
|
||||
exhausted=False,
|
||||
primary_window=UsageWindow(primary, 18000, 100, 1775000100),
|
||||
secondary_window=UsageWindow(secondary, 604800, 100, 1775000100)
|
||||
if secondary is not None
|
||||
else None,
|
||||
limit_reached=False,
|
||||
allowed=True,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_pkce_pair_shapes() -> None:
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
assert len(verifier) > 40
|
||||
assert len(challenge) > 40
|
||||
assert "=" not in challenge
|
||||
|
||||
|
||||
def test_build_authorize_url_contains_redirect_and_state() -> None:
|
||||
settings = Settings(callback_host="localhost", callback_port=1455)
|
||||
url = build_authorize_url(settings, "challenge", "state-123")
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
query = urllib.parse.parse_qs(parsed.query)
|
||||
assert parsed.scheme == "https"
|
||||
assert query["redirect_uri"] == ["http://localhost:1455/auth/callback"]
|
||||
assert query["state"] == ["state-123"]
|
||||
assert query["code_challenge"] == ["challenge"]
|
||||
|
||||
|
||||
def test_parse_redirect_url_extracts_code_and_state() -> None:
|
||||
code, state = parse_redirect_url(
|
||||
"http://127.0.0.1:1455/auth/callback?code=abc&state=xyz"
|
||||
)
|
||||
assert code == "abc"
|
||||
assert state == "xyz"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_callback_receives_code_and_state() -> None:
|
||||
task = asyncio.create_task(
|
||||
wait_for_callback("127.0.0.1", 18555, "state-1", timeout=5)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"http://127.0.0.1:18555/auth/callback?code=abc123&state=state-1",
|
||||
timeout=5,
|
||||
)
|
||||
result = await task
|
||||
|
||||
assert response.status_code == 200
|
||||
assert result == ("abc123", "state-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_and_store_account_populates_usage_snapshot(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
store = JsonStateStore(tmp_path / "accounts.json")
|
||||
settings = Settings(data_dir=tmp_path)
|
||||
client = FakeClient(settings, make_usage(12, 3))
|
||||
|
||||
account = await exchange_and_store_account(
|
||||
store,
|
||||
cast(Any, client),
|
||||
"code",
|
||||
"verifier",
|
||||
False,
|
||||
)
|
||||
|
||||
assert account.last_known_usage is not None
|
||||
assert account.id == "oauth@example.com"
|
||||
assert account.last_known_usage.primary_window is not None
|
||||
assert account.last_known_usage.primary_window.used_percent == 12
|
||||
assert account.last_known_usage.secondary_window is not None
|
||||
assert account.last_known_usage.secondary_window.used_percent == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_and_store_account_keeps_tokens_on_transient_usage_failure(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
store = JsonStateStore(tmp_path / "accounts.json")
|
||||
settings = Settings(data_dir=tmp_path)
|
||||
client = FakeClient(settings, make_usage(12, 3), transient_usage_failure=True)
|
||||
|
||||
account = await exchange_and_store_account(
|
||||
store,
|
||||
cast(Any, client),
|
||||
"code",
|
||||
"verifier",
|
||||
False,
|
||||
)
|
||||
|
||||
saved = store.load()
|
||||
assert account.last_known_usage is None
|
||||
assert saved.accounts[0].access_token == "access-token"
|
||||
assert saved.accounts[0].last_error is not None
|
||||
Loading…
Add table
Add a link
Reference in a new issue