Skip to content

Commit 8d0c1a3

Browse files
committed
fix: enforce bridge tool choice semantics
- reject non-object tool arguments instead of coercing them to {}\n- skip structured tool planning when tool_choice is none\n- validate required/specific tool_choice outputs\n- tighten prompt instructions and add regression tests
1 parent 078e390 commit 8d0c1a3

File tree

2 files changed

+205
-4
lines changed

2 files changed

+205
-4
lines changed

examples/subscription_bridge/server.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,58 @@ def _describe_tool_choice(tool_choice: Any) -> str:
136136
return "auto"
137137

138138

139+
def _required_tool_choice_name(tool_choice: Any) -> str | None:
140+
if not isinstance(tool_choice, dict):
141+
return None
142+
if tool_choice.get("type") == "function":
143+
function = tool_choice.get("function")
144+
if isinstance(function, dict):
145+
name = function.get("name")
146+
if isinstance(name, str) and name.strip():
147+
return name.strip()
148+
name = tool_choice.get("name")
149+
if isinstance(name, str) and name.strip():
150+
return name.strip()
151+
return None
152+
153+
154+
def _tool_choice_requires_tool_calls(tool_choice: Any) -> bool:
155+
return tool_choice == "required" or _required_tool_choice_name(tool_choice) is not None
156+
157+
158+
def _tool_choice_allows_structured_tool_calls(tool_choice: Any) -> bool:
159+
return tool_choice != "none"
160+
161+
162+
def _validate_tool_choice_decision(decision: dict[str, Any], payload: dict[str, Any]) -> None:
163+
tool_choice = payload.get("tool_choice")
164+
required_tool_name = _required_tool_choice_name(tool_choice)
165+
166+
if tool_choice == "none":
167+
if decision.get("type") == "tool_calls":
168+
raise RuntimeError("tool_choice='none' forbids tool calls")
169+
return
170+
171+
if required_tool_name is not None:
172+
if decision.get("type") != "tool_calls":
173+
raise RuntimeError(
174+
f"required tool choice {required_tool_name!r} requires a tool call"
175+
)
176+
invalid_names = [
177+
tool_call.get("name")
178+
for tool_call in decision.get("tool_calls", [])
179+
if tool_call.get("name") != required_tool_name
180+
]
181+
if invalid_names:
182+
raise RuntimeError(
183+
f"backend violated required tool choice {required_tool_name!r}"
184+
)
185+
return
186+
187+
if _tool_choice_requires_tool_calls(tool_choice) and decision.get("type") != "tool_calls":
188+
raise RuntimeError("tool_choice='required' requires a tool call")
189+
190+
139191
def _chat_message_blocks(messages: Any) -> list[str]:
140192
if not isinstance(messages, list) or not messages:
141193
raise ValueError("chat.completions payload must include non-empty messages")
@@ -239,7 +291,9 @@ def build_responses_prompt(payload: dict[str, Any]) -> str:
239291

240292

241293
def _build_structured_decision_prompt(base_prompt: str, payload: dict[str, Any]) -> str:
242-
tool_choice = _describe_tool_choice(payload.get("tool_choice"))
294+
raw_tool_choice = payload.get("tool_choice")
295+
tool_choice = _describe_tool_choice(raw_tool_choice)
296+
required_tool_name = _required_tool_choice_name(raw_tool_choice)
243297
parallel_tool_calls = bool(payload.get("parallel_tool_calls"))
244298
instructions = [
245299
"Return JSON only.",
@@ -256,6 +310,11 @@ def _build_structured_decision_prompt(base_prompt: str, payload: dict[str, Any])
256310
"When you emit tool_calls, arguments_json must be a valid JSON string encoding an object that matches the tool schema.",
257311
"Do not invent tools.",
258312
]
313+
if raw_tool_choice == "required":
314+
instructions.append("You must return at least one tool call.")
315+
if required_tool_name is not None:
316+
instructions.append("You must return at least one tool call.")
317+
instructions.append(f"Every tool call name must be exactly {required_tool_name}.")
259318
return f"{base_prompt}\n\nDecision rules:\n- " + "\n- ".join(instructions)
260319

