1
- use std:: borrow:: Cow ;
2
- use std:: fmt;
3
1
use std:: fmt:: Write ;
4
2
5
- use pyo3:: exceptions:: PyTypeError ;
6
3
use pyo3:: intern;
7
4
use pyo3:: prelude:: * ;
8
5
use pyo3:: types:: { PyDict , PyList , PyString } ;
9
6
10
- use ahash:: AHashMap ;
11
-
12
7
use crate :: build_tools:: py_schema_err;
13
8
use crate :: build_tools:: { is_strict, schema_or_config} ;
14
9
use crate :: errors:: { ErrorType , LocItem , ValError , ValLineError , ValResult } ;
15
10
use crate :: input:: { GenericMapping , Input } ;
16
11
use crate :: lookup_key:: LookupKey ;
17
12
use crate :: recursion_guard:: RecursionGuard ;
18
- use crate :: tools:: { extract_i64 , py_err , SchemaDict } ;
13
+ use crate :: tools:: SchemaDict ;
19
14
20
15
use super :: custom_error:: CustomError ;
21
16
use super :: literal:: LiteralValidator ;
@@ -221,55 +216,12 @@ impl Discriminator {
221
216
}
222
217
}
223
218
224
- #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
225
- enum ChoiceKey {
226
- Int ( i64 ) ,
227
- Str ( String ) ,
228
- }
229
-
230
- impl ChoiceKey {
231
- fn from_py ( raw : & PyAny ) -> PyResult < Self > {
232
- if let Ok ( py_int) = extract_i64 ( raw) {
233
- Ok ( Self :: Int ( py_int) )
234
- } else if let Ok ( py_str) = raw. downcast :: < PyString > ( ) {
235
- Ok ( Self :: Str ( py_str. to_str ( ) ?. to_string ( ) ) )
236
- } else {
237
- py_err ! ( PyTypeError ; "Expected int or str, got {}" , raw. get_type( ) . name( ) . unwrap_or( "<unknown python object>" ) )
238
- }
239
- }
240
-
241
- fn repr ( & self ) -> String {
242
- match self {
243
- Self :: Int ( i) => i. to_string ( ) ,
244
- Self :: Str ( s) => format ! ( "'{s}'" ) ,
245
- }
246
- }
247
- }
248
-
249
- impl fmt:: Display for ChoiceKey {
250
- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
251
- match self {
252
- Self :: Int ( i) => write ! ( f, "{i}" ) ,
253
- Self :: Str ( s) => write ! ( f, "{s}" ) ,
254
- }
255
- }
256
- }
257
-
258
- impl From < & ChoiceKey > for LocItem {
259
- fn from ( key : & ChoiceKey ) -> Self {
260
- match key {
261
- ChoiceKey :: Str ( s) => s. as_str ( ) . into ( ) ,
262
- ChoiceKey :: Int ( i) => ( * i) . into ( ) ,
263
- }
264
- }
265
- }
266
-
267
219
#[ derive( Debug , Clone ) ]
268
220
pub struct TaggedUnionValidator {
269
- discriminator_validator : Box < CombinedValidator > ,
270
221
choices : Py < PyDict > ,
271
222
validators : Vec < CombinedValidator > ,
272
223
discriminator : Discriminator ,
224
+ discriminator_validator : Box < CombinedValidator > ,
273
225
from_attributes : bool ,
274
226
strict : bool ,
275
227
custom_error : Option < CustomError > ,
@@ -290,35 +242,43 @@ impl BuildValidator for TaggedUnionValidator {
290
242
let discriminator = Discriminator :: new ( py, schema. get_as_req ( intern ! ( py, "discriminator" ) ) ?) ?;
291
243
let discriminator_repr = discriminator. to_string_py ( py) ?;
292
244
293
- let schema_choices: & PyDict = schema. get_as_req ( intern ! ( py, "choices" ) ) ?;
294
245
let choices = PyDict :: new ( py) ;
295
246
let mut validators = Vec :: with_capacity ( choices. len ( ) ) ;
296
247
let mut tags_repr = String :: with_capacity ( 50 ) ;
297
248
let mut descr = String :: with_capacity ( 50 ) ;
298
249
let mut first = true ;
299
250
let mut discriminators = Vec :: with_capacity ( choices. len ( ) ) ;
300
- for ( choice_key, choice_schema) in schema {
251
+ let schema_choices: & PyDict = schema. get_as_req ( intern ! ( py, "choices" ) ) ?;
252
+ let schema_choice_keys = schema_choices. keys ( ) ;
253
+ let discriminator_validator = LiteralValidator :: build (
254
+ PyDict :: from_sequence (
255
+ py,
256
+ vec ! [
257
+ ( intern!( py, "type" ) , intern!( py, "literal" ) . as_ref( ) ) ,
258
+ ( intern!( py, "expected" ) , schema_choice_keys. as_ref( ) ) ,
259
+ ]
260
+ . into_py ( py) ,
261
+ ) ?,
262
+ config,
263
+ definitions,
264
+ ) ?;
265
+ for ( choice_key, choice_schema) in schema_choices. iter ( ) {
301
266
discriminators. push ( choice_key) ;
302
267
let validator = build_validator ( choice_schema, config, definitions) ?;
303
268
choices. set_item ( choice_key, validators. len ( ) ) ?;
304
- validators. push ( validator) ;
305
269
let tag_repr = choice_key. repr ( ) ?. to_string ( ) ;
306
270
if first {
307
271
first = false ;
308
272
write ! ( tags_repr, "{tag_repr}" ) . unwrap ( ) ;
309
- descr. push_str ( validator. get_name ( ) ) ;
273
+ descr. push_str ( validator. get_name ( ) . clone ( ) ) ;
310
274
} else {
311
275
write ! ( tags_repr, ", {tag_repr}" ) . unwrap ( ) ;
312
276
// no spaces in get_name() output to make loc easy to read
313
277
write ! ( descr, ",{}" , validator. get_name( ) ) . unwrap ( ) ;
314
278
}
279
+ validators. push ( validator) ;
315
280
}
316
281
317
- let discriminator_validator_schema = PyDict :: new ( py) ;
318
- discriminator_validator_schema. set_item ( intern ! ( py, "type" ) , intern ! ( py, "literal" ) ) ?;
319
- discriminator_validator_schema. set_item ( intern ! ( py, "expected" ) , discriminators. into_py ( py) ) ?;
320
- let discriminator_validator = build_validator ( discriminator_validator_schema, config, definitions) ?;
321
-
322
282
let key = intern ! ( py, "from_attributes" ) ;
323
283
let from_attributes = schema_or_config ( schema, config, key, key) ?. unwrap_or ( true ) ;
324
284
@@ -330,8 +290,8 @@ impl BuildValidator for TaggedUnionValidator {
330
290
Ok ( Self {
331
291
choices : choices. into ( ) ,
332
292
validators,
333
- discriminator_validator : Box :: new ( discriminator_validator) ,
334
293
discriminator,
294
+ discriminator_validator : Box :: new ( discriminator_validator) ,
335
295
from_attributes,
336
296
strict : is_strict ( schema, config) ?,
337
297
custom_error : CustomError :: build ( schema, config, definitions) ?,
@@ -360,12 +320,7 @@ impl Validator for TaggedUnionValidator {
360
320
// errors when getting attributes which should be "raised"
361
321
match lookup_key. $get_method( $( $dict ) ,+) ? {
362
322
Some ( ( _, value) ) => {
363
- if let Ok ( either_int) = value. validate_int( self . strict) {
364
- let int = either_int. into_i64( py) ?;
365
- Ok ( ChoiceKey :: Int ( int) )
366
- } else {
367
- Ok ( ChoiceKey :: Str ( value. validate_str( self . strict) ?. as_cow( ) ?. as_ref( ) . to_string( ) ) )
368
- }
323
+ Ok ( value. to_object( py) . into_ref( py) )
369
324
}
370
325
None => Err ( self . tag_not_found( input) ) ,
371
326
}
@@ -379,27 +334,19 @@ impl Validator for TaggedUnionValidator {
379
334
GenericMapping :: PyMapping ( mapping) => find_validator ! ( py_get_mapping_item, mapping) ,
380
335
GenericMapping :: JsonObject ( mapping) => find_validator ! ( json_get, mapping) ,
381
336
} ?;
382
- self . find_call_validator ( py, & tag, input, extra, definitions, recursion_guard)
337
+ self . find_call_validator ( py, tag, input, extra, definitions, recursion_guard)
383
338
}
384
339
Discriminator :: Function ( ref func) => {
385
340
let tag = func. call1 ( py, ( input. to_object ( py) , ) ) ?;
386
341
if tag. is_none ( py) {
387
342
Err ( self . tag_not_found ( input) )
388
343
} else {
389
- let tag: & PyAny = tag. downcast ( py) ?;
390
- self . find_call_validator (
391
- py,
392
- & ( ChoiceKey :: from_py ( tag) ?) ,
393
- input,
394
- extra,
395
- definitions,
396
- recursion_guard,
397
- )
344
+ self . find_call_validator ( py, tag. into_ref ( py) , input, extra, definitions, recursion_guard)
398
345
}
399
346
}
400
347
Discriminator :: SelfSchema => self . find_call_validator (
401
348
py,
402
- & ChoiceKey :: Str ( self . self_schema_tag ( py, input) ?. into_owned ( ) ) ,
349
+ & self . self_schema_tag ( py, input) ?. as_ref ( ) ,
403
350
input,
404
351
extra,
405
352
definitions,
@@ -413,7 +360,8 @@ impl Validator for TaggedUnionValidator {
413
360
definitions : Option < & DefinitionsBuilder < CombinedValidator > > ,
414
361
ultra_strict : bool ,
415
362
) -> bool {
416
- self . validators . iter ( )
363
+ self . validators
364
+ . iter ( )
417
365
. any ( |v| v. different_strict_behavior ( definitions, ultra_strict) )
418
366
}
419
367
@@ -422,8 +370,9 @@ impl Validator for TaggedUnionValidator {
422
370
}
423
371
424
372
fn complete ( & mut self , definitions : & DefinitionsBuilder < CombinedValidator > ) -> PyResult < ( ) > {
425
- self . validators . iter ( )
426
- . try_for_each ( |( _, validator) | validator. complete ( definitions) )
373
+ self . validators
374
+ . iter_mut ( )
375
+ . try_for_each ( |validator| validator. complete ( definitions) )
427
376
}
428
377
}
429
378
@@ -432,7 +381,7 @@ impl TaggedUnionValidator {
432
381
& ' s self ,
433
382
py : Python < ' data > ,
434
383
input : & ' data impl Input < ' data > ,
435
- ) -> ValResult < ' data , Cow < ' data , str > > {
384
+ ) -> ValResult < ' data , & ' data PyString > {
436
385
let dict = input. strict_dict ( ) ?;
437
386
let either_tag = match dict {
438
387
GenericMapping :: PyDict ( dict) => match dict. get_item ( intern ! ( py, "type" ) ) {
@@ -455,44 +404,44 @@ impl TaggedUnionValidator {
455
404
if tag == "function" {
456
405
let mode = mode. ok_or_else ( || self . tag_not_found ( input) ) ?;
457
406
match mode. as_cow ( ) ?. as_ref ( ) {
458
- "plain" => Ok ( Cow :: Borrowed ( "function-plain" ) ) ,
459
- "wrap" => Ok ( Cow :: Borrowed ( "function-wrap" ) ) ,
460
- _ => Ok ( Cow :: Borrowed ( "function" ) ) ,
407
+ "plain" => Ok ( intern ! ( py , "function-plain" ) ) ,
408
+ "wrap" => Ok ( intern ! ( py , "function-wrap" ) ) ,
409
+ _ => Ok ( intern ! ( py , "function" ) ) ,
461
410
}
462
411
} else {
463
412
// tag == "tuple"
464
413
if let Some ( mode) = mode {
465
414
if mode. as_cow ( ) ?. as_ref ( ) == "positional" {
466
- return Ok ( Cow :: Borrowed ( "tuple-positional" ) ) ;
415
+ return Ok ( intern ! ( py , "tuple-positional" ) ) ;
467
416
}
468
417
}
469
- Ok ( Cow :: Borrowed ( "tuple-variable" ) )
418
+ Ok ( intern ! ( py , "tuple-variable" ) )
470
419
}
471
420
} else {
472
- Ok ( Cow :: Owned ( tag. to_string ( ) ) )
421
+ Ok ( PyString :: new ( py , tag) )
473
422
}
474
423
}
475
424
476
425
fn find_call_validator < ' s , ' data > (
477
426
& ' s self ,
478
427
py : Python < ' data > ,
479
- tag : & ChoiceKey ,
428
+ tag : & ' data PyAny ,
480
429
input : & ' data impl Input < ' data > ,
481
430
extra : & Extra ,
482
431
definitions : & ' data Definitions < CombinedValidator > ,
483
432
recursion_guard : & ' s mut RecursionGuard ,
484
433
) -> ValResult < ' data , PyObject > {
485
- if let Some ( validator ) = self . choices . get ( tag ) {
486
- return match validator . validate ( py , input , extra , definitions , recursion_guard ) {
487
- Ok ( res ) => Ok ( res ) ,
488
- Err ( err ) => Err ( err . with_outer_location ( tag . into ( ) ) ) ,
489
- } ;
490
- } else if let Some ( ref repeat_choices ) = self . repeat_choices {
491
- if let Some ( choice_tag ) = repeat_choices . get ( tag ) {
492
- let validator = & self . choices [ choice_tag ] ;
434
+ if let Ok ( tag ) = self
435
+ . discriminator_validator
436
+ . validate ( py , tag , extra , definitions , recursion_guard )
437
+ {
438
+ let tag = tag . as_ref ( py ) ;
439
+ if let Some ( validator_idx ) = self . choices . as_ref ( py ) . get_item ( tag ) {
440
+ // We know this will always be a usize because we put it there ourselves
441
+ let validator = & self . validators [ usize :: extract ( validator_idx ) . unwrap ( ) ] ;
493
442
return match validator. validate ( py, input, extra, definitions, recursion_guard) {
494
443
Ok ( res) => Ok ( res) ,
495
- Err ( err) => Err ( err. with_outer_location ( tag . into ( ) ) ) ,
444
+ Err ( err) => Err ( err. with_outer_location ( LocItem :: try_from ( tag ) ? ) ) ,
496
445
} ;
497
446
}
498
447
}
0 commit comments