Skip to content

Commit 1149fe4

Browse files
committed
Flatten everything under the arguments validator
1 parent 5969925 commit 1149fe4

File tree

5 files changed

+152
-146
lines changed

5 files changed

+152
-146
lines changed

python/pydantic_core/core_schema.py

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Mapping
1111
from datetime import date, datetime, time, timedelta
1212
from decimal import Decimal
13-
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union, overload
13+
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union
1414

1515
from typing_extensions import deprecated
1616

@@ -3372,53 +3372,15 @@ def arguments_parameter(
33723372
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)
33733373

33743374

3375-
class VarKwargsSchema(TypedDict):
3376-
type: Literal['var_kwargs']
3377-
mode: Literal['single', 'typed_dict']
3378-
schema: CoreSchema
3379-
3380-
3381-
@overload
3382-
def var_kwargs_schema(
3383-
*,
3384-
mode: Literal['single'],
3385-
schema: CoreSchema,
3386-
) -> VarKwargsSchema: ...
3387-
3388-
3389-
@overload
3390-
def var_kwargs_schema(
3391-
*,
3392-
mode: Literal['typed_dict'],
3393-
schema: TypedDictSchema,
3394-
) -> VarKwargsSchema: ...
3395-
3396-
3397-
def var_kwargs_schema(
3398-
*,
3399-
mode: Literal['single', 'typed_dict'],
3400-
schema: CoreSchema,
3401-
) -> VarKwargsSchema:
3402-
"""Returns a schema describing the variadic keyword arguments of a callable.
3403-
3404-
Args:
3405-
mode: The validation mode to use. If `'single'`, every value of the keyword arguments will
3406-
be validated against the core schema from the `schema` argument. If `'typed_dict'`, the
3407-
`schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema].
3408-
"""
3409-
3410-
return _dict_not_none(
3411-
type='var_kwargs',
3412-
mode=mode,
3413-
schema=schema,
3414-
)
3375+
VarKwargsMode: TypeAlias = Literal['single', 'unpacked-typed-dict']
34153376

34163377

