Skip to content

Commit 9013480

Browse files
authored
fix: #2873 preserve computer driver compatibility for modifier keys (#2877)
1 parent 83b3833 commit 9013480

4 files changed

Lines changed: 371 additions & 43 deletions

File tree

examples/tools/computer_use.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import asyncio
66
import base64
77
import sys
8+
from collections.abc import AsyncIterator
9+
from contextlib import asynccontextmanager
810
from typing import Any, Literal
911

1012
from playwright.async_api import Browser, Page, Playwright, async_playwright
@@ -118,46 +120,77 @@ async def screenshot(self) -> str:
118120
png_bytes = await self.page.screenshot(full_page=False)
119121
return base64.b64encode(png_bytes).decode("utf-8")
120122

121-
async def click(self, x: int, y: int, button: Button = "left") -> None:
123+
def _normalize_keys(self, keys: list[str] | None) -> list[str]:
124+
if not keys:
125+
return []
126+
return [CUA_KEY_TO_PLAYWRIGHT_KEY.get(key.lower(), key) for key in keys]
127+
128+
@asynccontextmanager
129+
async def _hold_keys(self, keys: list[str] | None) -> AsyncIterator[None]:
130+
mapped_keys = self._normalize_keys(keys)
131+
try:
132+
for key in mapped_keys:
133+
await self.page.keyboard.down(key)
134+
yield
135+
finally:
136+
for key in reversed(mapped_keys):
137+
await self.page.keyboard.up(key)
138+
139+
async def click(
140+
self, x: int, y: int, button: Button = "left", *, keys: list[str] | None = None
141+
) -> None:
122142
playwright_button: Literal["left", "middle", "right"] = "left"
123143

124144
# Playwright only supports left, middle, right buttons
125145
if button in ("left", "right", "middle"):
126146
playwright_button = button # type: ignore
127147

128-
await self.page.mouse.click(x, y, button=playwright_button)
148+
async with self._hold_keys(keys):
149+
await self.page.mouse.click(x, y, button=playwright_button)
129150

130-
async def double_click(self, x: int, y: int) -> None:
131-
await self.page.mouse.dblclick(x, y)
151+
async def double_click(self, x: int, y: int, *, keys: list[str] | None = None) -> None:
152+
async with self._hold_keys(keys):
153+
await self.page.mouse.dblclick(x, y)
132154

133-
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
134-
await self.page.mouse.move(x, y)
135-
await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})")
155+
async def scroll(
156+
self,
157+
x: int,
158+
y: int,
159+
scroll_x: int,
160+
scroll_y: int,
161+
*,
162+
keys: list[str] | None = None,
163+
) -> None:
164+
async with self._hold_keys(keys):
165+
await self.page.mouse.move(x, y)
166+
await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})")
136167

137168
async def type(self, text: str) -> None:
138169
await self.page.keyboard.type(text)
139170

140171
async def wait(self) -> None:
141172
await asyncio.sleep(1)
142173

143-
async def move(self, x: int, y: int) -> None:
144-
await self.page.mouse.move(x, y)
174+
async def move(self, x: int, y: int, *, keys: list[str] | None = None) -> None:
175+
async with self._hold_keys(keys):
176+
await self.page.mouse.move(x, y)
145177

146178
async def keypress(self, keys: list[str]) -> None:
147-
mapped_keys = [CUA_KEY_TO_PLAYWRIGHT_KEY.get(key.lower(), key) for key in keys]
179+
mapped_keys = self._normalize_keys(keys)
148180
for key in mapped_keys:
149181
await self.page.keyboard.down(key)
150182
for key in reversed(mapped_keys):
151183
await self.page.keyboard.up(key)
152184

153-
async def drag(self, path: list[tuple[int, int]]) -> None:
185+
async def drag(self, path: list[tuple[int, int]], *, keys: list[str] | None = None) -> None:
154186
if not path:
155187
return
156-
await self.page.mouse.move(path[0][0], path[0][1])
157-
await self.page.mouse.down()
158-
for px, py in path[1:]:
159-
await self.page.mouse.move(px, py)
160-
await self.page.mouse.up()
188+
async with self._hold_keys(keys):
189+
await self.page.mouse.move(path[0][0], path[0][1])
190+
await self.page.mouse.down()
191+
for px, py in path[1:]:
192+
await self.page.mouse.move(px, py)
193+
await self.page.mouse.up()
161194

162195

