Skip to content

Commit 3b000c9

Browse files
committed
Fix tagged union for non-scalar enum values
1 parent 7c3b30b commit 3b000c9

File tree

5 files changed

+77
-215
lines changed

5 files changed

+77
-215
lines changed

generate_self_schema.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def tagged_union(std_union_schema: Dict[str, Any], discriminator_key: str, ref:
9595
first, *rest = literal
9696
tagged_choices[first] = choice
9797
for arg in rest:
98-
tagged_choices[arg] = first
98+
tagged_choices[arg] = choice
9999
s = {'type': 'tagged-union', 'discriminator': discriminator_key, 'choices': tagged_choices}
100100
if ref is not None:
101101
s['ref'] = ref
@@ -129,15 +129,8 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
129129
schema = {'type': 'list', 'items_schema': schema_ref_validator}
130130
elif fr_arg == 'Dict[str, CoreSchema]':
131131
schema = {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
132-
elif fr_arg == 'Dict[Union[str, int], Union[str, int, CoreSchema]]':
133-
schema = {
134-
'type': 'dict',
135-
'keys_schema': {'type': 'union', 'choices': [{'type': 'str'}, {'type': 'int'}]},
136-
'values_schema': {
137-
'type': 'union',
138-
'choices': [{'type': 'str'}, {'type': 'int'}, schema_ref_validator],
139-
},
140-
}
132+
elif fr_arg == 'Dict[Hashable, CoreSchema]':
133+
schema = {'type': 'dict', 'keys_schema': {'type': 'any'}, 'values_schema': schema_ref_validator}
141134
else:
142135
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')
143136
else:

pydantic_core/core_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2361,7 +2361,7 @@ def union_schema(
23612361

23622362
class TaggedUnionSchema(TypedDict, total=False):
23632363
type: Required[Literal['tagged-union']]
2364-
choices: Required[Dict[Union[str, int], Union[str, int, CoreSchema]]]
2364+
choices: Required[Dict[Hashable, CoreSchema]]
23652365
discriminator: Required[
23662366
Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[Union[str, int]]]]
23672367
]

src/lookup_key.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::errors::{ErrorType, ValLineError};
1010
use crate::input::{Input, JsonInput, JsonObject};
1111
use crate::tools::{extract_i64, py_err};
1212

13-
/// Used got getting items from python dicts, python objects, or JSON objects, in different ways
13+
/// Used for getting items from python dicts, python objects, or JSON objects, in different ways
1414
#[derive(Debug, Clone)]
1515
pub(crate) enum LookupKey {
1616
/// simply look up a key in a dict, equivalent to `d.get(key)`
@@ -29,7 +29,7 @@ pub(crate) enum LookupKey {
2929
py_key2: Py<PyString>,
3030
path2: LookupPath,
3131
},
32-
/// look up keys buy one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
32+
/// look up keys by one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
3333
/// ints are also supported to index arrays/lists/tuples and dicts with int keys
3434
/// we reuse Location as the enum is the same, and the meaning is the same
3535
PathChoices(Vec<LookupPath>),

src/validators/union.rs

Lines changed: 46 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
1-
use std::borrow::Cow;
2-
use std::fmt;
31
use std::fmt::Write;
42

5-
use pyo3::exceptions::PyTypeError;
63
use pyo3::intern;
74
use pyo3::prelude::*;
85
use pyo3::types::{PyDict, PyList, PyString};
96

10-
use ahash::AHashMap;
11-
127
use crate::build_tools::py_schema_err;
138
use crate::build_tools::{is_strict, schema_or_config};
149
use crate::errors::{ErrorType, LocItem, ValError, ValLineError, ValResult};
1510
use crate::input::{GenericMapping, Input};
1611
use crate::lookup_key::LookupKey;
1712
use crate::recursion_guard::RecursionGuard;
18-
use crate::tools::{extract_i64, py_err, SchemaDict};
13+
use crate::tools::SchemaDict;
1914

2015
use super::custom_error::CustomError;
2116
use super::literal::LiteralValidator;
@@ -221,55 +216,12 @@ impl Discriminator {
221216
}
222217
}
223218

224-
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
225-
enum ChoiceKey {
226-
Int(i64),
227-
Str(String),
228-
}
229-
230-
impl ChoiceKey {
231-
fn from_py(raw: &PyAny) -> PyResult<Self> {
232-
if let Ok(py_int) = extract_i64(raw) {
233-
Ok(Self::Int(py_int))
234-
} else if let Ok(py_str) = raw.downcast::<PyString>() {
235-
Ok(Self::Str(py_str.to_str()?.to_string()))
236-
} else {
237-
py_err!(PyTypeError; "Expected int or str, got {}", raw.get_type().name().unwrap_or("<unknown python object>"))
238-
}
239-
}
240-
241-
fn repr(&self) -> String {
242-
match self {
243-
Self::Int(i) => i.to_string(),
244-
Self::Str(s) => format!("'{s}'"),
245-
}
246-
}
247-
}
248-
249-
impl fmt::Display for ChoiceKey {
250-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251-
match self {
252-
Self::Int(i) => write!(f, "{i}"),
253-
Self::Str(s) => write!(f, "{s}"),
254-
}
255-
}
256-
}
257-
258-
impl From<&ChoiceKey> for LocItem {
259-
fn from(key: &ChoiceKey) -> Self {
260-
match key {
261-
ChoiceKey::Str(s) => s.as_str().into(),
262-
ChoiceKey::Int(i) => (*i).into(),
263-
}
264-
}
265-
}
266-
267219
#[derive(Debug, Clone)]
268220
pub struct TaggedUnionValidator {
269-
discriminator_validator: Box<CombinedValidator>,
270221
choices: Py<PyDict>,
271222
validators: Vec<CombinedValidator>,
272223
discriminator: Discriminator,
224+
discriminator_validator: Box<CombinedValidator>,
273225
from_attributes: bool,
274226
strict: bool,
275227
custom_error: Option<CustomError>,
@@ -290,35 +242,43 @@ impl BuildValidator for TaggedUnionValidator {
290242
let discriminator = Discriminator::new(py, schema.get_as_req(intern!(py, "discriminator"))?)?;
291243
let discriminator_repr = discriminator.to_string_py(py)?;
292244

293-
let schema_choices: &PyDict = schema.get_as_req(intern!(py, "choices"))?;
294245
let choices = PyDict::new(py);
295246
let mut validators = Vec::with_capacity(choices.len());
296247
let mut tags_repr = String::with_capacity(50);
297248
let mut descr = String::with_capacity(50);
298249
let mut first = true;
299250
let mut discriminators = Vec::with_capacity(choices.len());
300-
for (choice_key, choice_schema) in schema {
251+
let schema_choices: &PyDict = schema.get_as_req(intern!(py, "choices"))?;
252+
let schema_choice_keys = schema_choices.keys();
253+
let discriminator_validator = LiteralValidator::build(
254+
PyDict::from_sequence(
255+
py,
256+
vec![
257+
(intern!(py, "type"), intern!(py, "literal").as_ref()),
258+
(intern!(py, "expected"), schema_choice_keys.as_ref()),
259+
]
260+
.into_py(py),
261+
)?,
262+
config,
263+
definitions,
264+
)?;
265+
for (choice_key, choice_schema) in schema_choices.iter() {
301266
discriminators.push(choice_key);
302267
let validator = build_validator(choice_schema, config, definitions)?;
303268
choices.set_item(choice_key, validators.len())?;
304-
validators.push(validator);
305269
let tag_repr = choice_key.repr()?.to_string();
306270
if first {
307271
first = false;
308272
write!(tags_repr, "{tag_repr}").unwrap();
309-
descr.push_str(validator.get_name());
273+
descr.push_str(validator.get_name().clone());
310274
} else {
311275
write!(tags_repr, ", {tag_repr}").unwrap();
312276
// no spaces in get_name() output to make loc easy to read
313277
write!(descr, ",{}", validator.get_name()).unwrap();
314278
}
279+
validators.push(validator);
315280
}
316281

317-
let discriminator_validator_schema = PyDict::new(py);
318-
discriminator_validator_schema.set_item(intern!(py, "type"), intern!(py, "literal"))?;
319-
discriminator_validator_schema.set_item(intern!(py, "expected"), discriminators.into_py(py))?;
320-
let discriminator_validator = build_validator(discriminator_validator_schema, config, definitions)?;
321-
322282
let key = intern!(py, "from_attributes");
323283
let from_attributes = schema_or_config(schema, config, key, key)?.unwrap_or(true);
324284

@@ -330,8 +290,8 @@ impl BuildValidator for TaggedUnionValidator {
330290
Ok(Self {
331291
choices: choices.into(),
332292
validators,
333-
discriminator_validator: Box::new(discriminator_validator),
334293
discriminator,
294+
discriminator_validator: Box::new(discriminator_validator),
335295
from_attributes,
336296
strict: is_strict(schema, config)?,
337297
custom_error: CustomError::build(schema, config, definitions)?,
@@ -360,12 +320,7 @@ impl Validator for TaggedUnionValidator {
360320
// errors when getting attributes which should be "raised"
361321
match lookup_key.$get_method($( $dict ),+)? {
362322
Some((_, value)) => {
363-
if let Ok(either_int) = value.validate_int(self.strict) {
364-
let int = either_int.into_i64(py)?;
365-
Ok(ChoiceKey::Int(int))
366-
} else {
367-
Ok(ChoiceKey::Str(value.validate_str(self.strict)?.as_cow()?.as_ref().to_string()))
368-
}
323+
Ok(value.to_object(py).into_ref(py))
369324
}
370325
None => Err(self.tag_not_found(input)),
371326
}
@@ -379,27 +334,19 @@ impl Validator for TaggedUnionValidator {
379334
GenericMapping::PyMapping(mapping) => find_validator!(py_get_mapping_item, mapping),
380335
GenericMapping::JsonObject(mapping) => find_validator!(json_get, mapping),
381336
}?;
382-
self.find_call_validator(py, &tag, input, extra, definitions, recursion_guard)
337+
self.find_call_validator(py, tag, input, extra, definitions, recursion_guard)
383338
}
384339
Discriminator::Function(ref func) => {
385340
let tag = func.call1(py, (input.to_object(py),))?;
386341
if tag.is_none(py) {
387342
Err(self.tag_not_found(input))
388343
} else {
389-
let tag: &PyAny = tag.downcast(py)?;
390-
self.find_call_validator(
391-
py,
392-
&(ChoiceKey::from_py(tag)?),
393-
input,
394-
extra,
395-
definitions,
396-
recursion_guard,
397-
)
344+
self.find_call_validator(py, tag.into_ref(py), input, extra, definitions, recursion_guard)
398345
}
399346
}
400347
Discriminator::SelfSchema => self.find_call_validator(
401348
py,
402-
&ChoiceKey::Str(self.self_schema_tag(py, input)?.into_owned()),
349+
&self.self_schema_tag(py, input)?.as_ref(),
403350
input,
404351
extra,
405352
definitions,
@@ -413,7 +360,8 @@ impl Validator for TaggedUnionValidator {
413360
definitions: Option<&DefinitionsBuilder<CombinedValidator>>,
414361
ultra_strict: bool,
415362
) -> bool {
416-
self.validators.iter()
363+
self.validators
364+
.iter()
417365
.any(|v| v.different_strict_behavior(definitions, ultra_strict))
418366
}
419367

@@ -422,8 +370,9 @@ impl Validator for TaggedUnionValidator {
422370
}
423371

424372
fn complete(&mut self, definitions: &DefinitionsBuilder<CombinedValidator>) -> PyResult<()> {
425-
self.validators.iter()
426-
.try_for_each(|(_, validator)| validator.complete(definitions))
373+
self.validators
374+
.iter_mut()
375+
.try_for_each(|validator| validator.complete(definitions))
427376
}
428377
}
429378

@@ -432,7 +381,7 @@ impl TaggedUnionValidator {
432381
&'s self,
433382
py: Python<'data>,
434383
input: &'data impl Input<'data>,
435-
) -> ValResult<'data, Cow<'data, str>> {
384+
) -> ValResult<'data, &'data PyString> {
436385
let dict = input.strict_dict()?;
437386
let either_tag = match dict {
438387
GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "type")) {
@@ -455,44 +404,44 @@ impl TaggedUnionValidator {
455404
if tag == "function" {
456405
let mode = mode.ok_or_else(|| self.tag_not_found(input))?;
457406
match mode.as_cow()?.as_ref() {
458-
"plain" => Ok(Cow::Borrowed("function-plain")),
459-
"wrap" => Ok(Cow::Borrowed("function-wrap")),
460-
_ => Ok(Cow::Borrowed("function")),
407+
"plain" => Ok(intern!(py, "function-plain")),
408+
"wrap" => Ok(intern!(py, "function-wrap")),
409+
_ => Ok(intern!(py, "function")),
461410
}
462411
} else {
463412
// tag == "tuple"
464413
if let Some(mode) = mode {
465414
if mode.as_cow()?.as_ref() == "positional" {
466-
return Ok(Cow::Borrowed("tuple-positional"));
415+
return Ok(intern!(py, "tuple-positional"));
467416
}
468417
}
469-
Ok(Cow::Borrowed("tuple-variable"))
418+
Ok(intern!(py, "tuple-variable"))
470419
}
471420
} else {
472-
Ok(Cow::Owned(tag.to_string()))
421+
Ok(PyString::new(py, tag))
473422
}
474423
}
475424

476425
fn find_call_validator<'s, 'data>(
477426
&'s self,
478427
py: Python<'data>,
479-
tag: &ChoiceKey,
428+
tag: &'data PyAny,
480429
input: &'data impl Input<'data>,
481430
extra: &Extra,
482431
definitions: &'data Definitions<CombinedValidator>,
483432
recursion_guard: &'s mut RecursionGuard,
484433
) -> ValResult<'data, PyObject> {
485-
if let Some(validator) = self.choices.get(tag) {
486-
return match validator.validate(py, input, extra, definitions, recursion_guard) {
487-
Ok(res) => Ok(res),
488-
Err(err) => Err(err.with_outer_location(tag.into())),
489-
};
490-
} else if let Some(ref repeat_choices) = self.repeat_choices {
491-
if let Some(choice_tag) = repeat_choices.get(tag) {
492-
let validator = &self.choices[choice_tag];
434+
if let Ok(tag) = self
435+
.discriminator_validator
436+
.validate(py, tag, extra, definitions, recursion_guard)
437+
{
438+
let tag = tag.as_ref(py);
439+
if let Some(validator_idx) = self.choices.as_ref(py).get_item(tag) {
440+
// We know this will always be a usize because we put it there ourselves
441+
let validator = &self.validators[usize::extract(validator_idx).unwrap()];
493442
return match validator.validate(py, input, extra, definitions, recursion_guard) {
494443
Ok(res) => Ok(res),
495-
Err(err) => Err(err.with_outer_location(tag.into())),
444+
Err(err) => Err(err.with_outer_location(LocItem::try_from(tag)?)),
496445
};
497446
}
498447
}

0 commit comments

Comments
 (0)