34173378
class ArgumentsSchema(TypedDict, total=False):
34183379
type: Required[Literal['arguments']]
34193380
arguments_schema: Required[List[ArgumentsParameter]]
34203381
populate_by_name: bool
34213382
var_args_schema: CoreSchema
3383+
var_kwargs_mode: VarKwargsMode
34223384
var_kwargs_schema: CoreSchema
34233385
ref: str
34243386
metadata: Dict[str, Any]
@@ -3430,6 +3392,7 @@ def arguments_schema(
34303392
*,
34313393
populate_by_name: bool | None = None,
34323394
var_args_schema: CoreSchema | None = None,
3395+
var_kwargs_mode: VarKwargsMode | None = None,
34333396
var_kwargs_schema: CoreSchema | None = None,
34343397
ref: str | None = None,
34353398
metadata: Dict[str, Any] | None = None,
@@ -3456,6 +3419,9 @@ def arguments_schema(
34563419
arguments: The arguments to use for the arguments schema
34573420
populate_by_name: Whether to populate by name
34583421
var_args_schema: The variable args schema to use for the arguments schema
3422+
var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'single'`, every value of the
3423+
keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`,
3424+
the `schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema]
34593425
var_kwargs_schema: The variable kwargs schema to use for the arguments schema
34603426
ref: optional unique identifier of the schema, used to reference the schema in other places
34613427
metadata: Any other information you want to include with the schema, not used by pydantic-core
@@ -3466,6 +3432,7 @@ def arguments_schema(
34663432
arguments_schema=arguments,
34673433
populate_by_name=populate_by_name,
34683434
var_args_schema=var_args_schema,
3435+
var_kwargs_mode=var_kwargs_mode,
34693436
var_kwargs_schema=var_kwargs_schema,
34703437
ref=ref,
34713438
metadata=metadata,

src/validators/arguments.rs

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::str::FromStr;
2+
13
use pyo3::intern;
24
use pyo3::prelude::*;
35
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
@@ -15,6 +17,27 @@ use crate::tools::SchemaDict;
1517
use super::validation_state::ValidationState;
1618
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};
1719

20+
#[derive(Debug, PartialEq)]
21+
enum VarKwargsMode {
22+
Single,
23+
UnpackedTypedDict,
24+
}
25+
26+
impl FromStr for VarKwargsMode {
27+
type Err = PyErr;
28+
29+
fn from_str(s: &str) -> Result<Self, Self::Err> {
30+
match s {
31+
"single" => Ok(Self::Single),
32+
"unpacked-typed-dict" => Ok(Self::UnpackedTypedDict),
33+
s => py_schema_err!(
34+
"Invalid var_kwargs mode: `{}`, expected `single` or `unpacked-typed-dict`",
35+
s
36+
),
37+
}
38+
}
39+
}
40+
1841
#[derive(Debug)]
1942
struct Parameter {
2043
positional: bool,
@@ -29,6 +52,7 @@ pub struct ArgumentsValidator {
2952
parameters: Vec<Parameter>,
3053
positional_params_count: usize,
3154
var_args_validator: Option<Box<CombinedValidator>>,
55+
var_kwargs_mode: VarKwargsMode,
3256
var_kwargs_validator: Option<Box<CombinedValidator>>,
3357
loc_by_alias: bool,
3458
extra: ExtraBehavior,
@@ -117,17 +141,31 @@ impl BuildValidator for ArgumentsValidator {
117141
});
118142
}
119143

144+
let py_var_kwargs_mode: Bound<PyString> = match schema.get_as(intern!(py, "var_kwargs_mode"))? {
145+
Some(v) => v,
146+
None => PyString::new_bound(py, "single"),
147+
};
148+
let var_kwargs_mode = VarKwargsMode::from_str(py_var_kwargs_mode.to_string().as_str())?;
149+
let var_kwargs_validator = match schema.get_item(intern!(py, "var_kwargs_schema"))? {
150+
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
151+
None => None,
152+
};
153+
154+
if var_kwargs_mode == VarKwargsMode::UnpackedTypedDict && var_kwargs_validator.is_none() {
155+
return py_schema_err!(
156+
"`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
157+
);
158+
}
159+
120160
Ok(Self {
121161
parameters,
122162
positional_params_count,
123163
var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? {
124164
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
125165
None => None,
126166
},
127-
var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? {
128-
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
129-
None => None,
130-
},
167+
var_kwargs_mode,
168+
var_kwargs_validator,
131169
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
132170
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
133171
}
@@ -258,6 +296,8 @@ impl Validator for ArgumentsValidator {
258296
// if there are kwargs check any that haven't been processed yet
259297
if let Some(kwargs) = args.kwargs() {
260298
if kwargs.len() > used_kwargs.len() {
299+
let remaining_kwargs = PyDict::new_bound(py);
300+
261301
for result in kwargs.iter() {
262302
let (raw_key, value) = result?;
263303
let either_str = match raw_key
@@ -278,28 +318,55 @@ impl Validator for ArgumentsValidator {
278318
Err(err) => return Err(err),
279319
};
280320
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
281-
match self.var_kwargs_validator {
282-
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
283-
Ok(value) => {
284-
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
285-
}
286-
Err(ValError::LineErrors(line_errors)) => {
287-
for err in line_errors {
288-
errors.push(err.with_outer_location(raw_key.clone()));
321+
match self.var_kwargs_mode {
322+
VarKwargsMode::Single => match self.var_kwargs_validator {
323+
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
324+
Ok(value) => {
325+
output_kwargs
326+
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
327+
}
328+
Err(ValError::LineErrors(line_errors)) => {
329+
for err in line_errors {
330+
errors.push(err.with_outer_location(raw_key.clone()));
331+
}
332+
}
333+
Err(err) => return Err(err),
334+
},
335+
None => {
336+
if let ExtraBehavior::Forbid = self.extra {
337+
errors.push(ValLineError::new_with_loc(
338+
ErrorTypeDefaults::UnexpectedKeywordArgument,
339+
value,
340+
raw_key.clone(),
341+
));
289342
}
290343
}
291-
Err(err) => return Err(err),
292344
},
293-
None => {
294-
if let ExtraBehavior::Forbid = self.extra {
295-
errors.push(ValLineError::new_with_loc(
296-
ErrorTypeDefaults::UnexpectedKeywordArgument,
297-
value,
298-
raw_key.clone(),
299-
));
300-
}
345+
VarKwargsMode::UnpackedTypedDict => {
346+
// Save to the remaining kwargs, we will validate as a single dict:
347+
remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
348+
}
349+
}
350+
}
351+
}
352+
353+
if self.var_kwargs_mode == VarKwargsMode::UnpackedTypedDict {
354+
// `var_kwargs_validator` is guaranteed to be `Some`:
355+
match self
356+
.var_kwargs_validator
357+
.as_ref()
358+
.unwrap()
359+
.validate(py, remaining_kwargs.as_any(), state)
360+
{
361+
Ok(value) => {
362+
output_kwargs.update(value.downcast_bound::<PyDict>(py).unwrap().as_mapping())?;
363+
}
364+
Err(ValError::LineErrors(line_errors)) => {
365+
for error in line_errors {
366+
errors.push(error);
301367
}
302368
}
369+
Err(err) => return Err(err),
303370
}
304371
}
305372
}

src/validators/mod.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ mod union;
6161
mod url;
6262
mod uuid;
6363
mod validation_state;
64-
mod var_kwargs;
6564
mod with_default;
6665

6766
pub use self::validation_state::{Exactness, ValidationState};
@@ -562,7 +561,6 @@ pub fn build_validator(
562561
callable::CallableValidator,
563562
// arguments
564563
arguments::ArgumentsValidator,
565-
var_kwargs::VarKwargsValidator,
566564
// default value
567565
with_default::WithDefaultValidator,
568566
// chain validators
@@ -718,7 +716,6 @@ pub enum CombinedValidator {
718716
Callable(callable::CallableValidator),
719717
// arguments
720718
Arguments(arguments::ArgumentsValidator),
721-
VarKwargs(var_kwargs::VarKwargsValidator),
722719
// default value
723720
WithDefault(with_default::WithDefaultValidator),
724721
// chain validators

src/validators/var_kwargs.rs

Lines changed: 0 additions & 80 deletions
This file was deleted.

0 commit comments

Comments
 (0)