163196
async def run_agent(

src/agents/computer.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77

88
class Computer(abc.ABC):
9-
"""A computer implemented with sync operations. The Computer interface abstracts the
10-
operations needed to control a computer or browser."""
9+
"""A computer implemented with sync operations.
10+
11+
Subclasses provide the local runtime behind `ComputerTool`. Mouse action methods may
12+
also accept a keyword-only `keys` argument to receive held modifier keys when the
13+
driver supports them.
14+
"""
1115

1216
@property
1317
def environment(self) -> Environment | None:
@@ -21,44 +25,57 @@ def dimensions(self) -> tuple[int, int] | None:
2125

2226
@abc.abstractmethod
2327
def screenshot(self) -> str:
28+
"""Return a base64-encoded PNG screenshot of the current display."""
2429
pass
2530

2631
@abc.abstractmethod
2732
def click(self, x: int, y: int, button: Button) -> None:
33+
"""Click `button` at the given `(x, y)` screen coordinates."""
2834
pass
2935

3036
@abc.abstractmethod
3137
def double_click(self, x: int, y: int) -> None:
38+
"""Double-click at the given `(x, y)` screen coordinates."""
3239
pass
3340

3441
@abc.abstractmethod
3542
def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
43+
"""Scroll at `(x, y)` by `(scroll_x, scroll_y)` units."""
3644
pass
3745

3846
@abc.abstractmethod
3947
def type(self, text: str) -> None:
48+
"""Type `text` into the currently focused target."""
4049
pass
4150

4251
@abc.abstractmethod
4352
def wait(self) -> None:
53+
"""Wait until the computer is ready for the next action."""
4454
pass
4555

4656
@abc.abstractmethod
4757
def move(self, x: int, y: int) -> None:
58+
"""Move the mouse cursor to the given `(x, y)` screen coordinates."""
4859
pass
4960

5061
@abc.abstractmethod
5162
def keypress(self, keys: list[str]) -> None:
63+
"""Press the provided keys, such as `["ctrl", "c"]`."""
5264
pass
5365

5466
@abc.abstractmethod
5567
def drag(self, path: list[tuple[int, int]]) -> None:
68+
"""Click-and-drag the mouse along the given sequence of `(x, y)` waypoints."""
5669
pass
5770

5871

5972
class AsyncComputer(abc.ABC):
60-
"""A computer implemented with async operations. The Computer interface abstracts the
61-
operations needed to control a computer or browser."""
73+
"""A computer implemented with async operations.
74+
75+
Subclasses provide the local runtime behind `ComputerTool`. Mouse action methods may
76+
also accept a keyword-only `keys` argument to receive held modifier keys when the
77+
driver supports them.
78+
"""
6279

6380
@property
6481
def environment(self) -> Environment | None:
@@ -72,36 +89,45 @@ def dimensions(self) -> tuple[int, int] | None:
7289

7390
@abc.abstractmethod
7491
async def screenshot(self) -> str:
92+
"""Return a base64-encoded PNG screenshot of the current display."""
7593
pass
7694

7795
@abc.abstractmethod
7896
async def click(self, x: int, y: int, button: Button) -> None:
97+
"""Click `button` at the given `(x, y)` screen coordinates."""
7998
pass
8099

81100
@abc.abstractmethod
82101
async def double_click(self, x: int, y: int) -> None:
102+
"""Double-click at the given `(x, y)` screen coordinates."""
83103
pass
84104

85105
@abc.abstractmethod
86106
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
107+
"""Scroll at `(x, y)` by `(scroll_x, scroll_y)` units."""
87108
pass
88109

89110
@abc.abstractmethod
90111
async def type(self, text: str) -> None:
112+
"""Type `text` into the currently focused target."""
91113
pass
92114

93115
@abc.abstractmethod
94116
async def wait(self) -> None:
117+
"""Wait until the computer is ready for the next action."""
95118
pass
96119

97120
@abc.abstractmethod
98121
async def move(self, x: int, y: int) -> None:
122+
"""Move the mouse cursor to the given `(x, y)` screen coordinates."""
99123
pass
100124

101125
@abc.abstractmethod
102126
async def keypress(self, keys: list[str]) -> None:
127+
"""Press the provided keys, such as `["ctrl", "c"]`."""
103128
pass
104129

105130
@abc.abstractmethod
106131
async def drag(self, path: list[tuple[int, int]]) -> None:
132+
"""Click-and-drag the mouse along the given sequence of `(x, y)` waypoints."""
107133
pass

src/agents/run_internal/tool_actions.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,30 +189,38 @@ async def _execute_action_and_capture(
189189
) -> str:
190190
"""Execute computer actions (sync or async drivers) and return the final screenshot."""
191191

192-
async def maybe_call(method_name: str, *args: Any) -> Any:
192+
async def maybe_call(method_name: str, *args: Any, **kwargs: Any) -> Any:
193193
method = getattr(computer, method_name, None)
194194
if method is None or not callable(method):
195195
raise ModelBehaviorError(f"Computer driver missing method {method_name}")
196-
result = method(*args)
196+
filtered_kwargs = cls._filter_supported_kwargs(
197+
method_name=method_name,
198+
method=method,
199+
kwargs=kwargs,
200+
)
201+
result = method(*args, **filtered_kwargs)
197202
return await result if inspect.isawaitable(result) else result
198203

199204
last_action_was_screenshot = False
200205
last_screenshot_result: Any = None
201206
for action in cls._iter_actions(tool_call):
202207
action_type = get_mapping_or_attr(action, "type")
208+
action_keys = cls._normalize_modifier_keys(get_mapping_or_attr(action, "keys"))
203209
last_action_was_screenshot = False
204210
if action_type == "click":
205211
await maybe_call(
206212
"click",
207213
get_mapping_or_attr(action, "x"),
208214
get_mapping_or_attr(action, "y"),
209215
get_mapping_or_attr(action, "button"),
216+
keys=action_keys,
210217
)
211218
elif action_type == "double_click":
212219
await maybe_call(
213220
"double_click",
214221
get_mapping_or_attr(action, "x"),
215222
get_mapping_or_attr(action, "y"),
223+
keys=action_keys,
216224
)
217225
elif action_type == "drag":
218226
path = get_mapping_or_attr(action, "path") or []
@@ -225,6 +233,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
225233
)
226234
for point in path
227235
],
236+
keys=action_keys,
228237
)
229238
elif action_type == "keypress":
230239
await maybe_call("keypress", get_mapping_or_attr(action, "keys"))
@@ -233,6 +242,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
233242
"move",
234243
get_mapping_or_attr(action, "x"),
235244
get_mapping_or_attr(action, "y"),
245+
keys=action_keys,
236246
)
237247
elif action_type == "screenshot":
238248
last_screenshot_result = await maybe_call("screenshot")
@@ -244,6 +254,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
244254
get_mapping_or_attr(action, "y"),
245255
get_mapping_or_attr(action, "scroll_x"),
246256
get_mapping_or_attr(action, "scroll_y"),
257+
keys=action_keys,
247258
)
248259
elif action_type == "type":
249260
await maybe_call("type", get_mapping_or_attr(action, "text"))
@@ -289,6 +300,64 @@ def _serialize_action_payload(action: Any) -> Any:
289300
return dataclasses.asdict(action)
290301
return action
291302

