Skip to content

Commit 364dd9c

Browse files
committed
remove definitions vec
1 parent 1a966d5 commit 364dd9c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+440
-629
lines changed

src/definitions.rs

Lines changed: 129 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
/// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar.
44
/// We use DefinitionsBuilder to collect the references / definitions into a single vector
55
/// and then get a definition from a reference using an integer id (just for performance of not using a HashMap)
6-
use std::collections::hash_map::Entry;
6+
use std::{
7+
collections::hash_map::Entry,
8+
fmt::Debug,
9+
sync::{Arc, OnceLock},
10+
};
711

8-
use pyo3::prelude::*;
12+
use pyo3::{prelude::*, PyTraverseError, PyVisit};
913

1014
use ahash::AHashMap;
1115

12-
use crate::build_tools::py_schema_err;
13-
14-
// An integer id for the reference
15-
pub type ReferenceId = usize;
16+
use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};
1617

1718
/// Definitions are validators and serializers that are
1819
/// shared by reference.
@@ -24,91 +25,154 @@ pub type ReferenceId = usize;
2425
/// They get indexed by a ReferenceId, which are integer identifiers
2526
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
2627
/// gets build.
27-
pub type Definitions<T> = [T];
28+
#[derive(Clone)]
29+
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);
2830

29-
#[derive(Clone, Debug)]
30-
struct Definition<T> {
31-
pub id: ReferenceId,
32-
pub value: Option<T>,
31+
impl<T> Definitions<T> {
32+
pub fn values(&self) -> impl Iterator<Item = &Definition<T>> {
33+
self.0.values()
34+
}
35+
}
36+
37+
/// Internal type which contains a definition to be filled
38+
pub struct Definition<T>(Arc<OnceLock<T>>);
39+
40+
impl<T> Definition<T> {
41+
pub fn get(&self) -> Option<&T> {
42+
self.0.get()
43+
}
44+
}
45+
46+
/// Reference to a definition.
47+
pub struct DefinitionRef<T> {
48+
name: Arc<String>,
49+
value: Definition<T>,
50+
}
51+
52+
// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
53+
impl<T> Clone for DefinitionRef<T> {
54+
fn clone(&self) -> Self {
55+
Self {
56+
name: self.name.clone(),
57+
value: self.value.clone(),
58+
}
59+
}
60+
}
61+
62+
impl<T> DefinitionRef<T> {
63+
pub fn id(&self) -> usize {
64+
Arc::as_ptr(&self.value.0) as usize
65+
}
66+
67+
pub fn name(&self) -> &str {
68+
&self.name
69+
}
70+
71+
pub fn get(&self) -> Option<&T> {
72+
self.value.0.get()
73+
}
74+
}
75+
76+
impl<T: Debug> Debug for DefinitionRef<T> {
77+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78+
// To avoid possible infinite recursion from recursive definitions,
79+
// a DefinitionRef just displays debug as its name
80+
self.name.fmt(f)
81+
}
82+
}
83+
84+
impl<T: Debug> Debug for Definitions<T> {
85+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86+
self.0.fmt(f)
87+
}
88+
}
89+
90+
impl<T> Clone for Definition<T> {
91+
fn clone(&self) -> Self {
92+
Self(self.0.clone())
93+
}
94+
}
95+
96+
impl<T: Debug> Debug for Definition<T> {
97+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98+
match self.0.get() {
99+
Some(value) => value.fmt(f),
100+
None => "...".fmt(f),
101+
}
102+
}
103+
}
104+
105+
impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
106+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
107+
if let Some(value) = self.value.0.get() {
108+
value.py_gc_traverse(visit)?;
109+
}
110+
Ok(())
111+
}
112+
}
113+
114+
impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
115+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
116+
for value in self.0.values() {
117+
if let Some(value) = value.0.get() {
118+
value.py_gc_traverse(visit)?;
119+
}
120+
}
121+
Ok(())
122+
}
33123
}
34124

35125
#[derive(Clone, Debug)]
36126
pub struct DefinitionsBuilder<T> {
37-
definitions: AHashMap<String, Definition<T>>,
127+
definitions: Definitions<T>,
38128
}
39129

40-
impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
130+
impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
41131
pub fn new() -> Self {
42132
Self {
43-
definitions: AHashMap::new(),
133+
definitions: Definitions(AHashMap::new()),
44134
}
45135
}
46136

