Skip to content

Commit 1fd6711

Browse files
committed
make args passed to compile() take precedence
1 parent b73917b commit 1fd6711

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

src/sagemaker/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def compile(
294294
sagemaker.model.Model: A SageMaker ``Model`` object. See
295295
:func:`~sagemaker.model.Model` for full details.
296296
"""
297-
framework = self._framework() or framework
297+
framework = framework or self._framework()
298298
if framework is None:
299299
raise ValueError(
300300
"You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
@@ -308,7 +308,7 @@ def compile(
308308
if self.model_data is None:
309309
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
310310

311-
framework_version = self._get_framework_version() or framework_version
311+
framework_version = framework_version or self._get_framework_version()
312312

313313
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
314314
config = self._compilation_job_config(

tests/unit/test_mxnet.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,30 +360,31 @@ def test_mxnet(
360360
)
361361

362362

363-
@patch("sagemaker.utils.create_tar_file", MagicMock())
363+
@patch("sagemaker.utils.repack_model", MagicMock())
364+
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
364365
@patch("time.strftime", return_value=TIMESTAMP)
365-
def test_mxnet_neo(
366-
strftime, sagemaker_session, mxnet_inference_version, mxnet_py_version, skip_if_mms_version
367-
):
366+
def test_mxnet_neo(strftime, sagemaker_session, neo_mxnet_version):
368367
mx = MXNet(
369368
entry_point=SCRIPT_PATH,
370-
framework_version=mxnet_inference_version,
371-
py_version=mxnet_py_version,
369+
framework_version="1.6",
370+
py_version="py3",
372371
role=ROLE,
373372
sagemaker_session=sagemaker_session,
374373
instance_count=INSTANCE_COUNT,
375374
instance_type=INSTANCE_TYPE,
375+
base_job_name="sagemaker-mxnet",
376376
)
377-
378-
inputs = "s3://mybucket/train"
379-
380-
mx.fit(inputs=inputs)
377+
mx.fit()
381378

382379
input_shape = {"data": [100, 1, 28, 28]}
383380
output_location = "s3://neo-sdk-test"
384381

385382
compiled_model = mx.compile_model(
386-
target_instance_family="ml_c4", input_shape=input_shape, output_path=output_location
383+
target_instance_family="ml_c4",
384+
input_shape=input_shape,
385+
output_path=output_location,
386+
framework="mxnet",
387+
framework_version=neo_mxnet_version,
387388
)
388389

389390
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
@@ -399,7 +400,7 @@ def test_mxnet_neo(
399400
actual_compile_model_args = sagemaker_session.method_calls[3][2]
400401
assert expected_compile_model_args == actual_compile_model_args
401402

402-
assert compiled_model.image_uri == _neo_inference_image(mxnet_inference_version)
403+
assert compiled_model.image_uri == _neo_inference_image(neo_mxnet_version)
403404

404405
predictor = mx.deploy(1, CPU, use_compiled_model=True)
405406
assert isinstance(predictor, MXNetPredictor)

0 commit comments

Comments
 (0)