Skip to content

Commit 9075107

Browse files
committed
Allow non-scalar values as tagged union keys
1 parent e7108e8 commit 9075107

File tree

8 files changed

+185
-303
lines changed

8 files changed

+185
-303
lines changed

generate_self_schema.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def tagged_union(std_union_schema: Dict[str, Any], discriminator_key: str, ref:
9595
first, *rest = literal
9696
tagged_choices[first] = choice
9797
for arg in rest:
98-
tagged_choices[arg] = first
98+
tagged_choices[arg] = choice
9999
s = {'type': 'tagged-union', 'discriminator': discriminator_key, 'choices': tagged_choices}
100100
if ref is not None:
101101
s['ref'] = ref
@@ -129,15 +129,8 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
129129
schema = {'type': 'list', 'items_schema': schema_ref_validator}
130130
elif fr_arg == 'Dict[str, CoreSchema]':
131131
schema = {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
132-
elif fr_arg == 'Dict[Union[str, int], Union[str, int, CoreSchema]]':
133-
schema = {
134-
'type': 'dict',
135-
'keys_schema': {'type': 'union', 'choices': [{'type': 'str'}, {'type': 'int'}]},
136-
'values_schema': {
137-
'type': 'union',
138-
'choices': [{'type': 'str'}, {'type': 'int'}, schema_ref_validator],
139-
},
140-
}
132+
elif fr_arg == 'Dict[Hashable, CoreSchema]':
133+
schema = {'type': 'dict', 'keys_schema': {'type': 'any'}, 'values_schema': schema_ref_validator}
141134
else:
142135
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')
143136
else:

pydantic_core/core_schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from collections.abc import Mapping
55
from datetime import date, datetime, time, timedelta
6-
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
6+
from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union
77

88
if sys.version_info < (3, 11):
99
from typing_extensions import Protocol, Required, TypeAlias
@@ -2361,7 +2361,7 @@ def union_schema(
23612361

23622362
class TaggedUnionSchema(TypedDict, total=False):
23632363
type: Required[Literal['tagged-union']]
2364-
choices: Required[Dict[Union[str, int], Union[str, int, CoreSchema]]]
2364+
choices: Required[Dict[Hashable, CoreSchema]]
23652365
discriminator: Required[
23662366
Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[Union[str, int]]]]
23672367
]
@@ -2376,7 +2376,7 @@ class TaggedUnionSchema(TypedDict, total=False):
23762376

23772377

23782378
def tagged_union_schema(
2379-
choices: Dict[Union[int, str], int | str | CoreSchema],
2379+
choices: Dict[Hashable, CoreSchema],
23802380
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | int | None],
23812381
*,
23822382
custom_error_type: str | None = None,

src/input/input_abstract.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,18 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
124124
self.strict_int()
125125
}
126126

