@@ -4,15 +4,17 @@ use std::marker::PhantomData;
4
4
use pyo3:: exceptions:: PyTypeError ;
5
5
use pyo3:: intern;
6
6
use pyo3:: prelude:: * ;
7
- use pyo3:: types:: { PyDict , PyFloat , PyInt , PyList , PyString , PyTuple , PyType } ;
7
+ use pyo3:: types:: { PyDict , PyFloat , PyInt , PyList , PyString , PyType } ;
8
8
9
9
use crate :: build_tools:: { is_strict, py_schema_err} ;
10
10
use crate :: errors:: { ErrorType , ValError , ValResult } ;
11
11
use crate :: input:: Input ;
12
+ use crate :: serializers:: { to_jsonable_python, SerializationConfig } ;
12
13
use crate :: tools:: { safe_repr, SchemaDict } ;
13
14
14
15
use super :: is_instance:: class_repr;
15
16
use super :: literal:: { expected_repr_name, LiteralLookup } ;
17
+ use super :: InputType ;
16
18
use super :: { BuildValidator , CombinedValidator , DefinitionsBuilder , Exactness , ValidationState , Validator } ;
17
19
18
20
#[ derive( Debug , Clone ) ]
@@ -33,36 +35,55 @@ impl BuildValidator for BuildEnumValidator {
33
35
34
36
let py = schema. py ( ) ;
35
37
let value_str = intern ! ( py, "value" ) ;
36
- let mut expected : Vec < ( Bound < ' _ , PyAny > , PyObject ) > = members
38
+ let expected_py : Vec < ( Bound < ' _ , PyAny > , PyObject ) > = members
37
39
. iter ( )
38
40
. map ( |v| Ok ( ( v. getattr ( value_str) ?, v. into ( ) ) ) )
39
41
. collect :: < PyResult < _ > > ( ) ?;
42
+ let ser_config = SerializationConfig :: from_config ( config) . unwrap_or_default ( ) ;
43
+ let expected_json: Vec < ( Bound < ' _ , PyAny > , PyObject ) > = members
44
+ . iter ( )
45
+ . map ( |v| {
46
+ Ok ( (
47
+ to_jsonable_python (
48
+ py,
49
+ & v. getattr ( value_str) ?,
50
+ None ,
51
+ None ,
52
+ false ,
53
+ false ,
54
+ false ,
55
+ & ser_config. timedelta_mode . to_string ( ) ,
56
+ & ser_config. bytes_mode . to_string ( ) ,
57
+ & ser_config. inf_nan_mode . to_string ( ) ,
58
+ false ,
59
+ None ,
60
+ true ,
61
+ None ,
62
+ ) ?
63
+ . into_bound ( py) ,
64
+ v. into ( ) ,
65
+ ) )
66
+ } )
67
+ . collect :: < PyResult < _ > > ( ) ?;
40
68
41
- let repr_args: Vec < String > = expected
69
+ let repr_args: Vec < String > = expected_py
42
70
. iter ( )
43
71
. map ( |( k, _) | k. repr ( ) ?. extract ( ) )
44
72
. collect :: < PyResult < _ > > ( ) ?;
45
73
46
- let mut addition = vec ! [ ] ;
47
- for ( k, v) in & expected {
48
- if let Ok ( ss) = k. downcast :: < PyTuple > ( ) {
49
- let list = ss. to_list ( ) ;
50
- addition. push ( ( list. into_any ( ) , v. clone ( ) ) ) ;
51
- }
52
- }
53
- expected. append ( & mut addition) ;
54
-
55
74
let class: Bound < PyType > = schema. get_as_req ( intern ! ( py, "cls" ) ) ?;
56
75
let class_repr = class_repr ( schema, & class) ?;
57
76
58
- let lookup = LiteralLookup :: new ( py, expected. into_iter ( ) ) ?;
77
+ let py_lookup = LiteralLookup :: new ( py, expected_py. into_iter ( ) ) ?;
78
+ let json_lookup = LiteralLookup :: new ( py, expected_json. into_iter ( ) ) ?;
59
79
60
80
macro_rules! build {
61
81
( $vv: ty, $name_prefix: literal) => {
62
82
EnumValidator {
63
83
phantom: PhantomData :: <$vv>,
64
84
class: class. clone( ) . into( ) ,
65
- lookup,
85
+ py_lookup,
86
+ json_lookup,
66
87
missing: schema. get_as( intern!( py, "missing" ) ) ?,
67
88
expected_repr: expected_repr_name( repr_args, "" ) . 0 ,
68
89
strict: is_strict( schema, config) ?,
@@ -96,7 +117,8 @@ pub trait EnumValidateValue: std::fmt::Debug + Clone + Send + Sync {
96
117
pub struct EnumValidator < T : EnumValidateValue > {
97
118
phantom : PhantomData < T > ,
98
119
class : Py < PyType > ,
99
- lookup : LiteralLookup < PyObject > ,
120
+ py_lookup : LiteralLookup < PyObject > ,
121
+ json_lookup : LiteralLookup < PyObject > ,
100
122
missing : Option < PyObject > ,
101
123
expected_repr : String ,
102
124
strict : bool ,
@@ -129,7 +151,11 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
129
151
130
152
state. floor_exactness ( Exactness :: Lax ) ;
131
153
132
- if let Some ( v) = T :: validate_value ( py, input, & self . lookup , strict) ? {
154
+ let lookup = match state. extra ( ) . input_type {
155
+ InputType :: Json => & self . json_lookup ,
156
+ _ => & self . py_lookup ,
157
+ } ;
158
+ if let Some ( v) = T :: validate_value ( py, input, lookup, strict) ? {
133
159
return Ok ( v) ;
134
160
} else if let Ok ( res) = class. as_unbound ( ) . call1 ( py, ( input. as_python ( ) , ) ) {
135
161
return Ok ( res) ;
0 commit comments