Skip to content

Commit ec3d870

Browse files
committed
support pass through on functions
1 parent 51b6ee1 commit ec3d870

File tree

7 files changed

+305
-96
lines changed

7 files changed

+305
-96
lines changed

pydantic_core/_pydantic_core.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,26 @@ class SchemaValidator:
3838
def title(self) -> str: ...
3939
def __init__(self, schema: CoreSchema, config: 'CoreConfig | None' = None) -> None: ...
4040
def validate_python(
41-
self, input: Any, *, strict: 'bool | None' = None, context: Any = None, in_init: bool = False
41+
self, input: Any, *, strict: 'bool | None' = None, context: Any = None, init_mode: bool = False
4242
) -> Any: ...
4343
def isinstance_python(
44-
self, input: Any, *, strict: 'bool | None' = None, context: Any = None, in_init: bool = False
44+
self, input: Any, *, strict: 'bool | None' = None, context: Any = None, init_mode: bool = False
4545
) -> bool: ...
4646
def validate_json(
4747
self,
4848
input: 'str | bytes | bytearray',
4949
*,
5050
strict: 'bool | None' = None,
5151
context: Any = None,
52-
in_init: bool = False,
52+
init_mode: bool = False,
5353
) -> Any: ...
5454
def isinstance_json(
5555
self,
5656
input: 'str | bytes | bytearray',
5757
*,
5858
strict: 'bool | None' = None,
5959
context: Any = None,
60-
in_init: bool = False,
60+
init_mode: bool = False,
6161
) -> bool: ...
6262
def validate_assignment(
6363
self, field: str, input: Any, data: 'dict[str, Any]', strict: 'bool | None' = None, context: Any = None

src/validators/function.rs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,24 @@ impl Validator for FunctionBeforeValidator {
110110
.validate(py, value.into_ref(py), extra, slots, recursion_guard)
111111
}
112112

113+
fn validate_init<'s, 'data>(
114+
&'s self,
115+
py: Python<'data>,
116+
input: &'data impl Input<'data>,
117+
extra: &Extra,
118+
slots: &'data [CombinedValidator],
119+
recursion_guard: &'s mut RecursionGuard,
120+
) -> ValResult<'data, PyObject> {
121+
let info = ValidationInfo::new(py, extra, &self.config);
122+
let value = self
123+
.func
124+
.call1(py, (input.to_object(py), info))
125+
.map_err(|e| convert_err(py, e, input))?;
126+
127+
self.validator
128+
.validate_init(py, value.into_ref(py), extra, slots, recursion_guard)
129+
}
130+
113131
fn get_name(&self) -> &str {
114132
&self.name
115133
}
@@ -148,6 +166,19 @@ impl Validator for FunctionAfterValidator {
148166
self.func.call1(py, (v, info)).map_err(|e| convert_err(py, e, input))
149167
}
150168

169+
fn validate_init<'s, 'data>(
170+
&'s self,
171+
py: Python<'data>,
172+
input: &'data impl Input<'data>,
173+
extra: &Extra,
174+
slots: &'data [CombinedValidator],
175+
recursion_guard: &'s mut RecursionGuard,
176+
) -> ValResult<'data, PyObject> {
177+
let v = self.validator.validate_init(py, input, extra, slots, recursion_guard)?;
178+
let info = ValidationInfo::new(py, extra, &self.config);
179+
self.func.call1(py, (v, info)).map_err(|e| convert_err(py, e, input))
180+
}
181+
151182
fn get_name(&self) -> &str {
152183
&self.name
153184
}
@@ -227,7 +258,40 @@ impl Validator for FunctionWrapValidator {
227258
recursion_guard: &'s mut RecursionGuard,
228259
) -> ValResult<'data, PyObject> {
229260
let call_next_validator = ValidatorCallable {
230-
validator: InternalValidator::new(py, "ValidatorCallable", &self.validator, slots, extra, recursion_guard),
261+
validator: InternalValidator::new(
262+
py,
263+
"ValidatorCallable",
264+
&self.validator,
265+
slots,
266+
extra,
267+
recursion_guard,
268+
false,
269+
),
270+
};
271+
let info = ValidationInfo::new(py, extra, &self.config);
272+
self.func
273+
.call1(py, (input.to_object(py), call_next_validator, info))
274+
.map_err(|e| convert_err(py, e, input))
275+
}
276+
277+
fn validate_init<'s, 'data>(
278+
&'s self,
279+
py: Python<'data>,
280+
input: &'data impl Input<'data>,
281+
extra: &Extra,
282+
slots: &'data [CombinedValidator],
283+
recursion_guard: &'s mut RecursionGuard,
284+
) -> ValResult<'data, PyObject> {
285+
let call_next_validator = ValidatorCallable {
286+
validator: InternalValidator::new(
287+
py,
288+
"ValidatorCallable",
289+
&self.validator,
290+
slots,
291+
extra,
292+
recursion_guard,
293+
true,
294+
),
231295
};
232296
let info = ValidationInfo::new(py, extra, &self.config, self.is_field_validator)?;
233297
self.func

