165 lines
4.6 KiB
Python
165 lines
4.6 KiB
Python
import asyncio
|
|
import json
|
|
|
|
import server
|
|
from providers.base import Provider, ProviderTokens
|
|
from utils.env import parse_int_env
|
|
|
|
|
|
class FakeRequest:
|
|
def __init__(self, provider: str):
|
|
self.match_info = {"provider": provider}
|
|
|
|
|
|
class FakeProvider(Provider):
|
|
def __init__(
|
|
self,
|
|
token: str | None = "tok",
|
|
usage: dict | None = None,
|
|
rotate: bool = False,
|
|
):
|
|
self._token = token
|
|
self._usage = usage or {
|
|
"used_percent": 10,
|
|
"remaining_percent": 90,
|
|
"primary_window": None,
|
|
"secondary_window": None,
|
|
}
|
|
self._rotate = rotate
|
|
self._prepare_threshold = 80
|
|
self.get_token_calls = 0
|
|
self.standby_calls = 0
|
|
|
|
@property
|
|
def prepare_threshold(self) -> int:
|
|
return self._prepare_threshold
|
|
|
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
|
return usage_percent >= self.prepare_threshold
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "fake"
|
|
|
|
async def get_token(self) -> str | None:
|
|
self.get_token_calls += 1
|
|
return self._token
|
|
|
|
async def register_new_account(self) -> bool:
|
|
return True
|
|
|
|
async def get_usage_info(self, access_token: str) -> dict:
|
|
_ = access_token
|
|
return dict(self._usage)
|
|
|
|
def load_tokens(self) -> ProviderTokens | None:
|
|
return None
|
|
|
|
def save_tokens(self, tokens: ProviderTokens) -> None:
|
|
_ = tokens
|
|
|
|
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
|
_ = usage_percent
|
|
return self._rotate
|
|
|
|
async def ensure_standby_account(
|
|
self,
|
|
usage_percent: int,
|
|
) -> None:
|
|
_ = usage_percent
|
|
self.standby_calls += 1
|
|
|
|
|
|
def _response_json(resp) -> dict:
|
|
return json.loads(resp.body.decode("utf-8"))
|
|
|
|
|
|
def test_parse_int_env_defaults(monkeypatch):
|
|
monkeypatch.delenv("X_TEST", raising=False)
|
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
|
|
|
|
|
def test_parse_int_env_invalid(monkeypatch):
|
|
monkeypatch.setenv("X_TEST", "abc")
|
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
|
|
|
|
|
def test_parse_int_env_out_of_range(monkeypatch):
|
|
monkeypatch.setenv("X_TEST", "999")
|
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
|
|
|
|
|
def test_build_limit_fields():
|
|
limit = server.build_limit(90, 85)
|
|
assert limit == {
|
|
"used_percent": 90,
|
|
"remaining_percent": 10,
|
|
"exhausted": False,
|
|
"needs_prepare": True,
|
|
}
|
|
|
|
|
|
def test_get_prepare_threshold():
|
|
assert (
|
|
server.get_prepare_threshold("chatgpt")
|
|
== server.PROVIDERS["chatgpt"].prepare_threshold
|
|
)
|
|
assert server.get_prepare_threshold("unknown") == 100
|
|
|
|
|
|
def test_token_handler_unknown_provider(monkeypatch):
|
|
monkeypatch.setattr(server, "PROVIDERS", {})
|
|
resp = asyncio.run(server.token_handler(FakeRequest("missing")))
|
|
assert resp.status == 404
|
|
|
|
|
|
def test_health_handler_ok():
|
|
resp = asyncio.run(server.health_handler(object()))
|
|
assert resp.status == 200
|
|
assert resp.text == "ok"
|
|
|
|
|
|
def test_token_handler_success(monkeypatch):
|
|
provider = FakeProvider()
|
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
|
data = _response_json(resp)
|
|
|
|
assert resp.status == 200
|
|
assert data["token"] == "tok"
|
|
assert data["limit"]["needs_prepare"] is False
|
|
|
|
|
|
def test_token_handler_triggers_standby(monkeypatch):
|
|
provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10})
|
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
|
|
|
called = {"value": False}
|
|
|
|
def fake_trigger(name, usage_percent, reason):
|
|
assert name == "fake"
|
|
assert usage_percent == 90
|
|
assert "standby policy" in reason
|
|
called["value"] = True
|
|
|
|
monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger)
|
|
|
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
|
assert resp.status == 200
|
|
assert called["value"] is True
|
|
|
|
|
|
def test_token_handler_rotation_path(monkeypatch):
|
|
provider = FakeProvider(
|
|
usage={"used_percent": 96, "remaining_percent": 4},
|
|
rotate=True,
|
|
)
|
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
|
monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None)
|
|
|
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
|
assert resp.status == 200
|
|
assert provider.get_token_calls >= 2
|