Skip to content

Commit 7c3b30b

Browse files
committed
Fix tagged union for non-scalar enum values
1 parent e7108e8 commit 7c3b30b

File tree

2 files changed

+26
-49
lines changed

2 files changed

+26
-49
lines changed

pydantic_core/core_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from collections.abc import Mapping
55
from datetime import date, datetime, time, timedelta
6-
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
6+
from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union
77

88
if sys.version_info < (3, 11):
99
from typing_extensions import Protocol, Required, TypeAlias
@@ -2376,7 +2376,7 @@ class TaggedUnionSchema(TypedDict, total=False):
23762376

23772377

23782378
def tagged_union_schema(
2379-
choices: Dict[Union[int, str], int | str | CoreSchema],
2379+
choices: Dict[Hashable, CoreSchema],
23802380
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | int | None],
23812381
*,
23822382
custom_error_type: str | None = None,

src/validators/union.rs

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::recursion_guard::RecursionGuard;
1818
use crate::tools::{extract_i64, py_err, SchemaDict};
1919

2020
use super::custom_error::CustomError;
21+
use super::literal::LiteralValidator;
2122
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
2223

2324
#[derive(Debug, Clone)]
@@ -265,8 +266,9 @@ impl From<&ChoiceKey> for LocItem {
265266

266267
#[derive(Debug, Clone)]
267268
pub struct TaggedUnionValidator {
268-
choices: AHashMap<ChoiceKey, CombinedValidator>,
269-
repeat_choices: Option<AHashMap<ChoiceKey, ChoiceKey>>,
269+
discriminator_validator: Box<CombinedValidator>,
270+
choices: Py<PyDict>,
271+
validators: Vec<CombinedValidator>,
270272
discriminator: Discriminator,
271273
from_attributes: bool,
272274
strict: bool,
@@ -289,22 +291,18 @@ impl BuildValidator for TaggedUnionValidator {
289291
let discriminator_repr = discriminator.to_string_py(py)?;
290292

291293
let schema_choices: &PyDict = schema.get_as_req(intern!(py, "choices"))?;
292-
let mut choices = AHashMap::with_capacity(schema_choices.len());
293-
let mut repeat_choices_vec: Vec<(ChoiceKey, ChoiceKey)> = Vec::new();
294-
let mut first = true;
294+
let choices = PyDict::new(py);
295+
let mut validators = Vec::with_capacity(choices.len());
295296
let mut tags_repr = String::with_capacity(50);
296297
let mut descr = String::with_capacity(50);
297-
298-
for (key, value) in schema_choices {
299-
let tag = ChoiceKey::from_py(key)?;
300-
301-
if let Ok(repeat_tag) = ChoiceKey::from_py(value) {
302-
repeat_choices_vec.push((tag, repeat_tag));
303-
continue;
304-
}
305-
306-
let validator = build_validator(value, config, definitions)?;
307-
let tag_repr = tag.repr();
298+
let mut first = true;
299+
let mut discriminators = Vec::with_capacity(choices.len());
300+
for (choice_key, choice_schema) in schema {
301+
discriminators.push(choice_key);
302+
let validator = build_validator(choice_schema, config, definitions)?;
303+
choices.set_item(choice_key, validators.len())?;
304+
validators.push(validator);
305+
let tag_repr = choice_key.repr()?.to_string();
308306
if first {
309307
first = false;
310308
write!(tags_repr, "{tag_repr}").unwrap();
@@ -314,32 +312,12 @@ impl BuildValidator for TaggedUnionValidator {
314312
// no spaces in get_name() output to make loc easy to read
315313
write!(descr, ",{}", validator.get_name()).unwrap();
316314
}
317-
choices.insert(tag, validator);
318315
}
319-
let repeat_choices = if repeat_choices_vec.is_empty() {
320-
None
321-
} else {
322-
let mut wrong_values = Vec::with_capacity(repeat_choices_vec.len());
323-
let mut repeat_choices = AHashMap::with_capacity(repeat_choices_vec.len());
324-
for (tag, repeat_tag) in repeat_choices_vec {
325-
match choices.get(&repeat_tag) {
326-
Some(validator) => {
327-
let tag_repr = tag.repr();
328-
write!(tags_repr, ", {tag_repr}").unwrap();
329-
write!(descr, ",{}", validator.get_name()).unwrap();
330-
repeat_choices.insert(tag, repeat_tag);
331-
}
332-
None => wrong_values.push(format!("`{repeat_tag}`")),
333-
}
334-
}
335-
if !wrong_values.is_empty() {
336-
return py_schema_err!(
337-
"String values in choices don't match any keys: {}",
338-
wrong_values.join(", ")
339-
);
340-
}
341-
Some(repeat_choices)
342-
};
316+
317+
let discriminator_validator_schema = PyDict::new(py);
318+
discriminator_validator_schema.set_item(intern!(py, "type"), intern!(py, "literal"))?;
319+
discriminator_validator_schema.set_item(intern!(py, "expected"), discriminators.into_py(py))?;
320+
let discriminator_validator = build_validator(discriminator_validator_schema, config, definitions)?;
343321

344322
let key = intern!(py, "from_attributes");
345323
let from_attributes = schema_or_config(schema, config, key, key)?.unwrap_or(true);
@@ -350,8 +328,9 @@ impl BuildValidator for TaggedUnionValidator {
350328
};
351329

352330
Ok(Self {
353-
choices,
354-
repeat_choices,
331+
choices: choices.into(),
332+
validators,
333+
discriminator_validator: Box::new(discriminator_validator),
355334
discriminator,
356335
from_attributes,
357336
strict: is_strict(schema, config)?,
@@ -434,8 +413,7 @@ impl Validator for TaggedUnionValidator {
434413
definitions: Option<&DefinitionsBuilder<CombinedValidator>>,
435414
ultra_strict: bool,
436415
) -> bool {
437-
self.choices
438-
.values()
416+
self.validators.iter()
439417
.any(|v| v.different_strict_behavior(definitions, ultra_strict))
440418
}
441419

@@ -444,8 +422,7 @@ impl Validator for TaggedUnionValidator {
444422
}
445423

446424
fn complete(&mut self, definitions: &DefinitionsBuilder<CombinedValidator>) -> PyResult<()> {
447-
self.choices
448-
.iter_mut()
425+
self.validators.iter()
449426
.try_for_each(|(_, validator)| validator.complete(definitions))
450427
}
451428
}

0 commit comments

Comments
 (0)