Skip to content

Allow non-scalar values as tagged union keys #655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def tagged_union(std_union_schema: Dict[str, Any], discriminator_key: str, ref:
first, *rest = literal
tagged_choices[first] = choice
for arg in rest:
tagged_choices[arg] = first
tagged_choices[arg] = choice
s = {'type': 'tagged-union', 'discriminator': discriminator_key, 'choices': tagged_choices}
if ref is not None:
s['ref'] = ref
Expand Down Expand Up @@ -129,15 +129,8 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
schema = {'type': 'list', 'items_schema': schema_ref_validator}
elif fr_arg == 'Dict[str, CoreSchema]':
schema = {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'Dict[Union[str, int], Union[str, int, CoreSchema]]':
schema = {
'type': 'dict',
'keys_schema': {'type': 'union', 'choices': [{'type': 'str'}, {'type': 'int'}]},
'values_schema': {
'type': 'union',
'choices': [{'type': 'str'}, {'type': 'int'}, schema_ref_validator],
},
}
elif fr_arg == 'Dict[Hashable, CoreSchema]':
schema = {'type': 'dict', 'keys_schema': {'type': 'any'}, 'values_schema': schema_ref_validator}
else:
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')
else:
Expand Down
6 changes: 3 additions & 3 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from typing import Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union

if sys.version_info < (3, 11):
from typing_extensions import Protocol, Required, TypeAlias
Expand Down Expand Up @@ -2361,7 +2361,7 @@ def union_schema(

class TaggedUnionSchema(TypedDict, total=False):
type: Required[Literal['tagged-union']]
choices: Required[Dict[Union[str, int], Union[str, int, CoreSchema]]]
choices: Required[Dict[Hashable, CoreSchema]]
discriminator: Required[
Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[Union[str, int]]]]
]
Expand All @@ -2376,7 +2376,7 @@ class TaggedUnionSchema(TypedDict, total=False):


def tagged_union_schema(
choices: Dict[Union[int, str], int | str | CoreSchema],
choices: Dict[Hashable, CoreSchema],
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | int | None],
*,
custom_error_type: str | None = None,
Expand Down
12 changes: 12 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_int()
}

/// Extract an EitherInt from the input, only allowing exact
/// matches for an Int (no subclasses)
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
}

/// Extract a String from the input, only allowing exact
/// matches for a String (no subclasses)
fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
self.strict_str()
}

fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult<f64> {
if ultra_strict {
self.ultra_strict_float()
Expand Down
8 changes: 8 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ impl<'a> Input<'a> for PyAny {
}
}

fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
if PyInt::is_exact_type_of(self) {
Ok(EitherInt::Py(self))
} else {
Err(ValError::new(ErrorType::IntType, self))
}
}

