570 lines
18 KiB
Python
570 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from gibby.client import OpenAIAPIError, OpenAIClient
|
|
from gibby.manager import AccountManager, NoUsableAccountError
|
|
from gibby.models import AccountRecord, StateFile, UsageSnapshot, UsageWindow
|
|
from gibby.settings import Settings
|
|
from gibby.store import JsonStateStore
|
|
|
|
|
|
class FakeClient(OpenAIClient):
|
|
def __init__(
|
|
self,
|
|
usage_by_token=None,
|
|
refresh_map=None,
|
|
failing_tokens=None,
|
|
permanent_refresh_tokens=None,
|
|
):
|
|
self.usage_by_token = usage_by_token or {}
|
|
self.refresh_map = refresh_map or {}
|
|
self.failing_tokens = set(failing_tokens or [])
|
|
self.permanent_refresh_tokens = set(permanent_refresh_tokens or [])
|
|
self.fetched_tokens: list[str] = []
|
|
self.refresh_calls: list[str] = []
|
|
|
|
async def refresh_access_token(self, refresh_token: str):
|
|
self.refresh_calls.append(refresh_token)
|
|
if refresh_token in self.permanent_refresh_tokens:
|
|
raise OpenAIAPIError("invalid_grant", permanent=True, status_code=401)
|
|
return self.refresh_map[refresh_token]
|
|
|
|
async def fetch_usage_payload(self, access_token: str):
|
|
self.fetched_tokens.append(access_token)
|
|
if access_token in self.failing_tokens:
|
|
raise RuntimeError("usage failed")
|
|
usage = self.usage_by_token[access_token]
|
|
return {
|
|
"email": f"{access_token}@example.com",
|
|
"account_id": f"acct-{access_token}",
|
|
"rate_limit": {
|
|
"allowed": usage.allowed,
|
|
"limit_reached": usage.limit_reached,
|
|
"primary_window": {
|
|
"used_percent": usage.primary_window.used_percent,
|
|
"limit_window_seconds": usage.primary_window.limit_window_seconds,
|
|
"reset_after_seconds": usage.primary_window.reset_after_seconds,
|
|
"reset_at": usage.primary_window.reset_at,
|
|
}
|
|
if usage.primary_window
|
|
else None,
|
|
"secondary_window": {
|
|
"used_percent": usage.secondary_window.used_percent,
|
|
"limit_window_seconds": usage.secondary_window.limit_window_seconds,
|
|
"reset_after_seconds": usage.secondary_window.reset_after_seconds,
|
|
"reset_at": usage.secondary_window.reset_at,
|
|
}
|
|
if usage.secondary_window
|
|
else None,
|
|
},
|
|
}
|
|
|
|
|
|
def make_usage(
|
|
*,
|
|
used: int,
|
|
secondary_used: int | None = None,
|
|
limit_reached: bool = False,
|
|
reset_after: int = 0,
|
|
) -> UsageSnapshot:
|
|
exhausted = (
|
|
used >= 100
|
|
or (secondary_used is not None and secondary_used >= 100)
|
|
or limit_reached
|
|
)
|
|
return UsageSnapshot(
|
|
checked_at=int(time.time()),
|
|
used_percent=used,
|
|
remaining_percent=max(0, 100 - used),
|
|
exhausted=exhausted,
|
|
primary_window=UsageWindow(
|
|
used_percent=used,
|
|
limit_window_seconds=604800,
|
|
reset_after_seconds=reset_after,
|
|
reset_at=int(time.time()) + reset_after if reset_after else 0,
|
|
),
|
|
secondary_window=UsageWindow(
|
|
used_percent=secondary_used,
|
|
limit_window_seconds=604800,
|
|
reset_after_seconds=reset_after,
|
|
reset_at=int(time.time()) + reset_after if reset_after else 0,
|
|
)
|
|
if secondary_used is not None
|
|
else None,
|
|
limit_reached=limit_reached,
|
|
allowed=not exhausted,
|
|
)
|
|
|
|
|
|
def make_manager(
|
|
store: JsonStateStore,
|
|
client: FakeClient,
|
|
*,
|
|
threshold: int = 95,
|
|
) -> AccountManager:
|
|
return AccountManager(
|
|
store,
|
|
client,
|
|
Settings(data_dir=store.path.parent, exhausted_usage_threshold=threshold),
|
|
)
|
|
|
|
|
|
def make_store(tmp_path: Path, state: StateFile) -> JsonStateStore:
|
|
store = JsonStateStore(tmp_path / "accounts.json")
|
|
store.save(state)
|
|
return store
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prefers_active_account_when_locally_usable(tmp_path: Path) -> None:
|
|
active = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=20),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=70),
|
|
)
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account_id="a1", accounts=[active, second])
|
|
)
|
|
client = FakeClient(
|
|
usage_by_token={
|
|
"tok-a1": make_usage(used=21),
|
|
"tok-a2": make_usage(used=72),
|
|
}
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a1"
|
|
assert client.fetched_tokens == ["tok-a1"]
|
|
assert store.load().active_account_id == "tok-a1@example.com"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prefers_higher_primary_usage_from_saved_snapshot(tmp_path: Path) -> None:
|
|
first = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=20),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=70),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[first, second]))
|
|
client = FakeClient(usage_by_token={"tok-a2": make_usage(used=72)})
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a2"]
|
|
saved = store.load()
|
|
assert saved.active_account_id == "tok-a2@example.com"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_breaks_ties_with_secondary_usage(tmp_path: Path) -> None:
|
|
first = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=60, secondary_used=10),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=60, secondary_used=40),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[first, second]))
|
|
client = FakeClient(
|
|
usage_by_token={"tok-a2": make_usage(used=61, secondary_used=41)}
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_treats_missing_secondary_as_zero(tmp_path: Path) -> None:
|
|
first = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=60),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=60, secondary_used=1),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[first, second]))
|
|
client = FakeClient(
|
|
usage_by_token={"tok-a2": make_usage(used=61, secondary_used=1)}
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_account_still_in_cooldown(tmp_path: Path) -> None:
|
|
active = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
cooldown_until=int(time.time()) + 300,
|
|
last_known_usage=make_usage(used=80),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=20),
|
|
)
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account_id="a1", accounts=[active, second])
|
|
)
|
|
client = FakeClient(usage_by_token={"tok-a2": make_usage(used=25)})
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a2"]
|
|
assert store.load().active_account_id == "tok-a2@example.com"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_active_account_blocked_by_local_exhausted_snapshot(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
active = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=96),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=20),
|
|
)
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account_id="a1", accounts=[active, second])
|
|
)
|
|
client = FakeClient(usage_by_token={"tok-a2": make_usage(used=25)})
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a2"]
|
|
assert store.load().active_account_id == "tok-a2@example.com"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_live_checks_depleted_and_moves_to_next(tmp_path: Path) -> None:
|
|
high = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=94),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=50),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[high, second]))
|
|
client = FakeClient(
|
|
usage_by_token={
|
|
"tok-a1": make_usage(used=96, reset_after=120),
|
|
"tok-a2": make_usage(used=52),
|
|
}
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a1", "tok-a2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_live_checks_secondary_depleted_and_moves_to_next(tmp_path: Path) -> None:
|
|
high = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=30, secondary_used=94),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=20, secondary_used=10),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[high, second]))
|
|
client = FakeClient(
|
|
usage_by_token={
|
|
"tok-a1": make_usage(used=30, secondary_used=100, reset_after=120),
|
|
"tok-a2": make_usage(used=22, secondary_used=10),
|
|
}
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a1", "tok-a2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_falls_through_when_live_usage_is_depleted(tmp_path: Path) -> None:
|
|
first = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=80),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=70),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[first, second]))
|
|
client = FakeClient(
|
|
usage_by_token={
|
|
"tok-a1": make_usage(used=95, reset_after=120),
|
|
"tok-a2": make_usage(used=71),
|
|
}
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
saved = store.load()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert client.fetched_tokens == ["tok-a1", "tok-a2"]
|
|
depleted = next(
|
|
account for account in saved.accounts if account.id == "tok-a1@example.com"
|
|
)
|
|
assert depleted.cooldown_until is not None
|
|
assert depleted.last_known_usage is not None
|
|
assert depleted.last_known_usage.used_percent == 95
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_keeps_account_out_until_blocking_window_resets(tmp_path: Path) -> None:
|
|
current = int(time.time())
|
|
blocked = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=80, secondary_used=40),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=70),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[blocked, second]))
|
|
blocked_usage = UsageSnapshot(
|
|
checked_at=current,
|
|
used_percent=80,
|
|
remaining_percent=20,
|
|
exhausted=True,
|
|
primary_window=UsageWindow(
|
|
used_percent=80,
|
|
limit_window_seconds=604800,
|
|
reset_after_seconds=60,
|
|
reset_at=current + 60,
|
|
),
|
|
secondary_window=UsageWindow(
|
|
used_percent=40,
|
|
limit_window_seconds=604800,
|
|
reset_after_seconds=240,
|
|
reset_at=current + 240,
|
|
),
|
|
limit_reached=True,
|
|
allowed=False,
|
|
)
|
|
client = FakeClient(
|
|
usage_by_token={
|
|
"tok-a1": blocked_usage,
|
|
"tok-a2": make_usage(used=71),
|
|
}
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
saved = store.load()
|
|
blocked_saved = next(
|
|
account for account in saved.accounts if account.id == "tok-a1@example.com"
|
|
)
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
assert blocked_saved.cooldown_until == current + 240
|
|
assert client.fetched_tokens == ["tok-a1", "tok-a2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refreshes_expired_token_before_usage(tmp_path: Path) -> None:
|
|
active = AccountRecord(
|
|
id="a1",
|
|
access_token="old-token",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) - 1,
|
|
last_known_usage=make_usage(used=20),
|
|
)
|
|
store = make_store(tmp_path, StateFile(active_account_id="a1", accounts=[active]))
|
|
client = FakeClient(
|
|
usage_by_token={"new-token": make_usage(used=20)},
|
|
refresh_map={"ref-a1": ("new-token", "new-refresh", int(time.time()) + 600)},
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
saved = store.load()
|
|
|
|
assert payload["token"] == "new-token"
|
|
assert client.refresh_calls == ["ref-a1"]
|
|
assert client.fetched_tokens == ["new-token"]
|
|
assert saved.accounts[0].id == "new-token@example.com"
|
|
assert saved.accounts[0].access_token == "new-token"
|
|
assert saved.accounts[0].refresh_token == "new-refresh"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raises_when_all_accounts_unusable(tmp_path: Path) -> None:
|
|
active = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=80),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=70),
|
|
)
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account_id="a1", accounts=[active, second])
|
|
)
|
|
client = FakeClient(
|
|
usage_by_token={
|
|
"tok-a1": make_usage(used=96, reset_after=120),
|
|
"tok-a2": make_usage(used=97, reset_after=120),
|
|
}
|
|
)
|
|
|
|
with pytest.raises(NoUsableAccountError):
|
|
await make_manager(store, client).issue_token_response()
|
|
|
|
saved = store.load()
|
|
assert all(account.cooldown_until is not None for account in saved.accounts)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_threshold_can_be_overridden_for_selection(tmp_path: Path) -> None:
|
|
active = AccountRecord(
|
|
id="a1",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=96),
|
|
)
|
|
second = AccountRecord(
|
|
id="a2",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=20),
|
|
)
|
|
store = make_store(
|
|
tmp_path, StateFile(active_account_id="a1", accounts=[active, second])
|
|
)
|
|
client = FakeClient(
|
|
usage_by_token={
|
|
"tok-a1": make_usage(used=96),
|
|
"tok-a2": make_usage(used=25),
|
|
}
|
|
)
|
|
|
|
payload = await make_manager(store, client, threshold=97).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a1"
|
|
assert client.fetched_tokens == ["tok-a1"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_removes_account_and_records_failed_email_on_permanent_refresh_failure(
|
|
tmp_path: Path,
|
|
) -> None:
|
|
dead = AccountRecord(
|
|
id="a1",
|
|
email="dead@example.com",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
expires_at=int(time.time()) - 1,
|
|
last_known_usage=make_usage(used=80),
|
|
)
|
|
alive = AccountRecord(
|
|
id="a2",
|
|
email="alive@example.com",
|
|
access_token="tok-a2",
|
|
refresh_token="ref-a2",
|
|
expires_at=int(time.time()) + 600,
|
|
last_known_usage=make_usage(used=70),
|
|
)
|
|
store = make_store(tmp_path, StateFile(accounts=[dead, alive]))
|
|
client = FakeClient(
|
|
usage_by_token={"tok-a2": make_usage(used=71)},
|
|
permanent_refresh_tokens={"ref-a1"},
|
|
)
|
|
|
|
payload = await make_manager(store, client).issue_token_response()
|
|
|
|
assert payload["token"] == "tok-a2"
|
|
saved = store.load()
|
|
assert [account.id for account in saved.accounts] == ["tok-a2@example.com"]
|
|
assert (tmp_path / "failed.txt").read_text().splitlines() == ["dead@example.com"]
|