127+
/// Extract an EitherInt from the input, only allowing exact
128+
/// matches for an Int (no subclasses)
129+
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
130+
self.strict_int()
131+
}
132+
133+
/// Extract a String from the input, only allowing exact
134+
/// matches for a String (no subclasses)
135+
fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
136+
self.strict_str()
137+
}
138+
127139
fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult<f64> {
128140
if ultra_strict {
129141
self.ultra_strict_float()

src/input/input_python.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ impl<'a> Input<'a> for PyAny {
195195
}
196196
}
197197

198+
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
199+
if PyInt::is_exact_type_of(self) {
200+
Ok(EitherInt::Py(self))
201+
} else {
202+
Err(ValError::new(ErrorType::IntType, self))
203+
}
204+
}
205+
198206
fn lax_str(&'a self) -> ValResult<EitherString<'a>> {
199207
if let Ok(py_str) = <PyString as PyTryFrom>::try_from_exact(self) {
200208
Ok(py_str.into())

src/lookup_key.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::errors::{ErrorType, ValLineError};
1010
use crate::input::{Input, JsonInput, JsonObject};
1111
use crate::tools::{extract_i64, py_err};
1212

13-
/// Used got getting items from python dicts, python objects, or JSON objects, in different ways
13+
/// Used for getting items from python dicts, python objects, or JSON objects, in different ways
1414
#[derive(Debug, Clone)]
1515
pub(crate) enum LookupKey {
1616
/// simply look up a key in a dict, equivalent to `d.get(key)`
@@ -29,7 +29,7 @@ pub(crate) enum LookupKey {
2929
py_key2: Py<PyString>,
3030
path2: LookupPath,
3131
},
32-
/// look up keys buy one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
32+
/// look up keys by one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
3333
/// ints are also supported to index arrays/lists/tuples and dicts with int keys
3434
/// we reuse Location as the enum is the same, and the meaning is the same
3535
PathChoices(Vec<LookupPath>),

src/validators/literal.rs

Lines changed: 96 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// Validator for things inside of a typing.Literal[]
22
// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
3+
use core::fmt::Debug;
34

4-
use ahash::AHashSet;
5+
use ahash::AHashMap;
56
use pyo3::intern;
67
use pyo3::prelude::*;
78
use pyo3::types::{PyDict, PyList};
@@ -15,15 +16,96 @@ use crate::tools::SchemaDict;
1516
use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
1617

1718
#[derive(Debug, Clone)]
18-
pub struct LiteralValidator {
19+
pub struct LiteralLookup<T: Clone + Debug> {
1920
// Specialized lookups for ints and strings because they
2021
// (1) are easy to convert between Rust and Python
2122
// (2) hashing them in Rust is very fast
2223
// (3) are the most commonly used things in Literal[...]
23-
expected_int: Option<AHashSet<i64>>,
24-
expected_str: Option<AHashSet<String>>,
24+
expected_int: Option<AHashMap<i64, usize>>,
25+
expected_str: Option<AHashMap<String, usize>>,
2526
// Catch all for Enum and bytes (the latter only because it is seldom used)
2627
expected_py: Option<Py<PyDict>>,
28+
pub values: Vec<T>,
29+
}
30+
31+
impl<T: Clone + Debug> LiteralLookup<T> {
32+
pub fn new<'py>(py: Python<'py>, expected: impl Iterator<Item = (&'py PyAny, T)>) -> PyResult<Self> {
33+
let mut expected_int = AHashMap::new();
34+
let mut expected_str = AHashMap::new();
35+
let expected_py = PyDict::new(py);
36+
let mut values = Vec::new();
37+
for (k, v) in expected {
38+
let id = values.len();
39+
values.push(v);
40+
if let Ok(either_int) = k.exact_int() {
41+
let int = either_int
42+
.into_i64(py)
43+
.map_err(|_| py_schema_error_type!("error extracting int {:?}", k))?;
44+
expected_int.insert(int, id);
45+
} else if let Ok(either_str) = k.exact_str() {
46+
let str = either_str
47+
.as_cow()
48+
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
49+
expected_str.insert(str.to_string(), id);
50+
} else {
51+
expected_py.set_item(k, id)?;
52+
}
53+
}
54+
55+
Ok(Self {
56+
expected_int: match expected_int.is_empty() {
57+
true => None,
58+
false => Some(expected_int),
59+
},
60+
expected_str: match expected_str.is_empty() {
61+
true => None,
62+
false => Some(expected_str),
63+
},
64+
expected_py: match expected_py.is_empty() {
65+
true => None,
66+
false => Some(expected_py.into()),
67+
},
68+
values,
69+
})
70+
}
71+
72+
pub fn validate<'data, I: Input<'data>>(
73+
&self,
74+
py: Python<'data>,
75+
input: &'data I,
76+
) -> ValResult<'data, Option<(&'data I, &T)>> {
77+
// dbg!(input.to_object(py).as_ref(py).repr().unwrap());
78+
if let Some(expected_ints) = &self.expected_int {
79+
if let Ok(either_int) = input.exact_int() {
80+
let int = either_int.into_i64(py)?;
81+
if let Some(id) = expected_ints.get(&int) {
82+
return Ok(Some((input, &self.values[*id])));
83+
}
84+
}
85+
}
86+
if let Some(expected_strings) = &self.expected_str {
87+
// dbg!(expected_strings);
88+
if let Ok(either_str) = input.exact_str() {
89+
let cow = either_str.as_cow()?;
90+
if let Some(id) = expected_strings.get(cow.as_ref()) {
91+
return Ok(Some((input, &self.values[*id])));
92+
}
93+
}
94+
}
95+
// must be an enum or bytes
96+
if let Some(expected_py) = &self.expected_py {
97+
if let Some(v) = expected_py.as_ref(py).get_item(input) {
98+
let id: usize = v.extract().unwrap();
99+
return Ok(Some((input, &self.values[id])));
100+
}
101+
};
102+
Ok(None)
103+
}
104+
}
105+
106+
#[derive(Debug, Clone)]
107+
pub struct LiteralValidator {
108+
lookup: LiteralLookup<PyObject>,
27109
expected_repr: String,
28110
name: String,
29111
}
@@ -41,32 +123,14 @@ impl BuildValidator for LiteralValidator {
41123
return py_schema_err!("`expected` should have length > 0");
42124
}
43125
let py = expected.py();
44-
// Literal[...] only supports int, str, bytes or enums, all of which can be hashed
45-
let mut expected_int = AHashSet::new();
46-
let mut expected_str = AHashSet::new();
47-
let expected_py = PyDict::new(py);
48126
let mut repr_args: Vec<String> = Vec::new();
49127
for item in expected.iter() {
50128
repr_args.push(item.repr()?.extract()?);
51-
if let Ok(either_int) = item.strict_int() {
52-
let int = either_int
53-
.into_i64(py)
54-
.map_err(|_| py_schema_error_type!("error extracting int {:?}", item))?;
55-
expected_int.insert(int);
56-
} else if let Ok(either_str) = item.strict_str() {
57-
let str = either_str
58-
.as_cow()
59-
.map_err(|_| py_schema_error_type!("error extracting str {:?}", item))?;
60-
expected_str.insert(str.to_string());
61-
} else {
62-
expected_py.set_item(item, item)?;
63-
}
64129
}
65130
let (expected_repr, name) = expected_repr_name(repr_args, "literal");
131+
let lookup = LiteralLookup::new(py, expected.iter().map(|v| (v, v.to_object(py))))?;
66132
Ok(CombinedValidator::Literal(Self {
67-
expected_int: (!expected_int.is_empty()).then_some(expected_int),
68-
expected_str: (!expected_str.is_empty()).then_some(expected_str),
69-
expected_py: (!expected_py.is_empty()).then_some(expected_py.into()),
133+
lookup,
70134
expected_repr,
71135
name,
72136
}))
@@ -82,34 +146,15 @@ impl Validator for LiteralValidator {
82146
_definitions: &'data Definitions<CombinedValidator>,
83147
_recursion_guard: &'s mut RecursionGuard,
84148
) -> ValResult<'data, PyObject> {
85-
if let Some(expected_ints) = &self.expected_int {
86-
if let Ok(either_int) = input.strict_int() {
87-
let int = either_int.into_i64(py)?;
88-
if expected_ints.contains(&int) {
89-
return Ok(input.to_object(py));
90-
}
91-
}
149+
match self.lookup.validate(py, input)? {
150+
Some((_, v)) => Ok(v.clone()),
151+
None => Err(ValError::new(
152+
ErrorType::LiteralError {
153+
expected: self.expected_repr.clone(),
154+
},
155+
input,
156+
)),
92157
}
93-
if let Some(expected_strings) = &self.expected_str {
94-
if let Ok(either_str) = input.strict_str() {
95-
let cow = either_str.as_cow()?;
96-
if expected_strings.contains(cow.as_ref()) {
97-
return Ok(input.to_object(py));
98-
}
99-
}
100-
}
101-
// must be an enum or bytes
102-
if let Some(expected_py) = &self.expected_py {
103-
if let Some(v) = expected_py.as_ref(py).get_item(input) {
104-
return Ok(v.into());
105-
}
106-
};
107-
Err(ValError::new(
108-
ErrorType::LiteralError {
109-
expected: self.expected_repr.clone(),
110-
},
111-
input,
112-
))
113158
}
114159

115160
fn different_strict_behavior(

0 commit comments

Comments
 (0)