Skip to content

[3.12] gh-123431: Harmonize extension code checks in pickle (GH-123434) #123460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def decode_long(data):
return int.from_bytes(data, byteorder='little', signed=True)


_NoValue = object()

# Pickling machinery

class _Pickler:
Expand Down Expand Up @@ -1091,11 +1093,16 @@ def save_global(self, obj, name=None):
(obj, module_name, name))

if self.proto >= 2:
code = _extension_registry.get((module_name, name))
if code:
assert code > 0
code = _extension_registry.get((module_name, name), _NoValue)
if code is not _NoValue:
if code <= 0xff:
write(EXT1 + pack("<B", code))
data = pack("<B", code)
if data == b'\0':
# Should never happen in normal circumstances,
# since the type and the value of the code are
# checked in copyreg.add_extension().
raise RuntimeError("extension code 0 is out of range")
write(EXT1 + data)
elif code <= 0xffff:
write(EXT2 + pack("<H", code))
else:
Expand Down Expand Up @@ -1589,9 +1596,8 @@ def load_ext4(self):
dispatch[EXT4[0]] = load_ext4

def get_extension(self, code):
nil = []
obj = _extension_cache.get(code, nil)
if obj is not nil:
obj = _extension_cache.get(code, _NoValue)
if obj is not _NoValue:
self.append(obj)
return
key = _inverted_registry.get(code)
Expand Down
51 changes: 51 additions & 0 deletions Lib/test/pickletester.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,35 @@ def find_class(module_name, global_name):
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))

def test_bad_ext_code(self):
# unregistered extension code
self.check_unpickling_error(ValueError, b'\x82\x01.')
self.check_unpickling_error(ValueError, b'\x82\xff.')
self.check_unpickling_error(ValueError, b'\x83\x01\x00.')
self.check_unpickling_error(ValueError, b'\x83\xff\xff.')
self.check_unpickling_error(ValueError, b'\x84\x01\x00\x00\x00.')
self.check_unpickling_error(ValueError, b'\x84\xff\xff\xff\x7f.')
# EXT specifies code <= 0
self.check_unpickling_error(pickle.UnpicklingError, b'\x82\x00.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x83\x00\x00.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x84\x00\x00\x00\x00.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x84\x00\x00\x00\x80.')
self.check_unpickling_error(pickle.UnpicklingError, b'\x84\xff\xff\xff\xff.')

@support.cpython_only
def test_bad_ext_inverted_registry(self):
code = 1
def check(key, exc):
with support.swap_item(copyreg._inverted_registry, code, key):
with self.assertRaises(exc):
self.loads(b'\x82\x01.')
check(None, ValueError)
check((), ValueError)
check((__name__,), (TypeError, ValueError))
check((__name__, "MyList", "x"), (TypeError, ValueError))
check((__name__, None), (TypeError, ValueError))
check((None, "MyList"), (TypeError, ValueError))

def test_bad_reduce(self):
self.assertEqual(self.loads(b'cbuiltins\nint\n)R.'), 0)
self.check_unpickling_error(TypeError, b'N)R.')
Expand Down Expand Up @@ -2033,6 +2062,28 @@ def persistent_id(self, obj):
check({Clearer(): 1, Clearer(): 2})
check({1: Clearer(), 2: Clearer()})

@support.cpython_only
def test_bad_ext_code(self):
# This should never happen in normal circumstances, because the type
# and the value of the extesion code is checked in copyreg.add_extension().
key = (__name__, 'MyList')
def check(code, exc):
assert key not in copyreg._extension_registry
assert code not in copyreg._inverted_registry
with (support.swap_item(copyreg._extension_registry, key, code),
support.swap_item(copyreg._inverted_registry, code, key)):
for proto in protocols[2:]:
with self.assertRaises(exc):
self.dumps(MyList, proto)

check(object(), TypeError)
check(None, TypeError)
check(-1, (RuntimeError, struct.error))
check(0, RuntimeError)
check(2**31, (RuntimeError, OverflowError, struct.error))
check(2**1000, (OverflowError, struct.error))
check(-2**1000, (OverflowError, struct.error))


class AbstractPickleTests:
# Subclass must define self.dumps, self.loads.
Expand Down
24 changes: 8 additions & 16 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -3725,31 +3725,23 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
code_obj = PyDict_GetItemWithError(st->extension_registry,
extension_key);
Py_DECREF(extension_key);
/* The object is not registered in the extension registry.
This is the most likely code path. */
if (code_obj == NULL) {
if (PyErr_Occurred()) {
goto error;
}
/* The object is not registered in the extension registry.
This is the most likely code path. */
goto gen_global;
}

/* XXX: pickle.py doesn't check neither the type, nor the range
of the value returned by the extension_registry. It should for
consistency. */

/* Verify code_obj has the right type and value. */
if (!PyLong_Check(code_obj)) {
PyErr_Format(st->PicklingError,
"Can't pickle %R: extension code %R isn't an integer",
obj, code_obj);
goto error;
}
code = PyLong_AS_LONG(code_obj);
Py_INCREF(code_obj);
code = PyLong_AsLong(code_obj);
Py_DECREF(code_obj);
if (code <= 0 || code > 0x7fffffffL) {
/* Should never happen in normal circumstances, since the type and
the value of the code are checked in copyreg.add_extension(). */
if (!PyErr_Occurred())
PyErr_Format(st->PicklingError, "Can't pickle %R: extension "
"code %ld is out of range", obj, code);
PyErr_Format(PyExc_RuntimeError, "extension code %ld is out of range", code);
goto error;
}

Expand Down
Loading