Skip to content

Commit c2ab16a

Browse files
committed
fix review comments
1 parent 2bb98d8 commit c2ab16a

5 files changed

Lines changed: 200 additions & 13 deletions

File tree

src/agents/memory/openai_responses_compaction_session.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,6 @@ def _resolve_compaction_mode(
168168
if not self._has_pending_local_history_rewrite:
169169
return resolved_mode
170170

171-
if (
172-
self._local_history_rewrite_response_id is not None
173-
and response_id is not None
174-
and response_id != self._local_history_rewrite_response_id
175-
):
176-
self._has_pending_local_history_rewrite = False
177-
self._local_history_rewrite_response_id = None
178-
return resolved_mode
179-
180171
if resolved_mode == "previous_response_id":
181172
if self._local_history_rewrite_response_id is None and response_id is not None:
182173
self._local_history_rewrite_response_id = response_id
@@ -321,6 +312,8 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
321312

322313
self._compaction_candidate_items = select_compaction_candidate_items(output_items)
323314
self._session_items = output_items
315+
if resolved_mode == "input":
316+
self._clear_pending_local_history_rewrite()
324317

325318
logger.debug(
326319
f"compact: done for {self._response_id} "
@@ -435,6 +428,10 @@ def _mark_local_history_rewrite(self) -> None:
435428
self._has_pending_local_history_rewrite = True
436429
self._local_history_rewrite_response_id = self._response_id
437430

431+
def _clear_pending_local_history_rewrite(self) -> None:
432+
self._has_pending_local_history_rewrite = False
433+
self._local_history_rewrite_response_id = None
434+
438435

439436
_ResolvedCompactionMode = Literal["previous_response_id", "input"]
440437

src/agents/run.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
input_guardrails_triggered,
5656
resolve_processed_response,
5757
resolve_resumed_context,
58+
resolve_trace_include_sensitive_data,
5859
resolve_trace_settings,
5960
save_turn_items_if_needed,
6061
should_cancel_parallel_model_task_on_input_guardrail_trip,
@@ -412,6 +413,7 @@ async def run(
412413
auto_previous_response_id = kwargs.get("auto_previous_response_id", False)
413414
conversation_id = kwargs.get("conversation_id")
414415
session = kwargs.get("session")
416+
run_config_was_supplied = run_config is not None
415417

416418
if run_config is None:
417419
run_config = RunConfig()
@@ -511,10 +513,15 @@ async def run(
511513
history_is_server_managed=history_is_server_managed,
512514
)
513515

514-
if is_resumed_state and run_state is not None:
516+
resolved_trace_include_sensitive_data = resolve_trace_include_sensitive_data(
517+
run_state=run_state,
518+
run_config=run_config,
519+
run_config_was_supplied=run_config_was_supplied,
520+
)
521+
if resolved_trace_include_sensitive_data != run_config.trace_include_sensitive_data:
515522
run_config = dataclasses.replace(
516523
run_config,
517-
trace_include_sensitive_data=run_state._trace_include_sensitive_data,
524+
trace_include_sensitive_data=resolved_trace_include_sensitive_data,
518525
)
519526

520527
resolved_reasoning_item_id_policy: ReasoningItemIdPolicy | None = (
@@ -1462,6 +1469,7 @@ def run_streamed(
14621469
auto_previous_response_id = kwargs.get("auto_previous_response_id", False)
14631470
conversation_id = kwargs.get("conversation_id")
14641471
session = kwargs.get("session")
1472+
run_config_was_supplied = run_config is not None
14651473

14661474
if run_config is None:
14671475
run_config = RunConfig()
@@ -1553,10 +1561,15 @@ def run_streamed(
15531561
session=session,
15541562
history_is_server_managed=history_is_server_managed,
15551563
)
1556-
if is_resumed_state and run_state is not None:
1564+
resolved_trace_include_sensitive_data = resolve_trace_include_sensitive_data(
1565+
run_state=run_state,
1566+
run_config=run_config,
1567+
run_config_was_supplied=run_config_was_supplied,
1568+
)
1569+
if resolved_trace_include_sensitive_data != run_config.trace_include_sensitive_data:
15571570
run_config = dataclasses.replace(
15581571
run_config,
1559-
trace_include_sensitive_data=run_state._trace_include_sensitive_data,
1572+
trace_include_sensitive_data=resolved_trace_include_sensitive_data,
15601573
)
15611574

15621575
resolved_reasoning_item_id_policy: ReasoningItemIdPolicy | None = (

src/agents/run_internal/agent_runner_helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"ensure_context_wrapper",
4444
"finalize_conversation_tracking",
4545
"input_guardrails_triggered",
46+
"resolve_trace_include_sensitive_data",
4647
"validate_session_conversation_settings",
4748
"resolve_trace_settings",
4849
"resolve_processed_response",
@@ -178,6 +179,18 @@ def resolve_trace_settings(
178179
return workflow_name, trace_id, group_id, metadata, tracing
179180

180181

182+
def resolve_trace_include_sensitive_data(
183+
*,
184+
run_state: RunState[TContext] | None,
185+
run_config: RunConfig,
186+
run_config_was_supplied: bool,
187+
) -> bool:
188+
"""Resolve whether traces may include sensitive data for this run."""
189+
if run_state is None or run_config_was_supplied:
190+
return run_config.trace_include_sensitive_data
191+
return run_state._trace_include_sensitive_data
192+
193+
181194
def resolve_resumed_context(
182195
*,
183196
run_state: RunState[TContext],

tests/memory/test_openai_responses_compaction_session.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,75 @@ async def test_run_compaction_forces_input_mode_after_local_history_rewrite(self
416416
second_call_kwargs = mock_client.responses.compact.call_args.kwargs
417417
assert second_call_kwargs.get("previous_response_id") == "resp-2"
418418

419+
@pytest.mark.asyncio
420+
async def test_run_compaction_keeps_local_rewrite_pending_until_input_compaction_succeeds(
421+
self,
422+
) -> None:
423+
underlying = RewriteAwareSimpleSession(
424+
history=[
425+
cast(TResponseInputItem, {"type": "message", "role": "user", "content": "hello"}),
426+
cast(
427+
TResponseInputItem,
428+
{
429+
"type": "function_call",
430+
"call_id": "call-1",
431+
"id": "fc_1",
432+
"name": "test_tool",
433+
"arguments": '{"value":"foo"}',
434+
},
435+
),
436+
cast(
437+
TResponseInputItem,
438+
{
439+
"type": "function_call_output",
440+
"call_id": "call-1",
441+
"output": "ok",
442+
},
443+
),
444+
]
445+
)
446+
mock_compact_response = MagicMock()
447+
mock_compact_response.output = []
448+
mock_client = MagicMock()
449+
mock_client.responses.compact = AsyncMock(return_value=mock_compact_response)
450+
session = OpenAIResponsesCompactionSession(
451+
session_id="test",
452+
underlying_session=underlying,
453+
client=mock_client,
454+
compaction_mode="auto",
455+
)
456+
457+
await session.apply_history_mutations(
458+
{
459+
"mutations": [
460+
{
461+
"type": "replace_function_call",
462+
"call_id": "call-1",
463+
"replacement": cast(
464+
TResponseInputItem,
465+
{
466+
"type": "function_call",
467+
"call_id": "call-1",
468+
"id": "fc_1",
469+
"name": "test_tool",
470+
"arguments": '{"value":"bar"}',
471+
},
472+
),
473+
}
474+
]
475+
}
476+
)
477+
478+
await session.run_compaction({"response_id": "resp-1"})
479+
mock_client.responses.compact.assert_not_called()
480+
481+
await session.run_compaction({"response_id": "resp-2", "force": True})
482+
483+
call_kwargs = mock_client.responses.compact.call_args.kwargs
484+
assert "previous_response_id" not in call_kwargs
485+
assert isinstance(call_kwargs.get("input"), list)
486+
assert cast(dict[str, Any], call_kwargs["input"][1])["arguments"] == '{"value":"bar"}'
487+
419488
@pytest.mark.asyncio
420489
async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None:
421490
mock_session = self.create_mock_session()

tests/test_agent_tracing.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
from typing import Any
45
from uuid import uuid4
56

67
import pytest
@@ -27,6 +28,18 @@ def approval_tool() -> str:
2728
return Agent(name="test_agent", model=model, tools=[approval_tool])
2829

2930

31+
def _get_last_function_span_export(name: str) -> dict[str, Any]:
32+
matching_spans = [
33+
exported
34+
for span in fetch_ordered_spans()
35+
if (exported := span.export()) is not None
36+
and exported["span_data"]["type"] == "function"
37+
and exported["span_data"]["name"] == name
38+
]
39+
assert matching_spans
40+
return matching_spans[-1]
41+
42+
3043
@pytest.mark.asyncio
3144
async def test_single_run_is_single_trace():
3245
agent = Agent(
@@ -358,6 +371,45 @@ async def test_completed_result_to_state_preserves_sensitive_trace_flag() -> Non
358371
assert state._trace_include_sensitive_data is False
359372

360373

374+
@pytest.mark.asyncio
375+
async def test_resumed_run_honors_explicit_trace_include_sensitive_data() -> None:
376+
model = FakeModel()
377+
378+
@function_tool(name_override="send_email", needs_approval=True)
379+
def send_email(recipient: str) -> str:
380+
return recipient
381+
382+
agent = Agent(name="trace_agent", model=model, tools=[send_email])
383+
model.add_multiple_turn_outputs(
384+
[
385+
[
386+
get_function_tool_call(
387+
"send_email", '{"recipient":"alice@example.com"}', call_id="call-1"
388+
)
389+
],
390+
[get_text_message("done")],
391+
]
392+
)
393+
394+
first = await Runner.run(agent, input="first_test")
395+
assert first.interruptions
396+
397+
state = first.to_state()
398+
state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"})
399+
400+
resumed = await Runner.run(
401+
agent,
402+
state,
403+
run_config=RunConfig(trace_include_sensitive_data=False),
404+
)
405+
406+
assert resumed.final_output == "done"
407+
assert state._trace_include_sensitive_data is False
408+
function_span = _get_last_function_span_export("send_email")
409+
assert function_span["span_data"]["input"] is None
410+
assert function_span["span_data"]["output"] is None
411+
412+
361413
@pytest.mark.asyncio
362414
async def test_wrapped_trace_is_single_trace():
363415
model = FakeModel()
@@ -643,6 +695,49 @@ async def test_resumed_streaming_run_reuses_original_trace_without_duplicate_tra
643695
assert all(span.trace_id == traces[0].trace_id for span in fetch_ordered_spans())
644696

645697

698+
@pytest.mark.asyncio
699+
async def test_resumed_streaming_run_honors_explicit_trace_include_sensitive_data() -> None:
700+
model = FakeModel()
701+
702+
@function_tool(name_override="send_email", needs_approval=True)
703+
def send_email(recipient: str) -> str:
704+
return recipient
705+
706+
agent = Agent(name="trace_agent", model=model, tools=[send_email])
707+
model.add_multiple_turn_outputs(
708+
[
709+
[
710+
get_function_tool_call(
711+
"send_email", '{"recipient":"alice@example.com"}', call_id="call-1"
712+
)
713+
],
714+
[get_text_message("done")],
715+
]
716+
)
717+
718+
first = Runner.run_streamed(agent, input="first_test")
719+
async for _ in first.stream_events():
720+
pass
721+
assert first.interruptions
722+
723+
state = first.to_state()
724+
state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"})
725+
726+
resumed = Runner.run_streamed(
727+
agent,
728+
state,
729+
run_config=RunConfig(trace_include_sensitive_data=False),
730+
)
731+
async for _ in resumed.stream_events():
732+
pass
733+
734+
assert resumed.final_output == "done"
735+
assert state._trace_include_sensitive_data is False
736+
function_span = _get_last_function_span_export("send_email")
737+
assert function_span["span_data"]["input"] is None
738+
assert function_span["span_data"]["output"] is None
739+
740+
646741
@pytest.mark.asyncio
647742
async def test_wrapped_streaming_trace_is_single_trace():
648743
model = FakeModel()

0 commit comments

Comments
 (0)