This commit is contained in:
parent
d6396e4050
commit
8b5449b1fd
15 changed files with 663 additions and 302 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -35,7 +36,7 @@ def _dict_to_tokens(data: dict[str, Any] | None) -> ProviderTokens | None:
|
|||
refresh_token=data["refresh_token"],
|
||||
expires_at=data["expires_at"],
|
||||
)
|
||||
except KeyError, TypeError:
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -54,8 +55,20 @@ def _load_raw() -> dict[str, Any] | None:
|
|||
|
||||
def _save_raw(data: dict[str, Any]) -> None:
|
||||
TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(TOKENS_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
prefix=f"{TOKENS_FILE.name}.",
|
||||
suffix=".tmp",
|
||||
dir=str(TOKENS_FILE.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, TOKENS_FILE)
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
|
|
@ -68,7 +81,6 @@ def _normalize_state(data: dict[str, Any] | None) -> dict[str, Any]:
|
|||
"next_account": data.get("next_account"),
|
||||
}
|
||||
|
||||
# Backward compatibility with old flat schema
|
||||
return {"active": data, "next_account": None}
|
||||
|
||||
|
||||
|
|
@ -79,9 +91,7 @@ def load_state() -> tuple[ProviderTokens | None, ProviderTokens | None]:
|
|||
return active, next_account
|
||||
|
||||
|
||||
def save_state(
|
||||
active: ProviderTokens | None, next_account: ProviderTokens | None
|
||||
) -> None:
|
||||
def save_state(active: ProviderTokens | None, next_account: ProviderTokens | None) -> None:
|
||||
payload = {
|
||||
"active": _tokens_to_dict(active) if active else None,
|
||||
"next_account": _tokens_to_dict(next_account) if next_account else None,
|
||||
|
|
@ -104,13 +114,8 @@ def save_tokens(tokens: ProviderTokens):
|
|||
save_state(tokens, next_account)
|
||||
|
||||
|
||||
def save_next_tokens(tokens: ProviderTokens):
|
||||
active, _ = load_state()
|
||||
save_state(active, tokens)
|
||||
|
||||
|
||||
def promote_next_tokens() -> bool:
|
||||
active, next_account = load_state()
|
||||
_, next_account = load_state()
|
||||
if not next_account:
|
||||
return False
|
||||
save_state(next_account, None)
|
||||
|
|
@ -123,42 +128,34 @@ def clear_next_tokens():
|
|||
|
||||
|
||||
async def refresh_tokens(refresh_token: str) -> ProviderTokens | None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
}
|
||||
async with session.post(TOKEN_URL, data=data) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
logger.warning("Token refresh failed: %s %s", resp.status, text)
|
||||
return None
|
||||
json_resp = await resp.json()
|
||||
expires_in = json_resp["expires_in"]
|
||||
return ProviderTokens(
|
||||
access_token=json_resp["access_token"],
|
||||
refresh_token=json_resp["refresh_token"],
|
||||
expires_at=time.time() + expires_in,
|
||||
)
|
||||
|
||||
|
||||
async def get_valid_tokens() -> ProviderTokens | None:
|
||||
tokens = load_tokens()
|
||||
if not tokens:
|
||||
logger.info("No tokens found")
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=15)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(TOKEN_URL, data=data) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
logger.warning("Token refresh failed: %s %s", resp.status, text)
|
||||
return None
|
||||
json_resp = await resp.json()
|
||||
except (aiohttp.ClientError, TimeoutError) as e:
|
||||
logger.warning("Token refresh request error: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Token refresh unexpected error: %s", e)
|
||||
return None
|
||||
|
||||
if tokens.is_expired:
|
||||
logger.info("Token expired, refreshing...")
|
||||
if not tokens.refresh_token:
|
||||
logger.info("No refresh token available")
|
||||
return None
|
||||
new_tokens = await refresh_tokens(tokens.refresh_token)
|
||||
if not new_tokens:
|
||||
logger.warning("Failed to refresh token")
|
||||
return None
|
||||
save_tokens(new_tokens)
|
||||
return new_tokens
|
||||
|
||||
return tokens
|
||||
try:
|
||||
expires_in = int(json_resp["expires_in"])
|
||||
return ProviderTokens(
|
||||
access_token=json_resp["access_token"],
|
||||
refresh_token=json_resp["refresh_token"],
|
||||
expires_at=time.time() + expires_in,
|
||||
)
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.warning("Token refresh response parse error: %s", e)
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue