1
0
Fork 0
gibidy/tests/test_server_unit.py

165 lines
4.6 KiB
Python

import asyncio
import json
import server
from providers.base import Provider, ProviderTokens
from utils.env import parse_int_env
class FakeRequest:
def __init__(self, provider: str):
self.match_info = {"provider": provider}
class FakeProvider(Provider):
def __init__(
self,
token: str | None = "tok",
usage: dict | None = None,
rotate: bool = False,
):
self._token = token
self._usage = usage or {
"used_percent": 10,
"remaining_percent": 90,
"primary_window": None,
"secondary_window": None,
}
self._rotate = rotate
self._prepare_threshold = 80
self.get_token_calls = 0
self.standby_calls = 0
@property
def prepare_threshold(self) -> int:
return self._prepare_threshold
def should_prepare_standby(self, usage_percent: int) -> bool:
return usage_percent >= self.prepare_threshold
@property
def name(self) -> str:
return "fake"
async def get_token(self) -> str | None:
self.get_token_calls += 1
return self._token
async def register_new_account(self) -> bool:
return True
async def get_usage_info(self, access_token: str) -> dict:
_ = access_token
return dict(self._usage)
def load_tokens(self) -> ProviderTokens | None:
return None
def save_tokens(self, tokens: ProviderTokens) -> None:
_ = tokens
async def maybe_rotate_account(self, usage_percent: int) -> bool:
_ = usage_percent
return self._rotate
async def ensure_standby_account(
self,
usage_percent: int,
) -> None:
_ = usage_percent
self.standby_calls += 1
def _response_json(resp) -> dict:
return json.loads(resp.body.decode("utf-8"))
def test_parse_int_env_defaults(monkeypatch):
monkeypatch.delenv("X_TEST", raising=False)
assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_parse_int_env_invalid(monkeypatch):
monkeypatch.setenv("X_TEST", "abc")
assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_parse_int_env_out_of_range(monkeypatch):
monkeypatch.setenv("X_TEST", "999")
assert parse_int_env("X_TEST", 10, 1, 20) == 10
def test_build_limit_fields():
limit = server.build_limit(90, 85)
assert limit == {
"used_percent": 90,
"remaining_percent": 10,
"exhausted": False,
"needs_prepare": True,
}
def test_get_prepare_threshold():
assert (
server.get_prepare_threshold("chatgpt")
== server.PROVIDERS["chatgpt"].prepare_threshold
)
assert server.get_prepare_threshold("unknown") == 100
def test_token_handler_unknown_provider(monkeypatch):
monkeypatch.setattr(server, "PROVIDERS", {})
resp = asyncio.run(server.token_handler(FakeRequest("missing")))
assert resp.status == 404
def test_health_handler_ok():
resp = asyncio.run(server.health_handler(object()))
assert resp.status == 200
assert resp.text == "ok"
def test_token_handler_success(monkeypatch):
provider = FakeProvider()
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
monkeypatch.setattr(server, "background_tasks", {"fake": None})
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
data = _response_json(resp)
assert resp.status == 200
assert data["token"] == "tok"
assert data["limit"]["needs_prepare"] is False
def test_token_handler_triggers_standby(monkeypatch):
provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10})
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
monkeypatch.setattr(server, "background_tasks", {"fake": None})
called = {"value": False}
def fake_trigger(name, usage_percent, reason):
assert name == "fake"
assert usage_percent == 90
assert "standby policy" in reason
called["value"] = True
monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger)
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
assert resp.status == 200
assert called["value"] is True
def test_token_handler_rotation_path(monkeypatch):
provider = FakeProvider(
usage={"used_percent": 96, "remaining_percent": 4},
rotate=True,
)
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
monkeypatch.setattr(server, "background_tasks", {"fake": None})
monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None)
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
assert resp.status == 200
assert provider.get_token_calls >= 2