Skip to content

Commit c96de5e

Browse files
authored
feat: support Annotated[T, Field(...)] in function schema (#2435)
1 parent 195b53d commit c96de5e

File tree

2 files changed

+204
-1
lines changed

2 files changed

+204
-1
lines changed

src/agents/function_schema.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,15 @@ def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
210210
return None
211211

212212

213+
def _extract_field_info_from_metadata(metadata: tuple[Any, ...]) -> FieldInfo | None:
214+
"""Returns the first FieldInfo in Annotated metadata, or None."""
215+
216+
for item in metadata:
217+
if isinstance(item, FieldInfo):
218+
return item
219+
return None
220+
221+
213222
def function_schema(
214223
func: Callable[..., Any],
215224
docstring_style: DocstringStyle | None = None,
@@ -252,13 +261,15 @@ def function_schema(
252261
type_hints_with_extras = get_type_hints(func, include_extras=True)
253262
type_hints: dict[str, Any] = {}
254263
annotated_param_descs: dict[str, str] = {}
264+
param_metadata: dict[str, tuple[Any, ...]] = {}
255265

256266
for name, annotation in type_hints_with_extras.items():
257267
if name == "return":
258268
continue
259269

260270
stripped_ann, metadata = _strip_annotated(annotation)
261271
type_hints[name] = stripped_ann
272+
param_metadata[name] = metadata
262273

263274
description = _extract_description_from_metadata(metadata)
264275
if description is not None:
@@ -356,7 +367,20 @@ def function_schema(
356367

357368
else:
358369
# Normal parameter
359-
if default == inspect._empty:
370+
metadata = param_metadata.get(name, ())
371+
field_info_from_annotated = _extract_field_info_from_metadata(metadata)
372+
373+
if field_info_from_annotated is not None:
374+
merged = FieldInfo.merge_field_infos(
375+
field_info_from_annotated,
376+
description=field_description or field_info_from_annotated.description,
377+
)
378+
if default != inspect._empty and not isinstance(default, FieldInfo):
379+
merged = FieldInfo.merge_field_infos(merged, default=default)
380+
elif isinstance(default, FieldInfo):
381+
merged = FieldInfo.merge_field_infos(merged, default)
382+
fields[name] = (ann, merged)
383+
elif default == inspect._empty:
360384
# Required field
361385
fields[name] = (
362386
ann,

tests/test_function_schema.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,182 @@ def func_with_multiple_field_constraints(
706706

707707
with pytest.raises(ValidationError): # zero factor
708708
fs.params_pydantic_model(**{"score": 50, "factor": 0.0})
709+
710+
711+
# --- Annotated + Field: same behavior as Field as default ---
712+
713+
714+
def test_function_with_annotated_field_required_constraints():
715+
"""Test function with required Annotated[int, Field(...)] parameter that has constraints."""
716+
717+
def func_with_annotated_field_constraints(
718+
my_number: Annotated[int, Field(..., gt=10, le=100)],
719+
) -> int:
720+
return my_number * 2
721+
722+
fs = function_schema(func_with_annotated_field_constraints, use_docstring_info=False)
723+
724+
# Check that the schema includes the constraints
725+
properties = fs.params_json_schema.get("properties", {})
726+
my_number_schema = properties.get("my_number", {})
727+
assert my_number_schema.get("type") == "integer"
728+
assert my_number_schema.get("exclusiveMinimum") == 10 # gt=10
729+
assert my_number_schema.get("maximum") == 100 # le=100
730+
731+
# Valid input should work
732+
valid_input = {"my_number": 50}
733+
parsed = fs.params_pydantic_model(**valid_input)
734+
args, kwargs_dict = fs.to_call_args(parsed)
735+
result = func_with_annotated_field_constraints(*args, **kwargs_dict)
736+
assert result == 100
737+
738+
# Invalid input: too small (should violate gt=10)
739+
with pytest.raises(ValidationError):
740+
fs.params_pydantic_model(**{"my_number": 5})
741+
742+
# Invalid input: too large (should violate le=100)
743+
with pytest.raises(ValidationError):
744+
fs.params_pydantic_model(**{"my_number": 150})
745+
746+
747+
def test_function_with_annotated_field_optional_with_default():
748+
"""Optional Annotated[float, Field(...)] param with default and constraints."""
749+
750+
def func_with_annotated_optional_field(
751+
required_param: str,
752+
optional_param: Annotated[float, Field(default=5.0, ge=0.0)],
753+
) -> str:
754+
return f"{required_param}: {optional_param}"
755+
756+
fs = function_schema(func_with_annotated_optional_field, use_docstring_info=False)
757+
758+
# Check that the schema includes the constraints and description
759+
properties = fs.params_json_schema.get("properties", {})
760+
optional_schema = properties.get("optional_param", {})
761+
assert optional_schema.get("type") == "number"
762+
assert optional_schema.get("minimum") == 0.0 # ge=0.0
763+
assert optional_schema.get("default") == 5.0
764+
765+
# Valid input with default
766+
valid_input = {"required_param": "test"}
767+
parsed = fs.params_pydantic_model(**valid_input)
768+
args, kwargs_dict = fs.to_call_args(parsed)
769+
result = func_with_annotated_optional_field(*args, **kwargs_dict)
770+
assert result == "test: 5.0"
771+
772+
# Valid input with explicit value
773+
valid_input2 = {"required_param": "test", "optional_param": 10.5}
774+
parsed2 = fs.params_pydantic_model(**valid_input2)
775+
args2, kwargs_dict2 = fs.to_call_args(parsed2)
776+
result2 = func_with_annotated_optional_field(*args2, **kwargs_dict2)
777+
assert result2 == "test: 10.5"
778+
779+
# Invalid input: negative value (should violate ge=0.0)
780+
with pytest.raises(ValidationError):
781+
fs.params_pydantic_model(**{"required_param": "test", "optional_param": -1.0})
782+
783+
784+
def test_function_with_annotated_field_string_constraints():
785+
"""Annotated[str, Field(...)] parameter with string constraints (min/max length, pattern)."""
786+
787+
def func_with_annotated_string_field(
788+
name: Annotated[
789+
str,
790+
Field(..., min_length=3, max_length=20, pattern=r"^[A-Za-z]+$"),
791+
],
792+
) -> str:
793+
return f"Hello, {name}!"
794+
795+
fs = function_schema(func_with_annotated_string_field, use_docstring_info=False)
796+
797+
# Check that the schema includes string constraints
798+
properties = fs.params_json_schema.get("properties", {})
799+
name_schema = properties.get("name", {})
800+
assert name_schema.get("type") == "string"
801+
assert name_schema.get("minLength") == 3
802+
assert name_schema.get("maxLength") == 20
803+
assert name_schema.get("pattern") == r"^[A-Za-z]+$"
804+
805+
# Valid input
806+
valid_input = {"name": "Alice"}
807+
parsed = fs.params_pydantic_model(**valid_input)
808+
args, kwargs_dict = fs.to_call_args(parsed)
809+
result = func_with_annotated_string_field(*args, **kwargs_dict)
810+
assert result == "Hello, Alice!"
811+
812+
# Invalid input: too short
813+
with pytest.raises(ValidationError):
814+
fs.params_pydantic_model(**{"name": "Al"})
815+
816+
# Invalid input: too long
817+
with pytest.raises(ValidationError):
818+
fs.params_pydantic_model(**{"name": "A" * 25})
819+
820+
# Invalid input: doesn't match pattern (contains numbers)
821+
with pytest.raises(ValidationError):
822+
fs.params_pydantic_model(**{"name": "Alice123"})
823+
824+
825+
def test_function_with_annotated_field_multiple_constraints():
826+
"""Test function with multiple Annotated params with Field having different constraint types."""
827+
828+
def func_with_annotated_multiple_field_constraints(
829+
score: Annotated[
830+
int,
831+
Field(..., ge=0, le=100, description="Score from 0 to 100"),
832+
],
833+
name: Annotated[str, Field(default="Unknown", min_length=1, max_length=50)],
834+
factor: Annotated[float, Field(default=1.0, gt=0.0, description="Positive multiplier")],
835+
) -> str:
836+
final_score = score * factor
837+
return f"{name} scored {final_score}"
838+
839+
fs = function_schema(func_with_annotated_multiple_field_constraints, use_docstring_info=False)
840+
841+
# Check schema structure
842+
properties = fs.params_json_schema.get("properties", {})
843+
844+
# Check score field
845+
score_schema = properties.get("score", {})
846+
assert score_schema.get("type") == "integer"
847+
assert score_schema.get("minimum") == 0
848+
assert score_schema.get("maximum") == 100
849+
assert score_schema.get("description") == "Score from 0 to 100"
850+
851+
# Check name field
852+
name_schema = properties.get("name", {})
853+
assert name_schema.get("type") == "string"
854+
assert name_schema.get("minLength") == 1
855+
assert name_schema.get("maxLength") == 50
856+
assert name_schema.get("default") == "Unknown"
857+
858+
# Check factor field
859+
factor_schema = properties.get("factor", {})
860+
assert factor_schema.get("type") == "number"
861+
assert factor_schema.get("exclusiveMinimum") == 0.0
862+
assert factor_schema.get("default") == 1.0
863+
assert factor_schema.get("description") == "Positive multiplier"
864+
865+
# Valid input with defaults
866+
valid_input = {"score": 85}
867+
parsed = fs.params_pydantic_model(**valid_input)
868+
args, kwargs_dict = fs.to_call_args(parsed)
869+
result = func_with_annotated_multiple_field_constraints(*args, **kwargs_dict)
870+
assert result == "Unknown scored 85.0"
871+
872+
# Valid input with all parameters
873+
valid_input2 = {"score": 90, "name": "Alice", "factor": 1.5}
874+
parsed2 = fs.params_pydantic_model(**valid_input2)
875+
args2, kwargs_dict2 = fs.to_call_args(parsed2)
876+
result2 = func_with_annotated_multiple_field_constraints(*args2, **kwargs_dict2)
877+
assert result2 == "Alice scored 135.0"
878+
879+
# Test various validation errors
880+
with pytest.raises(ValidationError): # score too high
881+
fs.params_pydantic_model(**{"score": 150})
882+
883+
with pytest.raises(ValidationError): # empty name
884+
fs.params_pydantic_model(**{"score": 50, "name": ""})
885+
886+
with pytest.raises(ValidationError): # zero factor
887+
fs.params_pydantic_model(**{"score": 50, "factor": 0.0})

0 commit comments

Comments
 (0)