Skip to content

Commit 0a0d7ec

Browse files
authored
change: Add allow_pickle parameter to NumpyDeserializer (#1755)
1 parent be1deba commit 0a0d7ec

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

src/sagemaker/deserializers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,15 @@ class NumpyDeserializer(BaseDeserializer):
163163

164164
ACCEPT = "application/x-npy"
165165

166-
def __init__(self, dtype=None):
167-
"""Initialize the dtype.
166+
def __init__(self, dtype=None, allow_pickle=True):
167+
"""Initialize the dtype and allow_pickle arguments.
168168
169169
Args:
170-
dtype (str): The dtype of the data.
170+
dtype (str): The dtype of the data (default: None).
171+
allow_pickle (bool): Allow loading pickled object arrays (default: True).
171172
"""
172173
self.dtype = dtype
174+
self.allow_pickle = allow_pickle
173175

174176
def deserialize(self, stream, content_type):
175177
"""Deserialize data from an inference endpoint into a NumPy array.
@@ -189,7 +191,7 @@ def deserialize(self, stream, content_type):
189191
if content_type == "application/json":
190192
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
191193
if content_type == "application/x-npy":
192-
return np.load(io.BytesIO(stream.read()))
194+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
193195
finally:
194196
stream.close()
195197

tests/unit/sagemaker/test_deserializers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_numpy_deserializer_from_npy(numpy_deserializer):
141141

142142

143143
def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
144-
array = np.array(["one", "two"])
144+
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
145145
stream = io.BytesIO()
146146
np.save(stream, array)
147147
stream.seek(0)
@@ -151,6 +151,18 @@ def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
151151
assert np.array_equal(array, result)
152152

153153

154+
def test_numpy_deserializer_from_npy_object_array_with_allow_pickle_false():
155+
numpy_deserializer = NumpyDeserializer(allow_pickle=False)
156+
157+
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
158+
stream = io.BytesIO()
159+
np.save(stream, array)
160+
stream.seek(0)
161+
162+
with pytest.raises(ValueError):
163+
numpy_deserializer.deserialize(stream, "application/x-npy")
164+
165+
154166
@pytest.fixture
155167
def json_deserializer():
156168
return JSONDeserializer()

0 commit comments

Comments
 (0)