@@ -136,6 +136,58 @@ def _describe_tool_choice(tool_choice: Any) -> str:
136136 return "auto"
137137
138138
139+ def _required_tool_choice_name (tool_choice : Any ) -> str | None :
140+ if not isinstance (tool_choice , dict ):
141+ return None
142+ if tool_choice .get ("type" ) == "function" :
143+ function = tool_choice .get ("function" )
144+ if isinstance (function , dict ):
145+ name = function .get ("name" )
146+ if isinstance (name , str ) and name .strip ():
147+ return name .strip ()
148+ name = tool_choice .get ("name" )
149+ if isinstance (name , str ) and name .strip ():
150+ return name .strip ()
151+ return None
152+
153+
154+ def _tool_choice_requires_tool_calls (tool_choice : Any ) -> bool :
155+ return tool_choice == "required" or _required_tool_choice_name (tool_choice ) is not None
156+
157+
158+ def _tool_choice_allows_structured_tool_calls (tool_choice : Any ) -> bool :
159+ return tool_choice != "none"
160+
161+
162+ def _validate_tool_choice_decision (decision : dict [str , Any ], payload : dict [str , Any ]) -> None :
163+ tool_choice = payload .get ("tool_choice" )
164+ required_tool_name = _required_tool_choice_name (tool_choice )
165+
166+ if tool_choice == "none" :
167+ if decision .get ("type" ) == "tool_calls" :
168+ raise RuntimeError ("tool_choice='none' forbids tool calls" )
169+ return
170+
171+ if required_tool_name is not None :
172+ if decision .get ("type" ) != "tool_calls" :
173+ raise RuntimeError (
174+ f"required tool choice { required_tool_name !r} requires a tool call"
175+ )
176+ invalid_names = [
177+ tool_call .get ("name" )
178+ for tool_call in decision .get ("tool_calls" , [])
179+ if tool_call .get ("name" ) != required_tool_name
180+ ]
181+ if invalid_names :
182+ raise RuntimeError (
183+ f"backend violated required tool choice { required_tool_name !r} "
184+ )
185+ return
186+
187+ if _tool_choice_requires_tool_calls (tool_choice ) and decision .get ("type" ) != "tool_calls" :
188+ raise RuntimeError ("tool_choice='required' requires a tool call" )
189+
190+
139191def _chat_message_blocks (messages : Any ) -> list [str ]:
140192 if not isinstance (messages , list ) or not messages :
141193 raise ValueError ("chat.completions payload must include non-empty messages" )
@@ -239,7 +291,9 @@ def build_responses_prompt(payload: dict[str, Any]) -> str:
239291
240292
241293def _build_structured_decision_prompt (base_prompt : str , payload : dict [str , Any ]) -> str :
242- tool_choice = _describe_tool_choice (payload .get ("tool_choice" ))
294+ raw_tool_choice = payload .get ("tool_choice" )
295+ tool_choice = _describe_tool_choice (raw_tool_choice )
296+ required_tool_name = _required_tool_choice_name (raw_tool_choice )
243297 parallel_tool_calls = bool (payload .get ("parallel_tool_calls" ))
244298 instructions = [
245299 "Return JSON only." ,
@@ -256,6 +310,11 @@ def _build_structured_decision_prompt(base_prompt: str, payload: dict[str, Any])
256310 "When you emit tool_calls, arguments_json must be a valid JSON string encoding an object that matches the tool schema." ,
257311 "Do not invent tools." ,
258312 ]
313+ if raw_tool_choice == "required" :
314+ instructions .append ("You must return at least one tool call." )
315+ if required_tool_name is not None :
316+ instructions .append ("You must return at least one tool call." )
317+ instructions .append (f"Every tool call name must be exactly { required_tool_name } ." )
259318 return f"{ base_prompt } \n \n Decision rules:\n - " + "\n - " .join (instructions )
260319
261320
@@ -285,8 +344,10 @@ def _coerce_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]
285344 arguments = json .loads (arguments )
286345 except json .JSONDecodeError :
287346 arguments = {"value" : arguments }
288- if not isinstance ( arguments , dict ) :
347+ if arguments is None :
289348 arguments = {}
349+ elif not isinstance (arguments , dict ):
350+ raise ValueError ("tool call arguments must decode to a JSON object" )
290351 normalized .append (
291352 {
292353 "call_id" : tool_call .get ("call_id" ) or f"call_{ uuid .uuid4 ().hex } " ,
@@ -625,7 +686,9 @@ def _respond_for_chat_request(
625686 payload : dict [str , Any ], * , backend : str , model : str , workdir : Path , request_id : str
626687) -> dict [str , Any ]:
627688 prompt = build_chat_prompt (payload )
628- if _normalize_tools (payload .get ("tools" )):
689+ if _normalize_tools (payload .get ("tools" )) and _tool_choice_allows_structured_tool_calls (
690+ payload .get ("tool_choice" )
691+ ):
629692 try :
630693 decision = run_backend_structured (
631694 backend = backend ,
@@ -634,6 +697,7 @@ def _respond_for_chat_request(
634697 workdir = workdir ,
635698 schema = DecisionSchema ,
636699 )
700+ _validate_tool_choice_decision (decision , payload )
637701 if decision .get ("type" ) == "tool_calls" :
638702 return build_chat_completion_response (
639703 model = model ,
@@ -656,7 +720,9 @@ def _respond_for_responses_request(
656720 payload : dict [str , Any ], * , backend : str , model : str , workdir : Path , request_id : str
657721) -> dict [str , Any ]:
658722 prompt = build_responses_prompt (payload )
659- if _normalize_tools (payload .get ("tools" )):
723+ if _normalize_tools (payload .get ("tools" )) and _tool_choice_allows_structured_tool_calls (
724+ payload .get ("tool_choice" )
725+ ):
660726 try :
661727 decision = run_backend_structured (
662728 backend = backend ,
@@ -665,6 +731,7 @@ def _respond_for_responses_request(
665731 workdir = workdir ,
666732 schema = DecisionSchema ,
667733 )
734+ _validate_tool_choice_decision (decision , payload )
668735 if decision .get ("type" ) == "tool_calls" :
669736 return build_responses_api_response (
670737 model = model ,
0 commit comments