303+
@staticmethod
304+
def _normalize_modifier_keys(keys: Any) -> list[str] | None:
305+
if not keys:
306+
return None
307+
return cast(list[str], keys)
308+
309+
@classmethod
310+
def _filter_supported_kwargs(
311+
cls,
312+
*,
313+
method_name: str,
314+
method: Any,
315+
kwargs: dict[str, Any],
316+
) -> dict[str, Any]:
317+
filtered_kwargs = {key: value for key, value in kwargs.items() if value is not None}
318+
if not filtered_kwargs:
319+
return {}
320+
321+
supported_kwargs = cls._supported_keyword_arguments(method)
322+
unsupported_kwargs = [
323+
key
324+
for key in filtered_kwargs
325+
if key not in supported_kwargs and None not in supported_kwargs
326+
]
327+
if unsupported_kwargs:
328+
logger.warning(
329+
"Computer driver method %r does not accept keyword argument(s) %s; "
330+
"dropping them and continuing.",
331+
method_name,
332+
", ".join(sorted(unsupported_kwargs)),
333+
)
334+
for key in unsupported_kwargs:
335+
filtered_kwargs.pop(key, None)
336+
337+
return filtered_kwargs
338+
339+
@staticmethod
340+
def _supported_keyword_arguments(method: Any) -> set[str | None]:
341+
try:
342+
signature = inspect.signature(method)
343+
except (TypeError, ValueError):
344+
return set()
345+
supported: set[str | None] = {
346+
parameter.name
347+
for parameter in signature.parameters.values()
348+
if parameter.kind
349+
in {
350+
inspect.Parameter.KEYWORD_ONLY,
351+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
352+
}
353+
}
354+
if any(
355+
parameter.kind == inspect.Parameter.VAR_KEYWORD
356+
for parameter in signature.parameters.values()
357+
):
358+
supported.add(None)
359+
return supported
360+
292361

293362
class LocalShellAction:
294363
"""Execute local shell commands via the LocalShellTool with lifecycle hooks."""

0 commit comments

Comments
 (0)