@@ -91,7 +91,7 @@ def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION):
91
91
92
92
def _chainer_estimator (
93
93
sagemaker_session ,
94
- framework_version = defaults . CHAINER_VERSION ,
94
+ framework_version ,
95
95
train_instance_type = None ,
96
96
base_job_name = None ,
97
97
use_mpi = None ,
@@ -202,13 +202,14 @@ def _create_train_job_with_additional_hyperparameters(version):
202
202
}
203
203
204
204
205
- def test_additional_hyperparameters (sagemaker_session ):
205
+ def test_additional_hyperparameters (sagemaker_session , chainer_version ):
206
206
chainer = _chainer_estimator (
207
207
sagemaker_session ,
208
208
use_mpi = True ,
209
209
num_processes = 4 ,
210
210
process_slots_per_host = 10 ,
211
211
additional_mpi_options = "-x MY_ENVIRONMENT_VARIABLE" ,
212
+ framework_version = chainer_version ,
212
213
)
213
214
assert bool (strtobool (chainer .hyperparameters ()["sagemaker_use_mpi" ]))
214
215
assert int (chainer .hyperparameters ()["sagemaker_num_processes" ]) == 4
@@ -300,7 +301,7 @@ def test_create_model(sagemaker_session, chainer_version):
300
301
assert model .vpc_config is None
301
302
302
303
303
- def test_create_model_with_optional_params (sagemaker_session ):
304
+ def test_create_model_with_optional_params (sagemaker_session , chainer_version ):
304
305
container_log_level = '"logging.INFO"'
305
306
source_dir = "s3://mybucket/source"
306
307
enable_cloudwatch_metrics = "true"
@@ -311,6 +312,7 @@ def test_create_model_with_optional_params(sagemaker_session):
311
312
train_instance_count = INSTANCE_COUNT ,
312
313
train_instance_type = INSTANCE_TYPE ,
313
314
container_log_level = container_log_level ,
315
+ framework_version = chainer_version ,
314
316
py_version = PYTHON_VERSION ,
315
317
base_job_name = "job" ,
316
318
source_dir = source_dir ,
@@ -372,8 +374,8 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
372
374
sagemaker_session = sagemaker_session ,
373
375
train_instance_count = INSTANCE_COUNT ,
374
376
train_instance_type = INSTANCE_TYPE ,
375
- py_version = PYTHON_VERSION ,
376
377
framework_version = chainer_version ,
378
+ py_version = PYTHON_VERSION ,
377
379
)
378
380
379
381
inputs = "s3://mybucket/train"
@@ -414,62 +416,72 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
414
416
415
417
416
418
@patch ("sagemaker.utils.create_tar_file" , MagicMock ())
417
- def test_model (sagemaker_session ):
419
+ def test_model (sagemaker_session , chainer_version ):
418
420
model = ChainerModel (
419
421
"s3://some/data.tar.gz" ,
420
422
role = ROLE ,
421
423
entry_point = SCRIPT_PATH ,
422
424
sagemaker_session = sagemaker_session ,
425
+ framework_version = chainer_version ,
426
+ py_version = PYTHON_VERSION ,
423
427
)
424
428
predictor = model .deploy (1 , GPU )
425
429
assert isinstance (predictor , ChainerPredictor )
426
430
427
431
428
432
@patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
429
- def test_model_prepare_container_def_accelerator_error (sagemaker_session ):
433
+ def test_model_prepare_container_def_accelerator_error (sagemaker_session , chainer_version ):
430
434
model = ChainerModel (
431
- MODEL_DATA , role = ROLE , entry_point = SCRIPT_PATH , sagemaker_session = sagemaker_session
435
+ MODEL_DATA ,
436
+ role = ROLE ,
437
+ entry_point = SCRIPT_PATH ,
438
+ sagemaker_session = sagemaker_session ,
439
+ framework_version = chainer_version ,
440
+ py_version = PYTHON_VERSION ,
432
441
)
433
442
with pytest .raises (ValueError ):
434
443
model .prepare_container_def (INSTANCE_TYPE , accelerator_type = ACCELERATOR_TYPE )
435
444
436
445
437
- def test_train_image_default (sagemaker_session ):
446
+ def test_train_image_default (sagemaker_session , chainer_version ):
438
447
chainer = Chainer (
439
448
entry_point = SCRIPT_PATH ,
440
449
role = ROLE ,
441
450
sagemaker_session = sagemaker_session ,
442
451
train_instance_count = INSTANCE_COUNT ,
443
452
train_instance_type = INSTANCE_TYPE ,
453
+ framework_version = chainer_version ,
444
454
py_version = PYTHON_VERSION ,
445
455
)
446
456
447
- assert _get_full_cpu_image_uri (defaults . CHAINER_VERSION ) in chainer .train_image ()
457
+ assert _get_full_cpu_image_uri (chainer_version ) in chainer .train_image ()
448
458
449
459
450
460
def test_train_image_cpu_instances (sagemaker_session , chainer_version ):
451
461
chainer = _chainer_estimator (
452
- sagemaker_session , chainer_version , train_instance_type = "ml.c2.2xlarge"
462
+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.c2.2xlarge"
453
463
)
454
464
assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version )
455
465
456
466
chainer = _chainer_estimator (
457
- sagemaker_session , chainer_version , train_instance_type = "ml.c4.2xlarge"
467
+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.c4.2xlarge"
458
468
)
459
469
assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version )
460
470
461
- chainer = _chainer_estimator (sagemaker_session , chainer_version , train_instance_type = "ml.m16" )
471
+ chainer = _chainer_estimator (
472
+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.m16"
473
+ )
462
474
assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version )
463
475
464
476
465
477
def test_train_image_gpu_instances (sagemaker_session , chainer_version ):
466
478
chainer = _chainer_estimator (
467
- sagemaker_session , chainer_version , train_instance_type = "ml.g2.2xlarge"
479
+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.g2.2xlarge"
468
480
)
469
481
assert chainer .train_image () == _get_full_gpu_image_uri (chainer_version )
470
482
471
483
chainer = _chainer_estimator (
472
- sagemaker_session , chainer_version , train_instance_type = "ml.p2.2xlarge"
484
+ sagemaker_session , framework_version = chainer_version , train_instance_type = "ml.p2.2xlarge"
473
485
)
474
486
assert chainer .train_image () == _get_full_gpu_image_uri (chainer_version )
475
487
@@ -597,13 +609,14 @@ def test_attach_custom_image(sagemaker_session):
597
609
598
610
599
611
@patch ("sagemaker.chainer.estimator.python_deprecation_warning" )
600
- def test_estimator_py2_warning (warning , sagemaker_session ):
612
+ def test_estimator_py2_warning (warning , sagemaker_session , chainer_version ):
601
613
estimator = Chainer (
602
614
entry_point = SCRIPT_PATH ,
603
615
role = ROLE ,
604
616
sagemaker_session = sagemaker_session ,
605
617
train_instance_count = INSTANCE_COUNT ,
606
618
train_instance_type = INSTANCE_TYPE ,
619
+ framework_version = chainer_version ,
607
620
py_version = "py2" ,
608
621
)
609
622
@@ -612,49 +625,22 @@ def test_estimator_py2_warning(warning, sagemaker_session):
612
625
613
626
614
627
@patch ("sagemaker.chainer.model.python_deprecation_warning" )
615
- def test_model_py2_warning (warning , sagemaker_session ):
628
+ def test_model_py2_warning (warning , sagemaker_session , chainer_version ):
616
629
model = ChainerModel (
617
630
MODEL_DATA ,
618
631
role = ROLE ,
619
632
entry_point = SCRIPT_PATH ,
620
633
sagemaker_session = sagemaker_session ,
634
+ framework_version = chainer_version ,
621
635
py_version = "py2" ,
622
636
)
623
637
assert model .py_version == "py2"
624
638
warning .assert_called_with (model .__framework_name__ , defaults .LATEST_PY2_VERSION )
625
639
626
640
627
- @patch ("sagemaker.chainer.estimator.empty_framework_version_warning" )
628
- def test_empty_framework_version (warning , sagemaker_session ):
629
- estimator = Chainer (
630
- entry_point = SCRIPT_PATH ,
631
- role = ROLE ,
632
- sagemaker_session = sagemaker_session ,
633
- train_instance_count = INSTANCE_COUNT ,
634
- train_instance_type = INSTANCE_TYPE ,
635
- framework_version = None ,
636
- )
637
-
638
- assert estimator .framework_version == defaults .CHAINER_VERSION
639
- warning .assert_called_with (defaults .CHAINER_VERSION , Chainer .LATEST_VERSION )
640
-
641
-
642
- @patch ("sagemaker.chainer.model.empty_framework_version_warning" )
643
- def test_model_empty_framework_version (warning , sagemaker_session ):
644
- model = ChainerModel (
645
- MODEL_DATA ,
646
- role = ROLE ,
647
- entry_point = SCRIPT_PATH ,
648
- sagemaker_session = sagemaker_session ,
649
- framework_version = None ,
650
- )
651
- assert model .framework_version == defaults .CHAINER_VERSION
652
- warning .assert_called_with (defaults .CHAINER_VERSION , defaults .LATEST_VERSION )
653
-
654
-
655
- def test_custom_image_estimator_deploy (sagemaker_session ):
641
+ def test_custom_image_estimator_deploy (sagemaker_session , chainer_version ):
656
642
custom_image = "mycustomimage:latest"
657
- chainer = _chainer_estimator (sagemaker_session )
643
+ chainer = _chainer_estimator (sagemaker_session , framework_version = chainer_version )
658
644
chainer .fit (inputs = "s3://mybucket/train" , job_name = "new_name" )
659
645
model = chainer .create_model (image = custom_image )
660
646
assert model .image == custom_image
0 commit comments