66from pydantic import SecretStr
77
88from celeste import providers as _providers # noqa: F401
9- from celeste .auth import APIKey , Authentication
9+ from celeste .auth import APIKey , Authentication , AuthHeader , NoAuth
1010from celeste .client import ModalityClient
1111from celeste .core import (
1212 Capability ,
1313 Modality ,
1414 Operation ,
1515 Parameter ,
16+ Protocol ,
1617 Provider ,
1718 UsageField ,
1819)
4243from celeste .modalities .images .models import MODELS as _images_models
4344from celeste .modalities .images .providers import PROVIDERS as _images_providers
4445from celeste .modalities .text .models import MODELS as _text_models
46+ from celeste .modalities .text .protocols .chatcompletions import ChatCompletionsTextClient
47+ from celeste .modalities .text .protocols .openresponses import OpenResponsesTextClient
4548from celeste .modalities .text .providers import PROVIDERS as _text_providers
4649from celeste .modalities .videos .models import MODELS as _videos_models
4750from celeste .modalities .videos .providers import PROVIDERS as _videos_providers
5861
5962logger = logging .getLogger (__name__ )
6063
61- _CLIENT_MAP : dict [tuple [Modality , Provider ], type [ModalityClient ]] = {
64+ _CLIENT_MAP : dict [tuple [Modality , Provider | Protocol ], type [ModalityClient ]] = {
6265 ** {(Modality .TEXT , p ): c for p , c in _text_providers .items ()},
6366 ** {(Modality .IMAGES , p ): c for p , c in _images_providers .items ()},
6467 ** {(Modality .VIDEOS , p ): c for p , c in _videos_providers .items ()},
6568 ** {(Modality .AUDIO , p ): c for p , c in _audio_providers .items ()},
6669 ** {(Modality .EMBEDDINGS , p ): c for p , c in _embeddings_providers .items ()},
70+ # Protocol entries (for compatible APIs via protocol= + base_url=)
71+ (Modality .TEXT , Protocol .OPENRESPONSES ): OpenResponsesTextClient ,
72+ (Modality .TEXT , Protocol .CHATCOMPLETIONS ): ChatCompletionsTextClient ,
6773}
6874
6975for _model in [
7379 * _audio_models ,
7480 * _embeddings_models ,
7581]:
82+ assert _model .provider is not None
7683 _models [(_model .id , _model .provider )] = _model
7784
7885_CAPABILITY_TO_MODALITY_OPERATION : dict [Capability , tuple [Modality , Operation ]] = {
@@ -89,6 +96,7 @@ def _resolve_model(
8996 operation : Operation | None = None ,
9097 provider : Provider | None = None ,
9198 model : Model | str | None = None ,
99+ protocol : Protocol | None = None ,
92100) -> Model :
93101 """Resolve model parameter to Model object (auto-select if None, lookup if string)."""
94102 if model is None :
@@ -110,6 +118,21 @@ def _resolve_model(
110118 if isinstance (model , str ):
111119 found = get_model (model , provider )
112120 if not found :
121+ # Protocol path: unregistered models are expected
122+ if protocol is not None :
123+ if modality is None :
124+ msg = f"Model '{ model } ' not registered. Specify 'modality' explicitly."
125+ raise ValueError (msg )
126+ operations : dict [Modality , set [Operation ]] = {}
127+ if modality is not None :
128+ operations [modality ] = {operation } if operation else set ()
129+ return Model (
130+ id = model ,
131+ provider = provider ,
132+ display_name = model ,
133+ operations = operations ,
134+ streaming = True ,
135+ )
113136 if provider is None :
114137 raise ModelNotFoundError (model_id = model , provider = provider )
115138 if modality is None :
@@ -121,7 +144,7 @@ def _resolve_model(
121144 UserWarning ,
122145 stacklevel = 3 ,
123146 )
124- operations : dict [ Modality , set [ Operation ]] = {}
147+ operations = {}
125148 if modality is not None :
126149 operations [modality ] = {operation } if operation else set ()
127150 return Model (
@@ -158,38 +181,37 @@ def _infer_operation(model: Model, modality: Modality) -> Operation:
158181
159182def create_client (
160183 capability : Capability | None = None ,
161- modality : Modality | str | None = None ,
162- operation : Operation | str | None = None ,
184+ modality : Modality | None = None ,
185+ operation : Operation | None = None ,
163186 provider : Provider | None = None ,
164187 model : Model | str | None = None ,
165188 api_key : str | SecretStr | None = None ,
166189 auth : Authentication | None = None ,
190+ protocol : Protocol | None = None ,
191+ base_url : str | None = None ,
167192) -> ModalityClient :
168193 """Create an async client for the specified AI capability or modality.
169194
170195 Args:
171196 capability: The AI capability to use (deprecated, use modality instead).
172- If not provided and model is specified, capability is inferred
173- from the model (if unambiguous).
174197 modality: The modality to use (e.g., Modality.IMAGES, "images").
175- Preferred over capability for new code.
176198 operation: The operation to use (e.g., Operation.GENERATE, "generate").
177- If not provided and model supports exactly one operation for the
178- modality, it is inferred automatically.
179- provider: Optional provider. If not specified and model ID matches multiple
180- providers, the first match is used with a warning.
199+ provider: Optional provider (e.g., Provider.OPENAI).
181200 model: Model object, string model ID, or None for auto-selection.
182201 api_key: Optional API key override (string or SecretStr).
183202 auth: Optional Authentication object for custom auth (e.g., GoogleADC).
203+ protocol: Wire format protocol for compatible APIs (e.g., "openresponses",
204+ "chatcompletions"). Use with base_url for third-party compatible APIs.
205+ base_url: Custom base URL override. Use with protocol for compatible APIs,
206+ or with provider to proxy through a custom endpoint.
184207
185208 Returns:
186209 Configured client instance ready for generation operations.
187210
188211 Raises:
189212 ModelNotFoundError: If no model found for the specified capability/provider.
190- ClientNotFoundError: If no client registered for capability/provider.
213+ ClientNotFoundError: If no client registered for capability/provider/protocol .
191214 MissingCredentialsError: If required credentials are not configured.
192- UnsupportedCapabilityError: If the resolved model doesn't support the requested capability.
193215 ValueError: If capability/operation cannot be inferred from model.
194216 """
195217 # Translation layer: convert deprecated capability to modality/operation
@@ -208,37 +230,56 @@ def create_client(
208230 msg = "Either 'modality' or 'model' must be provided"
209231 raise ValueError (msg )
210232
211- resolved_modality = Modality (modality ) if isinstance (modality , str ) else modality
212- resolved_operation = (
213- Operation (operation ) if isinstance (operation , str ) else operation
214- )
215- resolved_provider = Provider (provider ) if isinstance (provider , str ) else provider
233+ resolved_modality = modality
234+ resolved_operation = operation
235+ resolved_provider = provider
236+ resolved_protocol = protocol
237+
238+ # Default to openresponses when base_url is given without protocol or provider
239+ if base_url is not None and resolved_protocol is None and resolved_provider is None :
240+ resolved_protocol = Protocol .OPENRESPONSES
216241
217242 resolved_model = _resolve_model (
218243 modality = resolved_modality ,
219244 operation = resolved_operation ,
220245 provider = resolved_provider ,
221246 model = model ,
247+ protocol = resolved_protocol ,
222248 )
223249
224- key = (resolved_modality , resolved_model .provider )
225- if key not in _CLIENT_MAP :
226- raise ClientNotFoundError (
227- modality = resolved_modality , provider = resolved_model .provider
228- )
229- modality_client_class = _CLIENT_MAP [key ]
230-
231- resolved_auth = credentials .get_auth (
232- resolved_model .provider ,
233- override_auth = auth ,
234- override_key = api_key ,
250+ # Client lookup: protocol takes precedence for compatible API path
251+ target = (
252+ resolved_protocol if resolved_protocol is not None else resolved_model .provider
235253 )
254+ if target is None :
255+ raise ClientNotFoundError (modality = resolved_modality )
256+
257+ if (resolved_modality , target ) not in _CLIENT_MAP :
258+ raise ClientNotFoundError (modality = resolved_modality , provider = target )
259+ modality_client_class = _CLIENT_MAP [(resolved_modality , target )]
260+
261+ # Auth resolution: BYOA for protocol path, credentials for provider path
262+ if resolved_protocol is not None and resolved_provider is None :
263+ if auth is not None :
264+ resolved_auth = auth
265+ elif api_key is not None :
266+ resolved_auth = AuthHeader (secret = api_key ) # type: ignore[arg-type] # validator converts str
267+ else :
268+ resolved_auth = NoAuth ()
269+ else :
270+ resolved_auth = credentials .get_auth (
271+ resolved_model .provider , # type: ignore[arg-type] # always Provider in this branch
272+ override_auth = auth ,
273+ override_key = api_key ,
274+ )
236275
237276 return modality_client_class (
238277 modality = resolved_modality ,
239278 model = resolved_model ,
240- provider = resolved_model .provider ,
279+ provider = resolved_provider ,
280+ protocol = resolved_protocol ,
241281 auth = resolved_auth ,
282+ base_url = base_url ,
242283 )
243284
244285
@@ -264,6 +305,7 @@ def create_client(
264305 "Output" ,
265306 "Parameter" ,
266307 "Parameters" ,
308+ "Protocol" ,
267309 "Provider" ,
268310 "RefResolvingJsonSchemaGenerator" ,
269311 "Role" ,
0 commit comments