261320

@@ -285,8 +344,10 @@ def _coerce_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]
285344
arguments = json.loads(arguments)
286345
except json.JSONDecodeError:
287346
arguments = {"value": arguments}
288-
if not isinstance(arguments, dict):
347+
if arguments is None:
289348
arguments = {}
349+
elif not isinstance(arguments, dict):
350+
raise ValueError("tool call arguments must decode to a JSON object")
290351
normalized.append(
291352
{
292353
"call_id": tool_call.get("call_id") or f"call_{uuid.uuid4().hex}",
@@ -625,7 +686,9 @@ def _respond_for_chat_request(
625686
payload: dict[str, Any], *, backend: str, model: str, workdir: Path, request_id: str
626687
) -> dict[str, Any]:
627688
prompt = build_chat_prompt(payload)
628-
if _normalize_tools(payload.get("tools")):
689+
if _normalize_tools(payload.get("tools")) and _tool_choice_allows_structured_tool_calls(
690+
payload.get("tool_choice")
691+
):
629692
try:
630693
decision = run_backend_structured(
631694
backend=backend,
@@ -634,6 +697,7 @@ def _respond_for_chat_request(
634697
workdir=workdir,
635698
schema=DecisionSchema,
636699
)
700+
_validate_tool_choice_decision(decision, payload)
637701
if decision.get("type") == "tool_calls":
638702
return build_chat_completion_response(
639703
model=model,
@@ -656,7 +720,9 @@ def _respond_for_responses_request(
656720
payload: dict[str, Any], *, backend: str, model: str, workdir: Path, request_id: str
657721
) -> dict[str, Any]:
658722
prompt = build_responses_prompt(payload)
659-
if _normalize_tools(payload.get("tools")):
723+
if _normalize_tools(payload.get("tools")) and _tool_choice_allows_structured_tool_calls(
724+
payload.get("tool_choice")
725+
):
660726
try:
661727
decision = run_backend_structured(
662728
backend=backend,
@@ -665,6 +731,7 @@ def _respond_for_responses_request(
665731
workdir=workdir,
666732
schema=DecisionSchema,
667733
)
734+
_validate_tool_choice_decision(decision, payload)
668735
if decision.get("type") == "tool_calls":
669736
return build_responses_api_response(
670737
model=model,

tests/examples/test_subscription_bridge.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,119 @@ def test_build_chat_completion_response_can_emit_tool_calls() -> None:
230230
assert json.loads(tool_call["function"]["arguments"]) == {"city": "Tokyo"}
231231

232232

233+
def test_build_chat_completion_response_rejects_non_object_tool_arguments() -> None:
234+
with pytest.raises(ValueError, match="must decode to a JSON object"):
235+
server.build_chat_completion_response(
236+
model="codex/gpt-5.4",
237+
request_id="req_bad",
238+
tool_calls=[{"name": "get_weather", "arguments_json": '[]'}],
239+
)
240+
241+
242+
def test_respond_for_chat_request_skips_structured_tool_mode_when_tool_choice_is_none(
243+
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
244+
) -> None:
245+
called: dict[str, bool] = {"structured": False, "plain": False}
246+
247+
def fake_run_backend_structured(**_: Any) -> dict[str, Any]:
248+
called["structured"] = True
249+
raise AssertionError("structured backend should not run when tool_choice is none")
250+
251+
def fake_run_backend(*, backend: str, prompt: str, model: str | None, workdir: Path) -> str:
252+
called["plain"] = True
253+
return "No tool call emitted."
254+
255+
monkeypatch.setattr(server, "run_backend_structured", fake_run_backend_structured)
256+
monkeypatch.setattr(server, "run_backend", fake_run_backend)
257+
258+
response = server._respond_for_chat_request(
259+
{
260+
"messages": [{"role": "user", "content": "Just answer directly."}],
261+
"tools": [
262+
{
263+
"type": "function",
264+
"function": {
265+
"name": "get_weather",
266+
"description": "Get the weather for a city.",
267+
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
268+
},
269+
}
270+
],
271+
"tool_choice": "none",
272+
},
273+
backend="codex",
274+
model="codex/gpt-5.4",
275+
workdir=tmp_path,
276+
request_id="req_none",
277+
)
278+
279+
assert called == {"structured": False, "plain": True}
280+
assert response["choices"][0]["finish_reason"] == "stop"
281+
assert response["choices"][0]["message"]["content"] == "No tool call emitted."
282+
283+
284+
def test_respond_for_chat_request_rejects_tool_calls_outside_required_tool_choice(
285+
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
286+
) -> None:
287+
def fake_run_backend_structured(**_: Any) -> dict[str, Any]:
288+
return {"type": "tool_calls", "tool_calls": [{"name": "other_tool", "arguments": {}}]}
289+
290+
monkeypatch.setattr(server, "run_backend_structured", fake_run_backend_structured)
291+
292+
with pytest.raises(RuntimeError, match="required tool choice"):
293+
server._respond_for_chat_request(
294+
{
295+
"messages": [{"role": "user", "content": "Use the weather tool."}],
296+
"tools": [
297+
{
298+
"type": "function",
299+
"function": {
300+
"name": "get_weather",
301+
"description": "Get the weather for a city.",
302+
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
303+
},
304+
}
305+
],
306+
"tool_choice": {"type": "function", "function": {"name": "get_weather"}},
307+
},
308+
backend="codex",
309+
model="codex/gpt-5.4",
310+
workdir=tmp_path,
311+
request_id="req_specific",
312+
)
313+
314+
315+
def test_respond_for_chat_request_requires_tool_calls_when_tool_choice_is_required(
316+
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
317+
) -> None:
318+
def fake_run_backend_structured(**_: Any) -> dict[str, Any]:
319+
return {"type": "final", "content": "Here is a direct answer."}
320+
321+
monkeypatch.setattr(server, "run_backend_structured", fake_run_backend_structured)
322+
323+
with pytest.raises(RuntimeError, match="requires a tool call"):
324+
server._respond_for_chat_request(
325+
{
326+
"messages": [{"role": "user", "content": "Use a tool."}],
327+
"tools": [
328+
{
329+
"type": "function",
330+
"function": {
331+
"name": "get_weather",
332+
"description": "Get the weather for a city.",
333+
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
334+
},
335+
}
336+
],
337+
"tool_choice": "required",
338+
},
339+
backend="codex",
340+
model="codex/gpt-5.4",
341+
workdir=tmp_path,
342+
request_id="req_required",
343+
)
344+
345+
233346
def test_build_responses_api_response_can_emit_function_calls() -> None:
234347
response = server.build_responses_api_response(
235348
model="codex/gpt-5.4",
@@ -330,6 +443,27 @@ def test_structured_decision_prompt_requires_plain_text_final_content() -> None:
330443
assert "Do not wrap the final answer in JSON" in prompt
331444

332445

446+
def test_structured_decision_prompt_requires_tool_calls_when_tool_choice_is_required() -> None:
447+
prompt = server._build_structured_decision_prompt(
448+
"Conversation transcript:\n\n[user]\nUse a tool.",
449+
{"tool_choice": "required", "parallel_tool_calls": False},
450+
)
451+
452+
assert "You must return at least one tool call." in prompt
453+
454+
455+
def test_structured_decision_prompt_limits_specific_tool_choice() -> None:
456+
prompt = server._build_structured_decision_prompt(
457+
"Conversation transcript:\n\n[user]\nUse the weather tool.",
458+
{
459+
"tool_choice": {"type": "function", "function": {"name": "get_weather"}},
460+
"parallel_tool_calls": False,
461+
},
462+
)
463+
464+
assert "Every tool call name must be exactly get_weather." in prompt
465+
466+
333467
def test_normalize_decision_payload_unwraps_nested_final_json_content() -> None:
334468
payload = {
335469
"type": "final",

0 commit comments

Comments
 (0)