1
- use pyo3:: exceptions:: { PyAssertionError , PyTypeError , PyValueError } ;
1
+ use pyo3:: exceptions:: { PyAssertionError , PyAttributeError , PyRuntimeError , PyTypeError , PyValueError } ;
2
2
use pyo3:: intern;
3
3
use pyo3:: prelude:: * ;
4
- use pyo3:: types:: { PyAny , PyDict } ;
4
+ use pyo3:: types:: { PyAny , PyDict , PyString } ;
5
5
6
6
use crate :: build_tools:: { function_name, py_err, SchemaDict } ;
7
7
use crate :: errors:: {
@@ -16,6 +16,18 @@ use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Ex
16
16
17
17
pub struct FunctionBuilder ;
18
18
19
+ fn destructure_function_schema ( schema : & PyDict ) -> PyResult < ( bool , & PyAny ) > {
20
+ let func_dict: & PyDict = schema. get_as_req ( intern ! ( schema. py( ) , "function" ) ) ?;
21
+ let function: & PyAny = func_dict. get_as_req ( intern ! ( schema. py( ) , "function" ) ) ?;
22
+ let func_type: & str = func_dict. get_as_req ( intern ! ( schema. py( ) , "type" ) ) ?;
23
+ let is_field_validator = match func_type {
24
+ "field" => true ,
25
+ "general" => false ,
26
+ _ => unreachable ! ( ) ,
27
+ } ;
28
+ Ok ( ( is_field_validator, function) )
29
+ }
30
+
19
31
impl BuildValidator for FunctionBuilder {
20
32
const EXPECTED_TYPE : & ' static str = "function" ;
21
33
@@ -45,7 +57,7 @@ macro_rules! impl_build {
45
57
) -> PyResult <CombinedValidator > {
46
58
let py = schema. py( ) ;
47
59
let validator = build_validator( schema. get_as_req( intern!( py, "schema" ) ) ?, config, build_context) ?;
48
- let function = schema . get_as_req :: < & PyAny > ( intern! ( py , " function" ) ) ?;
60
+ let ( is_field_validator , function) = destructure_function_schema ( schema ) ?;
49
61
let name = format!(
50
62
"{}[{}(), {}]" ,
51
63
$name,
@@ -60,6 +72,7 @@ macro_rules! impl_build {
60
72
None => py. None ( ) ,
61
73
} ,
62
74
name,
75
+ is_field_validator,
63
76
}
64
77
. into( ) )
65
78
}
@@ -73,6 +86,7 @@ pub struct FunctionBeforeValidator {
73
86
func : PyObject ,
74
87
config : PyObject ,
75
88
name : String ,
89
+ is_field_validator : bool ,
76
90
}
77
91
78
92
impl_build ! ( FunctionBeforeValidator , "function-before" ) ;
@@ -86,7 +100,7 @@ impl Validator for FunctionBeforeValidator {
86
100
slots : & ' data [ CombinedValidator ] ,
87
101
recursion_guard : & ' s mut RecursionGuard ,
88
102
) -> ValResult < ' data , PyObject > {
89
- let info = ValidationInfo :: new ( py, extra, & self . config ) ;
103
+ let info = ValidationInfo :: new ( py, extra, & self . config , self . is_field_validator ) ? ;
90
104
let value = self
91
105
. func
92
106
. call1 ( py, ( input. to_object ( py) , info) )
@@ -115,6 +129,7 @@ pub struct FunctionAfterValidator {
115
129
func : PyObject ,
116
130
config : PyObject ,
117
131
name : String ,
132
+ is_field_validator : bool ,
118
133
}
119
134
120
135
impl_build ! ( FunctionAfterValidator , "function-after" ) ;
@@ -129,7 +144,7 @@ impl Validator for FunctionAfterValidator {
129
144
recursion_guard : & ' s mut RecursionGuard ,
130
145
) -> ValResult < ' data , PyObject > {
131
146
let v = self . validator . validate ( py, input, extra, slots, recursion_guard) ?;
132
- let info = ValidationInfo :: new ( py, extra, & self . config ) ;
147
+ let info = ValidationInfo :: new ( py, extra, & self . config , self . is_field_validator ) ? ;
133
148
self . func . call1 ( py, ( v, info) ) . map_err ( |e| convert_err ( py, e, input) )
134
149
}
135
150
@@ -151,19 +166,21 @@ pub struct FunctionPlainValidator {
151
166
func : PyObject ,
152
167
config : PyObject ,
153
168
name : String ,
169
+ is_field_validator : bool ,
154
170
}
155
171
156
172
impl FunctionPlainValidator {
157
173
pub fn build ( schema : & PyDict , config : Option < & PyDict > ) -> PyResult < CombinedValidator > {
158
174
let py = schema. py ( ) ;
159
- let function = schema . get_as_req :: < & PyAny > ( intern ! ( py , " function" ) ) ?;
175
+ let ( is_field_validator , function) = destructure_function_schema ( schema ) ?;
160
176
Ok ( Self {
161
177
func : function. into_py ( py) ,
162
178
config : match config {
163
179
Some ( c) => c. into ( ) ,
164
180
None => py. None ( ) ,
165
181
} ,
166
182
name : format ! ( "function-plain[{}()]" , function_name( function) ?) ,
183
+ is_field_validator,
167
184
}
168
185
. into ( ) )
169
186
}
@@ -178,7 +195,7 @@ impl Validator for FunctionPlainValidator {
178
195
_slots : & ' data [ CombinedValidator ] ,
179
196
_recursion_guard : & ' s mut RecursionGuard ,
180
197
) -> ValResult < ' data , PyObject > {
181
- let info = ValidationInfo :: new ( py, extra, & self . config ) ;
198
+ let info = ValidationInfo :: new ( py, extra, & self . config , self . is_field_validator ) ? ;
182
199
self . func
183
200
. call1 ( py, ( input. to_object ( py) , info) )
184
201
. map_err ( |e| convert_err ( py, e, input) )
@@ -195,6 +212,7 @@ pub struct FunctionWrapValidator {
195
212
func : PyObject ,
196
213
config : PyObject ,
197
214
name : String ,
215
+ is_field_validator : bool ,
198
216
}
199
217
200
218
impl_build ! ( FunctionWrapValidator , "function-wrap" ) ;
@@ -211,7 +229,7 @@ impl Validator for FunctionWrapValidator {
211
229
let call_next_validator = ValidatorCallable {
212
230
validator : InternalValidator :: new ( py, "ValidatorCallable" , & self . validator , slots, extra, recursion_guard) ,
213
231
} ;
214
- let info = ValidationInfo :: new ( py, extra, & self . config ) ;
232
+ let info = ValidationInfo :: new ( py, extra, & self . config , self . is_field_validator ) ? ;
215
233
self . func
216
234
. call1 ( py, ( input. to_object ( py) , call_next_validator, info) )
217
235
. map_err ( |e| convert_err ( py, e, input) )
@@ -300,20 +318,54 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) ->
300
318
301
319
#[ pyclass( module = "pydantic_core._pydantic_core" ) ]
302
320
pub struct ValidationInfo {
303
- #[ pyo3( get) ]
304
- data : Option < Py < PyDict > > ,
305
321
#[ pyo3( get) ]
306
322
config : PyObject ,
307
323
#[ pyo3( get) ]
308
324
context : Option < PyObject > ,
325
+ data : Option < Py < PyDict > > ,
326
+ field_name : Option < String > ,
309
327
}
310
328
311
329
impl ValidationInfo {
312
- fn new ( py : Python , extra : & Extra , config : & PyObject ) -> Self {
313
- Self {
314
- data : extra. data . map ( |v| v. into ( ) ) ,
315
- config : config. clone_ref ( py) ,
316
- context : extra. context . map ( |v| v. into ( ) ) ,
330
+ fn new ( py : Python , extra : & Extra , config : & PyObject , is_field_validator : bool ) -> PyResult < Self > {
331
+ if is_field_validator {
332
+ match extra. field_name {
333
+ Some ( field_name) => Ok (
334
+ Self {
335
+ config : config. clone_ref ( py) ,
336
+ context : extra. context . map ( |v| v. into ( ) ) ,
337
+ field_name : Some ( field_name. to_string ( ) ) ,
338
+ data : extra. data . map ( |v| v. into ( ) ) ,
339
+ }
340
+ ) ,
341
+ _ => Err ( PyRuntimeError :: new_err ( "This validator expected to be run inside the context of a model field but no model field was found" ) ) ,
342
+ }
343
+ } else {
344
+ Ok ( Self {
345
+ config : config. clone_ref ( py) ,
346
+ context : extra. context . map ( |v| v. into ( ) ) ,
347
+ field_name : None ,
348
+ data : None ,
349
+ } )
350
+ }
351
+ }
352
+ }
353
+
354
+ #[ pymethods]
355
+ impl ValidationInfo {
356
+ #[ getter]
357
+ fn get_data ( & self , py : Python ) -> PyResult < Py < PyDict > > {
358
+ match self . data {
359
+ Some ( ref data) => Ok ( data. clone_ref ( py) ) ,
360
+ None => Err ( PyAttributeError :: new_err ( "No attribute named 'data'" ) ) ,
361
+ }
362
+ }
363
+
364
+ #[ getter]
365
+ fn get_field_name < ' py > ( & self , py : Python < ' py > ) -> PyResult < & ' py PyString > {
366
+ match self . field_name {
367
+ Some ( ref field_name) => Ok ( PyString :: new ( py, field_name) ) ,
368
+ None => Err ( PyAttributeError :: new_err ( "No attribute named 'field_name'" ) ) ,
317
369
}
318
370
}
319
371
}
0 commit comments