1
0
Fork 0
tg-proxy/proxy.py
2026-05-04 23:12:53 +03:00

1026 lines
33 KiB
Python

import json
import logging
import os
import time
import traceback
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any, cast
from fastapi import FastAPI, HTTPException, Query, Request
from pydantic import BaseModel, Field
from pyrogram.client import Client
from pyrogram.raw.functions.account.get_notify_exceptions import GetNotifyExceptions
from pyrogram.raw.functions.account.get_notify_settings import GetNotifySettings
from pyrogram.raw.functions.messages.get_dialog_filters import GetDialogFilters
from pyrogram.raw.types.input_notify_broadcasts import InputNotifyBroadcasts
from pyrogram.raw.types.input_notify_chats import InputNotifyChats
from pyrogram.raw.types.input_notify_users import InputNotifyUsers
from pyrogram.types import Dialog
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
API_ID = int(os.getenv("API_ID", "0"))
API_HASH = os.getenv("API_HASH", "")
MAX_LIMIT = 100
DIALOGS_CACHE_TTL = 300
MESSAGES_CACHE_TTL = 1200
HEAD_REFRESH_LIMIT = 20
client = Client("tg_proxy", api_id=API_ID, api_hash=API_HASH, workdir="session")
_dialogs_cache: list[Dialog] = []
_dialogs_cache_time: float = 0
_dialog_filters_cache: list[Any] = []
_dialog_filters_cache_time: float = 0
_folder_membership_cache: dict[int, list[int]] = {}
_folder_membership_cache_time: float = 0
_global_notify_settings: dict[str, int] = {}
_notify_exceptions_cache: dict[int, int] = {}
_raw_message_cache: dict[int, dict[int, tuple[Any, float]]] = {}
_history_page_cache: dict[
tuple[int, int, int], tuple[int | None, list[Any], float]
] = {}
_delta_cache: dict[tuple[int, str], tuple[int | None, list[Any], float]] = {}
class Chat(BaseModel):
id: int
title: str | None
type: str
chat_type: str
username: str | None = None
members_count: int | None = None
is_pinned: bool = False
pinned: bool = False
last_message_date: datetime | None = None
unread_count: int = 0
is_muted: bool = False
muted: bool = False
archived: bool = False
folder_id: int | None = None
folder_ids: list[int] | None = None
last_online_at: datetime | None = None
class DialogFolder(BaseModel):
id: int
title: str | None = None
type: str
icon_emoji: str | None = None
pinned_chat_ids: list[int] = Field(default_factory=list)
include_chat_ids: list[int] = Field(default_factory=list)
exclude_chat_ids: list[int] = Field(default_factory=list)
contacts: bool = False
non_contacts: bool = False
groups: bool = False
broadcasts: bool = False
bots: bool = False
exclude_muted: bool = False
exclude_read: bool = False
exclude_archived: bool = False
has_my_invites: bool | None = None
class Attachment(BaseModel):
type: str
filename: str | None = None
mime: str | None = None
duration: int | None = None
size: int | None = None
class Message(BaseModel):
id: int
date: datetime | None
text: str | None
from_user: str | None
chat_id: int
from_me: bool | None = None
is_outgoing: bool | None = None
reply_to_message_id: int | None = None
quoted_text: str | None = None
reply_snippet: str | None = None
edited_at: datetime | None = None
is_read: bool | None = None
attachments: list[Attachment] | None = None
class PaginatedChats(BaseModel):
items: list[Chat]
limit: int
offset: int
has_more: bool
remaining_count: int | None = None
class PaginatedMessages(BaseModel):
chat: Chat | None = None
items: list[Message]
limit: int
offset: int
has_more: bool
remaining_count: int | None = None
class AccessLogMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
xff = request.headers.get("x-forwarded-for", "")
client_ip = (
xff.split(",")[0].strip()
if xff
else request.client.host
if request.client
else "-"
)
logger.info(
f'{client_ip} - "{request.method} {request.url.path}" {response.status_code}'
)
return response
class PrettyJSONResponse(JSONResponse):
def render(self, content) -> bytes:
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=2,
).encode("utf-8")
def validate_limit(limit: int) -> int:
if limit > MAX_LIMIT:
raise HTTPException(status_code=400, detail=f"Limit cannot exceed {MAX_LIMIT}")
return limit
def normalize_chat_type(chat_type: str) -> str:
if chat_type in ("private", "bot", "direct"):
return "direct"
if chat_type in ("group", "supergroup", "forum"):
return "group"
if chat_type == "channel":
return "channel"
return chat_type
def matches_chat_type_filter(chat: Chat, chat_type_filter: str | None) -> bool:
if chat_type_filter is None:
return True
normalized_filter = chat_type_filter.strip().lower()
if normalized_filter == "direct":
return chat.type in ("private", "direct")
if normalized_filter == "bot":
return chat.type == "bot"
if normalized_filter == "group":
return chat.type in ("group", "supergroup", "forum")
if normalized_filter == "channel":
return chat.type == "channel"
return False
def normalize_int(value) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def normalize_datetime(value: datetime | None) -> datetime | None:
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def raw_peer_to_chat_id(peer: Any) -> int | None:
if peer is None:
return None
if hasattr(peer, "user_id"):
return normalize_int(getattr(peer, "user_id", None))
if hasattr(peer, "chat_id"):
chat_id = normalize_int(getattr(peer, "chat_id", None))
return -chat_id if chat_id is not None else None
if hasattr(peer, "channel_id"):
channel_id = normalize_int(getattr(peer, "channel_id", None))
if channel_id is None:
return None
return int(f"-100{channel_id}")
nested_peer = getattr(peer, "peer", None)
if nested_peer is not None and nested_peer is not peer:
return raw_peer_to_chat_id(nested_peer)
return None
def extract_filter_title(raw_filter: Any) -> str | None:
title = getattr(raw_filter, "title", None)
if title is None:
return None
if isinstance(title, str):
return title
text = getattr(title, "text", None)
if isinstance(text, str):
return text
return str(title)
def collect_filter_chat_ids(peers: list[Any] | None) -> list[int]:
chat_ids: list[int] = []
if not peers:
return chat_ids
for peer in peers:
chat_id = raw_peer_to_chat_id(peer)
if chat_id is not None:
chat_ids.append(chat_id)
return sorted(set(chat_ids))
def build_dialog_folder(raw_filter: Any) -> DialogFolder | None:
folder_id = normalize_int(getattr(raw_filter, "id", None))
if folder_id is None:
return None
filter_name = raw_filter.__class__.__name__
if filter_name == "DialogFilterDefault":
return None
folder_type = "chatlist" if filter_name == "DialogFilterChatlist" else "folder"
return DialogFolder(
id=folder_id,
title=extract_filter_title(raw_filter),
type=folder_type,
icon_emoji=getattr(raw_filter, "emoticon", None),
pinned_chat_ids=collect_filter_chat_ids(
getattr(raw_filter, "pinned_peers", None)
),
include_chat_ids=collect_filter_chat_ids(
getattr(raw_filter, "include_peers", None)
),
exclude_chat_ids=collect_filter_chat_ids(
getattr(raw_filter, "exclude_peers", None)
),
contacts=bool(getattr(raw_filter, "contacts", False)),
non_contacts=bool(getattr(raw_filter, "non_contacts", False)),
groups=bool(getattr(raw_filter, "groups", False)),
broadcasts=bool(getattr(raw_filter, "broadcasts", False)),
bots=bool(getattr(raw_filter, "bots", False)),
exclude_muted=bool(getattr(raw_filter, "exclude_muted", False)),
exclude_read=bool(getattr(raw_filter, "exclude_read", False)),
exclude_archived=bool(getattr(raw_filter, "exclude_archived", False)),
has_my_invites=getattr(raw_filter, "has_my_invites", None),
)
def dialog_is_read(dialog: Dialog) -> bool:
return (getattr(dialog, "unread_messages_count", 0) or 0) == 0
def dialog_matches_folder(dialog: Dialog, chat: Chat, raw_filter: Any) -> bool:
filter_name = raw_filter.__class__.__name__
if filter_name == "DialogFilterDefault":
return False
folder = build_dialog_folder(raw_filter)
if folder is None:
return False
explicit_include_ids = set(folder.pinned_chat_ids) | set(folder.include_chat_ids)
explicit_exclude_ids = set(folder.exclude_chat_ids)
if chat.id in explicit_exclude_ids:
return False
matches_positive_rule = chat.id in explicit_include_ids
if filter_name == "DialogFilterChatlist":
return matches_positive_rule
chat_type = chat.type
is_bot = chat_type == "bot" or bool(getattr(dialog.chat, "is_bot", False))
is_group = chat_type in ("group", "supergroup", "forum")
is_broadcast = chat_type == "channel"
is_contact = bool(getattr(dialog.chat, "is_contact", False))
is_private = chat_type in ("private", "direct")
is_non_contact = is_private and not is_contact and not is_bot
if folder.contacts and is_contact:
matches_positive_rule = True
if folder.non_contacts and is_non_contact:
matches_positive_rule = True
if folder.groups and is_group:
matches_positive_rule = True
if folder.broadcasts and is_broadcast:
matches_positive_rule = True
if folder.bots and is_bot:
matches_positive_rule = True
if not matches_positive_rule:
return False
if folder.exclude_muted and chat.is_muted:
return False
if folder.exclude_read and dialog_is_read(dialog):
return False
if folder.exclude_archived and chat.archived:
return False
return True
def build_folder_membership_map(
dialogs: list[Dialog], raw_filters: list[Any]
) -> dict[int, list[int]]:
memberships: dict[int, list[int]] = {}
for dialog in dialogs:
chat = build_chat(dialog)
matched_folder_ids: list[int] = []
for raw_filter in raw_filters:
folder = build_dialog_folder(raw_filter)
if folder is None:
continue
if dialog_matches_folder(dialog, chat, raw_filter):
matched_folder_ids.append(folder.id)
memberships[chat.id] = sorted(set(matched_folder_ids))
return memberships
def extract_message_snippet(message) -> str | None:
if not message:
return None
for attr in ("text", "caption"):
value = getattr(message, attr, None)
if value:
return value[:200]
media = getattr(message, "media", None)
if media:
media_value = getattr(media, "value", None) or str(media)
return f"[{media_value}]"
return None
def build_attachments(message) -> list[Attachment] | None:
attachments: list[Attachment] = []
media_type = getattr(getattr(message, "media", None), "value", None)
if media_type == "photo":
photo = getattr(message, "photo", None)
attachments.append(
Attachment(
type="photo",
mime=getattr(photo, "mime_type", None),
size=normalize_int(getattr(photo, "file_size", None)),
)
)
elif media_type == "document":
document = getattr(message, "document", None)
attachments.append(
Attachment(
type="document",
filename=getattr(document, "file_name", None),
mime=getattr(document, "mime_type", None),
size=normalize_int(getattr(document, "file_size", None)),
)
)
elif media_type == "video":
video = getattr(message, "video", None)
attachments.append(
Attachment(
type="video",
filename=getattr(video, "file_name", None),
mime=getattr(video, "mime_type", None),
duration=normalize_int(getattr(video, "duration", None)),
size=normalize_int(getattr(video, "file_size", None)),
)
)
elif media_type == "audio":
audio = getattr(message, "audio", None)
attachments.append(
Attachment(
type="audio",
filename=getattr(audio, "file_name", None),
mime=getattr(audio, "mime_type", None),
duration=normalize_int(getattr(audio, "duration", None)),
size=normalize_int(getattr(audio, "file_size", None)),
)
)
elif media_type == "voice":
voice = getattr(message, "voice", None)
attachments.append(
Attachment(
type="voice",
mime=getattr(voice, "mime_type", None),
duration=normalize_int(getattr(voice, "duration", None)),
size=normalize_int(getattr(voice, "file_size", None)),
)
)
elif media_type == "animation":
animation = getattr(message, "animation", None)
attachments.append(
Attachment(
type="animation",
filename=getattr(animation, "file_name", None),
mime=getattr(animation, "mime_type", None),
duration=normalize_int(getattr(animation, "duration", None)),
size=normalize_int(getattr(animation, "file_size", None)),
)
)
elif media_type == "sticker":
sticker = getattr(message, "sticker", None)
attachments.append(
Attachment(
type="sticker",
filename=getattr(sticker, "file_name", None),
mime=getattr(sticker, "mime_type", None),
size=normalize_int(getattr(sticker, "file_size", None)),
)
)
elif media_type == "video_note":
video_note = getattr(message, "video_note", None)
attachments.append(
Attachment(
type="video_note",
duration=normalize_int(getattr(video_note, "duration", None)),
size=normalize_int(getattr(video_note, "file_size", None)),
)
)
return attachments or None
async def get_cached_dialogs() -> list[Dialog]:
global \
_dialogs_cache, \
_dialogs_cache_time, \
_global_notify_settings, \
_notify_exceptions_cache
now = time.time()
if _dialogs_cache and (now - _dialogs_cache_time) < DIALOGS_CACHE_TTL:
logger.info("Returning dialogs from cache")
return _dialogs_cache
logger.info("Fetching dialogs from Telegram...")
dialogs = []
async for dialog in client.get_dialogs():
dialogs.append(dialog)
_dialogs_cache = dialogs
_dialogs_cache_time = now
logger.info(f"Cached {len(dialogs)} dialogs")
logger.info("Fetching global notify settings...")
global_settings: dict[str, int] = {}
try:
for name, input_notify in [
("users", InputNotifyUsers()),
("chats", InputNotifyChats()),
("broadcasts", InputNotifyBroadcasts()),
]:
result = await client.invoke(GetNotifySettings(peer=input_notify))
mute_until = getattr(result, "mute_until", 0) or 0
global_settings[name] = mute_until
logger.info(f"Global {name} mute_until: {mute_until}")
except Exception as e:
logger.warning(f"Failed to fetch global notify settings: {e}")
_global_notify_settings = global_settings
logger.info("Fetching notify exceptions...")
exceptions: dict[int, int] = {}
try:
for input_notify in [
InputNotifyUsers(),
InputNotifyChats(),
InputNotifyBroadcasts(),
]:
result = await client.invoke(GetNotifyExceptions(peer=input_notify))
updates = cast(list[Any], getattr(result, "updates", []))
for update in updates:
notify_peer = getattr(update, "peer", None)
notify_settings = getattr(update, "notify_settings", None)
if notify_peer is None or notify_settings is None:
continue
peer = getattr(notify_peer, "peer", notify_peer)
chat_id = None
if hasattr(peer, "user_id"):
chat_id = peer.user_id
elif hasattr(peer, "chat_id"):
chat_id = peer.chat_id
elif hasattr(peer, "channel_id"):
chat_id = -100 - peer.channel_id
if chat_id is not None:
mute_until = getattr(notify_settings, "mute_until", 0) or 0
exceptions[chat_id] = mute_until
except Exception as e:
logger.warning(f"Failed to fetch notify exceptions: {e}")
_notify_exceptions_cache = exceptions
logger.info(f"Cached {len(exceptions)} notify exceptions")
return dialogs
async def get_cached_dialog_filters() -> list[Any]:
global _dialog_filters_cache, _dialog_filters_cache_time
now = time.time()
if _dialog_filters_cache and (now - _dialog_filters_cache_time) < DIALOGS_CACHE_TTL:
logger.info("Returning dialog filters from cache")
return _dialog_filters_cache
logger.info("Fetching dialog filters from Telegram...")
try:
result = await client.invoke(GetDialogFilters())
filters = list(cast(list[Any], getattr(result, "filters", [])))
except Exception as e:
logger.warning(f"Failed to fetch dialog filters: {e}")
return []
_dialog_filters_cache = filters
_dialog_filters_cache_time = now
logger.info(f"Cached {len(filters)} dialog filters")
return filters
async def get_cached_folder_memberships() -> dict[int, list[int]]:
global _folder_membership_cache, _folder_membership_cache_time
now = time.time()
if (
_folder_membership_cache
and (now - _folder_membership_cache_time) < DIALOGS_CACHE_TTL
):
logger.info("Returning folder memberships from cache")
return _folder_membership_cache
dialogs = await get_cached_dialogs()
raw_filters = await get_cached_dialog_filters()
memberships = build_folder_membership_map(dialogs, raw_filters)
_folder_membership_cache = memberships
_folder_membership_cache_time = now
logger.info(f"Cached folder memberships for {len(memberships)} chats")
return memberships
def gc_message_caches() -> None:
now = time.time()
for chat_id in list(_raw_message_cache.keys()):
messages = _raw_message_cache[chat_id]
stale_ids = [
message_id
for message_id, (_, fetched_at) in messages.items()
if (now - fetched_at) >= MESSAGES_CACHE_TTL
]
for message_id in stale_ids:
del messages[message_id]
if not messages:
del _raw_message_cache[chat_id]
stale_history_keys = [
key
for key, (_, _, fetched_at) in _history_page_cache.items()
if (now - fetched_at) >= MESSAGES_CACHE_TTL
]
for key in stale_history_keys:
del _history_page_cache[key]
stale_delta_keys = [
key
for key, (_, _, fetched_at) in _delta_cache.items()
if (now - fetched_at) >= MESSAGES_CACHE_TTL
]
for key in stale_delta_keys:
del _delta_cache[key]
def cache_raw_messages(chat_id: int, messages: list[Any]) -> None:
if not messages:
return
now = time.time()
chat_cache = _raw_message_cache.setdefault(chat_id, {})
for message in messages:
chat_cache[message.id] = (message, now)
def get_cached_raw_message(chat_id: int, message_id: int) -> Any | None:
chat_cache = _raw_message_cache.get(chat_id)
if not chat_cache:
return None
cached_entry = chat_cache.get(message_id)
if not cached_entry:
return None
message, fetched_at = cached_entry
if (time.time() - fetched_at) >= MESSAGES_CACHE_TTL:
del chat_cache[message_id]
if not chat_cache:
del _raw_message_cache[chat_id]
return None
return message
async def fetch_raw_messages(chat_id: int, limit: int, offset: int = 0) -> list[Any]:
messages = []
async for message in client.get_chat_history(chat_id, limit=limit, offset=offset):
messages.append(message)
cache_raw_messages(chat_id, messages)
return messages
async def refresh_chat_head(chat_id: int) -> tuple[int | None, list[Any]]:
fresh_messages = await fetch_raw_messages(chat_id, HEAD_REFRESH_LIMIT)
head_id = fresh_messages[0].id if fresh_messages else None
return head_id, fresh_messages
async def get_cached_or_fetch_history_page(
chat_id: int,
limit: int,
offset: int,
) -> list[Any]:
gc_message_caches()
head_id, _ = await refresh_chat_head(chat_id)
cache_key = (chat_id, limit, offset)
cached_entry = _history_page_cache.get(cache_key)
if cached_entry is not None:
cached_head_id, cached_messages, fetched_at = cached_entry
if (
cached_head_id == head_id
and (time.time() - fetched_at) < MESSAGES_CACHE_TTL
):
cache_raw_messages(chat_id, cached_messages)
return cached_messages
messages = await fetch_raw_messages(chat_id, limit + 1, offset)
_history_page_cache[cache_key] = (head_id, messages, time.time())
return messages
async def get_cached_or_fetch_delta_messages(
chat_id: int, since: datetime
) -> list[Any]:
gc_message_caches()
head_id, _ = await refresh_chat_head(chat_id)
normalized_since = normalize_datetime(since)
since_key = normalized_since.isoformat() if normalized_since else "none"
cache_key = (chat_id, since_key)
cached_entry = _delta_cache.get(cache_key)
if cached_entry is not None:
cached_head_id, cached_messages, fetched_at = cached_entry
if (
cached_head_id == head_id
and (time.time() - fetched_at) < MESSAGES_CACHE_TTL
):
cache_raw_messages(chat_id, cached_messages)
return cached_messages
messages: list[Any] = []
async for message in client.get_chat_history(chat_id):
message_date = normalize_datetime(message.date)
if message_date and normalized_since and message_date < normalized_since:
break
messages.append(message)
messages.reverse()
cache_raw_messages(chat_id, messages)
_delta_cache[cache_key] = (head_id, messages, time.time())
return messages
async def get_dialog_by_chat_id(chat_id: int) -> Dialog | None:
dialogs = await get_cached_dialogs()
for dialog in dialogs:
if dialog.chat.id == chat_id:
return dialog
return None
def build_chat(dialog: Dialog, folder_ids: list[int] | None = None) -> Chat:
chat = dialog.chat
chat_type = str(chat.type.value) if chat.type else "unknown"
if chat.id in _notify_exceptions_cache:
muted_until = _notify_exceptions_cache[chat.id]
elif chat_type in ("private", "bot", "direct"):
muted_until = _global_notify_settings.get("users", 0)
elif chat_type in ("group", "supergroup", "forum"):
muted_until = _global_notify_settings.get("chats", 0)
elif chat_type == "channel":
muted_until = _global_notify_settings.get("broadcasts", 0)
else:
muted_until = 0
return Chat(
id=chat.id or 0,
title=chat.title or chat.first_name,
type=chat_type,
chat_type=normalize_chat_type(chat_type),
username=chat.username,
members_count=chat.members_count,
is_pinned=dialog.is_pinned or False,
pinned=dialog.is_pinned or False,
last_message_date=dialog.top_message.date if dialog.top_message else None,
unread_count=dialog.unread_messages_count or 0,
is_muted=muted_until != 0,
muted=muted_until != 0,
archived=bool(dialog.folder_id),
folder_id=getattr(dialog, "folder_id", None),
folder_ids=folder_ids,
last_online_at=getattr(chat, "last_online_date", None),
)
def build_message(message, chat_id: int, dialog: Dialog | None = None) -> Message:
from_me = getattr(message, "outgoing", None)
if dialog is None:
is_read = None
elif from_me:
read_max_id = getattr(dialog, "read_outbox_max_id", None)
is_read = None if read_max_id is None else message.id <= read_max_id
else:
read_max_id = getattr(dialog, "read_inbox_max_id", None)
is_read = None if read_max_id is None else message.id <= read_max_id
return Message(
id=message.id,
date=message.date,
text=message.text,
from_user=message.from_user.first_name if message.from_user else None,
chat_id=chat_id,
from_me=from_me,
is_outgoing=from_me,
reply_to_message_id=getattr(message, "reply_to_message_id", None),
quoted_text=None,
reply_snippet=None,
edited_at=getattr(message, "edit_date", None),
is_read=is_read,
attachments=build_attachments(message),
)
async def enrich_reply_fields(
messages: list[Message],
chat_id: int,
reply_cache: dict[int, str | None],
) -> list[Message]:
messages_by_id: dict[int, Message] = {message.id: message for message in messages}
missing_reply_ids: list[int] = []
for message in messages:
reply_to_message_id = message.reply_to_message_id
if not reply_to_message_id:
continue
if reply_to_message_id in messages_by_id:
continue
if reply_to_message_id in reply_cache:
continue
cached_reply = get_cached_raw_message(chat_id, reply_to_message_id)
if cached_reply is not None:
reply_cache[reply_to_message_id] = extract_message_snippet(cached_reply)
continue
if reply_to_message_id not in missing_reply_ids:
missing_reply_ids.append(reply_to_message_id)
for i in range(0, len(missing_reply_ids), 200):
batch_ids = missing_reply_ids[i : i + 200]
fetched_replies = await client.get_messages(chat_id, batch_ids)
fetched_replies_list = list(fetched_replies)
cache_raw_messages(chat_id, fetched_replies_list)
for reply in fetched_replies_list:
reply_cache[reply.id] = extract_message_snippet(reply)
for reply_id in batch_ids:
reply_cache.setdefault(reply_id, None)
for message in messages:
reply_to_message_id = message.reply_to_message_id
if not reply_to_message_id:
continue
reply_message = messages_by_id.get(reply_to_message_id)
if reply_message is not None:
reply_snippet = extract_message_snippet(reply_message)
else:
reply_snippet = reply_cache.get(reply_to_message_id)
message.quoted_text = reply_snippet
message.reply_snippet = reply_snippet
return messages
@asynccontextmanager
async def lifespan(_: FastAPI):
os.makedirs("session", exist_ok=True)
logger.info("Starting Telegram client...")
await client.start()
logger.info("Telegram client started")
yield
logger.info("Stopping Telegram client...")
await client.stop()
logger.info("Telegram client stopped")
app = FastAPI(lifespan=lifespan, default_response_class=PrettyJSONResponse)
app.add_middleware(AccessLogMiddleware)
@app.exception_handler(Exception)
async def unhandled_exception_handler(_: Request, exc: Exception):
details = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
logger.exception("Unhandled exception")
return PlainTextResponse(details, status_code=500)
@app.get(
"/folders", response_model=list[DialogFolder], response_model_exclude_none=True
)
async def get_folders() -> list[DialogFolder]:
raw_filters = await get_cached_dialog_filters()
folders: list[DialogFolder] = []
for raw_filter in raw_filters:
folder = build_dialog_folder(raw_filter)
if folder is not None:
folders.append(folder)
return sorted(folders, key=lambda folder: folder.id)
@app.get("/chats", response_model_exclude_none=True)
async def get_chats(
limit: int = Query(default=50, ge=1),
offset: int = Query(default=0, ge=0),
archived: bool | None = Query(default=None),
chat_type: str | None = Query(default=None),
folder_id: int | None = Query(default=None),
) -> PaginatedChats:
validate_limit(limit)
dialogs = await get_cached_dialogs()
folder_memberships = (
await get_cached_folder_memberships() if folder_id is not None else {}
)
if folder_id is not None:
folders = await get_folders()
valid_folder_ids = {folder.id for folder in folders}
if folder_id not in valid_folder_ids:
raise HTTPException(
status_code=400, detail=f"Unknown folder_id: {folder_id}"
)
chats = []
matched_count = 0
for dialog in dialogs:
dialog_chat_id = dialog.chat.id or 0
chat = build_chat(dialog, folder_memberships.get(dialog_chat_id))
if archived is not None and chat.archived != archived:
continue
if not matches_chat_type_filter(chat, chat_type):
continue
if folder_id is not None and folder_id not in (chat.folder_ids or []):
continue
if matched_count < offset:
matched_count += 1
continue
chats.append(chat)
if len(chats) >= limit + 1:
break
has_more = len(chats) > limit
total_matching = 0
for dialog in dialogs:
dialog_chat_id = dialog.chat.id or 0
chat = build_chat(dialog, folder_memberships.get(dialog_chat_id))
if archived is not None and chat.archived != archived:
continue
if not matches_chat_type_filter(chat, chat_type):
continue
if folder_id is not None and folder_id not in (chat.folder_ids or []):
continue
total_matching += 1
remaining_count = max(0, total_matching - offset - len(chats[:limit]))
return PaginatedChats(
items=chats[:limit],
limit=limit,
offset=offset,
has_more=has_more,
remaining_count=remaining_count,
)
@app.get("/chats/{chat_id}/messages", response_model_exclude_none=True)
async def get_chat_messages(
chat_id: int,
limit: int = Query(default=50, ge=1),
offset: int = Query(default=0, ge=0),
) -> PaginatedMessages:
validate_limit(limit)
dialog = await get_dialog_by_chat_id(chat_id)
messages: list[Message] = []
folder_memberships = await get_cached_folder_memberships()
raw_messages = await get_cached_or_fetch_history_page(chat_id, limit, offset)
for msg in raw_messages:
messages.append(build_message(msg, chat_id, dialog))
reply_cache: dict[int, str | None] = {}
messages = await enrich_reply_fields(messages, chat_id, reply_cache)
has_more = len(messages) > limit
chat = build_chat(dialog, folder_memberships.get(chat_id)) if dialog else None
return PaginatedMessages(
chat=chat,
items=messages[:limit],
limit=limit,
offset=offset,
has_more=has_more,
)
@app.get("/chats/{chat_id}/delta", response_model_exclude_none=True)
async def get_chat_messages_delta(
chat_id: int,
since: datetime = Query(...),
limit: int = Query(default=50, ge=0),
offset: int = Query(default=0, ge=0),
) -> PaginatedMessages:
validate_limit(limit)
dialog = await get_dialog_by_chat_id(chat_id)
messages: list[Message] = []
folder_memberships = await get_cached_folder_memberships()
raw_messages = await get_cached_or_fetch_delta_messages(chat_id, since)
for msg in raw_messages:
messages.append(build_message(msg, chat_id, dialog))
reply_cache: dict[int, str | None] = {}
messages = await enrich_reply_fields(messages, chat_id, reply_cache)
if limit == 0:
paginated_messages = messages[offset:]
has_more = False
remaining_count = 0
else:
paginated_messages = messages[offset : offset + limit + 1]
has_more = len(paginated_messages) > limit
remaining_count = max(
0, len(messages) - offset - len(paginated_messages[:limit])
)
chat = build_chat(dialog, folder_memberships.get(chat_id)) if dialog else None
items = paginated_messages if limit == 0 else paginated_messages[:limit]
return PaginatedMessages(
chat=chat,
items=items,
limit=limit,
offset=offset,
has_more=has_more,
remaining_count=remaining_count,
)