Skip to content

Commit 543cf61

Browse files
author
Balaji Veeramani
committed
Fix integration tests
1 parent 5bb6dcc commit 543cf61

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

tests/integ/test_byo_estimator.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import sagemaker
2121
from sagemaker import image_uris
2222
from sagemaker.estimator import Estimator
23+
from sagemaker.serializers import BaseSerializer
2324
from sagemaker.utils import unique_name_from_base
2425
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, datasets
2526
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
@@ -35,11 +36,15 @@ def training_set():
3536
return datasets.one_p_mnist()
3637

3738

38-
def fm_serializer(data):
39-
js = {"instances": []}
40-
for row in data:
41-
js["instances"].append({"features": row.tolist()})
42-
return json.dumps(js)
39+
class _FactorizationMachineSerializer(BaseSerializer):
40+
41+
CONTENT_TYPE = "application/json"
42+
43+
def serialize(data):
44+
js = {"instances": []}
45+
for row in data:
46+
js["instances"].append({"features": row.tolist()})
47+
return json.dumps(js)
4348

4449

4550
@pytest.mark.canary_quick
@@ -84,8 +89,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
8489
with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
8590
model = estimator.create_model()
8691
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
87-
predictor.serializer = fm_serializer
88-
predictor.content_type = "application/json"
92+
predictor.serializer = _FactorizationMachineSerializer()
8993
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
9094

9195
result = predictor.predict(training_set[0][:10])
@@ -130,8 +134,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
130134
)
131135
model = estimator.create_model()
132136
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
133-
predictor.serializer = fm_serializer
134-
predictor.content_type = "application/json"
137+
predictor.serializer = _FactorizationMachineSerializer()
135138
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
136139

137140
result = predictor.predict(training_set[0][:10])

tests/integ/test_tuner.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.estimator import Estimator
2929
from sagemaker.mxnet.estimator import MXNet
3030
from sagemaker.pytorch import PyTorch
31+
from sagemaker.serializers import BaseSerializer
3132
from sagemaker.tensorflow import TensorFlow
3233
from sagemaker.tuner import (
3334
IntegerParameter,
@@ -900,8 +901,7 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
900901
best_training_job = tuner.best_training_job()
901902
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
902903
predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job)
903-
predictor.serializer = _fm_serializer
904-
predictor.content_type = "application/json"
904+
predictor.serializer = _FactorizationMachineSerializer()
905905
predictor.deserializer = JSONDeserializer()
906906

907907
result = predictor.predict(datasets.one_p_mnist()[0][:10])
@@ -912,11 +912,15 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
912912

913913

914914
# Serializer for the Factorization Machines predictor (for BYO example)
915-
def _fm_serializer(data):
916-
js = {"instances": []}
917-
for row in data:
918-
js["instances"].append({"features": row.tolist()})
919-
return json.dumps(js)
915+
class _FactorizationMachineSerializer(BaseSerializer):
916+
917+
CONTENT_TYPE = "application/json"
918+
919+
def serializer(data):
920+
js = {"instances": []}
921+
for row in data:
922+
js["instances"].append({"features": row.tolist()})
923+
return json.dumps(js)
920924

921925

922926
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):

tests/integ/test_tuner_multi_algo.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
from sagemaker import image_uris, utils
2121
from sagemaker.analytics import HyperparameterTuningJobAnalytics
22-
from sagemaker.content_types import CONTENT_TYPE_JSON
2322
from sagemaker.deserializers import JSONDeserializer
2423
from sagemaker.estimator import Estimator
24+
from sagemaker.serializers import BaseSerializer
2525
from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner
2626
from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES
2727
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
@@ -214,14 +214,17 @@ def _create_training_inputs(sagemaker_session):
214214

215215

216216
def _make_prediction(predictor, data):
217-
predictor.serializer = _prediction_data_serializer
218-
predictor.content_type = CONTENT_TYPE_JSON
217+
predictor.serializer = PredictionDataSerializer()
219218
predictor.deserializer = JSONDeserializer()
220219
return predictor.predict(data)
221220

222221

223-
def _prediction_data_serializer(data):
224-
js = {"instances": []}
225-
for row in data:
226-
js["instances"].append({"features": row.tolist()})
227-
return json.dumps(js)
222+
class PredictionDataSerializer(BaseSerializer):
223+
224+
CONTENT_TYPE = "application/json"
225+
226+
def serialize(data):
227+
js = {"instances": []}
228+
for row in data:
229+
js["instances"].append({"features": row.tolist()})
230+
return json.dumps(js)

0 commit comments

Comments
 (0)