Skip to content

Commit f3f436e

Browse files
Viicossydney-runkledavidhewitt
authored
Add support for unpacked TypedDict to type hint variadic keyword arguments in ArgumentsValidator (#1451)
Co-authored-by: Sydney Runkle <[email protected]> Co-authored-by: David Hewitt <[email protected]>
1 parent 8c1a0da commit f3f436e

File tree

3 files changed

+152
-22
lines changed

3 files changed

+152
-22
lines changed

python/pydantic_core/core_schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3372,11 +3372,15 @@ def arguments_parameter(
33723372
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)
33733373

33743374

3375+
VarKwargsMode: TypeAlias = Literal['uniform', 'unpacked-typed-dict']
3376+
3377+
33753378
class ArgumentsSchema(TypedDict, total=False):
33763379
type: Required[Literal['arguments']]
33773380
arguments_schema: Required[List[ArgumentsParameter]]
33783381
populate_by_name: bool
33793382
var_args_schema: CoreSchema
3383+
var_kwargs_mode: VarKwargsMode
33803384
var_kwargs_schema: CoreSchema
33813385
ref: str
33823386
metadata: Dict[str, Any]
@@ -3388,6 +3392,7 @@ def arguments_schema(
33883392
*,
33893393
populate_by_name: bool | None = None,
33903394
var_args_schema: CoreSchema | None = None,
3395+
var_kwargs_mode: VarKwargsMode | None = None,
33913396
var_kwargs_schema: CoreSchema | None = None,
33923397
ref: str | None = None,
33933398
metadata: Dict[str, Any] | None = None,
@@ -3414,6 +3419,9 @@ def arguments_schema(
34143419
arguments: The arguments to use for the arguments schema
34153420
populate_by_name: Whether to populate by name
34163421
var_args_schema: The variable args schema to use for the arguments schema
3422+
var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'uniform'`, every value of the
3423+
keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`,
3424+
the `var_kwargs_schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema]
34173425
var_kwargs_schema: The variable kwargs schema to use for the arguments schema
34183426
ref: optional unique identifier of the schema, used to reference the schema in other places
34193427
metadata: Any other information you want to include with the schema, not used by pydantic-core
@@ -3424,6 +3432,7 @@ def arguments_schema(
34243432
arguments_schema=arguments,
34253433
populate_by_name=populate_by_name,
34263434
var_args_schema=var_args_schema,
3435+
var_kwargs_mode=var_kwargs_mode,
34273436
var_kwargs_schema=var_kwargs_schema,
34283437
ref=ref,
34293438
metadata=metadata,

src/validators/arguments.rs

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::str::FromStr;
2+
13
use pyo3::intern;
24
use pyo3::prelude::*;
35
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
@@ -15,6 +17,27 @@ use crate::tools::SchemaDict;
1517
use super::validation_state::ValidationState;
1618
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};
1719

20+
#[derive(Debug, PartialEq)]
21+
enum VarKwargsMode {
22+
Uniform,
23+
UnpackedTypedDict,
24+
}
25+
26+
impl FromStr for VarKwargsMode {
27+
type Err = PyErr;
28+
29+
fn from_str(s: &str) -> Result<Self, Self::Err> {
30+
match s {
31+
"uniform" => Ok(Self::Uniform),
32+
"unpacked-typed-dict" => Ok(Self::UnpackedTypedDict),
33+
s => py_schema_err!(
34+
"Invalid var_kwargs mode: `{}`, expected `uniform` or `unpacked-typed-dict`",
35+
s
36+
),
37+
}
38+
}
39+
}
40+
1841
#[derive(Debug)]
1942
struct Parameter {
2043
positional: bool,
@@ -29,6 +52,7 @@ pub struct ArgumentsValidator {
2952
parameters: Vec<Parameter>,
3053
positional_params_count: usize,
3154
var_args_validator: Option<Box<CombinedValidator>>,
55+
var_kwargs_mode: VarKwargsMode,
3256
var_kwargs_validator: Option<Box<CombinedValidator>>,
3357
loc_by_alias: bool,
3458
extra: ExtraBehavior,
@@ -117,17 +141,31 @@ impl BuildValidator for ArgumentsValidator {
117141
});
118142
}
119143

144+
let py_var_kwargs_mode: Bound<PyString> = schema
145+
.get_as(intern!(py, "var_kwargs_mode"))?
146+
.unwrap_or_else(|| PyString::new_bound(py, "uniform"));
147+
148+
let var_kwargs_mode = VarKwargsMode::from_str(py_var_kwargs_mode.to_str()?)?;
149+
let var_kwargs_validator = match schema.get_item(intern!(py, "var_kwargs_schema"))? {
150+
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
151+
None => None,
152+
};
153+
154+
if var_kwargs_mode == VarKwargsMode::UnpackedTypedDict && var_kwargs_validator.is_none() {
155+
return py_schema_err!(
156+
"`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
157+
);
158+
}
159+
120160
Ok(Self {
121161
parameters,
122162
positional_params_count,
123163
var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? {
124164
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
125165
None => None,
126166
},
127-
var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? {
128-
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
129-
None => None,
130-
},
167+
var_kwargs_mode,
168+
var_kwargs_validator,
131169
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
132170
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
133171
}
@@ -255,6 +293,9 @@ impl Validator for ArgumentsValidator {
255293
}
256294
}
257295
}
296+
297+
let remaining_kwargs = PyDict::new_bound(py);
298+
258299
// if there are kwargs check any that haven't been processed yet
259300
if let Some(kwargs) = args.kwargs() {
260301
if kwargs.len() > used_kwargs.len() {
@@ -278,33 +319,58 @@ impl Validator for ArgumentsValidator {
278319
Err(err) => return Err(err),
279320
};
280321
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
281-
match self.var_kwargs_validator {
282-
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
283-
Ok(value) => {
284-
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
285-
}
286-
Err(ValError::LineErrors(line_errors)) => {
287-
for err in line_errors {
288-
errors.push(err.with_outer_location(raw_key.clone()));
322+
match self.var_kwargs_mode {
323+
VarKwargsMode::Uniform => match &self.var_kwargs_validator {
324+
Some(validator) => match validator.validate(py, value.borrow_input(), state) {
325+
Ok(value) => {
326+
output_kwargs
327+
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
328+
}
329+
Err(ValError::LineErrors(line_errors)) => {
330+
for err in line_errors {
331+
errors.push(err.with_outer_location(raw_key.clone()));
332+
}
333+
}
334+
Err(err) => return Err(err),
335+
},
336+
None => {
337+
if let ExtraBehavior::Forbid = self.extra {
338+
errors.push(ValLineError::new_with_loc(
339+
ErrorTypeDefaults::UnexpectedKeywordArgument,
340+
value,
341+
raw_key.clone(),
342+
));
289343
}
290344
}
291-
Err(err) => return Err(err),
292345
},
293-
None => {
294-
if let ExtraBehavior::Forbid = self.extra {
295-
errors.push(ValLineError::new_with_loc(
296-
ErrorTypeDefaults::UnexpectedKeywordArgument,
297-
value,
298-
raw_key.clone(),
299-
));
300-
}
346+
VarKwargsMode::UnpackedTypedDict => {
347+
// Save to the remaining kwargs, we will validate as a single dict:
348+
remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
301349
}
302350
}
303351
}
304352
}
305353
}
306354
}
307355

356+
if self.var_kwargs_mode == VarKwargsMode::UnpackedTypedDict {
357+
// `var_kwargs_validator` is guaranteed to be `Some`:
358+
match self
359+
.var_kwargs_validator
360+
.as_ref()
361+
.unwrap()
362+
.validate(py, remaining_kwargs.as_any(), state)
363+
{
364+
Ok(value) => {
365+
output_kwargs.update(value.downcast_bound::<PyDict>(py).unwrap().as_mapping())?;
366+
}
367+
Err(ValError::LineErrors(line_errors)) => {
368+
errors.extend(line_errors);
369+
}
370+
Err(err) => return Err(err),
371+
}
372+
}
373+
308374
if !errors.is_empty() {
309375
Err(ValError::LineErrors(errors))
310376
} else {

tests/validators/test_arguments.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,19 @@ def test_build_non_default_follows():
769769
)
770770

771771

772+
def test_build_missing_var_kwargs():
773+
with pytest.raises(
774+
SchemaError, match="`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
775+
):
776+
SchemaValidator(
777+
{
778+
'type': 'arguments',
779+
'arguments_schema': [],
780+
'var_kwargs_mode': 'unpacked-typed-dict',
781+
}
782+
)
783+
784+
772785
@pytest.mark.parametrize(
773786
'input_value,expected',
774787
[
@@ -778,7 +791,7 @@ def test_build_non_default_follows():
778791
],
779792
ids=repr,
780793
)
781-
def test_kwargs(py_and_json: PyAndJson, input_value, expected):
794+
def test_kwargs_uniform(py_and_json: PyAndJson, input_value, expected):
782795
v = py_and_json(
783796
{
784797
'type': 'arguments',
@@ -796,6 +809,48 @@ def test_kwargs(py_and_json: PyAndJson, input_value, expected):
796809
assert v.validate_test(input_value) == expected
797810

798811

812+
@pytest.mark.parametrize(
813+
'input_value,expected',
814+
[
815+
[ArgsKwargs((), {'x': 1}), ((), {'x': 1})],
816+
[ArgsKwargs((), {'x': 1.0}), Err('x\n Input should be a valid integer [type=int_type,')],
817+
[ArgsKwargs((), {'x': 1, 'z': 'str'}), ((), {'x': 1, 'y': 'str'})],
818+
[ArgsKwargs((), {'x': 1, 'y': 'str'}), Err('y\n Extra inputs are not permitted [type=extra_forbidden,')],
819+
],
820+
)
821+
def test_kwargs_typed_dict(py_and_json: PyAndJson, input_value, expected):
822+
v = py_and_json(
823+
{
824+
'type': 'arguments',
825+
'arguments_schema': [],
826+
'var_kwargs_mode': 'unpacked-typed-dict',
827+
'var_kwargs_schema': {
828+
'type': 'typed-dict',
829+
'fields': {
830+
'x': {
831+
'type': 'typed-dict-field',
832+
'schema': {'type': 'int', 'strict': True},
833+
'required': True,
834+
},
835+
'y': {
836+
'type': 'typed-dict-field',
837+
'schema': {'type': 'str'},
838+
'required': False,
839+
'validation_alias': 'z',
840+
},
841+
},
842+
'config': {'extra_fields_behavior': 'forbid'},
843+
},
844+
}
845+
)
846+
847+
if isinstance(expected, Err):
848+
with pytest.raises(ValidationError, match=re.escape(expected.message)):
849+
v.validate_test(input_value)
850+
else:
851+
assert v.validate_test(input_value) == expected
852+
853+
799854
@pytest.mark.parametrize(
800855
'input_value,expected',
801856
[

0 commit comments

Comments
 (0)