47137
/// Get a ReferenceId for the given reference string.
48-
// This ReferenceId can later be used to retrieve a definition
49-
pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId {
50-
let next_id = self.definitions.len();
138+
pub fn get_definition(&mut self, reference: &str) -> DefinitionRef<T> {
51139
// We either need a String copy or two hashmap lookups
52140
// Neither is better than the other
53141
// We opted for the easier outward facing API
54-
match self.definitions.entry(reference.to_string()) {
55-
Entry::Occupied(entry) => entry.get().id,
56-
Entry::Vacant(entry) => {
57-
entry.insert(Definition {
58-
id: next_id,
59-
value: None,
60-
});
61-
next_id
62-
}
142+
let name = Arc::new(reference.to_string());
143+
let value = match self.definitions.0.entry(name.clone()) {
144+
Entry::Occupied(entry) => entry.into_mut(),
145+
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(OnceLock::new()))),
146+
};
147+
DefinitionRef {
148+
name,
149+
value: value.clone(),
63150
}
64151
}
65152

66153
/// Add a definition, returning the ReferenceId that maps to it
67-
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<ReferenceId> {
68-
let next_id = self.definitions.len();
69-
match self.definitions.entry(reference.clone()) {
70-
Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) {
71-
Some(_) => py_schema_err!("Duplicate ref: `{}`", reference),
72-
None => Ok(entry.get().id),
73-
},
74-
Entry::Vacant(entry) => {
75-
entry.insert(Definition {
76-
id: next_id,
77-
value: Some(value),
78-
});
79-
Ok(next_id)
154+
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
155+
let name = Arc::new(reference);
156+
let value = match self.definitions.0.entry(name.clone()) {
157+
Entry::Occupied(entry) => {
158+
let definition = entry.into_mut();
159+
match definition.0.set(value) {
160+
Ok(()) => definition.clone(),
161+
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
162+
}
80163
}
81-
}
82-
}
83-
84-
/// Retrieve an item definition using a ReferenceId
85-
/// If the definition doesn't yet exist (as happens in recursive types) then we create it
86-
/// At the end (in finish()) we check that there are no undefined definitions
87-
pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> {
88-
let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) {
89-
Some(v) => v,
90-
None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id),
164+
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(OnceLock::from(value)))).clone(),
91165
};
92-
match def.value.as_ref() {
93-
Some(v) => Ok(v),
94-
None => py_schema_err!(
95-
"Definitions error: attempted to use `{}` before it was filled",
96-
reference
97-
),
98-
}
166+
Ok(DefinitionRef { name, value })
99167
}
100168

