Skip to content

Commit 688007f

Browse files
authored
fix: Pass kwargs from create_model to Model constructors (aws#1377)
* fix: Pass kwargs from create_model to Model constructors * Test kwargs with dummy env
1 parent 4d0c2e6 commit 688007f

File tree

6 files changed

+21
-3
lines changed

6 files changed

+21
-3
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def create_model(
203203
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
204204
object. See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
205205
"""
206+
if "image" not in kwargs:
207+
kwargs["image"] = self.image_name
208+
206209
return ChainerModel(
207210
self.model_data,
208211
role or self.role,
@@ -215,10 +218,10 @@ def create_model(
215218
py_version=self.py_version,
216219
framework_version=self.framework_version,
217220
model_server_workers=model_server_workers,
218-
image=kwargs["image"] if "image" in kwargs else self.image_name,
219221
sagemaker_session=self.sagemaker_session,
220222
vpc_config=self.get_vpc_config(vpc_config_override),
221223
dependencies=(dependencies or self.dependencies),
224+
**kwargs
222225
)
223226

224227
@classmethod

src/sagemaker/mxnet/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def create_model(
206206
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
207207
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
208208
"""
209+
if "image" not in kwargs:
210+
kwargs["image"] = image_name or self.image_name
211+
209212
return MXNetModel(
210213
self.model_data,
211214
role or self.role,
@@ -217,11 +220,11 @@ def create_model(
217220
code_location=self.code_location,
218221
py_version=self.py_version,
219222
framework_version=self.framework_version,
220-
image=kwargs["image"] if "image" in kwargs else (image_name or self.image_name),
221223
model_server_workers=model_server_workers,
222224
sagemaker_session=self.sagemaker_session,
223225
vpc_config=self.get_vpc_config(vpc_config_override),
224226
dependencies=(dependencies or self.dependencies),
227+
**kwargs
225228
)
226229

227230
@classmethod

src/sagemaker/pytorch/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ def create_model(
164164
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``
165165
object. See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
166166
"""
167+
if "image" not in kwargs:
168+
kwargs["image"] = self.image_name
169+
167170
return PyTorchModel(
168171
self.model_data,
169172
role or self.role,
@@ -175,11 +178,11 @@ def create_model(
175178
code_location=self.code_location,
176179
py_version=self.py_version,
177180
framework_version=self.framework_version,
178-
image=kwargs["image"] if "image" in kwargs else self.image_name,
179181
model_server_workers=model_server_workers,
180182
sagemaker_session=self.sagemaker_session,
181183
vpc_config=self.get_vpc_config(vpc_config_override),
182184
dependencies=(dependencies or self.dependencies),
185+
**kwargs
183186
)
184187

185188
@classmethod

tests/unit/test_chainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
3232
SERVING_SCRIPT_FILE = "another_dummy_script.py"
3333
MODEL_DATA = "s3://some/data.tar.gz"
34+
ENV = {"DUMMY_ENV_VAR": "dummy_value"}
3435
TIMESTAMP = "2017-11-06-14:14:15.672"
3536
TIME = 1507167947
3637
BUCKET_NAME = "mybucket"
@@ -326,12 +327,14 @@ def test_create_model_with_optional_params(sagemaker_session):
326327
model_server_workers=model_server_workers,
327328
vpc_config_override=vpc_config,
328329
entry_point=SERVING_SCRIPT_FILE,
330+
env=ENV,
329331
)
330332

331333
assert model.role == new_role
332334
assert model.model_server_workers == model_server_workers
333335
assert model.vpc_config == vpc_config
334336
assert model.entry_point == SERVING_SCRIPT_FILE
337+
assert model.env == ENV
335338

336339

337340
def test_create_model_with_custom_image(sagemaker_session):

tests/unit/test_mxnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
3131
SERVING_SCRIPT_FILE = "another_dummy_script.py"
3232
MODEL_DATA = "s3://mybucket/model"
33+
ENV = {"DUMMY_ENV_VAR": "dummy_value"}
3334
TIMESTAMP = "2017-11-06-14:14:15.672"
3435
TIME = 1507167947
3536
BUCKET_NAME = "mybucket"
@@ -231,12 +232,14 @@ def test_create_model_with_optional_params(sagemaker_session):
231232
model_server_workers=model_server_workers,
232233
vpc_config_override=vpc_config,
233234
entry_point=SERVING_SCRIPT_FILE,
235+
env=ENV,
234236
)
235237

236238
assert model.role == new_role
237239
assert model.model_server_workers == model_server_workers
238240
assert model.vpc_config == vpc_config
239241
assert model.entry_point == SERVING_SCRIPT_FILE
242+
assert model.env == ENV
240243

241244

242245
def test_create_model_with_custom_image(sagemaker_session):

tests/unit/test_pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
2929
SERVING_SCRIPT_FILE = "another_dummy_script.py"
3030
MODEL_DATA = "s3://some/data.tar.gz"
31+
ENV = {"DUMMY_ENV_VAR": "dummy_value"}
3132
TIMESTAMP = "2017-11-06-14:14:15.672"
3233
TIME = 1507167947
3334
BUCKET_NAME = "mybucket"
@@ -212,12 +213,14 @@ def test_create_model_with_optional_params(sagemaker_session):
212213
model_server_workers=model_server_workers,
213214
vpc_config_override=vpc_config,
214215
entry_point=SERVING_SCRIPT_FILE,
216+
env=ENV,
215217
)
216218

217219
assert model.role == new_role
218220
assert model.model_server_workers == model_server_workers
219221
assert model.vpc_config == vpc_config
220222
assert model.entry_point == SERVING_SCRIPT_FILE
223+
assert model.env == ENV
221224

222225

223226
def test_create_model_with_custom_image(sagemaker_session):

0 commit comments

Comments
 (0)