122 lines
3.8 KiB
Python
122 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts"))
|
|
|
|
from gibby.client import OpenAIAPIError
|
|
from gibby.models import AccountRecord, StateFile, UsageSnapshot, UsageWindow
|
|
from gibby.store import JsonStateStore
|
|
import refresh_limits # type: ignore[import-not-found]
|
|
|
|
|
|
def make_usage(primary: int, secondary: int | None = None) -> UsageSnapshot:
|
|
return UsageSnapshot(
|
|
checked_at=int(time.time()),
|
|
primary_window=UsageWindow(primary, int(time.time()) + 10),
|
|
secondary_window=UsageWindow(secondary or 0, int(time.time()) + 10)
|
|
if secondary is not None
|
|
else None,
|
|
)
|
|
|
|
|
|
class FakeClient:
|
|
def __init__(self, settings, *, permanent: bool = False):
|
|
self.settings = settings
|
|
self.permanent = permanent
|
|
|
|
async def aclose(self) -> None:
|
|
return
|
|
|
|
async def refresh_access_token(self, refresh_token: str):
|
|
return ("new-token", "new-refresh", int(time.time()) + 600)
|
|
|
|
async def fetch_usage_payload(self, access_token: str) -> dict:
|
|
if self.permanent:
|
|
raise OpenAIAPIError("invalid_grant", permanent=True, status_code=401)
|
|
usage = make_usage(12, 4)
|
|
primary_window = usage.primary_window
|
|
assert primary_window is not None
|
|
secondary_window = (
|
|
{
|
|
"used_percent": usage.secondary_window.used_percent,
|
|
"reset_at": usage.secondary_window.reset_at,
|
|
}
|
|
if usage.secondary_window is not None
|
|
else None
|
|
)
|
|
return {
|
|
"email": "acc@example.com",
|
|
"account_id": "acc-1",
|
|
"rate_limit": {
|
|
"allowed": usage.allowed,
|
|
"limit_reached": usage.limit_reached,
|
|
"primary_window": {
|
|
"used_percent": primary_window.used_percent,
|
|
"reset_at": primary_window.reset_at,
|
|
},
|
|
"secondary_window": secondary_window,
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_limits_updates_all_accounts(
|
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
store = JsonStateStore(tmp_path / "accounts.json")
|
|
store.save(
|
|
StateFile(
|
|
accounts=[
|
|
AccountRecord(
|
|
email="acc@example.com",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
token_refresh_at=int(time.time()) + 600,
|
|
)
|
|
]
|
|
)
|
|
)
|
|
monkeypatch.setattr(refresh_limits, "OpenAIClient", FakeClient)
|
|
|
|
await refresh_limits.run(tmp_path)
|
|
|
|
state = store.load()
|
|
assert state.accounts[0].usage is not None
|
|
assert state.accounts[0].usage.primary_window is not None
|
|
assert state.accounts[0].usage.primary_window.used_percent == 12
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_limits_removes_permanently_failed_account(
|
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
store = JsonStateStore(tmp_path / "accounts.json")
|
|
store.save(
|
|
StateFile(
|
|
accounts=[
|
|
AccountRecord(
|
|
email="dead@example.com",
|
|
access_token="tok-a1",
|
|
refresh_token="ref-a1",
|
|
token_refresh_at=int(time.time()) + 600,
|
|
)
|
|
]
|
|
)
|
|
)
|
|
|
|
def permanent_client(settings):
|
|
return FakeClient(settings, permanent=True)
|
|
|
|
monkeypatch.setattr(refresh_limits, "OpenAIClient", permanent_client)
|
|
|
|
await refresh_limits.run(tmp_path)
|
|
|
|
state = store.load()
|
|
assert state.accounts == []
|
|
failed = JsonStateStore(tmp_path / "accounts.json").load_failed_accounts()
|
|
assert [account.email for account in failed] == ["dead@example.com"]
|