@@ -57,13 +57,13 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
57
57
58
58
59
59
@pytest .fixture (scope = 'module' )
60
- def mxnet_model (sagemaker_local_session ):
60
+ def mxnet_model (sagemaker_local_session , mxnet_full_version ):
61
61
script_path = os .path .join (DATA_DIR , 'mxnet_mnist' , 'mnist.py' )
62
62
data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
63
63
64
64
mx = MXNet (entry_point = script_path , role = 'SageMakerRole' ,
65
- train_instance_count = 1 , train_instance_type = 'local' ,
66
- sagemaker_session = sagemaker_local_session )
65
+ train_instance_count = 1 , train_instance_type = 'local' , launch_parameter_server = True ,
66
+ sagemaker_session = sagemaker_local_session , framework_version = mxnet_full_version )
67
67
68
68
train_input = mx .sagemaker_session .upload_data (path = os .path .join (data_path , 'train' ),
69
69
key_prefix = 'integ-test-data/mxnet_mnist/train' )
@@ -306,16 +306,16 @@ def test_local_mode_serving_from_local_model(sagemaker_local_session, mxnet_mode
306
306
fcntl .lockf (local_mode_lock , fcntl .LOCK_UN )
307
307
308
308
309
- def test_mxnet_local_mode (sagemaker_local_session ):
309
+ def test_mxnet_local_mode (sagemaker_local_session , mxnet_full_version ):
310
310
local_mode_lock_fd = open (LOCK_PATH , 'w' )
311
311
local_mode_lock = local_mode_lock_fd .fileno ()
312
312
313
313
script_path = os .path .join (DATA_DIR , 'mxnet_mnist' , 'mnist.py' )
314
314
data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
315
315
316
316
mx = MXNet (entry_point = script_path , role = 'SageMakerRole' , py_version = PYTHON_VERSION ,
317
- train_instance_count = 1 , train_instance_type = 'local' ,
318
- sagemaker_session = sagemaker_local_session )
317
+ train_instance_count = 1 , train_instance_type = 'local' , launch_parameter_server = True ,
318
+ sagemaker_session = sagemaker_local_session , framework_version = mxnet_full_version )
319
319
320
320
train_input = mx .sagemaker_session .upload_data (path = os .path .join (data_path , 'train' ),
321
321
key_prefix = 'integ-test-data/mxnet_mnist/train' )
@@ -338,15 +338,15 @@ def test_mxnet_local_mode(sagemaker_local_session):
338
338
fcntl .lockf (local_mode_lock , fcntl .LOCK_UN )
339
339
340
340
341
- def test_mxnet_local_data_local_script ():
341
+ def test_mxnet_local_data_local_script (mxnet_full_version ):
342
342
local_mode_lock_fd = open (LOCK_PATH , 'w' )
343
343
local_mode_lock = local_mode_lock_fd .fileno ()
344
344
345
345
script_path = os .path .join (DATA_DIR , 'mxnet_mnist' , 'mnist.py' )
346
346
data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
347
347
348
- mx = MXNet (entry_point = script_path , role = 'SageMakerRole' ,
349
- train_instance_count = 1 , train_instance_type = 'local' ,
348
+ mx = MXNet (entry_point = script_path , role = 'SageMakerRole' , framework_version = mxnet_full_version ,
349
+ train_instance_count = 1 , train_instance_type = 'local' , launch_parameter_server = True ,
350
350
sagemaker_session = LocalNoS3Session ())
351
351
352
352
train_input = 'file://' + os .path .join (data_path , 'train' )
@@ -368,14 +368,15 @@ def test_mxnet_local_data_local_script():
368
368
fcntl .lockf (local_mode_lock , fcntl .LOCK_UN )
369
369
370
370
371
- def test_local_transform_mxnet (sagemaker_local_session , tmpdir ):
371
+ def test_local_transform_mxnet (sagemaker_local_session , tmpdir , mxnet_full_version ):
372
372
local_mode_lock_fd = open (LOCK_PATH , 'w' )
373
373
local_mode_lock = local_mode_lock_fd .fileno ()
374
374
data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
375
375
script_path = os .path .join (data_path , 'mnist.py' )
376
376
377
377
mx = MXNet (entry_point = script_path , role = 'SageMakerRole' , train_instance_count = 1 ,
378
- train_instance_type = 'ml.c4.xlarge' , sagemaker_session = sagemaker_local_session )
378
+ train_instance_type = 'ml.c4.xlarge' , framework_version = mxnet_full_version ,
379
+ sagemaker_session = sagemaker_local_session , launch_parameter_server = True )
379
380
380
381
train_input = mx .sagemaker_session .upload_data (path = os .path .join (data_path , 'train' ),
381
382
key_prefix = 'integ-test-data/mxnet_mnist/train' )
0 commit comments