Skip to content

Commit 5e725e8

Browse files
committed
Fixing MWMS unit test for TF2
1 parent c1c8fc2 commit 5e725e8

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,12 @@ def test_fit_mpi(time, strftime, sagemaker_session):
525525
@patch("time.time", return_value=TIME)
526526
@patch("sagemaker.utils.create_tar_file", MagicMock())
527527
def test_fit_mwms(time, strftime, sagemaker_session):
528+
framework_version = "2.9.1"
529+
py_version = "py39"
528530
tf = TensorFlow(
529531
entry_point=SCRIPT_FILE,
530-
framework_version="2.9.1",
531-
py_version="py39",
532+
framework_version=framework_version,
533+
py_version=py_version,
532534
role=ROLE,
533535
sagemaker_session=sagemaker_session,
534536
instance_type=INSTANCE_TYPE,
@@ -546,6 +548,14 @@ def test_fit_mwms(time, strftime, sagemaker_session):
546548
expected_train_args = _create_train_job("2.9.1", py_version="py39")
547549
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
548550
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_MWMS_ENV_NAME] = json.dumps(True)
551+
expected_train_args[
552+
"image_uri"
553+
] = f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:{framework_version}-cpu-{py_version}"
554+
expected_train_args["job_name"] = f"tensorflow-training-{TIMESTAMP}"
555+
expected_train_args["hyperparameters"]["sagemaker_job_name"] = expected_train_args["job_name"]
556+
expected_train_args["hyperparameters"][
557+
"sagemaker_submit_directory"
558+
] = f"s3://{BUCKET_NAME}/{expected_train_args['job_name']}/source/sourcedir.tar.gz"
549559

550560
actual_train_args = sagemaker_session.method_calls[0][2]
551561
assert actual_train_args == expected_train_args

0 commit comments

Comments
 (0)