363 lines
13 KiB
Python
363 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from gibby.client import OpenAIAPIError, OpenAIClient
|
|
from gibby.manager import AccountManager, NoUsableAccountError
|
|
from gibby.models import AccountRecord, StateFile, UsageSnapshot, UsageWindow
|
|
from gibby.settings import Settings
|
|
from gibby.store import JsonStateStore
|
|
|
|
|
|
class FakeClient(OpenAIClient):
|
|
def __init__(
|
|
self,
|
|
usage_by_token=None,
|
|
refresh_map=None,
|
|
invalid_tokens=None,
|
|
transient_validation_tokens=None,
|
|
auth_failing_usage_tokens=None,
|
|
auth_failing_validation_tokens=None,
|
|
permanent_refresh_tokens=None,
|
|
):
|
|
self.usage_by_token = usage_by_token or {}
|
|
self.refresh_map = refresh_map or {}
|
|
self.invalid_tokens = set(invalid_tokens or [])
|
|
self.transient_validation_tokens = set(transient_validation_tokens or [])
|
|
self.auth_failing_usage_tokens = set(auth_failing_usage_tokens or [])
|
|
self.auth_failing_validation_tokens = set(auth_failing_validation_tokens or [])
|
|
self.permanent_refresh_tokens = set(permanent_refresh_tokens or [])
|
|
self.fetched_usage_tokens: list[str] = []
|
|
self.validated_tokens: list[str] = []
|
|
self.refresh_calls: list[str] = []
|
|
self.settings = Settings(data_dir=Path("."))
|
|
|
|
async def refresh_access_token(self, refresh_token: str):
|
|
self.refresh_calls.append(refresh_token)
|
|
if refresh_token in self.permanent_refresh_tokens:
|
|
raise OpenAIAPIError("invalid_grant", permanent=True, status_code=401)
|
|
return self.refresh_map[refresh_token]
|
|
|
|
async def fetch_usage_payload(self, access_token: str):
|
|
self.fetched_usage_tokens.append(access_token)
|
|
if access_token in self.auth_failing_usage_tokens:
|
|
raise OpenAIAPIError("usage auth failed", permanent=True, status_code=401)
|
|
usage = self.usage_by_token[access_token]
|
|
return {
|
|
"email": f"{access_token}@example.com",
|
|
"rate_limit": {
|
|
"allowed": usage.allowed,
|
|
"limit_reached": usage.limit_reached,
|
|
"primary_window": {
|
|
"used_percent": usage.primary_window.used_percent,
|
|
"reset_at": usage.primary_window.reset_at,
|
|
}
|
|
if usage.primary_window
|
|
else None,
|
|
"secondary_window": {
|
|
"used_percent": usage.secondary_window.used_percent,
|
|
"reset_at": usage.secondary_window.reset_at,
|
|
}
|
|
if usage.secondary_window
|
|
else None,
|
|
},
|
|
}
|
|
|
|
async def validate_token(self, access_token: str) -> bool:
|
|
self.validated_tokens.append(access_token)
|
|
if access_token in self.auth_failing_validation_tokens:
|
|
return False
|
|
if access_token in self.transient_validation_tokens:
|
|
raise OpenAIAPIError("validation 502", permanent=False, status_code=502)
|
|
return access_token not in self.invalid_tokens
|
|
|
|
|
|
def make_usage(primary: int, secondary: int = 0, *, checked_at: int | None = None):
|
|
return UsageSnapshot(
|
|
checked_at=checked_at or int(time.time()),
|
|
primary_window=UsageWindow(used_percent=primary, reset_at=int(time.time()) + 300),
|
|
secondary_window=UsageWindow(
|
|
used_percent=secondary, reset_at=int(time.time()) + 300
|
|
),
|
|
)
|
|
|
|
|
|
def make_account(
|
|
email: str,
|
|
*,
|
|
token: str,
|
|
refresh_token: str = "refresh",
|
|
token_refresh_at: int | None = None,
|
|
usage: UsageSnapshot | None = None,
|
|
disabled: bool = False,
|
|
) -> AccountRecord:
|
|
return AccountRecord(
|
|
email=email,
|
|
access_token=token,
|
|
refresh_token=refresh_token,
|
|
token_refresh_at=token_refresh_at or int(time.time()) + 600,
|
|
usage=usage,
|
|
usage_checked_at=usage.checked_at if usage is not None else None,
|
|
disabled=disabled,
|
|
)
|
|
|
|
|
|
def make_store(tmp_path: Path, state: StateFile) -> JsonStateStore:
|
|
store = JsonStateStore(tmp_path / "accounts.json")
|
|
store.save(state)
|
|
return store
|
|
|
|
|
|
def make_manager(
|
|
store: JsonStateStore,
|
|
client: FakeClient,
|
|
*,
|
|
threshold: int = 95,
|
|
stale_seconds: int = 3600,
|
|
) -> AccountManager:
|
|
return AccountManager(
|
|
store,
|
|
client,
|
|
Settings(
|
|
data_dir=store.path.parent,
|
|
exhausted_usage_threshold=threshold,
|
|
usage_stale_seconds=stale_seconds,
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prefers_active_account_when_usable(tmp_path: Path) -> None:
|
|
active = make_account("a@example.com", token="tok-a", usage=make_usage(20, 0))
|
|
second = make_account("b@example.com", token="tok-b", usage=make_usage(80, 0))
|
|
store = make_store(
|
|
tmp_path,
|
|
StateFile(active_account="a@example.com", accounts=[active, second]),
|
|
)
|
|
client = FakeClient()
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a"
|
|
assert client.fetched_usage_tokens == []
|
|
assert client.validated_tokens == ["tok-a"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refreshes_stale_active_usage_before_deciding(tmp_path: Path) -> None:
|
|
stale = int(time.time()) - 7200
|
|
active = make_account("a@example.com", token="tok-a", usage=make_usage(20, 0, checked_at=stale))
|
|
second = make_account("b@example.com", token="tok-b", usage=make_usage(80, 0))
|
|
store = make_store(
|
|
tmp_path,
|
|
StateFile(active_account="a@example.com", accounts=[active, second]),
|
|
)
|
|
client = FakeClient(usage_by_token={"tok-a": make_usage(21, 0)})
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a"
|
|
assert client.fetched_usage_tokens == ["tok-a"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_falls_back_to_highest_primary_usage_when_active_unusable(tmp_path: Path) -> None:
|
|
active = make_account("a@example.com", token="tok-a", usage=make_usage(95, 0))
|
|
low = make_account("b@example.com", token="tok-b", usage=make_usage(40, 0))
|
|
high = make_account("c@example.com", token="tok-c", usage=make_usage(70, 0))
|
|
store = make_store(
|
|
tmp_path,
|
|
StateFile(active_account="a@example.com", accounts=[active, low, high]),
|
|
)
|
|
client = FakeClient()
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
state = store.load()
|
|
|
|
assert payload["token"] == "tok-c"
|
|
assert state.active_account == "c@example.com"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_disabled_accounts(tmp_path: Path) -> None:
|
|
active = make_account("a@example.com", token="tok-a", usage=make_usage(20, 0), disabled=True)
|
|
second = make_account("b@example.com", token="tok-b", usage=make_usage(70, 0))
|
|
store = make_store(
|
|
tmp_path,
|
|
StateFile(active_account="a@example.com", accounts=[active, second]),
|
|
)
|
|
client = FakeClient()
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-b"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_secondary_100_makes_account_unusable(tmp_path: Path) -> None:
|
|
active = make_account("a@example.com", token="tok-a", usage=make_usage(20, 100))
|
|
second = make_account("b@example.com", token="tok-b", usage=make_usage(30, 0))
|
|
store = make_store(
|
|
tmp_path,
|
|
StateFile(active_account="a@example.com", accounts=[active, second]),
|
|
)
|
|
client = FakeClient()
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-b"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refreshes_token_before_validation(tmp_path: Path) -> None:
|
|
account = make_account(
|
|
"a@example.com",
|
|
token="old-token",
|
|
refresh_token="ref-a",
|
|
token_refresh_at=int(time.time()) - 1,
|
|
usage=make_usage(20, 0),
|
|
)
|
|
store = make_store(tmp_path, StateFile(active_account="a@example.com", accounts=[account]))
|
|
client = FakeClient(refresh_map={"ref-a": ("new-token", "new-refresh", int(time.time()) + 600)})
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
saved = store.load()
|
|
|
|
assert payload["token"] == "new-token"
|
|
assert client.refresh_calls == ["ref-a"]
|
|
assert client.validated_tokens == ["new-token"]
|
|
assert saved.accounts[0].access_token == "new-token"
|
|
assert saved.accounts[0].refresh_token == "new-refresh"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_token_moves_account_to_failed_json(tmp_path: Path) -> None:
|
|
bad = make_account(
|
|
"bad@example.com",
|
|
token="tok-bad",
|
|
refresh_token="ref-bad",
|
|
usage=make_usage(20, 0),
|
|
)
|
|
good = make_account("good@example.com", token="tok-good", usage=make_usage(30, 0))
|
|
store = make_store(tmp_path, StateFile(active_account="bad@example.com", accounts=[bad, good]))
|
|
client = FakeClient(
|
|
refresh_map={"ref-bad": ("tok-bad-2", "ref-bad-2", int(time.time()) + 600)},
|
|
auth_failing_validation_tokens={"tok-bad", "tok-bad-2"},
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
state = store.load()
|
|
failed = json.loads((tmp_path / "failed.json").read_text())
|
|
|
|
assert payload["token"] == "tok-good"
|
|
assert client.validated_tokens == ["tok-bad", "tok-bad-2", "tok-good"]
|
|
assert [account.email for account in state.accounts] == ["good@example.com"]
|
|
assert failed["accounts"][0]["email"] == "bad@example.com"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transient_validation_error_does_not_move_account_to_failed(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
account = make_account("a@example.com", token="tok-a", usage=make_usage(20, 0))
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account="a@example.com", accounts=[account])
|
|
)
|
|
client = FakeClient(transient_validation_tokens={"tok-a"})
|
|
|
|
with pytest.raises(NoUsableAccountError):
|
|
await make_manager(store, client).issue_token_response()
|
|
|
|
state = store.load()
|
|
assert [account.email for account in state.accounts] == ["a@example.com"]
|
|
assert not (tmp_path / "failed.json").exists()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_usage_auth_failure_refreshes_token_before_failed_json(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
stale = int(time.time()) - 7200
|
|
account = make_account(
|
|
"a@example.com",
|
|
token="old-token",
|
|
refresh_token="ref-a",
|
|
token_refresh_at=int(time.time()) + 600,
|
|
usage=make_usage(20, 0, checked_at=stale),
|
|
)
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account="a@example.com", accounts=[account])
|
|
)
|
|
client = FakeClient(
|
|
usage_by_token={"new-token": make_usage(21, 0)},
|
|
refresh_map={"ref-a": ("new-token", "new-refresh", int(time.time()) + 600)},
|
|
auth_failing_usage_tokens={"old-token"},
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
state = store.load()
|
|
|
|
assert payload["token"] == "new-token"
|
|
assert client.fetched_usage_tokens == ["old-token", "new-token"]
|
|
assert [account.email for account in state.accounts] == ["new-token@example.com"]
|
|
assert not (tmp_path / "failed.json").exists()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validation_auth_failure_refreshes_token_before_failed_json(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
account = make_account(
|
|
"a@example.com",
|
|
token="old-token",
|
|
refresh_token="ref-a",
|
|
token_refresh_at=int(time.time()) + 600,
|
|
usage=make_usage(20, 0),
|
|
)
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account="a@example.com", accounts=[account])
|
|
)
|
|
client = FakeClient(
|
|
refresh_map={"ref-a": ("new-token", "new-refresh", int(time.time()) + 600)},
|
|
auth_failing_validation_tokens={"old-token"},
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
state = store.load()
|
|
|
|
assert payload["token"] == "new-token"
|
|
assert client.validated_tokens == ["old-token", "new-token"]
|
|
assert [account.email for account in state.accounts] == ["a@example.com"]
|
|
assert not (tmp_path / "failed.json").exists()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rereads_disk_between_requests(tmp_path: Path) -> None:
|
|
first = make_account("a@example.com", token="tok-a", usage=make_usage(20, 0))
|
|
store = make_store(tmp_path, StateFile(active_account="a@example.com", accounts=[first]))
|
|
client = FakeClient()
|
|
manager = make_manager(store, client)
|
|
|
|
first_payload = await manager.issue_token_response()
|
|
assert first_payload["token"] == "tok-a"
|
|
|
|
replacement = make_account("b@example.com", token="tok-b", usage=make_usage(10, 0))
|
|
store.save(StateFile(active_account="b@example.com", accounts=[replacement]))
|
|
|
|
second_payload = await manager.issue_token_response()
|
|
assert second_payload["token"] == "tok-b"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raises_when_no_usable_accounts(tmp_path: Path) -> None:
|
|
disabled = make_account("a@example.com", token="tok-a", usage=make_usage(10, 0), disabled=True)
|
|
exhausted = make_account("b@example.com", token="tok-b", usage=make_usage(95, 0))
|
|
store = make_store(tmp_path, StateFile(accounts=[disabled, exhausted]))
|
|
client = FakeClient()
|
|
|
|
with pytest.raises(NoUsableAccountError):
|
|
await make_manager(store, client).issue_token_response()
|