Skip to content

Commit b764c16

Browse files
Support negative indices in serializer include/exclude (#627)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent c9a83c8 commit b764c16

File tree

10 files changed

+116
-16
lines changed

10 files changed

+116
-16
lines changed

src/serializers/filter.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use ahash::AHashSet;
2+
use pyo3::exceptions::PyValueError;
23
use std::hash::Hash;
34

45
use pyo3::exceptions::PyTypeError;
@@ -14,6 +15,46 @@ pub(crate) struct SchemaFilter<T> {
1415
exclude: Option<AHashSet<T>>,
1516
}
1617

18+
fn map_negative_index(value: &PyAny, len: Option<usize>) -> PyResult<&PyAny> {
19+
let py = value.py();
20+
match len {
21+
Some(len) => Ok(value
22+
.call_method1(intern!(py, "__mod__"), (len,))
23+
.map_or_else(|_| value, |v| v)),
24+
None => {
25+
// check that it's not negative
26+
let negative = value.call_method1(intern!(py, "__lt__"), (0,))?.is_true()?;
27+
if negative {
28+
Err(PyValueError::new_err(
29+
"Negative indices cannot be used to exclude items on unsized iterables",
30+
))
31+
} else {
32+
Ok(value)
33+
}
34+
}
35+
}
36+
}
37+
38+
fn map_negative_indices(include_or_exclude: &PyAny, len: Option<usize>) -> PyResult<&PyAny> {
39+
let py = include_or_exclude.py();
40+
if let Ok(exclude_dict) = include_or_exclude.downcast::<PyDict>() {
41+
let out = PyDict::new(py);
42+
for (k, v) in exclude_dict.iter() {
43+
out.set_item(map_negative_index(k, len)?, v)?;
44+
}
45+
Ok(out)
46+
} else if let Ok(exclude_set) = include_or_exclude.downcast::<PySet>() {
47+
let mut values = Vec::with_capacity(exclude_set.len());
48+
for v in exclude_set {
49+
values.push(map_negative_index(v, len)?)
50+
}
51+
Ok(PySet::new(py, values)?)
52+
} else {
53+
// return as is and deal with the error later
54+
Ok(include_or_exclude)
55+
}
56+
}
57+
1758
impl SchemaFilter<usize> {
1859
pub fn from_schema(schema: &PyDict) -> PyResult<Self> {
1960
let py = schema.py();
@@ -51,7 +92,10 @@ impl SchemaFilter<usize> {
5192
index: usize,
5293
include: Option<&'py PyAny>,
5394
exclude: Option<&'py PyAny>,
95+
len: Option<usize>,
5496
) -> PyResult<Option<(Option<&'py PyAny>, Option<&'py PyAny>)>> {
97+
let include = include.map(|v| map_negative_indices(v, len)).transpose()?;
98+
let exclude = exclude.map(|v| map_negative_indices(v, len)).transpose()?;
5599
self.filter(index, index, include, exclude)
56100
}
57101
}
@@ -232,7 +276,10 @@ impl AnyFilter {
232276
index: usize,
233277
include: Option<&'py PyAny>,
234278
exclude: Option<&'py PyAny>,
279+
len: Option<usize>,
235280
) -> PyResult<Option<(Option<&'py PyAny>, Option<&'py PyAny>)>> {
281+
let include = include.map(|v| map_negative_indices(v, len)).transpose()?;
282+
let exclude = exclude.map(|v| map_negative_indices(v, len)).transpose()?;
236283
self.filter(index, index, include, exclude)
237284
}
238285
}

src/serializers/infer.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ pub(crate) fn infer_to_python_known(
7070
let py_seq: &$t = value.downcast()?;
7171
let mut items = Vec::with_capacity(py_seq.len());
7272
let filter = AnyFilter::new();
73+
let len = value.len().ok();
7374

7475
for (index, element) in py_seq.iter().enumerate() {
75-
let op_next = filter.index_filter(index, include, exclude)?;
76+
let op_next = filter.index_filter(index, include, exclude, len)?;
7677
if let Some((next_include, next_exclude)) = op_next {
7778
items.push(infer_to_python(element, next_include, next_exclude, extra)?);
7879
}
@@ -204,7 +205,7 @@ pub(crate) fn infer_to_python_known(
204205

205206
for (index, r) in py_seq.iter()?.enumerate() {
206207
let element = r?;
207-
let op_next = filter.index_filter(index, include, exclude)?;
208+
let op_next = filter.index_filter(index, include, exclude, None)?;
208209
if let Some((next_include, next_exclude)) = op_next {
209210
items.push(infer_to_python(element, next_include, next_exclude, extra)?);
210211
}
@@ -378,8 +379,12 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
378379
let py_seq: &$t = value.downcast().map_err(py_err_se_err)?;
379380
let mut seq = serializer.serialize_seq(Some(py_seq.len()))?;
380381
let filter = AnyFilter::new();
382+
let len = value.len().ok();
383+
381384
for (index, element) in py_seq.iter().enumerate() {
382-
let op_next = filter.index_filter(index, include, exclude).map_err(py_err_se_err)?;
385+
let op_next = filter
386+
.index_filter(index, include, exclude, len)
387+
.map_err(py_err_se_err)?;
383388
if let Some((next_include, next_exclude)) = op_next {
384389
let item_serializer = SerializeInfer::new(element, next_include, next_exclude, extra);
385390
seq.serialize_element(&item_serializer)?
@@ -502,7 +507,9 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
502507
let filter = AnyFilter::new();
503508
for (index, r) in py_seq.iter().map_err(py_err_se_err)?.enumerate() {
504509
let element = r.map_err(py_err_se_err)?;
505-
let op_next = filter.index_filter(index, include, exclude).map_err(py_err_se_err)?;
510+
let op_next = filter
511+
.index_filter(index, include, exclude, None)
512+
.map_err(py_err_se_err)?;
506513
if let Some((next_include, next_exclude)) = op_next {
507514
let item_serializer = SerializeInfer::new(element, next_include, next_exclude, extra);
508515
seq.serialize_element(&item_serializer)?

src/serializers/type_serializers/function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ impl SerializationCallable {
422422

423423
if let Some(index_key) = index_key {
424424
let filter = if let Ok(index) = index_key.extract::<usize>() {
425-
self.filter.index_filter(index, include, exclude)?
425+
self.filter.index_filter(index, include, exclude, None)?
426426
} else {
427427
self.filter.key_filter(index_key, include, exclude)?
428428
};

src/serializers/type_serializers/generator.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl TypeSerializer for GeneratorSerializer {
6363
};
6464
for (index, iter_result) in py_iter.enumerate() {
6565
let element = iter_result?;
66-
let op_next = self.filter.index_filter(index, include, exclude)?;
66+
let op_next = self.filter.index_filter(index, include, exclude, None)?;
6767
if let Some((next_include, next_exclude)) = op_next {
6868
items.push(item_serializer.to_python(element, next_include, next_exclude, extra)?);
6969
}
@@ -115,7 +115,7 @@ impl TypeSerializer for GeneratorSerializer {
115115
let element = iter_result.map_err(py_err_se_err)?;
116116
let op_next = self
117117
.filter
118-
.index_filter(index, include, exclude)
118+
.index_filter(index, include, exclude, None)
119119
.map_err(py_err_se_err)?;
120120
if let Some((next_include, next_exclude)) = op_next {
121121
let item_serialize =
@@ -187,7 +187,7 @@ impl SerializationIterator {
187187

188188
for iter_result in iterator {
189189
let element = iter_result?;
190-
let filter = self.filter.index_filter(self.index, include, exclude)?;
190+
let filter = self.filter.index_filter(self.index, include, exclude, None)?;
191191
self.index += 1;
192192
if let Some((next_include, next_exclude)) = filter {
193193
let v = self

src/serializers/type_serializers/list.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl TypeSerializer for ListSerializer {
6060

6161
let mut items = Vec::with_capacity(py_list.len());
6262
for (index, element) in py_list.iter().enumerate() {
63-
let op_next = self.filter.index_filter(index, include, exclude)?;
63+
let op_next = self.filter.index_filter(index, include, exclude, value.len().ok())?;
6464
if let Some((next_include, next_exclude)) = op_next {
6565
items.push(item_serializer.to_python(element, next_include, next_exclude, extra)?);
6666
}
@@ -94,7 +94,7 @@ impl TypeSerializer for ListSerializer {
9494
for (index, element) in py_list.iter().enumerate() {
9595
let op_next = self
9696
.filter
97-
.index_filter(index, include, exclude)
97+
.index_filter(index, include, exclude, Some(py_list.len()))
9898
.map_err(py_err_se_err)?;
9999
if let Some((next_include, next_exclude)) = op_next {
100100
let item_serialize =

src/serializers/type_serializers/tuple.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ impl TypeSerializer for TupleVariableSerializer {
6262

6363
let mut items = Vec::with_capacity(py_tuple.len());
6464
for (index, element) in py_tuple.iter().enumerate() {
65-
let op_next = self.filter.index_filter(index, include, exclude)?;
65+
let op_next = self
66+
.filter
67+
.index_filter(index, include, exclude, Some(py_tuple.len()))?;
6668
if let Some((next_include, next_exclude)) = op_next {
6769
items.push(item_serializer.to_python(element, next_include, next_exclude, extra)?);
6870
}
@@ -114,7 +116,7 @@ impl TypeSerializer for TupleVariableSerializer {
114116
for (index, element) in py_tuple.iter().enumerate() {
115117
let op_next = self
116118
.filter
117-
.index_filter(index, include, exclude)
119+
.index_filter(index, include, exclude, Some(py_tuple.len()))
118120
.map_err(py_err_se_err)?;
119121
if let Some((next_include, next_exclude)) = op_next {
120122
let item_serialize =
@@ -198,7 +200,9 @@ impl TypeSerializer for TuplePositionalSerializer {
198200
Some(value) => value,
199201
None => break,
200202
};
201-
let op_next = self.filter.index_filter(index, include, exclude)?;
203+
let op_next = self
204+
.filter
205+
.index_filter(index, include, exclude, Some(py_tuple.len()))?;
202206
if let Some((next_include, next_exclude)) = op_next {
203207
items.push(serializer.to_python(element, next_include, next_exclude, extra)?);
204208
}
@@ -207,7 +211,9 @@ impl TypeSerializer for TuplePositionalSerializer {
207211
let extra_serializer = self.extra_serializer.as_ref();
208212
for (index2, element) in py_tuple_iter.enumerate() {
209213
let index = index2 + expected_length;
210-
let op_next = self.filter.index_filter(index, include, exclude)?;
214+
let op_next = self
215+
.filter
216+
.index_filter(index, include, exclude, Some(py_tuple.len()))?;
211217
if let Some((next_include, next_exclude)) = op_next {
212218
items.push(extra_serializer.to_python(element, next_include, next_exclude, extra)?);
213219
}
@@ -272,7 +278,7 @@ impl TypeSerializer for TuplePositionalSerializer {
272278
};
273279
let op_next = self
274280
.filter
275-
.index_filter(index, include, exclude)
281+
.index_filter(index, include, exclude, Some(py_tuple.len()))
276282
.map_err(py_err_se_err)?;
277283
if let Some((next_include, next_exclude)) = op_next {
278284
let item_serialize =
@@ -287,7 +293,7 @@ impl TypeSerializer for TuplePositionalSerializer {
287293
let index = index2 + expected_length;
288294
let op_next = self
289295
.filter
290-
.index_filter(index, include, exclude)
296+
.index_filter(index, include, exclude, Some(py_tuple.len()))
291297
.map_err(py_err_se_err)?;
292298
if let Some((next_include, next_exclude)) = op_next {
293299
let item_serialize =

tests/benchmarks/test_serialization_micro.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,12 @@ def test_to_string_direct(benchmark):
459459
assert s.to_json(123) == b'"123"'
460460

461461
benchmark(s.to_json, 123)
462+
463+
464+
def test_filter(benchmark):
465+
v = SchemaSerializer(core_schema.list_schema(core_schema.any_schema()))
466+
assert v.to_python(['a', 'b', 'c', 'd', 'e'], include={-1, -2}) == ['d', 'e']
467+
468+
@benchmark
469+
def t():
470+
v.to_python(['a', 'b', 'c', 'd', 'e'], include={-1, -2})

tests/serializers/test_any.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def test_include_list_tuple(any_serializer, seq_f):
203203
assert any_serializer.to_json(seq_f('a', 'b', 'c')) == b'["a","b","c"]'
204204

205205
assert any_serializer.to_python(seq_f(0, 1, 2, 3), include={1, 2}) == seq_f(1, 2)
206+
assert any_serializer.to_python(seq_f(0, 1, 2, 3), include={-1, -2}) == seq_f(2, 3)
206207
assert any_serializer.to_python(seq_f(0, 1, 2, 3), include={1, 2}, mode='json') == [1, 2]
207208
assert any_serializer.to_python(seq_f('a', 'b', 'c', 'd'), include={1, 2}) == seq_f('b', 'c')
208209
assert any_serializer.to_python(seq_f('a', 'b', 'c', 'd'), include={1, 2}, mode='json') == ['b', 'c']
@@ -242,6 +243,12 @@ def test_include_dict(any_serializer):
242243
assert any_serializer.to_json(MyDataclass(a=1, b='foo', frog=2), include={'a'}) == b'{"a":1}'
243244

244245

246+
def test_exclude_dict(any_serializer):
247+
assert any_serializer.to_python({1: 2, '3': 4}) == {1: 2, '3': 4}
248+
assert any_serializer.to_python({1: 2, 3: 4}, exclude={1}) == {3: 4}
249+
assert any_serializer.to_python({1: 2, 3: 4}, exclude={-1}) == {1: 2, 3: 4}
250+
251+
245252
class FieldsSetModel:
246253
__pydantic_serializer__ = 42
247254
__slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__'

tests/serializers/test_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def test_include():
122122
'f',
123123
'g',
124124
]
125+
with pytest.raises(ValueError, match='Negative indices cannot be used to exclude items on unsized iterables'):
126+
v.to_python(gen_ok('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), include={-1: None, -2: None}, mode='json')
125127

126128

127129
def test_custom_serializer():

tests/serializers/test_list_tuple.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,26 @@ def test_include(schema_func, seq_f):
8787
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), include=[6]) == seq_f('b', 'd', 'f', 'g')
8888
assert v.to_json(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), include={6}) == b'["b","d","f","g"]'
8989
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), include={6: None}) == seq_f('b', 'd', 'f', 'g')
90+
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), include={-1: None, -2: None}, mode='json') == [
91+
'b',
92+
'd',
93+
'f',
94+
'g',
95+
'h',
96+
]
97+
98+
99+
@pytest.mark.parametrize(
100+
'schema_func,seq_f', [(core_schema.list_schema, as_list), (core_schema.tuple_variable_schema, as_tuple)]
101+
)
102+
def test_negative(schema_func, seq_f):
103+
v = SchemaSerializer(schema_func(core_schema.any_schema()))
104+
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e')) == seq_f('a', 'b', 'c', 'd', 'e')
105+
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e'), include={-1, -2}) == seq_f('d', 'e')
106+
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e'), include={-1: None, -2: None}) == seq_f('d', 'e')
107+
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e'), include={-1, -2}, mode='json') == ['d', 'e']
108+
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e'), include={-1: None, -2: None}, mode='json') == ['d', 'e']
109+
assert v.to_json(seq_f('a', 'b', 'c', 'd', 'e'), include={-1, -2}) == b'["d","e"]'
90110

91111

92112
@pytest.mark.parametrize(
@@ -116,7 +136,9 @@ def test_exclude(schema_func, seq_f):
116136
assert v.to_json(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h')) == b'["a","c","e","g","h"]'
117137
# the two exclude lists are combined via union as they used to be
118138
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), exclude={6}) == seq_f('a', 'c', 'e', 'h')
139+
assert v.to_python(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), exclude={-1, -2}) == seq_f('a', 'c', 'e')
119140
assert v.to_json(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), exclude={6}) == b'["a","c","e","h"]'
141+
assert v.to_json(seq_f('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'), exclude={-1, -2}) == b'["a","c","e"]'
120142

121143

122144
@pytest.mark.parametrize('include,exclude', [({1, 3, 5}, {5, 6}), ([1, 3, 5], [5, 6])])

0 commit comments

Comments
 (0)