@@ -18,6 +18,7 @@ use crate::recursion_guard::RecursionGuard;
18
18
use crate :: tools:: { extract_i64, py_err, SchemaDict } ;
19
19
20
20
use super :: custom_error:: CustomError ;
21
+ use super :: literal:: LiteralValidator ;
21
22
use super :: { build_validator, BuildValidator , CombinedValidator , Definitions , DefinitionsBuilder , Extra , Validator } ;
22
23
23
24
#[ derive( Debug , Clone ) ]
@@ -265,8 +266,9 @@ impl From<&ChoiceKey> for LocItem {
265
266
266
267
#[ derive( Debug , Clone ) ]
267
268
pub struct TaggedUnionValidator {
268
- choices : AHashMap < ChoiceKey , CombinedValidator > ,
269
- repeat_choices : Option < AHashMap < ChoiceKey , ChoiceKey > > ,
269
+ discriminator_validator : Box < CombinedValidator > ,
270
+ choices : Py < PyDict > ,
271
+ validators : Vec < CombinedValidator > ,
270
272
discriminator : Discriminator ,
271
273
from_attributes : bool ,
272
274
strict : bool ,
@@ -289,22 +291,18 @@ impl BuildValidator for TaggedUnionValidator {
289
291
let discriminator_repr = discriminator. to_string_py ( py) ?;
290
292
291
293
let schema_choices: & PyDict = schema. get_as_req ( intern ! ( py, "choices" ) ) ?;
292
- let mut choices = AHashMap :: with_capacity ( schema_choices. len ( ) ) ;
293
- let mut repeat_choices_vec: Vec < ( ChoiceKey , ChoiceKey ) > = Vec :: new ( ) ;
294
- let mut first = true ;
294
+ let choices = PyDict :: new ( py) ;
295
+ let mut validators = Vec :: with_capacity ( choices. len ( ) ) ;
295
296
let mut tags_repr = String :: with_capacity ( 50 ) ;
296
297
let mut descr = String :: with_capacity ( 50 ) ;
297
-
298
- for ( key, value) in schema_choices {
299
- let tag = ChoiceKey :: from_py ( key) ?;
300
-
301
- if let Ok ( repeat_tag) = ChoiceKey :: from_py ( value) {
302
- repeat_choices_vec. push ( ( tag, repeat_tag) ) ;
303
- continue ;
304
- }
305
-
306
- let validator = build_validator ( value, config, definitions) ?;
307
- let tag_repr = tag. repr ( ) ;
298
+ let mut first = true ;
299
+ let mut discriminators = Vec :: with_capacity ( choices. len ( ) ) ;
300
+ for ( choice_key, choice_schema) in schema {
301
+ discriminators. push ( choice_key) ;
302
+ let validator = build_validator ( choice_schema, config, definitions) ?;
303
+ choices. set_item ( choice_key, validators. len ( ) ) ?;
304
+ validators. push ( validator) ;
305
+ let tag_repr = choice_key. repr ( ) ?. to_string ( ) ;
308
306
if first {
309
307
first = false ;
310
308
write ! ( tags_repr, "{tag_repr}" ) . unwrap ( ) ;
@@ -314,32 +312,12 @@ impl BuildValidator for TaggedUnionValidator {
314
312
// no spaces in get_name() output to make loc easy to read
315
313
write ! ( descr, ",{}" , validator. get_name( ) ) . unwrap ( ) ;
316
314
}
317
- choices. insert ( tag, validator) ;
318
315
}
319
- let repeat_choices = if repeat_choices_vec. is_empty ( ) {
320
- None
321
- } else {
322
- let mut wrong_values = Vec :: with_capacity ( repeat_choices_vec. len ( ) ) ;
323
- let mut repeat_choices = AHashMap :: with_capacity ( repeat_choices_vec. len ( ) ) ;
324
- for ( tag, repeat_tag) in repeat_choices_vec {
325
- match choices. get ( & repeat_tag) {
326
- Some ( validator) => {
327
- let tag_repr = tag. repr ( ) ;
328
- write ! ( tags_repr, ", {tag_repr}" ) . unwrap ( ) ;
329
- write ! ( descr, ",{}" , validator. get_name( ) ) . unwrap ( ) ;
330
- repeat_choices. insert ( tag, repeat_tag) ;
331
- }
332
- None => wrong_values. push ( format ! ( "`{repeat_tag}`" ) ) ,
333
- }
334
- }
335
- if !wrong_values. is_empty ( ) {
336
- return py_schema_err ! (
337
- "String values in choices don't match any keys: {}" ,
338
- wrong_values. join( ", " )
339
- ) ;
340
- }
341
- Some ( repeat_choices)
342
- } ;
316
+
317
+ let discriminator_validator_schema = PyDict :: new ( py) ;
318
+ discriminator_validator_schema. set_item ( intern ! ( py, "type" ) , intern ! ( py, "literal" ) ) ?;
319
+ discriminator_validator_schema. set_item ( intern ! ( py, "expected" ) , discriminators. into_py ( py) ) ?;
320
+ let discriminator_validator = build_validator ( discriminator_validator_schema, config, definitions) ?;
343
321
344
322
let key = intern ! ( py, "from_attributes" ) ;
345
323
let from_attributes = schema_or_config ( schema, config, key, key) ?. unwrap_or ( true ) ;
@@ -350,8 +328,9 @@ impl BuildValidator for TaggedUnionValidator {
350
328
} ;
351
329
352
330
Ok ( Self {
353
- choices,
354
- repeat_choices,
331
+ choices : choices. into ( ) ,
332
+ validators,
333
+ discriminator_validator : Box :: new ( discriminator_validator) ,
355
334
discriminator,
356
335
from_attributes,
357
336
strict : is_strict ( schema, config) ?,
@@ -434,8 +413,7 @@ impl Validator for TaggedUnionValidator {
434
413
definitions : Option < & DefinitionsBuilder < CombinedValidator > > ,
435
414
ultra_strict : bool ,
436
415
) -> bool {
437
- self . choices
438
- . values ( )
416
+ self . validators . iter ( )
439
417
. any ( |v| v. different_strict_behavior ( definitions, ultra_strict) )
440
418
}
441
419
@@ -444,8 +422,7 @@ impl Validator for TaggedUnionValidator {
444
422
}
445
423
446
424
fn complete ( & mut self , definitions : & DefinitionsBuilder < CombinedValidator > ) -> PyResult < ( ) > {
447
- self . choices
448
- . iter_mut ( )
425
+ self . validators . iter ( )
449
426
. try_for_each ( |( _, validator) | validator. complete ( definitions) )
450
427
}
451
428
}
0 commit comments