Skip to content

gh-123431: Harmonize extension code checks in pickle #123434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,11 +1086,16 @@ def save_global(self, obj, name=None):

module_name = whichmodule(obj, name)
if self.proto >= 2:
code = _extension_registry.get((module_name, name))
if code:
assert code > 0
code = _extension_registry.get((module_name, name), _NoValue)
if code is not _NoValue:
if code <= 0xff:
write(EXT1 + pack("<B", code))
data = pack("<B", code)
if data == b'\0':
# Should never happen in normal circumstances,
# since the type and the value of the code are
# checked in copyreg.add_extension().
raise RuntimeError("extension code 0 is out of range")
write(EXT1 + data)
elif code <= 0xffff:
write(EXT2 + pack("<H", code))
else:
Expand Down Expand Up @@ -1581,9 +1586,8 @@ def load_ext4(self):
dispatch[EXT4[0]] = load_ext4

def get_extension(self, code):
nil = []
obj = _extension_cache.get(code, nil)
if obj is not nil:
obj = _extension_cache.get(code, _NoValue)
if obj is not _NoValue:
self.append(obj)
return
key = _inverted_registry.get(code)
Expand Down
51 changes: 51 additions & 0 deletions Lib/test/pickletester.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,35 @@ def find_class(module_name, global_name):
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))

def test_bad_ext_code(self):
# unregistered extension code
self.check_unpickling_error(ValueError, b'\x82\x01.')
self.check_unpickling_error(ValueError, b'\x82\xff.')
self.check_unpickling_error(ValueError, b'\x83\x01\x00.')
self.check_unpickling_error(ValueError, b'\x83\xff\xff.')
self.check_unpickling_error(ValueError, b'\x84\x01\x00\x00\x00.')
self.check_unpickling_error(ValueError, b'\x84\xff\xff\xff\x7f.')
# EXT specifies code <= 0
self.check_unpickling_error(pickle.UnpicklingError, b'\x82\x00.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x83\x00\x00.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x84\x00\x00\x00\x00.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x84\x00\x00\x00\x80.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x84\xff\xff\xff\xff.')

@support.cpython_only
def test_bad_ext_inverted_registry(self):
code = 1
def check(key, exc):
with support.swap_item(copyreg._inverted_registry, code, key):
with self.assertRaises(exc):
self.loads(b'\x82\x01.')
check(None, ValueError)
check((), ValueError)
check((__name__,), (TypeError, ValueError))
check((__name__, "MyList", "x"), (TypeError, ValueError))
check((__name__, None), (TypeError, ValueError))
check((None, "MyList"), (TypeError, ValueError))

def test_bad_reduce(self):
self.assertEqual(self.loads(b'cbuiltins\nint\n)R.'), 0)
self.check_unpickling_error(TypeError, b'N)R.')
Expand Down Expand Up @@ -2163,6 +2192,28 @@ def persistent_id(self, obj):
check({Clearer(): 1, Clearer(): 2})
check({1: Clearer(), 2: Clearer()})

@support.cpython_only
def test_bad_ext_code(self):
# This should never happen in normal circumstances, because the type
# and the value of the extesion code is checked in copyreg.add_extension().
key = (__name__, 'MyList')
def check(code, exc):
assert key not in copyreg._extension_registry
assert code not in copyreg._inverted_registry
with (support.swap_item(copyreg._extension_registry, key, code),
support.swap_item(copyreg._inverted_registry, code, key)):
for proto in protocols[2:]:
with self.assertRaises(exc):
self.dumps(MyList, proto)
Comment on lines +2206 to +2207
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with self.assertRaises(exc):
self.dumps(MyList, proto)
with self.subTest(proto=proto):
with self.assertRaises(exc):
self.dumps(MyList, proto)

Just to know which protocol failed (I see tabs and spaces on GH so I hope that if you commit my suggestions, it will be normalized...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not important, because the code is the same for all protocols >= 2. Original tests for this feature only use protocol 2.

On other hand, using subTest() will produce worse tracebacks if some test fails. Currently the traceback contains the line with the check() call. subTest() truncates the traceback.


check(object(), TypeError)
check(None, TypeError)
check(-1, (RuntimeError, struct.error))
check(0, RuntimeError)
check(2**31, (RuntimeError, OverflowError, struct.error))
check(2**1000, (OverflowError, struct.error))
check(-2**1000, (OverflowError, struct.error))


class AbstractPickleTests:
# Subclass must define self.dumps, self.loads.
Expand Down
32 changes: 11 additions & 21 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -3650,34 +3650,24 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
if (extension_key == NULL) {
goto error;
}
code_obj = PyDict_GetItemWithError(st->extension_registry,
extension_key);
if (PyDict_GetItemRef(st->extension_registry, extension_key, &code_obj) < 0) {
Py_DECREF(extension_key);
goto error;
}
Py_DECREF(extension_key);
/* The object is not registered in the extension registry.
This is the most likely code path. */
if (code_obj == NULL) {
if (PyErr_Occurred()) {
goto error;
}
/* The object is not registered in the extension registry.
This is the most likely code path. */
goto gen_global;
}

/* XXX: pickle.py doesn't check neither the type, nor the range
of the value returned by the extension_registry. It should for
consistency. */

/* Verify code_obj has the right type and value. */
if (!PyLong_Check(code_obj)) {
PyErr_Format(st->PicklingError,
"Can't pickle %R: extension code %R isn't an integer",
obj, code_obj);
goto error;
}
code = PyLong_AS_LONG(code_obj);
code = PyLong_AsLong(code_obj);
Py_DECREF(code_obj);
if (code <= 0 || code > 0x7fffffffL) {
/* Should never happen in normal circumstances, since the type and
the value of the code are checked in copyreg.add_extension(). */
if (!PyErr_Occurred())
PyErr_Format(st->PicklingError, "Can't pickle %R: extension "
"code %ld is out of range", obj, code);
PyErr_Format(PyExc_RuntimeError, "extension code %ld is out of range", code);
goto error;
}

Expand Down
Loading