134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
"""HTTP middleware utilities for API layer."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, Request
|
|
from starlette.responses import StreamingResponse
|
|
|
|
logger = logging.getLogger("ai.http")
|
|
if not logger.handlers:
|
|
handler = logging.StreamHandler()
|
|
handler.setFormatter(
|
|
logging.Formatter(
|
|
"%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
"%Y-%m-%d %H:%M:%S",
|
|
)
|
|
)
|
|
logger.addHandler(handler)
|
|
logger.setLevel(logging.INFO)
|
|
logger.propagate = False
|
|
|
|
|
|
def install_request_logging(app: FastAPI) -> None:
|
|
enabled = os.getenv("AI_REQUEST_LOG_ENABLED", "false").strip().lower()
|
|
detailed_logging_enabled = enabled in {"1", "true", "yes", "on"}
|
|
max_len_raw = os.getenv("AI_REQUEST_LOG_MAX_BODY_CHARS", "20000").strip()
|
|
try:
|
|
max_body_chars = max(1024, int(max_len_raw))
|
|
except ValueError:
|
|
max_body_chars = 20000
|
|
|
|
@app.middleware("http")
|
|
async def request_logging_middleware(request: Request, call_next):
|
|
started_at = time.perf_counter()
|
|
|
|
body_text = ""
|
|
if detailed_logging_enabled and request.method in {"POST", "PUT", "PATCH"}:
|
|
body_bytes = await request.body()
|
|
body_text = _format_body(body_bytes)
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
except Exception:
|
|
elapsed_ms = (time.perf_counter() - started_at) * 1000
|
|
logger.exception(
|
|
"http request failed method=%s path=%s duration_ms=%.2f query=%s body=%s",
|
|
request.method,
|
|
request.url.path,
|
|
elapsed_ms,
|
|
request.url.query or "-",
|
|
body_text or "-",
|
|
)
|
|
raise
|
|
elapsed_ms = (time.perf_counter() - started_at) * 1000
|
|
|
|
if response.status_code >= 400:
|
|
response_body = await _read_response_body(response)
|
|
logger.warning(
|
|
"http request error method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s response_body=%s",
|
|
request.method,
|
|
request.url.path,
|
|
response.status_code,
|
|
elapsed_ms,
|
|
request.url.query or "-",
|
|
_truncate(body_text or "-", max_body_chars),
|
|
_truncate(response_body or "-", max_body_chars),
|
|
)
|
|
return response
|
|
|
|
if not detailed_logging_enabled:
|
|
return response
|
|
|
|
logger.info(
|
|
"http request method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s",
|
|
request.method,
|
|
request.url.path,
|
|
response.status_code,
|
|
elapsed_ms,
|
|
request.url.query or "-",
|
|
_truncate(body_text or "-", max_body_chars),
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
def _format_body(body: bytes) -> str:
|
|
if not body:
|
|
return ""
|
|
try:
|
|
parsed: Any = json.loads(body)
|
|
text = json.dumps(parsed, ensure_ascii=True, separators=(",", ":"))
|
|
except Exception:
|
|
text = body.decode("utf-8", errors="replace")
|
|
return text
|
|
|
|
|
|
def _truncate(value: str, limit: int) -> str:
|
|
if len(value) <= limit:
|
|
return value
|
|
return f"{value[:limit]}...[truncated {len(value) - limit} chars]"
|
|
|
|
|
|
async def _read_response_body(response: Any) -> str:
|
|
body = getattr(response, "body", None)
|
|
if isinstance(body, (bytes, bytearray)):
|
|
return _format_body(bytes(body))
|
|
|
|
if isinstance(response, StreamingResponse):
|
|
return "<streaming-response>"
|
|
|
|
iterator = getattr(response, "body_iterator", None)
|
|
if iterator is None:
|
|
return ""
|
|
|
|
chunks: list[bytes] = []
|
|
async for chunk in iterator:
|
|
if isinstance(chunk, bytes):
|
|
chunks.append(chunk)
|
|
else:
|
|
chunks.append(str(chunk).encode("utf-8", errors="replace"))
|
|
|
|
raw = b"".join(chunks)
|
|
response.body_iterator = _iterate_once(raw)
|
|
return _format_body(raw)
|
|
|
|
|
|
async def _iterate_once(payload: bytes) -> AsyncIterator[bytes]:
|
|
yield payload
|