101169
/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
102-
pub fn finish(self) -> PyResult<Vec<T>> {
103-
// We need to create a vec of defs according to the order in their ids
104-
let mut defs: Vec<(usize, T)> = Vec::new();
105-
for (reference, def) in self.definitions {
106-
match def.value {
107-
None => return py_schema_err!("Definitions error: definition {} was never filled", reference),
108-
Some(v) => defs.push((def.id, v)),
170+
pub fn finish(self) -> PyResult<Definitions<T>> {
171+
for (reference, def) in &self.definitions.0 {
172+
if def.0.get().is_none() {
173+
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
109174
}
110175
}
111-
defs.sort_by_key(|(id, _)| *id);
112-
Ok(defs.into_iter().map(|(_, v)| v).collect())
176+
Ok(self.definitions)
113177
}
114178
}

src/serializers/extra.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ use serde::ser::Error;
1010
use super::config::SerializationConfig;
1111
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
1212
use super::ob_type::ObTypeLookup;
13-
use super::shared::CombinedSerializer;
14-
use crate::definitions::Definitions;
1513
use crate::recursion_guard::RecursionGuard;
1614

1715
/// this is ugly, would be much better if extra could be stored in `SerializationState`
@@ -48,7 +46,6 @@ impl SerializationState {
4846
Extra::new(
4947
py,
5048
mode,
51-
&[],
5249
by_alias,
5350
&self.warnings,
5451
false,
@@ -72,7 +69,6 @@ impl SerializationState {
7269
#[cfg_attr(debug_assertions, derive(Debug))]
7370
pub(crate) struct Extra<'a> {
7471
pub mode: &'a SerMode,
75-
pub definitions: &'a Definitions<CombinedSerializer>,
7672
pub ob_type_lookup: &'a ObTypeLookup,
7773
pub warnings: &'a CollectWarnings,
7874
pub by_alias: bool,
@@ -98,7 +94,6 @@ impl<'a> Extra<'a> {
9894
pub fn new(
9995
py: Python<'a>,
10096
mode: &'a SerMode,
101-
definitions: &'a Definitions<CombinedSerializer>,
10297
by_alias: bool,
10398
warnings: &'a CollectWarnings,
10499
exclude_unset: bool,
@@ -112,7 +107,6 @@ impl<'a> Extra<'a> {
112107
) -> Self {
113108
Self {
114109
mode,
115-
definitions,
116110
ob_type_lookup: ObTypeLookup::cached(py),
117111
warnings,
118112
by_alias,
@@ -156,7 +150,6 @@ impl SerCheck {
156150
#[cfg_attr(debug_assertions, derive(Debug))]
157151
pub(crate) struct ExtraOwned {
158152
mode: SerMode,
159-
definitions: Vec<CombinedSerializer>,
160153
warnings: CollectWarnings,
161154
by_alias: bool,
162155
exclude_unset: bool,
@@ -176,7 +169,6 @@ impl ExtraOwned {
176169
pub fn new(extra: &Extra) -> Self {
177170
Self {
178171
mode: extra.mode.clone(),
179-
definitions: extra.definitions.to_vec(),
180172
warnings: extra.warnings.clone(),
181173
by_alias: extra.by_alias,
182174
exclude_unset: extra.exclude_unset,
@@ -196,7 +188,6 @@ impl ExtraOwned {
196188
pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> {
197189
Extra {
198190
mode: &self.mode,
199-
definitions: &self.definitions,
200191
ob_type_lookup: ObTypeLookup::cached(py),
201192
warnings: &self.warnings,
202193
by_alias: self.by_alias,

src/serializers/mod.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use pyo3::prelude::*;
55
use pyo3::types::{PyBytes, PyDict};
66
use pyo3::{PyTraverseError, PyVisit};
77

8-
use crate::definitions::DefinitionsBuilder;
8+
use crate::definitions::{Definitions, DefinitionsBuilder};
99
use crate::py_gc::PyGcTraverse;
1010

1111
use config::SerializationConfig;
@@ -30,7 +30,7 @@ mod type_serializers;
3030
#[derive(Debug)]
3131
pub struct SchemaSerializer {
3232
serializer: CombinedSerializer,
33-
definitions: Vec<CombinedSerializer>,
33+
definitions: Definitions<CombinedSerializer>,
3434
expected_json_size: AtomicUsize,
3535
config: SerializationConfig,
3636
}
@@ -54,7 +54,6 @@ impl SchemaSerializer {
5454
Extra::new(
5555
py,
5656
mode,
57-
&self.definitions,
5857
by_alias,
5958
warnings,
6059
exclude_unset,
@@ -184,9 +183,7 @@ impl SchemaSerializer {
184183

185184
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
186185
self.serializer.py_gc_traverse(&visit)?;
187-
for slot in &self.definitions {
188-
slot.py_gc_traverse(&visit)?;
189-
}
186+
self.definitions.py_gc_traverse(&visit)?;
190187
Ok(())
191188
}
192189
}

src/serializers/shared.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use serde_json::ser::PrettyFormatter;
1313

1414
use crate::build_tools::py_schema_err;
1515
use crate::build_tools::py_schema_error_type;
16-
use crate::definitions::{Definitions, DefinitionsBuilder};
16+
use crate::definitions::DefinitionsBuilder;
1717
use crate::py_gc::PyGcTraverse;
1818
use crate::tools::{py_err, SchemaDict};
1919

@@ -293,7 +293,7 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug {
293293
fn get_name(&self) -> &str;
294294

295295
/// Used by union serializers to decide if it's worth trying again while allowing subclasses
296-
fn retry_with_lax_check(&self, _definitions: &Definitions<CombinedSerializer>) -> bool {
296+
fn retry_with_lax_check(&self) -> bool {
297297
false
298298
}
299299

src/serializers/type_serializers/dataclass.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::borrow::Cow;
66
use ahash::AHashMap;
77

88
use crate::build_tools::{py_schema_error_type, ExtraBehavior};
9-
use crate::definitions::{Definitions, DefinitionsBuilder};
9+
use crate::definitions::DefinitionsBuilder;
1010
use crate::tools::SchemaDict;
1111

1212
use super::{
@@ -179,7 +179,7 @@ impl TypeSerializer for DataclassSerializer {
179179
&self.name
180180
}
181181

182-
fn retry_with_lax_check(&self, _definitions: &Definitions<CombinedSerializer>) -> bool {
182+
fn retry_with_lax_check(&self) -> bool {
183183
true
184184
}
185185
}

0 commit comments

Comments
 (0)