ai/ai
1
0
Fork 0
ai/app/api/middleware.py

134 lines
4.1 KiB
Python

"""HTTP middleware utilities for API layer."""
from __future__ import annotations
import json
import logging
import os
import time
from collections.abc import AsyncIterator
from typing import Any
from fastapi import FastAPI, Request
from starlette.responses import StreamingResponse
logger = logging.getLogger("ai.http")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
"%(asctime)s %(levelname)s %(name)s %(message)s",
"%Y-%m-%d %H:%M:%S",
)
)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
def install_request_logging(app: FastAPI) -> None:
enabled = os.getenv("AI_REQUEST_LOG_ENABLED", "false").strip().lower()
detailed_logging_enabled = enabled in {"1", "true", "yes", "on"}
max_len_raw = os.getenv("AI_REQUEST_LOG_MAX_BODY_CHARS", "20000").strip()
try:
max_body_chars = max(1024, int(max_len_raw))
except ValueError:
max_body_chars = 20000
@app.middleware("http")
async def request_logging_middleware(request: Request, call_next):
started_at = time.perf_counter()
body_text = ""
if detailed_logging_enabled and request.method in {"POST", "PUT", "PATCH"}:
body_bytes = await request.body()
body_text = _format_body(body_bytes)
try:
response = await call_next(request)
except Exception:
elapsed_ms = (time.perf_counter() - started_at) * 1000
logger.exception(
"http request failed method=%s path=%s duration_ms=%.2f query=%s body=%s",
request.method,
request.url.path,
elapsed_ms,
request.url.query or "-",
body_text or "-",
)
raise
elapsed_ms = (time.perf_counter() - started_at) * 1000
if response.status_code >= 400:
response_body = await _read_response_body(response)
logger.warning(
"http request error method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s response_body=%s",
request.method,
request.url.path,
response.status_code,
elapsed_ms,
request.url.query or "-",
_truncate(body_text or "-", max_body_chars),
_truncate(response_body or "-", max_body_chars),
)
return response
if not detailed_logging_enabled:
return response
logger.info(
"http request method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s",
request.method,
request.url.path,
response.status_code,
elapsed_ms,
request.url.query or "-",
_truncate(body_text or "-", max_body_chars),
)
return response
def _format_body(body: bytes) -> str:
if not body:
return ""
try:
parsed: Any = json.loads(body)
text = json.dumps(parsed, ensure_ascii=True, separators=(",", ":"))
except Exception:
text = body.decode("utf-8", errors="replace")
return text
def _truncate(value: str, limit: int) -> str:
if len(value) <= limit:
return value
return f"{value[:limit]}...[truncated {len(value) - limit} chars]"
async def _read_response_body(response: Any) -> str:
body = getattr(response, "body", None)
if isinstance(body, (bytes, bytearray)):
return _format_body(bytes(body))
if isinstance(response, StreamingResponse):
return "<streaming-response>"
iterator = getattr(response, "body_iterator", None)
if iterator is None:
return ""
chunks: list[bytes] = []
async for chunk in iterator:
if isinstance(chunk, bytes):
chunks.append(chunk)
else:
chunks.append(str(chunk).encode("utf-8", errors="replace"))
raw = b"".join(chunks)
response.body_iterator = _iterate_once(raw)
return _format_body(raw)
async def _iterate_once(payload: bytes) -> AsyncIterator[bytes]:
yield payload