Skip to content

Commit 6ebc51e

Browse files
feat: add protocol= + base_url= for compatible APIs (#241)
* feat: add protocol= + base_url= for compatible APIs (#230) Add support for pointing celeste at any OpenAI-compatible or Anthropic-compatible API using protocol= and base_url= parameters. - Add Protocol enum (openresponses, chatcompletions) to core.py - Remove Provider.OPENRESPONSES (protocols aren't providers) - Separate provider and protocol as independent fields on ModalityClient - Wire base_url as instance field, used by _build_url() - Add protocol= and base_url= to create_client() and domain namespaces - Default to openresponses when base_url given without protocol - BYOA auth: protocol path defaults to NoAuth, accepts api_key or auth - Make Model.provider optional (None for protocol-path models) - Remove dead base_url plumbing from per-call method signatures - Fix credentials.get_auth to use api_key without registry entry - Clean up OllamaGenerateClient to use _build_url() pattern - Remove redundant .value on StrEnum in integration tests Closes #230 * fix: remove redundant | str from StrEnum parameters (#237) StrEnum values are already strings — the | str union on Modality, Operation, Protocol, and Provider parameters was unnecessary. Also remove isinstance(x, str) coercions that are no-ops on StrEnums.
1 parent 912b800 commit 6ebc51e

File tree

26 files changed

+361
-123
lines changed

26 files changed

+361
-123
lines changed

src/celeste/__init__.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from pydantic import SecretStr
77

88
from celeste import providers as _providers # noqa: F401
9-
from celeste.auth import APIKey, Authentication
9+
from celeste.auth import APIKey, Authentication, AuthHeader, NoAuth
1010
from celeste.client import ModalityClient
1111
from celeste.core import (
1212
Capability,
1313
Modality,
1414
Operation,
1515
Parameter,
16+
Protocol,
1617
Provider,
1718
UsageField,
1819
)
@@ -42,6 +43,8 @@
4243
from celeste.modalities.images.models import MODELS as _images_models
4344
from celeste.modalities.images.providers import PROVIDERS as _images_providers
4445
from celeste.modalities.text.models import MODELS as _text_models
46+
from celeste.modalities.text.protocols.chatcompletions import ChatCompletionsTextClient
47+
from celeste.modalities.text.protocols.openresponses import OpenResponsesTextClient
4548
from celeste.modalities.text.providers import PROVIDERS as _text_providers
4649
from celeste.modalities.videos.models import MODELS as _videos_models
4750
from celeste.modalities.videos.providers import PROVIDERS as _videos_providers
@@ -58,12 +61,15 @@
5861

5962
logger = logging.getLogger(__name__)
6063

61-
_CLIENT_MAP: dict[tuple[Modality, Provider], type[ModalityClient]] = {
64+
_CLIENT_MAP: dict[tuple[Modality, Provider | Protocol], type[ModalityClient]] = {
6265
**{(Modality.TEXT, p): c for p, c in _text_providers.items()},
6366
**{(Modality.IMAGES, p): c for p, c in _images_providers.items()},
6467
**{(Modality.VIDEOS, p): c for p, c in _videos_providers.items()},
6568
**{(Modality.AUDIO, p): c for p, c in _audio_providers.items()},
6669
**{(Modality.EMBEDDINGS, p): c for p, c in _embeddings_providers.items()},
70+
# Protocol entries (for compatible APIs via protocol= + base_url=)
71+
(Modality.TEXT, Protocol.OPENRESPONSES): OpenResponsesTextClient,
72+
(Modality.TEXT, Protocol.CHATCOMPLETIONS): ChatCompletionsTextClient,
6773
}
6874

6975
for _model in [
@@ -73,6 +79,7 @@
7379
*_audio_models,
7480
*_embeddings_models,
7581
]:
82+
assert _model.provider is not None
7683
_models[(_model.id, _model.provider)] = _model
7784

7885
_CAPABILITY_TO_MODALITY_OPERATION: dict[Capability, tuple[Modality, Operation]] = {
@@ -89,6 +96,7 @@ def _resolve_model(
8996
operation: Operation | None = None,
9097
provider: Provider | None = None,
9198
model: Model | str | None = None,
99+
protocol: Protocol | None = None,
92100
) -> Model:
93101
"""Resolve model parameter to Model object (auto-select if None, lookup if string)."""
94102
if model is None:
@@ -110,6 +118,21 @@ def _resolve_model(
110118
if isinstance(model, str):
111119
found = get_model(model, provider)
112120
if not found:
121+
# Protocol path: unregistered models are expected
122+
if protocol is not None:
123+
if modality is None:
124+
msg = f"Model '{model}' not registered. Specify 'modality' explicitly."
125+
raise ValueError(msg)
126+
operations: dict[Modality, set[Operation]] = {}
127+
if modality is not None:
128+
operations[modality] = {operation} if operation else set()
129+
return Model(
130+
id=model,
131+
provider=provider,
132+
display_name=model,
133+
operations=operations,
134+
streaming=True,
135+
)
113136
if provider is None:
114137
raise ModelNotFoundError(model_id=model, provider=provider)
115138
if modality is None:
@@ -121,7 +144,7 @@ def _resolve_model(
121144
UserWarning,
122145
stacklevel=3,
123146
)
124-
operations: dict[Modality, set[Operation]] = {}
147+
operations = {}
125148
if modality is not None:
126149
operations[modality] = {operation} if operation else set()
127150
return Model(
@@ -158,38 +181,37 @@ def _infer_operation(model: Model, modality: Modality) -> Operation:
158181

159182
def create_client(
160183
capability: Capability | None = None,
161-
modality: Modality | str | None = None,
162-
operation: Operation | str | None = None,
184+
modality: Modality | None = None,
185+
operation: Operation | None = None,
163186
provider: Provider | None = None,
164187
model: Model | str | None = None,
165188
api_key: str | SecretStr | None = None,
166189
auth: Authentication | None = None,
190+
protocol: Protocol | None = None,
191+
base_url: str | None = None,
167192
) -> ModalityClient:
168193
"""Create an async client for the specified AI capability or modality.
169194
170195
Args:
171196
capability: The AI capability to use (deprecated, use modality instead).
172-
If not provided and model is specified, capability is inferred
173-
from the model (if unambiguous).
174197
modality: The modality to use (e.g., Modality.IMAGES, "images").
175-
Preferred over capability for new code.
176198
operation: The operation to use (e.g., Operation.GENERATE, "generate").
177-
If not provided and model supports exactly one operation for the
178-
modality, it is inferred automatically.
179-
provider: Optional provider. If not specified and model ID matches multiple
180-
providers, the first match is used with a warning.
199+
provider: Optional provider (e.g., Provider.OPENAI).
181200
model: Model object, string model ID, or None for auto-selection.
182201
api_key: Optional API key override (string or SecretStr).
183202
auth: Optional Authentication object for custom auth (e.g., GoogleADC).
203+
protocol: Wire format protocol for compatible APIs (e.g., "openresponses",
204+
"chatcompletions"). Use with base_url for third-party compatible APIs.
205+
base_url: Custom base URL override. Use with protocol for compatible APIs,
206+
or with provider to proxy through a custom endpoint.
184207
185208
Returns:
186209
Configured client instance ready for generation operations.
187210
188211
Raises:
189212
ModelNotFoundError: If no model found for the specified capability/provider.
190-
ClientNotFoundError: If no client registered for capability/provider.
213+
ClientNotFoundError: If no client registered for capability/provider/protocol.
191214
MissingCredentialsError: If required credentials are not configured.
192-
UnsupportedCapabilityError: If the resolved model doesn't support the requested capability.
193215
ValueError: If capability/operation cannot be inferred from model.
194216
"""
195217
# Translation layer: convert deprecated capability to modality/operation
@@ -208,37 +230,56 @@ def create_client(
208230
msg = "Either 'modality' or 'model' must be provided"
209231
raise ValueError(msg)
210232

211-
resolved_modality = Modality(modality) if isinstance(modality, str) else modality
212-
resolved_operation = (
213-
Operation(operation) if isinstance(operation, str) else operation
214-
)
215-
resolved_provider = Provider(provider) if isinstance(provider, str) else provider
233+
resolved_modality = modality
234+
resolved_operation = operation
235+
resolved_provider = provider
236+
resolved_protocol = protocol
237+
238+
# Default to openresponses when base_url is given without protocol or provider
239+
if base_url is not None and resolved_protocol is None and resolved_provider is None:
240+
resolved_protocol = Protocol.OPENRESPONSES
216241

217242
resolved_model = _resolve_model(
218243
modality=resolved_modality,
219244
operation=resolved_operation,
220245
provider=resolved_provider,
221246
model=model,
247+
protocol=resolved_protocol,
222248
)
223249

224-
key = (resolved_modality, resolved_model.provider)
225-
if key not in _CLIENT_MAP:
226-
raise ClientNotFoundError(
227-
modality=resolved_modality, provider=resolved_model.provider
228-
)
229-
modality_client_class = _CLIENT_MAP[key]
230-
231-
resolved_auth = credentials.get_auth(
232-
resolved_model.provider,
233-
override_auth=auth,
234-
override_key=api_key,
250+
# Client lookup: protocol takes precedence for compatible API path
251+
target = (
252+
resolved_protocol if resolved_protocol is not None else resolved_model.provider
235253
)
254+
if target is None:
255+
raise ClientNotFoundError(modality=resolved_modality)
256+
257+
if (resolved_modality, target) not in _CLIENT_MAP:
258+
raise ClientNotFoundError(modality=resolved_modality, provider=target)
259+
modality_client_class = _CLIENT_MAP[(resolved_modality, target)]
260+
261+
# Auth resolution: BYOA for protocol path, credentials for provider path
262+
if resolved_protocol is not None and resolved_provider is None:
263+
if auth is not None:
264+
resolved_auth = auth
265+
elif api_key is not None:
266+
resolved_auth = AuthHeader(secret=api_key) # type: ignore[arg-type] # validator converts str
267+
else:
268+
resolved_auth = NoAuth()
269+
else:
270+
resolved_auth = credentials.get_auth(
271+
resolved_model.provider, # type: ignore[arg-type] # always Provider in this branch
272+
override_auth=auth,
273+
override_key=api_key,
274+
)
236275

237276
return modality_client_class(
238277
modality=resolved_modality,
239278
model=resolved_model,
240-
provider=resolved_model.provider,
279+
provider=resolved_provider,
280+
protocol=resolved_protocol,
241281
auth=resolved_auth,
282+
base_url=base_url,
242283
)
243284

244285

@@ -264,6 +305,7 @@ def create_client(
264305
"Output",
265306
"Parameter",
266307
"Parameters",
308+
"Protocol",
267309
"Provider",
268310
"RefResolvingJsonSchemaGenerator",
269311
"Role",

src/celeste/client.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
from pydantic import BaseModel, ConfigDict, Field
1111

1212
from celeste.auth import Authentication
13-
from celeste.core import Modality, Provider
14-
from celeste.exceptions import StreamingNotSupportedError, UnsupportedParameterWarning
13+
from celeste.core import Modality, Protocol, Provider
14+
from celeste.exceptions import (
15+
ClientNotFoundError,
16+
StreamingNotSupportedError,
17+
UnsupportedParameterWarning,
18+
)
1519
from celeste.http import HTTPClient, get_http_client
1620
from celeste.io import Chunk as ChunkBase
1721
from celeste.io import FinishReason, Input, Output, Usage
@@ -47,7 +51,9 @@ class OpenAITextClient(OpenAIResponsesMixin, TextClient):
4751

4852
model: Model
4953
auth: Authentication
50-
provider: Provider
54+
provider: Provider | None
55+
protocol: Protocol | None
56+
base_url: str | None
5157
_content_fields: ClassVar[set[str]] = set()
5258

5359
@property
@@ -160,13 +166,19 @@ async def generate(self, prompt: str, **parameters) -> ImageGenerationOutput:
160166

161167
modality: Modality
162168
model: Model
163-
provider: Provider
169+
provider: Provider | None = None
170+
protocol: Protocol | None = None
164171
auth: Authentication = Field(exclude=True)
172+
base_url: str | None = Field(None, exclude=True)
165173

166174
@property
167175
def http_client(self) -> HTTPClient:
168176
"""Shared HTTP client with connection pooling."""
169-
return get_http_client(self.provider, self.modality)
177+
if self.provider is not None:
178+
return get_http_client(self.provider, self.modality)
179+
if self.protocol is not None:
180+
return get_http_client(self.protocol, self.modality)
181+
raise ClientNotFoundError(modality=self.modality)
170182

171183
# Namespace properties - implemented by modality clients
172184
@property
@@ -226,7 +238,6 @@ def _stream(
226238
stream_class: type[Stream[Out, Params, Chunk]],
227239
*,
228240
endpoint: str | None = None,
229-
base_url: str | None = None,
230241
extra_body: dict[str, Any] | None = None,
231242
extra_headers: dict[str, str] | None = None,
232243
**parameters: Unpack[Params], # type: ignore[misc]
@@ -259,7 +270,6 @@ def _stream(
259270
sse_iterator = self._make_stream_request(
260271
request_body,
261272
endpoint=endpoint,
262-
base_url=base_url,
263273
extra_headers=extra_headers,
264274
**parameters,
265275
)
@@ -269,7 +279,7 @@ def _stream(
269279
transform_output=self._transform_output,
270280
stream_metadata={
271281
"model": self.model.id,
272-
"provider": self.provider,
282+
"provider": self.provider or self.protocol,
273283
"modality": self.modality,
274284
},
275285
**parameters,
@@ -361,12 +371,16 @@ def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]:
361371
response_data = {
362372
k: v for k, v in response_data.items() if k not in self._content_fields
363373
}
364-
return {
374+
metadata: dict[str, Any] = {
365375
"model": self.model.id,
366-
"provider": self.provider,
367376
"modality": self.modality,
368377
"raw_response": response_data,
369378
}
379+
if self.provider is not None:
380+
metadata["provider"] = self.provider
381+
if self.protocol is not None:
382+
metadata["protocol"] = self.protocol
383+
return metadata
370384

371385
def _handle_error_response(self, response: httpx.Response) -> None:
372386
"""Handle error responses from provider APIs."""
@@ -383,7 +397,7 @@ def _handle_error_response(self, response: httpx.Response) -> None:
383397
error_msg = response.text or f"HTTP {response.status_code}"
384398

385399
raise httpx.HTTPStatusError(
386-
f"{self.provider} API error: {error_msg}",
400+
f"{self.provider or self.protocol} API error: {error_msg}",
387401
request=response.request,
388402
response=response,
389403
)

src/celeste/core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ class Provider(StrEnum):
2525
ELEVENLABS = "elevenlabs"
2626
GROQ = "groq"
2727
GRADIUM = "gradium"
28-
OPENRESPONSES = "openresponses"
2928
OLLAMA = "ollama"
3029

3130

31+
class Protocol(StrEnum):
32+
"""Wire format protocols for compatible APIs."""
33+
34+
OPENRESPONSES = "openresponses"
35+
CHATCOMPLETIONS = "chatcompletions"
36+
37+
3238
class Modality(StrEnum):
3339
"""Supported modalities."""
3440

@@ -167,6 +173,7 @@ def infer_modality(domain: Domain, operation: Operation) -> Modality:
167173
"Modality",
168174
"Operation",
169175
"Parameter",
176+
"Protocol",
170177
"Provider",
171178
"UsageField",
172179
"infer_modality",

src/celeste/credentials.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def get_auth(
176176

177177
registered = _auth_registry.get(provider)
178178
if registered is None:
179+
if override_key is not None:
180+
return AuthHeader(secret=override_key) # type: ignore[arg-type] # validator converts str
179181
raise UnsupportedProviderError(provider=provider)
180182

181183
# Auth class (GoogleADC, OAuth, etc.) → instantiate

src/celeste/http.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import httpx
1010
from httpx_sse import aconnect_sse
1111

12-
from celeste.core import Modality, Provider
12+
from celeste.core import Modality, Protocol, Provider
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -246,10 +246,10 @@ async def __aexit__(self, *args: Any) -> None: # noqa: ANN401
246246

247247

248248
# Module-level registry of shared HTTPClient instances
249-
_http_clients: dict[tuple[Provider, Modality], HTTPClient] = {}
249+
_http_clients: dict[tuple[Provider | Protocol, Modality], HTTPClient] = {}
250250

251251

252-
def get_http_client(provider: Provider, modality: Modality) -> HTTPClient:
252+
def get_http_client(provider: Provider | Protocol, modality: Modality) -> HTTPClient:
253253
"""Get or create shared HTTP client for provider and modality combination.
254254
255255
Args:

0 commit comments

Comments
 (0)