feat(router): implement codex streaming, yaml auth, and opencode plugin
This commit is contained in:
parent
28708500a5
commit
5f6ed46a9c
33 changed files with 13223 additions and 615 deletions
|
|
@ -16,6 +16,6 @@ COPY app ./app
|
|||
|
||||
VOLUME ["/data"]
|
||||
|
||||
EXPOSE 8000
|
||||
EXPOSE 80
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,18 @@ from app.core.types import CoreChatRequest, CoreChunk, CoreMessage, CoreModel
|
|||
|
||||
def to_core_chat_request(payload: ChatCompletionsRequest) -> CoreChatRequest:
|
||||
extras = dict(payload.model_extra or {})
|
||||
reasoning_effort = payload.reasoning_effort
|
||||
reasoning_summary = payload.reasoning_summary
|
||||
if isinstance(payload.reasoning, dict):
|
||||
if reasoning_effort is None and isinstance(
|
||||
payload.reasoning.get("effort"), str
|
||||
):
|
||||
reasoning_effort = payload.reasoning["effort"]
|
||||
if reasoning_summary is None and isinstance(
|
||||
payload.reasoning.get("summary"), str
|
||||
):
|
||||
reasoning_summary = payload.reasoning["summary"]
|
||||
|
||||
return CoreChatRequest(
|
||||
model=payload.model,
|
||||
stream=payload.stream,
|
||||
|
|
@ -43,15 +55,21 @@ def to_core_chat_request(payload: ChatCompletionsRequest) -> CoreChatRequest:
|
|||
max_completion_tokens=payload.max_completion_tokens,
|
||||
max_tokens=payload.max_tokens,
|
||||
metadata=payload.metadata,
|
||||
provider=payload.provider,
|
||||
plugins=payload.plugins,
|
||||
session_id=payload.session_id,
|
||||
trace=payload.trace,
|
||||
modalities=list(payload.modalities) if payload.modalities is not None else None,
|
||||
models=payload.models,
|
||||
n=payload.n,
|
||||
parallel_tool_calls=payload.parallel_tool_calls,
|
||||
prediction=payload.prediction,
|
||||
presence_penalty=payload.presence_penalty,
|
||||
prompt_cache_key=payload.prompt_cache_key,
|
||||
prompt_cache_retention=payload.prompt_cache_retention,
|
||||
reasoning_effort=payload.reasoning_effort,
|
||||
reasoning_summary=payload.reasoning_summary,
|
||||
reasoning_effort=reasoning_effort,
|
||||
reasoning_summary=reasoning_summary,
|
||||
reasoning=payload.reasoning,
|
||||
response_format=payload.response_format,
|
||||
safety_identifier=payload.safety_identifier,
|
||||
seed=payload.seed,
|
||||
|
|
@ -60,6 +78,8 @@ def to_core_chat_request(payload: ChatCompletionsRequest) -> CoreChatRequest:
|
|||
store=payload.store,
|
||||
stream_options=payload.stream_options,
|
||||
temperature=payload.temperature,
|
||||
debug=payload.debug,
|
||||
image_config=payload.image_config,
|
||||
tool_choice=payload.tool_choice,
|
||||
tools=payload.tools,
|
||||
top_logprobs=payload.top_logprobs,
|
||||
|
|
@ -89,6 +109,7 @@ def to_api_chunk(
|
|||
role=chunk.role,
|
||||
content=chunk.content,
|
||||
reasoning_content=chunk.reasoning_content,
|
||||
reasoning_details=chunk.reasoning_details,
|
||||
tool_calls=chunk.tool_calls,
|
||||
),
|
||||
finish_reason=chunk.finish_reason,
|
||||
|
|
@ -106,6 +127,7 @@ def to_chat_completion_response(
|
|||
) -> ChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
reasoning_details_parts: list[dict[str, object]] = []
|
||||
tool_calls_parts: list[dict[str, object]] = []
|
||||
finish_reason: str | None = None
|
||||
for chunk in chunks:
|
||||
|
|
@ -113,6 +135,8 @@ def to_chat_completion_response(
|
|||
text_parts.append(chunk.content)
|
||||
if chunk.reasoning_content:
|
||||
reasoning_parts.append(chunk.reasoning_content)
|
||||
if chunk.reasoning_details:
|
||||
reasoning_details_parts.extend(chunk.reasoning_details)
|
||||
if chunk.tool_calls:
|
||||
tool_calls_parts.extend(chunk.tool_calls)
|
||||
if chunk.finish_reason is not None:
|
||||
|
|
@ -127,7 +151,9 @@ def to_chat_completion_response(
|
|||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="".join(text_parts),
|
||||
reasoning="".join(reasoning_parts) or None,
|
||||
reasoning_content="".join(reasoning_parts) or None,
|
||||
reasoning_details=reasoning_details_parts or None,
|
||||
tool_calls=_merge_tool_call_deltas(tool_calls_parts) or None,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
|
|
@ -180,7 +206,7 @@ def to_models_response(models: list[CoreModel]) -> ModelsResponse:
|
|||
id=model.id,
|
||||
created=model.created,
|
||||
owned_by=model.owned_by,
|
||||
name=model.name,
|
||||
name=_format_model_name(model),
|
||||
description=model.description,
|
||||
context_length=model.context_length,
|
||||
architecture=model.architecture,
|
||||
|
|
@ -192,3 +218,17 @@ def to_models_response(models: list[CoreModel]) -> ModelsResponse:
|
|||
for model in models
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _format_model_name(model: CoreModel) -> str | None:
|
||||
if model.name is None:
|
||||
return None
|
||||
|
||||
provider_label = model.provider_display_name
|
||||
if provider_label is None:
|
||||
provider_name, _, _ = model.id.partition("/")
|
||||
provider_label = provider_name or None
|
||||
|
||||
if provider_label is None:
|
||||
return model.name
|
||||
return f"{provider_label}: {model.name}"
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger("ai.http")
|
||||
if not logger.handlers:
|
||||
|
|
@ -26,15 +28,19 @@ logger.propagate = False
|
|||
|
||||
def install_request_logging(app: FastAPI) -> None:
|
||||
enabled = os.getenv("AI_REQUEST_LOG_ENABLED", "false").strip().lower()
|
||||
if enabled not in {"1", "true", "yes", "on"}:
|
||||
return
|
||||
detailed_logging_enabled = enabled in {"1", "true", "yes", "on"}
|
||||
max_len_raw = os.getenv("AI_REQUEST_LOG_MAX_BODY_CHARS", "20000").strip()
|
||||
try:
|
||||
max_body_chars = max(1024, int(max_len_raw))
|
||||
except ValueError:
|
||||
max_body_chars = 20000
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_logging_middleware(request: Request, call_next):
|
||||
started_at = time.perf_counter()
|
||||
|
||||
body_text = ""
|
||||
if request.method in {"POST", "PUT", "PATCH"}:
|
||||
if detailed_logging_enabled and request.method in {"POST", "PUT", "PATCH"}:
|
||||
body_bytes = await request.body()
|
||||
body_text = _format_body(body_bytes)
|
||||
|
||||
|
|
@ -53,6 +59,23 @@ def install_request_logging(app: FastAPI) -> None:
|
|||
raise
|
||||
elapsed_ms = (time.perf_counter() - started_at) * 1000
|
||||
|
||||
if response.status_code >= 400:
|
||||
response_body = await _read_response_body(response)
|
||||
logger.warning(
|
||||
"http request error method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s response_body=%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
response.status_code,
|
||||
elapsed_ms,
|
||||
request.url.query or "-",
|
||||
_truncate(body_text or "-", max_body_chars),
|
||||
_truncate(response_body or "-", max_body_chars),
|
||||
)
|
||||
return response
|
||||
|
||||
if not detailed_logging_enabled:
|
||||
return response
|
||||
|
||||
logger.info(
|
||||
"http request method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s",
|
||||
request.method,
|
||||
|
|
@ -60,7 +83,7 @@ def install_request_logging(app: FastAPI) -> None:
|
|||
response.status_code,
|
||||
elapsed_ms,
|
||||
request.url.query or "-",
|
||||
body_text or "-",
|
||||
_truncate(body_text or "-", max_body_chars),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -75,3 +98,37 @@ def _format_body(body: bytes) -> str:
|
|||
except Exception:
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
return text
|
||||
|
||||
|
||||
def _truncate(value: str, limit: int) -> str:
|
||||
if len(value) <= limit:
|
||||
return value
|
||||
return f"{value[:limit]}...[truncated {len(value) - limit} chars]"
|
||||
|
||||
|
||||
async def _read_response_body(response: Any) -> str:
|
||||
body = getattr(response, "body", None)
|
||||
if isinstance(body, (bytes, bytearray)):
|
||||
return _format_body(bytes(body))
|
||||
|
||||
if isinstance(response, StreamingResponse):
|
||||
return "<streaming-response>"
|
||||
|
||||
iterator = getattr(response, "body_iterator", None)
|
||||
if iterator is None:
|
||||
return ""
|
||||
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in iterator:
|
||||
if isinstance(chunk, bytes):
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
chunks.append(str(chunk).encode("utf-8", errors="replace"))
|
||||
|
||||
raw = b"".join(chunks)
|
||||
response.body_iterator = _iterate_once(raw)
|
||||
return _format_body(raw)
|
||||
|
||||
|
||||
async def _iterate_once(payload: bytes) -> AsyncIterator[bytes]:
|
||||
yield payload
|
||||
|
|
|
|||
|
|
@ -34,7 +34,12 @@ class ChatCompletionsRequest(BaseModel):
|
|||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
modalities: list[Literal["text", "audio"]] | None = None
|
||||
provider: dict[str, Any] | None = None
|
||||
plugins: list[dict[str, Any]] | None = None
|
||||
session_id: str | None = None
|
||||
trace: dict[str, Any] | None = None
|
||||
modalities: list[Literal["text", "image"]] | None = None
|
||||
models: list[Any] | None = None
|
||||
n: int | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
prediction: dict[str, Any] | None = None
|
||||
|
|
@ -48,6 +53,7 @@ class ChatCompletionsRequest(BaseModel):
|
|||
default=None,
|
||||
validation_alias=AliasChoices("reasoning_summary", "reasoningSummary"),
|
||||
)
|
||||
reasoning: dict[str, Any] | None = None
|
||||
response_format: dict[str, Any] | None = None
|
||||
safety_identifier: str | None = None
|
||||
seed: int | None = None
|
||||
|
|
@ -55,6 +61,8 @@ class ChatCompletionsRequest(BaseModel):
|
|||
stop: str | list[str] | None = None
|
||||
store: bool | None = None
|
||||
stream_options: dict[str, Any] | None = None
|
||||
debug: dict[str, Any] | None = None
|
||||
image_config: dict[str, Any] | None = None
|
||||
temperature: float | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
|
|
@ -71,6 +79,7 @@ class ChatCompletionChunkDelta(BaseModel):
|
|||
role: str | None = None
|
||||
content: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
reasoning_details: list[dict[str, Any]] | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
|
|
@ -97,7 +106,9 @@ class ChatCompletionMessage(BaseModel):
|
|||
|
||||
role: str = "assistant"
|
||||
content: str
|
||||
reasoning: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
reasoning_details: list[dict[str, Any]] | None = None
|
||||
refusal: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
|
@ -12,7 +11,7 @@ from app.config.models import AppConfig, AuthConfig, LoadedConfig, LoadedProvide
|
|||
|
||||
def load_config(config_path: Path, auth_path: Path) -> LoadedConfig:
|
||||
app_data = _read_yaml(config_path)
|
||||
auth_data = _read_json(auth_path)
|
||||
auth_data = _read_yaml(auth_path)
|
||||
|
||||
app_config = AppConfig.model_validate(app_data)
|
||||
auth_config = AuthConfig.model_validate(auth_data)
|
||||
|
|
@ -27,6 +26,8 @@ def load_config(config_path: Path, auth_path: Path) -> LoadedConfig:
|
|||
name=provider_name,
|
||||
url=provider.url,
|
||||
type=provider.type,
|
||||
display_name=provider.name,
|
||||
models=provider.models,
|
||||
whitelist=provider.whitelist,
|
||||
blacklist=provider.blacklist,
|
||||
auth=auth,
|
||||
|
|
@ -41,11 +42,3 @@ def _read_yaml(path: Path) -> dict:
|
|||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"YAML file '{path}' must contain an object")
|
||||
return raw
|
||||
|
||||
|
||||
def _read_json(path: Path) -> dict:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
raw = json.load(handle)
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"JSON file '{path}' must contain an object")
|
||||
return raw
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ class ProviderConfig(BaseModel):
|
|||
|
||||
url: str
|
||||
type: ProviderType
|
||||
name: str | None = None
|
||||
models: dict[str, dict[str, str]] | None = None
|
||||
whitelist: list[str] | None = None
|
||||
blacklist: list[str] | None = None
|
||||
|
||||
|
|
@ -39,7 +41,13 @@ class OAuthAuth(BaseModel):
|
|||
expires: int
|
||||
|
||||
|
||||
ProviderAuth = TokenAuth | OAuthAuth
|
||||
class UrlAuth(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
ProviderAuth = TokenAuth | OAuthAuth | UrlAuth
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
|
|
@ -54,6 +62,8 @@ class LoadedProviderConfig(BaseModel):
|
|||
name: str
|
||||
url: str
|
||||
type: ProviderType
|
||||
display_name: str | None = None
|
||||
models: dict[str, dict[str, str]] | None = None
|
||||
whitelist: list[str] | None = None
|
||||
blacklist: list[str] | None = None
|
||||
auth: ProviderAuth
|
||||
|
|
|
|||
306
app/core/models_dev.py
Normal file
306
app/core/models_dev.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
"""models.dev catalog lookup and model enrichment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.types import CoreModel, ProviderModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelsDevCatalog:
|
||||
"""Fetches and caches models.dev provider catalog."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
catalog_url: str = "https://models.dev/api.json",
|
||||
cache_ttl_seconds: float = 600.0,
|
||||
timeout_seconds: float = 10.0,
|
||||
fetch_catalog: Callable[[], Awaitable[dict[str, Any]]] | None = None,
|
||||
) -> None:
|
||||
self._catalog_url = catalog_url
|
||||
self._cache_ttl_seconds = cache_ttl_seconds
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._fetch_catalog = fetch_catalog or self._fetch_catalog_http
|
||||
self._catalog: dict[str, dict[str, Any]] | None = None
|
||||
self._catalog_expires_at = 0.0
|
||||
self._catalog_lock = asyncio.Lock()
|
||||
|
||||
async def get_provider_models(
|
||||
self, *, provider_name: str, provider_url: str
|
||||
) -> tuple[str | None, dict[str, dict[str, Any]]]:
|
||||
catalog = await self._get_catalog()
|
||||
if not catalog:
|
||||
return None, {}
|
||||
|
||||
provider_id = _resolve_provider_id(
|
||||
catalog,
|
||||
provider_name=provider_name,
|
||||
provider_url=provider_url,
|
||||
)
|
||||
if provider_id is None:
|
||||
return None, {}
|
||||
|
||||
provider = _as_dict(catalog.get(provider_id))
|
||||
provider_display_name = _as_str(provider.get("name"))
|
||||
raw_models = provider.get("models")
|
||||
if not isinstance(raw_models, dict):
|
||||
return provider_display_name, {}
|
||||
|
||||
return (
|
||||
provider_display_name,
|
||||
{
|
||||
model_id: model
|
||||
for model_id, model in raw_models.items()
|
||||
if isinstance(model_id, str) and isinstance(model, dict)
|
||||
},
|
||||
)
|
||||
|
||||
async def _get_catalog(self) -> dict[str, dict[str, Any]]:
|
||||
now = time.monotonic()
|
||||
if self._catalog is not None and now < self._catalog_expires_at:
|
||||
return self._catalog
|
||||
|
||||
async with self._catalog_lock:
|
||||
now = time.monotonic()
|
||||
if self._catalog is not None and now < self._catalog_expires_at:
|
||||
return self._catalog
|
||||
|
||||
try:
|
||||
fetched = await self._fetch_catalog()
|
||||
self._catalog = _coerce_catalog(fetched)
|
||||
except Exception:
|
||||
logger.exception("failed to fetch models.dev catalog")
|
||||
self._catalog = {}
|
||||
|
||||
self._catalog_expires_at = now + self._cache_ttl_seconds
|
||||
return self._catalog
|
||||
|
||||
async def _fetch_catalog_http(self) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
|
||||
response = await client.get(self._catalog_url)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
|
||||
def to_core_model(
|
||||
*,
|
||||
provider_name: str,
|
||||
provider_model: ProviderModel,
|
||||
models_dev_model: dict[str, Any] | None = None,
|
||||
provider_display_name: str | None = None,
|
||||
model_override: dict[str, Any] | None = None,
|
||||
) -> CoreModel:
|
||||
override_name = _as_str(model_override.get("name")) if model_override else None
|
||||
models_dev_name = (
|
||||
_as_str(models_dev_model.get("name")) if models_dev_model else None
|
||||
)
|
||||
name = override_name or models_dev_name or provider_model.name
|
||||
|
||||
models_dev_context_length = (
|
||||
_context_length_from_models_dev(models_dev_model) if models_dev_model else None
|
||||
)
|
||||
context_length = models_dev_context_length or provider_model.context_length
|
||||
|
||||
models_dev_architecture = (
|
||||
_architecture_from_models_dev(models_dev_model) if models_dev_model else None
|
||||
)
|
||||
architecture = models_dev_architecture or provider_model.architecture
|
||||
|
||||
models_dev_pricing = (
|
||||
_pricing_from_models_dev(models_dev_model) if models_dev_model else None
|
||||
)
|
||||
pricing = models_dev_pricing or provider_model.pricing
|
||||
|
||||
models_dev_supported_parameters = (
|
||||
_supported_parameters_from_models_dev(models_dev_model)
|
||||
if models_dev_model
|
||||
else None
|
||||
)
|
||||
supported_parameters = (
|
||||
models_dev_supported_parameters or provider_model.supported_parameters
|
||||
)
|
||||
|
||||
models_dev_created = (
|
||||
_created_from_models_dev(models_dev_model) if models_dev_model else None
|
||||
)
|
||||
created = models_dev_created or provider_model.created
|
||||
|
||||
models_dev_owned_by = (
|
||||
_as_str(models_dev_model.get("provider")) if models_dev_model else None
|
||||
)
|
||||
owned_by = models_dev_owned_by or provider_model.owned_by
|
||||
|
||||
return CoreModel(
|
||||
id=f"{provider_name}/{provider_model.id}",
|
||||
created=created or 0,
|
||||
owned_by=owned_by or "wzray",
|
||||
name=name,
|
||||
provider_display_name=provider_display_name,
|
||||
description=provider_model.description,
|
||||
context_length=context_length,
|
||||
architecture=architecture,
|
||||
pricing=pricing,
|
||||
supported_parameters=supported_parameters,
|
||||
settings=provider_model.settings,
|
||||
opencode=provider_model.opencode,
|
||||
config_override=model_override,
|
||||
)
|
||||
|
||||
|
||||
def _coerce_catalog(raw: dict[str, Any]) -> dict[str, dict[str, Any]]:
|
||||
catalog: dict[str, dict[str, Any]] = {}
|
||||
for key, value in raw.items():
|
||||
if isinstance(key, str) and isinstance(value, dict):
|
||||
catalog[key] = value
|
||||
return catalog
|
||||
|
||||
|
||||
def _resolve_provider_id(
|
||||
catalog: dict[str, dict[str, Any]], *, provider_name: str, provider_url: str
|
||||
) -> str | None:
|
||||
if provider_name in catalog:
|
||||
return provider_name
|
||||
|
||||
provider_host = _host(provider_url)
|
||||
if provider_host is not None:
|
||||
for provider_id, provider in catalog.items():
|
||||
api_url = _as_str(provider.get("api"))
|
||||
if api_url is None:
|
||||
continue
|
||||
if _host(api_url) == provider_host:
|
||||
return provider_id
|
||||
|
||||
normalized_name = _normalize_token(provider_name)
|
||||
candidates: list[tuple[int, int, str]] = []
|
||||
for provider_id in catalog:
|
||||
normalized_id = _normalize_token(provider_id)
|
||||
if not normalized_id:
|
||||
continue
|
||||
if normalized_name.startswith(normalized_id) or normalized_id.startswith(
|
||||
normalized_name
|
||||
):
|
||||
candidates.append(
|
||||
(
|
||||
abs(len(normalized_name) - len(normalized_id)),
|
||||
-len(normalized_id),
|
||||
provider_id,
|
||||
)
|
||||
)
|
||||
if candidates:
|
||||
candidates.sort()
|
||||
return candidates[0][2]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _context_length_from_models_dev(model: dict[str, Any]) -> int | None:
|
||||
limit = _as_dict(model.get("limit"))
|
||||
context = limit.get("context")
|
||||
return context if isinstance(context, int) else None
|
||||
|
||||
|
||||
def _architecture_from_models_dev(model: dict[str, Any]) -> dict[str, Any] | None:
|
||||
modalities = _as_dict(model.get("modalities"))
|
||||
architecture: dict[str, Any] = {}
|
||||
|
||||
input_modalities = modalities.get("input")
|
||||
if isinstance(input_modalities, list):
|
||||
architecture["input_modalities"] = [
|
||||
str(modality) for modality in input_modalities if isinstance(modality, str)
|
||||
]
|
||||
|
||||
output_modalities = modalities.get("output")
|
||||
if isinstance(output_modalities, list):
|
||||
architecture["output_modalities"] = [
|
||||
str(modality) for modality in output_modalities if isinstance(modality, str)
|
||||
]
|
||||
|
||||
family = _as_str(model.get("family"))
|
||||
if family is not None:
|
||||
architecture["family"] = family
|
||||
|
||||
return architecture or None
|
||||
|
||||
|
||||
def _pricing_from_models_dev(model: dict[str, Any]) -> dict[str, Any] | None:
|
||||
cost = model.get("cost")
|
||||
return cost if isinstance(cost, dict) else None
|
||||
|
||||
|
||||
def _supported_parameters_from_models_dev(model: dict[str, Any]) -> list[str] | None:
|
||||
supported: list[str] = []
|
||||
if _as_bool(model.get("reasoning")):
|
||||
supported.extend(["reasoning_effort", "reasoning_summary"])
|
||||
|
||||
if _as_bool(model.get("tool_call")):
|
||||
supported.extend(["tools", "tool_choice", "parallel_tool_calls"])
|
||||
|
||||
if _as_bool(model.get("structured_output")):
|
||||
supported.append("response_format")
|
||||
|
||||
output_limit = _as_dict(model.get("limit")).get("output")
|
||||
if isinstance(output_limit, int):
|
||||
supported.extend(["max_tokens", "max_completion_tokens"])
|
||||
|
||||
modalities = model.get("modalities")
|
||||
if isinstance(modalities, dict):
|
||||
supported.append("modalities")
|
||||
|
||||
return sorted(set(supported)) or None
|
||||
|
||||
|
||||
def _created_from_models_dev(model: dict[str, Any]) -> int | None:
|
||||
release_date = _as_str(model.get("release_date"))
|
||||
if release_date is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = datetime.fromisoformat(release_date)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
parsed = parsed.astimezone(timezone.utc)
|
||||
|
||||
return int(parsed.timestamp())
|
||||
|
||||
|
||||
def _host(url: str) -> str | None:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return None
|
||||
if not parsed.hostname:
|
||||
return None
|
||||
return parsed.hostname.lower()
|
||||
|
||||
|
||||
def _normalize_token(value: str) -> str:
|
||||
return "".join(ch for ch in value.lower() if ch.isalnum())
|
||||
|
||||
|
||||
def _as_dict(raw: Any) -> dict[str, Any]:
|
||||
return raw if isinstance(raw, dict) else {}
|
||||
|
||||
|
||||
def _as_str(value: Any) -> str | None:
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def _as_bool(value: Any) -> bool:
|
||||
return value is True
|
||||
|
|
@ -2,18 +2,42 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.core.models_dev import ModelsDevCatalog, to_core_model
|
||||
from app.core.errors import InvalidModelError, ProviderNotFoundError
|
||||
from app.core.types import CoreChatRequest, CoreChunk, CoreModel, ProviderChatRequest
|
||||
from app.core.types import (
|
||||
CoreChatRequest,
|
||||
CoreChunk,
|
||||
CoreModel,
|
||||
ProviderChatRequest,
|
||||
ProviderModel,
|
||||
)
|
||||
from app.providers.base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouterCore:
|
||||
"""Routes normalized requests to a specific provider instance."""
|
||||
|
||||
def __init__(self, providers: dict[str, BaseProvider]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
providers: dict[str, BaseProvider],
|
||||
*,
|
||||
models_cache_ttl_seconds: float = 600.0,
|
||||
models_dev_catalog: ModelsDevCatalog | None = None,
|
||||
) -> None:
|
||||
self.providers = providers
|
||||
self._models_cache_ttl_seconds = models_cache_ttl_seconds
|
||||
self._models_dev_catalog = models_dev_catalog or ModelsDevCatalog()
|
||||
self._models_cache: list[CoreModel] | None = None
|
||||
self._models_cache_expires_at = 0.0
|
||||
self._models_cache_lock = asyncio.Lock()
|
||||
|
||||
def resolve_provider(self, routed_model: str) -> tuple[BaseProvider, str]:
|
||||
provider_name, upstream_model = split_routed_model(routed_model)
|
||||
|
|
@ -36,7 +60,12 @@ class RouterCore:
|
|||
max_completion_tokens=request.max_completion_tokens,
|
||||
max_tokens=request.max_tokens,
|
||||
metadata=request.metadata,
|
||||
provider=request.provider,
|
||||
plugins=request.plugins,
|
||||
session_id=request.session_id,
|
||||
trace=request.trace,
|
||||
modalities=request.modalities,
|
||||
models=request.models,
|
||||
n=request.n,
|
||||
parallel_tool_calls=request.parallel_tool_calls,
|
||||
prediction=request.prediction,
|
||||
|
|
@ -45,6 +74,7 @@ class RouterCore:
|
|||
prompt_cache_retention=request.prompt_cache_retention,
|
||||
reasoning_effort=request.reasoning_effort,
|
||||
reasoning_summary=request.reasoning_summary,
|
||||
reasoning=request.reasoning,
|
||||
response_format=request.response_format,
|
||||
safety_identifier=request.safety_identifier,
|
||||
seed=request.seed,
|
||||
|
|
@ -53,6 +83,8 @@ class RouterCore:
|
|||
store=request.store,
|
||||
stream_options=request.stream_options,
|
||||
temperature=request.temperature,
|
||||
debug=request.debug,
|
||||
image_config=request.image_config,
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
top_logprobs=request.top_logprobs,
|
||||
|
|
@ -66,28 +98,64 @@ class RouterCore:
|
|||
yield chunk
|
||||
|
||||
async def list_models(self) -> list[CoreModel]:
|
||||
models: list[CoreModel] = []
|
||||
for provider_name, provider in self.providers.items():
|
||||
provider_models = await provider.list_models()
|
||||
for model in provider_models:
|
||||
if not provider.is_model_allowed(model.id):
|
||||
continue
|
||||
models.append(
|
||||
CoreModel(
|
||||
id=f"{provider_name}/{model.id}",
|
||||
created=model.created or 0,
|
||||
owned_by=model.owned_by or "wzray",
|
||||
name=model.name,
|
||||
description=model.description,
|
||||
context_length=model.context_length,
|
||||
architecture=model.architecture,
|
||||
pricing=model.pricing,
|
||||
supported_parameters=model.supported_parameters,
|
||||
settings=model.settings,
|
||||
opencode=model.opencode,
|
||||
now = time.monotonic()
|
||||
if self._models_cache is not None and now < self._models_cache_expires_at:
|
||||
return list(self._models_cache)
|
||||
|
||||
async with self._models_cache_lock:
|
||||
now = time.monotonic()
|
||||
if self._models_cache is not None and now < self._models_cache_expires_at:
|
||||
return list(self._models_cache)
|
||||
|
||||
models: list[CoreModel] = []
|
||||
for provider_name, provider in self.providers.items():
|
||||
try:
|
||||
provider_models = await provider.list_models()
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"models listing failed provider=%s base_url=%s error=%s",
|
||||
provider_name,
|
||||
provider.base_url,
|
||||
repr(exc),
|
||||
)
|
||||
continue
|
||||
|
||||
(
|
||||
provider_display_name,
|
||||
models_dev_models,
|
||||
) = await self._models_dev_catalog.get_provider_models(
|
||||
provider_name=provider_name,
|
||||
provider_url=provider.base_url,
|
||||
)
|
||||
return models
|
||||
for model in provider_models:
|
||||
if not provider.is_model_allowed(model.id):
|
||||
continue
|
||||
|
||||
model_override = None
|
||||
if provider.models_config is not None:
|
||||
raw_override = provider.models_config.get(model.id)
|
||||
if isinstance(raw_override, dict):
|
||||
model_override = raw_override
|
||||
|
||||
provider_display_name = (
|
||||
provider.display_name or provider_display_name
|
||||
)
|
||||
|
||||
models.append(
|
||||
to_core_model(
|
||||
provider_name=provider_name,
|
||||
provider_model=model,
|
||||
models_dev_model=models_dev_models.get(model.id),
|
||||
provider_display_name=provider_display_name,
|
||||
model_override=model_override,
|
||||
)
|
||||
)
|
||||
|
||||
self._models_cache = list(models)
|
||||
self._models_cache_expires_at = (
|
||||
time.monotonic() + self._models_cache_ttl_seconds
|
||||
)
|
||||
return list(models)
|
||||
|
||||
|
||||
def split_routed_model(routed_model: str) -> tuple[str, str]:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,12 @@ class CoreChatRequest:
|
|||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
provider: dict[str, Any] | None = None
|
||||
plugins: list[dict[str, Any]] | None = None
|
||||
session_id: str | None = None
|
||||
trace: dict[str, Any] | None = None
|
||||
modalities: list[str] | None = None
|
||||
models: list[Any] | None = None
|
||||
n: int | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
prediction: dict[str, Any] | None = None
|
||||
|
|
@ -41,6 +46,7 @@ class CoreChatRequest:
|
|||
prompt_cache_retention: str | None = None
|
||||
reasoning_effort: str | None = None
|
||||
reasoning_summary: str | None = None
|
||||
reasoning: dict[str, Any] | None = None
|
||||
response_format: dict[str, Any] | None = None
|
||||
safety_identifier: str | None = None
|
||||
seed: int | None = None
|
||||
|
|
@ -49,6 +55,8 @@ class CoreChatRequest:
|
|||
store: bool | None = None
|
||||
stream_options: dict[str, Any] | None = None
|
||||
temperature: float | None = None
|
||||
debug: dict[str, Any] | None = None
|
||||
image_config: dict[str, Any] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
top_logprobs: int | None = None
|
||||
|
|
@ -71,7 +79,12 @@ class ProviderChatRequest:
|
|||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
provider: dict[str, Any] | None = None
|
||||
plugins: list[dict[str, Any]] | None = None
|
||||
session_id: str | None = None
|
||||
trace: dict[str, Any] | None = None
|
||||
modalities: list[str] | None = None
|
||||
models: list[Any] | None = None
|
||||
n: int | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
prediction: dict[str, Any] | None = None
|
||||
|
|
@ -80,6 +93,7 @@ class ProviderChatRequest:
|
|||
prompt_cache_retention: str | None = None
|
||||
reasoning_effort: str | None = None
|
||||
reasoning_summary: str | None = None
|
||||
reasoning: dict[str, Any] | None = None
|
||||
response_format: dict[str, Any] | None = None
|
||||
safety_identifier: str | None = None
|
||||
seed: int | None = None
|
||||
|
|
@ -88,6 +102,8 @@ class ProviderChatRequest:
|
|||
store: bool | None = None
|
||||
stream_options: dict[str, Any] | None = None
|
||||
temperature: float | None = None
|
||||
debug: dict[str, Any] | None = None
|
||||
image_config: dict[str, Any] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
top_logprobs: int | None = None
|
||||
|
|
@ -120,6 +136,7 @@ class CoreModel:
|
|||
created: int = 0
|
||||
owned_by: str = "wzray"
|
||||
name: str | None = None
|
||||
provider_display_name: str | None = None
|
||||
description: str | None = None
|
||||
context_length: int | None = None
|
||||
architecture: dict[str, Any] | None = None
|
||||
|
|
@ -127,6 +144,7 @@ class CoreModel:
|
|||
supported_parameters: list[str] | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
opencode: dict[str, Any] | None = None
|
||||
config_override: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
|
@ -135,5 +153,6 @@ class CoreChunk:
|
|||
role: str | None = None
|
||||
content: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
reasoning_details: list[dict[str, Any]] | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
finish_reason: str | None = None
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from app.providers.factory import build_provider_registry
|
|||
def _resolve_paths() -> tuple[Path, Path]:
|
||||
data_dir = Path(os.getenv("AI_ROUTER_DATA_DIR", "."))
|
||||
config_path = Path(os.getenv("AI_ROUTER_CONFIG", str(data_dir / "config.yml")))
|
||||
auth_path = Path(os.getenv("AI_ROUTER_AUTH", str(data_dir / "auth.json")))
|
||||
auth_path = Path(os.getenv("AI_ROUTER_AUTH", str(data_dir / "auth.yml")))
|
||||
return config_path, auth_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from app.core.errors import ModelNotAllowedError
|
||||
from app.core.types import CoreChunk, ProviderChatRequest, ProviderModel
|
||||
|
|
@ -30,6 +31,8 @@ class BaseProvider(ABC):
|
|||
self.api_type = api_type
|
||||
self.whitelist = whitelist
|
||||
self.blacklist = blacklist
|
||||
self.display_name: str | None = None
|
||||
self.models_config: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
def ensure_model_allowed(self, model: str) -> None:
|
||||
if self.whitelist is not None and model not in self.whitelist:
|
||||
|
|
|
|||
84
app/providers/codex_responses/models.py
Normal file
84
app/providers/codex_responses/models.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""Model list mapping for codex responses provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.core.types import ProviderModel
|
||||
from app.providers.codex_responses.utils import _to_dict
|
||||
|
||||
|
||||
def _coerce_model_items(raw: Any) -> list[dict[str, Any]]:
|
||||
data = getattr(raw, "data", None)
|
||||
if isinstance(data, list):
|
||||
return [_to_dict(item) for item in data]
|
||||
models = getattr(raw, "models", None)
|
||||
if isinstance(models, list):
|
||||
return [_to_dict(item) for item in models]
|
||||
if isinstance(raw, list):
|
||||
return [_to_dict(item) for item in raw]
|
||||
if isinstance(raw, dict):
|
||||
models = raw.get("models")
|
||||
if isinstance(models, list):
|
||||
return [_to_dict(item) for item in models]
|
||||
data = raw.get("data")
|
||||
if isinstance(data, list):
|
||||
return [_to_dict(item) for item in data]
|
||||
return []
|
||||
|
||||
|
||||
def _to_provider_model(raw: dict[str, Any]) -> ProviderModel:
|
||||
model_id = raw.get("id")
|
||||
if not isinstance(model_id, str):
|
||||
model_id = raw.get("slug") if isinstance(raw.get("slug"), str) else None
|
||||
if model_id is None:
|
||||
raise ValueError("Codex model item missing id/slug")
|
||||
|
||||
arch = raw.get("architecture")
|
||||
if isinstance(arch, dict):
|
||||
arch = {k: v for k, v in arch.items() if k != "tokenizer"}
|
||||
else:
|
||||
input_modalities = raw.get("input_modalities")
|
||||
if isinstance(input_modalities, list):
|
||||
arch = {"input_modalities": input_modalities}
|
||||
else:
|
||||
arch = None
|
||||
|
||||
supported = raw.get("supported_parameters")
|
||||
supported_parameters = None
|
||||
if isinstance(supported, list):
|
||||
supported_parameters = [str(v) for v in supported]
|
||||
|
||||
context_length = raw.get("context_length")
|
||||
if not isinstance(context_length, int):
|
||||
context_length = raw.get("context_window")
|
||||
if not isinstance(context_length, int):
|
||||
top_provider = raw.get("top_provider")
|
||||
if isinstance(top_provider, dict) and isinstance(
|
||||
top_provider.get("context_length"), int
|
||||
):
|
||||
context_length = top_provider.get("context_length")
|
||||
else:
|
||||
context_length = None
|
||||
|
||||
return ProviderModel(
|
||||
id=model_id,
|
||||
name=(
|
||||
raw.get("name")
|
||||
if isinstance(raw.get("name"), str)
|
||||
else raw.get("display_name")
|
||||
if isinstance(raw.get("display_name"), str)
|
||||
else None
|
||||
),
|
||||
description=raw.get("description")
|
||||
if isinstance(raw.get("description"), str)
|
||||
else None,
|
||||
context_length=context_length,
|
||||
architecture=arch,
|
||||
pricing=raw.get("pricing") if isinstance(raw.get("pricing"), dict) else None,
|
||||
supported_parameters=supported_parameters,
|
||||
settings=raw.get("settings") if isinstance(raw.get("settings"), dict) else None,
|
||||
opencode=raw.get("opencode") if isinstance(raw.get("opencode"), dict) else None,
|
||||
created=raw.get("created") if isinstance(raw.get("created"), int) else None,
|
||||
owned_by=raw.get("owned_by") if isinstance(raw.get("owned_by"), str) else None,
|
||||
)
|
||||
|
|
@ -51,10 +51,13 @@ class CodexOAuthProvider:
|
|||
|
||||
headers = {"Authorization": f"Bearer {self._access}"}
|
||||
if self._account_id:
|
||||
headers["chatgpt-account-id"] = self._account_id
|
||||
headers["OpenAI-Beta"] = "responses=experimental"
|
||||
headers["ChatGPT-Account-Id"] = self._account_id
|
||||
return OAuthData(token=self._access, headers=headers)
|
||||
|
||||
async def get_headers(self) -> dict[str, str]:
|
||||
oauth = await self.get()
|
||||
return oauth.headers
|
||||
|
||||
def _is_expired(self) -> bool:
|
||||
return int(time.time() * 1000) >= self._expires - 60_000
|
||||
|
||||
|
|
|
|||
|
|
@ -4,17 +4,21 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import AsyncIterator
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Mapping, cast
|
||||
import random
|
||||
import string
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from openai import AsyncOpenAI, OpenAIError
|
||||
|
||||
from app.config.models import LoadedProviderConfig, OAuthAuth
|
||||
from app.config.models import LoadedProviderConfig, OAuthAuth, UrlAuth
|
||||
from app.core.errors import UpstreamProviderError
|
||||
from app.core.types import CoreChunk, ProviderChatRequest, ProviderModel
|
||||
from app.providers.base import BaseProvider
|
||||
from app.providers.codex_responses.models import _coerce_model_items, _to_provider_model
|
||||
from app.providers.codex_responses.oauth import CodexOAuthProvider
|
||||
from app.providers.codex_responses.stream import _map_response_stream_to_chunks
|
||||
from app.providers.codex_responses.translator import build_responses_create_args
|
||||
from app.providers.token_url_auth import TokenUrlAuthProvider
|
||||
from app.providers.registry import provider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -32,7 +36,7 @@ class CodexResponsesProvider(BaseProvider):
|
|||
*,
|
||||
name: str,
|
||||
base_url: str,
|
||||
oauth: CodexOAuthProvider,
|
||||
oauth: CodexOAuthProvider | TokenUrlAuthProvider,
|
||||
client_version: str = "0.100.0",
|
||||
whitelist: list[str] | None = None,
|
||||
blacklist: list[str] | None = None,
|
||||
|
|
@ -54,31 +58,49 @@ class CodexResponsesProvider(BaseProvider):
|
|||
|
||||
@classmethod
|
||||
def from_config(cls, config: LoadedProviderConfig) -> BaseProvider:
|
||||
if not isinstance(config.auth, OAuthAuth):
|
||||
oauth: CodexOAuthProvider | TokenUrlAuthProvider
|
||||
if isinstance(config.auth, OAuthAuth):
|
||||
oauth = CodexOAuthProvider(auth=config.auth)
|
||||
elif isinstance(config.auth, UrlAuth):
|
||||
oauth = TokenUrlAuthProvider(token_url=config.auth.url)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provider '{config.name}' type 'codex-responses' requires oauth auth"
|
||||
f"Provider '{config.name}' type 'codex-responses' requires oauth or url auth"
|
||||
)
|
||||
return cls(
|
||||
provider = cls(
|
||||
name=config.name,
|
||||
base_url=config.url,
|
||||
oauth=CodexOAuthProvider(auth=config.auth),
|
||||
oauth=oauth,
|
||||
whitelist=config.whitelist,
|
||||
blacklist=config.blacklist,
|
||||
)
|
||||
provider.display_name = config.display_name
|
||||
provider.models_config = config.models
|
||||
return provider
|
||||
|
||||
async def stream_chat(
|
||||
self, request: ProviderChatRequest
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
auth = await self._oauth.get()
|
||||
auth = await self._oauth.get_headers()
|
||||
try:
|
||||
create_args, ignored_extra = build_responses_create_args(request)
|
||||
create_args, extra_body = build_responses_create_args(request)
|
||||
|
||||
_log_ignored_extra(ignored_extra, provider_name=self.name)
|
||||
_log_outbound_request(create_args, provider_name=self.name)
|
||||
extra_headers = dict(auth)
|
||||
extra_headers.update(
|
||||
{
|
||||
"session_id": "ses_"
|
||||
+ "".join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(28)
|
||||
),
|
||||
"originator": "opencode",
|
||||
}
|
||||
)
|
||||
|
||||
stream = await cast("AsyncResponses", self._client.responses).create(
|
||||
**create_args,
|
||||
extra_headers=auth.headers,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
async for chunk in _map_response_stream_to_chunks(
|
||||
|
|
@ -98,10 +120,10 @@ class CodexResponsesProvider(BaseProvider):
|
|||
) from exc
|
||||
|
||||
async def list_models(self) -> list[ProviderModel]:
|
||||
auth = await self._oauth.get()
|
||||
auth = await self._oauth.get_headers()
|
||||
try:
|
||||
response = await self._client.models.list(
|
||||
extra_headers=auth.headers,
|
||||
extra_headers=auth,
|
||||
extra_query={"client_version": self._client_version},
|
||||
)
|
||||
items = _coerce_model_items(response)
|
||||
|
|
@ -120,473 +142,3 @@ class CodexResponsesProvider(BaseProvider):
|
|||
raise UpstreamProviderError(
|
||||
f"Provider '{self.name}' failed while listing models: {exc}"
|
||||
) from exc
|
||||
|
||||
|
||||
def _log_ignored_extra(extra: dict[str, Any], *, provider_name: str) -> None:
|
||||
if not extra:
|
||||
return
|
||||
logger.error(
|
||||
"provider '%s' ignored unsupported extra params: %s",
|
||||
provider_name,
|
||||
extra,
|
||||
)
|
||||
|
||||
|
||||
def _log_outbound_request(
|
||||
create_args: Mapping[str, Any], *, provider_name: str
|
||||
) -> None:
|
||||
enabled = os.getenv("AI_CODEX_REQUEST_LOG_ENABLED", "false").strip().lower()
|
||||
if enabled not in {"1", "true", "yes", "on"}:
|
||||
return
|
||||
|
||||
logger.error(
|
||||
"provider '%s' outbound responses.create model=%s reasoning=%s text=%s parallel_tool_calls=%s tools_count=%s input_items=%s",
|
||||
provider_name,
|
||||
create_args.get("model"),
|
||||
create_args.get("reasoning"),
|
||||
create_args.get("text"),
|
||||
create_args.get("parallel_tool_calls"),
|
||||
len(create_args.get("tools", []) or []),
|
||||
len(create_args.get("input", []) or []),
|
||||
)
|
||||
|
||||
|
||||
def _coerce_model_items(raw: Any) -> list[dict[str, Any]]:
|
||||
data = getattr(raw, "data", None)
|
||||
if isinstance(data, list):
|
||||
return [_to_dict(item) for item in data]
|
||||
models = getattr(raw, "models", None)
|
||||
if isinstance(models, list):
|
||||
return [_to_dict(item) for item in models]
|
||||
if isinstance(raw, list):
|
||||
return [_to_dict(item) for item in raw]
|
||||
if isinstance(raw, dict):
|
||||
models = raw.get("models")
|
||||
if isinstance(models, list):
|
||||
return [_to_dict(item) for item in models]
|
||||
data = raw.get("data")
|
||||
if isinstance(data, list):
|
||||
return [_to_dict(item) for item in data]
|
||||
return []
|
||||
|
||||
|
||||
def _to_dict(raw: Any) -> dict[str, Any]:
|
||||
if hasattr(raw, "model_dump"):
|
||||
dumped = raw.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return dumped
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
return {}
|
||||
|
||||
|
||||
def _as_dict(raw: Any) -> dict[str, Any]:
|
||||
return _to_dict(raw)
|
||||
|
||||
|
||||
def _as_str(value: Any) -> str | None:
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def _to_provider_model(raw: dict[str, Any]) -> ProviderModel:
|
||||
model_id = raw.get("id")
|
||||
if not isinstance(model_id, str):
|
||||
model_id = raw.get("slug") if isinstance(raw.get("slug"), str) else None
|
||||
if model_id is None:
|
||||
raise ValueError("Codex model item missing id/slug")
|
||||
|
||||
arch = raw.get("architecture")
|
||||
if isinstance(arch, dict):
|
||||
arch = {k: v for k, v in arch.items() if k != "tokenizer"}
|
||||
else:
|
||||
input_modalities = raw.get("input_modalities")
|
||||
if isinstance(input_modalities, list):
|
||||
arch = {"input_modalities": input_modalities}
|
||||
else:
|
||||
arch = None
|
||||
|
||||
supported = raw.get("supported_parameters")
|
||||
supported_parameters = None
|
||||
if isinstance(supported, list):
|
||||
supported_parameters = [str(v) for v in supported]
|
||||
|
||||
context_length = raw.get("context_length")
|
||||
if not isinstance(context_length, int):
|
||||
context_length = raw.get("context_window")
|
||||
if not isinstance(context_length, int):
|
||||
top_provider = raw.get("top_provider")
|
||||
if isinstance(top_provider, dict) and isinstance(
|
||||
top_provider.get("context_length"), int
|
||||
):
|
||||
context_length = top_provider.get("context_length")
|
||||
else:
|
||||
context_length = None
|
||||
|
||||
return ProviderModel(
|
||||
id=model_id,
|
||||
name=(
|
||||
raw.get("name")
|
||||
if isinstance(raw.get("name"), str)
|
||||
else raw.get("display_name")
|
||||
if isinstance(raw.get("display_name"), str)
|
||||
else None
|
||||
),
|
||||
description=raw.get("description")
|
||||
if isinstance(raw.get("description"), str)
|
||||
else None,
|
||||
context_length=context_length,
|
||||
architecture=arch,
|
||||
pricing=raw.get("pricing") if isinstance(raw.get("pricing"), dict) else None,
|
||||
supported_parameters=supported_parameters,
|
||||
settings=raw.get("settings") if isinstance(raw.get("settings"), dict) else None,
|
||||
opencode=raw.get("opencode") if isinstance(raw.get("opencode"), dict) else None,
|
||||
created=raw.get("created") if isinstance(raw.get("created"), int) else None,
|
||||
owned_by=raw.get("owned_by") if isinstance(raw.get("owned_by"), str) else None,
|
||||
)
|
||||
|
||||
|
||||
async def _map_response_stream_to_chunks(
|
||||
stream: AsyncIterator[Any], *, provider_name: str
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
sent_role = False
|
||||
emitted_text = False
|
||||
emitted_reasoning = False
|
||||
saw_tool_call = False
|
||||
tool_call_delta_seen: set[str] = set()
|
||||
tool_call_finalized: set[str] = set()
|
||||
tool_call_names: dict[str, str] = {}
|
||||
tool_call_indexes: dict[str, int] = {}
|
||||
pending_tool_arguments: dict[str, str] = {}
|
||||
next_tool_call_index = 0
|
||||
emitted_content_keys: set[tuple[str, int]] = set()
|
||||
saw_delta_keys: set[tuple[str, int]] = set()
|
||||
|
||||
async for event in stream:
|
||||
event_type = _event_type(event)
|
||||
|
||||
if event_type in {"response.output_text.delta", "response.refusal.delta"}:
|
||||
key = _content_key(event)
|
||||
if key is not None:
|
||||
saw_delta_keys.add(key)
|
||||
text_delta = _first_string(event, "delta")
|
||||
if text_delta:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield CoreChunk(content=text_delta)
|
||||
emitted_text = True
|
||||
continue
|
||||
|
||||
if event_type == "response.reasoning_summary_text.delta":
|
||||
reasoning_delta = _first_string(event, "delta")
|
||||
if reasoning_delta:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield CoreChunk(reasoning_content=reasoning_delta)
|
||||
emitted_reasoning = True
|
||||
continue
|
||||
|
||||
if event_type == "response.reasoning_summary_text.done":
|
||||
reasoning_done = _first_string(event, "text")
|
||||
if reasoning_done and not emitted_reasoning:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield CoreChunk(reasoning_content=reasoning_done)
|
||||
emitted_reasoning = True
|
||||
continue
|
||||
|
||||
if event_type in {"response.output_text.done", "response.refusal.done"}:
|
||||
key = _content_key(event)
|
||||
if key is not None and key in saw_delta_keys:
|
||||
emitted_content_keys.add(key)
|
||||
continue
|
||||
text_done = _first_string(event, "text", "refusal")
|
||||
if text_done:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield CoreChunk(content=text_done)
|
||||
emitted_text = True
|
||||
if key is not None:
|
||||
emitted_content_keys.add(key)
|
||||
continue
|
||||
|
||||
if event_type == "response.content_part.done":
|
||||
key = _content_key(event)
|
||||
if key is None or key in emitted_content_keys or key in saw_delta_keys:
|
||||
continue
|
||||
part = _as_dict(getattr(event, "part", None))
|
||||
if part.get("type") == "output_text":
|
||||
text = _as_str(part.get("text"))
|
||||
elif part.get("type") == "refusal":
|
||||
text = _as_str(part.get("refusal"))
|
||||
else:
|
||||
text = None
|
||||
if text:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield CoreChunk(content=text)
|
||||
emitted_text = True
|
||||
emitted_content_keys.add(key)
|
||||
continue
|
||||
|
||||
if event_type == "response.output_item.done":
|
||||
item = _as_dict(getattr(event, "item", None))
|
||||
item_type = _as_str(item.get("type"))
|
||||
if item_type == "function_call":
|
||||
saw_tool_call = True
|
||||
tool_chunk, next_tool_call_index = _tool_call_delta_from_item(
|
||||
item,
|
||||
tool_call_finalized=tool_call_finalized,
|
||||
tool_call_names=tool_call_names,
|
||||
tool_call_indexes=tool_call_indexes,
|
||||
pending_tool_arguments=pending_tool_arguments,
|
||||
next_tool_call_index=next_tool_call_index,
|
||||
)
|
||||
if tool_chunk is not None:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield tool_chunk
|
||||
continue
|
||||
|
||||
if item_type in {
|
||||
"custom_tool_call",
|
||||
"computer_call",
|
||||
"code_interpreter_call",
|
||||
"web_search_call",
|
||||
"file_search_call",
|
||||
"shell_call",
|
||||
"apply_patch_call",
|
||||
"mcp_call",
|
||||
}:
|
||||
saw_tool_call = True
|
||||
continue
|
||||
|
||||
if item_type != "message":
|
||||
continue
|
||||
item_id = _as_str(item.get("id"))
|
||||
if item_id is None:
|
||||
continue
|
||||
content = item.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for idx, part in enumerate(content):
|
||||
key = (item_id, idx)
|
||||
if key in emitted_content_keys or key in saw_delta_keys:
|
||||
continue
|
||||
part_dict = _as_dict(part)
|
||||
text = None
|
||||
if part_dict.get("type") == "output_text":
|
||||
text = _as_str(part_dict.get("text"))
|
||||
elif part_dict.get("type") == "refusal":
|
||||
text = _as_str(part_dict.get("refusal"))
|
||||
if text:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield CoreChunk(content=text)
|
||||
emitted_text = True
|
||||
emitted_content_keys.add(key)
|
||||
continue
|
||||
|
||||
if event_type in {
|
||||
"response.function_call_arguments.delta",
|
||||
"response.function_call_arguments.done",
|
||||
"response.custom_tool_call_input.delta",
|
||||
"response.custom_tool_call_input.done",
|
||||
}:
|
||||
saw_tool_call = True
|
||||
tool_chunk, next_tool_call_index = _tool_call_delta_from_event(
|
||||
event,
|
||||
event_type=event_type,
|
||||
tool_call_delta_seen=tool_call_delta_seen,
|
||||
tool_call_finalized=tool_call_finalized,
|
||||
tool_call_names=tool_call_names,
|
||||
tool_call_indexes=tool_call_indexes,
|
||||
pending_tool_arguments=pending_tool_arguments,
|
||||
next_tool_call_index=next_tool_call_index,
|
||||
)
|
||||
if tool_chunk is not None:
|
||||
if not sent_role:
|
||||
yield CoreChunk(role="assistant")
|
||||
sent_role = True
|
||||
yield tool_chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.completed":
|
||||
finish_reason = (
|
||||
"tool_calls" if saw_tool_call and not emitted_text else "stop"
|
||||
)
|
||||
yield CoreChunk(finish_reason=finish_reason)
|
||||
continue
|
||||
|
||||
if event_type == "response.incomplete":
|
||||
reason = _extract_incomplete_finish_reason(event)
|
||||
yield CoreChunk(finish_reason=reason)
|
||||
continue
|
||||
|
||||
if event_type == "response.failed":
|
||||
raise UpstreamProviderError(
|
||||
f"Provider '{provider_name}' upstream response failed"
|
||||
)
|
||||
|
||||
|
||||
def _event_type(event: Any) -> str:
|
||||
return _as_str(getattr(event, "type", None)) or ""
|
||||
|
||||
|
||||
def _content_key(event: Any) -> tuple[str, int] | None:
|
||||
item_id = _as_str(getattr(event, "item_id", None))
|
||||
content_index = getattr(event, "content_index", None)
|
||||
if item_id is None or not isinstance(content_index, int):
|
||||
return None
|
||||
return (item_id, content_index)
|
||||
|
||||
|
||||
def _first_string(event: Any, *field_names: str) -> str | None:
|
||||
for name in field_names:
|
||||
value = getattr(event, name, None)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_incomplete_finish_reason(event: Any) -> str:
|
||||
response = getattr(event, "response", None)
|
||||
response_dict = _as_dict(response)
|
||||
incomplete_details = response_dict.get("incomplete_details")
|
||||
details_dict = _as_dict(incomplete_details)
|
||||
reason = _as_str(details_dict.get("reason"))
|
||||
if reason == "max_output_tokens":
|
||||
return "length"
|
||||
if reason == "content_filter":
|
||||
return "content_filter"
|
||||
return "stop"
|
||||
|
||||
|
||||
def _tool_call_delta_from_event(
|
||||
event: Any,
|
||||
*,
|
||||
event_type: str,
|
||||
tool_call_delta_seen: set[str],
|
||||
tool_call_finalized: set[str],
|
||||
tool_call_names: dict[str, str],
|
||||
tool_call_indexes: dict[str, int],
|
||||
pending_tool_arguments: dict[str, str],
|
||||
next_tool_call_index: int,
|
||||
) -> tuple[CoreChunk | None, int]:
|
||||
item_id = _as_str(getattr(event, "item_id", None))
|
||||
if item_id is None:
|
||||
return None, next_tool_call_index
|
||||
|
||||
index = tool_call_indexes.get(item_id)
|
||||
if index is None:
|
||||
index = next_tool_call_index
|
||||
tool_call_indexes[item_id] = index
|
||||
next_tool_call_index += 1
|
||||
|
||||
if event_type == "response.function_call_arguments.delta":
|
||||
if item_id in tool_call_finalized:
|
||||
return None, next_tool_call_index
|
||||
arguments_delta = _as_str(getattr(event, "delta", None))
|
||||
if not arguments_delta:
|
||||
return None, next_tool_call_index
|
||||
tool_call_delta_seen.add(item_id)
|
||||
pending_tool_arguments[item_id] = (
|
||||
pending_tool_arguments.get(item_id, "") + arguments_delta
|
||||
)
|
||||
return None, next_tool_call_index
|
||||
|
||||
if event_type == "response.function_call_arguments.done":
|
||||
name = _as_str(getattr(event, "name", None))
|
||||
arguments = _as_str(getattr(event, "arguments", None))
|
||||
|
||||
if item_id in tool_call_finalized:
|
||||
return None, next_tool_call_index
|
||||
|
||||
if name:
|
||||
tool_call_names[item_id] = name
|
||||
function_name = tool_call_names.get(item_id)
|
||||
if not function_name:
|
||||
return None, next_tool_call_index
|
||||
|
||||
function_arguments = (
|
||||
arguments
|
||||
if arguments is not None
|
||||
else pending_tool_arguments.get(item_id, "")
|
||||
)
|
||||
pending_tool_arguments.pop(item_id, None)
|
||||
|
||||
tool_call_finalized.add(item_id)
|
||||
|
||||
return (
|
||||
CoreChunk(
|
||||
tool_calls=[
|
||||
{
|
||||
"index": index,
|
||||
"id": item_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": function_arguments,
|
||||
},
|
||||
}
|
||||
]
|
||||
),
|
||||
next_tool_call_index,
|
||||
)
|
||||
|
||||
return None, next_tool_call_index
|
||||
|
||||
|
||||
def _tool_call_delta_from_item(
|
||||
item: dict[str, Any],
|
||||
*,
|
||||
tool_call_finalized: set[str],
|
||||
tool_call_names: dict[str, str],
|
||||
tool_call_indexes: dict[str, int],
|
||||
pending_tool_arguments: dict[str, str],
|
||||
next_tool_call_index: int,
|
||||
) -> tuple[CoreChunk | None, int]:
|
||||
if _as_str(item.get("type")) != "function_call":
|
||||
return None, next_tool_call_index
|
||||
|
||||
item_id = _as_str(item.get("id"))
|
||||
name = _as_str(item.get("name"))
|
||||
if item_id is None or name is None:
|
||||
return None, next_tool_call_index
|
||||
if item_id in tool_call_finalized:
|
||||
return None, next_tool_call_index
|
||||
|
||||
index = tool_call_indexes.get(item_id)
|
||||
if index is None:
|
||||
index = next_tool_call_index
|
||||
tool_call_indexes[item_id] = index
|
||||
next_tool_call_index += 1
|
||||
|
||||
tool_call_names[item_id] = name
|
||||
|
||||
arguments = _as_str(item.get("arguments"))
|
||||
function_arguments = (
|
||||
arguments if arguments is not None else pending_tool_arguments.get(item_id, "")
|
||||
)
|
||||
pending_tool_arguments.pop(item_id, None)
|
||||
tool_call_finalized.add(item_id)
|
||||
|
||||
return (
|
||||
CoreChunk(
|
||||
tool_calls=[
|
||||
{
|
||||
"index": index,
|
||||
"id": item_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": function_arguments},
|
||||
}
|
||||
]
|
||||
),
|
||||
next_tool_call_index,
|
||||
)
|
||||
|
|
|
|||
518
app/providers/codex_responses/stream.py
Normal file
518
app/providers/codex_responses/stream.py
Normal file
|
|
@ -0,0 +1,518 @@
|
|||
"""Streaming event mapping for codex responses provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.core.errors import UpstreamProviderError
|
||||
from app.core.types import CoreChunk
|
||||
from app.providers.codex_responses.utils import _as_dict, _as_str
|
||||
|
||||
_TEXT_DELTA_EVENTS = {"response.output_text.delta", "response.refusal.delta"}
|
||||
_TEXT_DONE_EVENTS = {"response.output_text.done", "response.refusal.done"}
|
||||
_TOOL_CALL_DELTA_EVENTS = {
|
||||
"response.function_call_arguments.delta",
|
||||
"response.function_call_arguments.done",
|
||||
"response.custom_tool_call_input.delta",
|
||||
"response.custom_tool_call_input.done",
|
||||
}
|
||||
_NON_FUNCTION_TOOL_CALL_ITEMS = {
|
||||
"custom_tool_call",
|
||||
"computer_call",
|
||||
"code_interpreter_call",
|
||||
"web_search_call",
|
||||
"file_search_call",
|
||||
"shell_call",
|
||||
"apply_patch_call",
|
||||
"mcp_call",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _ResponseStreamState:
|
||||
sent_role: bool = False
|
||||
emitted_text: bool = False
|
||||
emitted_reasoning: bool = False
|
||||
saw_tool_call: bool = False
|
||||
next_tool_call_index: int = 0
|
||||
tool_call_finalized: set[str] = field(default_factory=set)
|
||||
tool_call_names: dict[str, str] = field(default_factory=dict)
|
||||
tool_call_indexes: dict[str, int] = field(default_factory=dict)
|
||||
pending_tool_arguments: dict[str, str] = field(default_factory=dict)
|
||||
emitted_tool_arguments: dict[str, str] = field(default_factory=dict)
|
||||
emitted_content_keys: set[tuple[str, int]] = field(default_factory=set)
|
||||
saw_delta_keys: set[tuple[str, int]] = field(default_factory=set)
|
||||
emitted_reasoning_item_ids: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
async def _map_response_stream_to_chunks(
|
||||
stream: AsyncIterator[Any], *, provider_name: str
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
state = _ResponseStreamState()
|
||||
|
||||
async for event in stream:
|
||||
event_type = _event_type(event)
|
||||
|
||||
if event_type in _TEXT_DELTA_EVENTS:
|
||||
for chunk in _handle_text_delta(event, state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.reasoning_summary_text.delta":
|
||||
for chunk in _emit_reasoning_chunk(_first_string(event, "delta"), state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.reasoning_summary_text.done":
|
||||
details_chunk = _reasoning_summary_detail_chunk(event)
|
||||
if details_chunk is not None:
|
||||
for chunk in _wrap_assistant_chunk(details_chunk, state):
|
||||
yield chunk
|
||||
if not state.emitted_reasoning:
|
||||
for chunk in _emit_reasoning_chunk(_first_string(event, "text"), state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.output_item.added":
|
||||
for chunk in _handle_output_item_added(event, state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type in _TEXT_DONE_EVENTS:
|
||||
for chunk in _handle_text_done(event, state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.content_part.done":
|
||||
for chunk in _handle_content_part_done(event, state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.output_item.done":
|
||||
for chunk in _handle_output_item_done(event, state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type in _TOOL_CALL_DELTA_EVENTS:
|
||||
state.saw_tool_call = True
|
||||
tool_chunk = _tool_call_delta_from_event(
|
||||
event,
|
||||
event_type=event_type,
|
||||
state=state,
|
||||
)
|
||||
if tool_chunk is not None:
|
||||
for chunk in _wrap_assistant_chunk(tool_chunk, state):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.completed":
|
||||
finish_reason = (
|
||||
"tool_calls"
|
||||
if state.saw_tool_call and not state.emitted_text
|
||||
else "stop"
|
||||
)
|
||||
yield CoreChunk(finish_reason=finish_reason)
|
||||
continue
|
||||
|
||||
if event_type == "response.incomplete":
|
||||
reason = _extract_incomplete_finish_reason(event)
|
||||
yield CoreChunk(finish_reason=reason)
|
||||
continue
|
||||
|
||||
if event_type == "response.failed":
|
||||
raise UpstreamProviderError(
|
||||
f"Provider '{provider_name}' upstream response failed"
|
||||
)
|
||||
|
||||
|
||||
def _event_type(event: Any) -> str:
|
||||
return _as_str(getattr(event, "type", None)) or ""
|
||||
|
||||
|
||||
def _ensure_assistant_role(state: _ResponseStreamState) -> CoreChunk | None:
|
||||
if state.sent_role:
|
||||
return None
|
||||
state.sent_role = True
|
||||
return CoreChunk(role="assistant")
|
||||
|
||||
|
||||
def _wrap_assistant_chunk(
|
||||
chunk: CoreChunk | None, state: _ResponseStreamState
|
||||
) -> list[CoreChunk]:
|
||||
if chunk is None:
|
||||
return []
|
||||
role_chunk = _ensure_assistant_role(state)
|
||||
return [role_chunk, chunk] if role_chunk is not None else [chunk]
|
||||
|
||||
|
||||
def _emit_text_chunk(text: str | None, state: _ResponseStreamState) -> list[CoreChunk]:
|
||||
if not text:
|
||||
return []
|
||||
state.emitted_text = True
|
||||
return _wrap_assistant_chunk(CoreChunk(content=text), state)
|
||||
|
||||
|
||||
def _emit_reasoning_chunk(
|
||||
text: str | None, state: _ResponseStreamState
|
||||
) -> list[CoreChunk]:
|
||||
if not text:
|
||||
return []
|
||||
state.emitted_reasoning = True
|
||||
return _wrap_assistant_chunk(CoreChunk(reasoning_content=text), state)
|
||||
|
||||
|
||||
def _handle_text_delta(event: Any, state: _ResponseStreamState) -> list[CoreChunk]:
|
||||
key = _content_key(event)
|
||||
if key is not None:
|
||||
state.saw_delta_keys.add(key)
|
||||
return _emit_text_chunk(_first_string(event, "delta"), state)
|
||||
|
||||
|
||||
def _handle_text_done(event: Any, state: _ResponseStreamState) -> list[CoreChunk]:
|
||||
key = _content_key(event)
|
||||
if key is not None and key in state.saw_delta_keys:
|
||||
state.emitted_content_keys.add(key)
|
||||
return []
|
||||
|
||||
chunks = _emit_text_chunk(_first_string(event, "text", "refusal"), state)
|
||||
if key is not None:
|
||||
state.emitted_content_keys.add(key)
|
||||
return chunks
|
||||
|
||||
|
||||
def _content_part_text(part: Mapping[str, Any]) -> str | None:
|
||||
part_type = part.get("type")
|
||||
if part_type == "output_text":
|
||||
return _as_str(part.get("text"))
|
||||
if part_type == "refusal":
|
||||
return _as_str(part.get("refusal"))
|
||||
return None
|
||||
|
||||
|
||||
def _handle_content_part_done(
|
||||
event: Any, state: _ResponseStreamState
|
||||
) -> list[CoreChunk]:
|
||||
key = _content_key(event)
|
||||
if key is None or key in state.emitted_content_keys or key in state.saw_delta_keys:
|
||||
return []
|
||||
|
||||
text = _content_part_text(_as_dict(getattr(event, "part", None)))
|
||||
chunks = _emit_text_chunk(text, state)
|
||||
if chunks:
|
||||
state.emitted_content_keys.add(key)
|
||||
return chunks
|
||||
|
||||
|
||||
def _handle_output_item_done(
|
||||
event: Any, state: _ResponseStreamState
|
||||
) -> list[CoreChunk]:
|
||||
item = _as_dict(getattr(event, "item", None))
|
||||
item_type = _as_str(item.get("type"))
|
||||
|
||||
if item_type == "reasoning":
|
||||
reasoning_chunk = _reasoning_encrypted_detail_chunk(
|
||||
item=item,
|
||||
item_id=_as_str(item.get("id")),
|
||||
output_index=getattr(event, "output_index", None),
|
||||
state=state,
|
||||
)
|
||||
return _wrap_assistant_chunk(reasoning_chunk, state)
|
||||
|
||||
if item_type == "function_call":
|
||||
state.saw_tool_call = True
|
||||
return _wrap_assistant_chunk(
|
||||
_tool_call_delta_from_item(item, state=state), state
|
||||
)
|
||||
|
||||
if item_type in _NON_FUNCTION_TOOL_CALL_ITEMS:
|
||||
state.saw_tool_call = True
|
||||
return []
|
||||
|
||||
if item_type != "message":
|
||||
return []
|
||||
|
||||
item_id = _as_str(item.get("id"))
|
||||
content = item.get("content")
|
||||
if item_id is None or not isinstance(content, list):
|
||||
return []
|
||||
|
||||
chunks: list[CoreChunk] = []
|
||||
for idx, part in enumerate(content):
|
||||
key = (item_id, idx)
|
||||
if key in state.emitted_content_keys or key in state.saw_delta_keys:
|
||||
continue
|
||||
|
||||
text = _content_part_text(_as_dict(part))
|
||||
emitted = _emit_text_chunk(text, state)
|
||||
if emitted:
|
||||
chunks.extend(emitted)
|
||||
state.emitted_content_keys.add(key)
|
||||
return chunks
|
||||
|
||||
|
||||
def _handle_output_item_added(
|
||||
event: Any, state: _ResponseStreamState
|
||||
) -> list[CoreChunk]:
|
||||
item = _as_dict(getattr(event, "item", None))
|
||||
if _as_str(item.get("type")) != "reasoning":
|
||||
return []
|
||||
|
||||
reasoning_chunk = _reasoning_encrypted_detail_chunk(
|
||||
item=item,
|
||||
item_id=_as_str(item.get("id")),
|
||||
output_index=getattr(event, "output_index", None),
|
||||
state=state,
|
||||
)
|
||||
return _wrap_assistant_chunk(reasoning_chunk, state)
|
||||
|
||||
|
||||
def _reasoning_encrypted_detail_chunk(
|
||||
*,
|
||||
item: dict[str, Any],
|
||||
item_id: str | None,
|
||||
output_index: Any,
|
||||
state: _ResponseStreamState,
|
||||
) -> CoreChunk | None:
|
||||
encrypted = _as_str(item.get("encrypted_content"))
|
||||
if encrypted is None:
|
||||
return None
|
||||
|
||||
if item_id is not None:
|
||||
if item_id in state.emitted_reasoning_item_ids:
|
||||
return None
|
||||
state.emitted_reasoning_item_ids.add(item_id)
|
||||
|
||||
detail: dict[str, Any] = {
|
||||
"type": "reasoning.encrypted",
|
||||
"data": encrypted,
|
||||
"format": "openai-responses-v1",
|
||||
}
|
||||
if item_id is not None:
|
||||
detail["id"] = item_id
|
||||
if isinstance(output_index, int):
|
||||
detail["index"] = output_index
|
||||
|
||||
return CoreChunk(reasoning_details=[detail])
|
||||
|
||||
|
||||
def _reasoning_summary_detail_chunk(event: Any) -> CoreChunk | None:
|
||||
text = _first_string(event, "text")
|
||||
if not text:
|
||||
return None
|
||||
|
||||
detail: dict[str, Any] = {
|
||||
"type": "reasoning.summary",
|
||||
"summary": text,
|
||||
"format": "openai-responses-v1",
|
||||
}
|
||||
item_id = _as_str(getattr(event, "item_id", None))
|
||||
if item_id is not None:
|
||||
detail["id"] = item_id
|
||||
summary_index = getattr(event, "summary_index", None)
|
||||
if isinstance(summary_index, int):
|
||||
detail["index"] = summary_index
|
||||
|
||||
return CoreChunk(reasoning_details=[detail])
|
||||
|
||||
|
||||
def _content_key(event: Any) -> tuple[str, int] | None:
|
||||
item_id = _as_str(getattr(event, "item_id", None))
|
||||
content_index = getattr(event, "content_index", None)
|
||||
if item_id is None or not isinstance(content_index, int):
|
||||
return None
|
||||
return (item_id, content_index)
|
||||
|
||||
|
||||
def _first_string(event: Any, *field_names: str) -> str | None:
|
||||
for name in field_names:
|
||||
value = getattr(event, name, None)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_incomplete_finish_reason(event: Any) -> str:
|
||||
response = getattr(event, "response", None)
|
||||
response_dict = _as_dict(response)
|
||||
incomplete_details = response_dict.get("incomplete_details")
|
||||
details_dict = _as_dict(incomplete_details)
|
||||
reason = _as_str(details_dict.get("reason"))
|
||||
if reason == "max_output_tokens":
|
||||
return "length"
|
||||
if reason == "content_filter":
|
||||
return "content_filter"
|
||||
return "stop"
|
||||
|
||||
|
||||
def _tool_call_delta_from_event(
|
||||
event: Any,
|
||||
*,
|
||||
event_type: str,
|
||||
state: _ResponseStreamState,
|
||||
) -> CoreChunk | None:
|
||||
item_id = _as_str(getattr(event, "item_id", None))
|
||||
if item_id is None:
|
||||
return None
|
||||
|
||||
index = _get_tool_call_index(item_id, state)
|
||||
|
||||
if event_type in {
|
||||
"response.function_call_arguments.delta",
|
||||
"response.custom_tool_call_input.delta",
|
||||
}:
|
||||
if item_id in state.tool_call_finalized:
|
||||
return None
|
||||
arguments_delta = _as_str(getattr(event, "delta", None))
|
||||
if not arguments_delta:
|
||||
return None
|
||||
function_name = state.tool_call_names.get(item_id)
|
||||
if function_name is None:
|
||||
state.pending_tool_arguments[item_id] = (
|
||||
state.pending_tool_arguments.get(item_id, "") + arguments_delta
|
||||
)
|
||||
return None
|
||||
state.emitted_tool_arguments[item_id] = (
|
||||
state.emitted_tool_arguments.get(item_id, "") + arguments_delta
|
||||
)
|
||||
return _build_tool_call_chunk(
|
||||
item_id=item_id,
|
||||
index=index,
|
||||
name=function_name,
|
||||
arguments=arguments_delta,
|
||||
)
|
||||
|
||||
if event_type in {
|
||||
"response.function_call_arguments.done",
|
||||
"response.custom_tool_call_input.done",
|
||||
}:
|
||||
name = _as_str(getattr(event, "name", None))
|
||||
arguments = _as_str(getattr(event, "arguments", None)) or _as_str(
|
||||
getattr(event, "input", None)
|
||||
)
|
||||
|
||||
if item_id in state.tool_call_finalized:
|
||||
return None
|
||||
|
||||
if name:
|
||||
state.tool_call_names[item_id] = name
|
||||
function_name = state.tool_call_names.get(item_id)
|
||||
if function_name is None:
|
||||
return None
|
||||
|
||||
buffered = state.pending_tool_arguments.pop(item_id, "")
|
||||
emitted = state.emitted_tool_arguments.pop(item_id, "")
|
||||
|
||||
done_arguments = _resolve_done_tool_arguments(
|
||||
arguments=arguments,
|
||||
buffered=buffered,
|
||||
emitted=emitted,
|
||||
)
|
||||
|
||||
state.tool_call_finalized.add(item_id)
|
||||
if done_arguments is None:
|
||||
return None
|
||||
|
||||
return _build_tool_call_chunk(
|
||||
item_id=item_id,
|
||||
index=index,
|
||||
name=function_name,
|
||||
arguments=done_arguments,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _tool_call_delta_from_item(
|
||||
item: dict[str, Any],
|
||||
*,
|
||||
state: _ResponseStreamState,
|
||||
) -> CoreChunk | None:
|
||||
if _as_str(item.get("type")) != "function_call":
|
||||
return None
|
||||
|
||||
item_id = _as_str(item.get("id"))
|
||||
name = _as_str(item.get("name"))
|
||||
if item_id is None or name is None:
|
||||
return None
|
||||
if item_id in state.tool_call_finalized:
|
||||
return None
|
||||
|
||||
index = _get_tool_call_index(item_id, state)
|
||||
|
||||
state.tool_call_names[item_id] = name
|
||||
|
||||
arguments = _as_str(item.get("arguments"))
|
||||
buffered = state.pending_tool_arguments.pop(item_id, "")
|
||||
emitted = state.emitted_tool_arguments.pop(item_id, "")
|
||||
function_arguments = _resolve_done_tool_arguments(
|
||||
arguments=arguments,
|
||||
buffered=buffered,
|
||||
emitted=emitted,
|
||||
)
|
||||
if function_arguments is None:
|
||||
return None
|
||||
|
||||
state.tool_call_finalized.add(item_id)
|
||||
|
||||
return _build_tool_call_chunk(
|
||||
item_id=item_id,
|
||||
index=index,
|
||||
name=name,
|
||||
arguments=function_arguments,
|
||||
)
|
||||
|
||||
|
||||
def _get_tool_call_index(item_id: str, state: _ResponseStreamState) -> int:
|
||||
index = state.tool_call_indexes.get(item_id)
|
||||
if index is not None:
|
||||
return index
|
||||
|
||||
index = state.next_tool_call_index
|
||||
state.tool_call_indexes[item_id] = index
|
||||
state.next_tool_call_index += 1
|
||||
return index
|
||||
|
||||
|
||||
def _build_tool_call_chunk(
|
||||
*, item_id: str, index: int, name: str, arguments: str
|
||||
) -> CoreChunk:
|
||||
return CoreChunk(
|
||||
tool_calls=[
|
||||
{
|
||||
"index": index,
|
||||
"id": item_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": arguments},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _tool_arguments_tail(arguments: str | None, emitted: str) -> str:
|
||||
if arguments is None:
|
||||
return ""
|
||||
if not emitted:
|
||||
return arguments
|
||||
if arguments.startswith(emitted):
|
||||
return arguments[len(emitted) :]
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_done_tool_arguments(
|
||||
*, arguments: str | None, buffered: str, emitted: str
|
||||
) -> str | None:
|
||||
if buffered:
|
||||
if arguments is None:
|
||||
return buffered
|
||||
return arguments
|
||||
|
||||
if arguments is None:
|
||||
return None if emitted else ""
|
||||
|
||||
if not emitted:
|
||||
return arguments
|
||||
|
||||
tail = _tool_arguments_tail(arguments, emitted)
|
||||
return tail or None
|
||||
|
|
@ -37,13 +37,15 @@ from app.core.types import CoreMessage, ProviderChatRequest
|
|||
def build_responses_create_args(
|
||||
request: ProviderChatRequest,
|
||||
) -> tuple[ResponseCreateParamsStreaming, dict[str, Any]]:
|
||||
instructions = _build_instructions(request.messages)
|
||||
messages = list(request.messages)
|
||||
instructions = _pop_instruction_message(messages)
|
||||
|
||||
args: ResponseCreateParamsStreaming = {
|
||||
"model": request.model,
|
||||
"input": _build_input_items(request.messages),
|
||||
"input": _build_input_items(messages),
|
||||
"stream": True,
|
||||
"store": False,
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"parallel_tool_calls": request.parallel_tool_calls
|
||||
if request.parallel_tool_calls is not None
|
||||
else True,
|
||||
|
|
@ -54,16 +56,12 @@ def build_responses_create_args(
|
|||
|
||||
if request.metadata is not None:
|
||||
args["metadata"] = request.metadata
|
||||
if request.prompt_cache_key is not None:
|
||||
args["prompt_cache_key"] = request.prompt_cache_key
|
||||
if request.prompt_cache_retention is not None:
|
||||
args["prompt_cache_retention"] = cast(Any, request.prompt_cache_retention)
|
||||
if request.safety_identifier is not None:
|
||||
args["safety_identifier"] = request.safety_identifier
|
||||
if request.service_tier is not None:
|
||||
args["service_tier"] = cast(Any, request.service_tier)
|
||||
if request.temperature is not None:
|
||||
args["temperature"] = request.temperature
|
||||
if request.top_p is not None:
|
||||
args["top_p"] = request.top_p
|
||||
if request.tools is not None:
|
||||
|
|
@ -79,20 +77,53 @@ def build_responses_create_args(
|
|||
if text_config is not None:
|
||||
args["text"] = text_config
|
||||
|
||||
return args, dict(request.extra)
|
||||
return args, _build_extra_body(request)
|
||||
|
||||
|
||||
def _build_instructions(messages: list[CoreMessage]) -> str | None:
|
||||
parts: list[str] = []
|
||||
for message in messages:
|
||||
def _build_extra_body(request: ProviderChatRequest) -> dict[str, Any]:
|
||||
extra_body = dict(request.extra)
|
||||
|
||||
if request.provider is not None:
|
||||
extra_body["provider"] = request.provider
|
||||
if request.plugins is not None:
|
||||
extra_body["plugins"] = request.plugins
|
||||
if request.session_id is not None:
|
||||
extra_body["session_id"] = request.session_id
|
||||
if request.trace is not None:
|
||||
extra_body["trace"] = request.trace
|
||||
if request.models is not None:
|
||||
extra_body["models"] = request.models
|
||||
if request.debug is not None:
|
||||
extra_body["debug"] = request.debug
|
||||
if request.image_config is not None:
|
||||
extra_body["image_config"] = request.image_config
|
||||
|
||||
for key in (
|
||||
"metadata",
|
||||
"prompt_cache_retention",
|
||||
"safety_identifier",
|
||||
"service_tier",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"reasoning",
|
||||
"response_format",
|
||||
"verbosity",
|
||||
):
|
||||
extra_body.pop(key, None)
|
||||
|
||||
return extra_body
|
||||
|
||||
|
||||
def _pop_instruction_message(messages: list[CoreMessage]) -> str:
|
||||
for index, message in enumerate(messages):
|
||||
if message.role not in {"developer", "system"}:
|
||||
continue
|
||||
text = _extract_text(message.content)
|
||||
if text:
|
||||
parts.append(text)
|
||||
if not parts:
|
||||
return "You are a helpful assistant."
|
||||
return "\n\n".join(parts)
|
||||
messages.pop(index)
|
||||
return text
|
||||
return "You are a helpful assistant."
|
||||
|
||||
|
||||
def _build_input_items(messages: list[CoreMessage]) -> ResponseInputParam:
|
||||
|
|
@ -104,6 +135,9 @@ def _build_input_items(messages: list[CoreMessage]) -> ResponseInputParam:
|
|||
continue
|
||||
|
||||
if message.role in {"developer", "system"}:
|
||||
system_message = _build_message_item(message, role_override="system")
|
||||
if system_message is not None:
|
||||
items.append(system_message)
|
||||
continue
|
||||
|
||||
if message.role in {"user", "assistant"}:
|
||||
|
|
@ -123,44 +157,73 @@ def _build_input_items(messages: list[CoreMessage]) -> ResponseInputParam:
|
|||
return items
|
||||
|
||||
|
||||
def _build_message_item(message: CoreMessage) -> EasyInputMessageParam | None:
|
||||
content = _build_message_content(message.content)
|
||||
def _build_message_item(
|
||||
message: CoreMessage, role_override: str | None = None
|
||||
) -> EasyInputMessageParam | None:
|
||||
role_value = role_override or message.role
|
||||
content = _build_message_content(message.content, role_value)
|
||||
if content is None:
|
||||
return None
|
||||
|
||||
role = cast("EasyInputMessageParam", {"role": message.role})["role"]
|
||||
role = cast("EasyInputMessageParam", {"role": role_value})["role"]
|
||||
item: EasyInputMessageParam = {
|
||||
"type": "message",
|
||||
"role": role,
|
||||
"content": content,
|
||||
"content": cast(Any, content),
|
||||
}
|
||||
return item
|
||||
|
||||
|
||||
def _build_message_content(
|
||||
content: Any,
|
||||
) -> str | list[ResponseInputContentParam] | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
role: str,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if content is None or content == "":
|
||||
return None
|
||||
|
||||
if isinstance(content, str):
|
||||
text_part_type = "output_text" if role == "assistant" else "input_text"
|
||||
return [{"type": text_part_type, "text": content}]
|
||||
|
||||
if isinstance(content, dict):
|
||||
content = [content]
|
||||
|
||||
if not isinstance(content, list):
|
||||
raise ValueError("Unsupported message content for responses input")
|
||||
|
||||
out: list[ResponseInputContentParam] = []
|
||||
out: list[dict[str, Any]] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
|
||||
part_type = part.get("type")
|
||||
cache_control = _extract_cache_control(part)
|
||||
|
||||
if part_type == "text" and isinstance(part.get("text"), str):
|
||||
out.append({"type": "input_text", "text": part["text"]})
|
||||
text_part_type = "output_text" if role == "assistant" else "input_text"
|
||||
text_item: dict[str, Any] = {"type": text_part_type, "text": part["text"]}
|
||||
if cache_control is not None:
|
||||
text_item["cache_control"] = cache_control
|
||||
out.append(text_item)
|
||||
continue
|
||||
|
||||
if part_type in {"input_text", "output_text"} and isinstance(
|
||||
part.get("text"), str
|
||||
):
|
||||
text_item = {"type": part_type, "text": part["text"]}
|
||||
if cache_control is not None:
|
||||
text_item["cache_control"] = cache_control
|
||||
out.append(text_item)
|
||||
continue
|
||||
|
||||
if part_type == "refusal" and isinstance(part.get("refusal"), str):
|
||||
out.append({"type": "input_text", "text": part["refusal"]})
|
||||
refusal_item: dict[str, Any] = {
|
||||
"type": "input_text",
|
||||
"text": part["refusal"],
|
||||
}
|
||||
if cache_control is not None:
|
||||
refusal_item["cache_control"] = cache_control
|
||||
out.append(refusal_item)
|
||||
continue
|
||||
|
||||
if part_type == "image_url" and isinstance(part.get("image_url"), dict):
|
||||
|
|
@ -170,7 +233,49 @@ def _build_message_content(
|
|||
image_item["image_url"] = image["url"]
|
||||
if image.get("detail") in {"low", "high", "auto"}:
|
||||
image_item["detail"] = image["detail"]
|
||||
out.append(cast("ResponseInputContentParam", image_item))
|
||||
if cache_control is not None:
|
||||
image_item["cache_control"] = cache_control
|
||||
out.append(image_item)
|
||||
continue
|
||||
|
||||
if part_type == "input_image":
|
||||
image_item = {
|
||||
"type": "input_image",
|
||||
"image_url": part.get("image_url"),
|
||||
"file_id": part.get("file_id"),
|
||||
"detail": part.get("detail")
|
||||
if part.get("detail") in {"low", "high", "auto"}
|
||||
else "auto",
|
||||
}
|
||||
if cache_control is not None:
|
||||
image_item["cache_control"] = cache_control
|
||||
out.append({k: v for k, v in image_item.items() if v is not None})
|
||||
continue
|
||||
|
||||
if part_type == "input_audio" and isinstance(part.get("input_audio"), dict):
|
||||
audio = part["input_audio"]
|
||||
audio_format = audio.get("format")
|
||||
if isinstance(audio.get("data"), str) and audio_format in {
|
||||
"wav",
|
||||
"mp3",
|
||||
"flac",
|
||||
"m4a",
|
||||
"ogg",
|
||||
"aiff",
|
||||
"aac",
|
||||
"pcm16",
|
||||
"pcm24",
|
||||
}:
|
||||
audio_item: dict[str, Any] = {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": audio["data"],
|
||||
"format": audio_format,
|
||||
},
|
||||
}
|
||||
if cache_control is not None:
|
||||
audio_item["cache_control"] = cache_control
|
||||
out.append(audio_item)
|
||||
continue
|
||||
|
||||
if part_type == "file" and isinstance(part.get("file"), dict):
|
||||
|
|
@ -182,7 +287,38 @@ def _build_message_content(
|
|||
file_item["file_id"] = wrapped["file_id"]
|
||||
if isinstance(wrapped.get("filename"), str):
|
||||
file_item["filename"] = wrapped["filename"]
|
||||
out.append(cast("ResponseInputContentParam", file_item))
|
||||
if isinstance(wrapped.get("file_url"), str):
|
||||
file_item["file_url"] = wrapped["file_url"]
|
||||
if cache_control is not None:
|
||||
file_item["cache_control"] = cache_control
|
||||
out.append(file_item)
|
||||
continue
|
||||
|
||||
if part_type == "input_file":
|
||||
file_item = {
|
||||
"type": "input_file",
|
||||
"file_data": part.get("file_data"),
|
||||
"file_id": part.get("file_id"),
|
||||
"filename": part.get("filename"),
|
||||
"file_url": part.get("file_url"),
|
||||
}
|
||||
if cache_control is not None:
|
||||
file_item["cache_control"] = cache_control
|
||||
out.append({k: v for k, v in file_item.items() if v is not None})
|
||||
continue
|
||||
|
||||
if part_type in {"video_url", "input_video"} and isinstance(
|
||||
part.get("video_url"), dict
|
||||
):
|
||||
video = part["video_url"]
|
||||
if isinstance(video.get("url"), str):
|
||||
video_item: dict[str, Any] = {
|
||||
"type": "input_file",
|
||||
"file_url": video["url"],
|
||||
}
|
||||
if cache_control is not None:
|
||||
video_item["cache_control"] = cache_control
|
||||
out.append(video_item)
|
||||
continue
|
||||
|
||||
if out:
|
||||
|
|
@ -314,11 +450,9 @@ def _normalize_tool_output(
|
|||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
converted = _build_message_content(content)
|
||||
converted = _build_message_content(content, "tool")
|
||||
if converted is None:
|
||||
return ""
|
||||
if isinstance(converted, str):
|
||||
return converted
|
||||
return cast("ResponseFunctionCallOutputItemListParam", converted)
|
||||
if content is None:
|
||||
return ""
|
||||
|
|
@ -328,6 +462,17 @@ def _normalize_tool_output(
|
|||
def _build_reasoning(request: ProviderChatRequest) -> Reasoning | None:
|
||||
effort = request.reasoning_effort
|
||||
summary = request.reasoning_summary
|
||||
|
||||
if effort is None and isinstance(request.reasoning, dict):
|
||||
reasoning_effort = request.reasoning.get("effort")
|
||||
if isinstance(reasoning_effort, str):
|
||||
effort = reasoning_effort
|
||||
|
||||
if summary is None and isinstance(request.reasoning, dict):
|
||||
reasoning_summary = request.reasoning.get("summary")
|
||||
if isinstance(reasoning_summary, str):
|
||||
summary = reasoning_summary
|
||||
|
||||
if summary is None:
|
||||
if request.verbosity == "low":
|
||||
summary = "concise"
|
||||
|
|
@ -345,6 +490,22 @@ def _build_reasoning(request: ProviderChatRequest) -> Reasoning | None:
|
|||
return reasoning
|
||||
|
||||
|
||||
def _extract_cache_control(part: dict[str, Any]) -> dict[str, Any] | None:
|
||||
cache_control = part.get("cache_control")
|
||||
if not isinstance(cache_control, dict):
|
||||
return None
|
||||
|
||||
cache_type = cache_control.get("type")
|
||||
if cache_type != "ephemeral":
|
||||
return None
|
||||
|
||||
out: dict[str, Any] = {"type": "ephemeral"}
|
||||
ttl = cache_control.get("ttl")
|
||||
if ttl in {"5m", "1h"}:
|
||||
out["ttl"] = ttl
|
||||
return out
|
||||
|
||||
|
||||
def _build_text_config(request: ProviderChatRequest) -> ResponseTextConfigParam | None:
|
||||
if request.response_format is None and request.verbosity is None:
|
||||
return None
|
||||
|
|
|
|||
36
app/providers/codex_responses/utils.py
Normal file
36
app/providers/codex_responses/utils.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Shared helpers for codex responses provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log_ignored_extra(extra: dict[str, Any], *, provider_name: str) -> None:
|
||||
if not extra:
|
||||
return
|
||||
logger.error(
|
||||
"provider '%s' ignored unsupported extra params: %s",
|
||||
provider_name,
|
||||
extra,
|
||||
)
|
||||
|
||||
|
||||
def _to_dict(raw: Any) -> dict[str, Any]:
|
||||
if hasattr(raw, "model_dump"):
|
||||
dumped = raw.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return dumped
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
return {}
|
||||
|
||||
|
||||
def _as_dict(raw: Any) -> dict[str, Any]:
|
||||
return _to_dict(raw)
|
||||
|
||||
|
||||
def _as_str(value: Any) -> str | None:
|
||||
return value if isinstance(value, str) else None
|
||||
|
|
@ -4,19 +4,32 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import AsyncIterator
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from openai import AsyncOpenAI, OpenAIError
|
||||
|
||||
from app.config.models import LoadedProviderConfig, TokenAuth
|
||||
from app.config.models import LoadedProviderConfig, TokenAuth, UrlAuth
|
||||
from app.core.errors import UpstreamProviderError
|
||||
from app.core.types import CoreChunk, CoreMessage, ProviderChatRequest, ProviderModel
|
||||
from app.providers.base import BaseProvider
|
||||
from app.providers.token_url_auth import TokenUrlAuthProvider
|
||||
from app.providers.registry import provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _BearerAuthProvider(Protocol):
|
||||
async def get_headers(self) -> dict[str, str]: ...
|
||||
|
||||
|
||||
class _StaticBearerAuthProvider:
|
||||
def __init__(self, token: str) -> None:
|
||||
self._token = token
|
||||
|
||||
async def get_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self._token}"}
|
||||
|
||||
|
||||
@provider(type="openai-completions")
|
||||
class OpenAICompletionsProvider(BaseProvider):
|
||||
"""Provider that talks to OpenAI-compatible /chat/completions APIs."""
|
||||
|
|
@ -26,9 +39,10 @@ class OpenAICompletionsProvider(BaseProvider):
|
|||
*,
|
||||
name: str,
|
||||
base_url: str,
|
||||
token: str,
|
||||
token: str | None = None,
|
||||
whitelist: list[str] | None = None,
|
||||
blacklist: list[str] | None = None,
|
||||
auth_provider: _BearerAuthProvider | None = None,
|
||||
client: Any | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
|
@ -38,30 +52,51 @@ class OpenAICompletionsProvider(BaseProvider):
|
|||
whitelist=whitelist,
|
||||
blacklist=blacklist,
|
||||
)
|
||||
self._client = client or AsyncOpenAI(api_key=token, base_url=base_url)
|
||||
if auth_provider is None:
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
f"Provider '{name}' type 'openai-completions' requires auth provider or token"
|
||||
)
|
||||
auth_provider = _StaticBearerAuthProvider(token)
|
||||
self._auth_provider = auth_provider
|
||||
self._client = client or AsyncOpenAI(api_key="placeholder", base_url=base_url)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LoadedProviderConfig) -> BaseProvider:
|
||||
if not isinstance(config.auth, TokenAuth):
|
||||
auth_provider: _BearerAuthProvider
|
||||
if isinstance(config.auth, TokenAuth):
|
||||
auth_provider = _StaticBearerAuthProvider(config.auth.token)
|
||||
elif isinstance(config.auth, UrlAuth):
|
||||
auth_provider = TokenUrlAuthProvider(token_url=config.auth.url)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provider '{config.name}' type 'openai-completions' requires token auth"
|
||||
f"Provider '{config.name}' type 'openai-completions' requires token or url auth"
|
||||
)
|
||||
return cls(
|
||||
|
||||
provider = cls(
|
||||
name=config.name,
|
||||
base_url=config.url,
|
||||
token=config.auth.token,
|
||||
whitelist=config.whitelist,
|
||||
blacklist=config.blacklist,
|
||||
auth_provider=auth_provider,
|
||||
)
|
||||
provider.display_name = config.display_name
|
||||
provider.models_config = config.models
|
||||
return provider
|
||||
|
||||
async def stream_chat(
|
||||
self, request: ProviderChatRequest
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
try:
|
||||
auth_headers = await self._auth_provider.get_headers()
|
||||
extra_body, ignored_extra = _build_openai_extra_payload(request)
|
||||
sent_assistant_role = False
|
||||
stream = await cast(Any, self._client.chat.completions).create(
|
||||
model=request.model,
|
||||
messages=[_to_chat_message(m) for m in request.messages],
|
||||
stream=True,
|
||||
extra_headers=auth_headers,
|
||||
extra_body=extra_body,
|
||||
audio=request.audio,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
logit_bias=request.logit_bias,
|
||||
|
|
@ -93,22 +128,50 @@ class OpenAICompletionsProvider(BaseProvider):
|
|||
verbosity=request.verbosity,
|
||||
web_search_options=request.web_search_options,
|
||||
)
|
||||
_log_ignored_extra(request.extra, provider_name=self.name)
|
||||
_log_ignored_extra(ignored_extra, provider_name=self.name)
|
||||
async for chunk in stream:
|
||||
choices = getattr(chunk, "choices", [])
|
||||
for idx, choice in enumerate(choices):
|
||||
delta = getattr(choice, "delta", None)
|
||||
role = getattr(delta, "role", None) if delta is not None else None
|
||||
content = (
|
||||
getattr(delta, "content", None) if delta is not None else None
|
||||
)
|
||||
role = _delta_str(delta, "role")
|
||||
content = _delta_str(delta, "content")
|
||||
reasoning_content = _extract_reasoning_content(delta)
|
||||
reasoning_details = _extract_reasoning_details(delta)
|
||||
tool_calls = _extract_tool_calls(delta)
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
if role is None and content is None and finish_reason is None:
|
||||
|
||||
choice_index = getattr(choice, "index", idx)
|
||||
if role == "assistant":
|
||||
sent_assistant_role = True
|
||||
elif (
|
||||
not sent_assistant_role
|
||||
and role is None
|
||||
and (
|
||||
content is not None
|
||||
or reasoning_content is not None
|
||||
or reasoning_details is not None
|
||||
or tool_calls is not None
|
||||
)
|
||||
):
|
||||
sent_assistant_role = True
|
||||
yield CoreChunk(index=choice_index, role="assistant")
|
||||
|
||||
if (
|
||||
role is None
|
||||
and content is None
|
||||
and reasoning_content is None
|
||||
and reasoning_details is None
|
||||
and tool_calls is None
|
||||
and finish_reason is None
|
||||
):
|
||||
continue
|
||||
yield CoreChunk(
|
||||
index=getattr(choice, "index", idx),
|
||||
index=choice_index,
|
||||
role=role,
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
reasoning_details=reasoning_details,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except OpenAIError as exc:
|
||||
|
|
@ -122,7 +185,8 @@ class OpenAICompletionsProvider(BaseProvider):
|
|||
|
||||
async def list_models(self) -> list[ProviderModel]:
|
||||
try:
|
||||
response = await self._client.models.list()
|
||||
auth_headers = await self._auth_provider.get_headers()
|
||||
response = await self._client.models.list(extra_headers=auth_headers)
|
||||
items = _coerce_model_items(response)
|
||||
return [_to_provider_model(item) for item in items if item.get("id")]
|
||||
except OpenAIError as exc:
|
||||
|
|
@ -156,6 +220,154 @@ def _to_chat_message(message: CoreMessage) -> dict[str, Any]:
|
|||
return out
|
||||
|
||||
|
||||
def _delta_value(delta: Any, field: str) -> Any:
|
||||
if delta is None:
|
||||
return None
|
||||
if isinstance(delta, dict):
|
||||
return delta.get(field)
|
||||
return getattr(delta, field, None)
|
||||
|
||||
|
||||
def _delta_str(delta: Any, field: str) -> str | None:
|
||||
value = _delta_value(delta, field)
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def _extract_reasoning_content(delta: Any) -> str | None:
|
||||
for field in ("reasoning_content", "reasoning", "reasoning_text"):
|
||||
value = _delta_value(delta, field)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
reasoning_obj = _to_dict(_delta_value(delta, "reasoning"))
|
||||
for field in ("content", "text", "summary"):
|
||||
value = reasoning_obj.get(field)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_tool_calls(delta: Any) -> list[dict[str, Any]] | None:
|
||||
raw = _delta_value(delta, "tool_calls")
|
||||
if not isinstance(raw, list):
|
||||
return None
|
||||
|
||||
tool_calls = [_to_dict(item) for item in raw]
|
||||
return tool_calls or None
|
||||
|
||||
|
||||
def _extract_reasoning_details(delta: Any) -> list[dict[str, Any]] | None:
|
||||
candidates: list[Any] = []
|
||||
|
||||
raw = _delta_value(delta, "reasoning_details")
|
||||
if isinstance(raw, list):
|
||||
candidates.extend(raw)
|
||||
|
||||
reasoning_obj = _to_dict(_delta_value(delta, "reasoning"))
|
||||
nested = reasoning_obj.get("details")
|
||||
if isinstance(nested, list):
|
||||
candidates.extend(nested)
|
||||
|
||||
details: list[dict[str, Any]] = []
|
||||
for candidate in candidates:
|
||||
detail = _to_dict(candidate)
|
||||
if detail:
|
||||
details.append(detail)
|
||||
|
||||
return details or None
|
||||
|
||||
|
||||
def _build_openai_extra_payload(
|
||||
request: ProviderChatRequest,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
extra_body: dict[str, Any] = {}
|
||||
extra = request.extra
|
||||
|
||||
provider_value = request.provider or (
|
||||
extra.get("provider") if isinstance(extra.get("provider"), dict) else None
|
||||
)
|
||||
if provider_value is not None:
|
||||
extra_body["provider"] = provider_value
|
||||
|
||||
plugins_value = request.plugins or (
|
||||
extra.get("plugins") if isinstance(extra.get("plugins"), list) else None
|
||||
)
|
||||
if plugins_value is not None:
|
||||
extra_body["plugins"] = plugins_value
|
||||
|
||||
session_id_value = request.session_id or (
|
||||
extra.get("session_id") if isinstance(extra.get("session_id"), str) else None
|
||||
)
|
||||
if session_id_value is not None:
|
||||
extra_body["session_id"] = session_id_value
|
||||
|
||||
trace_value = request.trace or (
|
||||
extra.get("trace") if isinstance(extra.get("trace"), dict) else None
|
||||
)
|
||||
if trace_value is not None:
|
||||
extra_body["trace"] = trace_value
|
||||
|
||||
debug_value = request.debug or (
|
||||
extra.get("debug") if isinstance(extra.get("debug"), dict) else None
|
||||
)
|
||||
if debug_value is not None:
|
||||
extra_body["debug"] = debug_value
|
||||
|
||||
image_config_value = request.image_config or (
|
||||
extra.get("image_config")
|
||||
if isinstance(extra.get("image_config"), dict)
|
||||
else None
|
||||
)
|
||||
if image_config_value is not None:
|
||||
extra_body["image_config"] = image_config_value
|
||||
|
||||
models_value = request.models or (
|
||||
extra.get("models") if isinstance(extra.get("models"), list) else None
|
||||
)
|
||||
if models_value is not None:
|
||||
extra_body["models"] = models_value
|
||||
|
||||
reasoning_value = _build_reasoning_payload(request)
|
||||
if reasoning_value is not None:
|
||||
extra_body["reasoning"] = reasoning_value
|
||||
|
||||
ignored_extra = dict(request.extra)
|
||||
for key in (
|
||||
"provider",
|
||||
"plugins",
|
||||
"session_id",
|
||||
"trace",
|
||||
"models",
|
||||
"debug",
|
||||
"image_config",
|
||||
):
|
||||
ignored_extra.pop(key, None)
|
||||
|
||||
return extra_body, ignored_extra
|
||||
|
||||
|
||||
def _build_reasoning_payload(request: ProviderChatRequest) -> dict[str, Any] | None:
|
||||
effort = request.reasoning_effort
|
||||
summary = request.reasoning_summary
|
||||
reasoning = request.reasoning
|
||||
|
||||
if isinstance(reasoning, dict):
|
||||
if effort is None and isinstance(reasoning.get("effort"), str):
|
||||
effort = reasoning["effort"]
|
||||
if summary is None and isinstance(reasoning.get("summary"), str):
|
||||
summary = reasoning["summary"]
|
||||
|
||||
if effort is None and summary is None:
|
||||
return None
|
||||
|
||||
out: dict[str, Any] = {}
|
||||
if effort is not None:
|
||||
out["effort"] = effort
|
||||
if summary is not None:
|
||||
out["summary"] = summary
|
||||
return out
|
||||
|
||||
|
||||
def _log_ignored_extra(extra: dict[str, Any], *, provider_name: str) -> None:
|
||||
if not extra:
|
||||
return
|
||||
|
|
|
|||
36
app/providers/token_url_auth.py
Normal file
36
app/providers/token_url_auth.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Auth helper for fetching bearer token from external URL."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class TokenUrlAuthProvider:
|
||||
"""Fetches bearer token from URL on each request."""
|
||||
|
||||
def __init__(self, *, token_url: str, timeout_seconds: float = 600.0) -> None:
|
||||
self._token_url = token_url
|
||||
self._timeout_seconds = timeout_seconds
|
||||
|
||||
async def get_headers(self) -> dict[str, str]:
|
||||
token = await self._fetch_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async def _fetch_token(self) -> str:
|
||||
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
|
||||
response = await client.get(self._token_url)
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise ValueError(
|
||||
f"Token URL auth failed with status {response.status_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
token = response.json()['token']
|
||||
except ValueError:
|
||||
raise ValueError(f"Token URL auth returned invalid response: {response.content}")
|
||||
|
||||
if not token:
|
||||
raise ValueError("Token URL auth returned empty token")
|
||||
|
||||
return token
|
||||
|
|
@ -1,4 +1,11 @@
|
|||
providers:
|
||||
wzray:
|
||||
url: http://127.0.0.1:8000/v1
|
||||
type: openai-completions
|
||||
name: Wzray
|
||||
models:
|
||||
openai/gpt-5:
|
||||
name: GPT-5
|
||||
zai:
|
||||
url: https://api.z.ai/api/coding/paas/v4
|
||||
type: openai-completions
|
||||
|
|
@ -14,4 +21,4 @@ providers:
|
|||
- glm-5-free
|
||||
codex:
|
||||
url: https://chatgpt.com/backend-api
|
||||
type: codex-responses
|
||||
type: codex-responses
|
||||
1774
docs/Chat Completions schema.md
Normal file
1774
docs/Chat Completions schema.md
Normal file
File diff suppressed because it is too large
Load diff
8513
docs/Responses schema.md
Normal file
8513
docs/Responses schema.md
Normal file
File diff suppressed because it is too large
Load diff
88
docs/TODO.md
Normal file
88
docs/TODO.md
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
# Codex Router Alignment Plan
|
||||
|
||||
## Confirmed Scope (from latest requirements)
|
||||
- Keep `parallel_tool_calls: true` in outbound responses payloads.
|
||||
- Do not send `prompt_cache_key` from the router for now.
|
||||
- Always send `include: ["reasoning.encrypted_content"]`.
|
||||
- Header work now: remove self-added duplicate headers only.
|
||||
- Message payload work now: stop string `content` serialization and send content parts like the good sample.
|
||||
|
||||
## Deferred (intentionally postponed)
|
||||
- Full header parity with the golden capture (transport/runtime-level UA and low-level accept-encoding parity).
|
||||
- Full one-to-one `input` history shape parity (`type` omission strategy for message items).
|
||||
- Recovering or synthesizing top-level developer message from upstream chat-completions schema.
|
||||
- End-to-end reasoning item roundtrip parity in history (`type: reasoning` pass-through and replay behavior).
|
||||
- Prompt cache implementation strategy and lifecycle management.
|
||||
|
||||
## Feasible Path For Deferred Items
|
||||
1. Header parity
|
||||
- Keep current sdk-based client for now.
|
||||
- If exact parity is required, switch codex provider transport from `AsyncOpenAI` to a custom `httpx` SSE client and set an explicit header allowlist.
|
||||
|
||||
2. Input history shape parity
|
||||
- Add a translator mode that emits implicit message items (`{"role":...,"content":...}`) without `type: "message"`.
|
||||
- Keep explicit item support for `function_call` and `function_call_output` unchanged.
|
||||
|
||||
3. Developer message availability
|
||||
- Add optional request extension field(s) in `model_extra`, e.g. `opencode_developer_message` or `opencode_input_items`.
|
||||
- Use extension when provided; otherwise keep current first-system/developer-to-instructions behavior.
|
||||
|
||||
4. Reasoning item roundtrip
|
||||
- Accept explicit inbound items with `extra.type == "reasoning"` and pass through `encrypted_content` + `summary` to responses `input`.
|
||||
- Keep chat-completions output contract unchanged; reasoning passthrough is input-side only unless a dedicated raw endpoint is added.
|
||||
|
||||
5. Prompt cache strategy
|
||||
- Keep disabled by default.
|
||||
- Add optional feature flag for deterministic hash-based key generation once cache policy is agreed.
|
||||
|
||||
## Schema.md Gap Breakdown (planning only, no implementation yet)
|
||||
|
||||
### Legend
|
||||
- `Supported` = already implemented.
|
||||
- `Partial` = partly implemented but not schema-complete.
|
||||
- `Missing` = not implemented yet.
|
||||
|
||||
| # | Area | What it does for users | Current status | Decision from latest review | Notes / planned behavior |
|
||||
|---|---|---|---|---|---|
|
||||
| 1 | Extra request controls (`provider`, `plugins`, `session_id`, `trace`, `models`, `debug`, `image_config`) | Lets users steer upstream routing, observability, plugin behavior, and image/provider-specific behavior directly from request body. | Missing | Explain each field first, then choose individually | Keep pass-through design: accept fields in API schema, preserve in internal request, forward when provider supports. |
|
||||
| 2 | `reasoning` object in request (`reasoning.effort`, `reasoning.summary`) | Standard schema-compatible way to request reasoning effort and summary verbosity. | Partial (we use flat `reasoning_effort` / `reasoning_summary`) | Must support | Add canonical `reasoning` object support while preserving backward compatibility with current flat aliases. Define precedence rules if both forms are provided. |
|
||||
| 3 | `modalities` alignment (`text`/`image`) | Controls output modalities users request. Must match schema contract exactly. | Partial / mismatched (`text`/`audio` now) | Must support schema behavior | Change request schema and internal mapping to `text`/`image` for the public API; ensure providers receive compatible values. |
|
||||
| 4 | Full message content parts (audio/video/cache-control variants) | Enables multi-part multimodal inputs (audio, video, richer text metadata) and cache hints on message parts. | Partial | Must support | Expand accepted message content item parsing and translator mapping for all schema item variants, including preservation of unknown-but-valid provider fields where safe. |
|
||||
| 5 | Assistant response extensions (`reasoning`, `reasoning_details`, `images`) | Returns richer assistant payloads: plain reasoning, structured reasoning metadata, and generated image outputs. | Missing | Must support | Extend response schemas and mappers so these fields can be emitted in non-streaming and streaming-compatible forms. |
|
||||
| 6 | Encrypted reasoning passthrough (`reasoning_details` with encrypted data) | Exposes encrypted reasoning blocks from upstream exactly as received for advanced clients/debugging/replay. | Missing | High priority, must support | Capture encrypted reasoning items from responses stream (`response.output_item.*` for `type=reasoning`) and surface in API output as raw/structured reasoning details without lossy transformation. |
|
||||
| 7 | Usage passthrough fidelity | Users should receive full upstream usage payload (raw), not a reduced subset. | Partial | Needed: pass full raw usage through | Do not over-normalize; preserve upstream usage object as-is when available. If upstream omits usage, return `null`/missing naturally. |
|
||||
| 8 | Detailed HTTP error matrix parity | Strictly maps many status codes exactly like reference schema. | Partial | Not required now | Keep current error strategy unless product requirements change. |
|
||||
| 9 | Optional `model` when `models` routing is used | OpenRouter-style multi-model router behavior. | Missing | Not required for this project | Keep `model` required in our API for now. |
|
||||
|
||||
## Field-by-field reference for item #1 (for product decision)
|
||||
|
||||
| Field | User-visible purpose | Typical payload shape | Risk/complexity |
|
||||
|---|---|---|---|
|
||||
| `provider` | Control provider routing policy (allow/deny fallback, specific providers, price/perf constraints). | Object with routing knobs (order/only/ignore, pricing, latency/throughput prefs). | Medium-High (router semantics + validation + provider compatibility). |
|
||||
| `plugins` | Enable optional behavior modules (web search/moderation/auto-router/etc). | Array of plugin descriptors with `id` and optional settings. | Medium (validation + pass-through + provider-specific effects). |
|
||||
| `session_id` (body) | Group related requests for observability/conversation continuity. | String (usually short opaque id). | Low (mostly passthrough + precedence with headers if both exist). |
|
||||
| `trace` | Attach tracing metadata for distributed observability. | Object (`trace_id`, `span_name`, etc + custom keys). | Low-Medium (schema + passthrough). |
|
||||
| `models` | Candidate model set for automatic selection/router behavior. | Array of model identifiers/patterns. | Medium-High (changes model resolution flow). |
|
||||
| `debug` | Request debug payloads (e.g., transformed upstream request echo in stream). | Object flags like `echo_upstream_body`. | Medium (security/sensitivity review required). |
|
||||
| `image_config` | Provider/model-specific image generation tuning options. | Arbitrary object map by provider/model conventions. | Medium (loosely-typed passthrough plus safety limits). |
|
||||
|
||||
## Execution order when implementation starts (agreed priorities)
|
||||
1. Encrypted reasoning + reasoning details output path (#6 + #5 core subset).
|
||||
2. Full usage passthrough fidelity (#7).
|
||||
3. Request `reasoning` object support (#2).
|
||||
4. Modalities contract alignment to schema (`text`/`image`) (#3).
|
||||
5. Message content multimodal expansion (#4).
|
||||
6. Decide and then implement selected item-#1 controls (`provider/plugins/session_id/trace/models/debug/image_config`).
|
||||
|
||||
## Implementation Steps (current)
|
||||
1. Update codex translator payload fields:
|
||||
- remove `prompt_cache_key`
|
||||
- add mandatory `include`
|
||||
2. Update message content serialization:
|
||||
- serialize string message content as `[{"type":"input_text","text":...}]`
|
||||
- preserve empty-content filtering behavior
|
||||
3. Update codex provider header handling:
|
||||
- avoid mutating oauth headers in place
|
||||
- remove self-added duplicate `user-agent` header
|
||||
4. Update/extend tests for new payload contract.
|
||||
5. Run full `pytest` and fix regressions until green.
|
||||
39
opencode/README.md
Normal file
39
opencode/README.md
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# OpenCode Wzray Plugin
|
||||
|
||||
Minimal OpenCode plugin that connects to your local API.
|
||||
|
||||
## What It Does
|
||||
|
||||
- Adds one provider in OpenCode config: `wzray`
|
||||
- Uses OpenAI-compatible transport
|
||||
- Fetches models from your API `GET /v1/models`
|
||||
- Uses `WZRAY_API_KEY` as bearer token when set
|
||||
|
||||
## Defaults
|
||||
|
||||
- Base URL: `http://127.0.0.1:8000/v1`
|
||||
- Provider key: `wzray`
|
||||
- Fallback model list contains one safe model (`openai/gpt-5`)
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `WZRAY_API_BASE_URL` (optional)
|
||||
- `AI_API_BASE_URL` (optional fallback)
|
||||
- `WZRAY_API_KEY` (optional)
|
||||
|
||||
## Install
|
||||
|
||||
From this directory:
|
||||
|
||||
```bash
|
||||
chmod +x ./install_opencode.sh
|
||||
./install_opencode.sh
|
||||
```
|
||||
|
||||
This copies plugin files to:
|
||||
|
||||
- `~/.config/opencode/plugin/opencode/`
|
||||
|
||||
And ensures `opencode.json` contains:
|
||||
|
||||
- `./plugin/opencode/plugin_wzray.ts`
|
||||
69
opencode/install_opencode.sh
Normal file
69
opencode/install_opencode.sh
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
OPENCODE_DIR="${XDG_CONFIG_HOME:-$HOME/.config}/opencode"
|
||||
PLUGIN_DIR="$OPENCODE_DIR/plugin/opencode"
|
||||
CONFIG_FILE="$OPENCODE_DIR/opencode.json"
|
||||
PLUGIN_ENTRY="./plugin/opencode/plugin_wzray.ts"
|
||||
|
||||
mkdir -p "$PLUGIN_DIR"
|
||||
|
||||
cp "$SCRIPT_DIR/plugin_wzray.ts" "$PLUGIN_DIR/plugin_wzray.ts"
|
||||
cp "$SCRIPT_DIR/models_wzray.ts" "$PLUGIN_DIR/models_wzray.ts"
|
||||
|
||||
if [[ ! -f "$CONFIG_FILE" ]]; then
|
||||
cat > "$CONFIG_FILE" <<EOF
|
||||
{
|
||||
"\$schema": "https://opencode.ai/config.json",
|
||||
"plugin": [
|
||||
"$PLUGIN_ENTRY"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
echo "Installed plugin and created $CONFIG_FILE"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
RUNTIME=""
|
||||
if command -v node >/dev/null 2>&1; then
|
||||
RUNTIME="node"
|
||||
elif command -v bun >/dev/null 2>&1; then
|
||||
RUNTIME="bun"
|
||||
else
|
||||
echo "Error: node or bun is required to update relaxed opencode.json" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
"$RUNTIME" - "$CONFIG_FILE" "$PLUGIN_ENTRY" <<'JS'
|
||||
const fs = require("node:fs");
|
||||
|
||||
const configPath = process.argv[2];
|
||||
const entry = process.argv[3];
|
||||
const source = fs.readFileSync(configPath, "utf8");
|
||||
|
||||
let data;
|
||||
try {
|
||||
data = new Function(`"use strict"; return (${source});`)();
|
||||
} catch (error) {
|
||||
console.error(`Failed to parse ${configPath}: ${String(error)}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
if (!data || typeof data !== "object" || Array.isArray(data)) {
|
||||
data = {};
|
||||
}
|
||||
|
||||
const plugins = Array.isArray(data.plugin) ? data.plugin : [];
|
||||
if (!plugins.includes(entry)) {
|
||||
plugins.push(entry);
|
||||
}
|
||||
data.plugin = plugins;
|
||||
|
||||
fs.writeFileSync(configPath, `${JSON.stringify(data, null, 2)}\n`, "utf8");
|
||||
JS
|
||||
|
||||
echo "Installed plugin files to $PLUGIN_DIR"
|
||||
echo "Updated $CONFIG_FILE"
|
||||
echo "Optional envs: WZRAY_API_BASE_URL (default http://127.0.0.1:8000/v1), WZRAY_API_KEY"
|
||||
53
opencode/models_wzray.ts
Normal file
53
opencode/models_wzray.ts
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
type ModelInfo = { name: string };
|
||||
type AvailableModels = Record<string, ModelInfo>;
|
||||
|
||||
declare const process: {
|
||||
env: Record<string, string | undefined>;
|
||||
};
|
||||
|
||||
const FALLBACK_MODELS: AvailableModels = {
|
||||
"openai/gpt-5.3-codex": { name: "OpenAI: GPT-5.3 Codex" },
|
||||
};
|
||||
|
||||
function getApiKey(): string | undefined {
|
||||
return process.env.WZRAY_API_KEY;
|
||||
}
|
||||
|
||||
function normalizeBaseUrl(baseUrl: string): string {
|
||||
const trimmed = baseUrl.replace(/\/+$/, "");
|
||||
return trimmed.endsWith("/v1") ? trimmed : `${trimmed}/v1`;
|
||||
}
|
||||
|
||||
function createHeaders(): Record<string, string> {
|
||||
const apiKey = getApiKey();
|
||||
if (!apiKey) return { "Content-Type": "application/json" };
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
};
|
||||
}
|
||||
|
||||
export async function getAvailableModels(baseUrl: string): Promise<AvailableModels> {
|
||||
const modelsUrl = `${normalizeBaseUrl(baseUrl)}/models`;
|
||||
try {
|
||||
const response = await fetch(modelsUrl, { headers: createHeaders() });
|
||||
if (!response.ok) return FALLBACK_MODELS;
|
||||
|
||||
const payload = (await response.json()) as {
|
||||
data?: Array<{ id?: string; name?: string }>;
|
||||
};
|
||||
const data = Array.isArray(payload.data) ? payload.data : [];
|
||||
|
||||
const models: AvailableModels = {};
|
||||
for (const item of data) {
|
||||
if (!item?.id) continue;
|
||||
models[item.id] = { name: item.name || item.id };
|
||||
}
|
||||
return Object.keys(models).length > 0 ? models : FALLBACK_MODELS;
|
||||
} catch {
|
||||
return FALLBACK_MODELS;
|
||||
}
|
||||
}
|
||||
|
||||
export { FALLBACK_MODELS };
|
||||
export type { AvailableModels, ModelInfo };
|
||||
55
opencode/plugin_wzray.ts
Normal file
55
opencode/plugin_wzray.ts
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
// @ts-ignore
|
||||
import type { Plugin, PluginInput } from "@opencode-ai/plugin";
|
||||
import { FALLBACK_MODELS, getAvailableModels } from "./models_wzray";
|
||||
|
||||
declare const process: {
|
||||
env: Record<string, string | undefined>;
|
||||
};
|
||||
|
||||
const SCHEMA = "https://opencode.ai/config.json";
|
||||
const NPM_PACKAGE = "@ai-sdk/openai-compatible";
|
||||
const PROVIDER_KEY = "wzray";
|
||||
const PROVIDER_NAME = "AI Router";
|
||||
const DEFAULT_BASE_URL = "https://ai.wzray.com/v1";
|
||||
|
||||
function getBaseUrl(): string {
|
||||
return (
|
||||
process.env.WZRAY_API_BASE_URL ||
|
||||
process.env.AI_API_BASE_URL ||
|
||||
DEFAULT_BASE_URL
|
||||
);
|
||||
}
|
||||
|
||||
function getApiKey(): string {
|
||||
return process.env.WZRAY_API_KEY || "{env:WZRAY_API_KEY}";
|
||||
}
|
||||
|
||||
function createProviderConfig(models: Record<string, { name: string }>) {
|
||||
return {
|
||||
schema: SCHEMA,
|
||||
npm: NPM_PACKAGE,
|
||||
name: PROVIDER_NAME,
|
||||
options: {
|
||||
baseURL: getBaseUrl(),
|
||||
apiKey: getApiKey(),
|
||||
},
|
||||
models,
|
||||
};
|
||||
}
|
||||
|
||||
async function configure(config: any): Promise<void> {
|
||||
if (!config.provider) config.provider = {};
|
||||
if (config.provider[PROVIDER_KEY]) return;
|
||||
|
||||
const baseUrl = getBaseUrl();
|
||||
const models = await getAvailableModels(baseUrl);
|
||||
config.provider[PROVIDER_KEY] = createProviderConfig(
|
||||
Object.keys(models).length > 0 ? models : FALLBACK_MODELS,
|
||||
);
|
||||
}
|
||||
|
||||
const WzrayProviderPlugin: Plugin = async (_input: PluginInput) => {
|
||||
return { config: configure };
|
||||
};
|
||||
|
||||
export default WzrayProviderPlugin;
|
||||
|
|
@ -4,6 +4,7 @@ from collections.abc import AsyncIterator
|
|||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.models_dev import ModelsDevCatalog
|
||||
from app.core.router import RouterCore
|
||||
from app.core.types import CoreChunk, ProviderChatRequest, ProviderModel
|
||||
from app.dependencies import get_router_core
|
||||
|
|
@ -38,6 +39,53 @@ class _StreamingProvider(BaseProvider):
|
|||
return [ProviderModel(id="minimax/minimax-m2.5:free", name="MiniMax")]
|
||||
|
||||
|
||||
class _ReasoningProvider(_StreamingProvider):
|
||||
async def stream_chat(
|
||||
self, request: ProviderChatRequest
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
self.models_seen.append(request.model)
|
||||
self.last_request = request
|
||||
yield CoreChunk(role="assistant")
|
||||
yield CoreChunk(
|
||||
reasoning_content="**Plan**",
|
||||
reasoning_details=[
|
||||
{
|
||||
"type": "reasoning.encrypted",
|
||||
"data": "enc_123",
|
||||
"id": "rs_1",
|
||||
"format": "openai-responses-v1",
|
||||
},
|
||||
{
|
||||
"type": "reasoning.summary",
|
||||
"summary": "**Plan**",
|
||||
"id": "rs_1",
|
||||
"format": "openai-responses-v1",
|
||||
},
|
||||
],
|
||||
)
|
||||
yield CoreChunk(content="Hello")
|
||||
yield CoreChunk(finish_reason="stop")
|
||||
|
||||
|
||||
class _ModelsDevCatalogWithProviderName(ModelsDevCatalog):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(fetch_catalog=lambda: _never_called())
|
||||
|
||||
async def get_provider_models(
|
||||
self, *, provider_name: str, provider_url: str
|
||||
) -> tuple[str | None, dict[str, dict[str, object]]]:
|
||||
return "Kilo AI", {
|
||||
"minimax/minimax-m2.5:free": {
|
||||
"name": "MiniMax",
|
||||
"release_date": "2026-01-15",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def _never_called() -> dict[str, object]:
|
||||
raise RuntimeError("not expected")
|
||||
|
||||
|
||||
def _parse_sse_data(raw: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
for line in raw.splitlines():
|
||||
|
|
@ -110,6 +158,43 @@ def test_chat_completions_non_stream_returns_chat_completion_object() -> None:
|
|||
assert body["choices"][0]["message"]["content"] == "Hello"
|
||||
|
||||
|
||||
def test_chat_completions_non_stream_includes_reasoning_details() -> None:
|
||||
app = create_app()
|
||||
provider = _ReasoningProvider()
|
||||
core = RouterCore(providers={"kilocode": provider})
|
||||
app.dependency_overrides[get_router_core] = lambda: core
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "kilocode/minimax/minimax-m2.5:free",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
message = body["choices"][0]["message"]
|
||||
assert message["reasoning"] == "**Plan**"
|
||||
assert message["reasoning_content"] == "**Plan**"
|
||||
assert message["reasoning_details"] == [
|
||||
{
|
||||
"type": "reasoning.encrypted",
|
||||
"data": "enc_123",
|
||||
"id": "rs_1",
|
||||
"format": "openai-responses-v1",
|
||||
},
|
||||
{
|
||||
"type": "reasoning.summary",
|
||||
"summary": "**Plan**",
|
||||
"id": "rs_1",
|
||||
"format": "openai-responses-v1",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_chat_completions_supports_unversioned_path() -> None:
|
||||
app = create_app()
|
||||
provider = _StreamingProvider()
|
||||
|
|
@ -123,13 +208,16 @@ def test_chat_completions_supports_unversioned_path() -> None:
|
|||
"model": "kilocode/minimax/minimax-m2.5:free",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True,
|
||||
"reasoning_effort": "low",
|
||||
"reasoning": {"effort": "low", "summary": "auto"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payloads = _parse_sse_data(response.text)
|
||||
assert payloads[-1] == "[DONE]"
|
||||
assert provider.last_request is not None
|
||||
assert provider.last_request.reasoning_effort == "low"
|
||||
assert provider.last_request.reasoning_summary == "auto"
|
||||
|
||||
|
||||
def test_chat_completions_accepts_temporary_extra_params() -> None:
|
||||
|
|
@ -180,12 +268,75 @@ def test_chat_completions_accepts_reasoning_summary_camel_alias() -> None:
|
|||
assert provider.last_request.reasoning_summary == "detailed"
|
||||
|
||||
|
||||
def test_models_endpoint_returns_aggregated_models() -> None:
|
||||
def test_chat_completions_accepts_schema_modalities() -> None:
|
||||
app = create_app()
|
||||
provider = _StreamingProvider()
|
||||
core = RouterCore(providers={"kilocode": provider})
|
||||
app.dependency_overrides[get_router_core] = lambda: core
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "kilocode/minimax/minimax-m2.5:free",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"modalities": ["text", "image"],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert provider.last_request is not None
|
||||
assert provider.last_request.modalities == ["text", "image"]
|
||||
|
||||
|
||||
def test_chat_completions_accepts_schema_router_fields() -> None:
|
||||
app = create_app()
|
||||
provider = _StreamingProvider()
|
||||
core = RouterCore(providers={"kilocode": provider})
|
||||
app.dependency_overrides[get_router_core] = lambda: core
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "kilocode/minimax/minimax-m2.5:free",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"provider": {"allow_fallbacks": False},
|
||||
"plugins": [{"id": "web", "enabled": True}],
|
||||
"session_id": "ses_123",
|
||||
"trace": {"trace_id": "tr_1"},
|
||||
"models": ["openai/gpt-5"],
|
||||
"debug": {"echo_upstream_body": True},
|
||||
"image_config": {"size": "1024x1024"},
|
||||
"reasoning": {"effort": "high", "summary": "detailed"},
|
||||
"temperature": 0.1,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert provider.last_request is not None
|
||||
assert provider.last_request.provider == {"allow_fallbacks": False}
|
||||
assert provider.last_request.plugins == [{"id": "web", "enabled": True}]
|
||||
assert provider.last_request.session_id == "ses_123"
|
||||
assert provider.last_request.trace == {"trace_id": "tr_1"}
|
||||
assert provider.last_request.models == ["openai/gpt-5"]
|
||||
assert provider.last_request.debug == {"echo_upstream_body": True}
|
||||
assert provider.last_request.image_config == {"size": "1024x1024"}
|
||||
assert provider.last_request.reasoning == {"effort": "high", "summary": "detailed"}
|
||||
assert provider.last_request.temperature == 0.1
|
||||
|
||||
|
||||
def test_models_endpoint_returns_aggregated_models() -> None:
|
||||
app = create_app()
|
||||
provider = _StreamingProvider()
|
||||
core = RouterCore(
|
||||
providers={"kilocode": provider},
|
||||
models_dev_catalog=_ModelsDevCatalogWithProviderName(),
|
||||
)
|
||||
app.dependency_overrides[get_router_core] = lambda: core
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/models")
|
||||
|
||||
|
|
@ -194,8 +345,9 @@ def test_models_endpoint_returns_aggregated_models() -> None:
|
|||
assert body["object"] == "list"
|
||||
assert body["data"][0]["id"] == "kilocode/minimax/minimax-m2.5:free"
|
||||
assert body["data"][0]["object"] == "model"
|
||||
assert body["data"][0]["created"] == 0
|
||||
assert body["data"][0]["created"] > 0
|
||||
assert body["data"][0]["owned_by"] == "wzray"
|
||||
assert body["data"][0]["name"] == "Kilo AI: MiniMax"
|
||||
|
||||
|
||||
def test_models_endpoint_supports_v1_path() -> None:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ def test_build_responses_create_args_maps_core_fields() -> None:
|
|||
safety_identifier="user-123",
|
||||
service_tier="default",
|
||||
parallel_tool_calls=False,
|
||||
temperature=0.2,
|
||||
top_p=0.9,
|
||||
tools=[
|
||||
{
|
||||
|
|
@ -68,11 +67,12 @@ def test_build_responses_create_args_maps_core_fields() -> None:
|
|||
assert args_dict["model"] == "gpt-5-codex"
|
||||
assert args_dict["stream"] is True
|
||||
assert args_dict["store"] is False
|
||||
assert args_dict["instructions"] == "Follow policy.\n\nSystem rule."
|
||||
assert args_dict["instructions"] == "Follow policy."
|
||||
assert args_dict["include"] == ["reasoning.encrypted_content"]
|
||||
assert args_dict["parallel_tool_calls"] is False
|
||||
assert "prompt_cache_key" not in args_dict
|
||||
assert args_dict["prompt_cache_retention"] == "24h"
|
||||
assert args_dict["metadata"] == {"team": "infra"}
|
||||
assert args_dict["temperature"] == 0.2
|
||||
assert args_dict["top_p"] == 0.9
|
||||
|
||||
assert args_dict["reasoning"] == {"effort": "medium", "summary": "detailed"}
|
||||
|
|
@ -87,19 +87,28 @@ def test_build_responses_create_args_maps_core_fields() -> None:
|
|||
}
|
||||
|
||||
items = args_dict["input"]
|
||||
assert items[0] == {"type": "message", "role": "user", "content": "Hello"}
|
||||
assert items[0] == {
|
||||
"type": "message",
|
||||
"role": "system",
|
||||
"content": [{"type": "input_text", "text": "System rule."}],
|
||||
}
|
||||
assert items[1] == {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": "Calling tool",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "Hello"}],
|
||||
}
|
||||
assert items[2] == {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Calling tool"}],
|
||||
}
|
||||
assert items[3] == {
|
||||
"type": "function_call",
|
||||
"call_id": "call_1",
|
||||
"name": "lookup",
|
||||
"arguments": '{"x": 1}',
|
||||
}
|
||||
assert items[3] == {
|
||||
assert items[4] == {
|
||||
"type": "function_call_output",
|
||||
"call_id": "call_1",
|
||||
"output": '{"ok": true}',
|
||||
|
|
@ -189,7 +198,161 @@ def test_build_responses_create_args_allows_missing_instructions() -> None:
|
|||
args_dict = cast(dict[str, Any], args)
|
||||
assert args_dict["instructions"] == "You are a helpful assistant."
|
||||
assert args_dict["input"] == [
|
||||
{"type": "message", "role": "user", "content": "hello"}
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "hello"}],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_build_responses_create_args_maps_followup_developer_to_system_message() -> (
|
||||
None
|
||||
):
|
||||
request = ProviderChatRequest(
|
||||
model="gpt-5-codex",
|
||||
messages=[
|
||||
CoreMessage(role="developer", content="Primary rules"),
|
||||
CoreMessage(role="developer", content="Secondary rules"),
|
||||
CoreMessage(role="user", content="hello"),
|
||||
],
|
||||
)
|
||||
|
||||
args, _ = build_responses_create_args(request)
|
||||
args_dict = cast(dict[str, Any], args)
|
||||
assert args_dict["instructions"] == "Primary rules"
|
||||
assert args_dict["input"] == [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "system",
|
||||
"content": [{"type": "input_text", "text": "Secondary rules"}],
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "hello"}],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_build_responses_create_args_skips_empty_assistant_messages() -> None:
|
||||
request = ProviderChatRequest(
|
||||
model="gpt-5-codex",
|
||||
messages=[
|
||||
CoreMessage(role="developer", content="Rules"),
|
||||
CoreMessage(role="assistant", content=""),
|
||||
CoreMessage(role="user", content="hello"),
|
||||
],
|
||||
)
|
||||
|
||||
args, _ = build_responses_create_args(request)
|
||||
args_dict = cast(dict[str, Any], args)
|
||||
assert args_dict["input"] == [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "hello"}],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_build_responses_create_args_maps_reasoning_object() -> None:
|
||||
request = ProviderChatRequest(
|
||||
model="gpt-5-codex",
|
||||
messages=[
|
||||
CoreMessage(role="developer", content="Rules"),
|
||||
CoreMessage(role="user", content="hello"),
|
||||
],
|
||||
reasoning={"effort": "low", "summary": "concise"},
|
||||
)
|
||||
|
||||
args, _ = build_responses_create_args(request)
|
||||
args_dict = cast(dict[str, Any], args)
|
||||
assert args_dict["reasoning"] == {"effort": "low", "summary": "concise"}
|
||||
|
||||
|
||||
def test_build_responses_create_args_keeps_flat_reasoning_precedence() -> None:
|
||||
request = ProviderChatRequest(
|
||||
model="gpt-5-codex",
|
||||
messages=[
|
||||
CoreMessage(role="developer", content="Rules"),
|
||||
CoreMessage(role="user", content="hello"),
|
||||
],
|
||||
reasoning_effort="high",
|
||||
reasoning_summary="detailed",
|
||||
reasoning={"effort": "low", "summary": "concise"},
|
||||
)
|
||||
|
||||
args, _ = build_responses_create_args(request)
|
||||
args_dict = cast(dict[str, Any], args)
|
||||
assert args_dict["reasoning"] == {"effort": "high", "summary": "detailed"}
|
||||
|
||||
|
||||
def test_build_responses_create_args_supports_multimodal_content_parts() -> None:
|
||||
request = ProviderChatRequest(
|
||||
model="gpt-5-codex",
|
||||
messages=[
|
||||
CoreMessage(role="developer", content="Rules"),
|
||||
CoreMessage(
|
||||
role="user",
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe media",
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.png",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": "ZmFrZQ==", "format": "wav"},
|
||||
},
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {"url": "https://example.com/video.mp4"},
|
||||
},
|
||||
{
|
||||
"type": "input_file",
|
||||
"file_url": "https://example.com/doc.pdf",
|
||||
"filename": "doc.pdf",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
args, _ = build_responses_create_args(request)
|
||||
args_dict = cast(dict[str, Any], args)
|
||||
parts = args_dict["input"][0]["content"]
|
||||
assert parts == [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Describe media",
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"},
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"detail": "high",
|
||||
"image_url": "https://example.com/image.png",
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": "ZmFrZQ==", "format": "wav"},
|
||||
},
|
||||
{
|
||||
"type": "input_file",
|
||||
"file_url": "https://example.com/video.mp4",
|
||||
},
|
||||
{
|
||||
"type": "input_file",
|
||||
"file_url": "https://example.com/doc.pdf",
|
||||
"filename": "doc.pdf",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,20 +7,16 @@ import yaml
|
|||
from pydantic import ValidationError
|
||||
|
||||
from app.config.loader import load_config
|
||||
from app.config.models import OAuthAuth, TokenAuth
|
||||
from app.config.models import OAuthAuth, TokenAuth, UrlAuth
|
||||
|
||||
|
||||
def _write_yaml(path: Path, data: dict) -> None:
|
||||
path.write_text(yaml.safe_dump(data), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
path.write_text(json.dumps(data), encoding="utf-8")
|
||||
|
||||
|
||||
def test_load_config_success(tmp_path: Path) -> None:
|
||||
config_path = tmp_path / "config.yml"
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path = tmp_path / "auth.yml"
|
||||
|
||||
_write_yaml(
|
||||
config_path,
|
||||
|
|
@ -30,12 +25,14 @@ def test_load_config_success(tmp_path: Path) -> None:
|
|||
"kilocode": {
|
||||
"url": "https://api.kilo.ai/api/openrouter",
|
||||
"type": "openai-completions",
|
||||
"name": "Kilo Override",
|
||||
"models": {"minimax/minimax-m2.5:free": {"name": "MiniMax Custom"}},
|
||||
"whitelist": ["minimax/minimax-m2.5:free"],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
_write_json(
|
||||
_write_yaml(
|
||||
auth_path,
|
||||
{"providers": {"kilocode": {"token": "public"}}},
|
||||
)
|
||||
|
|
@ -46,12 +43,14 @@ def test_load_config_success(tmp_path: Path) -> None:
|
|||
assert isinstance(provider.auth, TokenAuth)
|
||||
assert provider.type == "openai-completions"
|
||||
assert provider.auth.token == "public"
|
||||
assert provider.display_name == "Kilo Override"
|
||||
assert provider.models == {"minimax/minimax-m2.5:free": {"name": "MiniMax Custom"}}
|
||||
assert provider.whitelist == ["minimax/minimax-m2.5:free"]
|
||||
|
||||
|
||||
def test_load_config_requires_auth_entry(tmp_path: Path) -> None:
|
||||
config_path = tmp_path / "config.yml"
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path = tmp_path / "auth.yml"
|
||||
|
||||
_write_yaml(
|
||||
config_path,
|
||||
|
|
@ -61,7 +60,7 @@ def test_load_config_requires_auth_entry(tmp_path: Path) -> None:
|
|||
}
|
||||
},
|
||||
)
|
||||
_write_json(auth_path, {"providers": {}})
|
||||
_write_yaml(auth_path, {"providers": {}})
|
||||
|
||||
with pytest.raises(ValueError, match="Missing auth entry"):
|
||||
load_config(config_path=config_path, auth_path=auth_path)
|
||||
|
|
@ -69,7 +68,7 @@ def test_load_config_requires_auth_entry(tmp_path: Path) -> None:
|
|||
|
||||
def test_load_config_rejects_incomplete_oauth_auth(tmp_path: Path) -> None:
|
||||
config_path = tmp_path / "config.yml"
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path = tmp_path / "auth.yml"
|
||||
|
||||
_write_yaml(
|
||||
config_path,
|
||||
|
|
@ -79,7 +78,7 @@ def test_load_config_rejects_incomplete_oauth_auth(tmp_path: Path) -> None:
|
|||
}
|
||||
},
|
||||
)
|
||||
_write_json(auth_path, {"providers": {"kilocode": {"refresh": "rt_abc"}}})
|
||||
_write_yaml(auth_path, {"providers": {"kilocode": {"refresh": "rt_abc"}}})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
load_config(config_path=config_path, auth_path=auth_path)
|
||||
|
|
@ -87,7 +86,7 @@ def test_load_config_rejects_incomplete_oauth_auth(tmp_path: Path) -> None:
|
|||
|
||||
def test_load_config_supports_oauth_auth(tmp_path: Path) -> None:
|
||||
config_path = tmp_path / "config.yml"
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path = tmp_path / "auth.yml"
|
||||
|
||||
_write_yaml(
|
||||
config_path,
|
||||
|
|
@ -100,7 +99,7 @@ def test_load_config_supports_oauth_auth(tmp_path: Path) -> None:
|
|||
}
|
||||
},
|
||||
)
|
||||
_write_json(
|
||||
_write_yaml(
|
||||
auth_path,
|
||||
{
|
||||
"providers": {
|
||||
|
|
@ -119,3 +118,35 @@ def test_load_config_supports_oauth_auth(tmp_path: Path) -> None:
|
|||
assert provider.auth.access == "acc"
|
||||
assert provider.auth.refresh == "ref"
|
||||
assert provider.auth.expires == 123
|
||||
|
||||
|
||||
def test_load_config_supports_url_auth(tmp_path: Path) -> None:
|
||||
config_path = tmp_path / "config.yml"
|
||||
auth_path = tmp_path / "auth.yml"
|
||||
|
||||
_write_yaml(
|
||||
config_path,
|
||||
{
|
||||
"providers": {
|
||||
"kilo": {
|
||||
"url": "https://api.kilo.ai/api/openrouter",
|
||||
"type": "openai-completions",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
_write_yaml(
|
||||
auth_path,
|
||||
{
|
||||
"providers": {
|
||||
"kilo": {
|
||||
"url": "https://auth.local/token",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
loaded = load_config(config_path=config_path, auth_path=auth_path)
|
||||
provider = loaded.providers["kilo"]
|
||||
assert isinstance(provider.auth, UrlAuth)
|
||||
assert provider.auth.url == "https://auth.local/token"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from app.core.errors import (
|
|||
ModelNotAllowedError,
|
||||
ProviderNotFoundError,
|
||||
)
|
||||
from app.core.models_dev import ModelsDevCatalog
|
||||
from app.core.router import RouterCore, split_routed_model
|
||||
from app.core.types import (
|
||||
CoreChatRequest,
|
||||
|
|
@ -25,6 +26,7 @@ class StubProvider(BaseProvider):
|
|||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(api_type="stub", **kwargs)
|
||||
self.requests: list[ProviderChatRequest] = []
|
||||
self.list_models_calls = 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
|
|
@ -38,9 +40,90 @@ class StubProvider(BaseProvider):
|
|||
yield CoreChunk(finish_reason="stop")
|
||||
|
||||
async def list_models(self) -> list[ProviderModel]:
|
||||
self.list_models_calls += 1
|
||||
return [ProviderModel(id="model-a", name="Model A", context_length=123)]
|
||||
|
||||
|
||||
class StubModelsDevCatalog(ModelsDevCatalog):
|
||||
def __init__(self, payload: dict[str, dict[str, object]] | None = None) -> None:
|
||||
super().__init__(fetch_catalog=lambda: _never_called())
|
||||
self.payload = payload or {}
|
||||
self.calls: list[tuple[str, str]] = []
|
||||
self.provider_display_name = "Stub Provider"
|
||||
|
||||
async def get_provider_models(
|
||||
self, *, provider_name: str, provider_url: str
|
||||
) -> tuple[str | None, dict[str, dict[str, object]]]:
|
||||
self.calls.append((provider_name, provider_url))
|
||||
return self.provider_display_name, self.payload
|
||||
|
||||
|
||||
async def _never_called() -> dict[str, object]:
|
||||
raise RuntimeError("not expected")
|
||||
|
||||
|
||||
class MissingFieldsProvider(BaseProvider):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(api_type="stub", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
raise NotImplementedError
|
||||
|
||||
async def stream_chat(
|
||||
self, request: ProviderChatRequest
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
yield CoreChunk(finish_reason="stop")
|
||||
|
||||
async def list_models(self) -> list[ProviderModel]:
|
||||
return [ProviderModel(id="model-x")]
|
||||
|
||||
|
||||
class CompleteFieldsProvider(BaseProvider):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(api_type="stub", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
raise NotImplementedError
|
||||
|
||||
async def stream_chat(
|
||||
self, request: ProviderChatRequest
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
yield CoreChunk(finish_reason="stop")
|
||||
|
||||
async def list_models(self) -> list[ProviderModel]:
|
||||
return [
|
||||
ProviderModel(
|
||||
id="model-a",
|
||||
name="Model A",
|
||||
created=1,
|
||||
context_length=4096,
|
||||
architecture={"input_modalities": ["text"]},
|
||||
pricing={"input": 1.0, "output": 2.0},
|
||||
supported_parameters=["max_tokens"],
|
||||
owned_by="kilo",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class FailingModelsProvider(BaseProvider):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(api_type="stub", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
raise NotImplementedError
|
||||
|
||||
async def stream_chat(
|
||||
self, request: ProviderChatRequest
|
||||
) -> AsyncIterator[CoreChunk]:
|
||||
yield CoreChunk(finish_reason="stop")
|
||||
|
||||
async def list_models(self) -> list[ProviderModel]:
|
||||
raise RuntimeError("upstream unavailable")
|
||||
|
||||
|
||||
def _collect(async_iter: AsyncIterator[CoreChunk]) -> list[CoreChunk]:
|
||||
async def _inner() -> list[CoreChunk]:
|
||||
out: list[CoreChunk] = []
|
||||
|
|
@ -88,14 +171,12 @@ def test_stream_chat_routes_to_provider_model_without_prefix() -> None:
|
|||
req = CoreChatRequest(
|
||||
model="kilo/minimax/minimax-m2.5:free",
|
||||
messages=[CoreMessage(role="user", content="ping")],
|
||||
temperature=0.2,
|
||||
)
|
||||
chunks = _collect(core.stream_chat(req))
|
||||
|
||||
assert [c.content for c in chunks if c.content] == ["hello"]
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
assert provider.requests[0].model == "minimax/minimax-m2.5:free"
|
||||
assert provider.requests[0].temperature == 0.2
|
||||
|
||||
|
||||
def test_list_models_prefixes_provider_and_applies_defaults() -> None:
|
||||
|
|
@ -121,3 +202,129 @@ def test_list_models_respects_whitelist() -> None:
|
|||
|
||||
models = asyncio.run(core.list_models())
|
||||
assert models == []
|
||||
|
||||
|
||||
def test_list_models_uses_ttl_cache() -> None:
|
||||
provider = StubProvider(
|
||||
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
|
||||
)
|
||||
core = RouterCore(providers={"kilo": provider})
|
||||
|
||||
first = asyncio.run(core.list_models())
|
||||
second = asyncio.run(core.list_models())
|
||||
|
||||
assert len(first) == 1
|
||||
assert len(second) == 1
|
||||
assert provider.list_models_calls == 1
|
||||
|
||||
|
||||
def test_list_models_cache_can_be_disabled_with_zero_ttl() -> None:
|
||||
provider = StubProvider(
|
||||
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
|
||||
)
|
||||
core = RouterCore(providers={"kilo": provider}, models_cache_ttl_seconds=0.0)
|
||||
|
||||
asyncio.run(core.list_models())
|
||||
asyncio.run(core.list_models())
|
||||
|
||||
assert provider.list_models_calls == 2
|
||||
|
||||
|
||||
def test_list_models_enriches_missing_fields_from_models_dev() -> None:
|
||||
provider = MissingFieldsProvider(
|
||||
name="zai",
|
||||
base_url="https://api.z.ai/api/coding/paas/v4",
|
||||
whitelist=None,
|
||||
blacklist=None,
|
||||
)
|
||||
catalog = StubModelsDevCatalog(
|
||||
{
|
||||
"model-x": {
|
||||
"name": "Model X",
|
||||
"limit": {"context": 65536, "output": 4096},
|
||||
"modalities": {"input": ["text"], "output": ["text"]},
|
||||
"family": "glm",
|
||||
"cost": {"input": 1.0, "output": 2.0},
|
||||
"tool_call": True,
|
||||
"reasoning": True,
|
||||
"structured_output": True,
|
||||
"release_date": "2026-01-15",
|
||||
"provider": "z-ai",
|
||||
}
|
||||
}
|
||||
)
|
||||
core = RouterCore(providers={"zai": provider}, models_dev_catalog=catalog)
|
||||
|
||||
models = asyncio.run(core.list_models())
|
||||
|
||||
assert len(models) == 1
|
||||
assert models[0].id == "zai/model-x"
|
||||
assert models[0].name == "Model X"
|
||||
assert models[0].context_length == 65536
|
||||
assert models[0].architecture == {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
"family": "glm",
|
||||
}
|
||||
assert models[0].pricing == {"input": 1.0, "output": 2.0}
|
||||
assert models[0].owned_by == "z-ai"
|
||||
assert models[0].created > 0
|
||||
assert models[0].supported_parameters == [
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"modalities",
|
||||
"parallel_tool_calls",
|
||||
"reasoning_effort",
|
||||
"reasoning_summary",
|
||||
"response_format",
|
||||
"tool_choice",
|
||||
"tools",
|
||||
]
|
||||
assert models[0].provider_display_name == "Stub Provider"
|
||||
assert catalog.calls == [("zai", "https://api.z.ai/api/coding/paas/v4")]
|
||||
|
||||
|
||||
def test_list_models_always_calls_models_dev_and_prefers_it() -> None:
|
||||
provider = CompleteFieldsProvider(
|
||||
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
|
||||
)
|
||||
catalog = StubModelsDevCatalog({"model-a": {"name": "Never Used"}})
|
||||
core = RouterCore(providers={"kilo": provider}, models_dev_catalog=catalog)
|
||||
|
||||
models = asyncio.run(core.list_models())
|
||||
|
||||
assert len(models) == 1
|
||||
assert models[0].name == "Never Used"
|
||||
assert catalog.calls == [("kilo", "https://kilo")]
|
||||
|
||||
|
||||
def test_list_models_skips_failed_provider_and_returns_others() -> None:
|
||||
good = StubProvider(
|
||||
name="good", base_url="https://good", whitelist=None, blacklist=None
|
||||
)
|
||||
bad = FailingModelsProvider(
|
||||
name="bad", base_url="https://bad", whitelist=None, blacklist=None
|
||||
)
|
||||
core = RouterCore(providers={"good": good, "bad": bad})
|
||||
|
||||
models = asyncio.run(core.list_models())
|
||||
|
||||
assert len(models) == 1
|
||||
assert models[0].id == "good/model-a"
|
||||
|
||||
|
||||
def test_list_models_respects_provider_and_model_name_overrides() -> None:
|
||||
provider = StubProvider(
|
||||
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
|
||||
)
|
||||
provider.display_name = "Configured Provider"
|
||||
provider.models_config = {"model-a": {"name": "Configured Model"}}
|
||||
|
||||
catalog = StubModelsDevCatalog({"model-a": {"name": "ModelsDev Model"}})
|
||||
core = RouterCore(providers={"kilo": provider}, models_dev_catalog=catalog)
|
||||
|
||||
models = asyncio.run(core.list_models())
|
||||
|
||||
assert len(models) == 1
|
||||
assert models[0].name == "Configured Model"
|
||||
assert models[0].provider_display_name == "Configured Provider"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, cast
|
|||
|
||||
import pytest
|
||||
|
||||
from app.config.models import LoadedProviderConfig, OAuthAuth, TokenAuth
|
||||
from app.config.models import LoadedProviderConfig, OAuthAuth, TokenAuth, UrlAuth
|
||||
from app.core.errors import UpstreamProviderError
|
||||
from app.core.types import CoreMessage, ProviderChatRequest
|
||||
from app.providers.codex_responses.oauth import OAuthData
|
||||
|
|
@ -25,6 +25,8 @@ class _FakeEvent:
|
|||
refusal: str | None = None,
|
||||
item_id: str | None = None,
|
||||
content_index: int | None = None,
|
||||
output_index: int | None = None,
|
||||
summary_index: int | None = None,
|
||||
item: dict[str, Any] | None = None,
|
||||
part: dict[str, Any] | None = None,
|
||||
response: dict[str, Any] | None = None,
|
||||
|
|
@ -37,6 +39,8 @@ class _FakeEvent:
|
|||
self.refusal = refusal
|
||||
self.item_id = item_id
|
||||
self.content_index = content_index
|
||||
self.output_index = output_index
|
||||
self.summary_index = summary_index
|
||||
self.item = item
|
||||
self.part = part
|
||||
self.response = response
|
||||
|
|
@ -101,6 +105,14 @@ class _FakeOauth:
|
|||
async def get(self) -> OAuthData:
|
||||
return OAuthData(token="acc", headers={"Authorization": "Bearer acc"})
|
||||
|
||||
async def get_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": "Bearer acc"}
|
||||
|
||||
|
||||
class _StaticUrlAuth:
|
||||
async def get_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": "Bearer fetched-token"}
|
||||
|
||||
|
||||
class _FailingResponses:
|
||||
async def create(self, **kwargs):
|
||||
|
|
@ -165,7 +177,10 @@ def test_codex_provider_streams_openai_responses_events() -> None:
|
|||
assert chunks[2].content == " world"
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
|
||||
assert client.responses.last_headers == {"Authorization": "Bearer acc"}
|
||||
assert client.responses.last_headers is not None
|
||||
assert client.responses.last_headers["Authorization"] == "Bearer acc"
|
||||
assert client.responses.last_headers["originator"] == "opencode"
|
||||
assert client.responses.last_headers["session_id"].startswith("ses_")
|
||||
payload = client.responses.last_payload
|
||||
assert payload is not None
|
||||
assert payload["model"] == "gpt-5-codex"
|
||||
|
|
@ -220,7 +235,7 @@ def test_codex_provider_requires_oauth_in_from_config() -> None:
|
|||
type="codex-responses",
|
||||
auth=TokenAuth(token="bad"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires oauth auth"):
|
||||
with pytest.raises(ValueError, match="requires oauth or url auth"):
|
||||
CodexResponsesProvider.from_config(config)
|
||||
|
||||
|
||||
|
|
@ -266,6 +281,37 @@ def test_codex_provider_from_config_success() -> None:
|
|||
assert isinstance(provider, CodexResponsesProvider)
|
||||
|
||||
|
||||
def test_codex_provider_from_config_supports_url_auth() -> None:
|
||||
config = LoadedProviderConfig(
|
||||
name="codex",
|
||||
url="https://chatgpt.com/backend-api",
|
||||
type="codex-responses",
|
||||
auth=UrlAuth(url="https://auth.local/token"),
|
||||
)
|
||||
|
||||
provider = CodexResponsesProvider.from_config(config)
|
||||
assert isinstance(provider, CodexResponsesProvider)
|
||||
|
||||
|
||||
def test_codex_provider_uses_url_auth_headers() -> None:
|
||||
client = _FakeClient()
|
||||
provider = CodexResponsesProvider(
|
||||
name="codex",
|
||||
base_url="https://chatgpt.com/backend-api",
|
||||
oauth=cast(Any, _StaticUrlAuth()),
|
||||
client=client,
|
||||
)
|
||||
req = ProviderChatRequest(
|
||||
model="gpt-5-codex",
|
||||
messages=[CoreMessage(role="user", content="hello")],
|
||||
)
|
||||
|
||||
_collect(provider.stream_chat(req))
|
||||
|
||||
assert client.responses.last_headers is not None
|
||||
assert client.responses.last_headers["Authorization"] == "Bearer fetched-token"
|
||||
|
||||
|
||||
def test_codex_provider_emits_text_from_output_item_done_without_deltas() -> None:
|
||||
provider = CodexResponsesProvider(
|
||||
name="codex",
|
||||
|
|
@ -432,3 +478,64 @@ def test_codex_provider_streams_reasoning_summary_delta() -> None:
|
|||
assert chunks[0].role == "assistant"
|
||||
assert chunks[1].reasoning_content == "thinking..."
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
|
||||
|
||||
def test_codex_provider_streams_reasoning_details_with_encrypted_content() -> None:
|
||||
provider = CodexResponsesProvider(
|
||||
name="codex",
|
||||
base_url="https://chatgpt.com/backend-api",
|
||||
oauth=cast(Any, _FakeOauth()),
|
||||
client=_CustomClient(
|
||||
[
|
||||
_FakeEvent(
|
||||
type="response.output_item.added",
|
||||
output_index=0,
|
||||
item={
|
||||
"id": "rs_1",
|
||||
"type": "reasoning",
|
||||
"encrypted_content": "enc_123",
|
||||
},
|
||||
),
|
||||
_FakeEvent(
|
||||
type="response.reasoning_summary_text.done",
|
||||
item_id="rs_1",
|
||||
summary_index=0,
|
||||
text="**Plan**",
|
||||
),
|
||||
_FakeEvent(type="response.completed"),
|
||||
]
|
||||
),
|
||||
)
|
||||
req = ProviderChatRequest(
|
||||
model="gpt-5-codex",
|
||||
messages=[
|
||||
CoreMessage(role="developer", content="Be concise."),
|
||||
CoreMessage(role="user", content="Think first"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = _collect(provider.stream_chat(req))
|
||||
|
||||
details = [
|
||||
detail
|
||||
for chunk in chunks
|
||||
for detail in (chunk.reasoning_details or [])
|
||||
if isinstance(detail, dict)
|
||||
]
|
||||
assert {d.get("type") for d in details} == {
|
||||
"reasoning.encrypted",
|
||||
"reasoning.summary",
|
||||
}
|
||||
encrypted = next(d for d in details if d.get("type") == "reasoning.encrypted")
|
||||
assert encrypted["data"] == "enc_123"
|
||||
assert encrypted["id"] == "rs_1"
|
||||
assert encrypted["format"] == "openai-responses-v1"
|
||||
|
||||
summary = next(d for d in details if d.get("type") == "reasoning.summary")
|
||||
assert summary["summary"] == "**Plan**"
|
||||
assert summary["id"] == "rs_1"
|
||||
assert summary["format"] == "openai-responses-v1"
|
||||
|
||||
reasoning_chunks = [c.reasoning_content for c in chunks if c.reasoning_content]
|
||||
assert reasoning_chunks == ["**Plan**"]
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
|
|
|
|||
|
|
@ -5,15 +5,26 @@ from collections.abc import AsyncIterator
|
|||
|
||||
import pytest
|
||||
|
||||
from app.config.models import LoadedProviderConfig, OAuthAuth, UrlAuth
|
||||
from app.core.errors import UpstreamProviderError
|
||||
from app.core.types import CoreMessage, ProviderChatRequest
|
||||
from app.providers.openai_completions.provider import OpenAICompletionsProvider
|
||||
|
||||
|
||||
class _FakeDelta:
|
||||
def __init__(self, role: str | None = None, content: str | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
role: str | None = None,
|
||||
content: str | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
reasoning: object | None = None,
|
||||
tool_calls: list[dict[str, object]] | None = None,
|
||||
) -> None:
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.reasoning_content = reasoning_content
|
||||
self.reasoning = reasoning
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
|
||||
class _FakeChoice:
|
||||
|
|
@ -61,19 +72,73 @@ class _FakeCompletions:
|
|||
)
|
||||
|
||||
|
||||
class _ReasoningCompletions:
|
||||
async def create(self, **kwargs):
|
||||
return _FakeStream(
|
||||
[
|
||||
_FakeChunk([_FakeChoice(index=0, delta=_FakeDelta(role="assistant"))]),
|
||||
_FakeChunk(
|
||||
[
|
||||
_FakeChoice(
|
||||
index=0,
|
||||
delta=_FakeDelta(reasoning={"summary": "thinking"}),
|
||||
)
|
||||
]
|
||||
),
|
||||
_FakeChunk(
|
||||
[
|
||||
_FakeChoice(
|
||||
index=0,
|
||||
delta=_FakeDelta(
|
||||
tool_calls=[
|
||||
{
|
||||
"index": 0,
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "question",
|
||||
"arguments": "{",
|
||||
},
|
||||
}
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
_FakeChunk([_FakeChoice(index=0, finish_reason="tool_calls")]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class _FakeChat:
|
||||
def __init__(self) -> None:
|
||||
self.completions = _FakeCompletions()
|
||||
|
||||
|
||||
class _ReasoningChat:
|
||||
def __init__(self) -> None:
|
||||
self.completions = _ReasoningCompletions()
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.chat = _FakeChat()
|
||||
self.models = _FakeModels()
|
||||
|
||||
|
||||
class _ReasoningClient:
|
||||
def __init__(self) -> None:
|
||||
self.chat = _ReasoningChat()
|
||||
self.models = _FakeModels()
|
||||
|
||||
|
||||
class _FakeModels:
|
||||
async def list(self):
|
||||
def __init__(self) -> None:
|
||||
self.last_kwargs: dict[str, object] | None = None
|
||||
|
||||
async def list(self, **kwargs):
|
||||
self.last_kwargs = kwargs
|
||||
|
||||
class _Resp:
|
||||
data = [
|
||||
{
|
||||
|
|
@ -112,10 +177,15 @@ class _FailingClient:
|
|||
|
||||
|
||||
class _FailingModels:
|
||||
async def list(self):
|
||||
async def list(self, **kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
class _StaticAuth:
|
||||
async def get_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": "Bearer fetched-token"}
|
||||
|
||||
|
||||
def _collect(async_iter) -> list:
|
||||
async def _inner() -> list:
|
||||
out = []
|
||||
|
|
@ -137,7 +207,6 @@ def test_openai_completions_provider_streams_internal_chunks() -> None:
|
|||
req = ProviderChatRequest(
|
||||
model="minimax/minimax-m2.5:free",
|
||||
messages=[CoreMessage(role="user", content="hello")],
|
||||
temperature=0.1,
|
||||
top_p=0.9,
|
||||
max_tokens=123,
|
||||
)
|
||||
|
|
@ -152,10 +221,70 @@ def test_openai_completions_provider_streams_internal_chunks() -> None:
|
|||
assert payload is not None
|
||||
assert payload["model"] == "minimax/minimax-m2.5:free"
|
||||
assert payload["stream"] is True
|
||||
assert payload["temperature"] == 0.1
|
||||
assert payload["top_p"] == 0.9
|
||||
assert payload["max_tokens"] == 123
|
||||
assert payload["messages"] == [{"role": "user", "content": "hello"}]
|
||||
assert payload["extra_headers"] == {"Authorization": "Bearer public"}
|
||||
|
||||
|
||||
def test_openai_completions_provider_streams_reasoning_details() -> None:
|
||||
class _ReasoningDetailsChat:
|
||||
def __init__(self) -> None:
|
||||
self.completions = _ReasoningDetailsCompletions()
|
||||
|
||||
class _ReasoningDetailsCompletions:
|
||||
async def create(self, **kwargs):
|
||||
return _FakeStream(
|
||||
[
|
||||
_FakeChunk(
|
||||
[
|
||||
_FakeChoice(
|
||||
index=0,
|
||||
delta=_FakeDelta(
|
||||
reasoning={
|
||||
"details": [
|
||||
{
|
||||
"type": "reasoning.encrypted",
|
||||
"data": "enc_123",
|
||||
"format": "openai-responses-v1",
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
_FakeChunk([_FakeChoice(index=0, finish_reason="stop")]),
|
||||
]
|
||||
)
|
||||
|
||||
class _ReasoningDetailsClient:
|
||||
def __init__(self) -> None:
|
||||
self.chat = _ReasoningDetailsChat()
|
||||
self.models = _FakeModels()
|
||||
|
||||
provider = OpenAICompletionsProvider(
|
||||
name="kilo",
|
||||
base_url="https://api.kilo.ai/api/openrouter",
|
||||
token="public",
|
||||
client=_ReasoningDetailsClient(),
|
||||
)
|
||||
req = ProviderChatRequest(
|
||||
model="minimax/minimax-m2.5:free",
|
||||
messages=[CoreMessage(role="user", content="hello")],
|
||||
)
|
||||
|
||||
chunks = _collect(provider.stream_chat(req))
|
||||
|
||||
assert chunks[0].role == "assistant"
|
||||
assert chunks[1].reasoning_details == [
|
||||
{
|
||||
"type": "reasoning.encrypted",
|
||||
"data": "enc_123",
|
||||
"format": "openai-responses-v1",
|
||||
}
|
||||
]
|
||||
assert chunks[2].finish_reason == "stop"
|
||||
|
||||
|
||||
def test_openai_completions_provider_wraps_upstream_error() -> None:
|
||||
|
|
@ -174,12 +303,40 @@ def test_openai_completions_provider_wraps_upstream_error() -> None:
|
|||
_collect(provider.stream_chat(req))
|
||||
|
||||
|
||||
def test_openai_completions_provider_lists_models() -> None:
|
||||
def test_openai_completions_provider_streams_reasoning_and_tool_calls() -> None:
|
||||
provider = OpenAICompletionsProvider(
|
||||
name="kilo",
|
||||
base_url="https://api.kilo.ai/api/openrouter",
|
||||
token="public",
|
||||
client=_FakeClient(),
|
||||
client=_ReasoningClient(),
|
||||
)
|
||||
req = ProviderChatRequest(
|
||||
model="minimax/minimax-m2.5:free",
|
||||
messages=[CoreMessage(role="user", content="hello")],
|
||||
)
|
||||
|
||||
chunks = _collect(provider.stream_chat(req))
|
||||
|
||||
assert chunks[0].role == "assistant"
|
||||
assert chunks[1].reasoning_content == "thinking"
|
||||
assert chunks[2].tool_calls == [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "question", "arguments": "{"},
|
||||
}
|
||||
]
|
||||
assert chunks[3].finish_reason == "tool_calls"
|
||||
|
||||
|
||||
def test_openai_completions_provider_lists_models() -> None:
|
||||
client = _FakeClient()
|
||||
provider = OpenAICompletionsProvider(
|
||||
name="kilo",
|
||||
base_url="https://api.kilo.ai/api/openrouter",
|
||||
token="public",
|
||||
client=client,
|
||||
)
|
||||
|
||||
models = asyncio.run(provider.list_models())
|
||||
|
|
@ -187,6 +344,9 @@ def test_openai_completions_provider_lists_models() -> None:
|
|||
assert models[0].id == "minimax/minimax-m2.5:free"
|
||||
assert models[0].context_length == 2048
|
||||
assert models[0].architecture == {"input_modalities": ["text"]}
|
||||
assert client.models.last_kwargs == {
|
||||
"extra_headers": {"Authorization": "Bearer public"}
|
||||
}
|
||||
|
||||
|
||||
def test_openai_completions_provider_wraps_model_list_errors() -> None:
|
||||
|
|
@ -199,3 +359,84 @@ def test_openai_completions_provider_wraps_model_list_errors() -> None:
|
|||
|
||||
with pytest.raises(UpstreamProviderError, match="failed while listing models"):
|
||||
asyncio.run(provider.list_models())
|
||||
|
||||
|
||||
def test_openai_completions_provider_from_config_supports_url_auth() -> None:
|
||||
config = LoadedProviderConfig(
|
||||
name="kilo",
|
||||
url="https://api.kilo.ai/api/openrouter",
|
||||
type="openai-completions",
|
||||
auth=UrlAuth(url="https://auth.local/token"),
|
||||
)
|
||||
|
||||
provider = OpenAICompletionsProvider.from_config(config)
|
||||
assert isinstance(provider, OpenAICompletionsProvider)
|
||||
|
||||
|
||||
def test_openai_completions_provider_from_config_rejects_oauth() -> None:
|
||||
config = LoadedProviderConfig(
|
||||
name="kilo",
|
||||
url="https://api.kilo.ai/api/openrouter",
|
||||
type="openai-completions",
|
||||
auth=OAuthAuth(access="acc", refresh="ref", expires=1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="requires token or url auth"):
|
||||
OpenAICompletionsProvider.from_config(config)
|
||||
|
||||
|
||||
def test_openai_completions_provider_uses_custom_auth_provider() -> None:
|
||||
client = _FakeClient()
|
||||
provider = OpenAICompletionsProvider(
|
||||
name="kilo",
|
||||
base_url="https://api.kilo.ai/api/openrouter",
|
||||
auth_provider=_StaticAuth(),
|
||||
client=client,
|
||||
)
|
||||
req = ProviderChatRequest(
|
||||
model="minimax/minimax-m2.5:free",
|
||||
messages=[CoreMessage(role="user", content="hello")],
|
||||
)
|
||||
|
||||
_collect(provider.stream_chat(req))
|
||||
|
||||
payload = client.chat.completions.payload
|
||||
assert payload is not None
|
||||
assert payload["extra_headers"] == {"Authorization": "Bearer fetched-token"}
|
||||
|
||||
|
||||
def test_openai_completions_provider_passes_schema_fields_via_extra_body() -> None:
|
||||
client = _FakeClient()
|
||||
provider = OpenAICompletionsProvider(
|
||||
name="kilo",
|
||||
base_url="https://api.kilo.ai/api/openrouter",
|
||||
token="public",
|
||||
client=client,
|
||||
)
|
||||
req = ProviderChatRequest(
|
||||
model="minimax/minimax-m2.5:free",
|
||||
messages=[CoreMessage(role="user", content="hello")],
|
||||
reasoning={"effort": "high", "summary": "detailed"},
|
||||
provider={"allow_fallbacks": False},
|
||||
plugins=[{"id": "web", "enabled": True}],
|
||||
session_id="ses_123",
|
||||
trace={"trace_id": "tr_1"},
|
||||
models=["openai/gpt-5"],
|
||||
debug={"echo_upstream_body": True},
|
||||
image_config={"size": "1024x1024"},
|
||||
)
|
||||
|
||||
_collect(provider.stream_chat(req))
|
||||
|
||||
payload = client.chat.completions.payload
|
||||
assert payload is not None
|
||||
assert payload["extra_body"] == {
|
||||
"reasoning": {"effort": "high", "summary": "detailed"},
|
||||
"provider": {"allow_fallbacks": False},
|
||||
"plugins": [{"id": "web", "enabled": True}],
|
||||
"session_id": "ses_123",
|
||||
"trace": {"trace_id": "tr_1"},
|
||||
"models": ["openai/gpt-5"],
|
||||
"debug": {"echo_upstream_body": True},
|
||||
"image_config": {"size": "1024x1024"},
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue