ai/ai
1
0
Fork 0
ai/app/providers/base.py

65 lines
2.2 KiB
Python

"""Provider abstraction and shared behavior."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
from typing import Any
from app.core.errors import ModelNotAllowedError
from app.core.types import CoreChunk, ProviderChatRequest, ProviderModel
if TYPE_CHECKING:
from app.config.models import LoadedProviderConfig
class BaseProvider(ABC):
"""Common interface for all provider implementations."""
def __init__(
self,
*,
name: str,
base_url: str,
api_type: str,
whitelist: list[str] | None = None,
blacklist: list[str] | None = None,
) -> None:
self.name = name
self.base_url = base_url
self.api_type = api_type
self.whitelist = whitelist
self.blacklist = blacklist
self.display_name: str | None = None
self.models_config: dict[str, dict[str, Any]] | None = None
def ensure_model_allowed(self, model: str) -> None:
if self.whitelist is not None and model not in self.whitelist:
raise ModelNotAllowedError(
f"Model '{model}' is not in whitelist for provider '{self.name}'"
)
if self.blacklist is not None and model in self.blacklist:
raise ModelNotAllowedError(
f"Model '{model}' is blacklisted for provider '{self.name}'"
)
def is_model_allowed(self, model: str) -> bool:
if self.whitelist is not None and model not in self.whitelist:
return False
if self.blacklist is not None and model in self.blacklist:
return False
return True
@classmethod
@abstractmethod
def from_config(cls, config: "LoadedProviderConfig") -> "BaseProvider":
"""Create provider instance from merged runtime config."""
@abstractmethod
def stream_chat(self, request: ProviderChatRequest) -> AsyncIterator[CoreChunk]:
"""Execute streaming chat request and yield internal chunks."""
@abstractmethod
async def list_models(self) -> list[ProviderModel]:
"""List upstream models supported by provider."""