@@ -360,30 +360,31 @@ def test_mxnet(
360
360
)
361
361
362
362
363
- @patch ("sagemaker.utils.create_tar_file" , MagicMock ())
363
+ @patch ("sagemaker.utils.repack_model" , MagicMock ())
364
+ @patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
364
365
@patch ("time.strftime" , return_value = TIMESTAMP )
365
- def test_mxnet_neo (
366
- strftime , sagemaker_session , mxnet_inference_version , mxnet_py_version , skip_if_mms_version
367
- ):
366
+ def test_mxnet_neo (strftime , sagemaker_session , neo_mxnet_version ):
368
367
mx = MXNet (
369
368
entry_point = SCRIPT_PATH ,
370
- framework_version = mxnet_inference_version ,
371
- py_version = mxnet_py_version ,
369
+ framework_version = "1.6" ,
370
+ py_version = "py3" ,
372
371
role = ROLE ,
373
372
sagemaker_session = sagemaker_session ,
374
373
instance_count = INSTANCE_COUNT ,
375
374
instance_type = INSTANCE_TYPE ,
375
+ base_job_name = "sagemaker-mxnet" ,
376
376
)
377
-
378
- inputs = "s3://mybucket/train"
379
-
380
- mx .fit (inputs = inputs )
377
+ mx .fit ()
381
378
382
379
input_shape = {"data" : [100 , 1 , 28 , 28 ]}
383
380
output_location = "s3://neo-sdk-test"
384
381
385
382
compiled_model = mx .compile_model (
386
- target_instance_family = "ml_c4" , input_shape = input_shape , output_path = output_location
383
+ target_instance_family = "ml_c4" ,
384
+ input_shape = input_shape ,
385
+ output_path = output_location ,
386
+ framework = "mxnet" ,
387
+ framework_version = neo_mxnet_version ,
387
388
)
388
389
389
390
sagemaker_call_names = [c [0 ] for c in sagemaker_session .method_calls ]
@@ -399,7 +400,7 @@ def test_mxnet_neo(
399
400
actual_compile_model_args = sagemaker_session .method_calls [3 ][2 ]
400
401
assert expected_compile_model_args == actual_compile_model_args
401
402
402
- assert compiled_model .image_uri == _neo_inference_image (mxnet_inference_version )
403
+ assert compiled_model .image_uri == _neo_inference_image (neo_mxnet_version )
403
404
404
405
predictor = mx .deploy (1 , CPU , use_compiled_model = True )
405
406
assert isinstance (predictor , MXNetPredictor )
0 commit comments