Skip to content

Commit 4f3c8a5

Browse files
authored
fix: #1876 LiteLLM extra_body forwarding (#2900)
1 parent 5086b24 commit 4f3c8a5

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,14 @@ async def _fetch_response(
506506
extra_kwargs["extra_query"] = copy(model_settings.extra_query)
507507
if model_settings.metadata:
508508
extra_kwargs["metadata"] = copy(model_settings.metadata)
509-
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
510-
extra_kwargs.update(model_settings.extra_body)
509+
if model_settings.extra_body is not None:
510+
extra_body = copy(model_settings.extra_body)
511+
if isinstance(extra_body, dict) and reasoning_effort is not None:
512+
extra_body.pop("reasoning_effort", None)
513+
if not extra_body:
514+
extra_body = None
515+
if extra_body is not None:
516+
extra_kwargs["extra_body"] = extra_body
511517

512518
# Add kwargs from model_settings.extra_args, filtering out None values
513519
if model_settings.extra_args:

tests/models/test_litellm_extra_body.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
@pytest.mark.asyncio
1414
async def test_extra_body_is_forwarded(monkeypatch):
1515
"""
16-
Forward `extra_body` entries into litellm.acompletion kwargs.
16+
Forward `extra_body` via LiteLLM's dedicated kwarg.
1717
18-
This ensures that user-provided parameters (e.g. cached_content)
19-
arrive alongside default arguments.
18+
This ensures that provider-specific request fields stay nested under `extra_body`
19+
so LiteLLM can merge them into the upstream request body itself.
2020
"""
2121
captured: dict[str, object] = {}
2222

@@ -43,7 +43,9 @@ async def fake_acompletion(model, messages=None, **kwargs):
4343
previous_response_id=None,
4444
)
4545

46-
assert {"cached_content": "some_cache", "foo": 123}.items() <= captured.items()
46+
assert captured["extra_body"] == {"cached_content": "some_cache", "foo": 123}
47+
assert "cached_content" not in captured
48+
assert "foo" not in captured
4749

4850

4951
@pytest.mark.allow_call_model_methods
@@ -79,7 +81,7 @@ async def fake_acompletion(model, messages=None, **kwargs):
7981
)
8082

8183
assert captured["reasoning_effort"] == "none"
82-
assert captured["cached_content"] == "some_cache"
84+
assert captured["extra_body"] == {"cached_content": "some_cache"}
8385
assert settings.extra_body == {"reasoning_effort": "none", "cached_content": "some_cache"}
8486

8587

@@ -119,6 +121,7 @@ async def fake_acompletion(model, messages=None, **kwargs):
119121

120122
# reasoning_effort is string when no summary is provided (backward compatible)
121123
assert captured["reasoning_effort"] == "low"
124+
assert "extra_body" not in captured
122125
assert settings.extra_body == {"reasoning_effort": "high"}
123126

124127

@@ -157,9 +160,55 @@ async def fake_acompletion(model, messages=None, **kwargs):
157160

158161
assert captured["reasoning_effort"] == "none"
159162
assert captured["custom_param"] == "custom"
163+
assert "extra_body" not in captured
160164
assert settings.extra_args == {"reasoning_effort": "low", "custom_param": "custom"}
161165

162166

167+
@pytest.mark.allow_call_model_methods
168+
@pytest.mark.asyncio
169+
async def test_extra_body_metadata_stays_nested(monkeypatch):
170+
"""
171+
Keep extra_body metadata nested even when top-level metadata is also set.
172+
173+
LiteLLM resolves top-level metadata and extra_body separately. Flattening the nested
174+
metadata dict loses the caller's intended request shape for OpenAI-compatible proxies.
175+
"""
176+
captured: dict[str, object] = {}
177+
178+
async def fake_acompletion(model, messages=None, **kwargs):
179+
captured.update(kwargs)
180+
msg = Message(role="assistant", content="ok")
181+
choice = Choices(index=0, message=msg)
182+
return ModelResponse(choices=[choice], usage=Usage(0, 0, 0))
183+
184+
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
185+
settings = ModelSettings(
186+
metadata={"sdk": "agents"},
187+
extra_body={
188+
"metadata": {"trace_user_id": "user-123", "generation_id": "gen-456"},
189+
"cached_content": "some_cache",
190+
},
191+
)
192+
model = LitellmModel(model="test-model")
193+
194+
await model.get_response(
195+
system_instructions=None,
196+
input=[],
197+
model_settings=settings,
198+
tools=[],
199+
output_schema=None,
200+
handoffs=[],
201+
tracing=ModelTracing.DISABLED,
202+
previous_response_id=None,
203+
)
204+
205+
assert captured["metadata"] == {"sdk": "agents"}
206+
assert captured["extra_body"] == {
207+
"metadata": {"trace_user_id": "user-123", "generation_id": "gen-456"},
208+
"cached_content": "some_cache",
209+
}
210+
211+
163212
@pytest.mark.allow_call_model_methods
164213
@pytest.mark.asyncio
165214
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)