fn lax_str(&'a self) -> ValResult<EitherString<'a>> {
if let Ok(py_str) = <PyString as PyTryFrom>::try_from_exact(self) {
Ok(py_str.into())
Expand Down
4 changes: 2 additions & 2 deletions src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::errors::{ErrorType, ValLineError};
use crate::input::{Input, JsonInput, JsonObject};
use crate::tools::{extract_i64, py_err};

/// Used got getting items from python dicts, python objects, or JSON objects, in different ways
/// Used for getting items from python dicts, python objects, or JSON objects, in different ways
#[derive(Debug, Clone)]
pub(crate) enum LookupKey {
/// simply look up a key in a dict, equivalent to `d.get(key)`
Expand All @@ -29,7 +29,7 @@ pub(crate) enum LookupKey {
py_key2: Py<PyString>,
path2: LookupPath,
},
/// look up keys buy one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
/// look up keys by one or more "paths" a path might be `['foo', 'bar']` to get `d.?foo.?bar`
/// ints are also supported to index arrays/lists/tuples and dicts with int keys
/// we reuse Location as the enum is the same, and the meaning is the same
PathChoices(Vec<LookupPath>),
Expand Down
147 changes: 96 additions & 51 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Validator for things inside of a typing.Literal[]
// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
use core::fmt::Debug;

use ahash::AHashSet;
use ahash::AHashMap;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
Expand All @@ -15,15 +16,96 @@ use crate::tools::SchemaDict;
use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};

#[derive(Debug, Clone)]
pub struct LiteralValidator {
pub struct LiteralLookup<T: Clone + Debug> {
// Specialized lookups for ints and strings because they
// (1) are easy to convert between Rust and Python
// (2) hashing them in Rust is very fast
// (3) are the most commonly used things in Literal[...]
expected_int: Option<AHashSet<i64>>,
expected_str: Option<AHashSet<String>>,
expected_int: Option<AHashMap<i64, usize>>,
expected_str: Option<AHashMap<String, usize>>,
// Catch all for Enum and bytes (the latter only because it is seldom used)
expected_py: Option<Py<PyDict>>,
pub values: Vec<T>,
}

impl<T: Clone + Debug> LiteralLookup<T> {
pub fn new<'py>(py: Python<'py>, expected: impl Iterator<Item = (&'py PyAny, T)>) -> PyResult<Self> {
let mut expected_int = AHashMap::new();
let mut expected_str = AHashMap::new();
let expected_py = PyDict::new(py);
let mut values = Vec::new();
for (k, v) in expected {
let id = values.len();
values.push(v);
if let Ok(either_int) = k.exact_int() {
let int = either_int
.into_i64(py)
.map_err(|_| py_schema_error_type!("error extracting int {:?}", k))?;
expected_int.insert(int, id);
} else if let Ok(either_str) = k.exact_str() {
let str = either_str
.as_cow()
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
expected_str.insert(str.to_string(), id);
} else {
expected_py.set_item(k, id)?;
}
}

Ok(Self {
expected_int: match expected_int.is_empty() {
true => None,
false => Some(expected_int),
},
expected_str: match expected_str.is_empty() {
true => None,
false => Some(expected_str),
},
expected_py: match expected_py.is_empty() {
true => None,
false => Some(expected_py.into()),
},
values,
})
}

pub fn validate<'data, I: Input<'data>>(
&self,
py: Python<'data>,
input: &'data I,
) -> ValResult<'data, Option<(&'data I, &T)>> {
// dbg!(input.to_object(py).as_ref(py).repr().unwrap());
if let Some(expected_ints) = &self.expected_int {
if let Ok(either_int) = input.exact_int() {
let int = either_int.into_i64(py)?;
if let Some(id) = expected_ints.get(&int) {
return Ok(Some((input, &self.values[*id])));
}
}
}
if let Some(expected_strings) = &self.expected_str {
// dbg!(expected_strings);
if let Ok(either_str) = input.exact_str() {
let cow = either_str.as_cow()?;
if let Some(id) = expected_strings.get(cow.as_ref()) {
return Ok(Some((input, &self.values[*id])));
}
}
}
// must be an enum or bytes
if let Some(expected_py) = &self.expected_py {
if let Some(v) = expected_py.as_ref(py).get_item(input) {
let id: usize = v.extract().unwrap();
return Ok(Some((input, &self.values[id])));
}
};
Ok(None)
}
}

#[derive(Debug, Clone)]
pub struct LiteralValidator {
lookup: LiteralLookup<PyObject>,
expected_repr: String,
name: String,
}
Expand All @@ -41,32 +123,14 @@ impl BuildValidator for LiteralValidator {
return py_schema_err!("`expected` should have length > 0");
}
let py = expected.py();
// Literal[...] only supports int, str, bytes or enums, all of which can be hashed
let mut expected_int = AHashSet::new();
let mut expected_str = AHashSet::new();
let expected_py = PyDict::new(py);
let mut repr_args: Vec<String> = Vec::new();
for item in expected.iter() {
repr_args.push(item.repr()?.extract()?);
if let Ok(either_int) = item.strict_int() {
let int = either_int
.into_i64(py)
.map_err(|_| py_schema_error_type!("error extracting int {:?}", item))?;
expected_int.insert(int);
} else if let Ok(either_str) = item.strict_str() {
let str = either_str
.as_cow()
.map_err(|_| py_schema_error_type!("error extracting str {:?}", item))?;
expected_str.insert(str.to_string());
} else {
expected_py.set_item(item, item)?;
}
}
let (expected_repr, name) = expected_repr_name(repr_args, "literal");
let lookup = LiteralLookup::new(py, expected.iter().map(|v| (v, v.to_object(py))))?;
Ok(CombinedValidator::Literal(Self {
expected_int: (!expected_int.is_empty()).then_some(expected_int),
expected_str: (!expected_str.is_empty()).then_some(expected_str),
expected_py: (!expected_py.is_empty()).then_some(expected_py.into()),
lookup,
expected_repr,
name,
}))
Expand All @@ -82,34 +146,15 @@ impl Validator for LiteralValidator {
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
if let Some(expected_ints) = &self.expected_int {
if let Ok(either_int) = input.strict_int() {
let int = either_int.into_i64(py)?;
if expected_ints.contains(&int) {
return Ok(input.to_object(py));
}
}
match self.lookup.validate(py, input)? {
Some((_, v)) => Ok(v.clone()),
None => Err(ValError::new(
ErrorType::LiteralError {
expected: self.expected_repr.clone(),
},
input,
)),
}
if let Some(expected_strings) = &self.expected_str {
if let Ok(either_str) = input.strict_str() {
let cow = either_str.as_cow()?;
if expected_strings.contains(cow.as_ref()) {
return Ok(input.to_object(py));
}
}
}
// must be an enum or bytes
if let Some(expected_py) = &self.expected_py {
if let Some(v) = expected_py.as_ref(py).get_item(input) {
return Ok(v.into());
}
};
Err(ValError::new(
ErrorType::LiteralError {
expected: self.expected_repr.clone(),
},
input,
))
}

fn different_strict_behavior(
Expand Down
Loading