@@ -525,10 +525,12 @@ def test_fit_mpi(time, strftime, sagemaker_session):
525
525
@patch ("time.time" , return_value = TIME )
526
526
@patch ("sagemaker.utils.create_tar_file" , MagicMock ())
527
527
def test_fit_mwms (time , strftime , sagemaker_session ):
528
+ framework_version = "2.9.1"
529
+ py_version = "py39"
528
530
tf = TensorFlow (
529
531
entry_point = SCRIPT_FILE ,
530
- framework_version = "2.9.1" ,
531
- py_version = "py39" ,
532
+ framework_version = framework_version ,
533
+ py_version = py_version ,
532
534
role = ROLE ,
533
535
sagemaker_session = sagemaker_session ,
534
536
instance_type = INSTANCE_TYPE ,
@@ -546,6 +548,14 @@ def test_fit_mwms(time, strftime, sagemaker_session):
546
548
expected_train_args = _create_train_job ("2.9.1" , py_version = "py39" )
547
549
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
548
550
expected_train_args ["hyperparameters" ][TensorFlow .LAUNCH_MWMS_ENV_NAME ] = json .dumps (True )
551
+ expected_train_args [
552
+ "image_uri"
553
+ ] = f"763104351884.dkr.ecr.{ REGION } .amazonaws.com/tensorflow-training:{ framework_version } -cpu-{ py_version } "
554
+ expected_train_args ["job_name" ] = f"tensorflow-training-{ TIMESTAMP } "
555
+ expected_train_args ["hyperparameters" ]["sagemaker_job_name" ] = expected_train_args ["job_name" ]
556
+ expected_train_args ["hyperparameters" ][
557
+ "sagemaker_submit_directory"
558
+ ] = f"s3://{ BUCKET_NAME } /{ expected_train_args ['job_name' ]} /source/sourcedir.tar.gz"
549
559
550
560
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
551
561
assert actual_train_args == expected_train_args
0 commit comments