Skip to content

Commit 6a200c9

Browse files
Add field_name to ValidatorInfo (#439)
* Add field to ValidatorInfo * minimize changes * cleanup * pr feedback * use field_name * fix types * Fix defaults * wip, can't get self schema to work * handle edge cases * update docstrings * fix docstrings * fix tests * don't use a pipe * replace double with single quotes * make type first field and remove Required[] * rename model_field_function_plain_schema to method_plain_schema * move logic into destructure_function_schema * implement pr feedback on getters * rename Extra.field to Extra.field_we_are_currently_assigning_to * rename field_we_are_currently_assigning_to to assignee_field * rename methods * fix logic * rename type * rename callback -> function, fix error message * rename callback -> function, more * tests for info.data --------- Co-authored-by: Samuel Colvin <[email protected]>
1 parent cc79b05 commit 6a200c9

20 files changed

+1061
-196
lines changed

pydantic_core/core_schema.py

Lines changed: 316 additions & 43 deletions
Large diffs are not rendered by default.

src/validators/function.rs

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError};
1+
use pyo3::exceptions::{PyAssertionError, PyAttributeError, PyRuntimeError, PyTypeError, PyValueError};
22
use pyo3::intern;
33
use pyo3::prelude::*;
4-
use pyo3::types::{PyAny, PyDict};
4+
use pyo3::types::{PyAny, PyDict, PyString};
55

66
use crate::build_tools::{function_name, py_err, SchemaDict};
77
use crate::errors::{
@@ -16,6 +16,18 @@ use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Ex
1616

1717
pub struct FunctionBuilder;
1818

19+
fn destructure_function_schema(schema: &PyDict) -> PyResult<(bool, &PyAny)> {
20+
let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?;
21+
let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?;
22+
let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?;
23+
let is_field_validator = match func_type {
24+
"field" => true,
25+
"general" => false,
26+
_ => unreachable!(),
27+
};
28+
Ok((is_field_validator, function))
29+
}
30+
1931
impl BuildValidator for FunctionBuilder {
2032
const EXPECTED_TYPE: &'static str = "function";
2133

@@ -45,7 +57,7 @@ macro_rules! impl_build {
4557
) -> PyResult<CombinedValidator> {
4658
let py = schema.py();
4759
let validator = build_validator(schema.get_as_req(intern!(py, "schema"))?, config, build_context)?;
48-
let function = schema.get_as_req::<&PyAny>(intern!(py, "function"))?;
60+
let (is_field_validator, function) = destructure_function_schema(schema)?;
4961
let name = format!(
5062
"{}[{}(), {}]",
5163
$name,
@@ -60,6 +72,7 @@ macro_rules! impl_build {
6072
None => py.None(),
6173
},
6274
name,
75+
is_field_validator,
6376
}
6477
.into())
6578
}
@@ -73,6 +86,7 @@ pub struct FunctionBeforeValidator {
7386
func: PyObject,
7487
config: PyObject,
7588
name: String,
89+
is_field_validator: bool,
7690
}
7791

7892
impl_build!(FunctionBeforeValidator, "function-before");
@@ -86,7 +100,7 @@ impl Validator for FunctionBeforeValidator {
86100
slots: &'data [CombinedValidator],
87101
recursion_guard: &'s mut RecursionGuard,
88102
) -> ValResult<'data, PyObject> {
89-
let info = ValidationInfo::new(py, extra, &self.config);
103+
let info = ValidationInfo::new(py, extra, &self.config, self.is_field_validator)?;
90104
let value = self
91105
.func
92106
.call1(py, (input.to_object(py), info))
@@ -115,6 +129,7 @@ pub struct FunctionAfterValidator {
115129
func: PyObject,
116130
config: PyObject,
117131
name: String,
132+
is_field_validator: bool,
118133
}
119134

120135
impl_build!(FunctionAfterValidator, "function-after");
@@ -129,7 +144,7 @@ impl Validator for FunctionAfterValidator {
129144
recursion_guard: &'s mut RecursionGuard,
130145
) -> ValResult<'data, PyObject> {
131146
let v = self.validator.validate(py, input, extra, slots, recursion_guard)?;
132-
let info = ValidationInfo::new(py, extra, &self.config);
147+
let info = ValidationInfo::new(py, extra, &self.config, self.is_field_validator)?;
133148
self.func.call1(py, (v, info)).map_err(|e| convert_err(py, e, input))
134149
}
135150

@@ -151,19 +166,21 @@ pub struct FunctionPlainValidator {
151166
func: PyObject,
152167
config: PyObject,
153168
name: String,
169+
is_field_validator: bool,
154170
}
155171

