refactor!: change the entire purpose of this script
This commit is contained in:
parent
217e176975
commit
71d1050adb
20 changed files with 1124 additions and 872 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,2 +1,4 @@
|
|||
/data/
|
||||
__pycache__/
|
||||
.ruff_cache/
|
||||
.venv/
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ RUN pip install --no-cache-dir uv
|
|||
RUN uv sync --frozen --no-dev
|
||||
RUN /app/.venv/bin/python -m playwright install --with-deps chromium
|
||||
|
||||
COPY *.py /app/
|
||||
COPY src/*.py /app/
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PORT=8000
|
||||
108
src/browser.py
Normal file
108
src/browser.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from playwright.async_api import Browser, Playwright
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHROME_FLAGS = [
|
||||
"--no-startup-window",
|
||||
"--disable-field-trial-config",
|
||||
"--disable-background-networking",
|
||||
"--disable-background-timer-throttling",
|
||||
"--disable-backgrounding-occluded-windows",
|
||||
"--disable-back-forward-cache",
|
||||
"--disable-breakpad",
|
||||
"--disable-client-side-phishing-detection",
|
||||
"--disable-component-extensions-with-background-pages",
|
||||
"--disable-component-update",
|
||||
"--no-default-browser-check",
|
||||
"--disable-default-apps",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-extensions",
|
||||
"--disable-popup-blocking",
|
||||
"--disable-prompt-on-repost",
|
||||
"--disable-renderer-backgrounding",
|
||||
"--disable-hang-monitor",
|
||||
"--disable-ipc-flooding-protection",
|
||||
"--force-color-profile=srgb",
|
||||
"--metrics-recording-only",
|
||||
"--no-first-run",
|
||||
"--password-store=basic",
|
||||
"--use-mock-keychain",
|
||||
"--disable-infobars",
|
||||
"--disable-sync",
|
||||
"--enable-unsafe-swiftshader",
|
||||
"--no-sandbox",
|
||||
"--disable-search-engine-choice-screen",
|
||||
]
|
||||
|
||||
|
||||
def _fetch_ws_endpoint(port: int) -> str | None:
|
||||
try:
|
||||
with urllib.request.urlopen(
|
||||
f"http://127.0.0.1:{port}/json/version",
|
||||
timeout=1,
|
||||
) as resp:
|
||||
data = json.loads(resp.read().decode("utf-8"))
|
||||
return data.get("webSocketDebuggerUrl")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManagedBrowser:
|
||||
browser: Browser
|
||||
process: subprocess.Popen
|
||||
profile_dir: Path
|
||||
|
||||
async def close(self) -> None:
|
||||
try:
|
||||
await self.browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
if self.profile_dir.exists():
|
||||
shutil.rmtree(self.profile_dir, ignore_errors=True)
|
||||
|
||||
|
||||
async def launch(playwright: Playwright, cdp_port: int | None = None) -> ManagedBrowser:
|
||||
chrome_path = os.environ.get("CHROMIUM_PATH") or playwright.chromium.executable_path
|
||||
cdp_port = cdp_port or int(os.environ.get("CDP_PORT", "9222"))
|
||||
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
|
||||
|
||||
args = [
|
||||
chrome_path,
|
||||
*CHROME_FLAGS,
|
||||
f"--user-data-dir={profile_dir}",
|
||||
f"--remote-debugging-port={cdp_port}",
|
||||
]
|
||||
|
||||
proc = subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
|
||||
ws_endpoint = None
|
||||
for _ in range(60):
|
||||
ws_endpoint = await asyncio.to_thread(_fetch_ws_endpoint, cdp_port)
|
||||
if ws_endpoint:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not ws_endpoint:
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"CDP websocket not available on port {cdp_port}")
|
||||
|
||||
logger.info("CDP websocket: %s", ws_endpoint)
|
||||
browser = await playwright.chromium.connect_over_cdp(ws_endpoint)
|
||||
return ManagedBrowser(browser=browser, process=proc, profile_dir=profile_dir)
|
||||
5
src/email_providers/__init__.py
Normal file
5
src/email_providers/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from .base import BaseProvider
|
||||
from .ten_minute_mail import TenMinuteMailProvider
|
||||
from .temp_mail_org import TempMailOrgProvider
|
||||
|
||||
__all__ = ["BaseProvider", "TenMinuteMailProvider", "TempMailOrgProvider"]
|
||||
16
src/email_providers/base.py
Normal file
16
src/email_providers/base.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from playwright.async_api import BrowserContext
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, browser_session: BrowserContext):
|
||||
self.browser_session = browser_session
|
||||
|
||||
@abstractmethod
|
||||
async def get_new_email(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_latest_message(self, email: str) -> str | None:
|
||||
pass
|
||||
125
src/email_providers/temp_mail_org.py
Normal file
125
src/email_providers/temp_mail_org.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
|
||||
from playwright.async_api import BrowserContext, Page
|
||||
|
||||
from .base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TempMailOrgProvider(BaseProvider):
|
||||
def __init__(self, browser_session: BrowserContext):
|
||||
super().__init__(browser_session)
|
||||
self.page: Page | None = None
|
||||
|
||||
async def _ensure_page(self) -> Page:
|
||||
if self.page is None or self.page.is_closed():
|
||||
self.page = await self.browser_session.new_page()
|
||||
return self.page
|
||||
|
||||
async def get_new_email(self) -> str:
|
||||
page = await self._ensure_page()
|
||||
logger.info("[temp-mail.org] Opening mailbox page")
|
||||
await page.goto("https://temp-mail.org", wait_until="domcontentloaded")
|
||||
await page.locator("input#mail, #mail, input[value*='@']").first.wait_for(
|
||||
state="visible",
|
||||
timeout=30000,
|
||||
)
|
||||
|
||||
selectors = ["#mail", "input#mail", "input[value*='@']"]
|
||||
end_at = asyncio.get_running_loop().time() + 60
|
||||
while asyncio.get_running_loop().time() < end_at:
|
||||
await page.bring_to_front()
|
||||
for selector in selectors:
|
||||
try:
|
||||
field = page.locator(selector).first
|
||||
if await field.is_visible(timeout=1000):
|
||||
value = (await field.input_value()).strip()
|
||||
if "@" in value:
|
||||
logger.info(
|
||||
"[temp-mail.org] selector matched: %s -> %s",
|
||||
selector,
|
||||
value,
|
||||
)
|
||||
return value
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
body = await page.inner_text("body")
|
||||
found = extract_email(body)
|
||||
if found:
|
||||
logger.info("[temp-mail.org] email found by body scan: %s", found)
|
||||
return found
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
raise RuntimeError("Could not get temp email from temp-mail.org")
|
||||
|
||||
async def get_latest_message(self, email: str) -> str | None:
|
||||
page = await self._ensure_page()
|
||||
logger.info("[temp-mail.org] Waiting for latest message for %s", email)
|
||||
|
||||
if page.is_closed():
|
||||
raise RuntimeError("temp-mail.org tab was closed unexpectedly")
|
||||
|
||||
await page.bring_to_front()
|
||||
|
||||
items = page.locator("div.inbox-dataList ul li")
|
||||
|
||||
# temp-mail updates inbox via websocket; do not refresh/reload page.
|
||||
for attempt in range(30):
|
||||
try:
|
||||
count = await items.count()
|
||||
logger.info("[temp-mail.org] inbox items: %s", count)
|
||||
except Exception:
|
||||
count = 0
|
||||
|
||||
if count > 0:
|
||||
for idx in reversed(range(count)):
|
||||
try:
|
||||
item = items.nth(idx)
|
||||
if not await item.is_visible(timeout=1000):
|
||||
continue
|
||||
text = (await item.inner_text()).strip().replace("\n", " ")
|
||||
logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160])
|
||||
except Exception:
|
||||
continue
|
||||
if text:
|
||||
try:
|
||||
await item.click()
|
||||
logger.info("[temp-mail.org] opened item[%s]", idx)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
message_text = text
|
||||
try:
|
||||
content = await page.content()
|
||||
if content and "Your ChatGPT code is" in content:
|
||||
message_text = content
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
await page.go_back(
|
||||
wait_until="domcontentloaded", timeout=5000
|
||||
)
|
||||
logger.info("[temp-mail.org] returned back to inbox")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return message_text
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.warning("[temp-mail.org] No messages received within 60 seconds")
|
||||
return None
|
||||
|
||||
|
||||
def extract_email(text: str) -> str | None:
|
||||
match = re.search(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}", text)
|
||||
return match.group(0) if match else None
|
||||
100
src/email_providers/ten_minute_mail.py
Normal file
100
src/email_providers/ten_minute_mail.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from playwright.async_api import BrowserContext, Page
|
||||
|
||||
from .base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TenMinuteMailProvider(BaseProvider):
|
||||
def __init__(self, browser_session: BrowserContext):
|
||||
super().__init__(browser_session)
|
||||
self.page: Page | None = None
|
||||
|
||||
async def _ensure_page(self) -> Page:
|
||||
if self.page is None or self.page.is_closed():
|
||||
self.page = await self.browser_session.new_page()
|
||||
return self.page
|
||||
|
||||
async def get_new_email(self) -> str:
|
||||
page = await self._ensure_page()
|
||||
logger.info("[10min] Opening https://10minutemail.com")
|
||||
await page.goto("https://10minutemail.com", wait_until="domcontentloaded")
|
||||
await page.wait_for_timeout(3000)
|
||||
|
||||
email_input = page.locator("#mail_address")
|
||||
await email_input.first.wait_for(state="visible", timeout=60000)
|
||||
|
||||
email = (await email_input.first.input_value()).strip()
|
||||
if not email or "@" not in email:
|
||||
raise RuntimeError("10MinuteMail did not return a valid email")
|
||||
|
||||
logger.info("[10min] New email acquired: %s", email)
|
||||
return email
|
||||
|
||||
async def get_latest_message(self, email: str) -> str | None:
|
||||
page = await self._ensure_page()
|
||||
logger.info("[10min] Waiting for latest message for %s", email)
|
||||
|
||||
seen_count = 0
|
||||
for attempt in range(60):
|
||||
try:
|
||||
count = await page.evaluate(
|
||||
"""
|
||||
async () => {
|
||||
const response = await fetch('/messages/messageCount', { credentials: 'include' });
|
||||
const data = await response.json();
|
||||
return Number(data.messageCount || 0);
|
||||
}
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
count = 0
|
||||
|
||||
if count > 0:
|
||||
if count != seen_count:
|
||||
logger.info("[10min] Inbox has %s message(s)", count)
|
||||
seen_count = count
|
||||
|
||||
try:
|
||||
messages = await page.evaluate(
|
||||
"""
|
||||
async () => {
|
||||
const response = await fetch('/messages/messagesAfter/0', { credentials: 'include' });
|
||||
const data = await response.json();
|
||||
return Array.isArray(data) ? data : [];
|
||||
}
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
messages = []
|
||||
|
||||
text = ""
|
||||
if messages:
|
||||
latest = messages[-1]
|
||||
subject = str(latest.get("subject") or "")
|
||||
sender = str(latest.get("sender") or "")
|
||||
body_plain = str(latest.get("bodyPlainText") or "")
|
||||
body_html = str(latest.get("bodyHtmlContent") or "")
|
||||
text = "\n".join(
|
||||
part
|
||||
for part in [subject, sender, body_plain, body_html]
|
||||
if part
|
||||
)
|
||||
|
||||
if text:
|
||||
logger.info("[10min] Latest message received")
|
||||
return text
|
||||
|
||||
if attempt % 3 == 0:
|
||||
try:
|
||||
await page.reload(wait_until="domcontentloaded", timeout=60000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.warning("[10min] No messages received within timeout")
|
||||
return None
|
||||
|
|
@ -12,4 +12,4 @@ cleanup() {
|
|||
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
exec /app/.venv/bin/python -u proxy.py
|
||||
exec /app/.venv/bin/python -u server.py
|
||||
|
|
|
|||
|
|
@ -1,397 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
import secrets
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
import pkce
|
||||
from urllib.parse import urlencode, urlparse, parse_qs
|
||||
from playwright.async_api import async_playwright, Page, Browser
|
||||
from tokens import DATA_DIR, TOKENS_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize"
|
||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
REDIRECT_URI = "http://localhost:1455/auth/callback"
|
||||
SCOPE = "openid profile email offline_access"
|
||||
|
||||
|
||||
class AutomationError(Exception):
|
||||
def __init__(self, step: str, message: str, page: Page | None = None):
|
||||
self.step = step
|
||||
self.message = message
|
||||
self.page = page
|
||||
super().__init__(f"[{step}] {message}")
|
||||
|
||||
|
||||
async def save_error_screenshot(page: Page | None, step: str):
|
||||
if page:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshots_dir = DATA_DIR / "screenshots"
|
||||
screenshots_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = screenshots_dir / f"error_{step}_{timestamp}.png"
|
||||
try:
|
||||
await page.screenshot(path=str(filename))
|
||||
logger.error(f"Screenshot saved: {filename}")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def generate_pkce():
|
||||
return pkce.generate_pkce_pair()
|
||||
|
||||
|
||||
def generate_state():
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def create_auth_url(verifier: str, challenge: str, state: str) -> str:
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": CLIENT_ID,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"scope": SCOPE,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
"originator": "opencode",
|
||||
}
|
||||
return f"{AUTHORIZE_URL}?{urlencode(params)}"
|
||||
|
||||
|
||||
async def get_temp_email(page: Page) -> str:
|
||||
logger.info("Getting temp email...")
|
||||
for i in range(30):
|
||||
mail_input = page.locator("#mail")
|
||||
if await mail_input.count() > 0:
|
||||
val = await mail_input.input_value()
|
||||
if val and "@" in val:
|
||||
logger.info(f"Got email: {val}")
|
||||
return val
|
||||
await page.wait_for_timeout(1000)
|
||||
raise AutomationError("get_email", "Failed to get email", page)
|
||||
|
||||
|
||||
async def get_verification_code(page: Page, used_codes: list | None = None) -> str:
|
||||
logger.info("Waiting for verification code...")
|
||||
if used_codes is None:
|
||||
used_codes = []
|
||||
await page.wait_for_timeout(10000)
|
||||
|
||||
for attempt in range(20):
|
||||
mail_items = page.locator(".inbox-dataList ul li")
|
||||
count = await mail_items.count()
|
||||
logger.debug(f"Attempt {attempt + 1}: {count} emails")
|
||||
|
||||
if count > 0:
|
||||
codes = []
|
||||
for i in range(count):
|
||||
try:
|
||||
item = mail_items.nth(i)
|
||||
text = await item.inner_text()
|
||||
match = re.search(
|
||||
r"Your ChatGPT code is (\d{6})", text, re.IGNORECASE
|
||||
)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
if code not in used_codes:
|
||||
codes.append(code)
|
||||
except:
|
||||
pass
|
||||
|
||||
if codes:
|
||||
logger.info(f"Got code: {codes[0]}")
|
||||
return codes[0]
|
||||
|
||||
await page.wait_for_timeout(5000)
|
||||
await page.reload(wait_until="domcontentloaded")
|
||||
await page.wait_for_timeout(5000)
|
||||
|
||||
raise AutomationError("get_code", "Code not found", page)
|
||||
|
||||
|
||||
async def fill_date_field(page: Page, month: str, day: str, year: str):
|
||||
async def type_segment(segment_type: str, value: str):
|
||||
field = page.locator(f'[data-type="{segment_type}"]')
|
||||
if await field.count() == 0:
|
||||
raise AutomationError(
|
||||
"profile", f"Missing birthday segment: {segment_type}", page
|
||||
)
|
||||
|
||||
target = field.first
|
||||
await target.scroll_into_view_if_needed()
|
||||
await target.focus()
|
||||
await page.keyboard.press("Control+A")
|
||||
await page.keyboard.press("Backspace")
|
||||
await page.keyboard.type(value)
|
||||
await page.wait_for_timeout(200)
|
||||
|
||||
await type_segment("month", month)
|
||||
await type_segment("day", day)
|
||||
await type_segment("year", year)
|
||||
|
||||
|
||||
def generate_name():
|
||||
first_names = [
|
||||
"Alex",
|
||||
"Jordan",
|
||||
"Taylor",
|
||||
"Morgan",
|
||||
"Casey",
|
||||
"Riley",
|
||||
"Quinn",
|
||||
"Avery",
|
||||
"Parker",
|
||||
"Blake",
|
||||
]
|
||||
last_names = [
|
||||
"Smith",
|
||||
"Johnson",
|
||||
"Williams",
|
||||
"Brown",
|
||||
"Jones",
|
||||
"Davis",
|
||||
"Miller",
|
||||
"Wilson",
|
||||
"Moore",
|
||||
"Clark",
|
||||
]
|
||||
return f"{random.choice(first_names)} {random.choice(last_names)}"
|
||||
|
||||
|
||||
async def exchange_code_for_tokens(code: str, verifier: str) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": CLIENT_ID,
|
||||
"code": code,
|
||||
"code_verifier": verifier,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
}
|
||||
|
||||
async with session.post(TOKEN_URL, data=data) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
raise Exception(f"Token exchange failed: {resp.status} {text}")
|
||||
|
||||
json_resp = await resp.json()
|
||||
return {
|
||||
"access_token": json_resp["access_token"],
|
||||
"refresh_token": json_resp["refresh_token"],
|
||||
"expires_in": json_resp["expires_in"],
|
||||
}
|
||||
|
||||
|
||||
async def get_new_token(headless: bool = False) -> bool:
|
||||
logger.info("=== Starting token generation ===")
|
||||
|
||||
password = "TempPass123!"
|
||||
full_name = generate_name()
|
||||
birth_month, birth_day, birth_year = "01", "15", "1995"
|
||||
|
||||
verifier, challenge = generate_pkce()
|
||||
state = generate_state()
|
||||
auth_url = create_auth_url(verifier, challenge, state)
|
||||
|
||||
redirect_url_captured = None
|
||||
browser: Browser | None = None
|
||||
current_page: Page | None = None
|
||||
|
||||
try:
|
||||
async with async_playwright() as p:
|
||||
chromium_path = os.environ.get("CHROMIUM_PATH")
|
||||
if chromium_path:
|
||||
browser = await p.chromium.launch(
|
||||
headless=headless,
|
||||
executable_path=chromium_path,
|
||||
)
|
||||
else:
|
||||
browser = await p.chromium.launch(headless=headless)
|
||||
context = await browser.new_context()
|
||||
page = await context.new_page()
|
||||
current_page = page
|
||||
|
||||
logger.info("[1/6] Getting email...")
|
||||
await page.goto("https://temp-mail.org", wait_until="domcontentloaded")
|
||||
email = await get_temp_email(page)
|
||||
tempmail_page = page
|
||||
|
||||
logger.info("[2/6] Registering ChatGPT...")
|
||||
chatgpt_page = await context.new_page()
|
||||
current_page = chatgpt_page
|
||||
await chatgpt_page.goto("https://chatgpt.com")
|
||||
await chatgpt_page.wait_for_load_state("domcontentloaded")
|
||||
|
||||
await chatgpt_page.get_by_text("Sign up for free", exact=True).click()
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
await chatgpt_page.locator('input[type="email"]').fill(email)
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await chatgpt_page.wait_for_timeout(3000)
|
||||
|
||||
await chatgpt_page.locator('input[type="password"]').fill(password)
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
logger.info("[3/6] Getting verification code...")
|
||||
await tempmail_page.bring_to_front()
|
||||
code = await get_verification_code(tempmail_page)
|
||||
|
||||
await chatgpt_page.bring_to_front()
|
||||
code_input = chatgpt_page.get_by_placeholder("Code")
|
||||
if await code_input.count() > 0:
|
||||
await code_input.fill(code)
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
continue_btn = chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
)
|
||||
if await continue_btn.count() > 0:
|
||||
await continue_btn.click()
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
logger.info("[4/6] Setting profile...")
|
||||
name_input = chatgpt_page.get_by_placeholder("Full name")
|
||||
if await name_input.count() > 0:
|
||||
await name_input.fill(full_name)
|
||||
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
||||
await chatgpt_page.wait_for_timeout(1000)
|
||||
|
||||
continue_btn = chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
)
|
||||
if await continue_btn.count() > 0:
|
||||
await continue_btn.click()
|
||||
|
||||
logger.info("Account registered!")
|
||||
await chatgpt_page.wait_for_timeout(10000)
|
||||
await chatgpt_page.wait_for_load_state("networkidle", timeout=30000)
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
used_codes = [code]
|
||||
|
||||
logger.info("[5/6] OAuth flow...")
|
||||
oauth_page = await context.new_page()
|
||||
current_page = oauth_page
|
||||
|
||||
def handle_request(request):
|
||||
nonlocal redirect_url_captured
|
||||
url = request.url
|
||||
if "localhost:1455" in url and "code=" in url:
|
||||
logger.info("Redirect URL captured!")
|
||||
redirect_url_captured = url
|
||||
|
||||
oauth_page.on("request", handle_request)
|
||||
|
||||
await oauth_page.goto(auth_url)
|
||||
await oauth_page.wait_for_load_state("domcontentloaded")
|
||||
await oauth_page.wait_for_timeout(3000)
|
||||
|
||||
await oauth_page.locator('input[type="email"], input[name="email"]').fill(
|
||||
email
|
||||
)
|
||||
await oauth_page.wait_for_timeout(500)
|
||||
await oauth_page.get_by_role("button", name="Continue", exact=True).click()
|
||||
await oauth_page.wait_for_timeout(3000)
|
||||
|
||||
password_input = oauth_page.locator('input[type="password"]')
|
||||
if await password_input.count() > 0:
|
||||
await password_input.fill(password)
|
||||
await oauth_page.wait_for_timeout(500)
|
||||
await oauth_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await oauth_page.wait_for_timeout(5000)
|
||||
|
||||
await tempmail_page.bring_to_front()
|
||||
await tempmail_page.reload(wait_until="domcontentloaded")
|
||||
await tempmail_page.wait_for_timeout(3000)
|
||||
|
||||
try:
|
||||
oauth_code = await get_verification_code(tempmail_page, used_codes)
|
||||
except AutomationError:
|
||||
logger.info("Reopening mail...")
|
||||
tempmail_page = await context.new_page()
|
||||
current_page = tempmail_page
|
||||
await tempmail_page.goto(
|
||||
"https://temp-mail.org", wait_until="domcontentloaded"
|
||||
)
|
||||
await tempmail_page.wait_for_timeout(10000)
|
||||
oauth_code = await get_verification_code(tempmail_page, used_codes)
|
||||
|
||||
await oauth_page.bring_to_front()
|
||||
code_input = oauth_page.get_by_placeholder("Code")
|
||||
if await code_input.count() > 0:
|
||||
await code_input.fill(oauth_code)
|
||||
await oauth_page.wait_for_timeout(500)
|
||||
await oauth_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await oauth_page.wait_for_timeout(5000)
|
||||
|
||||
for btn_text in ["Continue", "Allow", "Authorize"]:
|
||||
btn = oauth_page.get_by_role("button", name=btn_text, exact=True)
|
||||
if await btn.count() > 0:
|
||||
await btn.click()
|
||||
break
|
||||
|
||||
await oauth_page.wait_for_timeout(5000)
|
||||
|
||||
logger.info("[6/6] Exchanging code for tokens...")
|
||||
if redirect_url_captured and "code=" in redirect_url_captured:
|
||||
parsed = urlparse(redirect_url_captured)
|
||||
params = parse_qs(parsed.query)
|
||||
auth_code = params.get("code", [None])[0]
|
||||
|
||||
if auth_code:
|
||||
tokens = await exchange_code_for_tokens(auth_code, verifier)
|
||||
|
||||
token_data = {
|
||||
"access_token": tokens["access_token"],
|
||||
"refresh_token": tokens["refresh_token"],
|
||||
"expires_at": time.time() + tokens["expires_in"],
|
||||
}
|
||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(TOKENS_FILE, "w") as f:
|
||||
json.dump(token_data, f, indent=2)
|
||||
|
||||
logger.info(f"Tokens saved to {TOKENS_FILE}")
|
||||
return True
|
||||
|
||||
raise AutomationError("token_exchange", "Failed to get tokens", oauth_page)
|
||||
|
||||
except AutomationError as e:
|
||||
logger.error(f"Error at step [{e.step}]: {e.message}")
|
||||
await save_error_screenshot(e.page, e.step)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
await save_error_screenshot(current_page, "unexpected")
|
||||
return False
|
||||
finally:
|
||||
if browser:
|
||||
await asyncio.sleep(2)
|
||||
await browser.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
success = asyncio.run(get_new_token())
|
||||
exit(0 if success else 1)
|
||||
3
src/providers/__init__.py
Normal file
3
src/providers/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .base import Provider, ProviderTokens
|
||||
|
||||
__all__ = ["Provider", "ProviderTokens"]
|
||||
54
src/providers/base.py
Normal file
54
src/providers/base.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderTokens:
|
||||
"""Base token structure for any provider"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str | None
|
||||
expires_at: float
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
import time
|
||||
|
||||
return time.time() >= self.expires_at - 10
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""Base class for all account providers"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name (e.g., 'chatgpt', 'claude')"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_token(self) -> str | None:
|
||||
"""Get valid access token, refreshing if needed"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_new_account(self) -> bool:
|
||||
"""Register a new account and get tokens"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Get usage information for the current token"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_tokens(self) -> ProviderTokens | None:
|
||||
"""Load tokens from storage"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_tokens(self, tokens: ProviderTokens) -> None:
|
||||
"""Save tokens to storage"""
|
||||
pass
|
||||
3
src/providers/chatgpt/__init__.py
Normal file
3
src/providers/chatgpt/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .provider import ChatGPTProvider
|
||||
|
||||
__all__ = ["ChatGPTProvider"]
|
||||
68
src/providers/chatgpt/provider.py
Normal file
68
src/providers/chatgpt/provider.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import logging
|
||||
from typing import Callable
|
||||
from typing import Any
|
||||
|
||||
from playwright.async_api import BrowserContext
|
||||
|
||||
from providers.base import Provider, ProviderTokens
|
||||
from email_providers import BaseProvider
|
||||
from email_providers import TempMailOrgProvider
|
||||
from .tokens import load_tokens, save_tokens, get_valid_tokens
|
||||
from .usage import get_usage_percent
|
||||
from .registration import register_chatgpt_account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGPTProvider(Provider):
|
||||
"""ChatGPT account provider"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||
):
|
||||
self.email_provider_factory = email_provider_factory or TempMailOrgProvider
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "chatgpt"
|
||||
|
||||
async def get_token(self) -> str | None:
|
||||
"""Get valid access token, refreshing if needed"""
|
||||
tokens = await get_valid_tokens()
|
||||
if not tokens:
|
||||
logger.info("No valid tokens, registering new account")
|
||||
success = await self.register_new_account()
|
||||
if not success:
|
||||
return None
|
||||
tokens = await get_valid_tokens()
|
||||
if not tokens:
|
||||
return None
|
||||
return tokens.access_token
|
||||
|
||||
async def register_new_account(self) -> bool:
|
||||
"""Register a new ChatGPT account"""
|
||||
return await register_chatgpt_account(
|
||||
email_provider_factory=self.email_provider_factory,
|
||||
)
|
||||
|
||||
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Get usage information for the current token"""
|
||||
usage_percent = get_usage_percent(access_token)
|
||||
if usage_percent < 0:
|
||||
return {"error": "Failed to get usage"}
|
||||
|
||||
remaining = max(0, 100 - usage_percent)
|
||||
return {
|
||||
"used_percent": usage_percent,
|
||||
"remaining_percent": remaining,
|
||||
"exhausted": usage_percent >= 100,
|
||||
}
|
||||
|
||||
def load_tokens(self) -> ProviderTokens | None:
|
||||
"""Load tokens from storage"""
|
||||
return load_tokens()
|
||||
|
||||
def save_tokens(self, tokens: ProviderTokens) -> None:
|
||||
"""Save tokens to storage"""
|
||||
save_tokens(tokens)
|
||||
477
src/providers/chatgpt/registration.py
Normal file
477
src/providers/chatgpt/registration.py
Normal file
|
|
@ -0,0 +1,477 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import os
|
||||
from typing import Callable
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import aiohttp
|
||||
from playwright.async_api import async_playwright, Page, BrowserContext
|
||||
|
||||
from browser import launch as launch_browser
|
||||
from email_providers import BaseProvider
|
||||
from providers.base import ProviderTokens
|
||||
from .tokens import CLIENT_ID, save_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
|
||||
AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize"
|
||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
REDIRECT_URI = "http://localhost:1455/auth/callback"
|
||||
SCOPE = "openid profile email offline_access"
|
||||
|
||||
|
||||
class AutomationError(Exception):
|
||||
def __init__(self, step: str, message: str, page: Page | None = None):
|
||||
self.step = step
|
||||
self.message = message
|
||||
self.page = page
|
||||
super().__init__(f"[{step}] {message}")
|
||||
|
||||
|
||||
async def save_error_screenshot(page: Page | None, step: str):
|
||||
if page:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshots_dir = DATA_DIR / "screenshots"
|
||||
screenshots_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = screenshots_dir / f"error_{step}_{timestamp}.png"
|
||||
try:
|
||||
await page.screenshot(path=str(filename))
|
||||
logger.error(f"Screenshot saved: {filename}")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def generate_password(length: int = 20) -> str:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return "".join(random.choice(alphabet) for _ in range(length))
|
||||
|
||||
|
||||
def generate_name() -> str:
|
||||
first_names = [
|
||||
"James",
|
||||
"John",
|
||||
"Robert",
|
||||
"Michael",
|
||||
"William",
|
||||
"David",
|
||||
"Richard",
|
||||
"Joseph",
|
||||
"Thomas",
|
||||
"Charles",
|
||||
"Christopher",
|
||||
"Daniel",
|
||||
"Matthew",
|
||||
"Anthony",
|
||||
"Mark",
|
||||
"Donald",
|
||||
"Steven",
|
||||
"Paul",
|
||||
"Andrew",
|
||||
"Joshua",
|
||||
]
|
||||
last_names = [
|
||||
"Smith",
|
||||
"Johnson",
|
||||
"Williams",
|
||||
"Brown",
|
||||
"Jones",
|
||||
"Garcia",
|
||||
"Miller",
|
||||
"Davis",
|
||||
"Rodriguez",
|
||||
"Martinez",
|
||||
"Hernandez",
|
||||
"Lopez",
|
||||
"Gonzalez",
|
||||
"Wilson",
|
||||
"Anderson",
|
||||
"Thomas",
|
||||
"Taylor",
|
||||
"Moore",
|
||||
"Jackson",
|
||||
"Martin",
|
||||
]
|
||||
return f"{random.choice(first_names)} {random.choice(last_names)}"
|
||||
|
||||
|
||||
def extract_verification_code(message: str) -> str | None:
|
||||
normalized = re.sub(r"\s+", " ", message)
|
||||
|
||||
preferred = re.search(
|
||||
r"Your\s+ChatGPT\s+code\s+is\s*(\d{6})",
|
||||
normalized,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if preferred:
|
||||
return preferred.group(1)
|
||||
|
||||
openai_otp = re.search(r"OpenAI\s+otp.*?(\d{6})", normalized, re.IGNORECASE)
|
||||
if openai_otp:
|
||||
return openai_otp.group(1)
|
||||
|
||||
all_codes = re.findall(r"\b(\d{6})\b", normalized)
|
||||
if all_codes:
|
||||
return all_codes[-1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(verifier.encode("utf-8")).digest()
|
||||
challenge = base64.urlsafe_b64encode(digest).decode("utf-8").rstrip("=")
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
def generate_state() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def build_authorize_url(verifier: str, challenge: str, state: str) -> str:
|
||||
del verifier
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": CLIENT_ID,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"scope": SCOPE,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
"state": state,
|
||||
"originator": "opencode",
|
||||
}
|
||||
return f"{AUTHORIZE_URL}?{urlencode(params)}"
|
||||
|
||||
|
||||
async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": CLIENT_ID,
|
||||
"code": code,
|
||||
"code_verifier": verifier,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
}
|
||||
async with session.post(TOKEN_URL, data=payload) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
raise RuntimeError(f"Token exchange failed: {resp.status} {text}")
|
||||
body = await resp.json()
|
||||
|
||||
expires_in = int(body["expires_in"])
|
||||
return ProviderTokens(
|
||||
access_token=body["access_token"],
|
||||
refresh_token=body["refresh_token"],
|
||||
expires_at=time.time() + expires_in,
|
||||
)
|
||||
|
||||
|
||||
async def get_new_verification_code(
|
||||
email_provider: BaseProvider,
|
||||
email: str,
|
||||
used_codes: set[str],
|
||||
timeout_seconds: int = 240,
|
||||
) -> str | None:
|
||||
attempts = max(1, timeout_seconds // 5)
|
||||
for _ in range(attempts):
|
||||
message = await email_provider.get_latest_message(email)
|
||||
if message:
|
||||
all_codes = re.findall(r"\b(\d{6})\b", message)
|
||||
for candidate in all_codes:
|
||||
if candidate not in used_codes:
|
||||
return candidate
|
||||
await asyncio.sleep(5)
|
||||
return None
|
||||
|
||||
|
||||
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
||||
message = await email_provider.get_latest_message(email)
|
||||
if not message:
|
||||
return None
|
||||
return extract_verification_code(message)
|
||||
|
||||
|
||||
async def fill_date_field(page: Page, month: str, day: str, year: str):
|
||||
month_field = page.locator('[data-type="month"]').first
|
||||
if await month_field.count() == 0:
|
||||
raise AutomationError("profile", "Missing birthday month field", page)
|
||||
|
||||
await month_field.scroll_into_view_if_needed()
|
||||
await month_field.click()
|
||||
await page.wait_for_timeout(120)
|
||||
|
||||
await page.keyboard.type(f"{month}{day}{year}")
|
||||
await page.wait_for_timeout(200)
|
||||
|
||||
|
||||
async def wait_for_signup_stabilization(page: Page):
|
||||
try:
|
||||
await page.wait_for_load_state("networkidle", timeout=15000)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Signup page did not reach networkidle quickly; continuing with fallback"
|
||||
)
|
||||
try:
|
||||
await page.wait_for_load_state("domcontentloaded", timeout=10000)
|
||||
except Exception:
|
||||
pass
|
||||
await page.wait_for_timeout(3000)
|
||||
|
||||
|
||||
async def register_chatgpt_account(
|
||||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||
) -> bool:
|
||||
logger.info("=== Starting ChatGPT account registration ===")
|
||||
|
||||
if email_provider_factory is None:
|
||||
logger.error("No email provider factory configured")
|
||||
return False
|
||||
|
||||
birth_month, birth_day, birth_year = "01", "15", "1995"
|
||||
|
||||
current_page: Page | None = None
|
||||
redirect_url_captured: str | None = None
|
||||
managed = None
|
||||
|
||||
try:
|
||||
async with async_playwright() as p:
|
||||
managed = await launch_browser(p)
|
||||
browser = managed.browser
|
||||
context = (
|
||||
browser.contexts[0] if browser.contexts else await browser.new_context()
|
||||
)
|
||||
email_provider = email_provider_factory(context)
|
||||
|
||||
logger.info("[1/6] Getting new email from configured provider...")
|
||||
email = await email_provider.get_new_email()
|
||||
if not email:
|
||||
raise AutomationError(
|
||||
"email_provider", "Email provider returned empty email"
|
||||
)
|
||||
|
||||
password = generate_password()
|
||||
full_name = generate_name()
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
oauth_state = generate_state()
|
||||
authorize_url = build_authorize_url(verifier, challenge, oauth_state)
|
||||
|
||||
logger.info("[2/6] Registering ChatGPT for %s", email)
|
||||
chatgpt_page = await context.new_page()
|
||||
current_page = chatgpt_page
|
||||
await chatgpt_page.goto("https://chatgpt.com")
|
||||
await chatgpt_page.wait_for_load_state("domcontentloaded")
|
||||
|
||||
await chatgpt_page.get_by_text("Sign up for free", exact=True).click()
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
await chatgpt_page.locator('input[type="email"]').fill(email)
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await chatgpt_page.wait_for_timeout(3000)
|
||||
|
||||
await chatgpt_page.locator('input[type="password"]').fill(password)
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
).click()
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
logger.info("[3/6] Getting verification message from email provider...")
|
||||
code = await get_latest_code(email_provider, email)
|
||||
if not code:
|
||||
raise AutomationError(
|
||||
"email_provider", "Email provider returned no verification message"
|
||||
)
|
||||
logger.info("[3/6] Verification code extracted: %s", code)
|
||||
used_codes = {code}
|
||||
|
||||
await chatgpt_page.bring_to_front()
|
||||
code_input = chatgpt_page.get_by_placeholder("Code")
|
||||
if await code_input.count() > 0:
|
||||
await code_input.fill(code)
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
continue_btn = chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
)
|
||||
if await continue_btn.count() > 0:
|
||||
await continue_btn.click()
|
||||
await chatgpt_page.wait_for_timeout(5000)
|
||||
|
||||
logger.info("[4/6] Setting profile...")
|
||||
name_input = chatgpt_page.get_by_placeholder("Full name")
|
||||
if await name_input.count() > 0:
|
||||
await name_input.fill(full_name)
|
||||
|
||||
await chatgpt_page.wait_for_timeout(500)
|
||||
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
||||
await chatgpt_page.wait_for_timeout(1000)
|
||||
|
||||
continue_btn = chatgpt_page.get_by_role(
|
||||
"button", name="Continue", exact=True
|
||||
)
|
||||
if await continue_btn.count() > 0:
|
||||
await continue_btn.click()
|
||||
|
||||
logger.info("Account registered!")
|
||||
await wait_for_signup_stabilization(chatgpt_page)
|
||||
|
||||
logger.info("[5/6] Skipping onboarding...")
|
||||
|
||||
for _ in range(5):
|
||||
skip_btn = chatgpt_page.locator(
|
||||
'button:has-text("Skip"):not(:has-text("Skip Tour"))'
|
||||
)
|
||||
if await skip_btn.count() > 0:
|
||||
for i in range(await skip_btn.count()):
|
||||
try:
|
||||
btn = skip_btn.nth(i)
|
||||
if await btn.is_visible():
|
||||
await btn.click(timeout=5000)
|
||||
logger.info("Clicked: Skip")
|
||||
await chatgpt_page.wait_for_timeout(1500)
|
||||
except:
|
||||
pass
|
||||
await chatgpt_page.wait_for_timeout(1000)
|
||||
|
||||
skip_tour = chatgpt_page.locator('button:has-text("Skip Tour")')
|
||||
if await skip_tour.count() > 0:
|
||||
try:
|
||||
await skip_tour.first.wait_for(state="visible", timeout=5000)
|
||||
await skip_tour.first.click(timeout=5000)
|
||||
logger.info("Clicked: Skip Tour")
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
except:
|
||||
pass
|
||||
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
for _ in range(3):
|
||||
continue_btn = chatgpt_page.locator('button:has-text("Continue")')
|
||||
if await continue_btn.count() > 0:
|
||||
try:
|
||||
await continue_btn.first.wait_for(state="visible", timeout=5000)
|
||||
await continue_btn.first.click(timeout=5000)
|
||||
logger.info("Clicked: Continue")
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
except:
|
||||
pass
|
||||
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
okay_btn = chatgpt_page.locator('button:has-text("Okay, let")')
|
||||
for _ in range(10):
|
||||
try:
|
||||
await okay_btn.first.wait_for(state="visible", timeout=3000)
|
||||
await okay_btn.first.click(timeout=5000)
|
||||
logger.info("Clicked: Okay, let's go")
|
||||
await chatgpt_page.wait_for_timeout(3000)
|
||||
break
|
||||
except:
|
||||
await chatgpt_page.wait_for_timeout(1000)
|
||||
|
||||
logger.info("Skipping subscription/card flow (disabled)")
|
||||
await chatgpt_page.wait_for_timeout(2000)
|
||||
|
||||
logger.info("[6/6] Running OAuth flow to get tokens...")
|
||||
oauth_page = await context.new_page()
|
||||
current_page = oauth_page
|
||||
|
||||
def handle_request(request):
|
||||
nonlocal redirect_url_captured
|
||||
url = request.url
|
||||
if "localhost:1455" in url and "code=" in url:
|
||||
redirect_url_captured = url
|
||||
logger.info("Captured OAuth redirect URL")
|
||||
|
||||
oauth_page.on("request", handle_request)
|
||||
|
||||
await oauth_page.goto(authorize_url, wait_until="domcontentloaded")
|
||||
await oauth_page.wait_for_timeout(2000)
|
||||
|
||||
email_input = oauth_page.locator('input[type="email"], input[name="email"]')
|
||||
if await email_input.count() > 0:
|
||||
await email_input.first.fill(email)
|
||||
await oauth_page.wait_for_timeout(400)
|
||||
|
||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||
if await continue_button.count() > 0:
|
||||
await continue_button.first.click()
|
||||
await oauth_page.wait_for_timeout(2500)
|
||||
|
||||
password_input = oauth_page.locator('input[type="password"]')
|
||||
if await password_input.count() > 0:
|
||||
await password_input.first.fill(password)
|
||||
await oauth_page.wait_for_timeout(400)
|
||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||
if await continue_button.count() > 0:
|
||||
await continue_button.first.click()
|
||||
await oauth_page.wait_for_timeout(2500)
|
||||
|
||||
for label in ["Continue", "Allow", "Authorize"]:
|
||||
button = oauth_page.get_by_role("button", name=label)
|
||||
if await button.count() > 0:
|
||||
try:
|
||||
await button.first.click(timeout=5000)
|
||||
await oauth_page.wait_for_timeout(2000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not redirect_url_captured:
|
||||
try:
|
||||
await oauth_page.wait_for_timeout(4000)
|
||||
current_url = oauth_page.url
|
||||
if "localhost:1455" in current_url and "code=" in current_url:
|
||||
redirect_url_captured = current_url
|
||||
logger.info("Captured OAuth redirect from page URL")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not redirect_url_captured:
|
||||
raise AutomationError(
|
||||
"oauth", "OAuth redirect with code was not captured", oauth_page
|
||||
)
|
||||
|
||||
parsed = urlparse(redirect_url_captured)
|
||||
params = parse_qs(parsed.query)
|
||||
auth_code = params.get("code", [None])[0]
|
||||
returned_state = params.get("state", [None])[0]
|
||||
|
||||
if not auth_code:
|
||||
raise AutomationError(
|
||||
"oauth", "OAuth code missing in redirect", oauth_page
|
||||
)
|
||||
if returned_state != oauth_state:
|
||||
raise AutomationError("oauth", "OAuth state mismatch", oauth_page)
|
||||
|
||||
tokens = await exchange_code_for_tokens(auth_code, verifier)
|
||||
save_tokens(tokens)
|
||||
logger.info("OAuth tokens saved successfully")
|
||||
|
||||
return True
|
||||
|
||||
except AutomationError as e:
|
||||
logger.error(f"Error at step [{e.step}]: {e.message}")
|
||||
await save_error_screenshot(e.page, e.step)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
await save_error_screenshot(current_page, "unexpected")
|
||||
return False
|
||||
finally:
|
||||
if managed:
|
||||
await asyncio.sleep(2)
|
||||
await managed.close()
|
||||
|
|
@ -1,46 +1,34 @@
|
|||
import json
|
||||
import time
|
||||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
|
||||
from providers.base import ProviderTokens
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
|
||||
TOKENS_FILE = DATA_DIR / "tokens.json"
|
||||
TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json"
|
||||
|
||||
CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tokens:
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_at: float # unix timestamp
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return time.time() >= self.expires_at - 10
|
||||
|
||||
|
||||
def load_tokens() -> Tokens | None:
|
||||
def load_tokens() -> ProviderTokens | None:
|
||||
if not TOKENS_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
with open(TOKENS_FILE) as f:
|
||||
data = json.load(f)
|
||||
access_token = data["access_token"]
|
||||
return Tokens(
|
||||
access_token=access_token,
|
||||
return ProviderTokens(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=data["refresh_token"],
|
||||
expires_at=data["expires_at"],
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
except json.JSONDecodeError, KeyError:
|
||||
return None
|
||||
|
||||
|
||||
def save_tokens(tokens: Tokens):
|
||||
def save_tokens(tokens: ProviderTokens):
|
||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(TOKENS_FILE, "w") as f:
|
||||
json.dump(
|
||||
|
|
@ -54,7 +42,7 @@ def save_tokens(tokens: Tokens):
|
|||
)
|
||||
|
||||
|
||||
async def refresh_tokens(refresh_token: str) -> Tokens | None:
|
||||
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
|
|
@ -68,14 +56,14 @@ async def refresh_tokens(refresh_token: str) -> Tokens | None:
|
|||
return None
|
||||
json_resp = await resp.json()
|
||||
expires_in = json_resp["expires_in"]
|
||||
return Tokens(
|
||||
return ProviderTokens(
|
||||
access_token=json_resp["access_token"],
|
||||
refresh_token=json_resp["refresh_token"],
|
||||
expires_at=time.time() + expires_in,
|
||||
)
|
||||
|
||||
|
||||
async def get_valid_tokens() -> Tokens | None:
|
||||
async def get_valid_tokens() -> ProviderTokens | None:
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
print("No tokens found")
|
||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
def clamp_percent(value: Any) -> int:
|
||||
try:
|
||||
num = float(value)
|
||||
except (TypeError, ValueError):
|
||||
except TypeError, ValueError:
|
||||
return 0
|
||||
if num < 0:
|
||||
return 0
|
||||
|
|
@ -35,7 +35,7 @@ def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
|||
body = res.read().decode("utf-8", errors="replace")
|
||||
except urllib.error.HTTPError as e:
|
||||
return -1
|
||||
except (urllib.error.URLError, socket.timeout):
|
||||
except urllib.error.URLError, socket.timeout:
|
||||
return -1
|
||||
|
||||
try:
|
||||
|
|
@ -48,14 +48,3 @@ def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
|||
return clamp_percent(primary.get("used_percent") or 0)
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from tokens import load_tokens
|
||||
|
||||
tokens = load_tokens()
|
||||
if tokens:
|
||||
usage = get_usage_percent(tokens.access_token)
|
||||
print(f"{usage}%")
|
||||
else:
|
||||
print("No tokens")
|
||||
436
src/proxy.py
436
src/proxy.py
|
|
@ -1,436 +0,0 @@
|
|||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import secrets
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
|
||||
from tokens import get_valid_tokens, load_tokens, DATA_DIR
|
||||
from codex_usage import get_usage_percent
|
||||
from get_new_token import get_new_token
|
||||
|
||||
CODEX_BASE_URL = "https://chatgpt.com/backend-api"
|
||||
PORT = int(os.environ.get("PORT", "8080"))
|
||||
USAGE_THRESHOLD = int(os.environ.get("USAGE_THRESHOLD", "85"))
|
||||
CHECK_INTERVAL = int(os.environ.get("CHECK_INTERVAL", "60"))
|
||||
FAKE_EXPIRES_IN = 9999999999999
|
||||
AUTH_FILE = DATA_DIR / "auth.json"
|
||||
JWT_AUTH_CLAIM_PATH = "https://api.openai.com/auth"
|
||||
JWT_PROFILE_CLAIM_PATH = "https://api.openai.com/profile"
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
refresh_in_progress = False
|
||||
auth_codes: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _b64url(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
|
||||
|
||||
|
||||
def _generate_jwt_like() -> str:
|
||||
account_id = str(uuid.uuid4())
|
||||
now = int(time.time())
|
||||
header = {"alg": "HS256", "typ": "JWT"}
|
||||
user_id = f"user-{secrets.token_urlsafe(18)}"
|
||||
account_user_id = f"{user_id}__{account_id}"
|
||||
payload = {
|
||||
"aud": ["https://api.openai.com/v1"],
|
||||
"client_id": "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
"iss": "https://auth.openai.com",
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + 315360000,
|
||||
"jti": str(uuid.uuid4()),
|
||||
"scp": ["openid", "profile", "email", "offline_access"],
|
||||
"session_id": f"authsess_{secrets.token_urlsafe(24)}",
|
||||
JWT_AUTH_CLAIM_PATH: {
|
||||
"chatgpt_account_id": account_id,
|
||||
"chatgpt_account_user_id": account_user_id,
|
||||
"chatgpt_compute_residency": "no_constraint",
|
||||
"chatgpt_plan_type": "plus",
|
||||
"chatgpt_user_id": user_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
JWT_PROFILE_CLAIM_PATH: {
|
||||
"email": f"proxy-{secrets.token_hex(4)}@example.local",
|
||||
"email_verified": True,
|
||||
},
|
||||
"sub": f"auth0|{secrets.token_urlsafe(20)}",
|
||||
}
|
||||
head = _b64url(json.dumps(header, separators=(",", ":")).encode("utf-8"))
|
||||
body = _b64url(json.dumps(payload, separators=(",", ":")).encode("utf-8"))
|
||||
sign = _b64url(secrets.token_bytes(32))
|
||||
return f"{head}.{body}.{sign}"
|
||||
|
||||
|
||||
def _generate_refresh_like() -> str:
|
||||
return f"rt_{secrets.token_urlsafe(40)}.{secrets.token_urlsafe(32)}"
|
||||
|
||||
|
||||
def _mask(value: str, head: int = 8, tail: int = 6) -> str:
|
||||
if not value:
|
||||
return "<empty>"
|
||||
if len(value) <= head + tail:
|
||||
return "<hidden>"
|
||||
return f"{value[:head]}...{value[-tail:]}"
|
||||
|
||||
|
||||
def load_or_create_auth() -> dict:
|
||||
if AUTH_FILE.exists():
|
||||
with open(AUTH_FILE) as f:
|
||||
data = json.load(f)
|
||||
if (
|
||||
data.get("access_token")
|
||||
and data.get("refresh_token")
|
||||
and data.get("expires_at")
|
||||
):
|
||||
return data
|
||||
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
access_token = _generate_jwt_like()
|
||||
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": _generate_refresh_like(),
|
||||
"expires_at": FAKE_EXPIRES_IN,
|
||||
}
|
||||
with open(AUTH_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
return data
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def request_log_middleware(request: web.Request, handler):
|
||||
started = time.perf_counter()
|
||||
response = None
|
||||
try:
|
||||
response = await handler(request)
|
||||
return response
|
||||
finally:
|
||||
elapsed_ms = int((time.perf_counter() - started) * 1000)
|
||||
status = getattr(response, "status", "ERR")
|
||||
logger.info(
|
||||
"%s %s -> %s (%d ms)",
|
||||
request.method,
|
||||
request.path_qs,
|
||||
status,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
|
||||
def check_auth(request: web.Request) -> bool:
|
||||
auth_data = load_or_create_auth()
|
||||
expected_token = auth_data["access_token"]
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth.lower().startswith("bearer "):
|
||||
token = auth[7:].strip()
|
||||
return token == expected_token
|
||||
return False
|
||||
|
||||
|
||||
async def oauth_authorize_handler(request: web.Request) -> web.Response:
|
||||
params = request.rel_url.query
|
||||
redirect_uri = params.get("redirect_uri")
|
||||
state = params.get("state", "")
|
||||
|
||||
if not redirect_uri:
|
||||
return web.json_response(
|
||||
{"error": "invalid_request", "error_description": "Missing redirect_uri"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
code = f"ac_{secrets.token_urlsafe(48)}"
|
||||
auth_codes[code] = {
|
||||
"state": state,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
query = urlencode(
|
||||
{
|
||||
"code": code,
|
||||
"scope": "openid profile email offline_access",
|
||||
"state": state,
|
||||
}
|
||||
)
|
||||
location = f"{redirect_uri}?{query}"
|
||||
logger.info("OAuth authorize: issued code")
|
||||
raise web.HTTPFound(location=location)
|
||||
|
||||
|
||||
async def oauth_token_handler(request: web.Request) -> web.Response:
|
||||
auth_data = load_or_create_auth()
|
||||
|
||||
content_type = request.content_type or ""
|
||||
grant_type = None
|
||||
refresh_token = None
|
||||
code = None
|
||||
if content_type.startswith("application/json"):
|
||||
body = await request.json()
|
||||
grant_type = body.get("grant_type")
|
||||
refresh_token = body.get("refresh_token")
|
||||
code = body.get("code")
|
||||
else:
|
||||
form = await request.post()
|
||||
grant_type = form.get("grant_type")
|
||||
refresh_token = form.get("refresh_token")
|
||||
code = form.get("code")
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
code = str(code) if code else ""
|
||||
if not code or code not in auth_codes:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Invalid authorization code",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
created_at = auth_codes[code]["created_at"]
|
||||
del auth_codes[code]
|
||||
if time.time() - created_at > 300:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Authorization code expired",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"access_token": auth_data["access_token"],
|
||||
"refresh_token": auth_data["refresh_token"],
|
||||
"token_type": "Bearer",
|
||||
"expires_in": FAKE_EXPIRES_IN,
|
||||
}
|
||||
)
|
||||
|
||||
if grant_type == "refresh_token":
|
||||
if refresh_token != auth_data["refresh_token"]:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Invalid refresh token",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"access_token": auth_data["access_token"],
|
||||
"refresh_token": auth_data["refresh_token"],
|
||||
"token_type": "Bearer",
|
||||
"expires_in": FAKE_EXPIRES_IN,
|
||||
}
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "unsupported_grant_type",
|
||||
"error_description": "Only authorization_code and refresh_token are supported",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_tokens_task():
|
||||
global refresh_in_progress
|
||||
if refresh_in_progress:
|
||||
logger.info("Token refresh already in progress")
|
||||
return
|
||||
|
||||
refresh_in_progress = True
|
||||
logger.info("Starting token refresh...")
|
||||
|
||||
try:
|
||||
success = await get_new_token(headless=False)
|
||||
if success:
|
||||
logger.info("Token refresh completed successfully")
|
||||
else:
|
||||
logger.error("Token refresh failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during token refresh: {e}")
|
||||
finally:
|
||||
refresh_in_progress = False
|
||||
|
||||
|
||||
async def usage_monitor():
|
||||
while True:
|
||||
for _ in range(1):
|
||||
tokens = load_tokens()
|
||||
|
||||
if not tokens:
|
||||
if not refresh_in_progress:
|
||||
logger.warning("No tokens found, starting refresh...")
|
||||
asyncio.create_task(refresh_tokens_task())
|
||||
break
|
||||
|
||||
usage = get_usage_percent(tokens.access_token)
|
||||
|
||||
if usage < 0:
|
||||
logger.warning("Failed to get usage, token may be invalid")
|
||||
asyncio.create_task(refresh_tokens_task())
|
||||
break
|
||||
|
||||
logger.info(f"Current usage: {usage}%")
|
||||
|
||||
if usage >= USAGE_THRESHOLD:
|
||||
logger.info(
|
||||
f"Usage {usage}% >= threshold {USAGE_THRESHOLD}%, starting refresh..."
|
||||
)
|
||||
asyncio.create_task(refresh_tokens_task())
|
||||
break
|
||||
|
||||
await asyncio.sleep(CHECK_INTERVAL)
|
||||
|
||||
|
||||
async def proxy_handler(request: web.Request) -> web.StreamResponse | web.Response:
|
||||
if not check_auth(request):
|
||||
auth = request.headers.get("Authorization", "")
|
||||
auth_preview = auth[:24] + ("..." if len(auth) > 24 else "")
|
||||
logger.warning(
|
||||
"Auth failed: method=%s path=%s auth_present=%s auth_preview=%s ua=%s",
|
||||
request.method,
|
||||
request.path,
|
||||
bool(auth),
|
||||
auth_preview,
|
||||
request.headers.get("User-Agent", ""),
|
||||
)
|
||||
return web.json_response({"error": "Unauthorized"}, status=401)
|
||||
|
||||
tokens = await get_valid_tokens()
|
||||
if not tokens:
|
||||
return web.json_response({"error": "No valid tokens"}, status=500)
|
||||
|
||||
path = request.path
|
||||
target_url = f"{CODEX_BASE_URL}{path}"
|
||||
logger.info(
|
||||
"Proxying request: %s %s -> %s",
|
||||
request.method,
|
||||
request.path_qs,
|
||||
target_url,
|
||||
)
|
||||
|
||||
headers = {}
|
||||
for key, value in request.headers.items():
|
||||
if key.lower() not in ("host", "authorization", "content-length"):
|
||||
headers[key] = value
|
||||
headers["Authorization"] = f"Bearer {tokens.access_token}"
|
||||
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
body = await request.read()
|
||||
else:
|
||||
body = None
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=headers,
|
||||
data=body,
|
||||
params=request.query,
|
||||
) as resp:
|
||||
content_type = resp.content_type or "application/json"
|
||||
is_stream = (
|
||||
content_type == "text/event-stream" or "stream" in content_type
|
||||
)
|
||||
|
||||
if is_stream:
|
||||
response = web.StreamResponse(
|
||||
status=resp.status,
|
||||
reason=resp.reason,
|
||||
headers={
|
||||
"Content-Type": content_type,
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
async for chunk in resp.content.iter_any():
|
||||
await response.write(chunk)
|
||||
|
||||
await response.write_eof()
|
||||
return response
|
||||
else:
|
||||
response_body = await resp.read()
|
||||
if resp.status >= 400:
|
||||
preview = response_body[:500].decode("utf-8", errors="replace")
|
||||
logger.warning(
|
||||
"Upstream error: status=%s path=%s body=%s",
|
||||
resp.status,
|
||||
request.path,
|
||||
preview,
|
||||
)
|
||||
return web.Response(
|
||||
status=resp.status,
|
||||
body=response_body,
|
||||
headers={"Content-Type": content_type},
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
return web.json_response({"error": f"Proxy error: {e}"}, status=502)
|
||||
|
||||
|
||||
async def health_handler(request: web.Request) -> web.Response:
|
||||
tokens = await get_valid_tokens()
|
||||
usage = -1
|
||||
if tokens:
|
||||
usage = get_usage_percent(tokens.access_token)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "ok" if tokens else "no_tokens",
|
||||
"has_tokens": tokens is not None,
|
||||
"usage_percent": usage,
|
||||
"refresh_in_progress": refresh_in_progress,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def start_background_tasks(app: web.Application):
|
||||
app["usage_monitor"] = asyncio.create_task(usage_monitor())
|
||||
|
||||
|
||||
async def cleanup_background_tasks(app: web.Application):
|
||||
app["usage_monitor"].cancel()
|
||||
try:
|
||||
await app["usage_monitor"]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
def create_app() -> web.Application:
|
||||
app = web.Application(middlewares=[request_log_middleware])
|
||||
app.router.add_get("/oauth/authorize", oauth_authorize_handler)
|
||||
app.router.add_post("/oauth/token", oauth_token_handler)
|
||||
app.router.add_get("/health", health_handler)
|
||||
app.router.add_route("*", "/{path:.*}", proxy_handler)
|
||||
app.on_startup.append(start_background_tasks)
|
||||
app.on_cleanup.append(cleanup_background_tasks)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(f"Starting proxy on port {PORT}")
|
||||
logger.info(f"Usage threshold: {USAGE_THRESHOLD}%")
|
||||
logger.info(f"Check interval: {CHECK_INTERVAL}s")
|
||||
|
||||
auth_data = load_or_create_auth()
|
||||
logger.info("Client access token: %s", _mask(auth_data["access_token"]))
|
||||
logger.info("Client refresh token: %s", _mask(auth_data["refresh_token"]))
|
||||
|
||||
startup_tokens = load_tokens()
|
||||
if startup_tokens:
|
||||
logger.info("Upstream access token: %s", _mask(startup_tokens.access_token))
|
||||
else:
|
||||
logger.warning("No upstream token found at %s", DATA_DIR / "tokens.json")
|
||||
app = create_app()
|
||||
web.run_app(app, host="0.0.0.0", port=PORT)
|
||||
147
src/server.py
Normal file
147
src/server.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from providers.chatgpt import ChatGPTProvider
|
||||
|
||||
PORT = int(os.environ.get("PORT", "8080"))
|
||||
USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85"))
|
||||
LIMIT_EXHAUSTED_PERCENT = 100
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Registry of available providers
|
||||
PROVIDERS = {
|
||||
"chatgpt": ChatGPTProvider(),
|
||||
}
|
||||
|
||||
refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()}
|
||||
background_refresh_tasks: dict[str, asyncio.Task | None] = {
|
||||
name: None for name in PROVIDERS.keys()
|
||||
}
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def request_log_middleware(request: web.Request, handler):
|
||||
response = await handler(request)
|
||||
logger.info("%s %s -> %s", request.method, request.path_qs, response.status)
|
||||
return response
|
||||
|
||||
|
||||
def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
||||
remaining = max(0, 100 - usage_percent)
|
||||
return {
|
||||
"used_percent": usage_percent,
|
||||
"remaining_percent": remaining,
|
||||
"exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT,
|
||||
"needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD,
|
||||
}
|
||||
|
||||
|
||||
async def issue_new_token(provider_name: str) -> str | None:
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
async with refresh_locks[provider_name]:
|
||||
logger.info(f"[{provider_name}] Generating new token")
|
||||
success = await provider.register_new_account()
|
||||
if not success:
|
||||
logger.error(f"[{provider_name}] Token generation failed")
|
||||
return None
|
||||
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
logger.error(f"[{provider_name}] Token was generated but not available")
|
||||
return None
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def background_refresh_worker(provider_name: str, reason: str):
|
||||
try:
|
||||
logger.info(f"[{provider_name}] Starting background token refresh ({reason})")
|
||||
new_token = await issue_new_token(provider_name)
|
||||
if new_token:
|
||||
logger.info(f"[{provider_name}] Background token refresh completed")
|
||||
else:
|
||||
logger.error(f"[{provider_name}] Background token refresh failed")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"[{provider_name}] Unhandled error in background token refresh"
|
||||
)
|
||||
|
||||
|
||||
def trigger_background_refresh(provider_name: str, reason: str):
|
||||
task = background_refresh_tasks.get(provider_name)
|
||||
if task and not task.done():
|
||||
logger.info(
|
||||
f"[{provider_name}] Background refresh already running, skip ({reason})"
|
||||
)
|
||||
return
|
||||
background_refresh_tasks[provider_name] = asyncio.create_task(
|
||||
background_refresh_worker(provider_name, reason)
|
||||
)
|
||||
|
||||
|
||||
async def token_handler(request: web.Request) -> web.Response:
|
||||
provider_name = request.match_info.get("provider", "chatgpt")
|
||||
|
||||
provider = PROVIDERS.get(provider_name)
|
||||
if not provider:
|
||||
return web.json_response(
|
||||
{"error": f"Unknown provider: {provider_name}"},
|
||||
status=404,
|
||||
)
|
||||
|
||||
# Get or create token
|
||||
token = await provider.get_token()
|
||||
if not token:
|
||||
return web.json_response(
|
||||
{"error": "Failed to get active token"},
|
||||
status=503,
|
||||
)
|
||||
|
||||
# Get usage info
|
||||
usage_info = await provider.get_usage_info(token)
|
||||
if "error" in usage_info:
|
||||
return web.json_response(
|
||||
{"error": usage_info["error"]},
|
||||
status=503,
|
||||
)
|
||||
|
||||
usage_percent = usage_info.get("used_percent", 0)
|
||||
|
||||
# Trigger background refresh if needed
|
||||
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
||||
trigger_background_refresh(
|
||||
provider_name,
|
||||
f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%",
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"token": token,
|
||||
"limit": build_limit(usage_percent),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_app() -> web.Application:
|
||||
app = web.Application(middlewares=[request_log_middleware])
|
||||
# New route: /{provider}/token
|
||||
app.router.add_get("/{provider}/token", token_handler)
|
||||
# Legacy route for backward compatibility
|
||||
app.router.add_get("/token", token_handler)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting token service on port %s", PORT)
|
||||
logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD)
|
||||
logger.info("Available providers: %s", ", ".join(PROVIDERS.keys()))
|
||||
app = create_app()
|
||||
web.run_app(app, host="0.0.0.0", port=PORT)
|
||||
0
src/uv.lock → uv.lock
generated
0
src/uv.lock → uv.lock
generated
Loading…
Add table
Add a link
Reference in a new issue