Skip to content

Commit 2f1c4cd

Browse files
authored
Merge branch 'zwei' into serde-compatability
2 parents 5737872 + 95671e0 commit 2f1c4cd

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

doc/v2.rst

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,34 @@ Specify Custom Serving Image
242242

243243
The ``image`` parameter has been renamed to ``image_uri`` for specifying a custom Docker image URI to use with inference.
244244

245+
TensorFlow Serving Model
246+
~~~~~~~~~~~~~~~~~~~~~~~~
247+
248+
``sagemaker.tensorflow.serving.Model`` has been renamed to :class:`sagemaker.tensorflow.model.TensorFlowModel`.
249+
(For the previous implementation of that class, see `Deprecate Legacy TensorFlow <#deprecate-legacy-tensorflow>`_).
250+
245251
Predictors
246252
----------
247253

254+
Generic Predictor Class Name
255+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
256+
248257
``sagemaker.predictor.RealTimePredictor`` has been renamed to :class:`sagemaker.predictor.Predictor`.
249258

250-
In addition, for :class:`sagemaker.predictor.Predictor`, :class:`sagemaker.sparkml.model.SparkMLPredictor`,
259+
Endpoint Argument Name
260+
~~~~~~~~~~~~~~~~~~~~~~
261+
262+
For :class:`sagemaker.predictor.Predictor`, :class:`sagemaker.sparkml.model.SparkMLPredictor`,
251263
and predictors for Amazon algorithm (e.g. Factorization Machines, Linear Learner, etc.),
252264
the ``endpoint`` attribute has been renamed to ``endpoint_name``.
253265

266+
TensorFlow Serving Predictor
267+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
268+
269+
``sagemaker.tensorflow.serving.Predictor`` has been renamed to :class:`sagemaker.tensorflow.model.TensorFlowPredictor`.
270+
(For the previous implementation of that class, see `Deprecate Legacy TensorFlow <#deprecate-legacy-tensorflow>`_).
271+
272+
254273
Airflow
255274
-------
256275

src/sagemaker/deserializers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,15 @@ class NumpyDeserializer(BaseDeserializer):
163163

164164
ACCEPT = "application/x-npy"
165165

166-
def __init__(self, dtype=None):
167-
"""Initialize the dtype.
166+
def __init__(self, dtype=None, allow_pickle=True):
167+
"""Initialize the dtype and allow_pickle arguments.
168168
169169
Args:
170-
dtype (str): The dtype of the data.
170+
dtype (str): The dtype of the data (default: None).
171+
allow_pickle (bool): Allow loading pickled object arrays (default: True).
171172
"""
172173
self.dtype = dtype
174+
self.allow_pickle = allow_pickle
173175

174176
def deserialize(self, stream, content_type):
175177
"""Deserialize data from an inference endpoint into a NumPy array.
@@ -189,7 +191,7 @@ def deserialize(self, stream, content_type):
189191
if content_type == "application/json":
190192
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
191193
if content_type == "application/x-npy":
192-
return np.load(io.BytesIO(stream.read()))
194+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
193195
finally:
194196
stream.close()
195197

tests/unit/sagemaker/test_deserializers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_numpy_deserializer_from_npy(numpy_deserializer):
141141

142142

143143
def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
144-
array = np.array(["one", "two"])
144+
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
145145
stream = io.BytesIO()
146146
np.save(stream, array)
147147
stream.seek(0)
@@ -151,6 +151,18 @@ def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
151151
assert np.array_equal(array, result)
152152

153153

154+
def test_numpy_deserializer_from_npy_object_array_with_allow_pickle_false():
155+
numpy_deserializer = NumpyDeserializer(allow_pickle=False)
156+
157+
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
158+
stream = io.BytesIO()
159+
np.save(stream, array)
160+
stream.seek(0)
161+
162+
with pytest.raises(ValueError):
163+
numpy_deserializer.deserialize(stream, "application/x-npy")
164+
165+
154166
@pytest.fixture
155167
def json_deserializer():
156168
return JSONDeserializer()

0 commit comments

Comments
 (0)