Skip to content

Commit 1da20ad

Browse files
author
Balaji Veeramani
committed
Update deserializers.py
1 parent be1deba commit 1da20ad

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/sagemaker/deserializers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,15 @@ class NumpyDeserializer(BaseDeserializer):
163163

164164
ACCEPT = "application/x-npy"
165165

166-
def __init__(self, dtype=None):
167-
"""Initialize the dtype.
166+
def __init__(self, dtype=None, allow_pickle=True):
167+
"""Initialize the dtype and allow_pickle arguments.
168168
169169
Args:
170-
dtype (str): The dtype of the data.
170+
dtype (str): The dtype of the data (default: None).
171+
allow_pickle (bool): Allow loading pickled object arrays (default: True).
171172
"""
172173
self.dtype = dtype
174+
self.allow_pickle = allow_pickle
173175

174176
def deserialize(self, stream, content_type):
175177
"""Deserialize data from an inference endpoint into a NumPy array.
@@ -189,7 +191,7 @@ def deserialize(self, stream, content_type):
189191
if content_type == "application/json":
190192
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
191193
if content_type == "application/x-npy":
192-
return np.load(io.BytesIO(stream.read()))
194+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
193195
finally:
194196
stream.close()
195197

0 commit comments

Comments
 (0)