refactor: harden ChatGPT token lifecycle with startup recovery, single-writer locking, and faster auth flow
This commit is contained in:
parent
71d1050adb
commit
533e382e0e
9 changed files with 313 additions and 178 deletions
10
Dockerfile
10
Dockerfile
|
|
@ -15,19 +15,19 @@ RUN pip install --no-cache-dir uv
|
||||||
RUN uv sync --frozen --no-dev
|
RUN uv sync --frozen --no-dev
|
||||||
RUN /app/.venv/bin/python -m playwright install --with-deps chromium
|
RUN /app/.venv/bin/python -m playwright install --with-deps chromium
|
||||||
|
|
||||||
COPY src/*.py /app/
|
COPY entrypoint.sh /entrypoint.sh
|
||||||
|
|
||||||
|
COPY src .
|
||||||
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
ENV PORT=8000
|
ENV PORT=80
|
||||||
ENV DATA_DIR=/data
|
ENV DATA_DIR=/data
|
||||||
ENV PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD=1
|
ENV PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD=1
|
||||||
|
|
||||||
VOLUME ["/data"]
|
VOLUME ["/data"]
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 80
|
||||||
|
|
||||||
STOPSIGNAL SIGINT
|
STOPSIGNAL SIGINT
|
||||||
|
|
||||||
COPY entrypoint.sh /entrypoint.sh
|
|
||||||
|
|
||||||
CMD ["/entrypoint.sh"]
|
CMD ["/entrypoint.sh"]
|
||||||
|
|
|
||||||
11
compose.yml
11
compose.yml
|
|
@ -1,11 +0,0 @@
|
||||||
services:
|
|
||||||
megapt:
|
|
||||||
build: src
|
|
||||||
restart: unless-stopped
|
|
||||||
environment:
|
|
||||||
USAGE_THRESHOLD: 85
|
|
||||||
CHECK_INTERVAL: 60
|
|
||||||
labels:
|
|
||||||
traefik.host: megapt
|
|
||||||
volumes:
|
|
||||||
- ./data:/data
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -45,6 +44,8 @@ CHROME_FLAGS = [
|
||||||
"--disable-search-engine-choice-screen",
|
"--disable-search-engine-choice-screen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
DEFAULT_CDP_PORT = 9222
|
||||||
|
|
||||||
|
|
||||||
def _fetch_ws_endpoint(port: int) -> str | None:
|
def _fetch_ws_endpoint(port: int) -> str | None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -78,9 +79,10 @@ class ManagedBrowser:
|
||||||
shutil.rmtree(self.profile_dir, ignore_errors=True)
|
shutil.rmtree(self.profile_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
async def launch(playwright: Playwright, cdp_port: int | None = None) -> ManagedBrowser:
|
async def launch(
|
||||||
chrome_path = os.environ.get("CHROMIUM_PATH") or playwright.chromium.executable_path
|
playwright: Playwright, cdp_port: int = DEFAULT_CDP_PORT
|
||||||
cdp_port = cdp_port or int(os.environ.get("CDP_PORT", "9222"))
|
) -> ManagedBrowser:
|
||||||
|
chrome_path = playwright.chromium.executable_path
|
||||||
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
|
profile_dir = Path(tempfile.mkdtemp(prefix="megapt_profile-", dir="/tmp"))
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -7,11 +8,12 @@ from playwright.async_api import BrowserContext
|
||||||
from providers.base import Provider, ProviderTokens
|
from providers.base import Provider, ProviderTokens
|
||||||
from email_providers import BaseProvider
|
from email_providers import BaseProvider
|
||||||
from email_providers import TempMailOrgProvider
|
from email_providers import TempMailOrgProvider
|
||||||
from .tokens import load_tokens, save_tokens, get_valid_tokens
|
from .tokens import load_tokens, save_tokens, refresh_tokens
|
||||||
from .usage import get_usage_percent
|
from .usage import get_usage_data
|
||||||
from .registration import register_chatgpt_account
|
from .registration import register_chatgpt_account
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
MAX_REGISTRATION_ATTEMPTS = 4
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTProvider(Provider):
|
class ChatGPTProvider(Provider):
|
||||||
|
|
@ -22,23 +24,63 @@ class ChatGPTProvider(Provider):
|
||||||
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
email_provider_factory: Callable[[BrowserContext], BaseProvider] | None = None,
|
||||||
):
|
):
|
||||||
self.email_provider_factory = email_provider_factory or TempMailOrgProvider
|
self.email_provider_factory = email_provider_factory or TempMailOrgProvider
|
||||||
|
self._token_write_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def _register_with_retries(self) -> bool:
|
||||||
|
for attempt in range(1, MAX_REGISTRATION_ATTEMPTS + 1):
|
||||||
|
logger.info(
|
||||||
|
"Registration attempt %s/%s",
|
||||||
|
attempt,
|
||||||
|
MAX_REGISTRATION_ATTEMPTS,
|
||||||
|
)
|
||||||
|
success = await self.register_new_account()
|
||||||
|
if success:
|
||||||
|
return True
|
||||||
|
logger.warning("Registration attempt %s failed", attempt)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def force_recreate_token(self) -> str | None:
|
||||||
|
async with self._token_write_lock:
|
||||||
|
success = await self._register_with_retries()
|
||||||
|
if not success:
|
||||||
|
return None
|
||||||
|
tokens = load_tokens()
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
return tokens.access_token
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "chatgpt"
|
return "chatgpt"
|
||||||
|
|
||||||
async def get_token(self) -> str | None:
|
async def get_token(self) -> str | None:
|
||||||
"""Get valid access token, refreshing if needed"""
|
"""Get valid access token with single-writer refresh/register path."""
|
||||||
tokens = await get_valid_tokens()
|
tokens = load_tokens()
|
||||||
if not tokens:
|
if tokens and not tokens.is_expired:
|
||||||
logger.info("No valid tokens, registering new account")
|
return tokens.access_token
|
||||||
success = await self.register_new_account()
|
|
||||||
|
async with self._token_write_lock:
|
||||||
|
tokens = load_tokens()
|
||||||
|
if tokens and not tokens.is_expired:
|
||||||
|
return tokens.access_token
|
||||||
|
|
||||||
|
if tokens and tokens.refresh_token:
|
||||||
|
logger.info("Token expired, refreshing under lock")
|
||||||
|
refreshed = await refresh_tokens(tokens.refresh_token)
|
||||||
|
if refreshed:
|
||||||
|
save_tokens(refreshed)
|
||||||
|
return refreshed.access_token
|
||||||
|
logger.warning("Token refresh failed, falling back to registration")
|
||||||
|
|
||||||
|
logger.info("No valid tokens, registering new account under lock")
|
||||||
|
success = await self._register_with_retries()
|
||||||
if not success:
|
if not success:
|
||||||
return None
|
return None
|
||||||
tokens = await get_valid_tokens()
|
|
||||||
|
tokens = load_tokens()
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return None
|
return None
|
||||||
return tokens.access_token
|
return tokens.access_token
|
||||||
|
|
||||||
async def register_new_account(self) -> bool:
|
async def register_new_account(self) -> bool:
|
||||||
"""Register a new ChatGPT account"""
|
"""Register a new ChatGPT account"""
|
||||||
|
|
@ -48,15 +90,18 @@ class ChatGPTProvider(Provider):
|
||||||
|
|
||||||
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
async def get_usage_info(self, access_token: str) -> dict[str, Any]:
|
||||||
"""Get usage information for the current token"""
|
"""Get usage information for the current token"""
|
||||||
usage_percent = get_usage_percent(access_token)
|
usage_data = get_usage_data(access_token)
|
||||||
if usage_percent < 0:
|
if not usage_data:
|
||||||
return {"error": "Failed to get usage"}
|
return {"error": "Failed to get usage"}
|
||||||
|
|
||||||
remaining = max(0, 100 - usage_percent)
|
|
||||||
return {
|
return {
|
||||||
"used_percent": usage_percent,
|
"used_percent": int(usage_data["used_percent"]),
|
||||||
"remaining_percent": remaining,
|
"remaining_percent": int(usage_data["remaining_percent"]),
|
||||||
"exhausted": usage_percent >= 100,
|
"exhausted": int(usage_data["used_percent"]) >= 100,
|
||||||
|
"primary_window": usage_data.get("primary_window"),
|
||||||
|
"secondary_window": usage_data.get("secondary_window"),
|
||||||
|
"limit_reached": bool(usage_data.get("limit_reached")),
|
||||||
|
"allowed": bool(usage_data.get("allowed", True)),
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_tokens(self) -> ProviderTokens | None:
|
def load_tokens(self) -> ProviderTokens | None:
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,36 @@ def generate_name() -> str:
|
||||||
"Paul",
|
"Paul",
|
||||||
"Andrew",
|
"Andrew",
|
||||||
"Joshua",
|
"Joshua",
|
||||||
|
"Kenneth",
|
||||||
|
"Kevin",
|
||||||
|
"Brian",
|
||||||
|
"George",
|
||||||
|
"Edward",
|
||||||
|
"Ronald",
|
||||||
|
"Timothy",
|
||||||
|
"Jason",
|
||||||
|
"Jeffrey",
|
||||||
|
"Ryan",
|
||||||
|
"Jacob",
|
||||||
|
"Gary",
|
||||||
|
"Nicholas",
|
||||||
|
"Eric",
|
||||||
|
"Jonathan",
|
||||||
|
"Stephen",
|
||||||
|
"Larry",
|
||||||
|
"Justin",
|
||||||
|
"Scott",
|
||||||
|
"Brandon",
|
||||||
|
"Benjamin",
|
||||||
|
"Samuel",
|
||||||
|
"Frank",
|
||||||
|
"Gregory",
|
||||||
|
"Raymond",
|
||||||
|
"Alexander",
|
||||||
|
"Patrick",
|
||||||
|
"Jack",
|
||||||
|
"Dennis",
|
||||||
|
"Jerry",
|
||||||
]
|
]
|
||||||
last_names = [
|
last_names = [
|
||||||
"Smith",
|
"Smith",
|
||||||
|
|
@ -100,10 +130,47 @@ def generate_name() -> str:
|
||||||
"Moore",
|
"Moore",
|
||||||
"Jackson",
|
"Jackson",
|
||||||
"Martin",
|
"Martin",
|
||||||
|
"Lee",
|
||||||
|
"Perez",
|
||||||
|
"Thompson",
|
||||||
|
"White",
|
||||||
|
"Harris",
|
||||||
|
"Sanchez",
|
||||||
|
"Clark",
|
||||||
|
"Ramirez",
|
||||||
|
"Lewis",
|
||||||
|
"Robinson",
|
||||||
|
"Walker",
|
||||||
|
"Young",
|
||||||
|
"Allen",
|
||||||
|
"King",
|
||||||
|
"Wright",
|
||||||
|
"Scott",
|
||||||
|
"Torres",
|
||||||
|
"Nguyen",
|
||||||
|
"Hill",
|
||||||
|
"Flores",
|
||||||
|
"Green",
|
||||||
|
"Adams",
|
||||||
|
"Nelson",
|
||||||
|
"Baker",
|
||||||
|
"Hall",
|
||||||
|
"Rivera",
|
||||||
|
"Campbell",
|
||||||
|
"Mitchell",
|
||||||
|
"Carter",
|
||||||
|
"Roberts",
|
||||||
]
|
]
|
||||||
return f"{random.choice(first_names)} {random.choice(last_names)}"
|
return f"{random.choice(first_names)} {random.choice(last_names)}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_birthdate_90s() -> tuple[str, str, str]:
|
||||||
|
year = random.randint(1990, 1999)
|
||||||
|
month = random.randint(1, 12)
|
||||||
|
day = random.randint(1, 28)
|
||||||
|
return f"{month:02d}", f"{day:02d}", str(year)
|
||||||
|
|
||||||
|
|
||||||
def extract_verification_code(message: str) -> str | None:
|
def extract_verification_code(message: str) -> str | None:
|
||||||
normalized = re.sub(r"\s+", " ", message)
|
normalized = re.sub(r"\s+", " ", message)
|
||||||
|
|
||||||
|
|
@ -177,24 +244,6 @@ async def exchange_code_for_tokens(code: str, verifier: str) -> ProviderTokens:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_new_verification_code(
|
|
||||||
email_provider: BaseProvider,
|
|
||||||
email: str,
|
|
||||||
used_codes: set[str],
|
|
||||||
timeout_seconds: int = 240,
|
|
||||||
) -> str | None:
|
|
||||||
attempts = max(1, timeout_seconds // 5)
|
|
||||||
for _ in range(attempts):
|
|
||||||
message = await email_provider.get_latest_message(email)
|
|
||||||
if message:
|
|
||||||
all_codes = re.findall(r"\b(\d{6})\b", message)
|
|
||||||
for candidate in all_codes:
|
|
||||||
if candidate not in used_codes:
|
|
||||||
return candidate
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
async def get_latest_code(email_provider: BaseProvider, email: str) -> str | None:
|
||||||
message = await email_provider.get_latest_message(email)
|
message = await email_provider.get_latest_message(email)
|
||||||
if not message:
|
if not message:
|
||||||
|
|
@ -209,24 +258,32 @@ async def fill_date_field(page: Page, month: str, day: str, year: str):
|
||||||
|
|
||||||
await month_field.scroll_into_view_if_needed()
|
await month_field.scroll_into_view_if_needed()
|
||||||
await month_field.click()
|
await month_field.click()
|
||||||
await page.wait_for_timeout(120)
|
await page.wait_for_timeout(80)
|
||||||
|
|
||||||
await page.keyboard.type(f"{month}{day}{year}")
|
await page.keyboard.type(f"{month}{day}{year}")
|
||||||
await page.wait_for_timeout(200)
|
await page.wait_for_timeout(120)
|
||||||
|
|
||||||
|
|
||||||
async def wait_for_signup_stabilization(page: Page):
|
async def click_continue(page: Page, timeout_ms: int = 10000):
|
||||||
try:
|
btn = page.get_by_role("button", name="Continue", exact=True).first
|
||||||
await page.wait_for_load_state("networkidle", timeout=15000)
|
await btn.wait_for(state="visible", timeout=timeout_ms)
|
||||||
except Exception:
|
await btn.click()
|
||||||
logger.warning(
|
|
||||||
"Signup page did not reach networkidle quickly; continuing with fallback"
|
|
||||||
)
|
async def wait_for_signup_stabilization(
|
||||||
try:
|
page: Page,
|
||||||
await page.wait_for_load_state("domcontentloaded", timeout=10000)
|
source_url: str,
|
||||||
except Exception:
|
timeout_seconds: int = 30,
|
||||||
pass
|
):
|
||||||
await page.wait_for_timeout(3000)
|
end_at = asyncio.get_running_loop().time() + timeout_seconds
|
||||||
|
while asyncio.get_running_loop().time() < end_at:
|
||||||
|
current_url = page.url
|
||||||
|
if current_url != source_url:
|
||||||
|
logger.info("Signup redirect detected: %s -> %s", source_url, current_url)
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
logger.warning("Signup redirect was not detected within %ss", timeout_seconds)
|
||||||
|
|
||||||
|
|
||||||
async def register_chatgpt_account(
|
async def register_chatgpt_account(
|
||||||
|
|
@ -238,7 +295,7 @@ async def register_chatgpt_account(
|
||||||
logger.error("No email provider factory configured")
|
logger.error("No email provider factory configured")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
birth_month, birth_day, birth_year = "01", "15", "1995"
|
birth_month, birth_day, birth_year = generate_birthdate_90s()
|
||||||
|
|
||||||
current_page: Page | None = None
|
current_page: Page | None = None
|
||||||
redirect_url_captured: str | None = None
|
redirect_url_captured: str | None = None
|
||||||
|
|
@ -253,7 +310,7 @@ async def register_chatgpt_account(
|
||||||
)
|
)
|
||||||
email_provider = email_provider_factory(context)
|
email_provider = email_provider_factory(context)
|
||||||
|
|
||||||
logger.info("[1/6] Getting new email from configured provider...")
|
logger.info("[1/5] Getting new email from configured provider...")
|
||||||
email = await email_provider.get_new_email()
|
email = await email_provider.get_new_email()
|
||||||
if not email:
|
if not email:
|
||||||
raise AutomationError(
|
raise AutomationError(
|
||||||
|
|
@ -266,127 +323,57 @@ async def register_chatgpt_account(
|
||||||
oauth_state = generate_state()
|
oauth_state = generate_state()
|
||||||
authorize_url = build_authorize_url(verifier, challenge, oauth_state)
|
authorize_url = build_authorize_url(verifier, challenge, oauth_state)
|
||||||
|
|
||||||
logger.info("[2/6] Registering ChatGPT for %s", email)
|
logger.info("[2/5] Registering ChatGPT for %s", email)
|
||||||
chatgpt_page = await context.new_page()
|
chatgpt_page = await context.new_page()
|
||||||
current_page = chatgpt_page
|
current_page = chatgpt_page
|
||||||
await chatgpt_page.goto("https://chatgpt.com")
|
await chatgpt_page.goto("https://chatgpt.com")
|
||||||
await chatgpt_page.wait_for_load_state("domcontentloaded")
|
await chatgpt_page.wait_for_load_state("domcontentloaded")
|
||||||
|
|
||||||
await chatgpt_page.get_by_text("Sign up for free", exact=True).click()
|
await chatgpt_page.get_by_text("Sign up for free", exact=True).click()
|
||||||
await chatgpt_page.wait_for_timeout(2000)
|
await chatgpt_page.locator('input[type="email"]').first.wait_for(
|
||||||
|
state="visible", timeout=15000
|
||||||
|
)
|
||||||
|
|
||||||
await chatgpt_page.locator('input[type="email"]').fill(email)
|
await chatgpt_page.locator('input[type="email"]').fill(email)
|
||||||
await chatgpt_page.wait_for_timeout(500)
|
await click_continue(chatgpt_page)
|
||||||
await chatgpt_page.get_by_role(
|
await chatgpt_page.locator('input[type="password"]').first.wait_for(
|
||||||
"button", name="Continue", exact=True
|
state="visible", timeout=15000
|
||||||
).click()
|
)
|
||||||
await chatgpt_page.wait_for_timeout(3000)
|
|
||||||
|
|
||||||
await chatgpt_page.locator('input[type="password"]').fill(password)
|
await chatgpt_page.locator('input[type="password"]').fill(password)
|
||||||
await chatgpt_page.wait_for_timeout(500)
|
await click_continue(chatgpt_page)
|
||||||
await chatgpt_page.get_by_role(
|
await chatgpt_page.get_by_placeholder("Code").first.wait_for(
|
||||||
"button", name="Continue", exact=True
|
state="visible", timeout=30000
|
||||||
).click()
|
)
|
||||||
await chatgpt_page.wait_for_timeout(5000)
|
|
||||||
|
|
||||||
logger.info("[3/6] Getting verification message from email provider...")
|
logger.info("[3/5] Getting verification message from email provider...")
|
||||||
code = await get_latest_code(email_provider, email)
|
code = await get_latest_code(email_provider, email)
|
||||||
if not code:
|
if not code:
|
||||||
raise AutomationError(
|
raise AutomationError(
|
||||||
"email_provider", "Email provider returned no verification message"
|
"email_provider", "Email provider returned no verification message"
|
||||||
)
|
)
|
||||||
logger.info("[3/6] Verification code extracted: %s", code)
|
logger.info("[3/5] Verification code extracted: %s", code)
|
||||||
used_codes = {code}
|
|
||||||
|
|
||||||
await chatgpt_page.bring_to_front()
|
await chatgpt_page.bring_to_front()
|
||||||
code_input = chatgpt_page.get_by_placeholder("Code")
|
code_input = chatgpt_page.get_by_placeholder("Code")
|
||||||
if await code_input.count() > 0:
|
if await code_input.count() > 0:
|
||||||
await code_input.fill(code)
|
await code_input.fill(code)
|
||||||
await chatgpt_page.wait_for_timeout(5000)
|
await click_continue(chatgpt_page)
|
||||||
|
|
||||||
continue_btn = chatgpt_page.get_by_role(
|
logger.info("[4/5] Setting profile...")
|
||||||
"button", name="Continue", exact=True
|
|
||||||
)
|
|
||||||
if await continue_btn.count() > 0:
|
|
||||||
await continue_btn.click()
|
|
||||||
await chatgpt_page.wait_for_timeout(5000)
|
|
||||||
|
|
||||||
logger.info("[4/6] Setting profile...")
|
|
||||||
name_input = chatgpt_page.get_by_placeholder("Full name")
|
name_input = chatgpt_page.get_by_placeholder("Full name")
|
||||||
|
await name_input.first.wait_for(state="visible", timeout=20000)
|
||||||
if await name_input.count() > 0:
|
if await name_input.count() > 0:
|
||||||
await name_input.fill(full_name)
|
await name_input.fill(full_name)
|
||||||
|
|
||||||
await chatgpt_page.wait_for_timeout(500)
|
|
||||||
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
await fill_date_field(chatgpt_page, birth_month, birth_day, birth_year)
|
||||||
await chatgpt_page.wait_for_timeout(1000)
|
profile_url = chatgpt_page.url
|
||||||
|
await click_continue(chatgpt_page)
|
||||||
continue_btn = chatgpt_page.get_by_role(
|
|
||||||
"button", name="Continue", exact=True
|
|
||||||
)
|
|
||||||
if await continue_btn.count() > 0:
|
|
||||||
await continue_btn.click()
|
|
||||||
|
|
||||||
logger.info("Account registered!")
|
logger.info("Account registered!")
|
||||||
await wait_for_signup_stabilization(chatgpt_page)
|
await wait_for_signup_stabilization(chatgpt_page, source_url=profile_url)
|
||||||
|
|
||||||
logger.info("[5/6] Skipping onboarding...")
|
logger.info("[5/5] Running OAuth flow to get tokens...")
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
skip_btn = chatgpt_page.locator(
|
|
||||||
'button:has-text("Skip"):not(:has-text("Skip Tour"))'
|
|
||||||
)
|
|
||||||
if await skip_btn.count() > 0:
|
|
||||||
for i in range(await skip_btn.count()):
|
|
||||||
try:
|
|
||||||
btn = skip_btn.nth(i)
|
|
||||||
if await btn.is_visible():
|
|
||||||
await btn.click(timeout=5000)
|
|
||||||
logger.info("Clicked: Skip")
|
|
||||||
await chatgpt_page.wait_for_timeout(1500)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
await chatgpt_page.wait_for_timeout(1000)
|
|
||||||
|
|
||||||
skip_tour = chatgpt_page.locator('button:has-text("Skip Tour")')
|
|
||||||
if await skip_tour.count() > 0:
|
|
||||||
try:
|
|
||||||
await skip_tour.first.wait_for(state="visible", timeout=5000)
|
|
||||||
await skip_tour.first.click(timeout=5000)
|
|
||||||
logger.info("Clicked: Skip Tour")
|
|
||||||
await chatgpt_page.wait_for_timeout(2000)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
await chatgpt_page.wait_for_timeout(2000)
|
|
||||||
|
|
||||||
for _ in range(3):
|
|
||||||
continue_btn = chatgpt_page.locator('button:has-text("Continue")')
|
|
||||||
if await continue_btn.count() > 0:
|
|
||||||
try:
|
|
||||||
await continue_btn.first.wait_for(state="visible", timeout=5000)
|
|
||||||
await continue_btn.first.click(timeout=5000)
|
|
||||||
logger.info("Clicked: Continue")
|
|
||||||
await chatgpt_page.wait_for_timeout(2000)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
await chatgpt_page.wait_for_timeout(2000)
|
|
||||||
|
|
||||||
okay_btn = chatgpt_page.locator('button:has-text("Okay, let")')
|
|
||||||
for _ in range(10):
|
|
||||||
try:
|
|
||||||
await okay_btn.first.wait_for(state="visible", timeout=3000)
|
|
||||||
await okay_btn.first.click(timeout=5000)
|
|
||||||
logger.info("Clicked: Okay, let's go")
|
|
||||||
await chatgpt_page.wait_for_timeout(3000)
|
|
||||||
break
|
|
||||||
except:
|
|
||||||
await chatgpt_page.wait_for_timeout(1000)
|
|
||||||
|
|
||||||
logger.info("Skipping subscription/card flow (disabled)")
|
|
||||||
await chatgpt_page.wait_for_timeout(2000)
|
|
||||||
|
|
||||||
logger.info("[6/6] Running OAuth flow to get tokens...")
|
|
||||||
oauth_page = await context.new_page()
|
oauth_page = await context.new_page()
|
||||||
current_page = oauth_page
|
current_page = oauth_page
|
||||||
|
|
||||||
|
|
@ -400,33 +387,34 @@ async def register_chatgpt_account(
|
||||||
oauth_page.on("request", handle_request)
|
oauth_page.on("request", handle_request)
|
||||||
|
|
||||||
await oauth_page.goto(authorize_url, wait_until="domcontentloaded")
|
await oauth_page.goto(authorize_url, wait_until="domcontentloaded")
|
||||||
await oauth_page.wait_for_timeout(2000)
|
await oauth_page.locator(
|
||||||
|
'input[type="email"], input[name="email"]'
|
||||||
|
).first.wait_for(state="visible", timeout=20000)
|
||||||
|
|
||||||
email_input = oauth_page.locator('input[type="email"], input[name="email"]')
|
email_input = oauth_page.locator('input[type="email"], input[name="email"]')
|
||||||
if await email_input.count() > 0:
|
if await email_input.count() > 0:
|
||||||
await email_input.first.fill(email)
|
await email_input.first.fill(email)
|
||||||
await oauth_page.wait_for_timeout(400)
|
|
||||||
|
|
||||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||||
if await continue_button.count() > 0:
|
if await continue_button.count() > 0:
|
||||||
await continue_button.first.click()
|
await continue_button.first.click()
|
||||||
await oauth_page.wait_for_timeout(2500)
|
await oauth_page.locator('input[type="password"]').first.wait_for(
|
||||||
|
state="visible", timeout=20000
|
||||||
|
)
|
||||||
|
|
||||||
password_input = oauth_page.locator('input[type="password"]')
|
password_input = oauth_page.locator('input[type="password"]')
|
||||||
if await password_input.count() > 0:
|
if await password_input.count() > 0:
|
||||||
await password_input.first.fill(password)
|
await password_input.first.fill(password)
|
||||||
await oauth_page.wait_for_timeout(400)
|
|
||||||
continue_button = oauth_page.get_by_role("button", name="Continue")
|
continue_button = oauth_page.get_by_role("button", name="Continue")
|
||||||
if await continue_button.count() > 0:
|
if await continue_button.count() > 0:
|
||||||
await continue_button.first.click()
|
await continue_button.first.click()
|
||||||
await oauth_page.wait_for_timeout(2500)
|
|
||||||
|
|
||||||
for label in ["Continue", "Allow", "Authorize"]:
|
for label in ["Continue", "Allow", "Authorize"]:
|
||||||
button = oauth_page.get_by_role("button", name=label)
|
button = oauth_page.get_by_role("button", name=label)
|
||||||
if await button.count() > 0:
|
if await button.count() > 0:
|
||||||
try:
|
try:
|
||||||
await button.first.click(timeout=5000)
|
await button.first.click(timeout=5000)
|
||||||
await oauth_page.wait_for_timeout(2000)
|
await oauth_page.wait_for_timeout(500)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,12 @@ import time
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
from providers.base import ProviderTokens
|
from providers.base import ProviderTokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
|
DATA_DIR = Path(os.environ.get("DATA_DIR", "./data"))
|
||||||
TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json"
|
TOKENS_FILE = DATA_DIR / "chatgpt_tokens.json"
|
||||||
|
|
||||||
|
|
@ -66,17 +69,17 @@ async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||||
async def get_valid_tokens() -> ProviderTokens | None:
|
async def get_valid_tokens() -> ProviderTokens | None:
|
||||||
tokens = load_tokens()
|
tokens = load_tokens()
|
||||||
if not tokens:
|
if not tokens:
|
||||||
print("No tokens found")
|
logger.info("No tokens found")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if tokens.is_expired:
|
if tokens.is_expired:
|
||||||
print("Token expired, refreshing...")
|
logger.info("Token expired, refreshing...")
|
||||||
if not tokens.refresh_token:
|
if not tokens.refresh_token:
|
||||||
print("No refresh token available")
|
logger.info("No refresh token available")
|
||||||
return None
|
return None
|
||||||
new_tokens = await refresh_tokens(tokens.refresh_token)
|
new_tokens = await refresh_tokens(tokens.refresh_token)
|
||||||
if not new_tokens:
|
if not new_tokens:
|
||||||
print("Failed to refresh token")
|
logger.warning("Failed to refresh token")
|
||||||
return None
|
return None
|
||||||
save_tokens(new_tokens)
|
save_tokens(new_tokens)
|
||||||
return new_tokens
|
return new_tokens
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,18 @@ def clamp_percent(value: Any) -> int:
|
||||||
return int(round(num))
|
return int(round(num))
|
||||||
|
|
||||||
|
|
||||||
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
def _parse_window(window: dict[str, Any] | None) -> dict[str, int] | None:
|
||||||
|
if not isinstance(window, dict):
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"used_percent": clamp_percent(window.get("used_percent") or 0),
|
||||||
|
"limit_window_seconds": int(window.get("limit_window_seconds") or 0),
|
||||||
|
"reset_after_seconds": int(window.get("reset_after_seconds") or 0),
|
||||||
|
"reset_at": int(window.get("reset_at") or 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage_data(access_token: str, timeout_ms: int = 10000) -> dict[str, Any] | None:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {access_token}",
|
"Authorization": f"Bearer {access_token}",
|
||||||
"User-Agent": "CodexProxy",
|
"User-Agent": "CodexProxy",
|
||||||
|
|
@ -33,18 +44,42 @@ def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res:
|
with urllib.request.urlopen(req, timeout=timeout_ms / 1000) as res:
|
||||||
body = res.read().decode("utf-8", errors="replace")
|
body = res.read().decode("utf-8", errors="replace")
|
||||||
except urllib.error.HTTPError as e:
|
except urllib.error.HTTPError:
|
||||||
return -1
|
return None
|
||||||
except urllib.error.URLError, socket.timeout:
|
except urllib.error.URLError, socket.timeout:
|
||||||
return -1
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(body)
|
data = json.loads(body)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return -1
|
return None
|
||||||
|
|
||||||
primary = (data.get("rate_limit") or {}).get("primary_window") or {}
|
rate_limit = data.get("rate_limit") or {}
|
||||||
|
primary = _parse_window(rate_limit.get("primary_window"))
|
||||||
|
secondary = _parse_window(rate_limit.get("secondary_window"))
|
||||||
|
|
||||||
|
used_candidates = []
|
||||||
if primary:
|
if primary:
|
||||||
return clamp_percent(primary.get("used_percent") or 0)
|
used_candidates.append(primary["used_percent"])
|
||||||
|
if secondary:
|
||||||
|
used_candidates.append(secondary["used_percent"])
|
||||||
|
|
||||||
return -1
|
if not used_candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
effective_used = max(used_candidates)
|
||||||
|
return {
|
||||||
|
"used_percent": effective_used,
|
||||||
|
"remaining_percent": max(0, 100 - effective_used),
|
||||||
|
"primary_window": primary,
|
||||||
|
"secondary_window": secondary,
|
||||||
|
"limit_reached": bool(rate_limit.get("limit_reached")),
|
||||||
|
"allowed": bool(rate_limit.get("allowed", True)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage_percent(access_token: str, timeout_ms: int = 10000) -> int:
|
||||||
|
data = get_usage_data(access_token, timeout_ms=timeout_ms)
|
||||||
|
if not data:
|
||||||
|
return -1
|
||||||
|
return int(data["used_percent"])
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,47 @@ def build_limit(usage_percent: int) -> dict[str, int | bool]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_provider_token_ready(provider_name: str):
|
||||||
|
provider = PROVIDERS.get(provider_name)
|
||||||
|
if not provider:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("[%s] Startup token check", provider_name)
|
||||||
|
token = await provider.get_token()
|
||||||
|
if not token:
|
||||||
|
logger.warning(
|
||||||
|
"[%s] Startup token check failed, forcing recreation", provider_name
|
||||||
|
)
|
||||||
|
if isinstance(provider, ChatGPTProvider):
|
||||||
|
token = await provider.force_recreate_token()
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
logger.error("[%s] Could not prepare token at startup", provider_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
usage_info = await provider.get_usage_info(token)
|
||||||
|
if "error" not in usage_info:
|
||||||
|
logger.info("[%s] Startup token is ready", provider_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"[%s] Startup token invalid for usage, forcing recreation", provider_name
|
||||||
|
)
|
||||||
|
if isinstance(provider, ChatGPTProvider):
|
||||||
|
token = await provider.force_recreate_token()
|
||||||
|
if token:
|
||||||
|
logger.info("[%s] Startup token recreated successfully", provider_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.error("[%s] Startup token recreation failed", provider_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup(app: web.Application):
|
||||||
|
del app
|
||||||
|
for provider_name in PROVIDERS.keys():
|
||||||
|
await ensure_provider_token_ready(provider_name)
|
||||||
|
|
||||||
|
|
||||||
async def issue_new_token(provider_name: str) -> str | None:
|
async def issue_new_token(provider_name: str) -> str | None:
|
||||||
provider = PROVIDERS.get(provider_name)
|
provider = PROVIDERS.get(provider_name)
|
||||||
if not provider:
|
if not provider:
|
||||||
|
|
@ -114,6 +155,33 @@ async def token_handler(request: web.Request) -> web.Response:
|
||||||
)
|
)
|
||||||
|
|
||||||
usage_percent = usage_info.get("used_percent", 0)
|
usage_percent = usage_info.get("used_percent", 0)
|
||||||
|
remaining_percent = usage_info.get("remaining_percent", max(0, 100 - usage_percent))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"[%s] token issued, used=%s%% remaining=%s%%",
|
||||||
|
provider_name,
|
||||||
|
usage_percent,
|
||||||
|
remaining_percent,
|
||||||
|
)
|
||||||
|
|
||||||
|
primary_window = usage_info.get("primary_window")
|
||||||
|
secondary_window = usage_info.get("secondary_window")
|
||||||
|
if primary_window:
|
||||||
|
logger.info(
|
||||||
|
"[%s] primary window: used=%s%% window=%ss reset_in=%ss",
|
||||||
|
provider_name,
|
||||||
|
primary_window.get("used_percent", 0),
|
||||||
|
primary_window.get("limit_window_seconds", 0),
|
||||||
|
primary_window.get("reset_after_seconds", 0),
|
||||||
|
)
|
||||||
|
if secondary_window:
|
||||||
|
logger.info(
|
||||||
|
"[%s] secondary window: used=%s%% window=%ss reset_in=%ss",
|
||||||
|
provider_name,
|
||||||
|
secondary_window.get("used_percent", 0),
|
||||||
|
secondary_window.get("limit_window_seconds", 0),
|
||||||
|
secondary_window.get("reset_after_seconds", 0),
|
||||||
|
)
|
||||||
|
|
||||||
# Trigger background refresh if needed
|
# Trigger background refresh if needed
|
||||||
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
if usage_percent >= USAGE_REFRESH_THRESHOLD:
|
||||||
|
|
@ -126,12 +194,17 @@ async def token_handler(request: web.Request) -> web.Response:
|
||||||
{
|
{
|
||||||
"token": token,
|
"token": token,
|
||||||
"limit": build_limit(usage_percent),
|
"limit": build_limit(usage_percent),
|
||||||
|
"usage": {
|
||||||
|
"primary_window": primary_window,
|
||||||
|
"secondary_window": secondary_window,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> web.Application:
|
def create_app() -> web.Application:
|
||||||
app = web.Application(middlewares=[request_log_middleware])
|
app = web.Application(middlewares=[request_log_middleware])
|
||||||
|
app.on_startup.append(on_startup)
|
||||||
# New route: /{provider}/token
|
# New route: /{provider}/token
|
||||||
app.router.add_get("/{provider}/token", token_handler)
|
app.router.add_get("/{provider}/token", token_handler)
|
||||||
# Legacy route for backward compatibility
|
# Legacy route for backward compatibility
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue