Compare commits
5 commits
ccd4d82194
...
320460f7f4
| Author | SHA1 | Date | |
|---|---|---|---|
| 320460f7f4 | |||
| 858d127246 | |||
| 307ca38ecc | |||
| 8b5449b1fd | |||
| d6396e4050 |
26 changed files with 1332 additions and 333 deletions
|
|
@ -1,8 +1,11 @@
|
||||||
# HTTP server port
|
# HTTP server port
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
# Trigger background token refresh when usage reaches threshold percent
|
# Prepare next ChatGPT account when active usage reaches threshold percent
|
||||||
USAGE_REFRESH_THRESHOLD=85
|
CHATGPT_PREPARE_THRESHOLD=85
|
||||||
|
|
||||||
|
# Switch active ChatGPT account when usage reaches threshold percent
|
||||||
|
CHATGPT_SWITCH_THRESHOLD=95
|
||||||
|
|
||||||
# Persistent data directory (tokens, screenshots)
|
# Persistent data directory (tokens, screenshots)
|
||||||
DATA_DIR=/data
|
DATA_DIR=/data
|
||||||
|
|
|
||||||
106
README.md
106
README.md
|
|
@ -1,29 +1,22 @@
|
||||||
# megapt
|
# megapt
|
||||||
|
|
||||||
HTTP service that returns an active ChatGPT access token.
|
Service for issuing ChatGPT OAuth tokens via browser automation with disposable email.
|
||||||
|
|
||||||
The service can:
|
|
||||||
- restore/refresh a saved token from `/data`
|
|
||||||
- auto-register a new ChatGPT account when needed
|
|
||||||
- get verification email from a disposable mail provider (`temp-mail.org`)
|
|
||||||
- expose token and usage info via HTTP endpoint
|
|
||||||
|
|
||||||
|
|
||||||
## Endpoints
|
## Endpoints
|
||||||
|
|
||||||
- `GET /token` - legacy route (defaults to `chatgpt` provider)
|
- `GET /chatgpt/token`
|
||||||
- `GET /chatgpt/token` - explicit provider route
|
- `GET /token` (legacy alias, same as chatgpt)
|
||||||
|
|
||||||
Example response:
|
Response shape:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"token": "<access_token>",
|
"token": "...",
|
||||||
"limit": {
|
"limit": {
|
||||||
"used_percent": 0,
|
"used_percent": 0,
|
||||||
"remaining_percent": 100,
|
"remaining_percent": 100,
|
||||||
"exhausted": false,
|
"exhausted": false,
|
||||||
"needs_refresh": false
|
"needs_prepare": false
|
||||||
},
|
},
|
||||||
"usage": {
|
"usage": {
|
||||||
"primary_window": {
|
"primary_window": {
|
||||||
|
|
@ -37,67 +30,58 @@ Example response:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
## Environment variables
|
- `PORT` - HTTP server port (default: `8080`)
|
||||||
|
- `DATA_DIR` - persistent data directory for tokens/screenshots (default: `./data`)
|
||||||
|
- `CHATGPT_PREPARE_THRESHOLD` - usage threshold to prepare `next_account` (default: `85`)
|
||||||
|
- `CHATGPT_SWITCH_THRESHOLD` - usage threshold to switch active account to `next_account` (default: `95`)
|
||||||
|
|
||||||
See `.env.example`.
|
Example config is in `.env.example`.
|
||||||
|
|
||||||
- `PORT` - HTTP port for the service
|
## Token Lifecycle
|
||||||
- `USAGE_REFRESH_THRESHOLD` - percent threshold to trigger background token rotation
|
|
||||||
- `DATA_DIR` - directory for persistent data (`chatgpt_tokens.json`, screenshots, etc.)
|
|
||||||
|
|
||||||
|
- **active account** - currently served token.
|
||||||
|
- **next account** - pre-created account/token stored for fast switch.
|
||||||
|
|
||||||
## Local run
|
Behavior:
|
||||||
|
|
||||||
Requirements:
|
1. If active token is valid, service returns it immediately.
|
||||||
- Python 3.14+
|
2. If active token is expired, service tries refresh under a single write lock.
|
||||||
- Playwright Chromium dependencies
|
3. If refresh fails or token is missing, service registers a new account (up to 4 attempts).
|
||||||
|
4. When usage reaches `CHATGPT_PREPARE_THRESHOLD`, service prepares `next_account`.
|
||||||
|
5. When usage reaches `CHATGPT_SWITCH_THRESHOLD`, service switches active account to `next_account`.
|
||||||
|
|
||||||
Install and run:
|
## Disposable Email Provider
|
||||||
|
|
||||||
|
- Default provider is `mail.tm` API (`MailTmProvider`) and does not use browser automation.
|
||||||
|
- Flow: fetch domains -> create account with random address/password -> get JWT token -> poll messages.
|
||||||
|
|
||||||
|
## Startup Behavior
|
||||||
|
|
||||||
|
On startup, service ensures active token exists and is usable.
|
||||||
|
Standby preparation runs through provider lifecycle hooks/background trigger when needed.
|
||||||
|
|
||||||
|
## Data Files
|
||||||
|
|
||||||
|
- `DATA_DIR/chatgpt_tokens.json` - token state with `active` and `next_account`.
|
||||||
|
- `DATA_DIR/screenshots/` - automation failure screenshots.
|
||||||
|
|
||||||
|
## Run Locally
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv sync --frozen --no-dev
|
PYTHONPATH=./src python src/server.py
|
||||||
./.venv/bin/python -m playwright install --with-deps chromium
|
|
||||||
PYTHONPATH=./src ./.venv/bin/python src/server.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then request token:
|
## Unit Tests
|
||||||
|
|
||||||
|
The project has unit tests only (no integration/network tests).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://127.0.0.1:8080/chatgpt/token
|
pytest -q
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Docker Notes
|
||||||
|
|
||||||
## Docker deployment
|
- Dockerfile sets `DATA_DIR=/data`.
|
||||||
|
- `entrypoint.sh` starts Xvfb and runs `server.py`.
|
||||||
Build image:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker build -t megapt:latest .
|
|
||||||
```
|
|
||||||
|
|
||||||
Run container:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run -d \
|
|
||||||
--name megapt \
|
|
||||||
--restart unless-stopped \
|
|
||||||
--env-file .env \
|
|
||||||
-v ./data:/data \
|
|
||||||
-p 80:80 \
|
|
||||||
megapt:latest
|
|
||||||
```
|
|
||||||
|
|
||||||
Check logs:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker logs -f megapt
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Notes
|
|
||||||
|
|
||||||
- Service performs a startup token check and tries to recover token automatically.
|
|
||||||
- Token write path is synchronized (single-writer lock) to avoid parallel re-registration.
|
|
||||||
- Browser runs in virtual display (`Xvfb`) inside container.
|
|
||||||
- Keep `/data` persistent between restarts.
|
|
||||||
|
|
|
||||||
|
|
@ -8,5 +8,10 @@ dependencies = [
|
||||||
"pkce==1.0.3",
|
"pkce==1.0.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
package = false
|
package = false
|
||||||
|
|
|
||||||
55
scripts/run_token_refresh_flow.py
Normal file
55
scripts/run_token_refresh_flow.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import browser
|
||||||
|
import server
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRequest:
|
||||||
|
def __init__(self, provider: str):
|
||||||
|
self.match_info = {"provider": provider}
|
||||||
|
self.method = "GET"
|
||||||
|
self.path_qs = f"/{provider}/token"
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_headed_browser() -> bool:
|
||||||
|
if "--no-startup-window" in browser.CHROME_FLAGS:
|
||||||
|
browser.CHROME_FLAGS.remove("--no-startup-window")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _run(provider: str) -> int:
|
||||||
|
patched = _enable_headed_browser()
|
||||||
|
logging.info("Headed mode patch applied: %s", patched)
|
||||||
|
|
||||||
|
request = _FakeRequest(provider)
|
||||||
|
response = await server.token_handler(request)
|
||||||
|
payload = json.loads(response.body.decode("utf-8"))
|
||||||
|
|
||||||
|
logging.info("Response status: %s", response.status)
|
||||||
|
logging.info("Response body: %s", json.dumps(payload, indent=2))
|
||||||
|
return 0 if response.status == 200 else 1
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"Run the same token refresh/issue flow as server /{provider}/token "
|
||||||
|
"in headed browser mode (non-headless)."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser.add_argument("--provider", default="chatgpt")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
return asyncio.run(_run(args.provider))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import socket
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -44,7 +45,12 @@ CHROME_FLAGS = [
|
||||||
"--disable-search-engine-choice-screen",
|
"--disable-search-engine-choice-screen",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_CDP_PORT = 9222
|
|
||||||
|
def _allocate_free_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("127.0.0.1", 0))
|
||||||
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
return int(s.getsockname()[1])
|
||||||
|
|
||||||
|
|
||||||
def _fetch_ws_endpoint(port: int) -> str | None:
|
def _fetch_ws_endpoint(port: int) -> str | None:
|
||||||
|
|
@ -79,10 +85,9 @@ class ManagedBrowser:
|
||||||
shutil.rmtree(self.profile_dir, ignore_errors=True)
|
shutil.rmtree(self.profile_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
async def launch(
|
async def launch(playwright: Playwright) -> ManagedBrowser:
|
||||||
playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT
|
|
||||||
) -> ManagedBrowser:
|
|
||||||
chrome_path = playwright.chromium.executable_path
|
chrome_path = playwright.chromium.executable_path
|
||||||
|
cdp_port = _allocate_free_port()
|
||||||
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
|
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,11 @@
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
|
from .mail_tm import MailTmProvider
|
||||||
from .ten_minute_mail import TenMinuteMailProvider
|
from .ten_minute_mail import TenMinuteMailProvider
|
||||||
from .temp_mail_org import TempMailOrgProvider
|
from .temp_mail_org import TempMailOrgProvider
|
||||||
|
|
||||||
__all__ = ["BaseProvider", "TenMinuteMailProvider", "TempMailOrgProvider"]
|
__all__ = [
|
||||||
|
"BaseProvider",
|
||||||
|
"MailTmProvider",
|
||||||
|
"TenMinuteMailProvider",
|
||||||
|
"TempMailOrgProvider",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -12,5 +12,5 @@ class BaseProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_latest_message(self, email: str) -> str | None:
|
async def get_latest_message(self) -> str | None:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
227
src/email_providers/mail_tm.py
Normal file
227
src/email_providers/mail_tm.py
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from playwright.async_api import BrowserContext
|
||||||
|
|
||||||
|
from .base import BaseProvider
|
||||||
|
from utils.randoms import generate_password
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_API_BASE = os.environ.get("MAIL_TM_API_BASE", "https://api.mail.tm")
|
||||||
|
_TIMEOUT_SECONDS = 20
|
||||||
|
_FIRST_NAMES = [
|
||||||
|
"james",
|
||||||
|
"john",
|
||||||
|
"robert",
|
||||||
|
"michael",
|
||||||
|
"david",
|
||||||
|
"william",
|
||||||
|
"joseph",
|
||||||
|
"thomas",
|
||||||
|
"daniel",
|
||||||
|
"mark",
|
||||||
|
"paul",
|
||||||
|
"kevin",
|
||||||
|
]
|
||||||
|
_LAST_NAMES = [
|
||||||
|
"smith",
|
||||||
|
"johnson",
|
||||||
|
"williams",
|
||||||
|
"brown",
|
||||||
|
"jones",
|
||||||
|
"miller",
|
||||||
|
"davis",
|
||||||
|
"wilson",
|
||||||
|
"anderson",
|
||||||
|
"taylor",
|
||||||
|
"martin",
|
||||||
|
"thompson",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_local_part() -> str:
|
||||||
|
first = secrets.choice(_FIRST_NAMES)
|
||||||
|
last = secrets.choice(_LAST_NAMES)
|
||||||
|
digits = "".join(secrets.choice(string.digits) for _ in range(8))
|
||||||
|
return f"{first}{last}{digits}"
|
||||||
|
|
||||||
|
|
||||||
|
class MailTmProvider(BaseProvider):
|
||||||
|
def __init__(self, browser_session: BrowserContext):
|
||||||
|
super().__init__(browser_session)
|
||||||
|
self._address: str | None = None
|
||||||
|
self._password: str | None = None
|
||||||
|
self._token: str | None = None
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
token: str | None = None,
|
||||||
|
json_body: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[int, dict[str, Any] | list[Any] | None]:
|
||||||
|
url = f"{_API_BASE.rstrip('/')}{path}"
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=_TIMEOUT_SECONDS)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=json_body,
|
||||||
|
) as resp:
|
||||||
|
status = resp.status
|
||||||
|
try:
|
||||||
|
payload = await resp.json()
|
||||||
|
except aiohttp.ContentTypeError:
|
||||||
|
payload = None
|
||||||
|
return status, payload
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
logger.warning("[mail.tm] request failed %s %s: %s", method, path, e)
|
||||||
|
return 0, None
|
||||||
|
|
||||||
|
async def _get_domains(self) -> list[str]:
|
||||||
|
status, payload = await self._request("GET", "/domains")
|
||||||
|
if status != 200 or not isinstance(payload, dict):
|
||||||
|
raise RuntimeError("mail.tm domains request failed")
|
||||||
|
|
||||||
|
members = payload.get("hydra:member")
|
||||||
|
if not isinstance(members, list):
|
||||||
|
raise RuntimeError("mail.tm domains response has unexpected format")
|
||||||
|
|
||||||
|
domains: list[str] = []
|
||||||
|
for item in members:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
domain = item.get("domain")
|
||||||
|
is_active = bool(item.get("isActive", True))
|
||||||
|
if isinstance(domain, str) and domain and is_active:
|
||||||
|
domains.append(domain)
|
||||||
|
|
||||||
|
if not domains:
|
||||||
|
raise RuntimeError("mail.tm returned no active domains")
|
||||||
|
return domains
|
||||||
|
|
||||||
|
async def _create_account(self, address: str, password: str) -> bool:
|
||||||
|
status, _ = await self._request(
|
||||||
|
"POST",
|
||||||
|
"/accounts",
|
||||||
|
json_body={"address": address, "password": password},
|
||||||
|
)
|
||||||
|
if status in (200, 201):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _create_token(self, address: str, password: str) -> str | None:
|
||||||
|
status, payload = await self._request(
|
||||||
|
"POST",
|
||||||
|
"/token",
|
||||||
|
json_body={"address": address, "password": password},
|
||||||
|
)
|
||||||
|
if status != 200 or not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
token = payload.get("token")
|
||||||
|
if isinstance(token, str) and token:
|
||||||
|
return token
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_new_email(self) -> str:
|
||||||
|
domains = await self._get_domains()
|
||||||
|
|
||||||
|
for _ in range(8):
|
||||||
|
domain = secrets.choice(domains)
|
||||||
|
address = f"{_generate_local_part()}@{domain}"
|
||||||
|
password = generate_password(length=24)
|
||||||
|
|
||||||
|
created = await self._create_account(address, password)
|
||||||
|
if not created:
|
||||||
|
continue
|
||||||
|
|
||||||
|
token = await self._create_token(address, password)
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._address = address
|
||||||
|
self._password = password
|
||||||
|
self._token = token
|
||||||
|
logger.info("[mail.tm] New mailbox acquired: %s", address)
|
||||||
|
return address
|
||||||
|
|
||||||
|
raise RuntimeError("mail.tm could not create account")
|
||||||
|
|
||||||
|
async def _list_messages(self) -> list[dict[str, Any]]:
|
||||||
|
if not self._token:
|
||||||
|
return []
|
||||||
|
status, payload = await self._request(
|
||||||
|
"GET",
|
||||||
|
"/messages",
|
||||||
|
token=self._token,
|
||||||
|
)
|
||||||
|
if status == 401 and self._address and self._password:
|
||||||
|
token = await self._create_token(self._address, self._password)
|
||||||
|
if token:
|
||||||
|
self._token = token
|
||||||
|
status, payload = await self._request(
|
||||||
|
"GET",
|
||||||
|
"/messages",
|
||||||
|
token=self._token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status != 200 or not isinstance(payload, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
members = payload.get("hydra:member")
|
||||||
|
if not isinstance(members, list):
|
||||||
|
return []
|
||||||
|
return [item for item in members if isinstance(item, dict)]
|
||||||
|
|
||||||
|
async def _get_message_text(self, message_id: str) -> str | None:
|
||||||
|
if not self._token:
|
||||||
|
return None
|
||||||
|
status, payload = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/messages/{message_id}",
|
||||||
|
token=self._token,
|
||||||
|
)
|
||||||
|
if status != 200 or not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = [
|
||||||
|
payload.get("subject"),
|
||||||
|
payload.get("intro"),
|
||||||
|
payload.get("text"),
|
||||||
|
payload.get("html"),
|
||||||
|
]
|
||||||
|
text = "\n".join(str(part) for part in parts if part)
|
||||||
|
return text or None
|
||||||
|
|
||||||
|
async def get_latest_message(self) -> str | None:
|
||||||
|
if not self._token:
|
||||||
|
raise RuntimeError("mail.tm provider is not initialized with mailbox token")
|
||||||
|
|
||||||
|
for _ in range(45):
|
||||||
|
messages = await self._list_messages()
|
||||||
|
if messages:
|
||||||
|
latest = messages[0]
|
||||||
|
message_id = latest.get("id")
|
||||||
|
if isinstance(message_id, str) and message_id:
|
||||||
|
full_message = await self._get_message_text(message_id)
|
||||||
|
if full_message:
|
||||||
|
logger.info("[mail.tm] Latest message received")
|
||||||
|
return full_message
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
logger.warning("[mail.tm] No messages received within timeout")
|
||||||
|
return None
|
||||||
|
|
@ -2,9 +2,10 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from playwright.async_api import BrowserContext, Page
|
from playwright.async_api import BrowserContext, Error as PlaywrightError, Page
|
||||||
|
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
|
from .utils import ensure_page
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -15,8 +16,7 @@ class TempMailOrgProvider(BaseProvider):
|
||||||
self.page: Page | None = None
|
self.page: Page | None = None
|
||||||
|
|
||||||
async def _ensure_page(self) -> Page:
|
async def _ensure_page(self) -> Page:
|
||||||
if self.page is None or self.page.is_closed():
|
self.page = await ensure_page(self.browser_session, self.page)
|
||||||
self.page = await self.browser_session.new_page()
|
|
||||||
return self.page
|
return self.page
|
||||||
|
|
||||||
async def get_new_email(self) -> str:
|
async def get_new_email(self) -> str:
|
||||||
|
|
@ -44,7 +44,7 @@ class TempMailOrgProvider(BaseProvider):
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
return value
|
return value
|
||||||
except Exception:
|
except PlaywrightError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -53,16 +53,16 @@ class TempMailOrgProvider(BaseProvider):
|
||||||
if found:
|
if found:
|
||||||
logger.info("[temp-mail.org] email found by body scan: %s", found)
|
logger.info("[temp-mail.org] email found by body scan: %s", found)
|
||||||
return found
|
return found
|
||||||
except Exception:
|
except PlaywrightError:
|
||||||
pass
|
logger.debug("Failed to scan body text for email")
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
raise RuntimeError("Could not get temp email from temp-mail.org")
|
raise RuntimeError("Could not get temp email from temp-mail.org")
|
||||||
|
|
||||||
async def get_latest_message(self, email: str) -> str | None:
|
async def get_latest_message(self) -> str | None:
|
||||||
page = await self._ensure_page()
|
page = await self._ensure_page()
|
||||||
logger.info("[temp-mail.org] Waiting for latest message for %s", email)
|
logger.info("[temp-mail.org] Waiting for latest message")
|
||||||
|
|
||||||
if page.is_closed():
|
if page.is_closed():
|
||||||
raise RuntimeError("temp-mail.org tab was closed unexpectedly")
|
raise RuntimeError("temp-mail.org tab was closed unexpectedly")
|
||||||
|
|
@ -76,7 +76,7 @@ class TempMailOrgProvider(BaseProvider):
|
||||||
try:
|
try:
|
||||||
count = await items.count()
|
count = await items.count()
|
||||||
logger.info("[temp-mail.org] inbox items: %s", count)
|
logger.info("[temp-mail.org] inbox items: %s", count)
|
||||||
except Exception:
|
except PlaywrightError:
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
|
|
@ -87,30 +87,30 @@ class TempMailOrgProvider(BaseProvider):
|
||||||
continue
|
continue
|
||||||
text = (await item.inner_text()).strip().replace("\n", " ")
|
text = (await item.inner_text()).strip().replace("\n", " ")
|
||||||
logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160])
|
logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160])
|
||||||
except Exception:
|
except PlaywrightError:
|
||||||
continue
|
continue
|
||||||
if text:
|
if text:
|
||||||
try:
|
try:
|
||||||
await item.click()
|
await item.click()
|
||||||
logger.info("[temp-mail.org] opened item[%s]", idx)
|
logger.info("[temp-mail.org] opened item[%s]", idx)
|
||||||
except Exception:
|
except PlaywrightError:
|
||||||
pass
|
logger.debug("Failed to open inbox item[%s]", idx)
|
||||||
|
|
||||||
message_text = text
|
message_text = text
|
||||||
try:
|
try:
|
||||||
content = await page.content()
|
content = await page.content()
|
||||||
if content and "Your ChatGPT code is" in content:
|
if content and "Your ChatGPT code is" in content:
|
||||||
message_text = content
|
message_text = content
|
||||||
except Exception:
|
except PlaywrightError:
|
||||||
pass
|
logger.debug("Failed to read opened message content")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await page.go_back(
|
await page.go_back(
|
||||||
wait_until="domcontentloaded", timeout=5000
|
wait_until="domcontentloaded", timeout=5000
|
||||||
)
|
)
|
||||||
logger.info("[temp-mail.org] returned back to inbox")
|
logger.info("[temp-mail.org] returned back to inbox")
|
||||||
except Exception:
|
except PlaywrightError:
|
||||||
pass
|
logger.debug("Failed to return back to inbox")
|
||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import logging
|
||||||
from playwright.async_api import BrowserContext, Page
|
from playwright.async_api import BrowserContext, Page
|
||||||
|
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
|
from .utils import ensure_page
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -14,8 +15,7 @@ class TenMinuteMailProvider(BaseProvider):
|
||||||
self.page: Page | None = None
|
self.page: Page | None = None
|
||||||
|
|
||||||
async def _ensure_page(self) -> Page:
|
async def _ensure_page(self) -> Page:
|
||||||
if self.page is None or self.page.is_closed():
|
self.page = await ensure_page(self.browser_session, self.page)
|
||||||
self.page = await self.browser_session.new_page()
|
|
||||||
return self.page
|
return self.page
|
||||||
|
|
||||||
async def get_new_email(self) -> str:
|
async def get_new_email(self) -> str:
|
||||||
|
|
@ -34,9 +34,9 @@ class TenMinuteMailProvider(BaseProvider):
|
||||||
logger.info("[10min] New email acquired: %s", email)
|
logger.info("[10min] New email acquired: %s", email)
|
||||||
return email
|
return email
|
||||||
|
|
||||||
async def get_latest_message(self, email: str) -> str | None:
|
async def get_latest_message(self) -> str | None:
|
||||||
page = await self._ensure_page()
|
page = await self._ensure_page()
|
||||||
logger.info("[10min] Waiting for latest message for %s", email)
|
logger.info("[10min] Waiting for latest message")
|
||||||
|
|
||||||
seen_count = 0
|
seen_count = 0
|
||||||
for attempt in range(60):
|
for attempt in range(60):
|
||||||
|
|
|
||||||
10
src/email_providers/utils.py
Normal file
10
src/email_providers/utils.py
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
from playwright.async_api import BrowserContext, Page
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_page(
|
||||||
|
browser_session: BrowserContext,
|
||||||
|
page: Page | None,
|
||||||
|
) -> Page:
|
||||||
|
if page is None or page.is_closed():
|
||||||
|
return await browser_session.new_page()
|
||||||
|
return page
|
||||||
|
|
@ -52,3 +52,36 @@ class Provider(ABC):
|
||||||
def save_tokens(self, tokens: ProviderTokens) -> None:
|
def save_tokens(self, tokens: ProviderTokens) -> None:
|
||||||
"""Save tokens to storage"""
|
"""Save tokens to storage"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def force_recreate_token(self) -> str | None:
|
||||||
|
"""Force-create a new active token when normal acquisition fails."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
||||||
|
"""Rotate active account/token if provider policy requires it."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prepare_threshold(self) -> int:
|
||||||
|
"""Usage percent when provider should prepare standby account/token."""
|
||||||
|
return 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def switch_threshold(self) -> int | None:
|
||||||
|
"""Usage percent when provider may switch active account/token."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
||||||
|
"""Whether standby preparation should be triggered for current usage."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def ensure_standby_account(
|
||||||
|
self,
|
||||||
|
usage_percent: int,
|
||||||
|
) -> None:
|
||||||
|
"""Prepare standby account/token asynchronously when needed."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def startup_prepare(self) -> None:
|
||||||
|
"""Optional provider-specific startup preparation."""
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,36 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from playwright.async_api import BrowserContext
|
from playwright.async_api import BrowserContext
|
||||||
|
|
||||||
from providers.base import Provider, ProviderTokens
|
|
||||||
from email_providers import BaseProvider
|
from email_providers import BaseProvider
|
||||||
from email_providers import TempMailOrgProvider
|
from email_providers import MailTmProvider
|
||||||
from .tokens import load_tokens, save_tokens, refresh_tokens
|
from providers.base import Provider, ProviderTokens
|
||||||
|
from utils.env import parse_int_env
|
||||||
|
from .tokens import (
|
||||||
|
clear_next_tokens,
|
||||||
|
load_next_tokens,
|
||||||
|
load_state,
|
||||||
|
load_tokens,
|
||||||
|
promote_next_tokens,
|
||||||
|
refresh_tokens,
|
||||||
|
save_state,
|
||||||
|
save_tokens,
|
||||||
|
)
|
||||||
from .usage import get_usage_data
|
from .usage import get_usage_data
|
||||||
from .registration import register_chatgpt_account
|
from .registration import register_chatgpt_account
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
MAX_REGISTRATION_ATTEMPTS = 4
|
CHATGPT_REGISTRATION_MAX_ATTEMPTS = 4
|
||||||
|
CHATGPT_PREPARE_THRESHOLD = parse_int_env("CHATGPT_PREPARE_THRESHOLD", 85, 0, 100)
|
||||||
|
CHATGPT_SWITCH_THRESHOLD = parse_int_env(
|
||||||
|
"CHATGPT_SWITCH_THRESHOLD",
|
||||||
|
95,
|
||||||
|
0,
|
||||||
|
100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTProvider(Provider):
|
class ChatGPTProvider(Provider):
|
||||||
|
|
@ -23,20 +40,61 @@ class ChatGPTProvider(Provider):
|
||||||
self,
|
self,
|
||||||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||||
):
|
):
|
||||||
self.email_provider_factory = email_provider_factory or TempMailOrgProvider
|
self.email_provider_factory = email_provider_factory or MailTmProvider
|
||||||
self._token_write_lock = asyncio.Lock()
|
self._token_write_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prepare_threshold(self) -> int:
|
||||||
|
return CHATGPT_PREPARE_THRESHOLD
|
||||||
|
|
||||||
|
@property
|
||||||
|
def switch_threshold(self) -> int | None:
|
||||||
|
return CHATGPT_SWITCH_THRESHOLD
|
||||||
|
|
||||||
async def _register_with_retries(self) -> bool:
|
async def _register_with_retries(self) -> bool:
|
||||||
for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1):
|
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Registration attempt %s/%s",
|
"Registration attempt %s/%s",
|
||||||
attempt,
|
attempt,
|
||||||
MAX_REGISTRATION_ATTEMPTS,
|
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
|
||||||
)
|
)
|
||||||
success = await self.register_new_account()
|
generated_tokens = await register_chatgpt_account(
|
||||||
if success:
|
email_provider_factory=self.email_provider_factory,
|
||||||
|
)
|
||||||
|
if generated_tokens:
|
||||||
|
save_tokens(generated_tokens)
|
||||||
return True
|
return True
|
||||||
logger.warning("Registration attempt %s failed", attempt)
|
logger.warning("Registration attempt %s failed", attempt)
|
||||||
|
await asyncio.sleep(1.5 * attempt)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _create_next_account_under_lock(self) -> bool:
|
||||||
|
active_before, next_before = load_state()
|
||||||
|
if next_before:
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info("Creating next account")
|
||||||
|
for attempt in range(1, CHATGPT_REGISTRATION_MAX_ATTEMPTS + 1):
|
||||||
|
logger.info(
|
||||||
|
"Next-account registration attempt %s/%s",
|
||||||
|
attempt,
|
||||||
|
CHATGPT_REGISTRATION_MAX_ATTEMPTS,
|
||||||
|
)
|
||||||
|
generated_tokens = await register_chatgpt_account(
|
||||||
|
email_provider_factory=self.email_provider_factory,
|
||||||
|
)
|
||||||
|
if generated_tokens:
|
||||||
|
if active_before:
|
||||||
|
save_state(active_before, generated_tokens)
|
||||||
|
else:
|
||||||
|
save_state(generated_tokens, None)
|
||||||
|
logger.info("Next account is ready")
|
||||||
|
return True
|
||||||
|
logger.warning("Next-account registration attempt %s failed", attempt)
|
||||||
|
await asyncio.sleep(1.5 * attempt)
|
||||||
|
|
||||||
|
if active_before or next_before:
|
||||||
|
save_state(active_before, next_before)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def force_recreate_token(self) -> str | None:
|
async def force_recreate_token(self) -> str | None:
|
||||||
|
|
@ -44,11 +102,62 @@ class ChatGPTProvider(Provider):
|
||||||
success = await self._register_with_retries()
|
success = await self._register_with_retries()
|
||||||
if not success:
|
if not success:
|
||||||
return None
|
return None
|
||||||
|
clear_next_tokens()
|
||||||
tokens = load_tokens()
|
tokens = load_tokens()
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return None
|
return None
|
||||||
return tokens.access_token
|
return tokens.access_token
|
||||||
|
|
||||||
|
async def startup_prepare(self) -> None:
|
||||||
|
await self.ensure_next_account()
|
||||||
|
|
||||||
|
async def ensure_next_account(self) -> bool:
|
||||||
|
next_tokens = load_next_tokens()
|
||||||
|
if next_tokens:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async with self._token_write_lock:
|
||||||
|
next_tokens = load_next_tokens()
|
||||||
|
if next_tokens:
|
||||||
|
return True
|
||||||
|
return await self._create_next_account_under_lock()
|
||||||
|
|
||||||
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
||||||
|
return usage_percent >= self.prepare_threshold
|
||||||
|
|
||||||
|
async def ensure_standby_account(
|
||||||
|
self,
|
||||||
|
usage_percent: int,
|
||||||
|
) -> None:
|
||||||
|
if self.should_prepare_standby(usage_percent):
|
||||||
|
await self.ensure_next_account()
|
||||||
|
|
||||||
|
async def maybe_switch_active_account(self, usage_percent: int) -> bool:
|
||||||
|
if usage_percent < CHATGPT_SWITCH_THRESHOLD:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async with self._token_write_lock:
|
||||||
|
next_tokens = load_next_tokens()
|
||||||
|
if not next_tokens or next_tokens.is_expired:
|
||||||
|
logger.info(
|
||||||
|
"Active usage >= %s%% and next account missing",
|
||||||
|
CHATGPT_SWITCH_THRESHOLD,
|
||||||
|
)
|
||||||
|
created = await self._create_next_account_under_lock()
|
||||||
|
if not created:
|
||||||
|
return False
|
||||||
|
|
||||||
|
switched = promote_next_tokens()
|
||||||
|
if switched:
|
||||||
|
logger.info(
|
||||||
|
"Switched active account (usage >= %s%%)",
|
||||||
|
CHATGPT_SWITCH_THRESHOLD,
|
||||||
|
)
|
||||||
|
return switched
|
||||||
|
|
||||||
|
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
||||||
|
return await self.maybe_switch_active_account(usage_percent)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "chatgpt"
|
return "chatgpt"
|
||||||
|
|
@ -84,13 +193,17 @@ class ChatGPTProvider(Provider):
|
||||||
|
|
||||||
async def register_new_account(self) -> bool:
|
async def register_new_account(self) -> bool:
|
||||||
"""Register a new ChatGPT account"""
|
"""Register a new ChatGPT account"""
|
||||||
return await register_chatgpt_account(
|
generated_tokens = await register_chatgpt_account(
|
||||||
email_provider_factory=self.email_provider_factory,
|
email_provider_factory=self.email_provider_factory,
|
||||||
)
|
)
|
||||||
|
if not generated_tokens:
|
||||||
|
return False
|
||||||
|
save_tokens(generated_tokens)
|
||||||
|
return True
|
||||||
|
|
||||||
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
||||||
"""Get usage information for the current token"""
|
"""Get usage information for the current token"""
|
||||||
usage_data = get_usage_data(access_token)
|
usage_data = await get_usage_data(access_token)
|
||||||
if not usage_data:
|
if not usage_data:
|
||||||
return {"error": "Failed to get usage"}
|
return {"error": "Failed to get usage"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -14,12 +13,18 @@ from typing import Callable
|
||||||
from urllib.parse import parse_qs, urlencode, urlparse
|
from urllib.parse import parse_qs, urlencode, urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from playwright.async_api import async_playwright, Page, BrowserContext
|
from playwright.async_api import (
|
||||||
|
async_playwright,
|
||||||
|
Error as PlaywrightError,
|
||||||
|
Page,
|
||||||
|
BrowserContext,
|
||||||
|
)
|
||||||
|
|
||||||
from browser import launch as launch_browser
|
from browser import launch as launch_browser
|
||||||
from email_providers import BaseProvider
|
from email_providers import BaseProvider
|
||||||
from providers.base import ProviderTokens
|
from providers.base import ProviderTokens
|
||||||
from .tokens import CLIENT_ID, save_tokens
|
from utils.randoms import generate_password
|
||||||
|
from .tokens import CLIENT_ID
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -46,14 +51,9 @@ async def save_error_screenshot(page: Page | None, step: str):
|
||||||
filename = screenshots_dir / f"error_{step}_{timestamp}.png"
|
filename = screenshots_dir / f"error_{step}_{timestamp}.png"
|
||||||
try:
|
try:
|
||||||
await page.screenshot(path=str(filename))
|
await page.screenshot(path=str(filename))
|
||||||
logger.error(f"Screenshot saved: {filename}")
|
logger.error("Screenshot saved: %s", filename)
|
||||||
except:
|
except PlaywrightError as e:
|
||||||
pass
|
logger.warning("Failed to save screenshot at step %s: %s", step, e)
|
||||||
|
|
||||||
|
|
||||||
def generate_password(length: int = 20) -> str:
|
|
||||||
alphabet = string.ascii_letters + string.digits
|
|
||||||
return "".join(random.choice(alphabet) for _ in range(length))
|
|
||||||
|
|
||||||
|
|
||||||
def generate_name() -> str:
|
def generate_name() -> str:
|
||||||
|
|
@ -204,8 +204,7 @@ def generate_state() -> str:
|
||||||
return secrets.token_urlsafe(32)
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
def build_authorize_url(verifier: str, challenge: str, state: str) -> str:
|
def build_authorize_url(challenge: str, state: str) -> str:
|
||||||
del verifier
|
|
||||||
params = {
|
params = {
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"client_id": CLIENT_ID,
|
"client_id": CLIENT_ID,
|
||||||
|
|
@ -222,7 +221,6 @@ def build_authorize_url(verifier: str, challenge: str, state: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
payload = {
|
payload = {
|
||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
"client_id": CLIENT_ID,
|
"client_id": CLIENT_ID,
|
||||||
|
|
@ -230,22 +228,30 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
||||||
"code_verifier": verifier,
|
"code_verifier": verifier,
|
||||||
"redirect_uri": REDIRECT_URI,
|
"redirect_uri": REDIRECT_URI,
|
||||||
}
|
}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=20)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
async with session.post(TOKEN_URL, data=payload) as resp:
|
async with session.post(TOKEN_URL, data=payload) as resp:
|
||||||
if not resp.ok:
|
if not resp.ok:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
raise RuntimeError(f"Token exchange failed: {resp.status} {text}")
|
raise RuntimeError(f"Token exchange failed: {resp.status} {text}")
|
||||||
body = await resp.json()
|
body = await resp.json()
|
||||||
|
except (aiohttp.ClientError, TimeoutError) as e:
|
||||||
|
raise RuntimeError(f"Token exchange request error: {e}") from e
|
||||||
|
|
||||||
|
try:
|
||||||
expires_in = int(body["expires_in"])
|
expires_in = int(body["expires_in"])
|
||||||
return ProviderTokens(
|
return ProviderTokens(
|
||||||
access_token=body["access_token"],
|
access_token=body["access_token"],
|
||||||
refresh_token=body["refresh_token"],
|
refresh_token=body["refresh_token"],
|
||||||
expires_at=time.time() + expires_in,
|
expires_at=time.time() + expires_in,
|
||||||
)
|
)
|
||||||
|
except (KeyError, TypeError, ValueError) as e:
|
||||||
|
raise RuntimeError(f"Token exchange response parse error: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
async def get_latest_code(email_provider: BaseProvider) -> str | None:
|
||||||
message = await email_provider.get_latest_message(email)
|
message = await email_provider.get_latest_message()
|
||||||
if not message:
|
if not message:
|
||||||
return None
|
return None
|
||||||
return extract_verification_code(message)
|
return extract_verification_code(message)
|
||||||
|
|
@ -270,6 +276,51 @@ async def click_continue(page: Page, timeout_ms: int = 10000):
|
||||||
await btn.click()
|
await btn.click()
|
||||||
|
|
||||||
|
|
||||||
|
async def oauth_needs_email_check(page: Page) -> bool:
|
||||||
|
marker = page.get_by_text("Check your inbox", exact=False)
|
||||||
|
return await marker.count() > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def fill_oauth_code_if_present(page: Page, code: str) -> bool:
|
||||||
|
candidates = [
|
||||||
|
page.get_by_placeholder("Code"),
|
||||||
|
page.get_by_label("Code"),
|
||||||
|
page.locator(
|
||||||
|
'input[name*="code" i], input[id*="code" i], '
|
||||||
|
'input[autocomplete="one-time-code"], input[inputmode="numeric"]'
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for locator in candidates:
|
||||||
|
if await locator.count() == 0:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await locator.first.wait_for(state="visible", timeout=1500)
|
||||||
|
await locator.first.fill(code)
|
||||||
|
return True
|
||||||
|
except PlaywrightError:
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def click_first_visible_button(
|
||||||
|
page: Page,
|
||||||
|
labels: list[str],
|
||||||
|
timeout_ms: int = 2000,
|
||||||
|
) -> bool:
|
||||||
|
for label in labels:
|
||||||
|
button = page.get_by_role("button", name=label)
|
||||||
|
if await button.count() == 0:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await button.first.wait_for(state="visible", timeout=timeout_ms)
|
||||||
|
await button.first.click(timeout=timeout_ms)
|
||||||
|
return True
|
||||||
|
except PlaywrightError:
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def wait_for_signup_stabilization(
|
async def wait_for_signup_stabilization(
|
||||||
page: Page,
|
page: Page,
|
||||||
source_url: str,
|
source_url: str,
|
||||||
|
|
@ -288,12 +339,12 @@ async def wait_for_signup_stabilization(
|
||||||
|
|
||||||
async def register_chatgpt_account(
|
async def register_chatgpt_account(
|
||||||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||||
) -> bool:
|
) -> ProviderTokens | None:
|
||||||
logger.info("=== Starting ChatGPT account registration ===")
|
logger.info("=== Starting ChatGPT account registration ===")
|
||||||
|
|
||||||
if email_provider_factory is None:
|
if email_provider_factory is None:
|
||||||
logger.error("No email provider factory configured")
|
logger.error("No email provider factory configured")
|
||||||
return False
|
return None
|
||||||
|
|
||||||
birth_month, birth_day, birth_year = generate_birthdate_90s()
|
birth_month, birth_day, birth_year = generate_birthdate_90s()
|
||||||
|
|
||||||
|
|
@ -321,7 +372,7 @@ async def register_chatgpt_account(
|
||||||
full_name = generate_name()
|
full_name = generate_name()
|
||||||
verifier, challenge = generate_pkce_pair()
|
verifier, challenge = generate_pkce_pair()
|
||||||
oauth_state = generate_state()
|
oauth_state = generate_state()
|
||||||
authorize_url = build_authorize_url(verifier, challenge, oauth_state)
|
authorize_url = build_authorize_url(challenge, oauth_state)
|
||||||
|
|
||||||
logger.info("[2/5] Registering ChatGPT for %s", email)
|
logger.info("[2/5] Registering ChatGPT for %s", email)
|
||||||
chatgpt_page = await context.new_page()
|
chatgpt_page = await context.new_page()
|
||||||
|
|
@ -347,24 +398,23 @@ async def register_chatgpt_account(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("[3/5] Getting verification message from email provider...")
|
logger.info("[3/5] Getting verification message from email provider...")
|
||||||
code = await get_latest_code(email_provider, email)
|
code = await get_latest_code(email_provider)
|
||||||
if not code:
|
if not code:
|
||||||
raise AutomationError(
|
raise AutomationError(
|
||||||
"email_provider", "Email provider returned no verification message"
|
"email_provider", "Email provider returned no verification message"
|
||||||
)
|
)
|
||||||
logger.info("[3/5] Verification code extracted: %s", code)
|
logger.info("[3/5] Verification code extracted")
|
||||||
|
|
||||||
await chatgpt_page.bring_to_front()
|
await chatgpt_page.bring_to_front()
|
||||||
code_input = chatgpt_page.get_by_placeholder("Code")
|
code_input = chatgpt_page.get_by_placeholder("Code")
|
||||||
if await code_input.count() > 0:
|
await code_input.first.wait_for(state="visible", timeout=10000)
|
||||||
await code_input.fill(code)
|
await code_input.first.fill(code)
|
||||||
await click_continue(chatgpt_page)
|
await click_continue(chatgpt_page)
|
||||||
|
|
||||||
logger.info("[4/5] Setting profile...")
|
logger.info("[4/5] Setting profile...")
|
||||||
name_input = chatgpt_page.get_by_placeholder("Full name")
|
name_input = chatgpt_page.get_by_placeholder("Full name")
|
||||||
await name_input.first.wait_for(state="visible", timeout=20000)
|
await name_input.first.wait_for(state="visible", timeout=20000)
|
||||||
if await name_input.count() > 0:
|
await name_input.first.fill(full_name)
|
||||||
await name_input.fill(full_name)
|
|
||||||
|
|
||||||
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
||||||
profile_url = chatgpt_page.url
|
profile_url = chatgpt_page.url
|
||||||
|
|
@ -409,25 +459,45 @@ async def register_chatgpt_account(
|
||||||
if await continue_button.count() > 0:
|
if await continue_button.count() > 0:
|
||||||
await continue_button.first.click()
|
await continue_button.first.click()
|
||||||
|
|
||||||
for label in ["Continue", "Allow", "Authorize"]:
|
last_oauth_email_code = code
|
||||||
button = oauth_page.get_by_role("button", name=label)
|
oauth_deadline = asyncio.get_running_loop().time() + 60
|
||||||
if await button.count() > 0:
|
while asyncio.get_running_loop().time() < oauth_deadline:
|
||||||
try:
|
if redirect_url_captured:
|
||||||
await button.first.click(timeout=5000)
|
break
|
||||||
await oauth_page.wait_for_timeout(500)
|
|
||||||
except Exception:
|
if await oauth_needs_email_check(oauth_page):
|
||||||
pass
|
logger.info("OAuth requested email confirmation code")
|
||||||
|
new_code = await get_latest_code(email_provider)
|
||||||
|
if new_code and new_code != last_oauth_email_code:
|
||||||
|
filled = await fill_oauth_code_if_present(oauth_page, new_code)
|
||||||
|
if filled:
|
||||||
|
last_oauth_email_code = new_code
|
||||||
|
logger.info("Filled OAuth email confirmation code")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"OAuth inbox challenge detected but code field not found"
|
||||||
|
)
|
||||||
|
|
||||||
if not redirect_url_captured:
|
|
||||||
try:
|
try:
|
||||||
await oauth_page.wait_for_timeout(4000)
|
|
||||||
current_url = oauth_page.url
|
current_url = oauth_page.url
|
||||||
if "localhost:1455" in current_url and "code=" in current_url:
|
if "localhost:1455" in current_url and "code=" in current_url:
|
||||||
redirect_url_captured = current_url
|
redirect_url_captured = current_url
|
||||||
logger.info("Captured OAuth redirect from page URL")
|
logger.info("Captured OAuth redirect from page URL")
|
||||||
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
clicked = await click_first_visible_button(
|
||||||
|
oauth_page,
|
||||||
|
["Continue", "Allow", "Authorize", "Verify"],
|
||||||
|
timeout_ms=2000,
|
||||||
|
)
|
||||||
|
|
||||||
|
if clicked:
|
||||||
|
await oauth_page.wait_for_timeout(500)
|
||||||
|
else:
|
||||||
|
await oauth_page.wait_for_timeout(1000)
|
||||||
|
|
||||||
if not redirect_url_captured:
|
if not redirect_url_captured:
|
||||||
raise AutomationError(
|
raise AutomationError(
|
||||||
"oauth", "OAuth redirect with code was not captured", oauth_page
|
"oauth", "OAuth redirect with code was not captured", oauth_page
|
||||||
|
|
@ -446,20 +516,18 @@ async def register_chatgpt_account(
|
||||||
raise AutomationError("oauth", "OAuth state mismatch", oauth_page)
|
raise AutomationError("oauth", "OAuth state mismatch", oauth_page)
|
||||||
|
|
||||||
tokens = await exchange_code_for_tokens(auth_code, verifier)
|
tokens = await exchange_code_for_tokens(auth_code, verifier)
|
||||||
save_tokens(tokens)
|
logger.info("OAuth tokens fetched successfully")
|
||||||
logger.info("OAuth tokens saved successfully")
|
|
||||||
|
|
||||||
return True
|
return tokens
|
||||||
|
|
||||||
except AutomationError as e:
|
except AutomationError as e:
|
||||||
logger.error(f"Error at step [{e.step}]: {e.message}")
|
logger.error(f"Error at step [{e.step}]: {e.message}")
|
||||||
await save_error_screenshot(e.page, e.step)
|
await save_error_screenshot(e.page, e.step)
|
||||||
return False
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error: {e}")
|
logger.error(f"Unexpected error: {e}")
|
||||||
await save_error_screenshot(current_page, "unexpected")
|
await save_error_screenshot(current_page, "unexpected")
|
||||||
return False
|
return None
|
||||||
finally:
|
finally:
|
||||||
if managed:
|
if managed:
|
||||||
await asyncio.sleep(2)
|
|
||||||
await managed.close()
|
await managed.close()
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import aiohttp
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from providers.base import ProviderTokens
|
from providers.base import ProviderTokens
|
||||||
|
|
||||||
|
|
@ -16,72 +19,143 @@ CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||||
|
|
||||||
|
|
||||||
def load_tokens() -> ProviderTokens | None:
|
def _tokens_to_dict(tokens: ProviderTokens) -> dict[str, Any]:
|
||||||
if not TOKENS_FILE.exists():
|
return {
|
||||||
|
"access_token": tokens.access_token,
|
||||||
|
"refresh_token": tokens.refresh_token,
|
||||||
|
"expires_at": tokens.expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None:
|
||||||
|
if not isinstance(data, dict):
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
with open(TOKENS_FILE) as f:
|
|
||||||
data = json.load(f)
|
|
||||||
return ProviderTokens(
|
return ProviderTokens(
|
||||||
access_token=data["access_token"],
|
access_token=data["access_token"],
|
||||||
refresh_token=data["refresh_token"],
|
refresh_token=data["refresh_token"],
|
||||||
expires_at=data["expires_at"],
|
expires_at=data["expires_at"],
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError, KeyError:
|
except (KeyError, TypeError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def save_tokens(tokens: ProviderTokens):
|
def _load_raw() -> dict[str, Any] | None:
|
||||||
|
if not TOKENS_FILE.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(TOKENS_FILE) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _save_raw(data: dict[str, Any]) -> None:
|
||||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(TOKENS_FILE, "w") as f:
|
fd, tmp_path = tempfile.mkstemp(
|
||||||
json.dump(
|
prefix=f"{TOKENS_FILE.name}.",
|
||||||
{
|
suffix=".tmp",
|
||||||
"access_token": tokens.access_token,
|
dir=str(TOKENS_FILE.parent),
|
||||||
"refresh_token": tokens.refresh_token,
|
|
||||||
"expires_at": tokens.expires_at,
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
indent=2,
|
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
with os.fdopen(fd, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
os.replace(tmp_path, TOKENS_FILE)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
||||||
|
if not data:
|
||||||
|
return {"active": None, "next_account": None}
|
||||||
|
|
||||||
|
if "active" in data or "next_account" in data:
|
||||||
|
return {
|
||||||
|
"active": data.get("active"),
|
||||||
|
"next_account": data.get("next_account"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"active": data, "next_account": None}
|
||||||
|
|
||||||
|
|
||||||
|
def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]:
|
||||||
|
normalized = _normalize_state(_load_raw())
|
||||||
|
active = _dict_to_tokens(normalized.get("active"))
|
||||||
|
next_account = _dict_to_tokens(normalized.get("next_account"))
|
||||||
|
return active, next_account
|
||||||
|
|
||||||
|
|
||||||
|
def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None:
|
||||||
|
payload = {
|
||||||
|
"active": _tokens_to_dict(active) if active else None,
|
||||||
|
"next_account": _tokens_to_dict(next_account) if next_account else None,
|
||||||
|
}
|
||||||
|
_save_raw(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokens() -> ProviderTokens | None:
|
||||||
|
active, _ = load_state()
|
||||||
|
return active
|
||||||
|
|
||||||
|
|
||||||
|
def load_next_tokens() -> ProviderTokens | None:
|
||||||
|
_, next_account = load_state()
|
||||||
|
return next_account
|
||||||
|
|
||||||
|
|
||||||
|
def save_tokens(tokens: ProviderTokens):
|
||||||
|
_, next_account = load_state()
|
||||||
|
save_state(tokens, next_account)
|
||||||
|
|
||||||
|
|
||||||
|
def promote_next_tokens() -> bool:
|
||||||
|
_, next_account = load_state()
|
||||||
|
if not next_account:
|
||||||
|
return False
|
||||||
|
save_state(next_account, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def clear_next_tokens():
|
||||||
|
active, _ = load_state()
|
||||||
|
save_state(active, None)
|
||||||
|
|
||||||
|
|
||||||
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
data = {
|
data = {
|
||||||
"grant_type": "refresh_token",
|
"grant_type": "refresh_token",
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
"client_id": CLIENT_ID,
|
"client_id": CLIENT_ID,
|
||||||
}
|
}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=15)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
async with session.post(TOKEN_URL, data=data) as resp:
|
async with session.post(TOKEN_URL, data=data) as resp:
|
||||||
if not resp.ok:
|
if not resp.ok:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
print(f"Token refresh failed: {resp.status} {text}")
|
logger.warning("Token refresh failed: %s %s", resp.status, text)
|
||||||
return None
|
return None
|
||||||
json_resp = await resp.json()
|
json_resp = await resp.json()
|
||||||
expires_in = json_resp["expires_in"]
|
except (aiohttp.ClientError, TimeoutError) as e:
|
||||||
|
logger.warning("Token refresh request error: %s", e)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Token refresh unexpected error: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
expires_in = int(json_resp["expires_in"])
|
||||||
return ProviderTokens(
|
return ProviderTokens(
|
||||||
access_token=json_resp["access_token"],
|
access_token=json_resp["access_token"],
|
||||||
refresh_token=json_resp["refresh_token"],
|
refresh_token=json_resp["refresh_token"],
|
||||||
expires_at=time.time() + expires_in,
|
expires_at=time.time() + expires_in,
|
||||||
)
|
)
|
||||||
|
except (KeyError, TypeError, ValueError) as e:
|
||||||
|
logger.warning("Token refresh response parse error: %s", e)
|
||||||
async def get_valid_tokens() -> ProviderTokens | None:
|
|
||||||
tokens = load_tokens()
|
|
||||||
if not tokens:
|
|
||||||
logger.info("No tokens found")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if tokens.is_expired:
|
|
||||||
logger.info("Token expired, refreshing...")
|
|
||||||
if not tokens.refresh_token:
|
|
||||||
logger.info("No refresh token available")
|
|
||||||
return None
|
|
||||||
new_tokens = await refresh_tokens(tokens.refresh_token)
|
|
||||||
if not new_tokens:
|
|
||||||
logger.warning("Failed to refresh token")
|
|
||||||
return None
|
|
||||||
save_tokens(new_tokens)
|
|
||||||
return new_tokens
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
import json
|
import logging
|
||||||
import socket
|
|
||||||
import urllib.error
|
|
||||||
import urllib.request
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def clamp_percent(value: Any) -> int:
|
def clamp_percent(value: Any) -> int:
|
||||||
try:
|
try:
|
||||||
num = float(value)
|
num = float(value)
|
||||||
except TypeError, ValueError:
|
except (TypeError, ValueError):
|
||||||
return 0
|
return 0
|
||||||
if num < 0:
|
if num < 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -28,30 +29,36 @@ def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None:
|
async def get_usage_data(
|
||||||
|
access_token: str,
|
||||||
|
timeout_ms: int = 10000,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {access_token}",
|
"Authorization": f"Bearer {access_token}",
|
||||||
"User-Agent": "CodexProxy",
|
"User-Agent": "CodexProxy",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
req = urllib.request.Request(
|
timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000)
|
||||||
"https://chatgpt.com/backend-api/wham/usage",
|
url = "https://chatgpt.com/backend-api/wham/usage"
|
||||||
headers=headers,
|
|
||||||
method="GET",
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers=headers) as res:
|
||||||
|
if not res.ok:
|
||||||
|
body = await res.text()
|
||||||
|
logger.warning(
|
||||||
|
"Usage fetch failed: status=%s body=%s",
|
||||||
|
res.status,
|
||||||
|
body[:300],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res:
|
|
||||||
body = res.read().decode("utf-8", errors="replace")
|
|
||||||
except urllib.error.HTTPError:
|
|
||||||
return None
|
return None
|
||||||
except urllib.error.URLError, socket.timeout:
|
data = await res.json()
|
||||||
|
except (aiohttp.ClientError, TimeoutError) as e:
|
||||||
|
logger.warning("Usage fetch request error: %s", e)
|
||||||
return None
|
return None
|
||||||
|
except Exception as e:
|
||||||
try:
|
logger.warning("Usage fetch unexpected error: %s", e)
|
||||||
data = json.loads(body)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
rate_limit = data.get("rate_limit") or {}
|
rate_limit = data.get("rate_limit") or {}
|
||||||
|
|
@ -76,10 +83,3 @@ def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any]
|
||||||
"limit_reached": bool(rate_limit.get("limit_reached")),
|
"limit_reached": bool(rate_limit.get("limit_reached")),
|
||||||
"allowed": bool(rate_limit.get("allowed", True)),
|
"allowed": bool(rate_limit.get("allowed", True)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
|
||||||
data = get_usage_data(access_token, timeout_ms=timeout_ms)
|
|
||||||
if not data:
|
|
||||||
return -1
|
|
||||||
return int(data["used_percent"])
|
|
||||||
|
|
|
||||||
175
src/server.py
175
src/server.py
|
|
@ -1,27 +1,22 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from providers.chatgpt import ChatGPTProvider
|
from providers.chatgpt import ChatGPTProvider
|
||||||
|
from providers.base import Provider
|
||||||
PORT = int(os.environ.get("PORT", "8080"))
|
from utils.env import parse_int_env
|
||||||
USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85"))
|
|
||||||
LIMIT_EXHAUSTED_PERCENT = 100
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
PORT = parse_int_env("PORT", 8080, 1, 65535)
|
||||||
|
LIMIT_EXHAUSTED_PERCENT = 100
|
||||||
|
|
||||||
# Registry of available providers
|
PROVIDERS: dict[str, Provider] = {
|
||||||
PROVIDERS = {
|
|
||||||
"chatgpt": ChatGPTProvider(),
|
"chatgpt": ChatGPTProvider(),
|
||||||
}
|
}
|
||||||
|
|
||||||
refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()}
|
background_tasks: dict[str, asyncio.Task | None] = {name: None for name in PROVIDERS}
|
||||||
background_refresh_tasks: dict[str, asyncio.Task | None] = {
|
|
||||||
name: None for name in PROVIDERS.keys()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
|
|
@ -31,16 +26,30 @@ async def request_log_middleware(request: web.Request, handler):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
def build_limit(usage_percent: int, prepare_threshold: int) -> dict[str, int | bool]:
|
||||||
remaining = max(0, 100 - usage_percent)
|
remaining = max(0, 100 - usage_percent)
|
||||||
return {
|
return {
|
||||||
"used_percent": usage_percent,
|
"used_percent": usage_percent,
|
||||||
"remaining_percent": remaining,
|
"remaining_percent": remaining,
|
||||||
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
|
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
|
||||||
"needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD,
|
"needs_prepare": usage_percent >= prepare_threshold,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_prepare_threshold(provider_name: str) -> int:
|
||||||
|
provider = PROVIDERS.get(provider_name)
|
||||||
|
if not provider:
|
||||||
|
return 100
|
||||||
|
return provider.prepare_threshold
|
||||||
|
|
||||||
|
|
||||||
|
def should_trigger_standby_prepare(provider_name: str, usage_percent: int) -> bool:
|
||||||
|
provider = PROVIDERS.get(provider_name)
|
||||||
|
if not provider:
|
||||||
|
return False
|
||||||
|
return provider.should_prepare_standby(usage_percent)
|
||||||
|
|
||||||
|
|
||||||
async def ensure_provider_token_ready(provider_name: str):
|
async def ensure_provider_token_ready(provider_name: str):
|
||||||
provider = PROVIDERS.get(provider_name)
|
provider = PROVIDERS.get(provider_name)
|
||||||
if not provider:
|
if not provider:
|
||||||
|
|
@ -52,7 +61,6 @@ async def ensure_provider_token_ready(provider_name: str):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[%s] Startup token check failed, forcing recreation", provider_name
|
"[%s] Startup token check failed, forcing recreation", provider_name
|
||||||
)
|
)
|
||||||
if isinstance(provider, ChatGPTProvider):
|
|
||||||
token = await provider.force_recreate_token()
|
token = await provider.force_recreate_token()
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
|
|
@ -60,103 +68,88 @@ async def ensure_provider_token_ready(provider_name: str):
|
||||||
return
|
return
|
||||||
|
|
||||||
usage_info = await provider.get_usage_info(token)
|
usage_info = await provider.get_usage_info(token)
|
||||||
if "error" not in usage_info:
|
if "error" in usage_info:
|
||||||
logger.info("[%s] Startup token is ready", provider_name)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[%s] Startup token invalid for usage, forcing recreation", provider_name
|
"[%s] Startup token invalid for usage, forcing recreation", provider_name
|
||||||
)
|
)
|
||||||
if isinstance(provider, ChatGPTProvider):
|
|
||||||
token = await provider.force_recreate_token()
|
token = await provider.force_recreate_token()
|
||||||
if token:
|
if not token:
|
||||||
logger.info("[%s] Startup token recreated successfully", provider_name)
|
logger.error("[%s] Startup token recreation failed", provider_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.error("[%s] Startup token recreation failed", provider_name)
|
await provider.startup_prepare()
|
||||||
|
logger.info("[%s] Startup token is ready", provider_name)
|
||||||
|
|
||||||
|
|
||||||
async def on_startup(app: web.Application):
|
async def ensure_standby_task(provider_name: str, usage_percent: int, reason: str):
|
||||||
del app
|
|
||||||
for provider_name in PROVIDERS.keys():
|
|
||||||
await ensure_provider_token_ready(provider_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def issue_new_token(provider_name: str) -> str | None:
|
|
||||||
provider = PROVIDERS.get(provider_name)
|
provider = PROVIDERS.get(provider_name)
|
||||||
if not provider:
|
if not provider:
|
||||||
return None
|
return
|
||||||
|
|
||||||
async with refresh_locks[provider_name]:
|
if not provider.should_prepare_standby(usage_percent):
|
||||||
logger.info(f"[{provider_name}] Generating new token")
|
return
|
||||||
success = await provider.register_new_account()
|
|
||||||
if not success:
|
|
||||||
logger.error(f"[{provider_name}] Token generation failed")
|
|
||||||
return None
|
|
||||||
|
|
||||||
token = await provider.get_token()
|
|
||||||
if not token:
|
|
||||||
logger.error(f"[{provider_name}] Token was generated but not available")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
async def background_refresh_worker(provider_name: str, reason: str):
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"[{provider_name}] Starting background token refresh ({reason})")
|
logger.info("[%s] Preparing standby in background (%s)", provider_name, reason)
|
||||||
new_token = await issue_new_token(provider_name)
|
await provider.ensure_standby_account(usage_percent)
|
||||||
if new_token:
|
|
||||||
logger.info(f"[{provider_name}] Background token refresh completed")
|
|
||||||
else:
|
|
||||||
logger.error(f"[{provider_name}] Background token refresh failed")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception("[%s] Unhandled standby preparation error", provider_name)
|
||||||
f"[{provider_name}] Unhandled error in background token refresh"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def trigger_background_refresh(provider_name: str, reason: str):
|
def trigger_standby_prepare(provider_name: str, usage_percent: int, reason: str):
|
||||||
task = background_refresh_tasks.get(provider_name)
|
task = background_tasks.get(provider_name)
|
||||||
if task and not task.done():
|
if task and not task.done():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{provider_name}] Background refresh already running, skip ({reason})"
|
"[%s] Standby prep already running, skip (%s)", provider_name, reason
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
background_refresh_tasks[provider_name] = asyncio.create_task(
|
background_tasks[provider_name] = asyncio.create_task(
|
||||||
background_refresh_worker(provider_name, reason)
|
ensure_standby_task(provider_name, usage_percent, reason)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def token_handler(request: web.Request) -> web.Response:
|
async def token_handler(request: web.Request) -> web.Response:
|
||||||
provider_name = request.match_info.get("provider", "chatgpt")
|
provider_name = request.match_info.get("provider", "chatgpt")
|
||||||
|
|
||||||
provider = PROVIDERS.get(provider_name)
|
provider = PROVIDERS.get(provider_name)
|
||||||
if not provider:
|
if not provider:
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"error": f"Unknown provider: {provider_name}"},
|
{"error": f"Unknown provider: {provider_name}"}, status=404
|
||||||
status=404,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get or create token
|
token = await provider.get_token()
|
||||||
|
if not token:
|
||||||
|
return web.json_response({"error": "Failed to get active token"}, status=503)
|
||||||
|
|
||||||
|
usage_info = await provider.get_usage_info(token)
|
||||||
|
if "error" in usage_info:
|
||||||
|
return web.json_response({"error": usage_info["error"]}, status=503)
|
||||||
|
|
||||||
|
usage_percent = int(usage_info.get("used_percent", 0))
|
||||||
|
switched = await provider.maybe_rotate_account(usage_percent)
|
||||||
|
if switched:
|
||||||
token = await provider.get_token()
|
token = await provider.get_token()
|
||||||
if not token:
|
if not token:
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"error": "Failed to get active token"},
|
{"error": "Failed to get active token after account switch"},
|
||||||
status=503,
|
status=503,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get usage info
|
|
||||||
usage_info = await provider.get_usage_info(token)
|
usage_info = await provider.get_usage_info(token)
|
||||||
if "error" in usage_info:
|
if "error" in usage_info:
|
||||||
return web.json_response(
|
return web.json_response({"error": usage_info["error"]}, status=503)
|
||||||
{"error": usage_info["error"]},
|
usage_percent = int(usage_info.get("used_percent", 0))
|
||||||
status=503,
|
logger.info("[%s] Active account switched before response", provider_name)
|
||||||
|
|
||||||
|
prepare_threshold = get_prepare_threshold(provider_name)
|
||||||
|
if should_trigger_standby_prepare(provider_name, usage_percent):
|
||||||
|
trigger_standby_prepare(
|
||||||
|
provider_name,
|
||||||
|
usage_percent,
|
||||||
|
f"usage {usage_percent}% reached standby policy",
|
||||||
)
|
)
|
||||||
|
|
||||||
usage_percent = usage_info.get("used_percent", 0)
|
remaining_percent = int(
|
||||||
remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent))
|
usage_info.get("remaining_percent", max(0, 100 - usage_percent))
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s] token issued, used=%s%% remaining=%s%%",
|
"[%s] token issued, used=%s%% remaining=%s%%",
|
||||||
provider_name,
|
provider_name,
|
||||||
|
|
@ -183,17 +176,10 @@ async def token_handler(request: web.Request) -> web.Response:
|
||||||
secondary_window.get("reset_after_seconds", 0),
|
secondary_window.get("reset_after_seconds", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger background refresh if needed
|
|
||||||
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
|
||||||
trigger_background_refresh(
|
|
||||||
provider_name,
|
|
||||||
f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%",
|
|
||||||
)
|
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{
|
{
|
||||||
"token": token,
|
"token": token,
|
||||||
"limit": build_limit(usage_percent),
|
"limit": build_limit(usage_percent, prepare_threshold),
|
||||||
"usage": {
|
"usage": {
|
||||||
"primary_window": primary_window,
|
"primary_window": primary_window,
|
||||||
"secondary_window": secondary_window,
|
"secondary_window": secondary_window,
|
||||||
|
|
@ -202,19 +188,42 @@ async def token_handler(request: web.Request) -> web.Response:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup(app: web.Application):
|
||||||
|
del app
|
||||||
|
for provider_name in PROVIDERS:
|
||||||
|
await ensure_provider_token_ready(provider_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_cleanup(app: web.Application):
|
||||||
|
del app
|
||||||
|
for task in background_tasks.values():
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
pending = [t for t in background_tasks.values() if t is not None]
|
||||||
|
if pending:
|
||||||
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> web.Application:
|
def create_app() -> web.Application:
|
||||||
app = web.Application(middlewares=[request_log_middleware])
|
app = web.Application(middlewares=[request_log_middleware])
|
||||||
app.on_startup.append(on_startup)
|
app.on_startup.append(on_startup)
|
||||||
# New route: /{provider}/token
|
app.on_cleanup.append(on_cleanup)
|
||||||
app.router.add_get("/{provider}/token", token_handler)
|
app.router.add_get("/{provider}/token", token_handler)
|
||||||
# Legacy route for backward compatibility
|
|
||||||
app.router.add_get("/token", token_handler)
|
app.router.add_get("/token", token_handler)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Starting token service on port %s", PORT)
|
logger.info("Starting token service on port %s", PORT)
|
||||||
logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD)
|
chatgpt_provider = PROVIDERS.get("chatgpt")
|
||||||
|
if chatgpt_provider:
|
||||||
|
logger.info(
|
||||||
|
"ChatGPT prepare-next threshold: %s%%", chatgpt_provider.prepare_threshold
|
||||||
|
)
|
||||||
|
if chatgpt_provider.switch_threshold is not None:
|
||||||
|
logger.info(
|
||||||
|
"ChatGPT switch threshold: %s%%", chatgpt_provider.switch_threshold
|
||||||
|
)
|
||||||
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
||||||
app = create_app()
|
app = create_app()
|
||||||
web.run_app(app, host="0.0.0.0", port=PORT)
|
web.run_app(app, host="0.0.0.0", port=PORT)
|
||||||
|
|
|
||||||
4
src/utils/__init__.py
Normal file
4
src/utils/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .env import parse_int_env
|
||||||
|
from .randoms import generate_password
|
||||||
|
|
||||||
|
__all__ = ["parse_int_env", "generate_password"]
|
||||||
22
src/utils/env.py
Normal file
22
src/utils/env.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def parse_int_env(
|
||||||
|
name: str,
|
||||||
|
default: int,
|
||||||
|
minimum: int,
|
||||||
|
maximum: int,
|
||||||
|
) -> int:
|
||||||
|
raw = os.environ.get(name)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = int(raw)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
if value < minimum or value > maximum:
|
||||||
|
return default
|
||||||
|
|
||||||
|
return value
|
||||||
13
src/utils/randoms.py
Normal file
13
src/utils/randoms.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
import random
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
|
||||||
|
|
||||||
|
def generate_password(
|
||||||
|
length: int = 20,
|
||||||
|
*,
|
||||||
|
secure: bool = True,
|
||||||
|
) -> str:
|
||||||
|
alphabet = string.ascii_letters + string.digits
|
||||||
|
chooser = secrets.choice if secure else random.choice
|
||||||
|
return "".join(chooser(alphabet) for _ in range(length))
|
||||||
12
tests/conftest.py
Normal file
12
tests/conftest.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _add_src_to_path() -> None:
|
||||||
|
root = Path(__file__).resolve().parents[1]
|
||||||
|
src = root / "src"
|
||||||
|
if str(src) not in sys.path:
|
||||||
|
sys.path.insert(0, str(src))
|
||||||
|
|
||||||
|
|
||||||
|
_add_src_to_path()
|
||||||
37
tests/test_registration_unit.py
Normal file
37
tests/test_registration_unit.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
from providers.chatgpt.registration import (
|
||||||
|
build_authorize_url,
|
||||||
|
extract_verification_code,
|
||||||
|
generate_birthdate_90s,
|
||||||
|
generate_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_name_shape():
|
||||||
|
name = generate_name()
|
||||||
|
parts = name.split(" ")
|
||||||
|
assert len(parts) == 2
|
||||||
|
assert all(p.isalpha() for p in parts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_birthdate_90s_range():
|
||||||
|
month, day, year = generate_birthdate_90s()
|
||||||
|
assert 1 <= int(month) <= 12
|
||||||
|
assert 1 <= int(day) <= 28
|
||||||
|
assert 1990 <= int(year) <= 1999
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_verification_code_prefers_chatgpt_phrase():
|
||||||
|
text = "foo 123456 bar Your ChatGPT code is 654321"
|
||||||
|
assert extract_verification_code(text) == "654321"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_verification_code_fallback_last_code():
|
||||||
|
text = "codes 111111 and 222222"
|
||||||
|
assert extract_verification_code(text) == "222222"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_authorize_url_contains_required_params():
|
||||||
|
url = build_authorize_url("challenge", "state123")
|
||||||
|
assert "response_type=code" in url
|
||||||
|
assert "code_challenge=challenge" in url
|
||||||
|
assert "state=state123" in url
|
||||||
159
tests/test_server_unit.py
Normal file
159
tests/test_server_unit.py
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import server
|
||||||
|
from providers.base import Provider, ProviderTokens
|
||||||
|
from utils.env import parse_int_env
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRequest:
|
||||||
|
def __init__(self, provider: str):
|
||||||
|
self.match_info = {"provider": provider}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeProvider(Provider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: str | None = "tok",
|
||||||
|
usage: dict | None = None,
|
||||||
|
rotate: bool = False,
|
||||||
|
):
|
||||||
|
self._token = token
|
||||||
|
self._usage = usage or {
|
||||||
|
"used_percent": 10,
|
||||||
|
"remaining_percent": 90,
|
||||||
|
"primary_window": None,
|
||||||
|
"secondary_window": None,
|
||||||
|
}
|
||||||
|
self._rotate = rotate
|
||||||
|
self._prepare_threshold = 80
|
||||||
|
self.get_token_calls = 0
|
||||||
|
self.standby_calls = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prepare_threshold(self) -> int:
|
||||||
|
return self._prepare_threshold
|
||||||
|
|
||||||
|
def should_prepare_standby(self, usage_percent: int) -> bool:
|
||||||
|
return usage_percent >= self.prepare_threshold
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "fake"
|
||||||
|
|
||||||
|
async def get_token(self) -> str | None:
|
||||||
|
self.get_token_calls += 1
|
||||||
|
return self._token
|
||||||
|
|
||||||
|
async def register_new_account(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_usage_info(self, access_token: str) -> dict:
|
||||||
|
_ = access_token
|
||||||
|
return dict(self._usage)
|
||||||
|
|
||||||
|
def load_tokens(self) -> ProviderTokens | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_tokens(self, tokens: ProviderTokens) -> None:
|
||||||
|
_ = tokens
|
||||||
|
|
||||||
|
async def maybe_rotate_account(self, usage_percent: int) -> bool:
|
||||||
|
_ = usage_percent
|
||||||
|
return self._rotate
|
||||||
|
|
||||||
|
async def ensure_standby_account(
|
||||||
|
self,
|
||||||
|
usage_percent: int,
|
||||||
|
) -> None:
|
||||||
|
_ = usage_percent
|
||||||
|
self.standby_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _response_json(resp) -> dict:
|
||||||
|
return json.loads(resp.body.decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_int_env_defaults(monkeypatch):
|
||||||
|
monkeypatch.delenv("X_TEST", raising=False)
|
||||||
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_int_env_invalid(monkeypatch):
|
||||||
|
monkeypatch.setenv("X_TEST", "abc")
|
||||||
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_int_env_out_of_range(monkeypatch):
|
||||||
|
monkeypatch.setenv("X_TEST", "999")
|
||||||
|
assert parse_int_env("X_TEST", 10, 1, 20) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_limit_fields():
|
||||||
|
limit = server.build_limit(90, 85)
|
||||||
|
assert limit == {
|
||||||
|
"used_percent": 90,
|
||||||
|
"remaining_percent": 10,
|
||||||
|
"exhausted": False,
|
||||||
|
"needs_prepare": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_prepare_threshold():
|
||||||
|
assert (
|
||||||
|
server.get_prepare_threshold("chatgpt")
|
||||||
|
== server.PROVIDERS["chatgpt"].prepare_threshold
|
||||||
|
)
|
||||||
|
assert server.get_prepare_threshold("unknown") == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_handler_unknown_provider(monkeypatch):
|
||||||
|
monkeypatch.setattr(server, "PROVIDERS", {})
|
||||||
|
resp = asyncio.run(server.token_handler(FakeRequest("missing")))
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_handler_success(monkeypatch):
|
||||||
|
provider = FakeProvider()
|
||||||
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||||
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||||
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||||
|
data = _response_json(resp)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
assert data["token"] == "tok"
|
||||||
|
assert data["limit"]["needs_prepare"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_handler_triggers_standby(monkeypatch):
|
||||||
|
provider = FakeProvider(usage={"used_percent": 90, "remaining_percent": 10})
|
||||||
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||||
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||||
|
|
||||||
|
called = {"value": False}
|
||||||
|
|
||||||
|
def fake_trigger(name, usage_percent, reason):
|
||||||
|
assert name == "fake"
|
||||||
|
assert usage_percent == 90
|
||||||
|
assert "standby policy" in reason
|
||||||
|
called["value"] = True
|
||||||
|
|
||||||
|
monkeypatch.setattr(server, "trigger_standby_prepare", fake_trigger)
|
||||||
|
|
||||||
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||||
|
assert resp.status == 200
|
||||||
|
assert called["value"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_handler_rotation_path(monkeypatch):
|
||||||
|
provider = FakeProvider(
|
||||||
|
usage={"used_percent": 96, "remaining_percent": 4},
|
||||||
|
rotate=True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(server, "PROVIDERS", {"fake": provider})
|
||||||
|
monkeypatch.setattr(server, "background_tasks", {"fake": None})
|
||||||
|
monkeypatch.setattr(server, "trigger_standby_prepare", lambda *_: None)
|
||||||
|
|
||||||
|
resp = asyncio.run(server.token_handler(FakeRequest("fake")))
|
||||||
|
assert resp.status == 200
|
||||||
|
assert provider.get_token_calls >= 2
|
||||||
60
tests/test_tokens_unit.py
Normal file
60
tests/test_tokens_unit.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from providers.base import ProviderTokens
|
||||||
|
from providers.chatgpt import tokens as t
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_state_backward_compatible():
|
||||||
|
raw = {"access_token": "a", "refresh_token": "r", "expires_at": 1}
|
||||||
|
normalized = t._normalize_state(raw)
|
||||||
|
assert normalized["active"]["access_token"] == "a"
|
||||||
|
assert normalized["next_account"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_promote_next_tokens(tmp_path, monkeypatch):
|
||||||
|
file_path = tmp_path / "chatgpt_tokens.json"
|
||||||
|
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||||
|
|
||||||
|
active = ProviderTokens("a1", "r1", 100)
|
||||||
|
nxt = ProviderTokens("a2", "r2", 200)
|
||||||
|
t.save_state(active, nxt)
|
||||||
|
|
||||||
|
assert t.promote_next_tokens() is True
|
||||||
|
cur, next_cur = t.load_state()
|
||||||
|
assert cur is not None
|
||||||
|
assert cur.access_token == "a2"
|
||||||
|
assert next_cur is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_tokens_preserves_next(tmp_path, monkeypatch):
|
||||||
|
file_path = tmp_path / "chatgpt_tokens.json"
|
||||||
|
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||||
|
|
||||||
|
active = ProviderTokens("a1", "r1", 100)
|
||||||
|
nxt = ProviderTokens("a2", "r2", 200)
|
||||||
|
t.save_state(active, nxt)
|
||||||
|
|
||||||
|
t.save_tokens(ProviderTokens("a3", "r3", 300))
|
||||||
|
cur, next_cur = t.load_state()
|
||||||
|
assert cur is not None and cur.access_token == "a3"
|
||||||
|
assert next_cur is not None and next_cur.access_token == "a2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_atomic_write_produces_valid_json(tmp_path, monkeypatch):
|
||||||
|
file_path = tmp_path / "chatgpt_tokens.json"
|
||||||
|
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||||
|
|
||||||
|
t.save_state(ProviderTokens("x", "y", 123), None)
|
||||||
|
with open(file_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert "active" in data
|
||||||
|
assert data["active"]["access_token"] == "x"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_state_from_missing_file(tmp_path, monkeypatch):
|
||||||
|
file_path = tmp_path / "missing.json"
|
||||||
|
monkeypatch.setattr(t, "TOKENS_FILE", file_path)
|
||||||
|
active, nxt = t.load_state()
|
||||||
|
assert active is None
|
||||||
|
assert nxt is None
|
||||||
32
tests/test_usage_unit.py
Normal file
32
tests/test_usage_unit.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
from providers.chatgpt.usage import _parse_window, clamp_percent
|
||||||
|
|
||||||
|
|
||||||
|
def test_clamp_percent_bounds():
|
||||||
|
assert clamp_percent(-1) == 0
|
||||||
|
assert clamp_percent(150) == 100
|
||||||
|
assert clamp_percent(49.6) == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_clamp_percent_invalid():
|
||||||
|
assert clamp_percent(None) == 0
|
||||||
|
assert clamp_percent("bad") == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_window_valid():
|
||||||
|
window = {
|
||||||
|
"used_percent": 34.4,
|
||||||
|
"limit_window_seconds": 3600,
|
||||||
|
"reset_after_seconds": 120,
|
||||||
|
"reset_at": 999,
|
||||||
|
}
|
||||||
|
parsed = _parse_window(window)
|
||||||
|
assert parsed == {
|
||||||
|
"used_percent": 34,
|
||||||
|
"limit_window_seconds": 3600,
|
||||||
|
"reset_after_seconds": 120,
|
||||||
|
"reset_at": 999,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_window_none():
|
||||||
|
assert _parse_window(None) is None
|
||||||
68
uv.lock
generated
68
uv.lock
generated
|
|
@ -83,6 +83,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" },
|
{ url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "colorama"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "frozenlist"
|
name = "frozenlist"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
|
@ -158,6 +167,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iniconfig"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "megapt"
|
name = "megapt"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
@ -168,12 +186,19 @@ dependencies = [
|
||||||
{ name = "playwright" },
|
{ name = "playwright" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "aiohttp", specifier = "==3.13.3" },
|
{ name = "aiohttp", specifier = "==3.13.3" },
|
||||||
{ name = "pkce", specifier = "==1.0.3" },
|
{ name = "pkce", specifier = "==1.0.3" },
|
||||||
{ name = "playwright", specifier = "==1.58.0" },
|
{ name = "playwright", specifier = "==1.58.0" },
|
||||||
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||||
]
|
]
|
||||||
|
provides-extras = ["dev"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multidict"
|
name = "multidict"
|
||||||
|
|
@ -220,6 +245,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
|
{ url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "packaging"
|
||||||
|
version = "26.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pkce"
|
name = "pkce"
|
||||||
version = "1.0.3"
|
version = "1.0.3"
|
||||||
|
|
@ -248,6 +282,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/c8/c4/cc0229fea55c87d6c9c67fe44a21e2cd28d1d558a5478ed4d617e9fb0c93/playwright-1.58.0-py3-none-win_arm64.whl", hash = "sha256:32ffe5c303901a13a0ecab91d1c3f74baf73b84f4bedbb6b935f5bc11cc98e1b", size = 33085919, upload-time = "2026-01-30T15:09:45.71Z" },
|
{ url = "https://files.pythonhosted.org/packages/c8/c4/cc0229fea55c87d6c9c67fe44a21e2cd28d1d558a5478ed4d617e9fb0c93/playwright-1.58.0-py3-none-win_arm64.whl", hash = "sha256:32ffe5c303901a13a0ecab91d1c3f74baf73b84f4bedbb6b935f5bc11cc98e1b", size = 33085919, upload-time = "2026-01-30T15:09:45.71Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pluggy"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "propcache"
|
name = "propcache"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
|
|
@ -299,6 +342,31 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/a0/c4/b4d4827c93ef43c01f599ef31453ccc1c132b353284fc6c87d535c233129/pyee-13.0.1-py3-none-any.whl", hash = "sha256:af2f8fede4171ef667dfded53f96e2ed0d6e6bd7ee3bb46437f77e3b57689228", size = 15659, upload-time = "2026-02-14T21:12:26.263Z" },
|
{ url = "https://files.pythonhosted.org/packages/a0/c4/b4d4827c93ef43c01f599ef31453ccc1c132b353284fc6c87d535c233129/pyee-13.0.1-py3-none-any.whl", hash = "sha256:af2f8fede4171ef667dfded53f96e2ed0d6e6bd7ee3bb46437f77e3b57689228", size = 15659, upload-time = "2026-02-14T21:12:26.263Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pygments"
|
||||||
|
version = "2.19.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest"
|
||||||
|
version = "9.0.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
{ name = "iniconfig" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pluggy" },
|
||||||
|
{ name = "pygments" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.15.0"
|
version = "4.15.0"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue