ai/ai
1
0
Fork 0

feat(router): implement codex streaming, yaml auth, and opencode plugin

This commit is contained in:
Arthur K. 2026-03-01 21:00:19 +03:00
parent 28708500a5
commit 5f6ed46a9c
Signed by: wzray
GPG key ID: B97F30FDC4636357
33 changed files with 13223 additions and 615 deletions

View file

@ -16,6 +16,6 @@ COPY app ./app
VOLUME ["/data"]
EXPOSE 8000
EXPOSE 80
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]

View file

@ -18,6 +18,18 @@ from app.core.types import CoreChatRequest, CoreChunk, CoreMessage, CoreModel
def to_core_chat_request(payload: ChatCompletionsRequest) -> CoreChatRequest:
extras = dict(payload.model_extra or {})
reasoning_effort = payload.reasoning_effort
reasoning_summary = payload.reasoning_summary
if isinstance(payload.reasoning, dict):
if reasoning_effort is None and isinstance(
payload.reasoning.get("effort"), str
):
reasoning_effort = payload.reasoning["effort"]
if reasoning_summary is None and isinstance(
payload.reasoning.get("summary"), str
):
reasoning_summary = payload.reasoning["summary"]
return CoreChatRequest(
model=payload.model,
stream=payload.stream,
@ -43,15 +55,21 @@ def to_core_chat_request(payload: ChatCompletionsRequest) -> CoreChatRequest:
max_completion_tokens=payload.max_completion_tokens,
max_tokens=payload.max_tokens,
metadata=payload.metadata,
provider=payload.provider,
plugins=payload.plugins,
session_id=payload.session_id,
trace=payload.trace,
modalities=list(payload.modalities) if payload.modalities is not None else None,
models=payload.models,
n=payload.n,
parallel_tool_calls=payload.parallel_tool_calls,
prediction=payload.prediction,
presence_penalty=payload.presence_penalty,
prompt_cache_key=payload.prompt_cache_key,
prompt_cache_retention=payload.prompt_cache_retention,
reasoning_effort=payload.reasoning_effort,
reasoning_summary=payload.reasoning_summary,
reasoning_effort=reasoning_effort,
reasoning_summary=reasoning_summary,
reasoning=payload.reasoning,
response_format=payload.response_format,
safety_identifier=payload.safety_identifier,
seed=payload.seed,
@ -60,6 +78,8 @@ def to_core_chat_request(payload: ChatCompletionsRequest) -> CoreChatRequest:
store=payload.store,
stream_options=payload.stream_options,
temperature=payload.temperature,
debug=payload.debug,
image_config=payload.image_config,
tool_choice=payload.tool_choice,
tools=payload.tools,
top_logprobs=payload.top_logprobs,
@ -89,6 +109,7 @@ def to_api_chunk(
role=chunk.role,
content=chunk.content,
reasoning_content=chunk.reasoning_content,
reasoning_details=chunk.reasoning_details,
tool_calls=chunk.tool_calls,
),
finish_reason=chunk.finish_reason,
@ -106,6 +127,7 @@ def to_chat_completion_response(
) -> ChatCompletionResponse:
text_parts: list[str] = []
reasoning_parts: list[str] = []
reasoning_details_parts: list[dict[str, object]] = []
tool_calls_parts: list[dict[str, object]] = []
finish_reason: str | None = None
for chunk in chunks:
@ -113,6 +135,8 @@ def to_chat_completion_response(
text_parts.append(chunk.content)
if chunk.reasoning_content:
reasoning_parts.append(chunk.reasoning_content)
if chunk.reasoning_details:
reasoning_details_parts.extend(chunk.reasoning_details)
if chunk.tool_calls:
tool_calls_parts.extend(chunk.tool_calls)
if chunk.finish_reason is not None:
@ -127,7 +151,9 @@ def to_chat_completion_response(
index=0,
message=ChatCompletionMessage(
content="".join(text_parts),
reasoning="".join(reasoning_parts) or None,
reasoning_content="".join(reasoning_parts) or None,
reasoning_details=reasoning_details_parts or None,
tool_calls=_merge_tool_call_deltas(tool_calls_parts) or None,
),
finish_reason=finish_reason,
@ -180,7 +206,7 @@ def to_models_response(models: list[CoreModel]) -> ModelsResponse:
id=model.id,
created=model.created,
owned_by=model.owned_by,
name=model.name,
name=_format_model_name(model),
description=model.description,
context_length=model.context_length,
architecture=model.architecture,
@ -192,3 +218,17 @@ def to_models_response(models: list[CoreModel]) -> ModelsResponse:
for model in models
]
)
def _format_model_name(model: CoreModel) -> str | None:
if model.name is None:
return None
provider_label = model.provider_display_name
if provider_label is None:
provider_name, _, _ = model.id.partition("/")
provider_label = provider_name or None
if provider_label is None:
return model.name
return f"{provider_label}: {model.name}"

View file

@ -6,9 +6,11 @@ import json
import logging
import os
import time
from collections.abc import AsyncIterator
from typing import Any
from fastapi import FastAPI, Request
from starlette.responses import StreamingResponse
logger = logging.getLogger("ai.http")
if not logger.handlers:
@ -26,15 +28,19 @@ logger.propagate = False
def install_request_logging(app: FastAPI) -> None:
enabled = os.getenv("AI_REQUEST_LOG_ENABLED", "false").strip().lower()
if enabled not in {"1", "true", "yes", "on"}:
return
detailed_logging_enabled = enabled in {"1", "true", "yes", "on"}
max_len_raw = os.getenv("AI_REQUEST_LOG_MAX_BODY_CHARS", "20000").strip()
try:
max_body_chars = max(1024, int(max_len_raw))
except ValueError:
max_body_chars = 20000
@app.middleware("http")
async def request_logging_middleware(request: Request, call_next):
started_at = time.perf_counter()
body_text = ""
if request.method in {"POST", "PUT", "PATCH"}:
if detailed_logging_enabled and request.method in {"POST", "PUT", "PATCH"}:
body_bytes = await request.body()
body_text = _format_body(body_bytes)
@ -53,6 +59,23 @@ def install_request_logging(app: FastAPI) -> None:
raise
elapsed_ms = (time.perf_counter() - started_at) * 1000
if response.status_code >= 400:
response_body = await _read_response_body(response)
logger.warning(
"http request error method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s response_body=%s",
request.method,
request.url.path,
response.status_code,
elapsed_ms,
request.url.query or "-",
_truncate(body_text or "-", max_body_chars),
_truncate(response_body or "-", max_body_chars),
)
return response
if not detailed_logging_enabled:
return response
logger.info(
"http request method=%s path=%s status=%s duration_ms=%.2f query=%s body=%s",
request.method,
@ -60,7 +83,7 @@ def install_request_logging(app: FastAPI) -> None:
response.status_code,
elapsed_ms,
request.url.query or "-",
body_text or "-",
_truncate(body_text or "-", max_body_chars),
)
return response
@ -75,3 +98,37 @@ def _format_body(body: bytes) -> str:
except Exception:
text = body.decode("utf-8", errors="replace")
return text
def _truncate(value: str, limit: int) -> str:
if len(value) <= limit:
return value
return f"{value[:limit]}...[truncated {len(value) - limit} chars]"
async def _read_response_body(response: Any) -> str:
body = getattr(response, "body", None)
if isinstance(body, (bytes, bytearray)):
return _format_body(bytes(body))
if isinstance(response, StreamingResponse):
return "<streaming-response>"
iterator = getattr(response, "body_iterator", None)
if iterator is None:
return ""
chunks: list[bytes] = []
async for chunk in iterator:
if isinstance(chunk, bytes):
chunks.append(chunk)
else:
chunks.append(str(chunk).encode("utf-8", errors="replace"))
raw = b"".join(chunks)
response.body_iterator = _iterate_once(raw)
return _format_body(raw)
async def _iterate_once(payload: bytes) -> AsyncIterator[bytes]:
yield payload

View file

@ -34,7 +34,12 @@ class ChatCompletionsRequest(BaseModel):
max_completion_tokens: int | None = None
max_tokens: int | None = None
metadata: dict[str, str] | None = None
modalities: list[Literal["text", "audio"]] | None = None
provider: dict[str, Any] | None = None
plugins: list[dict[str, Any]] | None = None
session_id: str | None = None
trace: dict[str, Any] | None = None
modalities: list[Literal["text", "image"]] | None = None
models: list[Any] | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
prediction: dict[str, Any] | None = None
@ -48,6 +53,7 @@ class ChatCompletionsRequest(BaseModel):
default=None,
validation_alias=AliasChoices("reasoning_summary", "reasoningSummary"),
)
reasoning: dict[str, Any] | None = None
response_format: dict[str, Any] | None = None
safety_identifier: str | None = None
seed: int | None = None
@ -55,6 +61,8 @@ class ChatCompletionsRequest(BaseModel):
stop: str | list[str] | None = None
store: bool | None = None
stream_options: dict[str, Any] | None = None
debug: dict[str, Any] | None = None
image_config: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
@ -71,6 +79,7 @@ class ChatCompletionChunkDelta(BaseModel):
role: str | None = None
content: str | None = None
reasoning_content: str | None = None
reasoning_details: list[dict[str, Any]] | None = None
tool_calls: list[dict[str, Any]] | None = None
@ -97,7 +106,9 @@ class ChatCompletionMessage(BaseModel):
role: str = "assistant"
content: str
reasoning: str | None = None
reasoning_content: str | None = None
reasoning_details: list[dict[str, Any]] | None = None
refusal: str | None = None
tool_calls: list[dict[str, Any]] | None = None
function_call: dict[str, Any] | None = None

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import json
from pathlib import Path
import yaml
@ -12,7 +11,7 @@ from app.config.models import AppConfig, AuthConfig, LoadedConfig, LoadedProvide
def load_config(config_path: Path, auth_path: Path) -> LoadedConfig:
app_data = _read_yaml(config_path)
auth_data = _read_json(auth_path)
auth_data = _read_yaml(auth_path)
app_config = AppConfig.model_validate(app_data)
auth_config = AuthConfig.model_validate(auth_data)
@ -27,6 +26,8 @@ def load_config(config_path: Path, auth_path: Path) -> LoadedConfig:
name=provider_name,
url=provider.url,
type=provider.type,
display_name=provider.name,
models=provider.models,
whitelist=provider.whitelist,
blacklist=provider.blacklist,
auth=auth,
@ -41,11 +42,3 @@ def _read_yaml(path: Path) -> dict:
if not isinstance(raw, dict):
raise ValueError(f"YAML file '{path}' must contain an object")
return raw
def _read_json(path: Path) -> dict:
with path.open("r", encoding="utf-8") as handle:
raw = json.load(handle)
if not isinstance(raw, dict):
raise ValueError(f"JSON file '{path}' must contain an object")
return raw

View file

@ -15,6 +15,8 @@ class ProviderConfig(BaseModel):
url: str
type: ProviderType
name: str | None = None
models: dict[str, dict[str, str]] | None = None
whitelist: list[str] | None = None
blacklist: list[str] | None = None
@ -39,7 +41,13 @@ class OAuthAuth(BaseModel):
expires: int
ProviderAuth = TokenAuth | OAuthAuth
class UrlAuth(BaseModel):
model_config = ConfigDict(extra="forbid")
url: str
ProviderAuth = TokenAuth | OAuthAuth | UrlAuth
class AuthConfig(BaseModel):
@ -54,6 +62,8 @@ class LoadedProviderConfig(BaseModel):
name: str
url: str
type: ProviderType
display_name: str | None = None
models: dict[str, dict[str, str]] | None = None
whitelist: list[str] | None = None
blacklist: list[str] | None = None
auth: ProviderAuth

306
app/core/models_dev.py Normal file
View file

@ -0,0 +1,306 @@
"""models.dev catalog lookup and model enrichment."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
import logging
import time
from typing import Any
from urllib.parse import urlparse
import httpx
from app.core.types import CoreModel, ProviderModel
logger = logging.getLogger(__name__)
class ModelsDevCatalog:
"""Fetches and caches models.dev provider catalog."""
def __init__(
self,
*,
catalog_url: str = "https://models.dev/api.json",
cache_ttl_seconds: float = 600.0,
timeout_seconds: float = 10.0,
fetch_catalog: Callable[[], Awaitable[dict[str, Any]]] | None = None,
) -> None:
self._catalog_url = catalog_url
self._cache_ttl_seconds = cache_ttl_seconds
self._timeout_seconds = timeout_seconds
self._fetch_catalog = fetch_catalog or self._fetch_catalog_http
self._catalog: dict[str, dict[str, Any]] | None = None
self._catalog_expires_at = 0.0
self._catalog_lock = asyncio.Lock()
async def get_provider_models(
self, *, provider_name: str, provider_url: str
) -> tuple[str | None, dict[str, dict[str, Any]]]:
catalog = await self._get_catalog()
if not catalog:
return None, {}
provider_id = _resolve_provider_id(
catalog,
provider_name=provider_name,
provider_url=provider_url,
)
if provider_id is None:
return None, {}
provider = _as_dict(catalog.get(provider_id))
provider_display_name = _as_str(provider.get("name"))
raw_models = provider.get("models")
if not isinstance(raw_models, dict):
return provider_display_name, {}
return (
provider_display_name,
{
model_id: model
for model_id, model in raw_models.items()
if isinstance(model_id, str) and isinstance(model, dict)
},
)
async def _get_catalog(self) -> dict[str, dict[str, Any]]:
now = time.monotonic()
if self._catalog is not None and now < self._catalog_expires_at:
return self._catalog
async with self._catalog_lock:
now = time.monotonic()
if self._catalog is not None and now < self._catalog_expires_at:
return self._catalog
try:
fetched = await self._fetch_catalog()
self._catalog = _coerce_catalog(fetched)
except Exception:
logger.exception("failed to fetch models.dev catalog")
self._catalog = {}
self._catalog_expires_at = now + self._cache_ttl_seconds
return self._catalog
async def _fetch_catalog_http(self) -> dict[str, Any]:
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
response = await client.get(self._catalog_url)
response.raise_for_status()
payload = response.json()
return payload if isinstance(payload, dict) else {}
def to_core_model(
*,
provider_name: str,
provider_model: ProviderModel,
models_dev_model: dict[str, Any] | None = None,
provider_display_name: str | None = None,
model_override: dict[str, Any] | None = None,
) -> CoreModel:
override_name = _as_str(model_override.get("name")) if model_override else None
models_dev_name = (
_as_str(models_dev_model.get("name")) if models_dev_model else None
)
name = override_name or models_dev_name or provider_model.name
models_dev_context_length = (
_context_length_from_models_dev(models_dev_model) if models_dev_model else None
)
context_length = models_dev_context_length or provider_model.context_length
models_dev_architecture = (
_architecture_from_models_dev(models_dev_model) if models_dev_model else None
)
architecture = models_dev_architecture or provider_model.architecture
models_dev_pricing = (
_pricing_from_models_dev(models_dev_model) if models_dev_model else None
)
pricing = models_dev_pricing or provider_model.pricing
models_dev_supported_parameters = (
_supported_parameters_from_models_dev(models_dev_model)
if models_dev_model
else None
)
supported_parameters = (
models_dev_supported_parameters or provider_model.supported_parameters
)
models_dev_created = (
_created_from_models_dev(models_dev_model) if models_dev_model else None
)
created = models_dev_created or provider_model.created
models_dev_owned_by = (
_as_str(models_dev_model.get("provider")) if models_dev_model else None
)
owned_by = models_dev_owned_by or provider_model.owned_by
return CoreModel(
id=f"{provider_name}/{provider_model.id}",
created=created or 0,
owned_by=owned_by or "wzray",
name=name,
provider_display_name=provider_display_name,
description=provider_model.description,
context_length=context_length,
architecture=architecture,
pricing=pricing,
supported_parameters=supported_parameters,
settings=provider_model.settings,
opencode=provider_model.opencode,
config_override=model_override,
)
def _coerce_catalog(raw: dict[str, Any]) -> dict[str, dict[str, Any]]:
catalog: dict[str, dict[str, Any]] = {}
for key, value in raw.items():
if isinstance(key, str) and isinstance(value, dict):
catalog[key] = value
return catalog
def _resolve_provider_id(
catalog: dict[str, dict[str, Any]], *, provider_name: str, provider_url: str
) -> str | None:
if provider_name in catalog:
return provider_name
provider_host = _host(provider_url)
if provider_host is not None:
for provider_id, provider in catalog.items():
api_url = _as_str(provider.get("api"))
if api_url is None:
continue
if _host(api_url) == provider_host:
return provider_id
normalized_name = _normalize_token(provider_name)
candidates: list[tuple[int, int, str]] = []
for provider_id in catalog:
normalized_id = _normalize_token(provider_id)
if not normalized_id:
continue
if normalized_name.startswith(normalized_id) or normalized_id.startswith(
normalized_name
):
candidates.append(
(
abs(len(normalized_name) - len(normalized_id)),
-len(normalized_id),
provider_id,
)
)
if candidates:
candidates.sort()
return candidates[0][2]
return None
def _context_length_from_models_dev(model: dict[str, Any]) -> int | None:
limit = _as_dict(model.get("limit"))
context = limit.get("context")
return context if isinstance(context, int) else None
def _architecture_from_models_dev(model: dict[str, Any]) -> dict[str, Any] | None:
modalities = _as_dict(model.get("modalities"))
architecture: dict[str, Any] = {}
input_modalities = modalities.get("input")
if isinstance(input_modalities, list):
architecture["input_modalities"] = [
str(modality) for modality in input_modalities if isinstance(modality, str)
]
output_modalities = modalities.get("output")
if isinstance(output_modalities, list):
architecture["output_modalities"] = [
str(modality) for modality in output_modalities if isinstance(modality, str)
]
family = _as_str(model.get("family"))
if family is not None:
architecture["family"] = family
return architecture or None
def _pricing_from_models_dev(model: dict[str, Any]) -> dict[str, Any] | None:
cost = model.get("cost")
return cost if isinstance(cost, dict) else None
def _supported_parameters_from_models_dev(model: dict[str, Any]) -> list[str] | None:
supported: list[str] = []
if _as_bool(model.get("reasoning")):
supported.extend(["reasoning_effort", "reasoning_summary"])
if _as_bool(model.get("tool_call")):
supported.extend(["tools", "tool_choice", "parallel_tool_calls"])
if _as_bool(model.get("structured_output")):
supported.append("response_format")
output_limit = _as_dict(model.get("limit")).get("output")
if isinstance(output_limit, int):
supported.extend(["max_tokens", "max_completion_tokens"])
modalities = model.get("modalities")
if isinstance(modalities, dict):
supported.append("modalities")
return sorted(set(supported)) or None
def _created_from_models_dev(model: dict[str, Any]) -> int | None:
release_date = _as_str(model.get("release_date"))
if release_date is None:
return None
try:
parsed = datetime.fromisoformat(release_date)
except ValueError:
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
else:
parsed = parsed.astimezone(timezone.utc)
return int(parsed.timestamp())
def _host(url: str) -> str | None:
try:
parsed = urlparse(url)
except ValueError:
return None
if not parsed.hostname:
return None
return parsed.hostname.lower()
def _normalize_token(value: str) -> str:
return "".join(ch for ch in value.lower() if ch.isalnum())
def _as_dict(raw: Any) -> dict[str, Any]:
return raw if isinstance(raw, dict) else {}
def _as_str(value: Any) -> str | None:
return value if isinstance(value, str) else None
def _as_bool(value: Any) -> bool:
return value is True

View file

@ -2,18 +2,42 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
import logging
import time
from typing import Any
from app.core.models_dev import ModelsDevCatalog, to_core_model
from app.core.errors import InvalidModelError, ProviderNotFoundError
from app.core.types import CoreChatRequest, CoreChunk, CoreModel, ProviderChatRequest
from app.core.types import (
CoreChatRequest,
CoreChunk,
CoreModel,
ProviderChatRequest,
ProviderModel,
)
from app.providers.base import BaseProvider
logger = logging.getLogger(__name__)
class RouterCore:
"""Routes normalized requests to a specific provider instance."""
def __init__(self, providers: dict[str, BaseProvider]) -> None:
def __init__(
self,
providers: dict[str, BaseProvider],
*,
models_cache_ttl_seconds: float = 600.0,
models_dev_catalog: ModelsDevCatalog | None = None,
) -> None:
self.providers = providers
self._models_cache_ttl_seconds = models_cache_ttl_seconds
self._models_dev_catalog = models_dev_catalog or ModelsDevCatalog()
self._models_cache: list[CoreModel] | None = None
self._models_cache_expires_at = 0.0
self._models_cache_lock = asyncio.Lock()
def resolve_provider(self, routed_model: str) -> tuple[BaseProvider, str]:
provider_name, upstream_model = split_routed_model(routed_model)
@ -36,7 +60,12 @@ class RouterCore:
max_completion_tokens=request.max_completion_tokens,
max_tokens=request.max_tokens,
metadata=request.metadata,
provider=request.provider,
plugins=request.plugins,
session_id=request.session_id,
trace=request.trace,
modalities=request.modalities,
models=request.models,
n=request.n,
parallel_tool_calls=request.parallel_tool_calls,
prediction=request.prediction,
@ -45,6 +74,7 @@ class RouterCore:
prompt_cache_retention=request.prompt_cache_retention,
reasoning_effort=request.reasoning_effort,
reasoning_summary=request.reasoning_summary,
reasoning=request.reasoning,
response_format=request.response_format,
safety_identifier=request.safety_identifier,
seed=request.seed,
@ -53,6 +83,8 @@ class RouterCore:
store=request.store,
stream_options=request.stream_options,
temperature=request.temperature,
debug=request.debug,
image_config=request.image_config,
tool_choice=request.tool_choice,
tools=request.tools,
top_logprobs=request.top_logprobs,
@ -66,28 +98,64 @@ class RouterCore:
yield chunk
async def list_models(self) -> list[CoreModel]:
models: list[CoreModel] = []
for provider_name, provider in self.providers.items():
provider_models = await provider.list_models()
for model in provider_models:
if not provider.is_model_allowed(model.id):
continue
models.append(
CoreModel(
id=f"{provider_name}/{model.id}",
created=model.created or 0,
owned_by=model.owned_by or "wzray",
name=model.name,
description=model.description,
context_length=model.context_length,
architecture=model.architecture,
pricing=model.pricing,
supported_parameters=model.supported_parameters,
settings=model.settings,
opencode=model.opencode,
now = time.monotonic()
if self._models_cache is not None and now < self._models_cache_expires_at:
return list(self._models_cache)
async with self._models_cache_lock:
now = time.monotonic()
if self._models_cache is not None and now < self._models_cache_expires_at:
return list(self._models_cache)
models: list[CoreModel] = []
for provider_name, provider in self.providers.items():
try:
provider_models = await provider.list_models()
except Exception as exc:
logger.exception(
"models listing failed provider=%s base_url=%s error=%s",
provider_name,
provider.base_url,
repr(exc),
)
continue
(
provider_display_name,
models_dev_models,
) = await self._models_dev_catalog.get_provider_models(
provider_name=provider_name,
provider_url=provider.base_url,
)
return models
for model in provider_models:
if not provider.is_model_allowed(model.id):
continue
model_override = None
if provider.models_config is not None:
raw_override = provider.models_config.get(model.id)
if isinstance(raw_override, dict):
model_override = raw_override
provider_display_name = (
provider.display_name or provider_display_name
)
models.append(
to_core_model(
provider_name=provider_name,
provider_model=model,
models_dev_model=models_dev_models.get(model.id),
provider_display_name=provider_display_name,
model_override=model_override,
)
)
self._models_cache = list(models)
self._models_cache_expires_at = (
time.monotonic() + self._models_cache_ttl_seconds
)
return list(models)
def split_routed_model(routed_model: str) -> tuple[str, str]:

View file

@ -32,7 +32,12 @@ class CoreChatRequest:
max_completion_tokens: int | None = None
max_tokens: int | None = None
metadata: dict[str, str] | None = None
provider: dict[str, Any] | None = None
plugins: list[dict[str, Any]] | None = None
session_id: str | None = None
trace: dict[str, Any] | None = None
modalities: list[str] | None = None
models: list[Any] | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
prediction: dict[str, Any] | None = None
@ -41,6 +46,7 @@ class CoreChatRequest:
prompt_cache_retention: str | None = None
reasoning_effort: str | None = None
reasoning_summary: str | None = None
reasoning: dict[str, Any] | None = None
response_format: dict[str, Any] | None = None
safety_identifier: str | None = None
seed: int | None = None
@ -49,6 +55,8 @@ class CoreChatRequest:
store: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
debug: dict[str, Any] | None = None
image_config: dict[str, Any] | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
@ -71,7 +79,12 @@ class ProviderChatRequest:
max_completion_tokens: int | None = None
max_tokens: int | None = None
metadata: dict[str, str] | None = None
provider: dict[str, Any] | None = None
plugins: list[dict[str, Any]] | None = None
session_id: str | None = None
trace: dict[str, Any] | None = None
modalities: list[str] | None = None
models: list[Any] | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
prediction: dict[str, Any] | None = None
@ -80,6 +93,7 @@ class ProviderChatRequest:
prompt_cache_retention: str | None = None
reasoning_effort: str | None = None
reasoning_summary: str | None = None
reasoning: dict[str, Any] | None = None
response_format: dict[str, Any] | None = None
safety_identifier: str | None = None
seed: int | None = None
@ -88,6 +102,8 @@ class ProviderChatRequest:
store: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
debug: dict[str, Any] | None = None
image_config: dict[str, Any] | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
@ -120,6 +136,7 @@ class CoreModel:
created: int = 0
owned_by: str = "wzray"
name: str | None = None
provider_display_name: str | None = None
description: str | None = None
context_length: int | None = None
architecture: dict[str, Any] | None = None
@ -127,6 +144,7 @@ class CoreModel:
supported_parameters: list[str] | None = None
settings: dict[str, Any] | None = None
opencode: dict[str, Any] | None = None
config_override: dict[str, Any] | None = None
@dataclass(slots=True)
@ -135,5 +153,6 @@ class CoreChunk:
role: str | None = None
content: str | None = None
reasoning_content: str | None = None
reasoning_details: list[dict[str, Any]] | None = None
tool_calls: list[dict[str, Any]] | None = None
finish_reason: str | None = None

View file

@ -14,7 +14,7 @@ from app.providers.factory import build_provider_registry
def _resolve_paths() -> tuple[Path, Path]:
data_dir = Path(os.getenv("AI_ROUTER_DATA_DIR", "."))
config_path = Path(os.getenv("AI_ROUTER_CONFIG", str(data_dir / "config.yml")))
auth_path = Path(os.getenv("AI_ROUTER_AUTH", str(data_dir / "auth.json")))
auth_path = Path(os.getenv("AI_ROUTER_AUTH", str(data_dir / "auth.yml")))
return config_path, auth_path

View file

@ -5,6 +5,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
from typing import Any
from app.core.errors import ModelNotAllowedError
from app.core.types import CoreChunk, ProviderChatRequest, ProviderModel
@ -30,6 +31,8 @@ class BaseProvider(ABC):
self.api_type = api_type
self.whitelist = whitelist
self.blacklist = blacklist
self.display_name: str | None = None
self.models_config: dict[str, dict[str, Any]] | None = None
def ensure_model_allowed(self, model: str) -> None:
if self.whitelist is not None and model not in self.whitelist:

View file

@ -0,0 +1,84 @@
"""Model list mapping for codex responses provider."""
from __future__ import annotations
from typing import Any
from app.core.types import ProviderModel
from app.providers.codex_responses.utils import _to_dict
def _coerce_model_items(raw: Any) -> list[dict[str, Any]]:
data = getattr(raw, "data", None)
if isinstance(data, list):
return [_to_dict(item) for item in data]
models = getattr(raw, "models", None)
if isinstance(models, list):
return [_to_dict(item) for item in models]
if isinstance(raw, list):
return [_to_dict(item) for item in raw]
if isinstance(raw, dict):
models = raw.get("models")
if isinstance(models, list):
return [_to_dict(item) for item in models]
data = raw.get("data")
if isinstance(data, list):
return [_to_dict(item) for item in data]
return []
def _to_provider_model(raw: dict[str, Any]) -> ProviderModel:
model_id = raw.get("id")
if not isinstance(model_id, str):
model_id = raw.get("slug") if isinstance(raw.get("slug"), str) else None
if model_id is None:
raise ValueError("Codex model item missing id/slug")
arch = raw.get("architecture")
if isinstance(arch, dict):
arch = {k: v for k, v in arch.items() if k != "tokenizer"}
else:
input_modalities = raw.get("input_modalities")
if isinstance(input_modalities, list):
arch = {"input_modalities": input_modalities}
else:
arch = None
supported = raw.get("supported_parameters")
supported_parameters = None
if isinstance(supported, list):
supported_parameters = [str(v) for v in supported]
context_length = raw.get("context_length")
if not isinstance(context_length, int):
context_length = raw.get("context_window")
if not isinstance(context_length, int):
top_provider = raw.get("top_provider")
if isinstance(top_provider, dict) and isinstance(
top_provider.get("context_length"), int
):
context_length = top_provider.get("context_length")
else:
context_length = None
return ProviderModel(
id=model_id,
name=(
raw.get("name")
if isinstance(raw.get("name"), str)
else raw.get("display_name")
if isinstance(raw.get("display_name"), str)
else None
),
description=raw.get("description")
if isinstance(raw.get("description"), str)
else None,
context_length=context_length,
architecture=arch,
pricing=raw.get("pricing") if isinstance(raw.get("pricing"), dict) else None,
supported_parameters=supported_parameters,
settings=raw.get("settings") if isinstance(raw.get("settings"), dict) else None,
opencode=raw.get("opencode") if isinstance(raw.get("opencode"), dict) else None,
created=raw.get("created") if isinstance(raw.get("created"), int) else None,
owned_by=raw.get("owned_by") if isinstance(raw.get("owned_by"), str) else None,
)

View file

@ -51,10 +51,13 @@ class CodexOAuthProvider:
headers = {"Authorization": f"Bearer {self._access}"}
if self._account_id:
headers["chatgpt-account-id"] = self._account_id
headers["OpenAI-Beta"] = "responses=experimental"
headers["ChatGPT-Account-Id"] = self._account_id
return OAuthData(token=self._access, headers=headers)
async def get_headers(self) -> dict[str, str]:
oauth = await self.get()
return oauth.headers
def _is_expired(self) -> bool:
return int(time.time() * 1000) >= self._expires - 60_000

View file

@ -4,17 +4,21 @@ from __future__ import annotations
from collections.abc import AsyncIterator
import logging
import os
from typing import TYPE_CHECKING, Any, Mapping, cast
import random
import string
from typing import TYPE_CHECKING, Any, cast
from openai import AsyncOpenAI, OpenAIError
from app.config.models import LoadedProviderConfig, OAuthAuth
from app.config.models import LoadedProviderConfig, OAuthAuth, UrlAuth
from app.core.errors import UpstreamProviderError
from app.core.types import CoreChunk, ProviderChatRequest, ProviderModel
from app.providers.base import BaseProvider
from app.providers.codex_responses.models import _coerce_model_items, _to_provider_model
from app.providers.codex_responses.oauth import CodexOAuthProvider
from app.providers.codex_responses.stream import _map_response_stream_to_chunks
from app.providers.codex_responses.translator import build_responses_create_args
from app.providers.token_url_auth import TokenUrlAuthProvider
from app.providers.registry import provider
if TYPE_CHECKING:
@ -32,7 +36,7 @@ class CodexResponsesProvider(BaseProvider):
*,
name: str,
base_url: str,
oauth: CodexOAuthProvider,
oauth: CodexOAuthProvider | TokenUrlAuthProvider,
client_version: str = "0.100.0",
whitelist: list[str] | None = None,
blacklist: list[str] | None = None,
@ -54,31 +58,49 @@ class CodexResponsesProvider(BaseProvider):
@classmethod
def from_config(cls, config: LoadedProviderConfig) -> BaseProvider:
if not isinstance(config.auth, OAuthAuth):
oauth: CodexOAuthProvider | TokenUrlAuthProvider
if isinstance(config.auth, OAuthAuth):
oauth = CodexOAuthProvider(auth=config.auth)
elif isinstance(config.auth, UrlAuth):
oauth = TokenUrlAuthProvider(token_url=config.auth.url)
else:
raise ValueError(
f"Provider '{config.name}' type 'codex-responses' requires oauth auth"
f"Provider '{config.name}' type 'codex-responses' requires oauth or url auth"
)
return cls(
provider = cls(
name=config.name,
base_url=config.url,
oauth=CodexOAuthProvider(auth=config.auth),
oauth=oauth,
whitelist=config.whitelist,
blacklist=config.blacklist,
)
provider.display_name = config.display_name
provider.models_config = config.models
return provider
async def stream_chat(
self, request: ProviderChatRequest
) -> AsyncIterator[CoreChunk]:
auth = await self._oauth.get()
auth = await self._oauth.get_headers()
try:
create_args, ignored_extra = build_responses_create_args(request)
create_args, extra_body = build_responses_create_args(request)
_log_ignored_extra(ignored_extra, provider_name=self.name)
_log_outbound_request(create_args, provider_name=self.name)
extra_headers = dict(auth)
extra_headers.update(
{
"session_id": "ses_"
+ "".join(
random.choice(string.ascii_letters + string.digits)
for _ in range(28)
),
"originator": "opencode",
}
)
stream = await cast("AsyncResponses", self._client.responses).create(
**create_args,
extra_headers=auth.headers,
extra_headers=extra_headers,
extra_body=extra_body,
)
async for chunk in _map_response_stream_to_chunks(
@ -98,10 +120,10 @@ class CodexResponsesProvider(BaseProvider):
) from exc
async def list_models(self) -> list[ProviderModel]:
auth = await self._oauth.get()
auth = await self._oauth.get_headers()
try:
response = await self._client.models.list(
extra_headers=auth.headers,
extra_headers=auth,
extra_query={"client_version": self._client_version},
)
items = _coerce_model_items(response)
@ -120,473 +142,3 @@ class CodexResponsesProvider(BaseProvider):
raise UpstreamProviderError(
f"Provider '{self.name}' failed while listing models: {exc}"
) from exc
def _log_ignored_extra(extra: dict[str, Any], *, provider_name: str) -> None:
if not extra:
return
logger.error(
"provider '%s' ignored unsupported extra params: %s",
provider_name,
extra,
)
def _log_outbound_request(
create_args: Mapping[str, Any], *, provider_name: str
) -> None:
enabled = os.getenv("AI_CODEX_REQUEST_LOG_ENABLED", "false").strip().lower()
if enabled not in {"1", "true", "yes", "on"}:
return
logger.error(
"provider '%s' outbound responses.create model=%s reasoning=%s text=%s parallel_tool_calls=%s tools_count=%s input_items=%s",
provider_name,
create_args.get("model"),
create_args.get("reasoning"),
create_args.get("text"),
create_args.get("parallel_tool_calls"),
len(create_args.get("tools", []) or []),
len(create_args.get("input", []) or []),
)
def _coerce_model_items(raw: Any) -> list[dict[str, Any]]:
data = getattr(raw, "data", None)
if isinstance(data, list):
return [_to_dict(item) for item in data]
models = getattr(raw, "models", None)
if isinstance(models, list):
return [_to_dict(item) for item in models]
if isinstance(raw, list):
return [_to_dict(item) for item in raw]
if isinstance(raw, dict):
models = raw.get("models")
if isinstance(models, list):
return [_to_dict(item) for item in models]
data = raw.get("data")
if isinstance(data, list):
return [_to_dict(item) for item in data]
return []
def _to_dict(raw: Any) -> dict[str, Any]:
if hasattr(raw, "model_dump"):
dumped = raw.model_dump()
if isinstance(dumped, dict):
return dumped
if isinstance(raw, dict):
return raw
return {}
def _as_dict(raw: Any) -> dict[str, Any]:
return _to_dict(raw)
def _as_str(value: Any) -> str | None:
return value if isinstance(value, str) else None
def _to_provider_model(raw: dict[str, Any]) -> ProviderModel:
model_id = raw.get("id")
if not isinstance(model_id, str):
model_id = raw.get("slug") if isinstance(raw.get("slug"), str) else None
if model_id is None:
raise ValueError("Codex model item missing id/slug")
arch = raw.get("architecture")
if isinstance(arch, dict):
arch = {k: v for k, v in arch.items() if k != "tokenizer"}
else:
input_modalities = raw.get("input_modalities")
if isinstance(input_modalities, list):
arch = {"input_modalities": input_modalities}
else:
arch = None
supported = raw.get("supported_parameters")
supported_parameters = None
if isinstance(supported, list):
supported_parameters = [str(v) for v in supported]
context_length = raw.get("context_length")
if not isinstance(context_length, int):
context_length = raw.get("context_window")
if not isinstance(context_length, int):
top_provider = raw.get("top_provider")
if isinstance(top_provider, dict) and isinstance(
top_provider.get("context_length"), int
):
context_length = top_provider.get("context_length")
else:
context_length = None
return ProviderModel(
id=model_id,
name=(
raw.get("name")
if isinstance(raw.get("name"), str)
else raw.get("display_name")
if isinstance(raw.get("display_name"), str)
else None
),
description=raw.get("description")
if isinstance(raw.get("description"), str)
else None,
context_length=context_length,
architecture=arch,
pricing=raw.get("pricing") if isinstance(raw.get("pricing"), dict) else None,
supported_parameters=supported_parameters,
settings=raw.get("settings") if isinstance(raw.get("settings"), dict) else None,
opencode=raw.get("opencode") if isinstance(raw.get("opencode"), dict) else None,
created=raw.get("created") if isinstance(raw.get("created"), int) else None,
owned_by=raw.get("owned_by") if isinstance(raw.get("owned_by"), str) else None,
)
async def _map_response_stream_to_chunks(
stream: AsyncIterator[Any], *, provider_name: str
) -> AsyncIterator[CoreChunk]:
sent_role = False
emitted_text = False
emitted_reasoning = False
saw_tool_call = False
tool_call_delta_seen: set[str] = set()
tool_call_finalized: set[str] = set()
tool_call_names: dict[str, str] = {}
tool_call_indexes: dict[str, int] = {}
pending_tool_arguments: dict[str, str] = {}
next_tool_call_index = 0
emitted_content_keys: set[tuple[str, int]] = set()
saw_delta_keys: set[tuple[str, int]] = set()
async for event in stream:
event_type = _event_type(event)
if event_type in {"response.output_text.delta", "response.refusal.delta"}:
key = _content_key(event)
if key is not None:
saw_delta_keys.add(key)
text_delta = _first_string(event, "delta")
if text_delta:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield CoreChunk(content=text_delta)
emitted_text = True
continue
if event_type == "response.reasoning_summary_text.delta":
reasoning_delta = _first_string(event, "delta")
if reasoning_delta:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield CoreChunk(reasoning_content=reasoning_delta)
emitted_reasoning = True
continue
if event_type == "response.reasoning_summary_text.done":
reasoning_done = _first_string(event, "text")
if reasoning_done and not emitted_reasoning:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield CoreChunk(reasoning_content=reasoning_done)
emitted_reasoning = True
continue
if event_type in {"response.output_text.done", "response.refusal.done"}:
key = _content_key(event)
if key is not None and key in saw_delta_keys:
emitted_content_keys.add(key)
continue
text_done = _first_string(event, "text", "refusal")
if text_done:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield CoreChunk(content=text_done)
emitted_text = True
if key is not None:
emitted_content_keys.add(key)
continue
if event_type == "response.content_part.done":
key = _content_key(event)
if key is None or key in emitted_content_keys or key in saw_delta_keys:
continue
part = _as_dict(getattr(event, "part", None))
if part.get("type") == "output_text":
text = _as_str(part.get("text"))
elif part.get("type") == "refusal":
text = _as_str(part.get("refusal"))
else:
text = None
if text:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield CoreChunk(content=text)
emitted_text = True
emitted_content_keys.add(key)
continue
if event_type == "response.output_item.done":
item = _as_dict(getattr(event, "item", None))
item_type = _as_str(item.get("type"))
if item_type == "function_call":
saw_tool_call = True
tool_chunk, next_tool_call_index = _tool_call_delta_from_item(
item,
tool_call_finalized=tool_call_finalized,
tool_call_names=tool_call_names,
tool_call_indexes=tool_call_indexes,
pending_tool_arguments=pending_tool_arguments,
next_tool_call_index=next_tool_call_index,
)
if tool_chunk is not None:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield tool_chunk
continue
if item_type in {
"custom_tool_call",
"computer_call",
"code_interpreter_call",
"web_search_call",
"file_search_call",
"shell_call",
"apply_patch_call",
"mcp_call",
}:
saw_tool_call = True
continue
if item_type != "message":
continue
item_id = _as_str(item.get("id"))
if item_id is None:
continue
content = item.get("content")
if not isinstance(content, list):
continue
for idx, part in enumerate(content):
key = (item_id, idx)
if key in emitted_content_keys or key in saw_delta_keys:
continue
part_dict = _as_dict(part)
text = None
if part_dict.get("type") == "output_text":
text = _as_str(part_dict.get("text"))
elif part_dict.get("type") == "refusal":
text = _as_str(part_dict.get("refusal"))
if text:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield CoreChunk(content=text)
emitted_text = True
emitted_content_keys.add(key)
continue
if event_type in {
"response.function_call_arguments.delta",
"response.function_call_arguments.done",
"response.custom_tool_call_input.delta",
"response.custom_tool_call_input.done",
}:
saw_tool_call = True
tool_chunk, next_tool_call_index = _tool_call_delta_from_event(
event,
event_type=event_type,
tool_call_delta_seen=tool_call_delta_seen,
tool_call_finalized=tool_call_finalized,
tool_call_names=tool_call_names,
tool_call_indexes=tool_call_indexes,
pending_tool_arguments=pending_tool_arguments,
next_tool_call_index=next_tool_call_index,
)
if tool_chunk is not None:
if not sent_role:
yield CoreChunk(role="assistant")
sent_role = True
yield tool_chunk
continue
if event_type == "response.completed":
finish_reason = (
"tool_calls" if saw_tool_call and not emitted_text else "stop"
)
yield CoreChunk(finish_reason=finish_reason)
continue
if event_type == "response.incomplete":
reason = _extract_incomplete_finish_reason(event)
yield CoreChunk(finish_reason=reason)
continue
if event_type == "response.failed":
raise UpstreamProviderError(
f"Provider '{provider_name}' upstream response failed"
)
def _event_type(event: Any) -> str:
return _as_str(getattr(event, "type", None)) or ""
def _content_key(event: Any) -> tuple[str, int] | None:
item_id = _as_str(getattr(event, "item_id", None))
content_index = getattr(event, "content_index", None)
if item_id is None or not isinstance(content_index, int):
return None
return (item_id, content_index)
def _first_string(event: Any, *field_names: str) -> str | None:
for name in field_names:
value = getattr(event, name, None)
if isinstance(value, str):
return value
return None
def _extract_incomplete_finish_reason(event: Any) -> str:
response = getattr(event, "response", None)
response_dict = _as_dict(response)
incomplete_details = response_dict.get("incomplete_details")
details_dict = _as_dict(incomplete_details)
reason = _as_str(details_dict.get("reason"))
if reason == "max_output_tokens":
return "length"
if reason == "content_filter":
return "content_filter"
return "stop"
def _tool_call_delta_from_event(
event: Any,
*,
event_type: str,
tool_call_delta_seen: set[str],
tool_call_finalized: set[str],
tool_call_names: dict[str, str],
tool_call_indexes: dict[str, int],
pending_tool_arguments: dict[str, str],
next_tool_call_index: int,
) -> tuple[CoreChunk | None, int]:
item_id = _as_str(getattr(event, "item_id", None))
if item_id is None:
return None, next_tool_call_index
index = tool_call_indexes.get(item_id)
if index is None:
index = next_tool_call_index
tool_call_indexes[item_id] = index
next_tool_call_index += 1
if event_type == "response.function_call_arguments.delta":
if item_id in tool_call_finalized:
return None, next_tool_call_index
arguments_delta = _as_str(getattr(event, "delta", None))
if not arguments_delta:
return None, next_tool_call_index
tool_call_delta_seen.add(item_id)
pending_tool_arguments[item_id] = (
pending_tool_arguments.get(item_id, "") + arguments_delta
)
return None, next_tool_call_index
if event_type == "response.function_call_arguments.done":
name = _as_str(getattr(event, "name", None))
arguments = _as_str(getattr(event, "arguments", None))
if item_id in tool_call_finalized:
return None, next_tool_call_index
if name:
tool_call_names[item_id] = name
function_name = tool_call_names.get(item_id)
if not function_name:
return None, next_tool_call_index
function_arguments = (
arguments
if arguments is not None
else pending_tool_arguments.get(item_id, "")
)
pending_tool_arguments.pop(item_id, None)
tool_call_finalized.add(item_id)
return (
CoreChunk(
tool_calls=[
{
"index": index,
"id": item_id,
"type": "function",
"function": {
"name": function_name,
"arguments": function_arguments,
},
}
]
),
next_tool_call_index,
)
return None, next_tool_call_index
def _tool_call_delta_from_item(
item: dict[str, Any],
*,
tool_call_finalized: set[str],
tool_call_names: dict[str, str],
tool_call_indexes: dict[str, int],
pending_tool_arguments: dict[str, str],
next_tool_call_index: int,
) -> tuple[CoreChunk | None, int]:
if _as_str(item.get("type")) != "function_call":
return None, next_tool_call_index
item_id = _as_str(item.get("id"))
name = _as_str(item.get("name"))
if item_id is None or name is None:
return None, next_tool_call_index
if item_id in tool_call_finalized:
return None, next_tool_call_index
index = tool_call_indexes.get(item_id)
if index is None:
index = next_tool_call_index
tool_call_indexes[item_id] = index
next_tool_call_index += 1
tool_call_names[item_id] = name
arguments = _as_str(item.get("arguments"))
function_arguments = (
arguments if arguments is not None else pending_tool_arguments.get(item_id, "")
)
pending_tool_arguments.pop(item_id, None)
tool_call_finalized.add(item_id)
return (
CoreChunk(
tool_calls=[
{
"index": index,
"id": item_id,
"type": "function",
"function": {"name": name, "arguments": function_arguments},
}
]
),
next_tool_call_index,
)

View file

@ -0,0 +1,518 @@
"""Streaming event mapping for codex responses provider."""
from __future__ import annotations
from collections.abc import AsyncIterator, Mapping
from dataclasses import dataclass, field
from typing import Any
from app.core.errors import UpstreamProviderError
from app.core.types import CoreChunk
from app.providers.codex_responses.utils import _as_dict, _as_str
_TEXT_DELTA_EVENTS = {"response.output_text.delta", "response.refusal.delta"}
_TEXT_DONE_EVENTS = {"response.output_text.done", "response.refusal.done"}
_TOOL_CALL_DELTA_EVENTS = {
"response.function_call_arguments.delta",
"response.function_call_arguments.done",
"response.custom_tool_call_input.delta",
"response.custom_tool_call_input.done",
}
_NON_FUNCTION_TOOL_CALL_ITEMS = {
"custom_tool_call",
"computer_call",
"code_interpreter_call",
"web_search_call",
"file_search_call",
"shell_call",
"apply_patch_call",
"mcp_call",
}
@dataclass(slots=True)
class _ResponseStreamState:
sent_role: bool = False
emitted_text: bool = False
emitted_reasoning: bool = False
saw_tool_call: bool = False
next_tool_call_index: int = 0
tool_call_finalized: set[str] = field(default_factory=set)
tool_call_names: dict[str, str] = field(default_factory=dict)
tool_call_indexes: dict[str, int] = field(default_factory=dict)
pending_tool_arguments: dict[str, str] = field(default_factory=dict)
emitted_tool_arguments: dict[str, str] = field(default_factory=dict)
emitted_content_keys: set[tuple[str, int]] = field(default_factory=set)
saw_delta_keys: set[tuple[str, int]] = field(default_factory=set)
emitted_reasoning_item_ids: set[str] = field(default_factory=set)
async def _map_response_stream_to_chunks(
stream: AsyncIterator[Any], *, provider_name: str
) -> AsyncIterator[CoreChunk]:
state = _ResponseStreamState()
async for event in stream:
event_type = _event_type(event)
if event_type in _TEXT_DELTA_EVENTS:
for chunk in _handle_text_delta(event, state):
yield chunk
continue
if event_type == "response.reasoning_summary_text.delta":
for chunk in _emit_reasoning_chunk(_first_string(event, "delta"), state):
yield chunk
continue
if event_type == "response.reasoning_summary_text.done":
details_chunk = _reasoning_summary_detail_chunk(event)
if details_chunk is not None:
for chunk in _wrap_assistant_chunk(details_chunk, state):
yield chunk
if not state.emitted_reasoning:
for chunk in _emit_reasoning_chunk(_first_string(event, "text"), state):
yield chunk
continue
if event_type == "response.output_item.added":
for chunk in _handle_output_item_added(event, state):
yield chunk
continue
if event_type in _TEXT_DONE_EVENTS:
for chunk in _handle_text_done(event, state):
yield chunk
continue
if event_type == "response.content_part.done":
for chunk in _handle_content_part_done(event, state):
yield chunk
continue
if event_type == "response.output_item.done":
for chunk in _handle_output_item_done(event, state):
yield chunk
continue
if event_type in _TOOL_CALL_DELTA_EVENTS:
state.saw_tool_call = True
tool_chunk = _tool_call_delta_from_event(
event,
event_type=event_type,
state=state,
)
if tool_chunk is not None:
for chunk in _wrap_assistant_chunk(tool_chunk, state):
yield chunk
continue
if event_type == "response.completed":
finish_reason = (
"tool_calls"
if state.saw_tool_call and not state.emitted_text
else "stop"
)
yield CoreChunk(finish_reason=finish_reason)
continue
if event_type == "response.incomplete":
reason = _extract_incomplete_finish_reason(event)
yield CoreChunk(finish_reason=reason)
continue
if event_type == "response.failed":
raise UpstreamProviderError(
f"Provider '{provider_name}' upstream response failed"
)
def _event_type(event: Any) -> str:
return _as_str(getattr(event, "type", None)) or ""
def _ensure_assistant_role(state: _ResponseStreamState) -> CoreChunk | None:
if state.sent_role:
return None
state.sent_role = True
return CoreChunk(role="assistant")
def _wrap_assistant_chunk(
chunk: CoreChunk | None, state: _ResponseStreamState
) -> list[CoreChunk]:
if chunk is None:
return []
role_chunk = _ensure_assistant_role(state)
return [role_chunk, chunk] if role_chunk is not None else [chunk]
def _emit_text_chunk(text: str | None, state: _ResponseStreamState) -> list[CoreChunk]:
if not text:
return []
state.emitted_text = True
return _wrap_assistant_chunk(CoreChunk(content=text), state)
def _emit_reasoning_chunk(
text: str | None, state: _ResponseStreamState
) -> list[CoreChunk]:
if not text:
return []
state.emitted_reasoning = True
return _wrap_assistant_chunk(CoreChunk(reasoning_content=text), state)
def _handle_text_delta(event: Any, state: _ResponseStreamState) -> list[CoreChunk]:
key = _content_key(event)
if key is not None:
state.saw_delta_keys.add(key)
return _emit_text_chunk(_first_string(event, "delta"), state)
def _handle_text_done(event: Any, state: _ResponseStreamState) -> list[CoreChunk]:
key = _content_key(event)
if key is not None and key in state.saw_delta_keys:
state.emitted_content_keys.add(key)
return []
chunks = _emit_text_chunk(_first_string(event, "text", "refusal"), state)
if key is not None:
state.emitted_content_keys.add(key)
return chunks
def _content_part_text(part: Mapping[str, Any]) -> str | None:
part_type = part.get("type")
if part_type == "output_text":
return _as_str(part.get("text"))
if part_type == "refusal":
return _as_str(part.get("refusal"))
return None
def _handle_content_part_done(
event: Any, state: _ResponseStreamState
) -> list[CoreChunk]:
key = _content_key(event)
if key is None or key in state.emitted_content_keys or key in state.saw_delta_keys:
return []
text = _content_part_text(_as_dict(getattr(event, "part", None)))
chunks = _emit_text_chunk(text, state)
if chunks:
state.emitted_content_keys.add(key)
return chunks
def _handle_output_item_done(
event: Any, state: _ResponseStreamState
) -> list[CoreChunk]:
item = _as_dict(getattr(event, "item", None))
item_type = _as_str(item.get("type"))
if item_type == "reasoning":
reasoning_chunk = _reasoning_encrypted_detail_chunk(
item=item,
item_id=_as_str(item.get("id")),
output_index=getattr(event, "output_index", None),
state=state,
)
return _wrap_assistant_chunk(reasoning_chunk, state)
if item_type == "function_call":
state.saw_tool_call = True
return _wrap_assistant_chunk(
_tool_call_delta_from_item(item, state=state), state
)
if item_type in _NON_FUNCTION_TOOL_CALL_ITEMS:
state.saw_tool_call = True
return []
if item_type != "message":
return []
item_id = _as_str(item.get("id"))
content = item.get("content")
if item_id is None or not isinstance(content, list):
return []
chunks: list[CoreChunk] = []
for idx, part in enumerate(content):
key = (item_id, idx)
if key in state.emitted_content_keys or key in state.saw_delta_keys:
continue
text = _content_part_text(_as_dict(part))
emitted = _emit_text_chunk(text, state)
if emitted:
chunks.extend(emitted)
state.emitted_content_keys.add(key)
return chunks
def _handle_output_item_added(
event: Any, state: _ResponseStreamState
) -> list[CoreChunk]:
item = _as_dict(getattr(event, "item", None))
if _as_str(item.get("type")) != "reasoning":
return []
reasoning_chunk = _reasoning_encrypted_detail_chunk(
item=item,
item_id=_as_str(item.get("id")),
output_index=getattr(event, "output_index", None),
state=state,
)
return _wrap_assistant_chunk(reasoning_chunk, state)
def _reasoning_encrypted_detail_chunk(
*,
item: dict[str, Any],
item_id: str | None,
output_index: Any,
state: _ResponseStreamState,
) -> CoreChunk | None:
encrypted = _as_str(item.get("encrypted_content"))
if encrypted is None:
return None
if item_id is not None:
if item_id in state.emitted_reasoning_item_ids:
return None
state.emitted_reasoning_item_ids.add(item_id)
detail: dict[str, Any] = {
"type": "reasoning.encrypted",
"data": encrypted,
"format": "openai-responses-v1",
}
if item_id is not None:
detail["id"] = item_id
if isinstance(output_index, int):
detail["index"] = output_index
return CoreChunk(reasoning_details=[detail])
def _reasoning_summary_detail_chunk(event: Any) -> CoreChunk | None:
text = _first_string(event, "text")
if not text:
return None
detail: dict[str, Any] = {
"type": "reasoning.summary",
"summary": text,
"format": "openai-responses-v1",
}
item_id = _as_str(getattr(event, "item_id", None))
if item_id is not None:
detail["id"] = item_id
summary_index = getattr(event, "summary_index", None)
if isinstance(summary_index, int):
detail["index"] = summary_index
return CoreChunk(reasoning_details=[detail])
def _content_key(event: Any) -> tuple[str, int] | None:
item_id = _as_str(getattr(event, "item_id", None))
content_index = getattr(event, "content_index", None)
if item_id is None or not isinstance(content_index, int):
return None
return (item_id, content_index)
def _first_string(event: Any, *field_names: str) -> str | None:
for name in field_names:
value = getattr(event, name, None)
if isinstance(value, str):
return value
return None
def _extract_incomplete_finish_reason(event: Any) -> str:
response = getattr(event, "response", None)
response_dict = _as_dict(response)
incomplete_details = response_dict.get("incomplete_details")
details_dict = _as_dict(incomplete_details)
reason = _as_str(details_dict.get("reason"))
if reason == "max_output_tokens":
return "length"
if reason == "content_filter":
return "content_filter"
return "stop"
def _tool_call_delta_from_event(
event: Any,
*,
event_type: str,
state: _ResponseStreamState,
) -> CoreChunk | None:
item_id = _as_str(getattr(event, "item_id", None))
if item_id is None:
return None
index = _get_tool_call_index(item_id, state)
if event_type in {
"response.function_call_arguments.delta",
"response.custom_tool_call_input.delta",
}:
if item_id in state.tool_call_finalized:
return None
arguments_delta = _as_str(getattr(event, "delta", None))
if not arguments_delta:
return None
function_name = state.tool_call_names.get(item_id)
if function_name is None:
state.pending_tool_arguments[item_id] = (
state.pending_tool_arguments.get(item_id, "") + arguments_delta
)
return None
state.emitted_tool_arguments[item_id] = (
state.emitted_tool_arguments.get(item_id, "") + arguments_delta
)
return _build_tool_call_chunk(
item_id=item_id,
index=index,
name=function_name,
arguments=arguments_delta,
)
if event_type in {
"response.function_call_arguments.done",
"response.custom_tool_call_input.done",
}:
name = _as_str(getattr(event, "name", None))
arguments = _as_str(getattr(event, "arguments", None)) or _as_str(
getattr(event, "input", None)
)
if item_id in state.tool_call_finalized:
return None
if name:
state.tool_call_names[item_id] = name
function_name = state.tool_call_names.get(item_id)
if function_name is None:
return None
buffered = state.pending_tool_arguments.pop(item_id, "")
emitted = state.emitted_tool_arguments.pop(item_id, "")
done_arguments = _resolve_done_tool_arguments(
arguments=arguments,
buffered=buffered,
emitted=emitted,
)
state.tool_call_finalized.add(item_id)
if done_arguments is None:
return None
return _build_tool_call_chunk(
item_id=item_id,
index=index,
name=function_name,
arguments=done_arguments,
)
return None
def _tool_call_delta_from_item(
item: dict[str, Any],
*,
state: _ResponseStreamState,
) -> CoreChunk | None:
if _as_str(item.get("type")) != "function_call":
return None
item_id = _as_str(item.get("id"))
name = _as_str(item.get("name"))
if item_id is None or name is None:
return None
if item_id in state.tool_call_finalized:
return None
index = _get_tool_call_index(item_id, state)
state.tool_call_names[item_id] = name
arguments = _as_str(item.get("arguments"))
buffered = state.pending_tool_arguments.pop(item_id, "")
emitted = state.emitted_tool_arguments.pop(item_id, "")
function_arguments = _resolve_done_tool_arguments(
arguments=arguments,
buffered=buffered,
emitted=emitted,
)
if function_arguments is None:
return None
state.tool_call_finalized.add(item_id)
return _build_tool_call_chunk(
item_id=item_id,
index=index,
name=name,
arguments=function_arguments,
)
def _get_tool_call_index(item_id: str, state: _ResponseStreamState) -> int:
index = state.tool_call_indexes.get(item_id)
if index is not None:
return index
index = state.next_tool_call_index
state.tool_call_indexes[item_id] = index
state.next_tool_call_index += 1
return index
def _build_tool_call_chunk(
*, item_id: str, index: int, name: str, arguments: str
) -> CoreChunk:
return CoreChunk(
tool_calls=[
{
"index": index,
"id": item_id,
"type": "function",
"function": {"name": name, "arguments": arguments},
}
]
)
def _tool_arguments_tail(arguments: str | None, emitted: str) -> str:
if arguments is None:
return ""
if not emitted:
return arguments
if arguments.startswith(emitted):
return arguments[len(emitted) :]
return ""
def _resolve_done_tool_arguments(
*, arguments: str | None, buffered: str, emitted: str
) -> str | None:
if buffered:
if arguments is None:
return buffered
return arguments
if arguments is None:
return None if emitted else ""
if not emitted:
return arguments
tail = _tool_arguments_tail(arguments, emitted)
return tail or None

View file

@ -37,13 +37,15 @@ from app.core.types import CoreMessage, ProviderChatRequest
def build_responses_create_args(
request: ProviderChatRequest,
) -> tuple[ResponseCreateParamsStreaming, dict[str, Any]]:
instructions = _build_instructions(request.messages)
messages = list(request.messages)
instructions = _pop_instruction_message(messages)
args: ResponseCreateParamsStreaming = {
"model": request.model,
"input": _build_input_items(request.messages),
"input": _build_input_items(messages),
"stream": True,
"store": False,
"include": ["reasoning.encrypted_content"],
"parallel_tool_calls": request.parallel_tool_calls
if request.parallel_tool_calls is not None
else True,
@ -54,16 +56,12 @@ def build_responses_create_args(
if request.metadata is not None:
args["metadata"] = request.metadata
if request.prompt_cache_key is not None:
args["prompt_cache_key"] = request.prompt_cache_key
if request.prompt_cache_retention is not None:
args["prompt_cache_retention"] = cast(Any, request.prompt_cache_retention)
if request.safety_identifier is not None:
args["safety_identifier"] = request.safety_identifier
if request.service_tier is not None:
args["service_tier"] = cast(Any, request.service_tier)
if request.temperature is not None:
args["temperature"] = request.temperature
if request.top_p is not None:
args["top_p"] = request.top_p
if request.tools is not None:
@ -79,20 +77,53 @@ def build_responses_create_args(
if text_config is not None:
args["text"] = text_config
return args, dict(request.extra)
return args, _build_extra_body(request)
def _build_instructions(messages: list[CoreMessage]) -> str | None:
parts: list[str] = []
for message in messages:
def _build_extra_body(request: ProviderChatRequest) -> dict[str, Any]:
extra_body = dict(request.extra)
if request.provider is not None:
extra_body["provider"] = request.provider
if request.plugins is not None:
extra_body["plugins"] = request.plugins
if request.session_id is not None:
extra_body["session_id"] = request.session_id
if request.trace is not None:
extra_body["trace"] = request.trace
if request.models is not None:
extra_body["models"] = request.models
if request.debug is not None:
extra_body["debug"] = request.debug
if request.image_config is not None:
extra_body["image_config"] = request.image_config
for key in (
"metadata",
"prompt_cache_retention",
"safety_identifier",
"service_tier",
"top_p",
"tools",
"tool_choice",
"reasoning",
"response_format",
"verbosity",
):
extra_body.pop(key, None)
return extra_body
def _pop_instruction_message(messages: list[CoreMessage]) -> str:
for index, message in enumerate(messages):
if message.role not in {"developer", "system"}:
continue
text = _extract_text(message.content)
if text:
parts.append(text)
if not parts:
return "You are a helpful assistant."
return "\n\n".join(parts)
messages.pop(index)
return text
return "You are a helpful assistant."
def _build_input_items(messages: list[CoreMessage]) -> ResponseInputParam:
@ -104,6 +135,9 @@ def _build_input_items(messages: list[CoreMessage]) -> ResponseInputParam:
continue
if message.role in {"developer", "system"}:
system_message = _build_message_item(message, role_override="system")
if system_message is not None:
items.append(system_message)
continue
if message.role in {"user", "assistant"}:
@ -123,44 +157,73 @@ def _build_input_items(messages: list[CoreMessage]) -> ResponseInputParam:
return items
def _build_message_item(message: CoreMessage) -> EasyInputMessageParam | None:
content = _build_message_content(message.content)
def _build_message_item(
message: CoreMessage, role_override: str | None = None
) -> EasyInputMessageParam | None:
role_value = role_override or message.role
content = _build_message_content(message.content, role_value)
if content is None:
return None
role = cast("EasyInputMessageParam", {"role": message.role})["role"]
role = cast("EasyInputMessageParam", {"role": role_value})["role"]
item: EasyInputMessageParam = {
"type": "message",
"role": role,
"content": content,
"content": cast(Any, content),
}
return item
def _build_message_content(
content: Any,
) -> str | list[ResponseInputContentParam] | None:
if isinstance(content, str):
return content
role: str,
) -> list[dict[str, Any]] | None:
if content is None or content == "":
return None
if isinstance(content, str):
text_part_type = "output_text" if role == "assistant" else "input_text"
return [{"type": text_part_type, "text": content}]
if isinstance(content, dict):
content = [content]
if not isinstance(content, list):
raise ValueError("Unsupported message content for responses input")
out: list[ResponseInputContentParam] = []
out: list[dict[str, Any]] = []
for part in content:
if not isinstance(part, dict):
continue
part_type = part.get("type")
cache_control = _extract_cache_control(part)
if part_type == "text" and isinstance(part.get("text"), str):
out.append({"type": "input_text", "text": part["text"]})
text_part_type = "output_text" if role == "assistant" else "input_text"
text_item: dict[str, Any] = {"type": text_part_type, "text": part["text"]}
if cache_control is not None:
text_item["cache_control"] = cache_control
out.append(text_item)
continue
if part_type in {"input_text", "output_text"} and isinstance(
part.get("text"), str
):
text_item = {"type": part_type, "text": part["text"]}
if cache_control is not None:
text_item["cache_control"] = cache_control
out.append(text_item)
continue
if part_type == "refusal" and isinstance(part.get("refusal"), str):
out.append({"type": "input_text", "text": part["refusal"]})
refusal_item: dict[str, Any] = {
"type": "input_text",
"text": part["refusal"],
}
if cache_control is not None:
refusal_item["cache_control"] = cache_control
out.append(refusal_item)
continue
if part_type == "image_url" and isinstance(part.get("image_url"), dict):
@ -170,7 +233,49 @@ def _build_message_content(
image_item["image_url"] = image["url"]
if image.get("detail") in {"low", "high", "auto"}:
image_item["detail"] = image["detail"]
out.append(cast("ResponseInputContentParam", image_item))
if cache_control is not None:
image_item["cache_control"] = cache_control
out.append(image_item)
continue
if part_type == "input_image":
image_item = {
"type": "input_image",
"image_url": part.get("image_url"),
"file_id": part.get("file_id"),
"detail": part.get("detail")
if part.get("detail") in {"low", "high", "auto"}
else "auto",
}
if cache_control is not None:
image_item["cache_control"] = cache_control
out.append({k: v for k, v in image_item.items() if v is not None})
continue
if part_type == "input_audio" and isinstance(part.get("input_audio"), dict):
audio = part["input_audio"]
audio_format = audio.get("format")
if isinstance(audio.get("data"), str) and audio_format in {
"wav",
"mp3",
"flac",
"m4a",
"ogg",
"aiff",
"aac",
"pcm16",
"pcm24",
}:
audio_item: dict[str, Any] = {
"type": "input_audio",
"input_audio": {
"data": audio["data"],
"format": audio_format,
},
}
if cache_control is not None:
audio_item["cache_control"] = cache_control
out.append(audio_item)
continue
if part_type == "file" and isinstance(part.get("file"), dict):
@ -182,7 +287,38 @@ def _build_message_content(
file_item["file_id"] = wrapped["file_id"]
if isinstance(wrapped.get("filename"), str):
file_item["filename"] = wrapped["filename"]
out.append(cast("ResponseInputContentParam", file_item))
if isinstance(wrapped.get("file_url"), str):
file_item["file_url"] = wrapped["file_url"]
if cache_control is not None:
file_item["cache_control"] = cache_control
out.append(file_item)
continue
if part_type == "input_file":
file_item = {
"type": "input_file",
"file_data": part.get("file_data"),
"file_id": part.get("file_id"),
"filename": part.get("filename"),
"file_url": part.get("file_url"),
}
if cache_control is not None:
file_item["cache_control"] = cache_control
out.append({k: v for k, v in file_item.items() if v is not None})
continue
if part_type in {"video_url", "input_video"} and isinstance(
part.get("video_url"), dict
):
video = part["video_url"]
if isinstance(video.get("url"), str):
video_item: dict[str, Any] = {
"type": "input_file",
"file_url": video["url"],
}
if cache_control is not None:
video_item["cache_control"] = cache_control
out.append(video_item)
continue
if out:
@ -314,11 +450,9 @@ def _normalize_tool_output(
if isinstance(content, str):
return content
if isinstance(content, list):
converted = _build_message_content(content)
converted = _build_message_content(content, "tool")
if converted is None:
return ""
if isinstance(converted, str):
return converted
return cast("ResponseFunctionCallOutputItemListParam", converted)
if content is None:
return ""
@ -328,6 +462,17 @@ def _normalize_tool_output(
def _build_reasoning(request: ProviderChatRequest) -> Reasoning | None:
effort = request.reasoning_effort
summary = request.reasoning_summary
if effort is None and isinstance(request.reasoning, dict):
reasoning_effort = request.reasoning.get("effort")
if isinstance(reasoning_effort, str):
effort = reasoning_effort
if summary is None and isinstance(request.reasoning, dict):
reasoning_summary = request.reasoning.get("summary")
if isinstance(reasoning_summary, str):
summary = reasoning_summary
if summary is None:
if request.verbosity == "low":
summary = "concise"
@ -345,6 +490,22 @@ def _build_reasoning(request: ProviderChatRequest) -> Reasoning | None:
return reasoning
def _extract_cache_control(part: dict[str, Any]) -> dict[str, Any] | None:
cache_control = part.get("cache_control")
if not isinstance(cache_control, dict):
return None
cache_type = cache_control.get("type")
if cache_type != "ephemeral":
return None
out: dict[str, Any] = {"type": "ephemeral"}
ttl = cache_control.get("ttl")
if ttl in {"5m", "1h"}:
out["ttl"] = ttl
return out
def _build_text_config(request: ProviderChatRequest) -> ResponseTextConfigParam | None:
if request.response_format is None and request.verbosity is None:
return None

View file

@ -0,0 +1,36 @@
"""Shared helpers for codex responses provider."""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def _log_ignored_extra(extra: dict[str, Any], *, provider_name: str) -> None:
if not extra:
return
logger.error(
"provider '%s' ignored unsupported extra params: %s",
provider_name,
extra,
)
def _to_dict(raw: Any) -> dict[str, Any]:
if hasattr(raw, "model_dump"):
dumped = raw.model_dump()
if isinstance(dumped, dict):
return dumped
if isinstance(raw, dict):
return raw
return {}
def _as_dict(raw: Any) -> dict[str, Any]:
return _to_dict(raw)
def _as_str(value: Any) -> str | None:
return value if isinstance(value, str) else None

View file

@ -4,19 +4,32 @@ from __future__ import annotations
from collections.abc import AsyncIterator
import logging
from typing import Any, cast
from typing import Any, Protocol, cast
from openai import AsyncOpenAI, OpenAIError
from app.config.models import LoadedProviderConfig, TokenAuth
from app.config.models import LoadedProviderConfig, TokenAuth, UrlAuth
from app.core.errors import UpstreamProviderError
from app.core.types import CoreChunk, CoreMessage, ProviderChatRequest, ProviderModel
from app.providers.base import BaseProvider
from app.providers.token_url_auth import TokenUrlAuthProvider
from app.providers.registry import provider
logger = logging.getLogger(__name__)
class _BearerAuthProvider(Protocol):
async def get_headers(self) -> dict[str, str]: ...
class _StaticBearerAuthProvider:
def __init__(self, token: str) -> None:
self._token = token
async def get_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._token}"}
@provider(type="openai-completions")
class OpenAICompletionsProvider(BaseProvider):
"""Provider that talks to OpenAI-compatible /chat/completions APIs."""
@ -26,9 +39,10 @@ class OpenAICompletionsProvider(BaseProvider):
*,
name: str,
base_url: str,
token: str,
token: str | None = None,
whitelist: list[str] | None = None,
blacklist: list[str] | None = None,
auth_provider: _BearerAuthProvider | None = None,
client: Any | None = None,
) -> None:
super().__init__(
@ -38,30 +52,51 @@ class OpenAICompletionsProvider(BaseProvider):
whitelist=whitelist,
blacklist=blacklist,
)
self._client = client or AsyncOpenAI(api_key=token, base_url=base_url)
if auth_provider is None:
if token is None:
raise ValueError(
f"Provider '{name}' type 'openai-completions' requires auth provider or token"
)
auth_provider = _StaticBearerAuthProvider(token)
self._auth_provider = auth_provider
self._client = client or AsyncOpenAI(api_key="placeholder", base_url=base_url)
@classmethod
def from_config(cls, config: LoadedProviderConfig) -> BaseProvider:
if not isinstance(config.auth, TokenAuth):
auth_provider: _BearerAuthProvider
if isinstance(config.auth, TokenAuth):
auth_provider = _StaticBearerAuthProvider(config.auth.token)
elif isinstance(config.auth, UrlAuth):
auth_provider = TokenUrlAuthProvider(token_url=config.auth.url)
else:
raise ValueError(
f"Provider '{config.name}' type 'openai-completions' requires token auth"
f"Provider '{config.name}' type 'openai-completions' requires token or url auth"
)
return cls(
provider = cls(
name=config.name,
base_url=config.url,
token=config.auth.token,
whitelist=config.whitelist,
blacklist=config.blacklist,
auth_provider=auth_provider,
)
provider.display_name = config.display_name
provider.models_config = config.models
return provider
async def stream_chat(
self, request: ProviderChatRequest
) -> AsyncIterator[CoreChunk]:
try:
auth_headers = await self._auth_provider.get_headers()
extra_body, ignored_extra = _build_openai_extra_payload(request)
sent_assistant_role = False
stream = await cast(Any, self._client.chat.completions).create(
model=request.model,
messages=[_to_chat_message(m) for m in request.messages],
stream=True,
extra_headers=auth_headers,
extra_body=extra_body,
audio=request.audio,
frequency_penalty=request.frequency_penalty,
logit_bias=request.logit_bias,
@ -93,22 +128,50 @@ class OpenAICompletionsProvider(BaseProvider):
verbosity=request.verbosity,
web_search_options=request.web_search_options,
)
_log_ignored_extra(request.extra, provider_name=self.name)
_log_ignored_extra(ignored_extra, provider_name=self.name)
async for chunk in stream:
choices = getattr(chunk, "choices", [])
for idx, choice in enumerate(choices):
delta = getattr(choice, "delta", None)
role = getattr(delta, "role", None) if delta is not None else None
content = (
getattr(delta, "content", None) if delta is not None else None
)
role = _delta_str(delta, "role")
content = _delta_str(delta, "content")
reasoning_content = _extract_reasoning_content(delta)
reasoning_details = _extract_reasoning_details(delta)
tool_calls = _extract_tool_calls(delta)
finish_reason = getattr(choice, "finish_reason", None)
if role is None and content is None and finish_reason is None:
choice_index = getattr(choice, "index", idx)
if role == "assistant":
sent_assistant_role = True
elif (
not sent_assistant_role
and role is None
and (
content is not None
or reasoning_content is not None
or reasoning_details is not None
or tool_calls is not None
)
):
sent_assistant_role = True
yield CoreChunk(index=choice_index, role="assistant")
if (
role is None
and content is None
and reasoning_content is None
and reasoning_details is None
and tool_calls is None
and finish_reason is None
):
continue
yield CoreChunk(
index=getattr(choice, "index", idx),
index=choice_index,
role=role,
content=content,
reasoning_content=reasoning_content,
reasoning_details=reasoning_details,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
except OpenAIError as exc:
@ -122,7 +185,8 @@ class OpenAICompletionsProvider(BaseProvider):
async def list_models(self) -> list[ProviderModel]:
try:
response = await self._client.models.list()
auth_headers = await self._auth_provider.get_headers()
response = await self._client.models.list(extra_headers=auth_headers)
items = _coerce_model_items(response)
return [_to_provider_model(item) for item in items if item.get("id")]
except OpenAIError as exc:
@ -156,6 +220,154 @@ def _to_chat_message(message: CoreMessage) -> dict[str, Any]:
return out
def _delta_value(delta: Any, field: str) -> Any:
if delta is None:
return None
if isinstance(delta, dict):
return delta.get(field)
return getattr(delta, field, None)
def _delta_str(delta: Any, field: str) -> str | None:
value = _delta_value(delta, field)
return value if isinstance(value, str) else None
def _extract_reasoning_content(delta: Any) -> str | None:
for field in ("reasoning_content", "reasoning", "reasoning_text"):
value = _delta_value(delta, field)
if isinstance(value, str):
return value
reasoning_obj = _to_dict(_delta_value(delta, "reasoning"))
for field in ("content", "text", "summary"):
value = reasoning_obj.get(field)
if isinstance(value, str):
return value
return None
def _extract_tool_calls(delta: Any) -> list[dict[str, Any]] | None:
raw = _delta_value(delta, "tool_calls")
if not isinstance(raw, list):
return None
tool_calls = [_to_dict(item) for item in raw]
return tool_calls or None
def _extract_reasoning_details(delta: Any) -> list[dict[str, Any]] | None:
candidates: list[Any] = []
raw = _delta_value(delta, "reasoning_details")
if isinstance(raw, list):
candidates.extend(raw)
reasoning_obj = _to_dict(_delta_value(delta, "reasoning"))
nested = reasoning_obj.get("details")
if isinstance(nested, list):
candidates.extend(nested)
details: list[dict[str, Any]] = []
for candidate in candidates:
detail = _to_dict(candidate)
if detail:
details.append(detail)
return details or None
def _build_openai_extra_payload(
request: ProviderChatRequest,
) -> tuple[dict[str, Any], dict[str, Any]]:
extra_body: dict[str, Any] = {}
extra = request.extra
provider_value = request.provider or (
extra.get("provider") if isinstance(extra.get("provider"), dict) else None
)
if provider_value is not None:
extra_body["provider"] = provider_value
plugins_value = request.plugins or (
extra.get("plugins") if isinstance(extra.get("plugins"), list) else None
)
if plugins_value is not None:
extra_body["plugins"] = plugins_value
session_id_value = request.session_id or (
extra.get("session_id") if isinstance(extra.get("session_id"), str) else None
)
if session_id_value is not None:
extra_body["session_id"] = session_id_value
trace_value = request.trace or (
extra.get("trace") if isinstance(extra.get("trace"), dict) else None
)
if trace_value is not None:
extra_body["trace"] = trace_value
debug_value = request.debug or (
extra.get("debug") if isinstance(extra.get("debug"), dict) else None
)
if debug_value is not None:
extra_body["debug"] = debug_value
image_config_value = request.image_config or (
extra.get("image_config")
if isinstance(extra.get("image_config"), dict)
else None
)
if image_config_value is not None:
extra_body["image_config"] = image_config_value
models_value = request.models or (
extra.get("models") if isinstance(extra.get("models"), list) else None
)
if models_value is not None:
extra_body["models"] = models_value
reasoning_value = _build_reasoning_payload(request)
if reasoning_value is not None:
extra_body["reasoning"] = reasoning_value
ignored_extra = dict(request.extra)
for key in (
"provider",
"plugins",
"session_id",
"trace",
"models",
"debug",
"image_config",
):
ignored_extra.pop(key, None)
return extra_body, ignored_extra
def _build_reasoning_payload(request: ProviderChatRequest) -> dict[str, Any] | None:
effort = request.reasoning_effort
summary = request.reasoning_summary
reasoning = request.reasoning
if isinstance(reasoning, dict):
if effort is None and isinstance(reasoning.get("effort"), str):
effort = reasoning["effort"]
if summary is None and isinstance(reasoning.get("summary"), str):
summary = reasoning["summary"]
if effort is None and summary is None:
return None
out: dict[str, Any] = {}
if effort is not None:
out["effort"] = effort
if summary is not None:
out["summary"] = summary
return out
def _log_ignored_extra(extra: dict[str, Any], *, provider_name: str) -> None:
if not extra:
return

View file

@ -0,0 +1,36 @@
"""Auth helper for fetching bearer token from external URL."""
from __future__ import annotations
import httpx
class TokenUrlAuthProvider:
"""Fetches bearer token from URL on each request."""
def __init__(self, *, token_url: str, timeout_seconds: float = 600.0) -> None:
self._token_url = token_url
self._timeout_seconds = timeout_seconds
async def get_headers(self) -> dict[str, str]:
token = await self._fetch_token()
return {"Authorization": f"Bearer {token}"}
async def _fetch_token(self) -> str:
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
response = await client.get(self._token_url)
if response.status_code >= 400:
raise ValueError(
f"Token URL auth failed with status {response.status_code}"
)
try:
token = response.json()['token']
except ValueError:
raise ValueError(f"Token URL auth returned invalid response: {response.content}")
if not token:
raise ValueError("Token URL auth returned empty token")
return token

View file

@ -1,4 +1,11 @@
providers:
wzray:
url: http://127.0.0.1:8000/v1
type: openai-completions
name: Wzray
models:
openai/gpt-5:
name: GPT-5
zai:
url: https://api.z.ai/api/coding/paas/v4
type: openai-completions
@ -14,4 +21,4 @@ providers:
- glm-5-free
codex:
url: https://chatgpt.com/backend-api
type: codex-responses
type: codex-responses

File diff suppressed because it is too large Load diff

8513
docs/Responses schema.md Normal file

File diff suppressed because it is too large Load diff

88
docs/TODO.md Normal file
View file

@ -0,0 +1,88 @@
# Codex Router Alignment Plan
## Confirmed Scope (from latest requirements)
- Keep `parallel_tool_calls: true` in outbound responses payloads.
- Do not send `prompt_cache_key` from the router for now.
- Always send `include: ["reasoning.encrypted_content"]`.
- Header work now: remove self-added duplicate headers only.
- Message payload work now: stop string `content` serialization and send content parts like the good sample.
## Deferred (intentionally postponed)
- Full header parity with the golden capture (transport/runtime-level UA and low-level accept-encoding parity).
- Full one-to-one `input` history shape parity (`type` omission strategy for message items).
- Recovering or synthesizing top-level developer message from upstream chat-completions schema.
- End-to-end reasoning item roundtrip parity in history (`type: reasoning` pass-through and replay behavior).
- Prompt cache implementation strategy and lifecycle management.
## Feasible Path For Deferred Items
1. Header parity
- Keep current sdk-based client for now.
- If exact parity is required, switch codex provider transport from `AsyncOpenAI` to a custom `httpx` SSE client and set an explicit header allowlist.
2. Input history shape parity
- Add a translator mode that emits implicit message items (`{"role":...,"content":...}`) without `type: "message"`.
- Keep explicit item support for `function_call` and `function_call_output` unchanged.
3. Developer message availability
- Add optional request extension field(s) in `model_extra`, e.g. `opencode_developer_message` or `opencode_input_items`.
- Use extension when provided; otherwise keep current first-system/developer-to-instructions behavior.
4. Reasoning item roundtrip
- Accept explicit inbound items with `extra.type == "reasoning"` and pass through `encrypted_content` + `summary` to responses `input`.
- Keep chat-completions output contract unchanged; reasoning passthrough is input-side only unless a dedicated raw endpoint is added.
5. Prompt cache strategy
- Keep disabled by default.
- Add optional feature flag for deterministic hash-based key generation once cache policy is agreed.
## Schema.md Gap Breakdown (planning only, no implementation yet)
### Legend
- `Supported` = already implemented.
- `Partial` = partly implemented but not schema-complete.
- `Missing` = not implemented yet.
| # | Area | What it does for users | Current status | Decision from latest review | Notes / planned behavior |
|---|---|---|---|---|---|
| 1 | Extra request controls (`provider`, `plugins`, `session_id`, `trace`, `models`, `debug`, `image_config`) | Lets users steer upstream routing, observability, plugin behavior, and image/provider-specific behavior directly from request body. | Missing | Explain each field first, then choose individually | Keep pass-through design: accept fields in API schema, preserve in internal request, forward when provider supports. |
| 2 | `reasoning` object in request (`reasoning.effort`, `reasoning.summary`) | Standard schema-compatible way to request reasoning effort and summary verbosity. | Partial (we use flat `reasoning_effort` / `reasoning_summary`) | Must support | Add canonical `reasoning` object support while preserving backward compatibility with current flat aliases. Define precedence rules if both forms are provided. |
| 3 | `modalities` alignment (`text`/`image`) | Controls output modalities users request. Must match schema contract exactly. | Partial / mismatched (`text`/`audio` now) | Must support schema behavior | Change request schema and internal mapping to `text`/`image` for the public API; ensure providers receive compatible values. |
| 4 | Full message content parts (audio/video/cache-control variants) | Enables multi-part multimodal inputs (audio, video, richer text metadata) and cache hints on message parts. | Partial | Must support | Expand accepted message content item parsing and translator mapping for all schema item variants, including preservation of unknown-but-valid provider fields where safe. |
| 5 | Assistant response extensions (`reasoning`, `reasoning_details`, `images`) | Returns richer assistant payloads: plain reasoning, structured reasoning metadata, and generated image outputs. | Missing | Must support | Extend response schemas and mappers so these fields can be emitted in non-streaming and streaming-compatible forms. |
| 6 | Encrypted reasoning passthrough (`reasoning_details` with encrypted data) | Exposes encrypted reasoning blocks from upstream exactly as received for advanced clients/debugging/replay. | Missing | High priority, must support | Capture encrypted reasoning items from responses stream (`response.output_item.*` for `type=reasoning`) and surface in API output as raw/structured reasoning details without lossy transformation. |
| 7 | Usage passthrough fidelity | Users should receive full upstream usage payload (raw), not a reduced subset. | Partial | Needed: pass full raw usage through | Do not over-normalize; preserve upstream usage object as-is when available. If upstream omits usage, return `null`/missing naturally. |
| 8 | Detailed HTTP error matrix parity | Strictly maps many status codes exactly like reference schema. | Partial | Not required now | Keep current error strategy unless product requirements change. |
| 9 | Optional `model` when `models` routing is used | OpenRouter-style multi-model router behavior. | Missing | Not required for this project | Keep `model` required in our API for now. |
## Field-by-field reference for item #1 (for product decision)
| Field | User-visible purpose | Typical payload shape | Risk/complexity |
|---|---|---|---|
| `provider` | Control provider routing policy (allow/deny fallback, specific providers, price/perf constraints). | Object with routing knobs (order/only/ignore, pricing, latency/throughput prefs). | Medium-High (router semantics + validation + provider compatibility). |
| `plugins` | Enable optional behavior modules (web search/moderation/auto-router/etc). | Array of plugin descriptors with `id` and optional settings. | Medium (validation + pass-through + provider-specific effects). |
| `session_id` (body) | Group related requests for observability/conversation continuity. | String (usually short opaque id). | Low (mostly passthrough + precedence with headers if both exist). |
| `trace` | Attach tracing metadata for distributed observability. | Object (`trace_id`, `span_name`, etc + custom keys). | Low-Medium (schema + passthrough). |
| `models` | Candidate model set for automatic selection/router behavior. | Array of model identifiers/patterns. | Medium-High (changes model resolution flow). |
| `debug` | Request debug payloads (e.g., transformed upstream request echo in stream). | Object flags like `echo_upstream_body`. | Medium (security/sensitivity review required). |
| `image_config` | Provider/model-specific image generation tuning options. | Arbitrary object map by provider/model conventions. | Medium (loosely-typed passthrough plus safety limits). |
## Execution order when implementation starts (agreed priorities)
1. Encrypted reasoning + reasoning details output path (#6 + #5 core subset).
2. Full usage passthrough fidelity (#7).
3. Request `reasoning` object support (#2).
4. Modalities contract alignment to schema (`text`/`image`) (#3).
5. Message content multimodal expansion (#4).
6. Decide and then implement selected item-#1 controls (`provider/plugins/session_id/trace/models/debug/image_config`).
## Implementation Steps (current)
1. Update codex translator payload fields:
- remove `prompt_cache_key`
- add mandatory `include`
2. Update message content serialization:
- serialize string message content as `[{"type":"input_text","text":...}]`
- preserve empty-content filtering behavior
3. Update codex provider header handling:
- avoid mutating oauth headers in place
- remove self-added duplicate `user-agent` header
4. Update/extend tests for new payload contract.
5. Run full `pytest` and fix regressions until green.

39
opencode/README.md Normal file
View file

@ -0,0 +1,39 @@
# OpenCode Wzray Plugin
Minimal OpenCode plugin that connects to your local API.
## What It Does
- Adds one provider in OpenCode config: `wzray`
- Uses OpenAI-compatible transport
- Fetches models from your API `GET /v1/models`
- Uses `WZRAY_API_KEY` as bearer token when set
## Defaults
- Base URL: `http://127.0.0.1:8000/v1`
- Provider key: `wzray`
- Fallback model list contains one safe model (`openai/gpt-5`)
## Environment Variables
- `WZRAY_API_BASE_URL` (optional)
- `AI_API_BASE_URL` (optional fallback)
- `WZRAY_API_KEY` (optional)
## Install
From this directory:
```bash
chmod +x ./install_opencode.sh
./install_opencode.sh
```
This copies plugin files to:
- `~/.config/opencode/plugin/opencode/`
And ensures `opencode.json` contains:
- `./plugin/opencode/plugin_wzray.ts`

View file

@ -0,0 +1,69 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
OPENCODE_DIR="${XDG_CONFIG_HOME:-$HOME/.config}/opencode"
PLUGIN_DIR="$OPENCODE_DIR/plugin/opencode"
CONFIG_FILE="$OPENCODE_DIR/opencode.json"
PLUGIN_ENTRY="./plugin/opencode/plugin_wzray.ts"
mkdir -p "$PLUGIN_DIR"
cp "$SCRIPT_DIR/plugin_wzray.ts" "$PLUGIN_DIR/plugin_wzray.ts"
cp "$SCRIPT_DIR/models_wzray.ts" "$PLUGIN_DIR/models_wzray.ts"
if [[ ! -f "$CONFIG_FILE" ]]; then
cat > "$CONFIG_FILE" <<EOF
{
"\$schema": "https://opencode.ai/config.json",
"plugin": [
"$PLUGIN_ENTRY"
]
}
EOF
echo "Installed plugin and created $CONFIG_FILE"
exit 0
fi
RUNTIME=""
if command -v node >/dev/null 2>&1; then
RUNTIME="node"
elif command -v bun >/dev/null 2>&1; then
RUNTIME="bun"
else
echo "Error: node or bun is required to update relaxed opencode.json" >&2
exit 1
fi
"$RUNTIME" - "$CONFIG_FILE" "$PLUGIN_ENTRY" <<'JS'
const fs = require("node:fs");
const configPath = process.argv[2];
const entry = process.argv[3];
const source = fs.readFileSync(configPath, "utf8");
let data;
try {
data = new Function(`"use strict"; return (${source});`)();
} catch (error) {
console.error(`Failed to parse ${configPath}: ${String(error)}`);
process.exit(1);
}
if (!data || typeof data !== "object" || Array.isArray(data)) {
data = {};
}
const plugins = Array.isArray(data.plugin) ? data.plugin : [];
if (!plugins.includes(entry)) {
plugins.push(entry);
}
data.plugin = plugins;
fs.writeFileSync(configPath, `${JSON.stringify(data, null, 2)}\n`, "utf8");
JS
echo "Installed plugin files to $PLUGIN_DIR"
echo "Updated $CONFIG_FILE"
echo "Optional envs: WZRAY_API_BASE_URL (default http://127.0.0.1:8000/v1), WZRAY_API_KEY"

53
opencode/models_wzray.ts Normal file
View file

@ -0,0 +1,53 @@
type ModelInfo = { name: string };
type AvailableModels = Record<string, ModelInfo>;
declare const process: {
env: Record<string, string | undefined>;
};
const FALLBACK_MODELS: AvailableModels = {
"openai/gpt-5.3-codex": { name: "OpenAI: GPT-5.3 Codex" },
};
function getApiKey(): string | undefined {
return process.env.WZRAY_API_KEY;
}
function normalizeBaseUrl(baseUrl: string): string {
const trimmed = baseUrl.replace(/\/+$/, "");
return trimmed.endsWith("/v1") ? trimmed : `${trimmed}/v1`;
}
function createHeaders(): Record<string, string> {
const apiKey = getApiKey();
if (!apiKey) return { "Content-Type": "application/json" };
return {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
};
}
export async function getAvailableModels(baseUrl: string): Promise<AvailableModels> {
const modelsUrl = `${normalizeBaseUrl(baseUrl)}/models`;
try {
const response = await fetch(modelsUrl, { headers: createHeaders() });
if (!response.ok) return FALLBACK_MODELS;
const payload = (await response.json()) as {
data?: Array<{ id?: string; name?: string }>;
};
const data = Array.isArray(payload.data) ? payload.data : [];
const models: AvailableModels = {};
for (const item of data) {
if (!item?.id) continue;
models[item.id] = { name: item.name || item.id };
}
return Object.keys(models).length > 0 ? models : FALLBACK_MODELS;
} catch {
return FALLBACK_MODELS;
}
}
export { FALLBACK_MODELS };
export type { AvailableModels, ModelInfo };

55
opencode/plugin_wzray.ts Normal file
View file

@ -0,0 +1,55 @@
// @ts-ignore
import type { Plugin, PluginInput } from "@opencode-ai/plugin";
import { FALLBACK_MODELS, getAvailableModels } from "./models_wzray";
declare const process: {
env: Record<string, string | undefined>;
};
const SCHEMA = "https://opencode.ai/config.json";
const NPM_PACKAGE = "@ai-sdk/openai-compatible";
const PROVIDER_KEY = "wzray";
const PROVIDER_NAME = "AI Router";
const DEFAULT_BASE_URL = "https://ai.wzray.com/v1";
function getBaseUrl(): string {
return (
process.env.WZRAY_API_BASE_URL ||
process.env.AI_API_BASE_URL ||
DEFAULT_BASE_URL
);
}
function getApiKey(): string {
return process.env.WZRAY_API_KEY || "{env:WZRAY_API_KEY}";
}
function createProviderConfig(models: Record<string, { name: string }>) {
return {
schema: SCHEMA,
npm: NPM_PACKAGE,
name: PROVIDER_NAME,
options: {
baseURL: getBaseUrl(),
apiKey: getApiKey(),
},
models,
};
}
async function configure(config: any): Promise<void> {
if (!config.provider) config.provider = {};
if (config.provider[PROVIDER_KEY]) return;
const baseUrl = getBaseUrl();
const models = await getAvailableModels(baseUrl);
config.provider[PROVIDER_KEY] = createProviderConfig(
Object.keys(models).length > 0 ? models : FALLBACK_MODELS,
);
}
const WzrayProviderPlugin: Plugin = async (_input: PluginInput) => {
return { config: configure };
};
export default WzrayProviderPlugin;

View file

@ -4,6 +4,7 @@ from collections.abc import AsyncIterator
from fastapi.testclient import TestClient
from app.core.models_dev import ModelsDevCatalog
from app.core.router import RouterCore
from app.core.types import CoreChunk, ProviderChatRequest, ProviderModel
from app.dependencies import get_router_core
@ -38,6 +39,53 @@ class _StreamingProvider(BaseProvider):
return [ProviderModel(id="minimax/minimax-m2.5:free", name="MiniMax")]
class _ReasoningProvider(_StreamingProvider):
async def stream_chat(
self, request: ProviderChatRequest
) -> AsyncIterator[CoreChunk]:
self.models_seen.append(request.model)
self.last_request = request
yield CoreChunk(role="assistant")
yield CoreChunk(
reasoning_content="**Plan**",
reasoning_details=[
{
"type": "reasoning.encrypted",
"data": "enc_123",
"id": "rs_1",
"format": "openai-responses-v1",
},
{
"type": "reasoning.summary",
"summary": "**Plan**",
"id": "rs_1",
"format": "openai-responses-v1",
},
],
)
yield CoreChunk(content="Hello")
yield CoreChunk(finish_reason="stop")
class _ModelsDevCatalogWithProviderName(ModelsDevCatalog):
def __init__(self) -> None:
super().__init__(fetch_catalog=lambda: _never_called())
async def get_provider_models(
self, *, provider_name: str, provider_url: str
) -> tuple[str | None, dict[str, dict[str, object]]]:
return "Kilo AI", {
"minimax/minimax-m2.5:free": {
"name": "MiniMax",
"release_date": "2026-01-15",
}
}
async def _never_called() -> dict[str, object]:
raise RuntimeError("not expected")
def _parse_sse_data(raw: str) -> list[str]:
out: list[str] = []
for line in raw.splitlines():
@ -110,6 +158,43 @@ def test_chat_completions_non_stream_returns_chat_completion_object() -> None:
assert body["choices"][0]["message"]["content"] == "Hello"
def test_chat_completions_non_stream_includes_reasoning_details() -> None:
app = create_app()
provider = _ReasoningProvider()
core = RouterCore(providers={"kilocode": provider})
app.dependency_overrides[get_router_core] = lambda: core
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "kilocode/minimax/minimax-m2.5:free",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
assert response.status_code == 200
body = response.json()
message = body["choices"][0]["message"]
assert message["reasoning"] == "**Plan**"
assert message["reasoning_content"] == "**Plan**"
assert message["reasoning_details"] == [
{
"type": "reasoning.encrypted",
"data": "enc_123",
"id": "rs_1",
"format": "openai-responses-v1",
},
{
"type": "reasoning.summary",
"summary": "**Plan**",
"id": "rs_1",
"format": "openai-responses-v1",
},
]
def test_chat_completions_supports_unversioned_path() -> None:
app = create_app()
provider = _StreamingProvider()
@ -123,13 +208,16 @@ def test_chat_completions_supports_unversioned_path() -> None:
"model": "kilocode/minimax/minimax-m2.5:free",
"messages": [{"role": "user", "content": "hi"}],
"stream": True,
"reasoning_effort": "low",
"reasoning": {"effort": "low", "summary": "auto"},
},
)
assert response.status_code == 200
payloads = _parse_sse_data(response.text)
assert payloads[-1] == "[DONE]"
assert provider.last_request is not None
assert provider.last_request.reasoning_effort == "low"
assert provider.last_request.reasoning_summary == "auto"
def test_chat_completions_accepts_temporary_extra_params() -> None:
@ -180,12 +268,75 @@ def test_chat_completions_accepts_reasoning_summary_camel_alias() -> None:
assert provider.last_request.reasoning_summary == "detailed"
def test_models_endpoint_returns_aggregated_models() -> None:
def test_chat_completions_accepts_schema_modalities() -> None:
app = create_app()
provider = _StreamingProvider()
core = RouterCore(providers={"kilocode": provider})
app.dependency_overrides[get_router_core] = lambda: core
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "kilocode/minimax/minimax-m2.5:free",
"messages": [{"role": "user", "content": "hi"}],
"modalities": ["text", "image"],
"stream": False,
},
)
assert response.status_code == 200
assert provider.last_request is not None
assert provider.last_request.modalities == ["text", "image"]
def test_chat_completions_accepts_schema_router_fields() -> None:
app = create_app()
provider = _StreamingProvider()
core = RouterCore(providers={"kilocode": provider})
app.dependency_overrides[get_router_core] = lambda: core
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
json={
"model": "kilocode/minimax/minimax-m2.5:free",
"messages": [{"role": "user", "content": "hi"}],
"provider": {"allow_fallbacks": False},
"plugins": [{"id": "web", "enabled": True}],
"session_id": "ses_123",
"trace": {"trace_id": "tr_1"},
"models": ["openai/gpt-5"],
"debug": {"echo_upstream_body": True},
"image_config": {"size": "1024x1024"},
"reasoning": {"effort": "high", "summary": "detailed"},
"temperature": 0.1,
"stream": False,
},
)
assert response.status_code == 200
assert provider.last_request is not None
assert provider.last_request.provider == {"allow_fallbacks": False}
assert provider.last_request.plugins == [{"id": "web", "enabled": True}]
assert provider.last_request.session_id == "ses_123"
assert provider.last_request.trace == {"trace_id": "tr_1"}
assert provider.last_request.models == ["openai/gpt-5"]
assert provider.last_request.debug == {"echo_upstream_body": True}
assert provider.last_request.image_config == {"size": "1024x1024"}
assert provider.last_request.reasoning == {"effort": "high", "summary": "detailed"}
assert provider.last_request.temperature == 0.1
def test_models_endpoint_returns_aggregated_models() -> None:
app = create_app()
provider = _StreamingProvider()
core = RouterCore(
providers={"kilocode": provider},
models_dev_catalog=_ModelsDevCatalogWithProviderName(),
)
app.dependency_overrides[get_router_core] = lambda: core
client = TestClient(app)
response = client.get("/models")
@ -194,8 +345,9 @@ def test_models_endpoint_returns_aggregated_models() -> None:
assert body["object"] == "list"
assert body["data"][0]["id"] == "kilocode/minimax/minimax-m2.5:free"
assert body["data"][0]["object"] == "model"
assert body["data"][0]["created"] == 0
assert body["data"][0]["created"] > 0
assert body["data"][0]["owned_by"] == "wzray"
assert body["data"][0]["name"] == "Kilo AI: MiniMax"
def test_models_endpoint_supports_v1_path() -> None:

View file

@ -35,7 +35,6 @@ def test_build_responses_create_args_maps_core_fields() -> None:
safety_identifier="user-123",
service_tier="default",
parallel_tool_calls=False,
temperature=0.2,
top_p=0.9,
tools=[
{
@ -68,11 +67,12 @@ def test_build_responses_create_args_maps_core_fields() -> None:
assert args_dict["model"] == "gpt-5-codex"
assert args_dict["stream"] is True
assert args_dict["store"] is False
assert args_dict["instructions"] == "Follow policy.\n\nSystem rule."
assert args_dict["instructions"] == "Follow policy."
assert args_dict["include"] == ["reasoning.encrypted_content"]
assert args_dict["parallel_tool_calls"] is False
assert "prompt_cache_key" not in args_dict
assert args_dict["prompt_cache_retention"] == "24h"
assert args_dict["metadata"] == {"team": "infra"}
assert args_dict["temperature"] == 0.2
assert args_dict["top_p"] == 0.9
assert args_dict["reasoning"] == {"effort": "medium", "summary": "detailed"}
@ -87,19 +87,28 @@ def test_build_responses_create_args_maps_core_fields() -> None:
}
items = args_dict["input"]
assert items[0] == {"type": "message", "role": "user", "content": "Hello"}
assert items[0] == {
"type": "message",
"role": "system",
"content": [{"type": "input_text", "text": "System rule."}],
}
assert items[1] == {
"type": "message",
"role": "assistant",
"content": "Calling tool",
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}],
}
assert items[2] == {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Calling tool"}],
}
assert items[3] == {
"type": "function_call",
"call_id": "call_1",
"name": "lookup",
"arguments": '{"x": 1}',
}
assert items[3] == {
assert items[4] == {
"type": "function_call_output",
"call_id": "call_1",
"output": '{"ok": true}',
@ -189,7 +198,161 @@ def test_build_responses_create_args_allows_missing_instructions() -> None:
args_dict = cast(dict[str, Any], args)
assert args_dict["instructions"] == "You are a helpful assistant."
assert args_dict["input"] == [
{"type": "message", "role": "user", "content": "hello"}
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "hello"}],
}
]
def test_build_responses_create_args_maps_followup_developer_to_system_message() -> (
None
):
request = ProviderChatRequest(
model="gpt-5-codex",
messages=[
CoreMessage(role="developer", content="Primary rules"),
CoreMessage(role="developer", content="Secondary rules"),
CoreMessage(role="user", content="hello"),
],
)
args, _ = build_responses_create_args(request)
args_dict = cast(dict[str, Any], args)
assert args_dict["instructions"] == "Primary rules"
assert args_dict["input"] == [
{
"type": "message",
"role": "system",
"content": [{"type": "input_text", "text": "Secondary rules"}],
},
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "hello"}],
},
]
def test_build_responses_create_args_skips_empty_assistant_messages() -> None:
request = ProviderChatRequest(
model="gpt-5-codex",
messages=[
CoreMessage(role="developer", content="Rules"),
CoreMessage(role="assistant", content=""),
CoreMessage(role="user", content="hello"),
],
)
args, _ = build_responses_create_args(request)
args_dict = cast(dict[str, Any], args)
assert args_dict["input"] == [
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "hello"}],
}
]
def test_build_responses_create_args_maps_reasoning_object() -> None:
request = ProviderChatRequest(
model="gpt-5-codex",
messages=[
CoreMessage(role="developer", content="Rules"),
CoreMessage(role="user", content="hello"),
],
reasoning={"effort": "low", "summary": "concise"},
)
args, _ = build_responses_create_args(request)
args_dict = cast(dict[str, Any], args)
assert args_dict["reasoning"] == {"effort": "low", "summary": "concise"}
def test_build_responses_create_args_keeps_flat_reasoning_precedence() -> None:
request = ProviderChatRequest(
model="gpt-5-codex",
messages=[
CoreMessage(role="developer", content="Rules"),
CoreMessage(role="user", content="hello"),
],
reasoning_effort="high",
reasoning_summary="detailed",
reasoning={"effort": "low", "summary": "concise"},
)
args, _ = build_responses_create_args(request)
args_dict = cast(dict[str, Any], args)
assert args_dict["reasoning"] == {"effort": "high", "summary": "detailed"}
def test_build_responses_create_args_supports_multimodal_content_parts() -> None:
request = ProviderChatRequest(
model="gpt-5-codex",
messages=[
CoreMessage(role="developer", content="Rules"),
CoreMessage(
role="user",
content=[
{
"type": "text",
"text": "Describe media",
"cache_control": {"type": "ephemeral", "ttl": "5m"},
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.png",
"detail": "high",
},
},
{
"type": "input_audio",
"input_audio": {"data": "ZmFrZQ==", "format": "wav"},
},
{
"type": "video_url",
"video_url": {"url": "https://example.com/video.mp4"},
},
{
"type": "input_file",
"file_url": "https://example.com/doc.pdf",
"filename": "doc.pdf",
},
],
),
],
)
args, _ = build_responses_create_args(request)
args_dict = cast(dict[str, Any], args)
parts = args_dict["input"][0]["content"]
assert parts == [
{
"type": "input_text",
"text": "Describe media",
"cache_control": {"type": "ephemeral", "ttl": "5m"},
},
{
"type": "input_image",
"detail": "high",
"image_url": "https://example.com/image.png",
},
{
"type": "input_audio",
"input_audio": {"data": "ZmFrZQ==", "format": "wav"},
},
{
"type": "input_file",
"file_url": "https://example.com/video.mp4",
},
{
"type": "input_file",
"file_url": "https://example.com/doc.pdf",
"filename": "doc.pdf",
},
]

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import json
from pathlib import Path
import pytest
@ -8,20 +7,16 @@ import yaml
from pydantic import ValidationError
from app.config.loader import load_config
from app.config.models import OAuthAuth, TokenAuth
from app.config.models import OAuthAuth, TokenAuth, UrlAuth
def _write_yaml(path: Path, data: dict) -> None:
path.write_text(yaml.safe_dump(data), encoding="utf-8")
def _write_json(path: Path, data: dict) -> None:
path.write_text(json.dumps(data), encoding="utf-8")
def test_load_config_success(tmp_path: Path) -> None:
config_path = tmp_path / "config.yml"
auth_path = tmp_path / "auth.json"
auth_path = tmp_path / "auth.yml"
_write_yaml(
config_path,
@ -30,12 +25,14 @@ def test_load_config_success(tmp_path: Path) -> None:
"kilocode": {
"url": "https://api.kilo.ai/api/openrouter",
"type": "openai-completions",
"name": "Kilo Override",
"models": {"minimax/minimax-m2.5:free": {"name": "MiniMax Custom"}},
"whitelist": ["minimax/minimax-m2.5:free"],
}
}
},
)
_write_json(
_write_yaml(
auth_path,
{"providers": {"kilocode": {"token": "public"}}},
)
@ -46,12 +43,14 @@ def test_load_config_success(tmp_path: Path) -> None:
assert isinstance(provider.auth, TokenAuth)
assert provider.type == "openai-completions"
assert provider.auth.token == "public"
assert provider.display_name == "Kilo Override"
assert provider.models == {"minimax/minimax-m2.5:free": {"name": "MiniMax Custom"}}
assert provider.whitelist == ["minimax/minimax-m2.5:free"]
def test_load_config_requires_auth_entry(tmp_path: Path) -> None:
config_path = tmp_path / "config.yml"
auth_path = tmp_path / "auth.json"
auth_path = tmp_path / "auth.yml"
_write_yaml(
config_path,
@ -61,7 +60,7 @@ def test_load_config_requires_auth_entry(tmp_path: Path) -> None:
}
},
)
_write_json(auth_path, {"providers": {}})
_write_yaml(auth_path, {"providers": {}})
with pytest.raises(ValueError, match="Missing auth entry"):
load_config(config_path=config_path, auth_path=auth_path)
@ -69,7 +68,7 @@ def test_load_config_requires_auth_entry(tmp_path: Path) -> None:
def test_load_config_rejects_incomplete_oauth_auth(tmp_path: Path) -> None:
config_path = tmp_path / "config.yml"
auth_path = tmp_path / "auth.json"
auth_path = tmp_path / "auth.yml"
_write_yaml(
config_path,
@ -79,7 +78,7 @@ def test_load_config_rejects_incomplete_oauth_auth(tmp_path: Path) -> None:
}
},
)
_write_json(auth_path, {"providers": {"kilocode": {"refresh": "rt_abc"}}})
_write_yaml(auth_path, {"providers": {"kilocode": {"refresh": "rt_abc"}}})
with pytest.raises(ValidationError):
load_config(config_path=config_path, auth_path=auth_path)
@ -87,7 +86,7 @@ def test_load_config_rejects_incomplete_oauth_auth(tmp_path: Path) -> None:
def test_load_config_supports_oauth_auth(tmp_path: Path) -> None:
config_path = tmp_path / "config.yml"
auth_path = tmp_path / "auth.json"
auth_path = tmp_path / "auth.yml"
_write_yaml(
config_path,
@ -100,7 +99,7 @@ def test_load_config_supports_oauth_auth(tmp_path: Path) -> None:
}
},
)
_write_json(
_write_yaml(
auth_path,
{
"providers": {
@ -119,3 +118,35 @@ def test_load_config_supports_oauth_auth(tmp_path: Path) -> None:
assert provider.auth.access == "acc"
assert provider.auth.refresh == "ref"
assert provider.auth.expires == 123
def test_load_config_supports_url_auth(tmp_path: Path) -> None:
config_path = tmp_path / "config.yml"
auth_path = tmp_path / "auth.yml"
_write_yaml(
config_path,
{
"providers": {
"kilo": {
"url": "https://api.kilo.ai/api/openrouter",
"type": "openai-completions",
}
}
},
)
_write_yaml(
auth_path,
{
"providers": {
"kilo": {
"url": "https://auth.local/token",
}
}
},
)
loaded = load_config(config_path=config_path, auth_path=auth_path)
provider = loaded.providers["kilo"]
assert isinstance(provider.auth, UrlAuth)
assert provider.auth.url == "https://auth.local/token"

View file

@ -10,6 +10,7 @@ from app.core.errors import (
ModelNotAllowedError,
ProviderNotFoundError,
)
from app.core.models_dev import ModelsDevCatalog
from app.core.router import RouterCore, split_routed_model
from app.core.types import (
CoreChatRequest,
@ -25,6 +26,7 @@ class StubProvider(BaseProvider):
def __init__(self, **kwargs) -> None:
super().__init__(api_type="stub", **kwargs)
self.requests: list[ProviderChatRequest] = []
self.list_models_calls = 0
@classmethod
def from_config(cls, config):
@ -38,9 +40,90 @@ class StubProvider(BaseProvider):
yield CoreChunk(finish_reason="stop")
async def list_models(self) -> list[ProviderModel]:
self.list_models_calls += 1
return [ProviderModel(id="model-a", name="Model A", context_length=123)]
class StubModelsDevCatalog(ModelsDevCatalog):
def __init__(self, payload: dict[str, dict[str, object]] | None = None) -> None:
super().__init__(fetch_catalog=lambda: _never_called())
self.payload = payload or {}
self.calls: list[tuple[str, str]] = []
self.provider_display_name = "Stub Provider"
async def get_provider_models(
self, *, provider_name: str, provider_url: str
) -> tuple[str | None, dict[str, dict[str, object]]]:
self.calls.append((provider_name, provider_url))
return self.provider_display_name, self.payload
async def _never_called() -> dict[str, object]:
raise RuntimeError("not expected")
class MissingFieldsProvider(BaseProvider):
def __init__(self, **kwargs) -> None:
super().__init__(api_type="stub", **kwargs)
@classmethod
def from_config(cls, config):
raise NotImplementedError
async def stream_chat(
self, request: ProviderChatRequest
) -> AsyncIterator[CoreChunk]:
yield CoreChunk(finish_reason="stop")
async def list_models(self) -> list[ProviderModel]:
return [ProviderModel(id="model-x")]
class CompleteFieldsProvider(BaseProvider):
def __init__(self, **kwargs) -> None:
super().__init__(api_type="stub", **kwargs)
@classmethod
def from_config(cls, config):
raise NotImplementedError
async def stream_chat(
self, request: ProviderChatRequest
) -> AsyncIterator[CoreChunk]:
yield CoreChunk(finish_reason="stop")
async def list_models(self) -> list[ProviderModel]:
return [
ProviderModel(
id="model-a",
name="Model A",
created=1,
context_length=4096,
architecture={"input_modalities": ["text"]},
pricing={"input": 1.0, "output": 2.0},
supported_parameters=["max_tokens"],
owned_by="kilo",
)
]
class FailingModelsProvider(BaseProvider):
def __init__(self, **kwargs) -> None:
super().__init__(api_type="stub", **kwargs)
@classmethod
def from_config(cls, config):
raise NotImplementedError
async def stream_chat(
self, request: ProviderChatRequest
) -> AsyncIterator[CoreChunk]:
yield CoreChunk(finish_reason="stop")
async def list_models(self) -> list[ProviderModel]:
raise RuntimeError("upstream unavailable")
def _collect(async_iter: AsyncIterator[CoreChunk]) -> list[CoreChunk]:
async def _inner() -> list[CoreChunk]:
out: list[CoreChunk] = []
@ -88,14 +171,12 @@ def test_stream_chat_routes_to_provider_model_without_prefix() -> None:
req = CoreChatRequest(
model="kilo/minimax/minimax-m2.5:free",
messages=[CoreMessage(role="user", content="ping")],
temperature=0.2,
)
chunks = _collect(core.stream_chat(req))
assert [c.content for c in chunks if c.content] == ["hello"]
assert chunks[-1].finish_reason == "stop"
assert provider.requests[0].model == "minimax/minimax-m2.5:free"
assert provider.requests[0].temperature == 0.2
def test_list_models_prefixes_provider_and_applies_defaults() -> None:
@ -121,3 +202,129 @@ def test_list_models_respects_whitelist() -> None:
models = asyncio.run(core.list_models())
assert models == []
def test_list_models_uses_ttl_cache() -> None:
provider = StubProvider(
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
)
core = RouterCore(providers={"kilo": provider})
first = asyncio.run(core.list_models())
second = asyncio.run(core.list_models())
assert len(first) == 1
assert len(second) == 1
assert provider.list_models_calls == 1
def test_list_models_cache_can_be_disabled_with_zero_ttl() -> None:
provider = StubProvider(
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
)
core = RouterCore(providers={"kilo": provider}, models_cache_ttl_seconds=0.0)
asyncio.run(core.list_models())
asyncio.run(core.list_models())
assert provider.list_models_calls == 2
def test_list_models_enriches_missing_fields_from_models_dev() -> None:
provider = MissingFieldsProvider(
name="zai",
base_url="https://api.z.ai/api/coding/paas/v4",
whitelist=None,
blacklist=None,
)
catalog = StubModelsDevCatalog(
{
"model-x": {
"name": "Model X",
"limit": {"context": 65536, "output": 4096},
"modalities": {"input": ["text"], "output": ["text"]},
"family": "glm",
"cost": {"input": 1.0, "output": 2.0},
"tool_call": True,
"reasoning": True,
"structured_output": True,
"release_date": "2026-01-15",
"provider": "z-ai",
}
}
)
core = RouterCore(providers={"zai": provider}, models_dev_catalog=catalog)
models = asyncio.run(core.list_models())
assert len(models) == 1
assert models[0].id == "zai/model-x"
assert models[0].name == "Model X"
assert models[0].context_length == 65536
assert models[0].architecture == {
"input_modalities": ["text"],
"output_modalities": ["text"],
"family": "glm",
}
assert models[0].pricing == {"input": 1.0, "output": 2.0}
assert models[0].owned_by == "z-ai"
assert models[0].created > 0
assert models[0].supported_parameters == [
"max_completion_tokens",
"max_tokens",
"modalities",
"parallel_tool_calls",
"reasoning_effort",
"reasoning_summary",
"response_format",
"tool_choice",
"tools",
]
assert models[0].provider_display_name == "Stub Provider"
assert catalog.calls == [("zai", "https://api.z.ai/api/coding/paas/v4")]
def test_list_models_always_calls_models_dev_and_prefers_it() -> None:
provider = CompleteFieldsProvider(
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
)
catalog = StubModelsDevCatalog({"model-a": {"name": "Never Used"}})
core = RouterCore(providers={"kilo": provider}, models_dev_catalog=catalog)
models = asyncio.run(core.list_models())
assert len(models) == 1
assert models[0].name == "Never Used"
assert catalog.calls == [("kilo", "https://kilo")]
def test_list_models_skips_failed_provider_and_returns_others() -> None:
good = StubProvider(
name="good", base_url="https://good", whitelist=None, blacklist=None
)
bad = FailingModelsProvider(
name="bad", base_url="https://bad", whitelist=None, blacklist=None
)
core = RouterCore(providers={"good": good, "bad": bad})
models = asyncio.run(core.list_models())
assert len(models) == 1
assert models[0].id == "good/model-a"
def test_list_models_respects_provider_and_model_name_overrides() -> None:
provider = StubProvider(
name="kilo", base_url="https://kilo", whitelist=None, blacklist=None
)
provider.display_name = "Configured Provider"
provider.models_config = {"model-a": {"name": "Configured Model"}}
catalog = StubModelsDevCatalog({"model-a": {"name": "ModelsDev Model"}})
core = RouterCore(providers={"kilo": provider}, models_dev_catalog=catalog)
models = asyncio.run(core.list_models())
assert len(models) == 1
assert models[0].name == "Configured Model"
assert models[0].provider_display_name == "Configured Provider"

View file

@ -6,7 +6,7 @@ from typing import Any, cast
import pytest
from app.config.models import LoadedProviderConfig, OAuthAuth, TokenAuth
from app.config.models import LoadedProviderConfig, OAuthAuth, TokenAuth, UrlAuth
from app.core.errors import UpstreamProviderError
from app.core.types import CoreMessage, ProviderChatRequest
from app.providers.codex_responses.oauth import OAuthData
@ -25,6 +25,8 @@ class _FakeEvent:
refusal: str | None = None,
item_id: str | None = None,
content_index: int | None = None,
output_index: int | None = None,
summary_index: int | None = None,
item: dict[str, Any] | None = None,
part: dict[str, Any] | None = None,
response: dict[str, Any] | None = None,
@ -37,6 +39,8 @@ class _FakeEvent:
self.refusal = refusal
self.item_id = item_id
self.content_index = content_index
self.output_index = output_index
self.summary_index = summary_index
self.item = item
self.part = part
self.response = response
@ -101,6 +105,14 @@ class _FakeOauth:
async def get(self) -> OAuthData:
return OAuthData(token="acc", headers={"Authorization": "Bearer acc"})
async def get_headers(self) -> dict[str, str]:
return {"Authorization": "Bearer acc"}
class _StaticUrlAuth:
async def get_headers(self) -> dict[str, str]:
return {"Authorization": "Bearer fetched-token"}
class _FailingResponses:
async def create(self, **kwargs):
@ -165,7 +177,10 @@ def test_codex_provider_streams_openai_responses_events() -> None:
assert chunks[2].content == " world"
assert chunks[-1].finish_reason == "stop"
assert client.responses.last_headers == {"Authorization": "Bearer acc"}
assert client.responses.last_headers is not None
assert client.responses.last_headers["Authorization"] == "Bearer acc"
assert client.responses.last_headers["originator"] == "opencode"
assert client.responses.last_headers["session_id"].startswith("ses_")
payload = client.responses.last_payload
assert payload is not None
assert payload["model"] == "gpt-5-codex"
@ -220,7 +235,7 @@ def test_codex_provider_requires_oauth_in_from_config() -> None:
type="codex-responses",
auth=TokenAuth(token="bad"),
)
with pytest.raises(ValueError, match="requires oauth auth"):
with pytest.raises(ValueError, match="requires oauth or url auth"):
CodexResponsesProvider.from_config(config)
@ -266,6 +281,37 @@ def test_codex_provider_from_config_success() -> None:
assert isinstance(provider, CodexResponsesProvider)
def test_codex_provider_from_config_supports_url_auth() -> None:
config = LoadedProviderConfig(
name="codex",
url="https://chatgpt.com/backend-api",
type="codex-responses",
auth=UrlAuth(url="https://auth.local/token"),
)
provider = CodexResponsesProvider.from_config(config)
assert isinstance(provider, CodexResponsesProvider)
def test_codex_provider_uses_url_auth_headers() -> None:
client = _FakeClient()
provider = CodexResponsesProvider(
name="codex",
base_url="https://chatgpt.com/backend-api",
oauth=cast(Any, _StaticUrlAuth()),
client=client,
)
req = ProviderChatRequest(
model="gpt-5-codex",
messages=[CoreMessage(role="user", content="hello")],
)
_collect(provider.stream_chat(req))
assert client.responses.last_headers is not None
assert client.responses.last_headers["Authorization"] == "Bearer fetched-token"
def test_codex_provider_emits_text_from_output_item_done_without_deltas() -> None:
provider = CodexResponsesProvider(
name="codex",
@ -432,3 +478,64 @@ def test_codex_provider_streams_reasoning_summary_delta() -> None:
assert chunks[0].role == "assistant"
assert chunks[1].reasoning_content == "thinking..."
assert chunks[-1].finish_reason == "stop"
def test_codex_provider_streams_reasoning_details_with_encrypted_content() -> None:
provider = CodexResponsesProvider(
name="codex",
base_url="https://chatgpt.com/backend-api",
oauth=cast(Any, _FakeOauth()),
client=_CustomClient(
[
_FakeEvent(
type="response.output_item.added",
output_index=0,
item={
"id": "rs_1",
"type": "reasoning",
"encrypted_content": "enc_123",
},
),
_FakeEvent(
type="response.reasoning_summary_text.done",
item_id="rs_1",
summary_index=0,
text="**Plan**",
),
_FakeEvent(type="response.completed"),
]
),
)
req = ProviderChatRequest(
model="gpt-5-codex",
messages=[
CoreMessage(role="developer", content="Be concise."),
CoreMessage(role="user", content="Think first"),
],
)
chunks = _collect(provider.stream_chat(req))
details = [
detail
for chunk in chunks
for detail in (chunk.reasoning_details or [])
if isinstance(detail, dict)
]
assert {d.get("type") for d in details} == {
"reasoning.encrypted",
"reasoning.summary",
}
encrypted = next(d for d in details if d.get("type") == "reasoning.encrypted")
assert encrypted["data"] == "enc_123"
assert encrypted["id"] == "rs_1"
assert encrypted["format"] == "openai-responses-v1"
summary = next(d for d in details if d.get("type") == "reasoning.summary")
assert summary["summary"] == "**Plan**"
assert summary["id"] == "rs_1"
assert summary["format"] == "openai-responses-v1"
reasoning_chunks = [c.reasoning_content for c in chunks if c.reasoning_content]
assert reasoning_chunks == ["**Plan**"]
assert chunks[-1].finish_reason == "stop"

View file

@ -5,15 +5,26 @@ from collections.abc import AsyncIterator
import pytest
from app.config.models import LoadedProviderConfig, OAuthAuth, UrlAuth
from app.core.errors import UpstreamProviderError
from app.core.types import CoreMessage, ProviderChatRequest
from app.providers.openai_completions.provider import OpenAICompletionsProvider
class _FakeDelta:
def __init__(self, role: str | None = None, content: str | None = None) -> None:
def __init__(
self,
role: str | None = None,
content: str | None = None,
reasoning_content: str | None = None,
reasoning: object | None = None,
tool_calls: list[dict[str, object]] | None = None,
) -> None:
self.role = role
self.content = content
self.reasoning_content = reasoning_content
self.reasoning = reasoning
self.tool_calls = tool_calls
class _FakeChoice:
@ -61,19 +72,73 @@ class _FakeCompletions:
)
class _ReasoningCompletions:
async def create(self, **kwargs):
return _FakeStream(
[
_FakeChunk([_FakeChoice(index=0, delta=_FakeDelta(role="assistant"))]),
_FakeChunk(
[
_FakeChoice(
index=0,
delta=_FakeDelta(reasoning={"summary": "thinking"}),
)
]
),
_FakeChunk(
[
_FakeChoice(
index=0,
delta=_FakeDelta(
tool_calls=[
{
"index": 0,
"id": "call_1",
"type": "function",
"function": {
"name": "question",
"arguments": "{",
},
}
]
),
)
]
),
_FakeChunk([_FakeChoice(index=0, finish_reason="tool_calls")]),
]
)
class _FakeChat:
def __init__(self) -> None:
self.completions = _FakeCompletions()
class _ReasoningChat:
def __init__(self) -> None:
self.completions = _ReasoningCompletions()
class _FakeClient:
def __init__(self) -> None:
self.chat = _FakeChat()
self.models = _FakeModels()
class _ReasoningClient:
def __init__(self) -> None:
self.chat = _ReasoningChat()
self.models = _FakeModels()
class _FakeModels:
async def list(self):
def __init__(self) -> None:
self.last_kwargs: dict[str, object] | None = None
async def list(self, **kwargs):
self.last_kwargs = kwargs
class _Resp:
data = [
{
@ -112,10 +177,15 @@ class _FailingClient:
class _FailingModels:
async def list(self):
async def list(self, **kwargs):
raise RuntimeError("boom")
class _StaticAuth:
async def get_headers(self) -> dict[str, str]:
return {"Authorization": "Bearer fetched-token"}
def _collect(async_iter) -> list:
async def _inner() -> list:
out = []
@ -137,7 +207,6 @@ def test_openai_completions_provider_streams_internal_chunks() -> None:
req = ProviderChatRequest(
model="minimax/minimax-m2.5:free",
messages=[CoreMessage(role="user", content="hello")],
temperature=0.1,
top_p=0.9,
max_tokens=123,
)
@ -152,10 +221,70 @@ def test_openai_completions_provider_streams_internal_chunks() -> None:
assert payload is not None
assert payload["model"] == "minimax/minimax-m2.5:free"
assert payload["stream"] is True
assert payload["temperature"] == 0.1
assert payload["top_p"] == 0.9
assert payload["max_tokens"] == 123
assert payload["messages"] == [{"role": "user", "content": "hello"}]
assert payload["extra_headers"] == {"Authorization": "Bearer public"}
def test_openai_completions_provider_streams_reasoning_details() -> None:
class _ReasoningDetailsChat:
def __init__(self) -> None:
self.completions = _ReasoningDetailsCompletions()
class _ReasoningDetailsCompletions:
async def create(self, **kwargs):
return _FakeStream(
[
_FakeChunk(
[
_FakeChoice(
index=0,
delta=_FakeDelta(
reasoning={
"details": [
{
"type": "reasoning.encrypted",
"data": "enc_123",
"format": "openai-responses-v1",
}
]
}
),
)
]
),
_FakeChunk([_FakeChoice(index=0, finish_reason="stop")]),
]
)
class _ReasoningDetailsClient:
def __init__(self) -> None:
self.chat = _ReasoningDetailsChat()
self.models = _FakeModels()
provider = OpenAICompletionsProvider(
name="kilo",
base_url="https://api.kilo.ai/api/openrouter",
token="public",
client=_ReasoningDetailsClient(),
)
req = ProviderChatRequest(
model="minimax/minimax-m2.5:free",
messages=[CoreMessage(role="user", content="hello")],
)
chunks = _collect(provider.stream_chat(req))
assert chunks[0].role == "assistant"
assert chunks[1].reasoning_details == [
{
"type": "reasoning.encrypted",
"data": "enc_123",
"format": "openai-responses-v1",
}
]
assert chunks[2].finish_reason == "stop"
def test_openai_completions_provider_wraps_upstream_error() -> None:
@ -174,12 +303,40 @@ def test_openai_completions_provider_wraps_upstream_error() -> None:
_collect(provider.stream_chat(req))
def test_openai_completions_provider_lists_models() -> None:
def test_openai_completions_provider_streams_reasoning_and_tool_calls() -> None:
provider = OpenAICompletionsProvider(
name="kilo",
base_url="https://api.kilo.ai/api/openrouter",
token="public",
client=_FakeClient(),
client=_ReasoningClient(),
)
req = ProviderChatRequest(
model="minimax/minimax-m2.5:free",
messages=[CoreMessage(role="user", content="hello")],
)
chunks = _collect(provider.stream_chat(req))
assert chunks[0].role == "assistant"
assert chunks[1].reasoning_content == "thinking"
assert chunks[2].tool_calls == [
{
"index": 0,
"id": "call_1",
"type": "function",
"function": {"name": "question", "arguments": "{"},
}
]
assert chunks[3].finish_reason == "tool_calls"
def test_openai_completions_provider_lists_models() -> None:
client = _FakeClient()
provider = OpenAICompletionsProvider(
name="kilo",
base_url="https://api.kilo.ai/api/openrouter",
token="public",
client=client,
)
models = asyncio.run(provider.list_models())
@ -187,6 +344,9 @@ def test_openai_completions_provider_lists_models() -> None:
assert models[0].id == "minimax/minimax-m2.5:free"
assert models[0].context_length == 2048
assert models[0].architecture == {"input_modalities": ["text"]}
assert client.models.last_kwargs == {
"extra_headers": {"Authorization": "Bearer public"}
}
def test_openai_completions_provider_wraps_model_list_errors() -> None:
@ -199,3 +359,84 @@ def test_openai_completions_provider_wraps_model_list_errors() -> None:
with pytest.raises(UpstreamProviderError, match="failed while listing models"):
asyncio.run(provider.list_models())
def test_openai_completions_provider_from_config_supports_url_auth() -> None:
config = LoadedProviderConfig(
name="kilo",
url="https://api.kilo.ai/api/openrouter",
type="openai-completions",
auth=UrlAuth(url="https://auth.local/token"),
)
provider = OpenAICompletionsProvider.from_config(config)
assert isinstance(provider, OpenAICompletionsProvider)
def test_openai_completions_provider_from_config_rejects_oauth() -> None:
config = LoadedProviderConfig(
name="kilo",
url="https://api.kilo.ai/api/openrouter",
type="openai-completions",
auth=OAuthAuth(access="acc", refresh="ref", expires=1),
)
with pytest.raises(ValueError, match="requires token or url auth"):
OpenAICompletionsProvider.from_config(config)
def test_openai_completions_provider_uses_custom_auth_provider() -> None:
client = _FakeClient()
provider = OpenAICompletionsProvider(
name="kilo",
base_url="https://api.kilo.ai/api/openrouter",
auth_provider=_StaticAuth(),
client=client,
)
req = ProviderChatRequest(
model="minimax/minimax-m2.5:free",
messages=[CoreMessage(role="user", content="hello")],
)
_collect(provider.stream_chat(req))
payload = client.chat.completions.payload
assert payload is not None
assert payload["extra_headers"] == {"Authorization": "Bearer fetched-token"}
def test_openai_completions_provider_passes_schema_fields_via_extra_body() -> None:
client = _FakeClient()
provider = OpenAICompletionsProvider(
name="kilo",
base_url="https://api.kilo.ai/api/openrouter",
token="public",
client=client,
)
req = ProviderChatRequest(
model="minimax/minimax-m2.5:free",
messages=[CoreMessage(role="user", content="hello")],
reasoning={"effort": "high", "summary": "detailed"},
provider={"allow_fallbacks": False},
plugins=[{"id": "web", "enabled": True}],
session_id="ses_123",
trace={"trace_id": "tr_1"},
models=["openai/gpt-5"],
debug={"echo_upstream_body": True},
image_config={"size": "1024x1024"},
)
_collect(provider.stream_chat(req))
payload = client.chat.completions.payload
assert payload is not None
assert payload["extra_body"] == {
"reasoning": {"effort": "high", "summary": "detailed"},
"provider": {"allow_fallbacks": False},
"plugins": [{"id": "web", "enabled": True}],
"session_id": "ses_123",
"trace": {"trace_id": "tr_1"},
"models": ["openai/gpt-5"],
"debug": {"echo_upstream_body": True},
"image_config": {"size": "1024x1024"},
}