Skip to content

Commit 64024cb

Browse files
author
Balaji Veeramani
committed
Fix tests
1 parent 06f69b2 commit 64024cb

File tree

5 files changed

+13
-5
lines changed

5 files changed

+13
-5
lines changed

src/sagemaker/multidatamodel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,12 @@ def deploy(
211211
enable_network_isolation = self.model.enable_network_isolation()
212212
role = self.model.role
213213
vpc_config = self.model.vpc_config
214-
predictor = self.model.predictor_cls
214+
predictor_cls = self.model.predictor_cls
215215
else:
216216
enable_network_isolation = self.enable_network_isolation()
217217
role = self.role
218218
vpc_config = self.vpc_config
219-
predictor = self.predictor_cls
219+
predictor_cls = self.predictor_cls
220220

221221
if role is None:
222222
raise ValueError("Role can not be null for deploying a model")
@@ -255,8 +255,8 @@ def deploy(
255255
data_capture_config_dict=data_capture_config_dict,
256256
)
257257

258-
if predictor:
259-
predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session)
258+
if predictor_cls:
259+
predictor = predictor_cls(self.endpoint_name, self.sagemaker_session)
260260
if serializer:
261261
predictor.serializer = serializer
262262
if deserializer:

src/sagemaker/tensorflow/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def deploy(
196196
self,
197197
initial_instance_count,
198198
instance_type,
199+
serializer=None,
200+
deserializer=None,
199201
accelerator_type=None,
200202
endpoint_name=None,
201203
tags=None,
@@ -211,6 +213,8 @@ def deploy(
211213
return super(TensorFlowModel, self).deploy(
212214
initial_instance_count=initial_instance_count,
213215
instance_type=instance_type,
216+
serializer=serializer,
217+
deserializer=deserializer,
214218
accelerator_type=accelerator_type,
215219
endpoint_name=endpoint_name,
216220
tags=tags,

tests/unit/sagemaker/automl/test_auto_ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ def test_deploy_optional_args(candidate_estimator, sagemaker_session, candidate_
591591
mock_pipeline.deploy.assert_called_with(
592592
initial_instance_count=INSTANCE_COUNT,
593593
instance_type=INSTANCE_TYPE,
594+
serializer=None,
595+
deserializer=None,
594596
endpoint_name=JOB_NAME,
595597
tags=TAGS,
596598
wait=False,

tests/unit/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2122,7 +2122,7 @@ def test_generic_deploy_accelerator_type(sagemaker_session):
21222122
IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session
21232123
)
21242124
e.fit({"train": "s3://bucket/training-prefix"})
2125-
e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, ACCELERATOR_TYPE)
2125+
e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
21262126

21272127
args = e.sagemaker_session.endpoint_from_production_variants.call_args[1]
21282128
print(args)

tests/unit/test_tuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,8 @@ def test_deploy_optional_params(_get_best_training_job, best_estimator, tuner):
893893
estimator.deploy.assert_called_with(
894894
initial_instance_count=INSTANCE_COUNT,
895895
instance_type=INSTANCE_TYPE,
896+
serializer=None,
897+
deserializer=None,
896898
accelerator_type=accelerator,
897899
endpoint_name=endpoint_name,
898900
wait=False,

0 commit comments

Comments
 (0)