refactor!: a lot of stuff
This commit is contained in:
parent
d6396e4050
commit
0af7179596
15 changed files with 663 additions and 302 deletions
16
README.md
16
README.md
|
|
@ -16,7 +16,7 @@ Response shape:
|
|||
"used_percent": 0,
|
||||
"remaining_percent": 100,
|
||||
"exhausted": false,
|
||||
"needs_refresh": false
|
||||
"needs_prepare": false
|
||||
},
|
||||
"usage": {
|
||||
"primary_window": {
|
||||
|
|
@ -54,10 +54,8 @@ Behavior:
|
|||
|
||||
## Startup Behavior
|
||||
|
||||
On startup, service:
|
||||
|
||||
1. Ensures active token exists and is usable.
|
||||
2. Ensures `next_account` is prepared for ChatGPT.
|
||||
On startup, service ensures active token exists and is usable.
|
||||
Standby preparation runs through provider lifecycle hooks/background trigger when needed.
|
||||
|
||||
## Data Files
|
||||
|
||||
|
|
@ -70,6 +68,14 @@ On startup, service:
|
|||
PYTHONPATH=./src python src/server.py
|
||||
```
|
||||
|
||||
## Unit Tests
|
||||
|
||||
The project has unit tests only (no integration/network tests).
|
||||
|
||||
```bash
|
||||
pytest -q
|
||||
```
|
||||
|
||||
## Docker Notes
|
||||
|
||||
- Dockerfile sets `DATA_DIR=/data`.
|
||||
|
|
|
|||
|
|
@ -8,5 +8,10 @@ dependencies = [
|
|||
"pkce==1.0.3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
package = false
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
|
@ -44,7 +45,12 @@ CHROME_FLAGS = [
|
|||
"--disable-search-engine-choice-screen",
|
||||
]
|
||||
|
||||
DEFAULT_CDP_PORT = 9222
|
||||
|
||||
def _allocate_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return int(s.getsockname()[1])
|
||||
|
||||
|
||||
def _fetch_ws_endpoint(port: int) -> str | None:
|
||||
|
|
@ -79,10 +85,9 @@ class ManagedBrowser:
|
|||
shutil.rmtree(self.profile_dir, ignore_errors=True)
|
||||
|
||||
|
||||
async def launch(
|
||||
playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT
|
||||
) -> ManagedBrowser:
|
||||
async def launch(playwright: Playwright) -> ManagedBrowser:
|
||||
chrome_path = playwright.chromium.executable_path
|
||||
cdp_port = _allocate_free_port()
|
||||
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
|
||||
|
||||
args = [
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import logging
|
||||
import re
|
||||
|
||||
from playwright.async_api import BrowserContext, Page
|
||||
from playwright.async_api import BrowserContext, Error as PlaywrightError, Page
|
||||
|
||||
from .base import BaseProvider
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ class TempMailOrgProvider(BaseProvider):
|
|||
value,
|
||||
)
|
||||
return value
|
||||
except Exception:
|
||||
except PlaywrightError:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
@ -53,8 +53,8 @@ class TempMailOrgProvider(BaseProvider):
|
|||
if found:
|
||||
logger.info("[temp-mail.org] email found by body scan: %s", found)
|
||||
return found
|
||||
except Exception:
|
||||
pass
|
||||
except PlaywrightError:
|
||||
logger.debug("Failed to scan body text for email")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ class TempMailOrgProvider(BaseProvider):
|
|||
try:
|
||||
count = await items.count()
|
||||
logger.info("[temp-mail.org] inbox items: %s", count)
|
||||
except Exception:
|
||||
except PlaywrightError:
|
||||
count = 0
|
||||
|
||||
if count > 0:
|
||||
|
|
@ -87,30 +87,30 @@ class TempMailOrgProvider(BaseProvider):
|
|||
continue
|
||||
text = (await item.inner_text()).strip().replace("\n", " ")
|
||||
logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160])
|
||||
except Exception:
|
||||
except PlaywrightError:
|
||||
continue
|
||||
if text:
|
||||
try:
|
||||
await item.click()
|
||||
logger.info("[temp-mail.org] opened item[%s]", idx)
|
||||
except Exception:
|
||||
pass
|
||||
except PlaywrightError:
|
||||
logger.debug("Failed to open inbox item[%s]", idx)
|
||||
|
||||
message_text = text
|
||||
try:
|
||||
content = await page.content()
|
||||
if content and "Your ChatGPT code is" in content:
|
||||
message_text = content
|
||||
except Exception:
|
||||
pass
|
||||
except PlaywrightError:
|
||||
logger.debug("Failed to read opened message content")
|
||||
|
||||
try:
|
||||
await page.go_back(
|
||||
wait_until="domcontentloaded", timeout=5000
|
||||
)
|
||||
logger.info("[temp-mail.org] returned back to inbox")
|
||||
except Exception:
|
||||
pass
|
||||
except PlaywrightError:
|
||||
logger.debug("Failed to return back to inbox")
|
||||
|
||||
return message_text
|
||||
|
||||
|
|
|
|||
|
|
@ -52,3 +52,23 @@ class Provider(ABC):
|
|||
def save_tokens(self, tokens: ProviderTokens) -> None:
|
||||
"""Save tokens to storage"""
|
||||
pass
|
||||
|
||||
async def force_recreate_token(self) -> str | None:
|
||||
"""Force-create a new active token when normal acquisition fails."""
|
||||
return None
|
||||
|
||||
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
||||
"""Rotate active account/token if provider policy requires it."""
|
||||
return False
|
||||
|
||||
async def ensure_standby_account(
|
||||
self,
|
||||
usage_percent: int,
|
||||
prepare_threshold: int,
|
||||
) -> None:
|
||||
"""Prepare standby account/token asynchronously when needed."""
|
||||
return None
|
||||
|
||||
async def startup_prepare(self) -> None:
|
||||
"""Optional provider-specific startup preparation."""
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .tokens import (
|
|||
load_tokens,
|
||||
promote_next_tokens,
|
||||
refresh_tokens,
|
||||
save_next_tokens,
|
||||
save_state,
|
||||
save_tokens,
|
||||
)
|
||||
from .usage import get_usage_data
|
||||
|
|
@ -44,10 +44,14 @@ class ChatGPTProvider(Provider):
|
|||
attempt,
|
||||
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
|
||||
)
|
||||
success = await self.register_new_account()
|
||||
if success:
|
||||
generated_tokens = await register_chatgpt_account(
|
||||
email_provider_factory=self.email_provider_factory,
|
||||
)
|
||||
if generated_tokens:
|
||||
save_tokens(generated_tokens)
|
||||
return True
|
||||
logger.warning("Registration attempt %s failed", attempt)
|
||||
await asyncio.sleep(1.5 * attempt)
|
||||
return False
|
||||
|
||||
async def _create_next_account_under_lock(self) -> bool:
|
||||
|
|
@ -56,23 +60,28 @@ class ChatGPTProvider(Provider):
|
|||
return True
|
||||
|
||||
logger.info("Creating next account")
|
||||
success = await self._register_with_retries()
|
||||
if not success:
|
||||
return False
|
||||
|
||||
generated_active = load_tokens()
|
||||
if not generated_active:
|
||||
return False
|
||||
|
||||
# Registration writes new tokens as active; restore old active and keep
|
||||
# generated account as next.
|
||||
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
||||
logger.info(
|
||||
"Next-account registration attempt %s/%s",
|
||||
attempt,
|
||||
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
|
||||
)
|
||||
generated_tokens = await register_chatgpt_account(
|
||||
email_provider_factory=self.email_provider_factory,
|
||||
)
|
||||
if generated_tokens:
|
||||
if active_before:
|
||||
save_tokens(active_before)
|
||||
save_state(active_before, generated_tokens)
|
||||
else:
|
||||
clear_next_tokens()
|
||||
save_next_tokens(generated_active)
|
||||
save_state(generated_tokens, None)
|
||||
logger.info("Next account is ready")
|
||||
return True
|
||||
logger.warning("Next-account registration attempt %s failed", attempt)
|
||||
await asyncio.sleep(1.5 * attempt)
|
||||
|
||||
if active_before or next_before:
|
||||
save_state(active_before, next_before)
|
||||
return False
|
||||
|
||||
async def force_recreate_token(self) -> str | None:
|
||||
async with self._token_write_lock:
|
||||
|
|
@ -85,6 +94,9 @@ class ChatGPTProvider(Provider):
|
|||
return None
|
||||
return tokens.access_token
|
||||
|
||||
async def startup_prepare(self) -> None:
|
||||
await self.ensure_next_account()
|
||||
|
||||
async def ensure_next_account(self) -> bool:
|
||||
next_tokens = load_next_tokens()
|
||||
if next_tokens and not next_tokens.is_expired:
|
||||
|
|
@ -96,6 +108,14 @@ class ChatGPTProvider(Provider):
|
|||
return True
|
||||
return await self._create_next_account_under_lock()
|
||||
|
||||
async def ensure_standby_account(
|
||||
self,
|
||||
usage_percent: int,
|
||||
prepare_threshold: int,
|
||||
) -> None:
|
||||
if usage_percent >= prepare_threshold:
|
||||
await self.ensure_next_account()
|
||||
|
||||
async def maybe_switch_active_account(self, usage_percent: int) -> bool:
|
||||
if usage_percent < CHATGPT_SWITCH_THRESHOLD:
|
||||
return False
|
||||
|
|
@ -119,6 +139,9 @@ class ChatGPTProvider(Provider):
|
|||
)
|
||||
return switched
|
||||
|
||||
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
||||
return await self.maybe_switch_active_account(usage_percent)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "chatgpt"
|
||||
|
|
@ -154,13 +177,17 @@ class ChatGPTProvider(Provider):
|
|||
|
||||
async def register_new_account(self) -> bool:
|
||||
"""Register a new ChatGPT account"""
|
||||
return await register_chatgpt_account(
|
||||
generated_tokens = await register_chatgpt_account(
|
||||
email_provider_factory=self.email_provider_factory,
|
||||
)
|
||||
if not generated_tokens:
|
||||
return False
|
||||
save_tokens(generated_tokens)
|
||||
return True
|
||||
|
||||
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Get usage information for the current token"""
|
||||
usage_data = get_usage_data(access_token)
|
||||
usage_data = await get_usage_data(access_token)
|
||||
if not usage_data:
|
||||
return {"error": "Failed to get usage"}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,17 @@ from typing import Callable
|
|||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import aiohttp
|
||||
from playwright.async_api import async_playwright, Page, BrowserContext
|
||||
from playwright.async_api import (
|
||||
async_playwright,
|
||||
Error as PlaywrightError,
|
||||
Page,
|
||||
BrowserContext,
|
||||
)
|
||||
|
||||
from browser import launch as launch_browser
|
||||
from email_providers import BaseProvider
|
||||
from providers.base import ProviderTokens
|
||||
from .tokens import CLIENT_ID, save_tokens
|
||||
from .tokens import CLIENT_ID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -46,9 +51,9 @@ async def save_error_screenshot(page: Page | None, step: str):
|
|||
filename = screenshots_dir / f"error_{step}_{timestamp}.png"
|
||||
try:
|
||||
await page.screenshot(path=str(filename))
|
||||
logger.error(f"Screenshot saved: {filename}")
|
||||
except:
|
||||
pass
|
||||
logger.error("Screenshot saved: %s", filename)
|
||||
except PlaywrightError as e:
|
||||
logger.warning("Failed to save screenshot at step %s: %s", step, e)
|
||||
|
||||
|
||||
def generate_password(length: int = 20) -> str:
|
||||
|
|
@ -204,8 +209,7 @@ def generate_state() -> str:
|
|||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def build_authorize_url(verifier: str, challenge: str, state: str) -> str:
|
||||
del verifier
|
||||
def build_authorize_url(challenge: str, state: str) -> str:
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": CLIENT_ID,
|
||||
|
|
@ -222,7 +226,6 @@ def build_authorize_url(verifier: str, challenge: str, state: str) -> str:
|
|||
|
||||
|
||||
async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": CLIENT_ID,
|
||||
|
|
@ -230,18 +233,26 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
|||
"code_verifier": verifier,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=20)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(TOKEN_URL, data=payload) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
raise RuntimeError(f"Token exchange failed: {resp.status} {text}")
|
||||
body = await resp.json()
|
||||
except (aiohttp.ClientError, TimeoutError) as e:
|
||||
raise RuntimeError(f"Token exchange request error: {e}") from e
|
||||
|
||||
try:
|
||||
expires_in = int(body["expires_in"])
|
||||
return ProviderTokens(
|
||||
access_token=body["access_token"],
|
||||
refresh_token=body["refresh_token"],
|
||||
expires_at=time.time() + expires_in,
|
||||
)
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
raise RuntimeError(f"Token exchange response parse error: {e}") from e
|
||||
|
||||
|
||||
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
||||
|
|
@ -270,6 +281,24 @@ async def click_continue(page: Page, timeout_ms: int = 10000):
|
|||
await btn.click()
|
||||
|
||||
|
||||
async def click_any_visible_button(
|
||||
page: Page,
|
||||
labels: list[str],
|
||||
timeout_ms: int = 2000,
|
||||
) -> bool:
|
||||
for label in labels:
|
||||
button = page.get_by_role("button", name=label)
|
||||
if await button.count() == 0:
|
||||
continue
|
||||
try:
|
||||
await button.first.wait_for(state="visible", timeout=timeout_ms)
|
||||
await button.first.click(timeout=timeout_ms)
|
||||
return True
|
||||
except PlaywrightError:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
async def wait_for_signup_stabilization(
|
||||
page: Page,
|
||||
source_url: str,
|
||||
|
|
@ -288,12 +317,12 @@ async def wait_for_signup_stabilization(
|
|||
|
||||
async def register_chatgpt_account(
|
||||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||
) -> bool:
|
||||
) -> ProviderTokens | None:
|
||||
logger.info("=== Starting ChatGPT account registration ===")
|
||||
|
||||
if email_provider_factory is None:
|
||||
logger.error("No email provider factory configured")
|
||||
return False
|
||||
return None
|
||||
|
||||
birth_month, birth_day, birth_year = generate_birthdate_90s()
|
||||
|
||||
|
|
@ -321,7 +350,7 @@ async def register_chatgpt_account(
|
|||
full_name = generate_name()
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
oauth_state = generate_state()
|
||||
authorize_url = build_authorize_url(verifier, challenge, oauth_state)
|
||||
authorize_url = build_authorize_url(challenge, oauth_state)
|
||||
|
||||
logger.info("[2/5] Registering ChatGPT for %s", email)
|
||||
chatgpt_page = await context.new_page()
|
||||
|
|
@ -352,19 +381,18 @@ async def register_chatgpt_account(
|
|||
raise AutomationError(
|
||||
"email_provider", "Email provider returned no verification message"
|
||||
)
|
||||
logger.info("[3/5] Verification code extracted: %s", code)
|
||||
logger.info("[3/5] Verification code extracted")
|
||||
|
||||
await chatgpt_page.bring_to_front()
|
||||
code_input = chatgpt_page.get_by_placeholder("Code")
|
||||
if await code_input.count() > 0:
|
||||
await code_input.fill(code)
|
||||
await code_input.first.wait_for(state="visible", timeout=10000)
|
||||
await code_input.first.fill(code)
|
||||
await click_continue(chatgpt_page)
|
||||
|
||||
logger.info("[4/5] Setting profile...")
|
||||
name_input = chatgpt_page.get_by_placeholder("Full name")
|
||||
await name_input.first.wait_for(state="visible", timeout=20000)
|
||||
if await name_input.count() > 0:
|
||||
await name_input.fill(full_name)
|
||||
await name_input.first.fill(full_name)
|
||||
|
||||
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
||||
profile_url = chatgpt_page.url
|
||||
|
|
@ -387,45 +415,42 @@ async def register_chatgpt_account(
|
|||
oauth_page.on("request", handle_request)
|
||||
|
||||
await oauth_page.goto(authorize_url, wait_until="domcontentloaded")
|
||||
await oauth_page.locator(
|
||||
'input[type="email"], input[name="email"]'
|
||||
).first.wait_for(state="visible", timeout=20000)
|
||||
|
||||
email_input = oauth_page.locator('input[type="email"], input[name="email"]')
|
||||
if await email_input.count() > 0:
|
||||
await email_input.first.wait_for(state="visible", timeout=10000)
|
||||
await email_input.first.fill(email)
|
||||
|
||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||
if await continue_button.count() > 0:
|
||||
await continue_button.first.click()
|
||||
await oauth_page.locator('input[type="password"]').first.wait_for(
|
||||
state="visible", timeout=20000
|
||||
await click_any_visible_button(
|
||||
oauth_page, ["Continue"], timeout_ms=4000
|
||||
)
|
||||
|
||||
password_input = oauth_page.locator('input[type="password"]')
|
||||
if await password_input.count() > 0:
|
||||
await password_input.first.wait_for(state="visible", timeout=10000)
|
||||
await password_input.first.fill(password)
|
||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||
if await continue_button.count() > 0:
|
||||
await continue_button.first.click()
|
||||
await click_any_visible_button(
|
||||
oauth_page, ["Continue"], timeout_ms=4000
|
||||
)
|
||||
|
||||
for label in ["Continue", "Allow", "Authorize"]:
|
||||
button = oauth_page.get_by_role("button", name=label)
|
||||
if await button.count() > 0:
|
||||
try:
|
||||
await button.first.click(timeout=5000)
|
||||
await oauth_page.wait_for_timeout(500)
|
||||
except Exception:
|
||||
pass
|
||||
for _ in range(6):
|
||||
if redirect_url_captured:
|
||||
break
|
||||
clicked = await click_any_visible_button(
|
||||
oauth_page,
|
||||
["Continue", "Allow", "Authorize"],
|
||||
timeout_ms=2000,
|
||||
)
|
||||
if clicked:
|
||||
await asyncio.sleep(0.4)
|
||||
else:
|
||||
await asyncio.sleep(0.4)
|
||||
|
||||
if not redirect_url_captured:
|
||||
try:
|
||||
await oauth_page.wait_for_timeout(4000)
|
||||
current_url = oauth_page.url
|
||||
if "localhost:1455" in current_url and "code=" in current_url:
|
||||
redirect_url_captured = current_url
|
||||
logger.info("Captured OAuth redirect from page URL")
|
||||
except Exception:
|
||||
except PlaywrightError:
|
||||
pass
|
||||
|
||||
if not redirect_url_captured:
|
||||
|
|
@ -446,20 +471,18 @@ async def register_chatgpt_account(
|
|||
raise AutomationError("oauth", "OAuth state mismatch", oauth_page)
|
||||
|
||||
tokens = await exchange_code_for_tokens(auth_code, verifier)
|
||||
save_tokens(tokens)
|
||||
logger.info("OAuth tokens saved successfully")
|
||||
logger.info("OAuth tokens fetched successfully")
|
||||
|
||||
return True
|
||||
return tokens
|
||||
|
||||
except AutomationError as e:
|
||||
logger.error(f"Error at step [{e.step}]: {e.message}")
|
||||
await save_error_screenshot(e.page, e.step)
|
||||
return False
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
await save_error_screenshot(current_page, "unexpected")
|
||||
return False
|
||||
return None
|
||||
finally:
|
||||
if managed:
|
||||
await asyncio.sleep(2)
|
||||
await managed.close()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -35,7 +36,7 @@ def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None:
|
|||
refresh_token=data["refresh_token"],
|
||||
expires_at=data["expires_at"],
|
||||
)
|
||||
except KeyError, TypeError:
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -54,8 +55,20 @@ def _load_raw() -> dict[str, Any] | None:
|
|||
|
||||
def _save_raw(data: dict[str, Any]) -> None:
|
||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(TOKENS_FILE, "w") as f:
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
prefix=f"{TOKENS_FILE.name}.",
|
||||
suffix=".tmp",
|
||||
dir=str(TOKENS_FILE.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, TOKENS_FILE)
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
|
|
@ -68,7 +81,6 @@ def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
|||
"next_account": data.get("next_account"),
|
||||
}
|
||||
|
||||
# Backward compatibility with old flat schema
|
||||
return {"active": data, "next_account": None}
|
||||
|
||||
|
||||
|
|
@ -79,9 +91,7 @@ def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]:
|
|||
return active, next_account
|
||||
|
||||
|
||||
def save_state(
|
||||
active: ProviderTokens | None, next_account: ProviderTokens | None
|
||||
) -> None:
|
||||
def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None:
|
||||
payload = {
|
||||
"active": _tokens_to_dict(active) if active else None,
|
||||
"next_account": _tokens_to_dict(next_account) if next_account else None,
|
||||
|
|
@ -104,13 +114,8 @@ def save_tokens(tokens: ProviderTokens):
|
|||
save_state(tokens, next_account)
|
||||
|
||||
|
||||
def save_next_tokens(tokens: ProviderTokens):
|
||||
active, _ = load_state()
|
||||
save_state(active, tokens)
|
||||
|
||||
|
||||
def promote_next_tokens() -> bool:
|
||||
active, next_account = load_state()
|
||||
_, next_account = load_state()
|
||||
if not next_account:
|
||||
return False
|
||||
save_state(next_account, None)
|
||||
|
|
@ -123,42 +128,34 @@ def clear_next_tokens():
|
|||
|
||||
|
||||
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=15)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(TOKEN_URL, data=data) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
logger.warning("Token refresh failed: %s %s", resp.status, text)
|
||||
return None
|
||||
json_resp = await resp.json()
|
||||
expires_in = json_resp["expires_in"]
|
||||
except (aiohttp.ClientError, TimeoutError) as e:
|
||||
logger.warning("Token refresh request error: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Token refresh unexpected error: %s", e)
|
||||
return None
|
||||
|
||||
try:
|
||||
expires_in = int(json_resp["expires_in"])
|
||||
return ProviderTokens(
|
||||
access_token=json_resp["access_token"],
|
||||
refresh_token=json_resp["refresh_token"],
|
||||
expires_at=time.time() + expires_in,
|
||||
)
|
||||
|
||||
|
||||
async def get_valid_tokens() -> ProviderTokens | None:
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
logger.info("No tokens found")
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.warning("Token refresh response parse error: %s", e)
|
||||
return None
|
||||
|
||||
if tokens.is_expired:
|
||||
logger.info("Token expired, refreshing...")
|
||||
if not tokens.refresh_token:
|
||||
logger.info("No refresh token available")
|
||||
return None
|
||||
new_tokens = await refresh_tokens(tokens.refresh_token)
|
||||
if not new_tokens:
|
||||
logger.warning("Failed to refresh token")
|
||||
return None
|
||||
save_tokens(new_tokens)
|
||||
return new_tokens
|
||||
|
||||
return tokens
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import json
|
||||
import socket
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def clamp_percent(value: Any) -> int:
|
||||
try:
|
||||
num = float(value)
|
||||
except TypeError, ValueError:
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
if num < 0:
|
||||
return 0
|
||||
|
|
@ -28,30 +29,36 @@ def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None:
|
|||
}
|
||||
|
||||
|
||||
def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None:
|
||||
async def get_usage_data(
|
||||
access_token: str,
|
||||
timeout_ms: int = 10000,
|
||||
) -> dict[str, Any] | None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": "CodexProxy",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
req = urllib.request.Request(
|
||||
"https://chatgpt.com/backend-api/wham/usage",
|
||||
headers=headers,
|
||||
method="GET",
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000)
|
||||
url = "https://chatgpt.com/backend-api/wham/usage"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as res:
|
||||
if not res.ok:
|
||||
body = await res.text()
|
||||
logger.warning(
|
||||
"Usage fetch failed: status=%s body=%s",
|
||||
res.status,
|
||||
body[:300],
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res:
|
||||
body = res.read().decode("utf-8", errors="replace")
|
||||
except urllib.error.HTTPError:
|
||||
return None
|
||||
except urllib.error.URLError, socket.timeout:
|
||||
data = await res.json()
|
||||
except (aiohttp.ClientError, TimeoutError) as e:
|
||||
logger.warning("Usage fetch request error: %s", e)
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
except Exception as e:
|
||||
logger.warning("Usage fetch unexpected error: %s", e)
|
||||
return None
|
||||
|
||||
rate_limit = data.get("rate_limit") or {}
|
||||
|
|
@ -76,10 +83,3 @@ def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any]
|
|||
"limit_reached": bool(rate_limit.get("limit_reached")),
|
||||
"allowed": bool(rate_limit.get("allowed", True)),
|
||||
}
|
||||
|
||||
|
||||
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
||||
data = get_usage_data(access_token, timeout_ms=timeout_ms)
|
||||
if not data:
|
||||
return -1
|
||||
return int(data["used_percent"])
|
||||
|
|
|
|||
199
src/server.py
199
src/server.py
|
|
@ -5,23 +5,44 @@ import os
|
|||
from aiohttp import web
|
||||
|
||||
from providers.chatgpt import ChatGPTProvider
|
||||
|
||||
PORT = int(os.environ.get("PORT", "8080"))
|
||||
CHATGPT_PREPARE_THRESHOLD = int(os.environ.get("CHATGPT_PREPARE_THRESHOLD", "85"))
|
||||
LIMIT_EXHAUSTED_PERCENT = 100
|
||||
from providers.chatgpt.provider import CHATGPT_SWITCH_THRESHOLD
|
||||
from providers.base import Provider
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Registry of available providers
|
||||
PROVIDERS = {
|
||||
|
||||
def _parse_int_env(name: str, default: int, minimum: int, maximum: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
logger.warning("Invalid %s=%r, using default %s", name, raw, default)
|
||||
return default
|
||||
if value < minimum or value > maximum:
|
||||
logger.warning(
|
||||
"%s=%s out of range [%s,%s], using default %s",
|
||||
name,
|
||||
value,
|
||||
minimum,
|
||||
maximum,
|
||||
default,
|
||||
)
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
PORT = _parse_int_env("PORT", 8080, 1, 65535)
|
||||
CHATGPT_PREPARE_THRESHOLD = _parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
|
||||
LIMIT_EXHAUSTED_PERCENT = 100
|
||||
|
||||
PROVIDERS: dict[str, Provider] = {
|
||||
"chatgpt": ChatGPTProvider(),
|
||||
}
|
||||
|
||||
refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()}
|
||||
background_refresh_tasks: dict[str, asyncio.Task | None] = {
|
||||
name: None for name in PROVIDERS.keys()
|
||||
}
|
||||
background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS}
|
||||
|
||||
|
||||
@web.middleware
|
||||
|
|
@ -31,16 +52,22 @@ async def request_log_middleware(request: web.Request, handler):
|
|||
return response
|
||||
|
||||
|
||||
def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
||||
def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]:
|
||||
remaining = max(0, 100 - usage_percent)
|
||||
return {
|
||||
"used_percent": usage_percent,
|
||||
"remaining_percent": remaining,
|
||||
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
|
||||
"needs_refresh": usage_percent >= CHATGPT_PREPARE_THRESHOLD,
|
||||
"needs_prepare": usage_percent >= prepare_threshold,
|
||||
}
|
||||
|
||||
|
||||
def get_prepare_threshold(provider_name: str) -> int:
|
||||
if provider_name == "chatgpt":
|
||||
return CHATGPT_PREPARE_THRESHOLD
|
||||
return 100
|
||||
|
||||
|
||||
async def ensure_provider_token_ready(provider_name: str):
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
|
|
@ -52,116 +79,68 @@ async def ensure_provider_token_ready(provider_name: str):
|
|||
logger.warning(
|
||||
"[%s] Startup token check failed, forcing recreation", provider_name
|
||||
)
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
token = await provider.force_recreate_token()
|
||||
|
||||
if not token:
|
||||
logger.error("[%s] Could not prepare token at startup", provider_name)
|
||||
return
|
||||
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
await provider.ensure_next_account()
|
||||
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" not in usage_info:
|
||||
logger.info("[%s] Startup token is ready", provider_name)
|
||||
return
|
||||
|
||||
if "error" in usage_info:
|
||||
logger.warning(
|
||||
"[%s] Startup token invalid for usage, forcing recreation", provider_name
|
||||
)
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
token = await provider.force_recreate_token()
|
||||
if token:
|
||||
logger.info("[%s] Startup token recreated successfully", provider_name)
|
||||
if not token:
|
||||
logger.error("[%s] Startup token recreation failed", provider_name)
|
||||
return
|
||||
|
||||
logger.error("[%s] Startup token recreation failed", provider_name)
|
||||
await provider.startup_prepare()
|
||||
logger.info("[%s] Startup token is ready", provider_name)
|
||||
|
||||
|
||||
async def on_startup(app: web.Application):
|
||||
del app
|
||||
for provider_name in PROVIDERS.keys():
|
||||
await ensure_provider_token_ready(provider_name)
|
||||
|
||||
|
||||
async def issue_new_token(provider_name: str) -> str | None:
|
||||
async def ensure_standby_task(provider_name: str, usage_percent: int, reason: str):
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
async with refresh_locks[provider_name]:
|
||||
logger.info(f"[{provider_name}] Generating new token")
|
||||
success = await provider.register_new_account()
|
||||
if not success:
|
||||
logger.error(f"[{provider_name}] Token generation failed")
|
||||
return None
|
||||
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
logger.error(f"[{provider_name}] Token was generated but not available")
|
||||
return None
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def background_refresh_worker(provider_name: str, reason: str):
|
||||
return
|
||||
try:
|
||||
logger.info(f"[{provider_name}] Starting background token refresh ({reason})")
|
||||
new_token = await issue_new_token(provider_name)
|
||||
if new_token:
|
||||
logger.info(f"[{provider_name}] Background token refresh completed")
|
||||
else:
|
||||
logger.error(f"[{provider_name}] Background token refresh failed")
|
||||
logger.info("[%s] Preparing standby in background (%s)", provider_name, reason)
|
||||
threshold = get_prepare_threshold(provider_name)
|
||||
await provider.ensure_standby_account(usage_percent, threshold)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"[{provider_name}] Unhandled error in background token refresh"
|
||||
)
|
||||
logger.exception("[%s] Unhandled standby preparation error", provider_name)
|
||||
|
||||
|
||||
def trigger_background_refresh(provider_name: str, reason: str):
|
||||
task = background_refresh_tasks.get(provider_name)
|
||||
def trigger_standby_prepare(provider_name: str, usage_percent: int, reason: str):
|
||||
task = background_tasks.get(provider_name)
|
||||
if task and not task.done():
|
||||
logger.info(
|
||||
f"[{provider_name}] Background refresh already running, skip ({reason})"
|
||||
"[%s] Standby prep already running, skip (%s)", provider_name, reason
|
||||
)
|
||||
return
|
||||
background_refresh_tasks[provider_name] = asyncio.create_task(
|
||||
background_refresh_worker(provider_name, reason)
|
||||
background_tasks[provider_name] = asyncio.create_task(
|
||||
ensure_standby_task(provider_name, usage_percent, reason)
|
||||
)
|
||||
|
||||
|
||||
async def token_handler(request: web.Request) -> web.Response:
|
||||
provider_name = request.match_info.get("provider", "chatgpt")
|
||||
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
return web.json_response(
|
||||
{"error": f"Unknown provider: {provider_name}"},
|
||||
status=404,
|
||||
{"error": f"Unknown provider: {provider_name}"}, status=404
|
||||
)
|
||||
|
||||
# Get or create token
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
return web.json_response(
|
||||
{"error": "Failed to get active token"},
|
||||
status=503,
|
||||
)
|
||||
return web.json_response({"error": "Failed to get active token"}, status=503)
|
||||
|
||||
# Get usage info
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" in usage_info:
|
||||
return web.json_response(
|
||||
{"error": usage_info["error"]},
|
||||
status=503,
|
||||
)
|
||||
return web.json_response({"error": usage_info["error"]}, status=503)
|
||||
|
||||
usage_percent = usage_info.get("used_percent", 0)
|
||||
remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent))
|
||||
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
switched = await provider.maybe_switch_active_account(usage_percent)
|
||||
usage_percent = int(usage_info.get("used_percent", 0))
|
||||
switched = await provider.maybe_rotate_account(usage_percent)
|
||||
if switched:
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
|
|
@ -171,16 +150,21 @@ async def token_handler(request: web.Request) -> web.Response:
|
|||
)
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" in usage_info:
|
||||
return web.json_response(
|
||||
{"error": usage_info["error"]},
|
||||
status=503,
|
||||
)
|
||||
usage_percent = usage_info.get("used_percent", 0)
|
||||
remaining_percent = usage_info.get(
|
||||
"remaining_percent", max(0, 100 - usage_percent)
|
||||
)
|
||||
return web.json_response({"error": usage_info["error"]}, status=503)
|
||||
usage_percent = int(usage_info.get("used_percent", 0))
|
||||
logger.info("[%s] Active account switched before response", provider_name)
|
||||
|
||||
prepare_threshold = get_prepare_threshold(provider_name)
|
||||
if usage_percent >= prepare_threshold:
|
||||
trigger_standby_prepare(
|
||||
provider_name,
|
||||
usage_percent,
|
||||
f"usage {usage_percent}% >= threshold {prepare_threshold}%",
|
||||
)
|
||||
|
||||
remaining_percent = int(
|
||||
usage_info.get("remaining_percent", max(0, 100 - usage_percent))
|
||||
)
|
||||
logger.info(
|
||||
"[%s] token issued, used=%s%% remaining=%s%%",
|
||||
provider_name,
|
||||
|
|
@ -207,20 +191,10 @@ async def token_handler(request: web.Request) -> web.Response:
|
|||
secondary_window.get("reset_after_seconds", 0),
|
||||
)
|
||||
|
||||
# Trigger background refresh if needed
|
||||
if usage_percent >= CHATGPT_PREPARE_THRESHOLD:
|
||||
if isinstance(provider, ChatGPTProvider):
|
||||
await provider.ensure_next_account()
|
||||
else:
|
||||
trigger_background_refresh(
|
||||
provider_name,
|
||||
f"usage {usage_percent}% >= threshold {CHATGPT_PREPARE_THRESHOLD}%",
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"token": token,
|
||||
"limit": build_limit(usage_percent),
|
||||
"limit": build_limit(usage_percent, prepare_threshold),
|
||||
"usage": {
|
||||
"primary_window": primary_window,
|
||||
"secondary_window": secondary_window,
|
||||
|
|
@ -229,22 +203,35 @@ async def token_handler(request: web.Request) -> web.Response:
|
|||
)
|
||||
|
||||
|
||||
async def on_startup(app: web.Application):
|
||||
del app
|
||||
for provider_name in PROVIDERS:
|
||||
await ensure_provider_token_ready(provider_name)
|
||||
|
||||
|
||||
async def on_cleanup(app: web.Application):
|
||||
del app
|
||||
for task in background_tasks.values():
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
pending = [t for t in background_tasks.values() if t is not None]
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
|
||||
def create_app() -> web.Application:
|
||||
app = web.Application(middlewares=[request_log_middleware])
|
||||
app.on_startup.append(on_startup)
|
||||
# New route: /{provider}/token
|
||||
app.on_cleanup.append(on_cleanup)
|
||||
app.router.add_get("/{provider}/token", token_handler)
|
||||
# Legacy route for backward compatibility
|
||||
app.router.add_get("/token", token_handler)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting token service on port %s", PORT)
|
||||
logger.info(
|
||||
"ChatGPT prepare-next threshold: %s%%",
|
||||
CHATGPT_PREPARE_THRESHOLD,
|
||||
)
|
||||
logger.info("ChatGPT prepare-next threshold: %s%%", CHATGPT_PREPARE_THRESHOLD)
|
||||
logger.info("ChatGPT switch threshold: %s%%", CHATGPT_SWITCH_THRESHOLD)
|
||||
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
||||
app = create_app()
|
||||
web.run_app(app, host="0.0.0.0", port=PORT)
|
||||
|
|
|
|||
12
tests/conftest.py
Normal file
12
tests/conftest.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _add_src_to_path() -> None:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
src = root / "src"
|
||||
if str(src) not in sys.path:
|
||||
sys.path.insert(0, str(src))
|
||||
|
||||
|
||||
_add_src_to_path()
|
||||
37
tests/test_registration_unit.py
Normal file
37
tests/test_registration_unit.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from providers.chatgpt.registration import (
|
||||
build_authorize_url,
|
||||
extract_verification_code,
|
||||
generate_birthdate_90s,
|
||||
generate_name,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_name_shape():
|
||||
name = generate_name()
|
||||
parts = name.split(" ")
|
||||
assert len(parts) == 2
|
||||
assert all(p.isalpha() for p in parts)
|
||||
|
||||
|
||||
def test_generate_birthdate_90s_range():
|
||||
month, day, year = generate_birthdate_90s()
|
||||
assert 1 <= int(month) <= 12
|
||||
assert 1 <= int(day) <= 28
|
||||
assert 1990 <= int(year) <= 1999
|
||||
|
||||
|
||||
def test_extract_verification_code_prefers_chatgpt_phrase():
|
||||
text = "foo 123456 bar Your ChatGPT code is 654321"
|
||||
assert extract_verification_code(text) == "654321"
|
||||
|
||||
|
||||
def test_extract_verification_code_fallback_last_code():
|
||||
text = "codes 111111 and 222222"
|
||||
assert extract_verification_code(text) == "222222"
|
||||
|
||||
|
||||
def test_build_authorize_url_contains_required_params():
|
||||
url = build_authorize_url("challenge", "state123")
|
||||
assert "response_type=code" in url
|
||||
assert "code_challenge=challenge" in url
|
||||
assert "state=state123" in url
|
||||
150
tests/test_server_unit.py
Normal file
150
tests/test_server_unit.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
import server
|
||||
from providers.base import Provider, ProviderTokens
|
||||
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, provider: str):
|
||||
self.match_info = {"provider": provider}
|
||||
|
||||
|
||||
class FakeProvider(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
token: str | None = "tok",
|
||||
usage: dict | None = None,
|
||||
rotate: bool = False,
|
||||
):
|
||||
self._token = token
|
||||
self._usage = usage or {
|
||||
"used_percent": 10,
|
||||
"remaining_percent": 90,
|
||||
"primary_window": None,
|
||||
"secondary_window": None,
|
||||
}
|
||||
self._rotate = rotate
|
||||
self.get_token_calls = 0
|
||||
self.standby_calls = 0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "fake"
|
||||
|
||||
async def get_token(self) -> str | None:
|
||||
self.get_token_calls += 1
|
||||
return self._token
|
||||
|
||||
async def register_new_account(self) -> bool:
|
||||
return True
|
||||
|
||||
async def get_usage_info(self, access_token: str) -> dict:
|
||||
_ = access_token
|
||||
return dict(self._usage)
|
||||
|
||||
def load_tokens(self) -> ProviderTokens | None:
|
||||
return None
|
||||
|
||||
def save_tokens(self, tokens: ProviderTokens) -> None:
|
||||
_ = tokens
|
||||
|
||||
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
||||
_ = usage_percent
|
||||
return self._rotate
|
||||
|
||||
async def ensure_standby_account(
|
||||
self, usage_percent: int, prepare_threshold: int
|
||||
) -> None:
|
||||
_ = usage_percent, prepare_threshold
|
||||
self.standby_calls += 1
|
||||
|
||||
|
||||
def _response_json(resp) -> dict:
|
||||
return json.loads(resp.body.decode("utf-8"))
|
||||
|
||||
|
||||
def test_parse_int_env_defaults(monkeypatch):
|
||||
monkeypatch.delenv("X_TEST", raising=False)
|
||||
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||
|
||||
|
||||
def test_parse_int_env_invalid(monkeypatch):
|
||||
monkeypatch.setenv("X_TEST", "abc")
|
||||
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||
|
||||
|
||||
def test_parse_int_env_out_of_range(monkeypatch):
|
||||
monkeypatch.setenv("X_TEST", "999")
|
||||
assert server._parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||
|
||||
|
||||
def test_build_limit_fields():
|
||||
limit = server.build_limit(90, 85)
|
||||
assert limit == {
|
||||
"used_percent": 90,
|
||||
"remaining_percent": 10,
|
||||
"exhausted": False,
|
||||
"needs_prepare": True,
|
||||
}
|
||||
|
||||
|
||||
def test_get_prepare_threshold():
|
||||
assert server.get_prepare_threshold("chatgpt") == server.CHATGPT_PREPARE_THRESHOLD
|
||||
assert server.get_prepare_threshold("unknown") == 100
|
||||
|
||||
|
||||
def test_token_handler_unknown_provider(monkeypatch):
|
||||
monkeypatch.setattr(server, "PROVIDERS", {})
|
||||
resp = asyncio.run(server.token_handler(FakeRequest("missing")))
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
def test_token_handler_success(monkeypatch):
|
||||
provider = FakeProvider()
|
||||
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||
monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80)
|
||||
|
||||
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||
data = _response_json(resp)
|
||||
|
||||
assert resp.status == 200
|
||||
assert data["token"] == "tok"
|
||||
assert data["limit"]["needs_prepare"] is False
|
||||
|
||||
|
||||
def test_token_handler_triggers_standby(monkeypatch):
|
||||
provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10})
|
||||
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
def fake_trigger(name, usage_percent, reason):
|
||||
assert name == "fake"
|
||||
assert usage_percent == 90
|
||||
assert "threshold" in reason
|
||||
called["value"] = True
|
||||
|
||||
monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80)
|
||||
monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger)
|
||||
|
||||
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||
assert resp.status == 200
|
||||
assert called["value"] is True
|
||||
|
||||
|
||||
def test_token_handler_rotation_path(monkeypatch):
|
||||
provider = FakeProvider(
|
||||
usage={"used_percent": 96, "remaining_percent": 4},
|
||||
rotate=True,
|
||||
)
|
||||
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||
monkeypatch.setattr(server, "get_prepare_threshold", lambda _: 80)
|
||||
monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None)
|
||||
|
||||
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||
assert resp.status == 200
|
||||
assert provider.get_token_calls >= 2
|
||||
60
tests/test_tokens_unit.py
Normal file
60
tests/test_tokens_unit.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from providers.base import ProviderTokens
|
||||
from providers.chatgpt import tokens as t
|
||||
|
||||
|
||||
def test_normalize_state_backward_compatible():
|
||||
raw = {"access_token": "a", "refresh_token": "r", "expires_at": 1}
|
||||
normalized = t._normalize_state(raw)
|
||||
assert normalized["active"]["access_token"] == "a"
|
||||
assert normalized["next_account"] is None
|
||||
|
||||
|
||||
def test_promote_next_tokens(tmp_path, monkeypatch):
|
||||
file_path = tmp_path / "chatgpt_tokens.json"
|
||||
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||
|
||||
active = ProviderTokens("a1", "r1", 100)
|
||||
nxt = ProviderTokens("a2", "r2", 200)
|
||||
t.save_state(active, nxt)
|
||||
|
||||
assert t.promote_next_tokens() is True
|
||||
cur, next_cur = t.load_state()
|
||||
assert cur is not None
|
||||
assert cur.access_token == "a2"
|
||||
assert next_cur is None
|
||||
|
||||
|
||||
def test_save_tokens_preserves_next(tmp_path, monkeypatch):
|
||||
file_path = tmp_path / "chatgpt_tokens.json"
|
||||
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||
|
||||
active = ProviderTokens("a1", "r1", 100)
|
||||
nxt = ProviderTokens("a2", "r2", 200)
|
||||
t.save_state(active, nxt)
|
||||
|
||||
t.save_tokens(ProviderTokens("a3", "r3", 300))
|
||||
cur, next_cur = t.load_state()
|
||||
assert cur is not None and cur.access_token == "a3"
|
||||
assert next_cur is not None and next_cur.access_token == "a2"
|
||||
|
||||
|
||||
def test_atomic_write_produces_valid_json(tmp_path, monkeypatch):
|
||||
file_path = tmp_path / "chatgpt_tokens.json"
|
||||
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||
|
||||
t.save_state(ProviderTokens("x", "y", 123), None)
|
||||
with open(file_path) as f:
|
||||
data = json.load(f)
|
||||
assert "active" in data
|
||||
assert data["active"]["access_token"] == "x"
|
||||
|
||||
|
||||
def test_load_state_from_missing_file(tmp_path, monkeypatch):
|
||||
file_path = tmp_path / "missing.json"
|
||||
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||
active, nxt = t.load_state()
|
||||
assert active is None
|
||||
assert nxt is None
|
||||
32
tests/test_usage_unit.py
Normal file
32
tests/test_usage_unit.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from providers.chatgpt.usage import _parse_window, clamp_percent
|
||||
|
||||
|
||||
def test_clamp_percent_bounds():
|
||||
assert clamp_percent(-1) == 0
|
||||
assert clamp_percent(150) == 100
|
||||
assert clamp_percent(49.6) == 50
|
||||
|
||||
|
||||
def test_clamp_percent_invalid():
|
||||
assert clamp_percent(None) == 0
|
||||
assert clamp_percent("bad") == 0
|
||||
|
||||
|
||||
def test_parse_window_valid():
|
||||
window = {
|
||||
"used_percent": 34.4,
|
||||
"limit_window_seconds": 3600,
|
||||
"reset_after_seconds": 120,
|
||||
"reset_at": 999,
|
||||
}
|
||||
parsed = _parse_window(window)
|
||||
assert parsed == {
|
||||
"used_percent": 34,
|
||||
"limit_window_seconds": 3600,
|
||||
"reset_after_seconds": 120,
|
||||
"reset_at": 999,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_window_none():
|
||||
assert _parse_window(None) is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue