Skip to content

Commit f923b13

Browse files
authored
feat: add run-context thread reuse for codex_tool (#2425)
1 parent 22c57b1 commit f923b13

File tree

5 files changed

+1285
-34
lines changed

5 files changed

+1285
-34
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import asyncio
2+
from collections.abc import Mapping
3+
from datetime import datetime
4+
5+
from pydantic import BaseModel
6+
7+
from agents import Agent, ModelSettings, Runner, gen_trace_id, trace
8+
9+
# This tool is still in experimental phase and the details could be changed until being GAed.
10+
from agents.extensions.experimental.codex import (
11+
CodexToolStreamEvent,
12+
ThreadErrorEvent,
13+
ThreadOptions,
14+
ThreadStartedEvent,
15+
TurnCompletedEvent,
16+
TurnFailedEvent,
17+
TurnStartedEvent,
18+
codex_tool,
19+
)
20+
21+
# Derived from codex_tool(name="codex_engineer") when run_context_thread_id_key is omitted.
22+
THREAD_ID_KEY = "codex_thread_id_engineer"
23+
24+
25+
async def on_codex_stream(payload: CodexToolStreamEvent) -> None:
26+
event = payload.event
27+
28+
if isinstance(event, ThreadStartedEvent):
29+
log(f"codex thread started: {event.thread_id}")
30+
return
31+
if isinstance(event, TurnStartedEvent):
32+
log("codex turn started")
33+
return
34+
if isinstance(event, TurnCompletedEvent):
35+
log(f"codex turn completed, usage: {event.usage}")
36+
return
37+
if isinstance(event, TurnFailedEvent):
38+
log(f"codex turn failed: {event.error.message}")
39+
return
40+
if isinstance(event, ThreadErrorEvent):
41+
log(f"codex stream error: {event.message}")
42+
43+
44+
def _timestamp() -> str:
45+
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
46+
47+
48+
def log(message: str) -> None:
49+
timestamp = _timestamp()
50+
lines = str(message).splitlines() or [""]
51+
for line in lines:
52+
print(f"{timestamp} {line}")
53+
54+
55+
def read_context_value(context: Mapping[str, str] | BaseModel, key: str) -> str | None:
56+
# either dict or pydantic model
57+
if isinstance(context, Mapping):
58+
return context.get(key)
59+
return getattr(context, key, None)
60+
61+
62+
async def main() -> None:
63+
agent = Agent(
64+
name="Codex Agent (same thread)",
65+
instructions=(
66+
"Always use the Codex tool answer the user's question. "
67+
"Even when you don't have enough context, the Codex tool may know. "
68+
"In that case, you can simply forward the question to the Codex tool."
69+
),
70+
tools=[
71+
codex_tool(
72+
# Give each Codex tool a unique `codex_` name when you run multiple tools in one agent.
73+
# Name-based defaults keep their run-context thread IDs separated.
74+
name="codex_engineer",
75+
sandbox_mode="workspace-write",
76+
default_thread_options=ThreadOptions(
77+
model="gpt-5.2-codex",
78+
model_reasoning_effort="low",
79+
network_access_enabled=True,
80+
web_search_enabled=False,
81+
approval_policy="never",
82+
),
83+
on_stream=on_codex_stream,
84+
# Reuse the same Codex thread across runs that share this context object.
85+
use_run_context_thread_id=True,
86+
)
87+
],
88+
model_settings=ModelSettings(tool_choice="required"),
89+
)
90+
91+
class MyContext(BaseModel):
92+
something: str | None = None
93+
# the default is "codex_thread_id"; missing this works as well
94+
codex_thread_id_engineer: str | None = None # aligns with run_context_thread_id_key
95+
96+
context = MyContext()
97+
98+
# Simple dict object works as well:
99+
# context: dict[str, str] = {}
100+
101+
trace_id = gen_trace_id()
102+
log(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}")
103+
104+
with trace("Codex same thread example", trace_id=trace_id):
105+
log("Turn 1: ask writing python code")
106+
first_prompt = "Write working python code example demonstrating how to call OpenAI's Responses API with web search tool."
107+
first_result = await Runner.run(agent, first_prompt, context=context)
108+
first_thread_id = read_context_value(context, THREAD_ID_KEY)
109+
log(first_result.final_output)
110+
log(f"thread id after turn 1: {first_thread_id}")
111+
112+
log("Turn 2: continue with the same Codex thread.")
113+
second_prompt = "Write the same code in TypeScript."
114+
second_result = await Runner.run(agent, second_prompt, context=context)
115+
second_thread_id = read_context_value(context, THREAD_ID_KEY)
116+
log(second_result.final_output)
117+
log(f"thread id after turn 2: {second_thread_id}")
118+
log(
119+
"same thread reused: "
120+
+ str(first_thread_id is not None and first_thread_id == second_thread_id)
121+
)
122+
123+
124+
if __name__ == "__main__":
125+
asyncio.run(main())

src/agents/agent.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
peek_agent_tool_run_result,
2626
record_agent_tool_run_result,
2727
)
28-
from .exceptions import ModelBehaviorError
28+
from .exceptions import ModelBehaviorError, UserError
2929
from .guardrail import InputGuardrail, OutputGuardrail
3030
from .handoffs import Handoff
3131
from .logger import logger
@@ -88,6 +88,32 @@ class ToolsToFinalOutputResult:
8888
"""
8989

9090

91+
def _validate_codex_tool_name_collisions(tools: list[Tool]) -> None:
92+
codex_tool_names = {
93+
tool.name
94+
for tool in tools
95+
if isinstance(tool, FunctionTool) and bool(getattr(tool, "_is_codex_tool", False))
96+
}
97+
if not codex_tool_names:
98+
return
99+
100+
name_counts: dict[str, int] = {}
101+
for tool in tools:
102+
tool_name = getattr(tool, "name", None)
103+
if isinstance(tool_name, str) and tool_name:
104+
name_counts[tool_name] = name_counts.get(tool_name, 0) + 1
105+
106+
duplicate_codex_names = sorted(
107+
name for name in codex_tool_names if name_counts.get(name, 0) > 1
108+
)
109+
if duplicate_codex_names:
110+
raise UserError(
111+
"Duplicate Codex tool names found: "
112+
+ ", ".join(duplicate_codex_names)
113+
+ ". Provide a unique codex_tool(name=...) per tool instance."
114+
)
115+
116+
91117
class AgentToolStreamEvent(TypedDict):
92118
"""Streaming event emitted when an agent is invoked as a tool."""
93119

@@ -182,7 +208,9 @@ async def _check_tool_enabled(tool: Tool) -> bool:
182208

183209
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
184210
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
185-
return [*mcp_tools, *enabled]
211+
all_tools: list[Tool] = [*mcp_tools, *enabled]
212+
_validate_codex_tool_name_collisions(all_tools)
213+
return all_tools
186214

187215

188216
@dataclass

0 commit comments

Comments
 (0)