Skip to content

Commit cfc54f9

Browse files
authored
feat: #2492 add explicit MultiProvider prefix modes (#2593)
1 parent cda1360 commit cfc54f9

File tree

4 files changed

+242
-8
lines changed

4 files changed

+242
-8
lines changed

src/agents/models/multi_provider.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

3+
from typing import Literal, cast
4+
35
from openai import AsyncOpenAI
46

57
from ..exceptions import UserError
68
from .interface import Model, ModelProvider
79
from .openai_provider import OpenAIProvider
810

11+
MultiProviderOpenAIPrefixMode = Literal["alias", "model_id"]
12+
MultiProviderUnknownPrefixMode = Literal["error", "model_id"]
13+
914

1015
class MultiProviderMap:
1116
"""A map of model name prefixes to ModelProviders."""
@@ -57,7 +62,11 @@ class MultiProvider(ModelProvider):
5762
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
5863
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
5964
60-
You can override or customize this mapping.
65+
You can override or customize this mapping. The ``openai`` prefix is ambiguous for some
66+
OpenAI-compatible backends because a string like ``openai/gpt-4.1`` could mean either "route
67+
to the OpenAI provider and use model ``gpt-4.1``" or "send the literal model ID
68+
``openai/gpt-4.1`` to the configured OpenAI-compatible endpoint." The prefix mode options let
69+
callers opt into the second behavior without breaking the historical alias semantics.
6170
"""
6271

6372
def __init__(
@@ -72,6 +81,8 @@ def __init__(
7281
openai_use_responses: bool | None = None,
7382
openai_use_responses_websocket: bool | None = None,
7483
openai_websocket_base_url: str | None = None,
84+
openai_prefix_mode: MultiProviderOpenAIPrefixMode = "alias",
85+
unknown_prefix_mode: MultiProviderUnknownPrefixMode = "error",
7586
) -> None:
7687
"""Create a new OpenAI provider.
7788
@@ -92,6 +103,15 @@ def __init__(
92103
responses API.
93104
openai_websocket_base_url: The websocket base URL to use for the OpenAI provider.
94105
If not provided, the provider will use `OPENAI_WEBSOCKET_BASE_URL` when set.
106+
openai_prefix_mode: Controls how ``openai/...`` model strings are interpreted.
107+
``"alias"`` preserves the historical behavior and strips the ``openai/`` prefix
108+
before calling the OpenAI provider. ``"model_id"`` keeps the full string and is
109+
useful for OpenAI-compatible endpoints that expect literal namespaced model IDs.
110+
unknown_prefix_mode: Controls how prefixes outside the explicit provider map and
111+
built-in fallbacks are handled. ``"error"`` preserves the historical fail-fast
112+
behavior and raises ``UserError``. ``"model_id"`` passes the full string through to
113+
the OpenAI provider so OpenAI-compatible endpoints can receive namespaced model IDs
114+
such as ``openrouter/openai/gpt-4o``.
95115
"""
96116
self.provider_map = provider_map
97117
self.openai_provider = OpenAIProvider(
@@ -104,6 +124,8 @@ def __init__(
104124
use_responses=openai_use_responses,
105125
use_responses_websocket=openai_use_responses_websocket,
106126
)
127+
self._openai_prefix_mode = self._validate_openai_prefix_mode(openai_prefix_mode)
128+
self._unknown_prefix_mode = self._validate_unknown_prefix_mode(unknown_prefix_mode)
107129

108130
self._fallback_providers: dict[str, ModelProvider] = {}
109131

@@ -124,6 +146,20 @@ def _create_fallback_provider(self, prefix: str) -> ModelProvider:
124146
else:
125147
raise UserError(f"Unknown prefix: {prefix}")
126148

149+
@staticmethod
150+
def _validate_openai_prefix_mode(mode: str) -> MultiProviderOpenAIPrefixMode:
151+
if mode not in {"alias", "model_id"}:
152+
raise UserError("MultiProvider openai_prefix_mode must be one of: 'alias', 'model_id'.")
153+
return cast(MultiProviderOpenAIPrefixMode, mode)
154+
155+
@staticmethod
156+
def _validate_unknown_prefix_mode(mode: str) -> MultiProviderUnknownPrefixMode:
157+
if mode not in {"error", "model_id"}:
158+
raise UserError(
159+
"MultiProvider unknown_prefix_mode must be one of: 'error', 'model_id'."
160+
)
161+
return cast(MultiProviderUnknownPrefixMode, mode)
162+
127163
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
128164
if prefix is None or prefix == "openai":
129165
return self.openai_provider
@@ -133,6 +169,31 @@ def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
133169
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
134170
return self._fallback_providers[prefix]
135171

172+
def _resolve_prefixed_model(
173+
self,
174+
*,
175+
original_model_name: str,
176+
prefix: str,
177+
stripped_model_name: str | None,
178+
) -> tuple[ModelProvider, str | None]:
179+
# Explicit provider_map entries are the least surprising routing mechanism, so they always
180+
# win over the built-in OpenAI alias and unknown-prefix fallback behavior.
181+
if self.provider_map and (provider := self.provider_map.get_provider(prefix)):
182+
return provider, stripped_model_name
183+
184+
if prefix == "litellm":
185+
return self._get_fallback_provider(prefix), stripped_model_name
186+
187+
if prefix == "openai":
188+
if self._openai_prefix_mode == "alias":
189+
return self.openai_provider, stripped_model_name
190+
return self.openai_provider, original_model_name
191+
192+
if self._unknown_prefix_mode == "model_id":
193+
return self.openai_provider, original_model_name
194+
195+
raise UserError(f"Unknown prefix: {prefix}")
196+
136197
def get_model(self, model_name: str | None) -> Model:
137198
"""Returns a Model based on the model name. The model name can have a prefix, ending with
138199
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
@@ -144,12 +205,21 @@ def get_model(self, model_name: str | None) -> Model:
144205
Returns:
145206
A Model.
146207
"""
147-
prefix, model_name = self._get_prefix_and_model_name(model_name)
208+
# Bare model names are always delegated directly to the OpenAI provider. That provider can
209+
# still point at an OpenAI-compatible endpoint via ``base_url``.
210+
if model_name is None:
211+
return self.openai_provider.get_model(None)
148212

149-
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
150-
return provider.get_model(model_name)
151-
else:
152-
return self._get_fallback_provider(prefix).get_model(model_name)
213+
prefix, stripped_model_name = self._get_prefix_and_model_name(model_name)
214+
if prefix is None:
215+
return self.openai_provider.get_model(stripped_model_name)
216+
217+
provider, resolved_model_name = self._resolve_prefixed_model(
218+
original_model_name=model_name,
219+
prefix=prefix,
220+
stripped_model_name=stripped_model_name,
221+
)
222+
return provider.get_model(resolved_model_name)
153223

154224
async def aclose(self) -> None:
155225
"""Close cached resources held by child providers."""

src/agents/responses_websocket_session.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
from .agent import Agent
99
from .items import TResponseInputItem
10-
from .models.multi_provider import MultiProvider
10+
from .models.multi_provider import (
11+
MultiProvider,
12+
MultiProviderOpenAIPrefixMode,
13+
MultiProviderUnknownPrefixMode,
14+
)
1115
from .models.openai_provider import OpenAIProvider
1216
from .result import RunResult, RunResultStreaming
1317
from .run import Runner
@@ -80,6 +84,8 @@ async def responses_websocket_session(
8084
websocket_base_url: str | None = None,
8185
organization: str | None = None,
8286
project: str | None = None,
87+
openai_prefix_mode: MultiProviderOpenAIPrefixMode = "alias",
88+
unknown_prefix_mode: MultiProviderUnknownPrefixMode = "error",
8389
) -> AsyncIterator[ResponsesWebSocketSession]:
8490
"""Create a shared OpenAI Responses websocket session for multiple Runner calls.
8591
@@ -89,6 +95,10 @@ async def responses_websocket_session(
8995
connections warm across turns and nested agent-as-tool runs that inherit the same
9096
``run_config``.
9197
98+
Use ``openai_prefix_mode="model_id"`` and/or ``unknown_prefix_mode="model_id"`` when the
99+
configured OpenAI-compatible endpoint expects literal namespaced model IDs instead of the SDK's
100+
historical routing-prefix behavior.
101+
92102
Drain or close streamed iterators before the context exits. Exiting the context while a
93103
websocket request is still in flight may force-close the shared connection.
94104
"""
@@ -100,6 +110,8 @@ async def responses_websocket_session(
100110
openai_project=project,
101111
openai_use_responses=True,
102112
openai_use_responses_websocket=True,
113+
openai_prefix_mode=openai_prefix_mode,
114+
unknown_prefix_mode=unknown_prefix_mode,
103115
)
104116
provider = model_provider.openai_provider
105117
session = ResponsesWebSocketSession(

tests/models/test_map.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
1-
from agents import Agent, MultiProvider, OpenAIResponsesModel, OpenAIResponsesWSModel, RunConfig
1+
from typing import Any, cast
2+
3+
import pytest
4+
5+
from agents import (
6+
Agent,
7+
MultiProvider,
8+
OpenAIResponsesModel,
9+
OpenAIResponsesWSModel,
10+
RunConfig,
11+
UserError,
12+
)
213
from agents.extensions.models.litellm_model import LitellmModel
14+
from agents.models.multi_provider import MultiProviderMap
315
from agents.run_internal.run_loop import get_model
416

517

@@ -53,3 +65,107 @@ def get_model(self, model_name):
5365

5466
MultiProvider(openai_websocket_base_url="wss://proxy.example.test/v1")
5567
assert captured_kwargs["websocket_base_url"] == "wss://proxy.example.test/v1"
68+
69+
70+
def test_openai_prefix_defaults_to_alias_mode(monkeypatch):
71+
captured_model: dict[str, Any] = {}
72+
73+
class FakeOpenAIProvider:
74+
def __init__(self, **kwargs):
75+
pass
76+
77+
def get_model(self, model_name):
78+
captured_model["value"] = model_name
79+
return object()
80+
81+
monkeypatch.setattr("agents.models.multi_provider.OpenAIProvider", FakeOpenAIProvider)
82+
83+
provider = MultiProvider()
84+
provider.get_model("openai/gpt-4o")
85+
assert captured_model["value"] == "gpt-4o"
86+
87+
88+
def test_openai_prefix_can_be_preserved_as_literal_model_id(monkeypatch):
89+
captured_model: dict[str, Any] = {}
90+
91+
class FakeOpenAIProvider:
92+
def __init__(self, **kwargs):
93+
pass
94+
95+
def get_model(self, model_name):
96+
captured_model["value"] = model_name
97+
return object()
98+
99+
monkeypatch.setattr("agents.models.multi_provider.OpenAIProvider", FakeOpenAIProvider)
100+
101+
provider = MultiProvider(openai_prefix_mode="model_id")
102+
provider.get_model("openai/gpt-4o")
103+
assert captured_model["value"] == "openai/gpt-4o"
104+
105+
106+
def test_unknown_prefix_defaults_to_error():
107+
provider = MultiProvider()
108+
109+
with pytest.raises(UserError, match="Unknown prefix: openrouter"):
110+
provider.get_model("openrouter/openai/gpt-4o")
111+
112+
113+
def test_unknown_prefix_can_be_preserved_for_openai_compatible_model_ids(monkeypatch):
114+
captured_model: dict[str, Any] = {}
115+
captured_result: dict[str, Any] = {}
116+
117+
class FakeOpenAIProvider:
118+
def __init__(self, **kwargs):
119+
pass
120+
121+
def get_model(self, model_name):
122+
captured_model["value"] = model_name
123+
fake_model = object()
124+
captured_result["value"] = fake_model
125+
return fake_model
126+
127+
monkeypatch.setattr("agents.models.multi_provider.OpenAIProvider", FakeOpenAIProvider)
128+
129+
provider = MultiProvider(unknown_prefix_mode="model_id")
130+
result = provider.get_model("openrouter/openai/gpt-4o")
131+
assert result is captured_result["value"]
132+
assert captured_model["value"] == "openrouter/openai/gpt-4o"
133+
134+
135+
def test_provider_map_entries_override_openai_prefix_mode(monkeypatch):
136+
captured_model: dict[str, Any] = {}
137+
138+
class FakeCustomProvider:
139+
def get_model(self, model_name):
140+
captured_model["value"] = model_name
141+
return object()
142+
143+
class FakeOpenAIProvider:
144+
def __init__(self, **kwargs):
145+
pass
146+
147+
def get_model(self, model_name):
148+
raise AssertionError("Expected the explicit provider_map entry to win.")
149+
150+
monkeypatch.setattr("agents.models.multi_provider.OpenAIProvider", FakeOpenAIProvider)
151+
152+
provider_map = MultiProviderMap()
153+
provider_map.add_provider("openai", cast(Any, FakeCustomProvider()))
154+
155+
provider = MultiProvider(
156+
provider_map=provider_map,
157+
openai_prefix_mode="model_id",
158+
)
159+
provider.get_model("openai/gpt-4o")
160+
assert captured_model["value"] == "gpt-4o"
161+
162+
163+
def test_multi_provider_rejects_invalid_prefix_modes():
164+
bad_openai_prefix_mode: Any = "invalid"
165+
bad_unknown_prefix_mode: Any = "invalid"
166+
167+
with pytest.raises(UserError, match="openai_prefix_mode"):
168+
MultiProvider(openai_prefix_mode=bad_openai_prefix_mode)
169+
170+
with pytest.raises(UserError, match="unknown_prefix_mode"):
171+
MultiProvider(unknown_prefix_mode=bad_unknown_prefix_mode)

tests/test_responses_websocket_session.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,42 @@ def fake_get_model(model_name):
3535
assert captured["model_name"] == "gpt-4.1"
3636

3737

38+
@pytest.mark.asyncio
39+
async def test_responses_websocket_session_can_preserve_openai_prefix_model_ids(monkeypatch):
40+
captured: dict[str, object] = {}
41+
sentinel = object()
42+
43+
def fake_get_model(model_name):
44+
captured["model_name"] = model_name
45+
return sentinel
46+
47+
async with responses_websocket_session(openai_prefix_mode="model_id") as ws:
48+
monkeypatch.setattr(ws.provider, "get_model", fake_get_model)
49+
50+
result = ws.run_config.model_provider.get_model("openai/gpt-4.1")
51+
52+
assert result is sentinel
53+
assert captured["model_name"] == "openai/gpt-4.1"
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_responses_websocket_session_can_preserve_unknown_prefix_model_ids(monkeypatch):
58+
captured: dict[str, object] = {}
59+
sentinel = object()
60+
61+
def fake_get_model(model_name):
62+
captured["model_name"] = model_name
63+
return sentinel
64+
65+
async with responses_websocket_session(unknown_prefix_mode="model_id") as ws:
66+
monkeypatch.setattr(ws.provider, "get_model", fake_get_model)
67+
68+
result = ws.run_config.model_provider.get_model("openrouter/openai/gpt-4.1")
69+
70+
assert result is sentinel
71+
assert captured["model_name"] == "openrouter/openai/gpt-4.1"
72+
73+
3874
@pytest.mark.asyncio
3975
async def test_responses_websocket_session_run_streamed_injects_run_config(monkeypatch):
4076
agent = Agent(name="test", instructions="Be concise.", model="gpt-4")

0 commit comments

Comments
 (0)