@@ -93,11 +93,7 @@ def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION):
93
93
94
94
95
95
def _pytorch_estimator (
96
- sagemaker_session ,
97
- framework_version = defaults .PYTORCH_VERSION ,
98
- train_instance_type = None ,
99
- base_job_name = None ,
100
- ** kwargs
96
+ sagemaker_session , framework_version , train_instance_type = None , base_job_name = None , ** kwargs
101
97
):
102
98
return PyTorch (
103
99
entry_point = SCRIPT_PATH ,
@@ -572,13 +568,17 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_version):
572
568
warning .assert_called_with (model .__framework_name__ , defaults .LATEST_PY2_VERSION )
573
569
574
570
575
- def test_pt_enable_sm_metrics (sagemaker_session ):
576
- pytorch = _pytorch_estimator (sagemaker_session , enable_sagemaker_metrics = True )
571
+ def test_pt_enable_sm_metrics (sagemaker_session , pytorch_full_version ):
572
+ pytorch = _pytorch_estimator (
573
+ sagemaker_session , framework_version = pytorch_full_version , enable_sagemaker_metrics = True
574
+ )
577
575
assert pytorch .enable_sagemaker_metrics
578
576
579
577
580
- def test_pt_disable_sm_metrics (sagemaker_session ):
581
- pytorch = _pytorch_estimator (sagemaker_session , enable_sagemaker_metrics = False )
578
+ def test_pt_disable_sm_metrics (sagemaker_session , pytorch_full_version ):
579
+ pytorch = _pytorch_estimator (
580
+ sagemaker_session , framework_version = pytorch_full_version , enable_sagemaker_metrics = False
581
+ )
582
582
assert not pytorch .enable_sagemaker_metrics
583
583
584
584
@@ -594,9 +594,9 @@ def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
594
594
assert pytorch .enable_sagemaker_metrics
595
595
596
596
597
- def test_custom_image_estimator_deploy (sagemaker_session ):
597
+ def test_custom_image_estimator_deploy (sagemaker_session , pytorch_full_version ):
598
598
custom_image = "mycustomimage:latest"
599
- pytorch = _pytorch_estimator (sagemaker_session )
599
+ pytorch = _pytorch_estimator (sagemaker_session , framework_version = pytorch_full_version )
600
600
pytorch .fit (inputs = "s3://mybucket/train" , job_name = "new_name" )
601
601
model = pytorch .create_model (image = custom_image )
602
602
assert model .image == custom_image
0 commit comments