Skip to content

Commit fcc77f8

Browse files
Improvements in union matching logic during validation (#1332)
1 parent 46379ac commit fcc77f8

File tree

6 files changed

+551
-26
lines changed

6 files changed

+551
-26
lines changed

src/validators/dataclass.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ impl Validator for DataclassArgsValidator {
154154
let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len());
155155

156156
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
157+
let mut fields_set_count: usize = 0;
157158

158159
macro_rules! set_item {
159160
($field:ident, $value:expr) => {{
@@ -175,6 +176,7 @@ impl Validator for DataclassArgsValidator {
175176
Ok(Some(value)) => {
176177
// Default value exists, and passed validation if required
177178
set_item!(field, value);
179+
fields_set_count += 1;
178180
}
179181
Ok(None) | Err(ValError::Omit) => continue,
180182
// Note: this will always use the field name even if there is an alias
@@ -214,15 +216,21 @@ impl Validator for DataclassArgsValidator {
214216
}
215217
// found a positional argument, validate it
216218
(Some(pos_value), None) => match field.validator.validate(py, pos_value.borrow_input(), state) {
217-
Ok(value) => set_item!(field, value),
219+
Ok(value) => {
220+
set_item!(field, value);
221+
fields_set_count += 1;
222+
}
218223
Err(ValError::LineErrors(line_errors)) => {
219224
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
220225
}
221226
Err(err) => return Err(err),
222227
},
223228
// found a keyword argument, validate it
224229
(None, Some((lookup_path, kw_value))) => match field.validator.validate(py, kw_value, state) {
225-
Ok(value) => set_item!(field, value),
230+
Ok(value) => {
231+
set_item!(field, value);
232+
fields_set_count += 1;
233+
}
226234
Err(ValError::LineErrors(line_errors)) => {
227235
errors.extend(
228236
line_errors
@@ -336,6 +344,8 @@ impl Validator for DataclassArgsValidator {
336344
}
337345
}
338346

347+
state.add_fields_set(fields_set_count);
348+
339349
if errors.is_empty() {
340350
if let Some(init_only_args) = init_only_args {
341351
Ok((output_dict, PyTuple::new_bound(py, init_only_args)).to_object(py))

src/validators/model.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ impl Validator for ModelValidator {
204204
for field_name in validated_fields_set {
205205
fields_set.add(field_name)?;
206206
}
207+
state.add_fields_set(fields_set.len());
207208
}
208209

209210
force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
@@ -241,10 +242,13 @@ impl ModelValidator {
241242
} else {
242243
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
243244
};
244-
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
245+
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
245246
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?;
247+
state.add_fields_set(fields_set.len());
246248
} else {
247-
let (model_dict, model_extra, fields_set) = output.extract(py)?;
249+
let (model_dict, model_extra, fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
250+
output.extract(py)?;
251+
state.add_fields_set(fields_set.len().unwrap_or(0));
248252
set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?;
249253
}
250254
self.call_post_init(py, self_instance.clone(), input, state.extra())
@@ -281,11 +285,13 @@ impl ModelValidator {
281285
} else {
282286
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
283287
};
284-
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
288+
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
285289
force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?;
290+
state.add_fields_set(fields_set.len());
286291
} else {
287292
let (model_dict, model_extra, val_fields_set) = output.extract(py)?;
288293
let fields_set = existing_fields_set.unwrap_or(&val_fields_set);
294+
state.add_fields_set(fields_set.len().unwrap_or(0));
289295
set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?;
290296
}
291297
self.call_post_init(py, instance, input, state.extra())

src/validators/typed_dict.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ impl Validator for TypedDictValidator {
165165

166166
{
167167
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
168+
let mut fields_set_count: usize = 0;
168169

169170
for field in &self.fields {
170171
let op_key_value = match dict.get_item(&field.lookup_key) {
@@ -186,6 +187,7 @@ impl Validator for TypedDictValidator {
186187
match field.validator.validate(py, value.borrow_input(), state) {
187188
Ok(value) => {
188189
output_dict.set_item(&field.name_py, value)?;
190+
fields_set_count += 1;
189191
}
190192
Err(ValError::Omit) => continue,
191193
Err(ValError::LineErrors(line_errors)) => {
@@ -227,6 +229,8 @@ impl Validator for TypedDictValidator {
227229
Err(err) => return Err(err),
228230
}
229231
}
232+
233+
state.add_fields_set(fields_set_count);
230234
}
231235

232236
if let Some(used_keys) = used_keys {

src/validators/union.rs

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,12 @@ impl UnionValidator {
108108
state: &mut ValidationState<'_, 'py>,
109109
) -> ValResult<PyObject> {
110110
let old_exactness = state.exactness;
111+
let old_fields_set_count = state.fields_set_count;
112+
111113
let strict = state.strict_or(self.strict);
112114
let mut errors = MaybeErrors::new(self.custom_error.as_ref());
113115

114-
let mut success = None;
116+
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;
115117

116118
for (choice, label) in &self.choices {
117119
let state = &mut state.rebind_extra(|extra| {
@@ -120,47 +122,67 @@ impl UnionValidator {
120122
}
121123
});
122124
state.exactness = Some(Exactness::Exact);
125+
state.fields_set_count = None;
123126
let result = choice.validate(py, input, state);
124127
match result {
125-
Ok(new_success) => match state.exactness {
126-
// exact match, return
127-
Some(Exactness::Exact) => {
128+
Ok(new_success) => match (state.exactness, state.fields_set_count) {
129+
(Some(Exactness::Exact), None) => {
130+
// exact match with no fields set data, return immediately
128131
return {
129132
// exact match, return, restore any previous exactness
130133
state.exactness = old_exactness;
134+
state.fields_set_count = old_fields_set_count;
131135
Ok(new_success)
132136
};
133137
}
134138
_ => {
135139
// success should always have an exactness
136140
debug_assert_ne!(state.exactness, None);
141+
137142
let new_exactness = state.exactness.unwrap_or(Exactness::Lax);
138-
// if the new result has higher exactness than the current success, replace it
139-
if success
140-
.as_ref()
141-
.map_or(true, |(_, current_exactness)| *current_exactness < new_exactness)
142-
{
143-
// TODO: is there a possible optimization here, where once there has
144-
// been one success, we turn on strict mode, to avoid unnecessary
145-
// coercions for further validation?
146-
success = Some((new_success, new_exactness));
143+
let new_fields_set_count = state.fields_set_count;
144+
145+
// we use both the exactness and the fields_set_count to determine the best union member match
146+
// if fields_set_count is available for the current best match and the new candidate, we use this
147+
// as the primary metric. If the new fields_set_count is greater, the new candidate is better.
148+
// if the fields_set_count is the same, we use the exactness as a tie breaker to determine the best match.
149+
// if the fields_set_count is not available for either the current best match or the new candidate,
150+
// we use the exactness to determine the best match.
151+
let new_success_is_best_match: bool =
152+
best_match
153+
.as_ref()
154+
.map_or(true, |(_, cur_exactness, cur_fields_set_count)| {
155+
match (*cur_fields_set_count, new_fields_set_count) {
156+
(Some(cur), Some(new)) if cur != new => cur < new,
157+
_ => *cur_exactness < new_exactness,
158+
}
159+
});
160+
161+
if new_success_is_best_match {
162+
best_match = Some((new_success, new_exactness, new_fields_set_count));
147163
}
148164
}
149165
},
150166
Err(ValError::LineErrors(lines)) => {
151167
// if we don't yet know this validation will succeed, record the error
152-
if success.is_none() {
168+
if best_match.is_none() {
153169
errors.push(choice, label.as_deref(), lines);
154170
}
155171
}
156172
otherwise => return otherwise,
157173
}
158174
}
175+
176+
// restore previous validation state to prepare for any future validations
159177
state.exactness = old_exactness;
178+
state.fields_set_count = old_fields_set_count;
160179

161-
if let Some((success, exactness)) = success {
180+
if let Some((best_match, exactness, fields_set_count)) = best_match {
162181
state.floor_exactness(exactness);
163-
return Ok(success);
182+
if let Some(count) = fields_set_count {
183+
state.add_fields_set(count);
184+
}
185+
return Ok(best_match);
164186
}
165187

166188
// no matches, build errors

src/validators/validation_state.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub enum Exactness {
1818
pub struct ValidationState<'a, 'py> {
1919
pub recursion_guard: &'a mut RecursionState,
2020
pub exactness: Option<Exactness>,
21+
pub fields_set_count: Option<usize>,
2122
// deliberately make Extra readonly
2223
extra: Extra<'a, 'py>,
2324
}
@@ -27,6 +28,7 @@ impl<'a, 'py> ValidationState<'a, 'py> {
2728
Self {
2829
recursion_guard, // Don't care about exactness unless doing union validation
2930
exactness: None,
31+
fields_set_count: None,
3032
extra,
3133
}
3234
}
@@ -68,6 +70,10 @@ impl<'a, 'py> ValidationState<'a, 'py> {
6870
}
6971
}
7072

73+
pub fn add_fields_set(&mut self, fields_set_count: usize) {
74+
*self.fields_set_count.get_or_insert(0) += fields_set_count;
75+
}
76+
7177
pub fn cache_str(&self) -> StringCacheMode {
7278
self.extra.cache_str
7379
}

0 commit comments

Comments
 (0)