diff --git a/.gitignore b/.gitignore index 7295a4f..afa1ad1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /data/ __pycache__/ +.ruff_cache/ +.venv/ diff --git a/src/Dockerfile b/Dockerfile similarity index 96% rename from src/Dockerfile rename to Dockerfile index c407732..ede7be6 100644 --- a/src/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ RUN pip install --no-cache-dir uv RUN uv sync --frozen --no-dev RUN /app/.venv/bin/python -m playwright install --with-deps chromium -COPY *.py /app/ +COPY src/*.py /app/ ENV PYTHONUNBUFFERED=1 ENV PORT=8000 diff --git a/src/pyproject.toml b/pyproject.toml similarity index 100% rename from src/pyproject.toml rename to pyproject.toml diff --git a/src/browser.py b/src/browser.py new file mode 100644 index 0000000..376295b --- /dev/null +++ b/src/browser.py @@ -0,0 +1,108 @@ +import asyncio +import json +import logging +import os +import shutil +import subprocess +import tempfile +import urllib.request +from dataclasses import dataclass +from pathlib import Path + +from playwright.async_api import Browser, Playwright + +logger = logging.getLogger(__name__) + +CHROME_FLAGS = [ + "--no-startup-window", + "--disable-field-trial-config", + "--disable-background-networking", + "--disable-background-timer-throttling", + "--disable-backgrounding-occluded-windows", + "--disable-back-forward-cache", + "--disable-breakpad", + "--disable-client-side-phishing-detection", + "--disable-component-extensions-with-background-pages", + "--disable-component-update", + "--no-default-browser-check", + "--disable-default-apps", + "--disable-dev-shm-usage", + "--disable-extensions", + "--disable-popup-blocking", + "--disable-prompt-on-repost", + "--disable-renderer-backgrounding", + "--disable-hang-monitor", + "--disable-ipc-flooding-protection", + "--force-color-profile=srgb", + "--metrics-recording-only", + "--no-first-run", + "--password-store=basic", + "--use-mock-keychain", + "--disable-infobars", + "--disable-sync", + "--enable-unsafe-swiftshader", + "--no-sandbox", + "--disable-search-engine-choice-screen", +] + + +def _fetch_ws_endpoint(port: int) -> str | None: + try: + with urllib.request.urlopen( + f"http://127.0.0.1:{port}/json/version", + timeout=1, + ) as resp: + data = json.loads(resp.read().decode("utf-8")) + return data.get("webSocketDebuggerUrl") + except Exception: + return None + + +@dataclass +class ManagedBrowser: + browser: Browser + process: subprocess.Popen + profile_dir: Path + + async def close(self) -> None: + try: + await self.browser.close() + except Exception: + pass + self.process.terminate() + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + if self.profile_dir.exists(): + shutil.rmtree(self.profile_dir, ignore_errors=True) + + +async def launch(playwright: Playwright, cdp_port: int | None = None) -> ManagedBrowser: + chrome_path = os.environ.get("CHROMIUM_PATH") or playwright.chromium.executable_path + cdp_port = cdp_port or int(os.environ.get("CDP_PORT", "9222")) + profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp")) + + args = [ + chrome_path, + *CHROME_FLAGS, + f"--user-data-dir={profile_dir}", + f"--remote-debugging-port={cdp_port}", + ] + + proc = subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + ws_endpoint = None + for _ in range(60): + ws_endpoint = await asyncio.to_thread(_fetch_ws_endpoint, cdp_port) + if ws_endpoint: + break + await asyncio.sleep(0.5) + + if not ws_endpoint: + proc.terminate() + raise RuntimeError(f"CDP websocket not available on port {cdp_port}") + + logger.info("CDP websocket: %s", ws_endpoint) + browser = await playwright.chromium.connect_over_cdp(ws_endpoint) + return ManagedBrowser(browser=browser, process=proc, profile_dir=profile_dir) diff --git a/src/email_providers/__init__.py b/src/email_providers/__init__.py new file mode 100644 index 0000000..76e66ca --- /dev/null +++ b/src/email_providers/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseProvider +from .ten_minute_mail import TenMinuteMailProvider +from .temp_mail_org import TempMailOrgProvider + +__all__ = ["BaseProvider", "TenMinuteMailProvider", "TempMailOrgProvider"] diff --git a/src/email_providers/base.py b/src/email_providers/base.py new file mode 100644 index 0000000..386601d --- /dev/null +++ b/src/email_providers/base.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod + +from playwright.async_api import BrowserContext + + +class BaseProvider(ABC): + def __init__(self, browser_session: BrowserContext): + self.browser_session = browser_session + + @abstractmethod + async def get_new_email(self) -> str: + pass + + @abstractmethod + async def get_latest_message(self, email: str) -> str | None: + pass diff --git a/src/email_providers/temp_mail_org.py b/src/email_providers/temp_mail_org.py new file mode 100644 index 0000000..cb2b840 --- /dev/null +++ b/src/email_providers/temp_mail_org.py @@ -0,0 +1,125 @@ +import asyncio +import logging +import re + +from playwright.async_api import BrowserContext, Page + +from .base import BaseProvider + +logger = logging.getLogger(__name__) + + +class TempMailOrgProvider(BaseProvider): + def __init__(self, browser_session: BrowserContext): + super().__init__(browser_session) + self.page: Page | None = None + + async def _ensure_page(self) -> Page: + if self.page is None or self.page.is_closed(): + self.page = await self.browser_session.new_page() + return self.page + + async def get_new_email(self) -> str: + page = await self._ensure_page() + logger.info("[temp-mail.org] Opening mailbox page") + await page.goto("https://temp-mail.org", wait_until="domcontentloaded") + await page.locator("input#mail, #mail, input[value*='@']").first.wait_for( + state="visible", + timeout=30000, + ) + + selectors = ["#mail", "input#mail", "input[value*='@']"] + end_at = asyncio.get_running_loop().time() + 60 + while asyncio.get_running_loop().time() < end_at: + await page.bring_to_front() + for selector in selectors: + try: + field = page.locator(selector).first + if await field.is_visible(timeout=1000): + value = (await field.input_value()).strip() + if "@" in value: + logger.info( + "[temp-mail.org] selector matched: %s -> %s", + selector, + value, + ) + return value + except Exception: + continue + + try: + body = await page.inner_text("body") + found = extract_email(body) + if found: + logger.info("[temp-mail.org] email found by body scan: %s", found) + return found + except Exception: + pass + + await asyncio.sleep(1) + + raise RuntimeError("Could not get temp email from temp-mail.org") + + async def get_latest_message(self, email: str) -> str | None: + page = await self._ensure_page() + logger.info("[temp-mail.org] Waiting for latest message for %s", email) + + if page.is_closed(): + raise RuntimeError("temp-mail.org tab was closed unexpectedly") + + await page.bring_to_front() + + items = page.locator("div.inbox-dataList ul li") + + # temp-mail updates inbox via websocket; do not refresh/reload page. + for attempt in range(30): + try: + count = await items.count() + logger.info("[temp-mail.org] inbox items: %s", count) + except Exception: + count = 0 + + if count > 0: + for idx in reversed(range(count)): + try: + item = items.nth(idx) + if not await item.is_visible(timeout=1000): + continue + text = (await item.inner_text()).strip().replace("\n", " ") + logger.info("[temp-mail.org] item[%s]: %s", idx, text[:160]) + except Exception: + continue + if text: + try: + await item.click() + logger.info("[temp-mail.org] opened item[%s]", idx) + except Exception: + pass + + message_text = text + try: + content = await page.content() + if content and "Your ChatGPT code is" in content: + message_text = content + except Exception: + pass + + try: + await page.go_back( + wait_until="domcontentloaded", timeout=5000 + ) + logger.info("[temp-mail.org] returned back to inbox") + except Exception: + pass + + return message_text + + await asyncio.sleep(2) + + logger.warning("[temp-mail.org] No messages received within 60 seconds") + return None + + +def extract_email(text: str) -> str | None: + match = re.search(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}", text) + return match.group(0) if match else None diff --git a/src/email_providers/ten_minute_mail.py b/src/email_providers/ten_minute_mail.py new file mode 100644 index 0000000..ae4c35e --- /dev/null +++ b/src/email_providers/ten_minute_mail.py @@ -0,0 +1,100 @@ +import asyncio +import logging + +from playwright.async_api import BrowserContext, Page + +from .base import BaseProvider + +logger = logging.getLogger(__name__) + + +class TenMinuteMailProvider(BaseProvider): + def __init__(self, browser_session: BrowserContext): + super().__init__(browser_session) + self.page: Page | None = None + + async def _ensure_page(self) -> Page: + if self.page is None or self.page.is_closed(): + self.page = await self.browser_session.new_page() + return self.page + + async def get_new_email(self) -> str: + page = await self._ensure_page() + logger.info("[10min] Opening https://10minutemail.com") + await page.goto("https://10minutemail.com", wait_until="domcontentloaded") + await page.wait_for_timeout(3000) + + email_input = page.locator("#mail_address") + await email_input.first.wait_for(state="visible", timeout=60000) + + email = (await email_input.first.input_value()).strip() + if not email or "@" not in email: + raise RuntimeError("10MinuteMail did not return a valid email") + + logger.info("[10min] New email acquired: %s", email) + return email + + async def get_latest_message(self, email: str) -> str | None: + page = await self._ensure_page() + logger.info("[10min] Waiting for latest message for %s", email) + + seen_count = 0 + for attempt in range(60): + try: + count = await page.evaluate( + """ + async () => { + const response = await fetch('/messages/messageCount', { credentials: 'include' }); + const data = await response.json(); + return Number(data.messageCount || 0); + } + """ + ) + except Exception: + count = 0 + + if count > 0: + if count != seen_count: + logger.info("[10min] Inbox has %s message(s)", count) + seen_count = count + + try: + messages = await page.evaluate( + """ + async () => { + const response = await fetch('/messages/messagesAfter/0', { credentials: 'include' }); + const data = await response.json(); + return Array.isArray(data) ? data : []; + } + """ + ) + except Exception: + messages = [] + + text = "" + if messages: + latest = messages[-1] + subject = str(latest.get("subject") or "") + sender = str(latest.get("sender") or "") + body_plain = str(latest.get("bodyPlainText") or "") + body_html = str(latest.get("bodyHtmlContent") or "") + text = "\n".join( + part + for part in [subject, sender, body_plain, body_html] + if part + ) + + if text: + logger.info("[10min] Latest message received") + return text + + if attempt % 3 == 0: + try: + await page.reload(wait_until="domcontentloaded", timeout=60000) + except Exception: + pass + + await asyncio.sleep(2) + + logger.warning("[10min] No messages received within timeout") + return None diff --git a/src/entrypoint.sh b/src/entrypoint.sh index 885c8aa..b289b49 100755 --- a/src/entrypoint.sh +++ b/src/entrypoint.sh @@ -12,4 +12,4 @@ cleanup() { trap cleanup EXIT INT TERM -exec /app/.venv/bin/python -u proxy.py +exec /app/.venv/bin/python -u server.py diff --git a/src/get_new_token.py b/src/get_new_token.py deleted file mode 100644 index a344dcf..0000000 --- a/src/get_new_token.py +++ /dev/null @@ -1,397 +0,0 @@ -import asyncio -import os -import re -import random -import secrets -import json -import time -import logging -from datetime import datetime - -import aiohttp -import pkce -from urllib.parse import urlencode, urlparse, parse_qs -from playwright.async_api import async_playwright, Page, Browser -from tokens import DATA_DIR, TOKENS_FILE - -logger = logging.getLogger(__name__) - -CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" -AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize" -TOKEN_URL = "https://auth.openai.com/oauth/token" -REDIRECT_URI = "http://localhost:1455/auth/callback" -SCOPE = "openid profile email offline_access" - - -class AutomationError(Exception): - def __init__(self, step: str, message: str, page: Page | None = None): - self.step = step - self.message = message - self.page = page - super().__init__(f"[{step}] {message}") - - -async def save_error_screenshot(page: Page | None, step: str): - if page: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - screenshots_dir = DATA_DIR / "screenshots" - screenshots_dir.mkdir(parents=True, exist_ok=True) - filename = screenshots_dir / f"error_{step}_{timestamp}.png" - try: - await page.screenshot(path=str(filename)) - logger.error(f"Screenshot saved: {filename}") - except: - pass - - -def generate_pkce(): - return pkce.generate_pkce_pair() - - -def generate_state(): - return secrets.token_urlsafe(32) - - -def create_auth_url(verifier: str, challenge: str, state: str) -> str: - params = { - "response_type": "code", - "client_id": CLIENT_ID, - "redirect_uri": REDIRECT_URI, - "scope": SCOPE, - "code_challenge": challenge, - "code_challenge_method": "S256", - "state": state, - "id_token_add_organizations": "true", - "codex_cli_simplified_flow": "true", - "originator": "opencode", - } - return f"{AUTHORIZE_URL}?{urlencode(params)}" - - -async def get_temp_email(page: Page) -> str: - logger.info("Getting temp email...") - for i in range(30): - mail_input = page.locator("#mail") - if await mail_input.count() > 0: - val = await mail_input.input_value() - if val and "@" in val: - logger.info(f"Got email: {val}") - return val - await page.wait_for_timeout(1000) - raise AutomationError("get_email", "Failed to get email", page) - - -async def get_verification_code(page: Page, used_codes: list | None = None) -> str: - logger.info("Waiting for verification code...") - if used_codes is None: - used_codes = [] - await page.wait_for_timeout(10000) - - for attempt in range(20): - mail_items = page.locator(".inbox-dataList ul li") - count = await mail_items.count() - logger.debug(f"Attempt {attempt + 1}: {count} emails") - - if count > 0: - codes = [] - for i in range(count): - try: - item = mail_items.nth(i) - text = await item.inner_text() - match = re.search( - r"Your ChatGPT code is (\d{6})", text, re.IGNORECASE - ) - if match: - code = match.group(1) - if code not in used_codes: - codes.append(code) - except: - pass - - if codes: - logger.info(f"Got code: {codes[0]}") - return codes[0] - - await page.wait_for_timeout(5000) - await page.reload(wait_until="domcontentloaded") - await page.wait_for_timeout(5000) - - raise AutomationError("get_code", "Code not found", page) - - -async def fill_date_field(page: Page, month: str, day: str, year: str): - async def type_segment(segment_type: str, value: str): - field = page.locator(f'[data-type="{segment_type}"]') - if await field.count() == 0: - raise AutomationError( - "profile", f"Missing birthday segment: {segment_type}", page - ) - - target = field.first - await target.scroll_into_view_if_needed() - await target.focus() - await page.keyboard.press("Control+A") - await page.keyboard.press("Backspace") - await page.keyboard.type(value) - await page.wait_for_timeout(200) - - await type_segment("month", month) - await type_segment("day", day) - await type_segment("year", year) - - -def generate_name(): - first_names = [ - "Alex", - "Jordan", - "Taylor", - "Morgan", - "Casey", - "Riley", - "Quinn", - "Avery", - "Parker", - "Blake", - ] - last_names = [ - "Smith", - "Johnson", - "Williams", - "Brown", - "Jones", - "Davis", - "Miller", - "Wilson", - "Moore", - "Clark", - ] - return f"{random.choice(first_names)} {random.choice(last_names)}" - - -async def exchange_code_for_tokens(code: str, verifier: str) -> dict: - async with aiohttp.ClientSession() as session: - data = { - "grant_type": "authorization_code", - "client_id": CLIENT_ID, - "code": code, - "code_verifier": verifier, - "redirect_uri": REDIRECT_URI, - } - - async with session.post(TOKEN_URL, data=data) as resp: - if not resp.ok: - text = await resp.text() - raise Exception(f"Token exchange failed: {resp.status} {text}") - - json_resp = await resp.json() - return { - "access_token": json_resp["access_token"], - "refresh_token": json_resp["refresh_token"], - "expires_in": json_resp["expires_in"], - } - - -async def get_new_token(headless: bool = False) -> bool: - logger.info("=== Starting token generation ===") - - password = "TempPass123!" - full_name = generate_name() - birth_month, birth_day, birth_year = "01", "15", "1995" - - verifier, challenge = generate_pkce() - state = generate_state() - auth_url = create_auth_url(verifier, challenge, state) - - redirect_url_captured = None - browser: Browser | None = None - current_page: Page | None = None - - try: - async with async_playwright() as p: - chromium_path = os.environ.get("CHROMIUM_PATH") - if chromium_path: - browser = await p.chromium.launch( - headless=headless, - executable_path=chromium_path, - ) - else: - browser = await p.chromium.launch(headless=headless) - context = await browser.new_context() - page = await context.new_page() - current_page = page - - logger.info("[1/6] Getting email...") - await page.goto("https://temp-mail.org", wait_until="domcontentloaded") - email = await get_temp_email(page) - tempmail_page = page - - logger.info("[2/6] Registering ChatGPT...") - chatgpt_page = await context.new_page() - current_page = chatgpt_page - await chatgpt_page.goto("https://chatgpt.com") - await chatgpt_page.wait_for_load_state("domcontentloaded") - - await chatgpt_page.get_by_text("Sign up for free", exact=True).click() - await chatgpt_page.wait_for_timeout(2000) - - await chatgpt_page.locator('input[type="email"]').fill(email) - await chatgpt_page.wait_for_timeout(500) - await chatgpt_page.get_by_role( - "button", name="Continue", exact=True - ).click() - await chatgpt_page.wait_for_timeout(3000) - - await chatgpt_page.locator('input[type="password"]').fill(password) - await chatgpt_page.wait_for_timeout(500) - await chatgpt_page.get_by_role( - "button", name="Continue", exact=True - ).click() - await chatgpt_page.wait_for_timeout(5000) - - logger.info("[3/6] Getting verification code...") - await tempmail_page.bring_to_front() - code = await get_verification_code(tempmail_page) - - await chatgpt_page.bring_to_front() - code_input = chatgpt_page.get_by_placeholder("Code") - if await code_input.count() > 0: - await code_input.fill(code) - await chatgpt_page.wait_for_timeout(5000) - - continue_btn = chatgpt_page.get_by_role( - "button", name="Continue", exact=True - ) - if await continue_btn.count() > 0: - await continue_btn.click() - await chatgpt_page.wait_for_timeout(5000) - - logger.info("[4/6] Setting profile...") - name_input = chatgpt_page.get_by_placeholder("Full name") - if await name_input.count() > 0: - await name_input.fill(full_name) - - await chatgpt_page.wait_for_timeout(500) - await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year) - await chatgpt_page.wait_for_timeout(1000) - - continue_btn = chatgpt_page.get_by_role( - "button", name="Continue", exact=True - ) - if await continue_btn.count() > 0: - await continue_btn.click() - - logger.info("Account registered!") - await chatgpt_page.wait_for_timeout(10000) - await chatgpt_page.wait_for_load_state("networkidle", timeout=30000) - await chatgpt_page.wait_for_timeout(5000) - - used_codes = [code] - - logger.info("[5/6] OAuth flow...") - oauth_page = await context.new_page() - current_page = oauth_page - - def handle_request(request): - nonlocal redirect_url_captured - url = request.url - if "localhost:1455" in url and "code=" in url: - logger.info("Redirect URL captured!") - redirect_url_captured = url - - oauth_page.on("request", handle_request) - - await oauth_page.goto(auth_url) - await oauth_page.wait_for_load_state("domcontentloaded") - await oauth_page.wait_for_timeout(3000) - - await oauth_page.locator('input[type="email"], input[name="email"]').fill( - email - ) - await oauth_page.wait_for_timeout(500) - await oauth_page.get_by_role("button", name="Continue", exact=True).click() - await oauth_page.wait_for_timeout(3000) - - password_input = oauth_page.locator('input[type="password"]') - if await password_input.count() > 0: - await password_input.fill(password) - await oauth_page.wait_for_timeout(500) - await oauth_page.get_by_role( - "button", name="Continue", exact=True - ).click() - await oauth_page.wait_for_timeout(5000) - - await tempmail_page.bring_to_front() - await tempmail_page.reload(wait_until="domcontentloaded") - await tempmail_page.wait_for_timeout(3000) - - try: - oauth_code = await get_verification_code(tempmail_page, used_codes) - except AutomationError: - logger.info("Reopening mail...") - tempmail_page = await context.new_page() - current_page = tempmail_page - await tempmail_page.goto( - "https://temp-mail.org", wait_until="domcontentloaded" - ) - await tempmail_page.wait_for_timeout(10000) - oauth_code = await get_verification_code(tempmail_page, used_codes) - - await oauth_page.bring_to_front() - code_input = oauth_page.get_by_placeholder("Code") - if await code_input.count() > 0: - await code_input.fill(oauth_code) - await oauth_page.wait_for_timeout(500) - await oauth_page.get_by_role( - "button", name="Continue", exact=True - ).click() - await oauth_page.wait_for_timeout(5000) - - for btn_text in ["Continue", "Allow", "Authorize"]: - btn = oauth_page.get_by_role("button", name=btn_text, exact=True) - if await btn.count() > 0: - await btn.click() - break - - await oauth_page.wait_for_timeout(5000) - - logger.info("[6/6] Exchanging code for tokens...") - if redirect_url_captured and "code=" in redirect_url_captured: - parsed = urlparse(redirect_url_captured) - params = parse_qs(parsed.query) - auth_code = params.get("code", [None])[0] - - if auth_code: - tokens = await exchange_code_for_tokens(auth_code, verifier) - - token_data = { - "access_token": tokens["access_token"], - "refresh_token": tokens["refresh_token"], - "expires_at": time.time() + tokens["expires_in"], - } - TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) - with open(TOKENS_FILE, "w") as f: - json.dump(token_data, f, indent=2) - - logger.info(f"Tokens saved to {TOKENS_FILE}") - return True - - raise AutomationError("token_exchange", "Failed to get tokens", oauth_page) - - except AutomationError as e: - logger.error(f"Error at step [{e.step}]: {e.message}") - await save_error_screenshot(e.page, e.step) - return False - except Exception as e: - logger.error(f"Unexpected error: {e}") - await save_error_screenshot(current_page, "unexpected") - return False - finally: - if browser: - await asyncio.sleep(2) - await browser.close() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - success = asyncio.run(get_new_token()) - exit(0 if success else 1) diff --git a/src/providers/__init__.py b/src/providers/__init__.py new file mode 100644 index 0000000..40e1179 --- /dev/null +++ b/src/providers/__init__.py @@ -0,0 +1,3 @@ +from .base import Provider, ProviderTokens + +__all__ = ["Provider", "ProviderTokens"] diff --git a/src/providers/base.py b/src/providers/base.py new file mode 100644 index 0000000..1a5b6bc --- /dev/null +++ b/src/providers/base.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ProviderTokens: + """Base token structure for any provider""" + + access_token: str + refresh_token: str | None + expires_at: float + metadata: dict[str, Any] | None = None + + @property + def is_expired(self) -> bool: + import time + + return time.time() >= self.expires_at - 10 + + +class Provider(ABC): + """Base class for all account providers""" + + @property + @abstractmethod + def name(self) -> str: + """Provider name (e.g., 'chatgpt', 'claude')""" + pass + + @abstractmethod + async def get_token(self) -> str | None: + """Get valid access token, refreshing if needed""" + pass + + @abstractmethod + async def register_new_account(self) -> bool: + """Register a new account and get tokens""" + pass + + @abstractmethod + async def get_usage_info(self, access_token: str) -> dict[str, Any]: + """Get usage information for the current token""" + pass + + @abstractmethod + def load_tokens(self) -> ProviderTokens | None: + """Load tokens from storage""" + pass + + @abstractmethod + def save_tokens(self, tokens: ProviderTokens) -> None: + """Save tokens to storage""" + pass diff --git a/src/providers/chatgpt/__init__.py b/src/providers/chatgpt/__init__.py new file mode 100644 index 0000000..a485bb8 --- /dev/null +++ b/src/providers/chatgpt/__init__.py @@ -0,0 +1,3 @@ +from .provider import ChatGPTProvider + +__all__ = ["ChatGPTProvider"] diff --git a/src/providers/chatgpt/provider.py b/src/providers/chatgpt/provider.py new file mode 100644 index 0000000..c6c9b92 --- /dev/null +++ b/src/providers/chatgpt/provider.py @@ -0,0 +1,68 @@ +import logging +from typing import Callable +from typing import Any + +from playwright.async_api import BrowserContext + +from providers.base import Provider, ProviderTokens +from email_providers import BaseProvider +from email_providers import TempMailOrgProvider +from .tokens import load_tokens, save_tokens, get_valid_tokens +from .usage import get_usage_percent +from .registration import register_chatgpt_account + +logger = logging.getLogger(__name__) + + +class ChatGPTProvider(Provider): + """ChatGPT account provider""" + + def __init__( + self, + email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, + ): + self.email_provider_factory = email_provider_factory or TempMailOrgProvider + + @property + def name(self) -> str: + return "chatgpt" + + async def get_token(self) -> str | None: + """Get valid access token, refreshing if needed""" + tokens = await get_valid_tokens() + if not tokens: + logger.info("No valid tokens, registering new account") + success = await self.register_new_account() + if not success: + return None + tokens = await get_valid_tokens() + if not tokens: + return None + return tokens.access_token + + async def register_new_account(self) -> bool: + """Register a new ChatGPT account""" + return await register_chatgpt_account( + email_provider_factory=self.email_provider_factory, + ) + + async def get_usage_info(self, access_token: str) -> dict[str, Any]: + """Get usage information for the current token""" + usage_percent = get_usage_percent(access_token) + if usage_percent < 0: + return {"error": "Failed to get usage"} + + remaining = max(0, 100 - usage_percent) + return { + "used_percent": usage_percent, + "remaining_percent": remaining, + "exhausted": usage_percent >= 100, + } + + def load_tokens(self) -> ProviderTokens | None: + """Load tokens from storage""" + return load_tokens() + + def save_tokens(self, tokens: ProviderTokens) -> None: + """Save tokens to storage""" + save_tokens(tokens) diff --git a/src/providers/chatgpt/registration.py b/src/providers/chatgpt/registration.py new file mode 100644 index 0000000..c1d5556 --- /dev/null +++ b/src/providers/chatgpt/registration.py @@ -0,0 +1,477 @@ +import asyncio +import base64 +import hashlib +import logging +import random +import re +import secrets +import string +import time +from datetime import datetime +from pathlib import Path +import os +from typing import Callable +from urllib.parse import parse_qs, urlencode, urlparse + +import aiohttp +from playwright.async_api import async_playwright, Page, BrowserContext + +from browser import launch as launch_browser +from email_providers import BaseProvider +from providers.base import ProviderTokens +from .tokens import CLIENT_ID, save_tokens + +logger = logging.getLogger(__name__) + +DATA_DIR = Path(os.environ.get("DATA_DIR", "./data")) +AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize" +TOKEN_URL = "https://auth.openai.com/oauth/token" +REDIRECT_URI = "http://localhost:1455/auth/callback" +SCOPE = "openid profile email offline_access" + + +class AutomationError(Exception): + def __init__(self, step: str, message: str, page: Page | None = None): + self.step = step + self.message = message + self.page = page + super().__init__(f"[{step}] {message}") + + +async def save_error_screenshot(page: Page | None, step: str): + if page: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + screenshots_dir = DATA_DIR / "screenshots" + screenshots_dir.mkdir(parents=True, exist_ok=True) + filename = screenshots_dir / f"error_{step}_{timestamp}.png" + try: + await page.screenshot(path=str(filename)) + logger.error(f"Screenshot saved: {filename}") + except: + pass + + +def generate_password(length: int = 20) -> str: + alphabet = string.ascii_letters + string.digits + return "".join(random.choice(alphabet) for _ in range(length)) + + +def generate_name() -> str: + first_names = [ + "James", + "John", + "Robert", + "Michael", + "William", + "David", + "Richard", + "Joseph", + "Thomas", + "Charles", + "Christopher", + "Daniel", + "Matthew", + "Anthony", + "Mark", + "Donald", + "Steven", + "Paul", + "Andrew", + "Joshua", + ] + last_names = [ + "Smith", + "Johnson", + "Williams", + "Brown", + "Jones", + "Garcia", + "Miller", + "Davis", + "Rodriguez", + "Martinez", + "Hernandez", + "Lopez", + "Gonzalez", + "Wilson", + "Anderson", + "Thomas", + "Taylor", + "Moore", + "Jackson", + "Martin", + ] + return f"{random.choice(first_names)} {random.choice(last_names)}" + + +def extract_verification_code(message: str) -> str | None: + normalized = re.sub(r"\s+", " ", message) + + preferred = re.search( + r"Your\s+ChatGPT\s+code\s+is\s*(\d{6})", + normalized, + re.IGNORECASE, + ) + if preferred: + return preferred.group(1) + + openai_otp = re.search(r"OpenAI\s+otp.*?(\d{6})", normalized, re.IGNORECASE) + if openai_otp: + return openai_otp.group(1) + + all_codes = re.findall(r"\b(\d{6})\b", normalized) + if all_codes: + return all_codes[-1] + + return None + + +def generate_pkce_pair() -> tuple[str, str]: + verifier = secrets.token_urlsafe(64) + digest = hashlib.sha256(verifier.encode("utf-8")).digest() + challenge = base64.urlsafe_b64encode(digest).decode("utf-8").rstrip("=") + return verifier, challenge + + +def generate_state() -> str: + return secrets.token_urlsafe(32) + + +def build_authorize_url(verifier: str, challenge: str, state: str) -> str: + del verifier + params = { + "response_type": "code", + "client_id": CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "scope": SCOPE, + "code_challenge": challenge, + "code_challenge_method": "S256", + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "state": state, + "originator": "opencode", + } + return f"{AUTHORIZE_URL}?{urlencode(params)}" + + +async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens: + async with aiohttp.ClientSession() as session: + payload = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": REDIRECT_URI, + } + async with session.post(TOKEN_URL, data=payload) as resp: + if not resp.ok: + text = await resp.text() + raise RuntimeError(f"Token exchange failed: {resp.status} {text}") + body = await resp.json() + + expires_in = int(body["expires_in"]) + return ProviderTokens( + access_token=body["access_token"], + refresh_token=body["refresh_token"], + expires_at=time.time() + expires_in, + ) + + +async def get_new_verification_code( + email_provider: BaseProvider, + email: str, + used_codes: set[str], + timeout_seconds: int = 240, +) -> str | None: + attempts = max(1, timeout_seconds // 5) + for _ in range(attempts): + message = await email_provider.get_latest_message(email) + if message: + all_codes = re.findall(r"\b(\d{6})\b", message) + for candidate in all_codes: + if candidate not in used_codes: + return candidate + await asyncio.sleep(5) + return None + + +async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None: + message = await email_provider.get_latest_message(email) + if not message: + return None + return extract_verification_code(message) + + +async def fill_date_field(page: Page, month: str, day: str, year: str): + month_field = page.locator('[data-type="month"]').first + if await month_field.count() == 0: + raise AutomationError("profile", "Missing birthday month field", page) + + await month_field.scroll_into_view_if_needed() + await month_field.click() + await page.wait_for_timeout(120) + + await page.keyboard.type(f"{month}{day}{year}") + await page.wait_for_timeout(200) + + +async def wait_for_signup_stabilization(page: Page): + try: + await page.wait_for_load_state("networkidle", timeout=15000) + except Exception: + logger.warning( + "Signup page did not reach networkidle quickly; continuing with fallback" + ) + try: + await page.wait_for_load_state("domcontentloaded", timeout=10000) + except Exception: + pass + await page.wait_for_timeout(3000) + + +async def register_chatgpt_account( + email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None, +) -> bool: + logger.info("=== Starting ChatGPT account registration ===") + + if email_provider_factory is None: + logger.error("No email provider factory configured") + return False + + birth_month, birth_day, birth_year = "01", "15", "1995" + + current_page: Page | None = None + redirect_url_captured: str | None = None + managed = None + + try: + async with async_playwright() as p: + managed = await launch_browser(p) + browser = managed.browser + context = ( + browser.contexts[0] if browser.contexts else await browser.new_context() + ) + email_provider = email_provider_factory(context) + + logger.info("[1/6] Getting new email from configured provider...") + email = await email_provider.get_new_email() + if not email: + raise AutomationError( + "email_provider", "Email provider returned empty email" + ) + + password = generate_password() + full_name = generate_name() + verifier, challenge = generate_pkce_pair() + oauth_state = generate_state() + authorize_url = build_authorize_url(verifier, challenge, oauth_state) + + logger.info("[2/6] Registering ChatGPT for %s", email) + chatgpt_page = await context.new_page() + current_page = chatgpt_page + await chatgpt_page.goto("https://chatgpt.com") + await chatgpt_page.wait_for_load_state("domcontentloaded") + + await chatgpt_page.get_by_text("Sign up for free", exact=True).click() + await chatgpt_page.wait_for_timeout(2000) + + await chatgpt_page.locator('input[type="email"]').fill(email) + await chatgpt_page.wait_for_timeout(500) + await chatgpt_page.get_by_role( + "button", name="Continue", exact=True + ).click() + await chatgpt_page.wait_for_timeout(3000) + + await chatgpt_page.locator('input[type="password"]').fill(password) + await chatgpt_page.wait_for_timeout(500) + await chatgpt_page.get_by_role( + "button", name="Continue", exact=True + ).click() + await chatgpt_page.wait_for_timeout(5000) + + logger.info("[3/6] Getting verification message from email provider...") + code = await get_latest_code(email_provider, email) + if not code: + raise AutomationError( + "email_provider", "Email provider returned no verification message" + ) + logger.info("[3/6] Verification code extracted: %s", code) + used_codes = {code} + + await chatgpt_page.bring_to_front() + code_input = chatgpt_page.get_by_placeholder("Code") + if await code_input.count() > 0: + await code_input.fill(code) + await chatgpt_page.wait_for_timeout(5000) + + continue_btn = chatgpt_page.get_by_role( + "button", name="Continue", exact=True + ) + if await continue_btn.count() > 0: + await continue_btn.click() + await chatgpt_page.wait_for_timeout(5000) + + logger.info("[4/6] Setting profile...") + name_input = chatgpt_page.get_by_placeholder("Full name") + if await name_input.count() > 0: + await name_input.fill(full_name) + + await chatgpt_page.wait_for_timeout(500) + await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year) + await chatgpt_page.wait_for_timeout(1000) + + continue_btn = chatgpt_page.get_by_role( + "button", name="Continue", exact=True + ) + if await continue_btn.count() > 0: + await continue_btn.click() + + logger.info("Account registered!") + await wait_for_signup_stabilization(chatgpt_page) + + logger.info("[5/6] Skipping onboarding...") + + for _ in range(5): + skip_btn = chatgpt_page.locator( + 'button:has-text("Skip"):not(:has-text("Skip Tour"))' + ) + if await skip_btn.count() > 0: + for i in range(await skip_btn.count()): + try: + btn = skip_btn.nth(i) + if await btn.is_visible(): + await btn.click(timeout=5000) + logger.info("Clicked: Skip") + await chatgpt_page.wait_for_timeout(1500) + except: + pass + await chatgpt_page.wait_for_timeout(1000) + + skip_tour = chatgpt_page.locator('button:has-text("Skip Tour")') + if await skip_tour.count() > 0: + try: + await skip_tour.first.wait_for(state="visible", timeout=5000) + await skip_tour.first.click(timeout=5000) + logger.info("Clicked: Skip Tour") + await chatgpt_page.wait_for_timeout(2000) + except: + pass + + await chatgpt_page.wait_for_timeout(2000) + + for _ in range(3): + continue_btn = chatgpt_page.locator('button:has-text("Continue")') + if await continue_btn.count() > 0: + try: + await continue_btn.first.wait_for(state="visible", timeout=5000) + await continue_btn.first.click(timeout=5000) + logger.info("Clicked: Continue") + await chatgpt_page.wait_for_timeout(2000) + except: + pass + + await chatgpt_page.wait_for_timeout(2000) + + okay_btn = chatgpt_page.locator('button:has-text("Okay, let")') + for _ in range(10): + try: + await okay_btn.first.wait_for(state="visible", timeout=3000) + await okay_btn.first.click(timeout=5000) + logger.info("Clicked: Okay, let's go") + await chatgpt_page.wait_for_timeout(3000) + break + except: + await chatgpt_page.wait_for_timeout(1000) + + logger.info("Skipping subscription/card flow (disabled)") + await chatgpt_page.wait_for_timeout(2000) + + logger.info("[6/6] Running OAuth flow to get tokens...") + oauth_page = await context.new_page() + current_page = oauth_page + + def handle_request(request): + nonlocal redirect_url_captured + url = request.url + if "localhost:1455" in url and "code=" in url: + redirect_url_captured = url + logger.info("Captured OAuth redirect URL") + + oauth_page.on("request", handle_request) + + await oauth_page.goto(authorize_url, wait_until="domcontentloaded") + await oauth_page.wait_for_timeout(2000) + + email_input = oauth_page.locator('input[type="email"], input[name="email"]') + if await email_input.count() > 0: + await email_input.first.fill(email) + await oauth_page.wait_for_timeout(400) + + continue_button = oauth_page.get_by_role("button", name="Continue") + if await continue_button.count() > 0: + await continue_button.first.click() + await oauth_page.wait_for_timeout(2500) + + password_input = oauth_page.locator('input[type="password"]') + if await password_input.count() > 0: + await password_input.first.fill(password) + await oauth_page.wait_for_timeout(400) + continue_button = oauth_page.get_by_role("button", name="Continue") + if await continue_button.count() > 0: + await continue_button.first.click() + await oauth_page.wait_for_timeout(2500) + + for label in ["Continue", "Allow", "Authorize"]: + button = oauth_page.get_by_role("button", name=label) + if await button.count() > 0: + try: + await button.first.click(timeout=5000) + await oauth_page.wait_for_timeout(2000) + except Exception: + pass + + if not redirect_url_captured: + try: + await oauth_page.wait_for_timeout(4000) + current_url = oauth_page.url + if "localhost:1455" in current_url and "code=" in current_url: + redirect_url_captured = current_url + logger.info("Captured OAuth redirect from page URL") + except Exception: + pass + + if not redirect_url_captured: + raise AutomationError( + "oauth", "OAuth redirect with code was not captured", oauth_page + ) + + parsed = urlparse(redirect_url_captured) + params = parse_qs(parsed.query) + auth_code = params.get("code", [None])[0] + returned_state = params.get("state", [None])[0] + + if not auth_code: + raise AutomationError( + "oauth", "OAuth code missing in redirect", oauth_page + ) + if returned_state != oauth_state: + raise AutomationError("oauth", "OAuth state mismatch", oauth_page) + + tokens = await exchange_code_for_tokens(auth_code, verifier) + save_tokens(tokens) + logger.info("OAuth tokens saved successfully") + + return True + + except AutomationError as e: + logger.error(f"Error at step [{e.step}]: {e.message}") + await save_error_screenshot(e.page, e.step) + return False + except Exception as e: + logger.error(f"Unexpected error: {e}") + await save_error_screenshot(current_page, "unexpected") + return False + finally: + if managed: + await asyncio.sleep(2) + await managed.close() diff --git a/src/tokens.py b/src/providers/chatgpt/tokens.py similarity index 75% rename from src/tokens.py rename to src/providers/chatgpt/tokens.py index e7f7fbc..47960bc 100644 --- a/src/tokens.py +++ b/src/providers/chatgpt/tokens.py @@ -1,46 +1,34 @@ import json import time import os -import base64 -from pathlib import Path -from dataclasses import dataclass import aiohttp +from pathlib import Path + +from providers.base import ProviderTokens DATA_DIR = Path(os.environ.get("DATA_DIR", "./data")) -TOKENS_FILE = DATA_DIR / "tokens.json" +TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json" CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" TOKEN_URL = "https://auth.openai.com/oauth/token" -@dataclass -class Tokens: - access_token: str - refresh_token: str - expires_at: float # unix timestamp - - @property - def is_expired(self) -> bool: - return time.time() >= self.expires_at - 10 - - -def load_tokens() -> Tokens | None: +def load_tokens() -> ProviderTokens | None: if not TOKENS_FILE.exists(): return None try: with open(TOKENS_FILE) as f: data = json.load(f) - access_token = data["access_token"] - return Tokens( - access_token=access_token, + return ProviderTokens( + access_token=data["access_token"], refresh_token=data["refresh_token"], expires_at=data["expires_at"], ) - except (json.JSONDecodeError, KeyError): + except json.JSONDecodeError, KeyError: return None -def save_tokens(tokens: Tokens): +def save_tokens(tokens: ProviderTokens): TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) with open(TOKENS_FILE, "w") as f: json.dump( @@ -54,7 +42,7 @@ def save_tokens(tokens: Tokens): ) -async def refresh_tokens(refresh_token: str) -> Tokens | None: +async def refresh_tokens(refresh_token: str) -> ProviderTokens | None: async with aiohttp.ClientSession() as session: data = { "grant_type": "refresh_token", @@ -68,14 +56,14 @@ async def refresh_tokens(refresh_token: str) -> Tokens | None: return None json_resp = await resp.json() expires_in = json_resp["expires_in"] - return Tokens( + return ProviderTokens( access_token=json_resp["access_token"], refresh_token=json_resp["refresh_token"], expires_at=time.time() + expires_in, ) -async def get_valid_tokens() -> Tokens | None: +async def get_valid_tokens() -> ProviderTokens | None: tokens = load_tokens() if not tokens: print("No tokens found") diff --git a/src/codex_usage.py b/src/providers/chatgpt/usage.py similarity index 78% rename from src/codex_usage.py rename to src/providers/chatgpt/usage.py index 4fed548..abe38a9 100644 --- a/src/codex_usage.py +++ b/src/providers/chatgpt/usage.py @@ -8,7 +8,7 @@ from typing import Any def clamp_percent(value: Any) -> int: try: num = float(value) - except (TypeError, ValueError): + except TypeError, ValueError: return 0 if num < 0: return 0 @@ -35,7 +35,7 @@ def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int: body = res.read().decode("utf-8", errors="replace") except urllib.error.HTTPError as e: return -1 - except (urllib.error.URLError, socket.timeout): + except urllib.error.URLError, socket.timeout: return -1 try: @@ -48,14 +48,3 @@ def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int: return clamp_percent(primary.get("used_percent") or 0) return -1 - - -if __name__ == "__main__": - from tokens import load_tokens - - tokens = load_tokens() - if tokens: - usage = get_usage_percent(tokens.access_token) - print(f"{usage}%") - else: - print("No tokens") diff --git a/src/proxy.py b/src/proxy.py deleted file mode 100644 index 6fa5e56..0000000 --- a/src/proxy.py +++ /dev/null @@ -1,436 +0,0 @@ -import os -import asyncio -import logging -import time -import secrets -import json -import base64 -import uuid -from urllib.parse import urlencode -from aiohttp import web -import aiohttp - -from tokens import get_valid_tokens, load_tokens, DATA_DIR -from codex_usage import get_usage_percent -from get_new_token import get_new_token - -CODEX_BASE_URL = "https://chatgpt.com/backend-api" -PORT = int(os.environ.get("PORT", "8080")) -USAGE_THRESHOLD = int(os.environ.get("USAGE_THRESHOLD", "85")) -CHECK_INTERVAL = int(os.environ.get("CHECK_INTERVAL", "60")) -FAKE_EXPIRES_IN = 9999999999999 -AUTH_FILE = DATA_DIR / "auth.json" -JWT_AUTH_CLAIM_PATH = "https://api.openai.com/auth" -JWT_PROFILE_CLAIM_PATH = "https://api.openai.com/profile" - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -refresh_in_progress = False -auth_codes: dict[str, dict] = {} - - -def _b64url(data: bytes) -> str: - return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") - - -def _generate_jwt_like() -> str: - account_id = str(uuid.uuid4()) - now = int(time.time()) - header = {"alg": "HS256", "typ": "JWT"} - user_id = f"user-{secrets.token_urlsafe(18)}" - account_user_id = f"{user_id}__{account_id}" - payload = { - "aud": ["https://api.openai.com/v1"], - "client_id": "app_EMoamEEZ73f0CkXaXp7hrann", - "iss": "https://auth.openai.com", - "iat": now, - "nbf": now, - "exp": now + 315360000, - "jti": str(uuid.uuid4()), - "scp": ["openid", "profile", "email", "offline_access"], - "session_id": f"authsess_{secrets.token_urlsafe(24)}", - JWT_AUTH_CLAIM_PATH: { - "chatgpt_account_id": account_id, - "chatgpt_account_user_id": account_user_id, - "chatgpt_compute_residency": "no_constraint", - "chatgpt_plan_type": "plus", - "chatgpt_user_id": user_id, - "user_id": user_id, - }, - JWT_PROFILE_CLAIM_PATH: { - "email": f"proxy-{secrets.token_hex(4)}@example.local", - "email_verified": True, - }, - "sub": f"auth0|{secrets.token_urlsafe(20)}", - } - head = _b64url(json.dumps(header, separators=(",", ":")).encode("utf-8")) - body = _b64url(json.dumps(payload, separators=(",", ":")).encode("utf-8")) - sign = _b64url(secrets.token_bytes(32)) - return f"{head}.{body}.{sign}" - - -def _generate_refresh_like() -> str: - return f"rt_{secrets.token_urlsafe(40)}.{secrets.token_urlsafe(32)}" - - -def _mask(value: str, head: int = 8, tail: int = 6) -> str: - if not value: - return "" - if len(value) <= head + tail: - return "" - return f"{value[:head]}...{value[-tail:]}" - - -def load_or_create_auth() -> dict: - if AUTH_FILE.exists(): - with open(AUTH_FILE) as f: - data = json.load(f) - if ( - data.get("access_token") - and data.get("refresh_token") - and data.get("expires_at") - ): - return data - - DATA_DIR.mkdir(parents=True, exist_ok=True) - access_token = _generate_jwt_like() - - data = { - "access_token": access_token, - "refresh_token": _generate_refresh_like(), - "expires_at": FAKE_EXPIRES_IN, - } - with open(AUTH_FILE, "w") as f: - json.dump(data, f, indent=2) - return data - - -@web.middleware -async def request_log_middleware(request: web.Request, handler): - started = time.perf_counter() - response = None - try: - response = await handler(request) - return response - finally: - elapsed_ms = int((time.perf_counter() - started) * 1000) - status = getattr(response, "status", "ERR") - logger.info( - "%s %s -> %s (%d ms)", - request.method, - request.path_qs, - status, - elapsed_ms, - ) - - -def check_auth(request: web.Request) -> bool: - auth_data = load_or_create_auth() - expected_token = auth_data["access_token"] - auth = request.headers.get("Authorization", "") - if auth.lower().startswith("bearer "): - token = auth[7:].strip() - return token == expected_token - return False - - -async def oauth_authorize_handler(request: web.Request) -> web.Response: - params = request.rel_url.query - redirect_uri = params.get("redirect_uri") - state = params.get("state", "") - - if not redirect_uri: - return web.json_response( - {"error": "invalid_request", "error_description": "Missing redirect_uri"}, - status=400, - ) - - code = f"ac_{secrets.token_urlsafe(48)}" - auth_codes[code] = { - "state": state, - "created_at": time.time(), - } - - query = urlencode( - { - "code": code, - "scope": "openid profile email offline_access", - "state": state, - } - ) - location = f"{redirect_uri}?{query}" - logger.info("OAuth authorize: issued code") - raise web.HTTPFound(location=location) - - -async def oauth_token_handler(request: web.Request) -> web.Response: - auth_data = load_or_create_auth() - - content_type = request.content_type or "" - grant_type = None - refresh_token = None - code = None - if content_type.startswith("application/json"): - body = await request.json() - grant_type = body.get("grant_type") - refresh_token = body.get("refresh_token") - code = body.get("code") - else: - form = await request.post() - grant_type = form.get("grant_type") - refresh_token = form.get("refresh_token") - code = form.get("code") - - if grant_type == "authorization_code": - code = str(code) if code else "" - if not code or code not in auth_codes: - return web.json_response( - { - "error": "invalid_grant", - "error_description": "Invalid authorization code", - }, - status=400, - ) - created_at = auth_codes[code]["created_at"] - del auth_codes[code] - if time.time() - created_at > 300: - return web.json_response( - { - "error": "invalid_grant", - "error_description": "Authorization code expired", - }, - status=400, - ) - - return web.json_response( - { - "access_token": auth_data["access_token"], - "refresh_token": auth_data["refresh_token"], - "token_type": "Bearer", - "expires_in": FAKE_EXPIRES_IN, - } - ) - - if grant_type == "refresh_token": - if refresh_token != auth_data["refresh_token"]: - return web.json_response( - { - "error": "invalid_grant", - "error_description": "Invalid refresh token", - }, - status=400, - ) - - return web.json_response( - { - "access_token": auth_data["access_token"], - "refresh_token": auth_data["refresh_token"], - "token_type": "Bearer", - "expires_in": FAKE_EXPIRES_IN, - } - ) - - return web.json_response( - { - "error": "unsupported_grant_type", - "error_description": "Only authorization_code and refresh_token are supported", - }, - status=400, - ) - - -async def refresh_tokens_task(): - global refresh_in_progress - if refresh_in_progress: - logger.info("Token refresh already in progress") - return - - refresh_in_progress = True - logger.info("Starting token refresh...") - - try: - success = await get_new_token(headless=False) - if success: - logger.info("Token refresh completed successfully") - else: - logger.error("Token refresh failed") - except Exception as e: - logger.error(f"Error during token refresh: {e}") - finally: - refresh_in_progress = False - - -async def usage_monitor(): - while True: - for _ in range(1): - tokens = load_tokens() - - if not tokens: - if not refresh_in_progress: - logger.warning("No tokens found, starting refresh...") - asyncio.create_task(refresh_tokens_task()) - break - - usage = get_usage_percent(tokens.access_token) - - if usage < 0: - logger.warning("Failed to get usage, token may be invalid") - asyncio.create_task(refresh_tokens_task()) - break - - logger.info(f"Current usage: {usage}%") - - if usage >= USAGE_THRESHOLD: - logger.info( - f"Usage {usage}% >= threshold {USAGE_THRESHOLD}%, starting refresh..." - ) - asyncio.create_task(refresh_tokens_task()) - break - - await asyncio.sleep(CHECK_INTERVAL) - - -async def proxy_handler(request: web.Request) -> web.StreamResponse | web.Response: - if not check_auth(request): - auth = request.headers.get("Authorization", "") - auth_preview = auth[:24] + ("..." if len(auth) > 24 else "") - logger.warning( - "Auth failed: method=%s path=%s auth_present=%s auth_preview=%s ua=%s", - request.method, - request.path, - bool(auth), - auth_preview, - request.headers.get("User-Agent", ""), - ) - return web.json_response({"error": "Unauthorized"}, status=401) - - tokens = await get_valid_tokens() - if not tokens: - return web.json_response({"error": "No valid tokens"}, status=500) - - path = request.path - target_url = f"{CODEX_BASE_URL}{path}" - logger.info( - "Proxying request: %s %s -> %s", - request.method, - request.path_qs, - target_url, - ) - - headers = {} - for key, value in request.headers.items(): - if key.lower() not in ("host", "authorization", "content-length"): - headers[key] = value - headers["Authorization"] = f"Bearer {tokens.access_token}" - - if request.method in ("POST", "PUT", "PATCH"): - body = await request.read() - else: - body = None - - async with aiohttp.ClientSession() as session: - try: - async with session.request( - method=request.method, - url=target_url, - headers=headers, - data=body, - params=request.query, - ) as resp: - content_type = resp.content_type or "application/json" - is_stream = ( - content_type == "text/event-stream" or "stream" in content_type - ) - - if is_stream: - response = web.StreamResponse( - status=resp.status, - reason=resp.reason, - headers={ - "Content-Type": content_type, - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - await response.prepare(request) - - async for chunk in resp.content.iter_any(): - await response.write(chunk) - - await response.write_eof() - return response - else: - response_body = await resp.read() - if resp.status >= 400: - preview = response_body[:500].decode("utf-8", errors="replace") - logger.warning( - "Upstream error: status=%s path=%s body=%s", - resp.status, - request.path, - preview, - ) - return web.Response( - status=resp.status, - body=response_body, - headers={"Content-Type": content_type}, - ) - except aiohttp.ClientError as e: - return web.json_response({"error": f"Proxy error: {e}"}, status=502) - - -async def health_handler(request: web.Request) -> web.Response: - tokens = await get_valid_tokens() - usage = -1 - if tokens: - usage = get_usage_percent(tokens.access_token) - - return web.json_response( - { - "status": "ok" if tokens else "no_tokens", - "has_tokens": tokens is not None, - "usage_percent": usage, - "refresh_in_progress": refresh_in_progress, - } - ) - - -async def start_background_tasks(app: web.Application): - app["usage_monitor"] = asyncio.create_task(usage_monitor()) - - -async def cleanup_background_tasks(app: web.Application): - app["usage_monitor"].cancel() - try: - await app["usage_monitor"] - except asyncio.CancelledError: - pass - - -def create_app() -> web.Application: - app = web.Application(middlewares=[request_log_middleware]) - app.router.add_get("/oauth/authorize", oauth_authorize_handler) - app.router.add_post("/oauth/token", oauth_token_handler) - app.router.add_get("/health", health_handler) - app.router.add_route("*", "/{path:.*}", proxy_handler) - app.on_startup.append(start_background_tasks) - app.on_cleanup.append(cleanup_background_tasks) - return app - - -if __name__ == "__main__": - logger.info(f"Starting proxy on port {PORT}") - logger.info(f"Usage threshold: {USAGE_THRESHOLD}%") - logger.info(f"Check interval: {CHECK_INTERVAL}s") - - auth_data = load_or_create_auth() - logger.info("Client access token: %s", _mask(auth_data["access_token"])) - logger.info("Client refresh token: %s", _mask(auth_data["refresh_token"])) - - startup_tokens = load_tokens() - if startup_tokens: - logger.info("Upstream access token: %s", _mask(startup_tokens.access_token)) - else: - logger.warning("No upstream token found at %s", DATA_DIR / "tokens.json") - app = create_app() - web.run_app(app, host="0.0.0.0", port=PORT) diff --git a/src/server.py b/src/server.py new file mode 100644 index 0000000..9fb791a --- /dev/null +++ b/src/server.py @@ -0,0 +1,147 @@ +import asyncio +import logging +import os + +from aiohttp import web + +from providers.chatgpt import ChatGPTProvider + +PORT = int(os.environ.get("PORT", "8080")) +USAGE_REFRESH_THRESHOLD = int(os.environ.get("USAGE_REFRESH_THRESHOLD", "85")) +LIMIT_EXHAUSTED_PERCENT = 100 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Registry of available providers +PROVIDERS = { + "chatgpt": ChatGPTProvider(), +} + +refresh_locks = {name: asyncio.Lock() for name in PROVIDERS.keys()} +background_refresh_tasks: dict[str, asyncio.Task | None] = { + name: None for name in PROVIDERS.keys() +} + + +@web.middleware +async def request_log_middleware(request: web.Request, handler): + response = await handler(request) + logger.info("%s %s -> %s", request.method, request.path_qs, response.status) + return response + + +def build_limit(usage_percent: int) -> dict[str, int | bool]: + remaining = max(0, 100 - usage_percent) + return { + "used_percent": usage_percent, + "remaining_percent": remaining, + "exhausted": usage_percent >= LIMIT_EXHAUSTED_PERCENT, + "needs_refresh": usage_percent >= USAGE_REFRESH_THRESHOLD, + } + + +async def issue_new_token(provider_name: str) -> str | None: + provider = PROVIDERS.get(provider_name) + if not provider: + return None + + async with refresh_locks[provider_name]: + logger.info(f"[{provider_name}] Generating new token") + success = await provider.register_new_account() + if not success: + logger.error(f"[{provider_name}] Token generation failed") + return None + + token = await provider.get_token() + if not token: + logger.error(f"[{provider_name}] Token was generated but not available") + return None + + return token + + +async def background_refresh_worker(provider_name: str, reason: str): + try: + logger.info(f"[{provider_name}] Starting background token refresh ({reason})") + new_token = await issue_new_token(provider_name) + if new_token: + logger.info(f"[{provider_name}] Background token refresh completed") + else: + logger.error(f"[{provider_name}] Background token refresh failed") + except Exception: + logger.exception( + f"[{provider_name}] Unhandled error in background token refresh" + ) + + +def trigger_background_refresh(provider_name: str, reason: str): + task = background_refresh_tasks.get(provider_name) + if task and not task.done(): + logger.info( + f"[{provider_name}] Background refresh already running, skip ({reason})" + ) + return + background_refresh_tasks[provider_name] = asyncio.create_task( + background_refresh_worker(provider_name, reason) + ) + + +async def token_handler(request: web.Request) -> web.Response: + provider_name = request.match_info.get("provider", "chatgpt") + + provider = PROVIDERS.get(provider_name) + if not provider: + return web.json_response( + {"error": f"Unknown provider: {provider_name}"}, + status=404, + ) + + # Get or create token + token = await provider.get_token() + if not token: + return web.json_response( + {"error": "Failed to get active token"}, + status=503, + ) + + # Get usage info + usage_info = await provider.get_usage_info(token) + if "error" in usage_info: + return web.json_response( + {"error": usage_info["error"]}, + status=503, + ) + + usage_percent = usage_info.get("used_percent", 0) + + # Trigger background refresh if needed + if usage_percent >= USAGE_REFRESH_THRESHOLD: + trigger_background_refresh( + provider_name, + f"usage {usage_percent}% >= threshold {USAGE_REFRESH_THRESHOLD}%", + ) + + return web.json_response( + { + "token": token, + "limit": build_limit(usage_percent), + } + ) + + +def create_app() -> web.Application: + app = web.Application(middlewares=[request_log_middleware]) + # New route: /{provider}/token + app.router.add_get("/{provider}/token", token_handler) + # Legacy route for backward compatibility + app.router.add_get("/token", token_handler) + return app + + +if __name__ == "__main__": + logger.info("Starting token service on port %s", PORT) + logger.info("Usage refresh threshold: %s%%", USAGE_REFRESH_THRESHOLD) + logger.info("Available providers: %s", ", ".join(PROVIDERS.keys())) + app = create_app() + web.run_app(app, host="0.0.0.0", port=PORT) diff --git a/src/uv.lock b/uv.lock similarity index 100% rename from src/uv.lock rename to uv.lock