File tree Expand file tree Collapse file tree 6 files changed +21
-3
lines changed Expand file tree Collapse file tree 6 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -203,6 +203,9 @@ def create_model(
203
203
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
204
204
object. See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
205
205
"""
206
+ if "image" not in kwargs :
207
+ kwargs ["image" ] = self .image_name
208
+
206
209
return ChainerModel (
207
210
self .model_data ,
208
211
role or self .role ,
@@ -215,10 +218,10 @@ def create_model(
215
218
py_version = self .py_version ,
216
219
framework_version = self .framework_version ,
217
220
model_server_workers = model_server_workers ,
218
- image = kwargs ["image" ] if "image" in kwargs else self .image_name ,
219
221
sagemaker_session = self .sagemaker_session ,
220
222
vpc_config = self .get_vpc_config (vpc_config_override ),
221
223
dependencies = (dependencies or self .dependencies ),
224
+ ** kwargs
222
225
)
223
226
224
227
@classmethod
Original file line number Diff line number Diff line change @@ -206,6 +206,9 @@ def create_model(
206
206
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
207
207
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
208
208
"""
209
+ if "image" not in kwargs :
210
+ kwargs ["image" ] = image_name or self .image_name
211
+
209
212
return MXNetModel (
210
213
self .model_data ,
211
214
role or self .role ,
@@ -217,11 +220,11 @@ def create_model(
217
220
code_location = self .code_location ,
218
221
py_version = self .py_version ,
219
222
framework_version = self .framework_version ,
220
- image = kwargs ["image" ] if "image" in kwargs else (image_name or self .image_name ),
221
223
model_server_workers = model_server_workers ,
222
224
sagemaker_session = self .sagemaker_session ,
223
225
vpc_config = self .get_vpc_config (vpc_config_override ),
224
226
dependencies = (dependencies or self .dependencies ),
227
+ ** kwargs
225
228
)
226
229
227
230
@classmethod
Original file line number Diff line number Diff line change @@ -164,6 +164,9 @@ def create_model(
164
164
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``
165
165
object. See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
166
166
"""
167
+ if "image" not in kwargs :
168
+ kwargs ["image" ] = self .image_name
169
+
167
170
return PyTorchModel (
168
171
self .model_data ,
169
172
role or self .role ,
@@ -175,11 +178,11 @@ def create_model(
175
178
code_location = self .code_location ,
176
179
py_version = self .py_version ,
177
180
framework_version = self .framework_version ,
178
- image = kwargs ["image" ] if "image" in kwargs else self .image_name ,
179
181
model_server_workers = model_server_workers ,
180
182
sagemaker_session = self .sagemaker_session ,
181
183
vpc_config = self .get_vpc_config (vpc_config_override ),
182
184
dependencies = (dependencies or self .dependencies ),
185
+ ** kwargs
183
186
)
184
187
185
188
@classmethod
Original file line number Diff line number Diff line change 31
31
SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
32
32
SERVING_SCRIPT_FILE = "another_dummy_script.py"
33
33
MODEL_DATA = "s3://some/data.tar.gz"
34
+ ENV = {"DUMMY_ENV_VAR" : "dummy_value" }
34
35
TIMESTAMP = "2017-11-06-14:14:15.672"
35
36
TIME = 1507167947
36
37
BUCKET_NAME = "mybucket"
@@ -326,12 +327,14 @@ def test_create_model_with_optional_params(sagemaker_session):
326
327
model_server_workers = model_server_workers ,
327
328
vpc_config_override = vpc_config ,
328
329
entry_point = SERVING_SCRIPT_FILE ,
330
+ env = ENV ,
329
331
)
330
332
331
333
assert model .role == new_role
332
334
assert model .model_server_workers == model_server_workers
333
335
assert model .vpc_config == vpc_config
334
336
assert model .entry_point == SERVING_SCRIPT_FILE
337
+ assert model .env == ENV
335
338
336
339
337
340
def test_create_model_with_custom_image (sagemaker_session ):
Original file line number Diff line number Diff line change 30
30
SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
31
31
SERVING_SCRIPT_FILE = "another_dummy_script.py"
32
32
MODEL_DATA = "s3://mybucket/model"
33
+ ENV = {"DUMMY_ENV_VAR" : "dummy_value" }
33
34
TIMESTAMP = "2017-11-06-14:14:15.672"
34
35
TIME = 1507167947
35
36
BUCKET_NAME = "mybucket"
@@ -231,12 +232,14 @@ def test_create_model_with_optional_params(sagemaker_session):
231
232
model_server_workers = model_server_workers ,
232
233
vpc_config_override = vpc_config ,
233
234
entry_point = SERVING_SCRIPT_FILE ,
235
+ env = ENV ,
234
236
)
235
237
236
238
assert model .role == new_role
237
239
assert model .model_server_workers == model_server_workers
238
240
assert model .vpc_config == vpc_config
239
241
assert model .entry_point == SERVING_SCRIPT_FILE
242
+ assert model .env == ENV
240
243
241
244
242
245
def test_create_model_with_custom_image (sagemaker_session ):
Original file line number Diff line number Diff line change 28
28
SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
29
29
SERVING_SCRIPT_FILE = "another_dummy_script.py"
30
30
MODEL_DATA = "s3://some/data.tar.gz"
31
+ ENV = {"DUMMY_ENV_VAR" : "dummy_value" }
31
32
TIMESTAMP = "2017-11-06-14:14:15.672"
32
33
TIME = 1507167947
33
34
BUCKET_NAME = "mybucket"
@@ -212,12 +213,14 @@ def test_create_model_with_optional_params(sagemaker_session):
212
213
model_server_workers = model_server_workers ,
213
214
vpc_config_override = vpc_config ,
214
215
entry_point = SERVING_SCRIPT_FILE ,
216
+ env = ENV ,
215
217
)
216
218
217
219
assert model .role == new_role
218
220
assert model .model_server_workers == model_server_workers
219
221
assert model .vpc_config == vpc_config
220
222
assert model .entry_point == SERVING_SCRIPT_FILE
223
+ assert model .env == ENV
221
224
222
225
223
226
def test_create_model_with_custom_image (sagemaker_session ):
You can’t perform that action at this time.
0 commit comments