Skip to content

Commit 8be45e6

Browse files
authored
fix: 8405 pattern serialization (#1168)
1 parent c670b2b commit 8be45e6

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

src/serializers/infer.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ pub(crate) fn infer_to_python_known(
219219
PyList::new(py, items).into_py(py)
220220
}
221221
ObType::Path => value.str()?.into_py(py),
222+
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
222223
ObType::Unknown => {
223224
if let Some(fallback) = extra.fallback {
224225
let next_value = fallback.call1((value,))?;
@@ -505,6 +506,16 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
505506
let s = value.str().map_err(py_err_se_err)?.to_str().map_err(py_err_se_err)?;
506507
serializer.serialize_str(s)
507508
}
509+
ObType::Pattern => {
510+
let s = value
511+
.getattr(intern!(value.py(), "pattern"))
512+
.map_err(py_err_se_err)?
513+
.str()
514+
.map_err(py_err_se_err)?
515+
.to_str()
516+
.map_err(py_err_se_err)?;
517+
serializer.serialize_str(s)
518+
}
508519
ObType::Unknown => {
509520
if let Some(fallback) = extra.fallback {
510521
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
@@ -628,6 +639,7 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra:
628639
infer_json_key(k, extra)
629640
}
630641
ObType::Path => Ok(key.str()?.to_string_lossy()),
642+
ObType::Pattern => Ok(key.getattr(intern!(key.py(), "pattern"))?.str()?.to_string_lossy()),
631643
ObType::Unknown => {
632644
if let Some(fallback) = extra.fallback {
633645
let next_key = fallback.call1((key,))?;

src/serializers/ob_type.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ pub struct ObTypeLookup {
4444
generator_object: PyObject,
4545
// path
4646
path_object: PyObject,
47+
// pattern
48+
pattern_object: PyObject,
4749
// uuid type
4850
uuid_object: PyObject,
4951
}
@@ -87,6 +89,7 @@ impl ObTypeLookup {
8789
.unwrap()
8890
.to_object(py),
8991
path_object: py.import("pathlib").unwrap().getattr("Path").unwrap().to_object(py),
92+
pattern_object: py.import("re").unwrap().getattr("Pattern").unwrap().to_object(py),
9093
uuid_object: py.import("uuid").unwrap().getattr("UUID").unwrap().to_object(py),
9194
}
9295
}
@@ -150,6 +153,7 @@ impl ObTypeLookup {
150153
ObType::Enum => self.enum_object.as_ptr() as usize == ob_type,
151154
ObType::Generator => self.generator_object.as_ptr() as usize == ob_type,
152155
ObType::Path => self.path_object.as_ptr() as usize == ob_type,
156+
ObType::Pattern => self.path_object.as_ptr() as usize == ob_type,
153157
ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type,
154158
ObType::Unknown => false,
155159
};
@@ -242,6 +246,8 @@ impl ObTypeLookup {
242246
ObType::Generator
243247
} else if ob_type == self.path_object.as_ptr() as usize {
244248
ObType::Path
249+
} else if ob_type == self.pattern_object.as_ptr() as usize {
250+
ObType::Pattern
245251
} else {
246252
// this allows for subtypes of the supported class types,
247253
// if `ob_type` didn't match any member of self, we try again with the next base type pointer
@@ -319,6 +325,8 @@ impl ObTypeLookup {
319325
ObType::Generator
320326
} else if value.is_instance(self.path_object.as_ref(py)).unwrap_or(false) {
321327
ObType::Path
328+
} else if value.is_instance(self.pattern_object.as_ref(py)).unwrap_or(false) {
329+
ObType::Pattern
322330
} else {
323331
ObType::Unknown
324332
}
@@ -396,6 +404,8 @@ pub enum ObType {
396404
Generator,
397405
// Path
398406
Path,
407+
//Pattern,
408+
Pattern,
399409
// Uuid
400410
Uuid,
401411
// unknown type

tests/serializers/test_any.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import json
33
import platform
4+
import re
45
import sys
56
from collections import namedtuple
67
from datetime import date, datetime, time, timedelta, timezone
@@ -437,7 +438,7 @@ def test_base64():
437438
(lambda: MyEnum.a, {}, b'1'),
438439
(lambda: MyEnum.b, {}, b'"b"'),
439440
(lambda: [MyDataclass(1, 'a', 2), MyModel(a=2, b='b')], {}, b'[{"a":1,"b":"a"},{"a":2,"b":"b"}]'),
440-
# # (lambda: re.compile('^regex$'), b'"^regex$"'),
441+
(lambda: re.compile('^regex$'), {}, b'"^regex$"'),
441442
],
442443
)
443444
def test_encoding(any_serializer, gen_input, kwargs, expected_json):

0 commit comments

Comments
 (0)