src/validators/generator.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ impl Validator for GeneratorValidator {
5757
let validator = self
5858
.item_validator
5959
.as_ref()
60-
.map(|v| InternalValidator::new(py, "ValidatorIterator", v, slots, extra, recursion_guard));
60+
.map(|v| InternalValidator::new(py, "ValidatorIterator", v, slots, extra, recursion_guard, false));
6161

6262
let v_iterator = ValidatorIterator {
6363
iterator,
@@ -196,6 +196,7 @@ pub struct InternalValidator {
196196
strict: Option<bool>,
197197
context: Option<PyObject>,
198198
recursion_guard: RecursionGuard,
199+
init_mode: bool,
199200
}
200201

201202
impl fmt::Debug for InternalValidator {
@@ -212,6 +213,7 @@ impl InternalValidator {
212213
slots: &[CombinedValidator],
213214
extra: &Extra,
214215
recursion_guard: &RecursionGuard,
216+
init_mode: bool,
215217
) -> Self {
216218
Self {
217219
name: name.to_string(),
@@ -222,6 +224,7 @@ impl InternalValidator {
222224
strict: extra.strict,
223225
context: extra.context.map(|d| d.into_py(py)),
224226
recursion_guard: recursion_guard.clone(),
227+
init_mode,
225228
}
226229
}
227230

@@ -241,8 +244,13 @@ impl InternalValidator {
241244
context: self.context.as_ref().map(|data| data.as_ref(py)),
242245
field_name: None,
243246
};
244-
self.validator
245-
.validate(py, input, &extra, &self.slots, &mut self.recursion_guard)
246-
.map_err(|e| ValidationError::from_val_error(py, self.name.to_object(py), e, outer_location))
247+
let r = if self.init_mode {
248+
self.validator
249+
.validate_init(py, input, &extra, &self.slots, &mut self.recursion_guard)
250+
} else {
251+
self.validator
252+
.validate(py, input, &extra, &self.slots, &mut self.recursion_guard)
253+
};
254+
r.map_err(|e| ValidationError::from_val_error(py, self.name.to_object(py), e, outer_location))
247255
}
248256
}

src/validators/mod.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,65 +97,65 @@ impl SchemaValidator {
9797
Ok((cls, args).into_py(py))
9898
}
9999

100-
#[pyo3(signature = (input, *, strict=None, context=None, in_init=false))]
100+
#[pyo3(signature = (input, *, strict=None, context=None, init_mode=false))]
101101
pub fn validate_python(
102102
&self,
103103
py: Python,
104104
input: &PyAny,
105105
strict: Option<bool>,
106106
context: Option<&PyAny>,
107-
in_init: bool,
107+
init_mode: bool,
108108
) -> PyResult<PyObject> {
109-
let r = self._validate(py, input, strict, context, in_init);
109+
let r = self._validate(py, input, strict, context, init_mode);
110110
r.map_err(|e| self.prepare_validation_err(py, e))
111111
}
112112

113-
#[pyo3(signature = (input, *, strict=None, context=None, in_init=false))]
113+
#[pyo3(signature = (input, *, strict=None, context=None, init_mode=false))]
114114
pub fn isinstance_python(
115115
&self,
116116
py: Python,
117117
input: &PyAny,
118118
strict: Option<bool>,
119119
context: Option<&PyAny>,
120-
in_init: bool,
120+
init_mode: bool,
121121
) -> PyResult<bool> {
122-
match self._validate(py, input, strict, context, in_init) {
122+
match self._validate(py, input, strict, context, init_mode) {
123123
Ok(_) => Ok(true),
124124
Err(ValError::InternalErr(err)) => Err(err),
125125
Err(ValError::Omit) => Err(ValidationError::omit_error()),
126126
Err(ValError::LineErrors(_)) => Ok(false),
127127
}
128128
}
129129

130-
#[pyo3(signature = (input, *, strict=None, context=None, in_init=false))]
130+
#[pyo3(signature = (input, *, strict=None, context=None, init_mode=false))]
131131
pub fn validate_json(
132132
&self,
133133
py: Python,
134134
input: &PyAny,
135135
strict: Option<bool>,
136136
context: Option<&PyAny>,
137-
in_init: bool,
137+
init_mode: bool,
138138
) -> PyResult<PyObject> {
139139
match input.parse_json() {
140140
Ok(input) => {
141-
let r = self._validate(py, &input, strict, context, in_init);
141+
let r = self._validate(py, &input, strict, context, init_mode);
142142
r.map_err(|e| self.prepare_validation_err(py, e))
143143
}
144144
Err(err) => Err(self.prepare_validation_err(py, err)),
145145
}
146146
}
147147

148-
#[pyo3(signature = (input, *, strict=None, context=None, in_init=false))]
148+
#[pyo3(signature = (input, *, strict=None, context=None, init_mode=false))]
149149
pub fn isinstance_json(
150150
&self,
151151
py: Python,
152152
input: &PyAny,
153153
strict: Option<bool>,
154154
context: Option<&PyAny>,
155-
in_init: bool,
155+
init_mode: bool,
156156
) -> PyResult<bool> {
157157
match input.parse_json() {
158-
Ok(input) => match self._validate(py, &input, strict, context, in_init) {
158+
Ok(input) => match self._validate(py, &input, strict, context, init_mode) {
159159
Ok(_) => Ok(true),
160160
Err(ValError::InternalErr(err)) => Err(err),
161161
Err(ValError::Omit) => Err(ValidationError::omit_error()),
@@ -220,12 +220,12 @@ impl SchemaValidator {
220220
input: &'data impl Input<'data>,
221221
strict: Option<bool>,
222222
context: Option<&'data PyAny>,
223-
in_init: bool,
223+
init_mode: bool,
224224
) -> ValResult<'data, PyObject>
225225
where
226226
's: 'data,
227227
{
228-
if in_init {
228+
if init_mode {
229229
self.validator.validate_init(
230230
py,
231231
input,

tests/test_isinstance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def test_isinstance():
1111
assert v.isinstance_python(123) is True
1212
assert v.validate_python('123') == 123
1313
assert v.isinstance_python('123') is True
14-
assert v.validate_python('123', in_init=True) == 123
15-
assert v.isinstance_python('123', in_init=True) is True
14+
assert v.validate_python('123', init_mode=True) == 123
15+
assert v.isinstance_python('123', init_mode=True) is True
1616

1717
with pytest.raises(ValidationError, match='Input should be a valid integer'):
1818
v.validate_python('foo')

tests/validators/test_model.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -638,74 +638,3 @@ def call_me_baby(self, context, **kwargs):
638638
assert m.field_b == 12
639639
assert m.__fields_set__ == {'field_a'}
640640
assert m.__dict__ == {'field_a': 'testtest', 'field_b': 12}
641-
642-
643-
def test_model_init():
644-
class MyModel:
645-
# this is not required, but it avoids `__fields_set__` being included in `__dict__`
646-
__slots__ = '__dict__', '__fields_set__'
647-
field_a: str
648-
field_b: int
649-
650-
v = SchemaValidator(
651-
{
652-
'type': 'model',
653-
'cls': MyModel,
654-
'schema': {
655-
'type': 'typed-dict',
656-
'return_fields_set': True,
657-
'fields': {'field_a': {'schema': {'type': 'str'}}, 'field_b': {'schema': {'type': 'int'}}},
658-
},
659-
}
660-
)
661-
m = v.validate_python({'field_a': 'test', 'field_b': 12})
662-
assert isinstance(m, MyModel)
663-
assert m.field_a == 'test'
664-
assert m.field_b == 12
665-
d, fields_set = v.validate_python({'field_a': 'test', 'field_b': 12}, in_init=True)
666-
assert d == {'field_a': 'test', 'field_b': 12}
667-
assert fields_set == {'field_a', 'field_b'}
668-
669-
670-
def test_model_init_nested():
671-
class MyModel:
672-
# this is not required, but it avoids `__fields_set__` being included in `__dict__`
673-
__slots__ = '__dict__', '__fields_set__'
674-
675-
v = SchemaValidator(
676-
{
677-
'type': 'model',
678-
'cls': MyModel,
679-
'schema': {
680-
'type': 'typed-dict',
681-
'return_fields_set': True,
682-
'fields': {
683-
'field_a': {'schema': {'type': 'str'}},
684-
'field_b': {
685-
'schema': {
686-
'type': 'model',
687-
'cls': MyModel,
688-
'schema': {
689-
'type': 'typed-dict',
690-
'return_fields_set': True,
691-
'fields': {'x_a': {'schema': {'type': 'str'}}, 'x_b': {'schema': {'type': 'int'}}},
692-
},
693-
}
694-
},
695-
},
696-
},
697-
}
698-
)
699-
m = v.validate_python({'field_a': 'test', 'field_b': {'x_a': 'foo', 'x_b': 12}})
700-
assert isinstance(m, MyModel)
701-
assert m.field_a == 'test'
702-
assert isinstance(m.field_b, MyModel)
703-
assert m.field_b.x_a == 'foo'
704-
assert m.field_b.x_b == 12
705-
d, fields_set = v.validate_python({'field_a': 'test', 'field_b': {'x_a': 'foo', 'x_b': 12}}, in_init=True)
706-
assert d['field_a'] == 'test'
707-
assert isinstance(d['field_b'], MyModel)
708-
assert d['field_b'].x_a == 'foo'
709-
assert d['field_b'].x_b == 12
710-
711-
assert fields_set == {'field_a', 'field_b'}

0 commit comments

Comments
 (0)