156172
impl FunctionPlainValidator {
157173
pub fn build(schema: &PyDict, config: Option<&PyDict>) -> PyResult<CombinedValidator> {
158174
let py = schema.py();
159-
let function = schema.get_as_req::<&PyAny>(intern!(py, "function"))?;
175+
let (is_field_validator, function) = destructure_function_schema(schema)?;
160176
Ok(Self {
161177
func: function.into_py(py),
162178
config: match config {
163179
Some(c) => c.into(),
164180
None => py.None(),
165181
},
166182
name: format!("function-plain[{}()]", function_name(function)?),
183+
is_field_validator,
167184
}
168185
.into())
169186
}
@@ -178,7 +195,7 @@ impl Validator for FunctionPlainValidator {
178195
_slots: &'data [CombinedValidator],
179196
_recursion_guard: &'s mut RecursionGuard,
180197
) -> ValResult<'data, PyObject> {
181-
let info = ValidationInfo::new(py, extra, &self.config);
198+
let info = ValidationInfo::new(py, extra, &self.config, self.is_field_validator)?;
182199
self.func
183200
.call1(py, (input.to_object(py), info))
184201
.map_err(|e| convert_err(py, e, input))
@@ -195,6 +212,7 @@ pub struct FunctionWrapValidator {
195212
func: PyObject,
196213
config: PyObject,
197214
name: String,
215+
is_field_validator: bool,
198216
}
199217

200218
impl_build!(FunctionWrapValidator, "function-wrap");
@@ -211,7 +229,7 @@ impl Validator for FunctionWrapValidator {
211229
let call_next_validator = ValidatorCallable {
212230
validator: InternalValidator::new(py, "ValidatorCallable", &self.validator, slots, extra, recursion_guard),
213231
};
214-
let info = ValidationInfo::new(py, extra, &self.config);
232+
let info = ValidationInfo::new(py, extra, &self.config, self.is_field_validator)?;
215233
self.func
216234
.call1(py, (input.to_object(py), call_next_validator, info))
217235
.map_err(|e| convert_err(py, e, input))
@@ -300,20 +318,54 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) ->
300318

301319
#[pyclass(module = "pydantic_core._pydantic_core")]
302320
pub struct ValidationInfo {
303-
#[pyo3(get)]
304-
data: Option<Py<PyDict>>,
305321
#[pyo3(get)]
306322
config: PyObject,
307323
#[pyo3(get)]
308324
context: Option<PyObject>,
325+
data: Option<Py<PyDict>>,
326+
field_name: Option<String>,
309327
}
310328

311329
impl ValidationInfo {
312-
fn new(py: Python, extra: &Extra, config: &PyObject) -> Self {
313-
Self {
314-
data: extra.data.map(|v| v.into()),
315-
config: config.clone_ref(py),
316-
context: extra.context.map(|v| v.into()),
330+
fn new(py: Python, extra: &Extra, config: &PyObject, is_field_validator: bool) -> PyResult<Self> {
331+
if is_field_validator {
332+
match extra.field_name {
333+
Some(field_name) => Ok(
334+
Self {
335+
config: config.clone_ref(py),
336+
context: extra.context.map(|v| v.into()),
337+
field_name: Some(field_name.to_string()),
338+
data: extra.data.map(|v| v.into()),
339+
}
340+
),
341+
_ => Err(PyRuntimeError::new_err("This validator expected to be run inside the context of a model field but no model field was found")),
342+
}
343+
} else {
344+
Ok(Self {
345+
config: config.clone_ref(py),
346+
context: extra.context.map(|v| v.into()),
347+
field_name: None,
348+
data: None,
349+
})
350+
}
351+
}
352+
}
353+
354+
#[pymethods]
355+
impl ValidationInfo {
356+
#[getter]
357+
fn get_data(&self, py: Python) -> PyResult<Py<PyDict>> {
358+
match self.data {
359+
Some(ref data) => Ok(data.clone_ref(py)),
360+
None => Err(PyAttributeError::new_err("No attribute named 'data'")),
361+
}
362+
}
363+
364+
#[getter]
365+
fn get_field_name<'py>(&self, py: Python<'py>) -> PyResult<&'py PyString> {
366+
match self.field_name {
367+
Some(ref field_name) => Ok(PyString::new(py, field_name)),
368+
None => Err(PyAttributeError::new_err("No attribute named 'field_name'")),
317369
}
318370
}
319371
}

src/validators/generator.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ impl InternalValidator {
218218
validator: validator.clone(),
219219
slots: slots.to_vec(),
220220
data: extra.data.map(|d| d.into_py(py)),
221-
field: extra.field.map(|f| f.to_string()),
221+
field: extra.assignee_field.map(|f| f.to_string()),
222222
strict: extra.strict,
223223
context: extra.context.map(|d| d.into_py(py)),
224224
recursion_guard: recursion_guard.clone(),
@@ -236,9 +236,10 @@ impl InternalValidator {
236236
{
237237
let extra = Extra {
238238
data: self.data.as_ref().map(|data| data.as_ref(py)),
239-
field: self.field.as_deref(),
239+
assignee_field: self.field.as_deref(),
240240
strict: self.strict,
241241
context: self.context.as_ref().map(|data| data.as_ref(py)),
242+
field_name: None,
242243
};
243244
self.validator
244245
.validate(py, input, &extra, &self.slots, &mut self.recursion_guard)

src/validators/mod.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,10 @@ impl SchemaValidator {
194194
) -> PyResult<PyObject> {
195195
let extra = Extra {
196196
data: Some(data),
197-
field: Some(field.as_str()),
197+
assignee_field: Some(field.as_str()),
198198
strict,
199199
context,
200+
field_name: None,
200201
};
201202
let r = self
202203
.validator
@@ -448,8 +449,11 @@ pub struct Extra<'a> {
448449
/// This is used as the `data` kwargs to validator functions, it also represents the current model
449450
/// data when validating assignment
450451
pub data: Option<&'a PyDict>,
452+
/// Represents the fields of the model we are currently validating
453+
/// If there is no model this will be None
454+
pub field_name: Option<&'a str>,
451455
/// The field being assigned to when validating assignment
452-
pub field: Option<&'a str>,
456+
pub assignee_field: Option<&'a str>,
453457
/// whether we're in strict or lax mode
454458
pub strict: Option<bool>,
455459
/// context used in validator functions
@@ -470,9 +474,10 @@ impl<'a> Extra<'a> {
470474
pub fn as_strict(&self) -> Self {
471475
Self {
472476
data: self.data,
473-
field: self.field,
477+
assignee_field: self.assignee_field,
474478
strict: Some(true),
475479
context: self.context,
480+
field_name: self.field_name,
476481
}
477482
}
478483
}

src/validators/typed_dict.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ impl Validator for TypedDictValidator {
162162
slots: &'data [CombinedValidator],
163163
recursion_guard: &'s mut RecursionGuard,
164164
) -> ValResult<'data, PyObject> {
165-
if let Some(field) = extra.field {
165+
if let Some(field) = extra.assignee_field {
166166
// we're validating assignment, completely different logic
167167
return self.validate_assignment(py, field, input, extra, slots, recursion_guard);
168168
}
@@ -183,16 +183,14 @@ impl Validator for TypedDictValidator {
183183
false => None,
184184
};
185185

186-
let extra = Extra {
187-
data: Some(output_dict),
188-
field: None,
189-
strict: extra.strict,
190-
context: extra.context,
191-
};
192-
193186
macro_rules! process {
194187
($dict:ident, $get_method:ident, $iter:ty $(,$kwargs:ident)?) => {{
195188
for field in &self.fields {
189+
let extra = Extra {
190+
data: Some(output_dict),
191+
field_name: Some(&field.name),
192+
..*extra
193+
};
196194
let op_key_value = match field.lookup_key.$get_method($dict $(, $kwargs )? ) {
197195
Ok(v) => v,
198196
Err(err) => {
@@ -348,6 +346,11 @@ impl TypedDictValidator {
348346
where
349347
'data: 's,
350348
{
349+
let extra = Extra {
350+
field_name: Some(field),
351+
assignee_field: None,
352+
..*extra
353+
};
351354
// TODO probably we should set location on errors here
352355
let data = match extra.data {
353356
Some(data) => data,
@@ -380,12 +383,12 @@ impl TypedDictValidator {
380383
if field.frozen {
381384
Err(ValError::new_with_loc(ErrorType::Frozen, input, field.name.to_string()))
382385
} else {
383-
prepare_result(field.validator.validate(py, input, extra, slots, recursion_guard))
386+
prepare_result(field.validator.validate(py, input, &extra, slots, recursion_guard))
384387
}
385388
} else if self.check_extra && !self.forbid_extra {
386389
// this is the "allow" case of extra_behavior
387390
match self.extra_validator {
388-
Some(ref validator) => prepare_result(validator.validate(py, input, extra, slots, recursion_guard)),
391+
Some(ref validator) => prepare_result(validator.validate(py, input, &extra, slots, recursion_guard)),
389392
None => prepare_tuple(input.to_object(py)),
390393
}
391394
} else {

tests/benchmarks/complete_schema.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,27 +110,33 @@ def wrap_function(input_value, validator, info):
110110
'schema': {
111111
'type': 'function',
112112
'mode': 'before',
113-
'function': append_func,
113+
'function': {'type': 'general', 'function': append_func},
114114
'schema': {'type': 'str'},
115115
}
116116
},
117117
'field_after': {
118118
'schema': {
119119
'type': 'function',
120120
'mode': 'after',
121-
'function': append_func,
121+
'function': {'type': 'general', 'function': append_func},
122122
'schema': {'type': 'str'},
123123
}
124124
},
125125
'field_wrap': {
126126
'schema': {
127127
'type': 'function',
128128
'mode': 'wrap',
129-
'function': wrap_function,
129+
'function': {'type': 'general', 'function': wrap_function},
130130
'schema': {'type': 'str'},
131131
}
132132
},
133-
'field_plain': {'schema': {'type': 'function', 'mode': 'plain', 'function': append_func}},
133+
'field_plain': {
134+
'schema': {
135+
'type': 'function',
136+
'mode': 'plain',
137+
'function': {'type': 'general', 'function': append_func},
138+
}
139+
},
134140
},
135141
}
136142
},

0 commit comments

Comments
 (0)