small fix
This commit is contained in:
parent
5349befcf4
commit
89b85b321e
8 changed files with 251 additions and 211 deletions
|
|
@ -19,11 +19,17 @@ class FakeClient(OpenAIClient):
|
|||
usage_by_token=None,
|
||||
refresh_map=None,
|
||||
invalid_tokens=None,
|
||||
transient_validation_tokens=None,
|
||||
auth_failing_usage_tokens=None,
|
||||
auth_failing_validation_tokens=None,
|
||||
permanent_refresh_tokens=None,
|
||||
):
|
||||
self.usage_by_token = usage_by_token or {}
|
||||
self.refresh_map = refresh_map or {}
|
||||
self.invalid_tokens = set(invalid_tokens or [])
|
||||
self.transient_validation_tokens = set(transient_validation_tokens or [])
|
||||
self.auth_failing_usage_tokens = set(auth_failing_usage_tokens or [])
|
||||
self.auth_failing_validation_tokens = set(auth_failing_validation_tokens or [])
|
||||
self.permanent_refresh_tokens = set(permanent_refresh_tokens or [])
|
||||
self.fetched_usage_tokens: list[str] = []
|
||||
self.validated_tokens: list[str] = []
|
||||
|
|
@ -38,6 +44,8 @@ class FakeClient(OpenAIClient):
|
|||
|
||||
async def fetch_usage_payload(self, access_token: str):
|
||||
self.fetched_usage_tokens.append(access_token)
|
||||
if access_token in self.auth_failing_usage_tokens:
|
||||
raise OpenAIAPIError("usage auth failed", permanent=True, status_code=401)
|
||||
usage = self.usage_by_token[access_token]
|
||||
return {
|
||||
"email": f"{access_token}@example.com",
|
||||
|
|
@ -61,6 +69,10 @@ class FakeClient(OpenAIClient):
|
|||
|
||||
async def validate_token(self, access_token: str) -> bool:
|
||||
self.validated_tokens.append(access_token)
|
||||
if access_token in self.auth_failing_validation_tokens:
|
||||
return False
|
||||
if access_token in self.transient_validation_tokens:
|
||||
raise OpenAIAPIError("validation 502", permanent=False, status_code=502)
|
||||
return access_token not in self.invalid_tokens
|
||||
|
||||
|
||||
|
|
@ -224,20 +236,105 @@ async def test_refreshes_token_before_validation(tmp_path: Path) -> None:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token_moves_account_to_failed_json(tmp_path: Path) -> None:
|
||||
bad = make_account("bad@example.com", token="tok-bad", usage=make_usage(20, 0))
|
||||
bad = make_account(
|
||||
"bad@example.com",
|
||||
token="tok-bad",
|
||||
refresh_token="ref-bad",
|
||||
usage=make_usage(20, 0),
|
||||
)
|
||||
good = make_account("good@example.com", token="tok-good", usage=make_usage(30, 0))
|
||||
store = make_store(tmp_path, StateFile(active_account="bad@example.com", accounts=[bad, good]))
|
||||
client = FakeClient(invalid_tokens={"tok-bad"})
|
||||
client = FakeClient(
|
||||
refresh_map={"ref-bad": ("tok-bad-2", "ref-bad-2", int(time.time()) + 600)},
|
||||
auth_failing_validation_tokens={"tok-bad", "tok-bad-2"},
|
||||
)
|
||||
|
||||
payload = await make_manager(store, client).issue_token_response()
|
||||
state = store.load()
|
||||
failed = json.loads((tmp_path / "failed.json").read_text())
|
||||
|
||||
assert payload["token"] == "tok-good"
|
||||
assert client.validated_tokens == ["tok-bad", "tok-bad-2", "tok-good"]
|
||||
assert [account.email for account in state.accounts] == ["good@example.com"]
|
||||
assert failed["accounts"][0]["email"] == "bad@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_validation_error_does_not_move_account_to_failed(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
account = make_account("a@example.com", token="tok-a", usage=make_usage(20, 0))
|
||||
store = make_store(
|
||||
tmp_path, StateFile(active_account="a@example.com", accounts=[account])
|
||||
)
|
||||
client = FakeClient(transient_validation_tokens={"tok-a"})
|
||||
|
||||
with pytest.raises(NoUsableAccountError):
|
||||
await make_manager(store, client).issue_token_response()
|
||||
|
||||
state = store.load()
|
||||
assert [account.email for account in state.accounts] == ["a@example.com"]
|
||||
assert not (tmp_path / "failed.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_auth_failure_refreshes_token_before_failed_json(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
stale = int(time.time()) - 7200
|
||||
account = make_account(
|
||||
"a@example.com",
|
||||
token="old-token",
|
||||
refresh_token="ref-a",
|
||||
token_refresh_at=int(time.time()) + 600,
|
||||
usage=make_usage(20, 0, checked_at=stale),
|
||||
)
|
||||
store = make_store(
|
||||
tmp_path, StateFile(active_account="a@example.com", accounts=[account])
|
||||
)
|
||||
client = FakeClient(
|
||||
usage_by_token={"new-token": make_usage(21, 0)},
|
||||
refresh_map={"ref-a": ("new-token", "new-refresh", int(time.time()) + 600)},
|
||||
auth_failing_usage_tokens={"old-token"},
|
||||
)
|
||||
|
||||
payload = await make_manager(store, client).issue_token_response()
|
||||
state = store.load()
|
||||
|
||||
assert payload["token"] == "new-token"
|
||||
assert client.fetched_usage_tokens == ["old-token", "new-token"]
|
||||
assert [account.email for account in state.accounts] == ["new-token@example.com"]
|
||||
assert not (tmp_path / "failed.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_auth_failure_refreshes_token_before_failed_json(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
account = make_account(
|
||||
"a@example.com",
|
||||
token="old-token",
|
||||
refresh_token="ref-a",
|
||||
token_refresh_at=int(time.time()) + 600,
|
||||
usage=make_usage(20, 0),
|
||||
)
|
||||
store = make_store(
|
||||
tmp_path, StateFile(active_account="a@example.com", accounts=[account])
|
||||
)
|
||||
client = FakeClient(
|
||||
refresh_map={"ref-a": ("new-token", "new-refresh", int(time.time()) + 600)},
|
||||
auth_failing_validation_tokens={"old-token"},
|
||||
)
|
||||
|
||||
payload = await make_manager(store, client).issue_token_response()
|
||||
state = store.load()
|
||||
|
||||
assert payload["token"] == "new-token"
|
||||
assert client.validated_tokens == ["old-token", "new-token"]
|
||||
assert [account.email for account in state.accounts] == ["a@example.com"]
|
||||
assert not (tmp_path / "failed.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rereads_disk_between_requests(tmp_path: Path) -> None:
|
||||
first = make_account("a@example.com", token="tok-a", usage=make_usage(20, 0))
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ def test_store_writes_minimal_accounts_schema(tmp_path) -> None:
|
|||
"usage": {
|
||||
"primary": {"used_percent": 70, "reset_at": 1300},
|
||||
"secondary": {"used_percent": 20, "reset_at": 4600},
|
||||
"limit_reached": False,
|
||||
"allowed": True,
|
||||
},
|
||||
"usage_checked_at": 1000,
|
||||
"disabled": False,
|
||||
|
|
@ -65,6 +67,8 @@ def test_store_load_reconstructs_account_state(tmp_path) -> None:
|
|||
"usage": {
|
||||
"primary": {"used_percent": 80, "reset_at": 1300},
|
||||
"secondary": {"used_percent": 15, "reset_at": 4600},
|
||||
"limit_reached": True,
|
||||
"allowed": False,
|
||||
},
|
||||
"usage_checked_at": 1000,
|
||||
"disabled": True,
|
||||
|
|
@ -82,6 +86,8 @@ def test_store_load_reconstructs_account_state(tmp_path) -> None:
|
|||
assert state.accounts[0].usage is not None
|
||||
assert state.accounts[0].usage.primary_window is not None
|
||||
assert state.accounts[0].usage.primary_window.used_percent == 80
|
||||
assert state.accounts[0].usage.limit_reached is True
|
||||
assert state.accounts[0].usage.allowed is False
|
||||
assert state.accounts[0].disabled is True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue