Skip to content

Commit d33cc1e

Browse files
committed
fix review comments
1 parent 202de56 commit d33cc1e

6 files changed

Lines changed: 222 additions & 25 deletions

File tree

src/agents/memory/openai_responses_compaction_session.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@
2828
OpenAIResponsesCompactionMode = Literal["previous_response_id", "input", "auto"]
2929

3030

31+
def _is_user_message_item(item: TResponseInputItem) -> bool:
32+
if not isinstance(item, dict):
33+
return False
34+
if item.get("type") == "message":
35+
return item.get("role") == "user"
36+
return item.get("role") == "user" and "content" in item
37+
38+
3139
def select_compaction_candidate_items(
3240
items: list[TResponseInputItem],
3341
) -> list[TResponseInputItem]:
@@ -36,18 +44,12 @@ def select_compaction_candidate_items(
3644
Excludes user messages and compaction items.
3745
"""
3846

39-
def _is_user_message(item: TResponseInputItem) -> bool:
40-
if not isinstance(item, dict):
41-
return False
42-
if item.get("type") == "message":
43-
return item.get("role") == "user"
44-
return item.get("role") == "user" and "content" in item
45-
4647
return [
4748
item
4849
for item in items
4950
if not (
50-
_is_user_message(item) or (isinstance(item, dict) and item.get("type") == "compaction")
51+
_is_user_message_item(item)
52+
or (isinstance(item, dict) and item.get("type") == "compaction")
5153
)
5254
]
5355

@@ -273,12 +275,12 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
273275
)
274276
return
275277

276-
unresolved_function_calls = _find_unresolved_function_calls_without_results(session_items)
277-
if unresolved_function_calls:
278+
frontier_unresolved_function_calls = _find_frontier_unresolved_function_calls(session_items)
279+
if frontier_unresolved_function_calls:
278280
logger.debug(
279281
"compact: blocked unresolved function calls for %s: %s",
280282
self._response_id,
281-
unresolved_function_calls,
283+
frontier_unresolved_function_calls,
282284
)
283285
return
284286

@@ -476,12 +478,19 @@ def _normalize_compaction_session_items(
476478
_ResolvedCompactionMode = Literal["previous_response_id", "input"]
477479

478480

479-
def _find_unresolved_function_calls_without_results(items: list[TResponseInputItem]) -> list[str]:
480-
"""Return function-call ids that do not yet have matching outputs."""
481-
function_calls: dict[str, TResponseInputItem] = {}
481+
def _find_frontier_unresolved_function_calls(items: list[TResponseInputItem]) -> list[str]:
482+
"""Return unresolved function-call ids that remain in the active conversation frontier.
483+
484+
Once a later user message appears, earlier unresolved tool calls are considered abandoned and
485+
should no longer block future compaction for the session.
486+
"""
487+
function_call_indices: dict[str, int] = {}
482488
resolved_call_ids: set[str] = set()
489+
last_user_message_index = -1
483490

484-
for item in items:
491+
for index, item in enumerate(items):
492+
if _is_user_message_item(item):
493+
last_user_message_index = index
485494
if isinstance(item, dict):
486495
item_type = item.get("type")
487496
call_id = item.get("call_id")
@@ -492,11 +501,15 @@ def _find_unresolved_function_calls_without_results(items: list[TResponseInputIt
492501
if not isinstance(call_id, str):
493502
continue
494503
if item_type == "function_call":
495-
function_calls[call_id] = item
504+
function_call_indices[call_id] = index
496505
elif item_type == "function_call_output":
497506
resolved_call_ids.add(call_id)
498507

499-
return [call_id for call_id in function_calls if call_id not in resolved_call_ids]
508+
return [
509+
call_id
510+
for call_id, index in function_call_indices.items()
511+
if call_id not in resolved_call_ids and index > last_user_message_index
512+
]
500513

501514

502515
def _resolve_compaction_mode(

src/agents/result.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,36 @@ def _populate_state_from_result(
107107
if trace_state is None:
108108
trace_state = TraceState.from_trace(getattr(result, "trace", None))
109109
state._trace_state = copy.deepcopy(trace_state) if trace_state else None
110-
state._trace_include_sensitive_data = getattr(
111-
source_state,
112-
"_trace_include_sensitive_data",
113-
True,
110+
trace_include_sensitive_data_snapshot = getattr(
111+
result,
112+
"_trace_include_sensitive_data_snapshot",
113+
None,
114114
)
115-
if isinstance(source_state, RunState):
115+
if trace_include_sensitive_data_snapshot is not None:
116+
state._trace_include_sensitive_data = trace_include_sensitive_data_snapshot
117+
else:
118+
state._trace_include_sensitive_data = getattr(
119+
source_state,
120+
"_trace_include_sensitive_data",
121+
True,
122+
)
123+
124+
session_history_mutations_snapshot = getattr(
125+
result,
126+
"_session_history_mutations_snapshot",
127+
None,
128+
)
129+
execution_only_approval_override_call_ids_snapshot = getattr(
130+
result,
131+
"_execution_only_approval_override_call_ids_snapshot",
132+
None,
133+
)
134+
if session_history_mutations_snapshot is not None:
135+
state._session_history_mutations = copy.deepcopy(session_history_mutations_snapshot)
136+
state._execution_only_approval_override_call_ids = list(
137+
execution_only_approval_override_call_ids_snapshot or []
138+
)
139+
elif isinstance(source_state, RunState):
116140
state._session_history_mutations = source_state.get_session_history_mutations()
117141
state._execution_only_approval_override_call_ids = list(
118142
source_state._execution_only_approval_override_call_ids
@@ -332,6 +356,15 @@ class RunResult(RunResultBase):
332356
to preserve the correct originalInput when serializing state."""
333357
_state: Any = field(default=None, repr=False)
334358
"""Internal reference to the originating RunState when available."""
359+
_trace_include_sensitive_data_snapshot: bool | None = field(default=None, repr=False)
360+
"""Snapshot of the trace redaction setting used when rebuilding state from a completed
361+
result."""
362+
_session_history_mutations_snapshot: list[Any] | None = field(default=None, repr=False)
363+
"""Snapshot of pending session-history rewrites needed by `to_state()`."""
364+
_execution_only_approval_override_call_ids_snapshot: list[str] | None = field(
365+
default=None, repr=False
366+
)
367+
"""Snapshot of execution-only approval overrides needed by `to_state()`."""
335368
_conversation_id: str | None = field(default=None, repr=False)
336369
"""Conversation identifier for server-managed runs."""
337370
_previous_response_id: str | None = field(default=None, repr=False)

src/agents/run_internal/agent_runner_helpers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import copy
56
from typing import Any, cast
67

78
from ..agent import Agent
@@ -185,9 +186,16 @@ def resolve_trace_include_sensitive_data(
185186
run_config: RunConfig,
186187
run_config_was_supplied: bool,
187188
) -> bool:
188-
"""Resolve whether traces may include sensitive data for this run."""
189-
if run_state is None or run_config_was_supplied:
189+
"""Resolve whether traces may include sensitive data for this run.
190+
191+
Resumed runs preserve the stored setting unless the new RunConfig explicitly narrows it by
192+
setting `trace_include_sensitive_data=False`.
193+
"""
194+
del run_config_was_supplied
195+
if run_state is None:
190196
return run_config.trace_include_sensitive_data
197+
if run_config.trace_include_sensitive_data is False:
198+
return False
191199
return run_state._trace_include_sensitive_data
192200

193201

@@ -295,9 +303,15 @@ def attach_run_state_metadata(result: RunResult, *, run_state: RunState | None)
295303
if run_state is None:
296304
return result
297305

298-
result._state = run_state
299306
result._current_turn_persisted_item_count = run_state._current_turn_persisted_item_count
300307
result._trace_state = run_state._trace_state
308+
result._trace_include_sensitive_data_snapshot = run_state._trace_include_sensitive_data
309+
result._session_history_mutations_snapshot = copy.deepcopy(
310+
run_state.get_session_history_mutations()
311+
)
312+
result._execution_only_approval_override_call_ids_snapshot = list(
313+
run_state._execution_only_approval_override_call_ids
314+
)
301315
return result
302316

303317

tests/memory/test_openai_responses_compaction_session.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,81 @@ async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None:
623623
assert second_kwargs.get("previous_response_id") == "resp-stored"
624624
assert "input" not in second_kwargs
625625

626+
@pytest.mark.asyncio
627+
async def test_run_compaction_ignores_abandoned_unresolved_function_calls(self) -> None:
628+
mock_session = self.create_mock_session()
629+
items: list[TResponseInputItem] = [
630+
cast(TResponseInputItem, {"type": "message", "role": "user", "content": "first"}),
631+
cast(
632+
TResponseInputItem,
633+
{
634+
"type": "function_call",
635+
"call_id": "call-abandoned",
636+
"id": "fc_1",
637+
"name": "test_tool",
638+
"arguments": "{}",
639+
},
640+
),
641+
cast(TResponseInputItem, {"type": "message", "role": "user", "content": "followup"}),
642+
cast(
643+
TResponseInputItem,
644+
{"type": "message", "role": "assistant", "content": "latest response"},
645+
),
646+
]
647+
mock_session.get_items.return_value = items
648+
649+
mock_compact_response = MagicMock()
650+
mock_compact_response.output = []
651+
652+
mock_client = MagicMock()
653+
mock_client.responses.compact = AsyncMock(return_value=mock_compact_response)
654+
655+
session = OpenAIResponsesCompactionSession(
656+
session_id="test",
657+
underlying_session=mock_session,
658+
client=mock_client,
659+
compaction_mode="auto",
660+
)
661+
662+
await session.run_compaction({"response_id": "resp-latest", "force": True})
663+
664+
mock_client.responses.compact.assert_called_once_with(
665+
previous_response_id="resp-latest",
666+
model="gpt-4.1",
667+
)
668+
669+
@pytest.mark.asyncio
670+
async def test_run_compaction_still_blocks_active_unresolved_function_calls(self) -> None:
671+
mock_session = self.create_mock_session()
672+
items: list[TResponseInputItem] = [
673+
cast(TResponseInputItem, {"type": "message", "role": "user", "content": "hello"}),
674+
cast(
675+
TResponseInputItem,
676+
{
677+
"type": "function_call",
678+
"call_id": "call-pending",
679+
"id": "fc_1",
680+
"name": "test_tool",
681+
"arguments": "{}",
682+
},
683+
),
684+
]
685+
mock_session.get_items.return_value = items
686+
687+
mock_client = MagicMock()
688+
mock_client.responses.compact = AsyncMock()
689+
690+
session = OpenAIResponsesCompactionSession(
691+
session_id="test",
692+
underlying_session=mock_session,
693+
client=mock_client,
694+
compaction_mode="auto",
695+
)
696+
697+
await session.run_compaction({"response_id": "resp-pending", "force": True})
698+
699+
mock_client.responses.compact.assert_not_called()
700+
626701
@pytest.mark.asyncio
627702
async def test_run_compaction_auto_uses_input_when_last_response_unstored(self) -> None:
628703
mock_session = self.create_mock_session()

tests/test_agent_tracing.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,45 @@ def send_email(recipient: str) -> str:
410410
assert function_span["span_data"]["output"] is None
411411

412412

413+
@pytest.mark.asyncio
414+
async def test_resumed_run_preserves_sensitive_trace_flag_for_unrelated_run_config() -> None:
415+
model = FakeModel()
416+
417+
@function_tool(name_override="send_email", needs_approval=True)
418+
def send_email(recipient: str) -> str:
419+
return recipient
420+
421+
agent = Agent(name="trace_agent", model=model, tools=[send_email])
422+
model.add_multiple_turn_outputs(
423+
[
424+
[
425+
get_function_tool_call(
426+
"send_email", '{"recipient":"alice@example.com"}', call_id="call-1"
427+
)
428+
],
429+
[get_text_message("done")],
430+
]
431+
)
432+
433+
first = await Runner.run(agent, input="first_test")
434+
assert first.interruptions
435+
436+
state = first.to_state()
437+
state.set_trace_include_sensitive_data(False)
438+
state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"})
439+
440+
resumed = await Runner.run(
441+
agent,
442+
state,
443+
run_config=RunConfig(workflow_name="override_workflow"),
444+
)
445+
446+
assert resumed.final_output == "done"
447+
function_span = _get_last_function_span_export("send_email")
448+
assert function_span["span_data"]["input"] is None
449+
assert function_span["span_data"]["output"] is None
450+
451+
413452
@pytest.mark.asyncio
414453
async def test_wrapped_trace_is_single_trace():
415454
model = FakeModel()

tests/test_result_cast.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
MessageOutputItem,
1616
RunContextWrapper,
1717
RunItem,
18+
Runner,
1819
RunResult,
1920
RunResultStreaming,
2021
)
2122
from agents.exceptions import AgentsException
2223
from agents.tool_context import ToolContext
2324

25+
from .fake_model import FakeModel
26+
from .test_responses import get_text_message
27+
2428

2529
def create_run_result(
2630
final_output: Any | None,
@@ -261,6 +265,25 @@ def test_run_result_streaming_release_agents_releases_current_agent() -> None:
261265
_ = streaming_result.last_agent
262266

263267

268+
@pytest.mark.asyncio
269+
async def test_runner_result_does_not_retain_live_run_state() -> None:
270+
agent = Agent(
271+
name="runner-result-agent",
272+
model=FakeModel(initial_output=[get_text_message("done")]),
273+
)
274+
275+
result = await Runner.run(agent, "hello")
276+
277+
assert result._state is None
278+
279+
agent_ref = weakref.ref(agent)
280+
result.release_agents()
281+
del agent
282+
gc.collect()
283+
284+
assert agent_ref() is None
285+
286+
264287
def test_run_result_agent_tool_invocation_returns_none_for_plain_context() -> None:
265288
result = create_run_result("ok")
266289

0 commit comments

Comments
 (0)