Skip to content

Commit 2c48152

Browse files
committed
Fixing unit tests for MWMS
1 parent 5f4959e commit 2c48152

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,11 @@ def test_fit_mpi(time, strftime, sagemaker_session):
523523
@patch("time.strftime", return_value=TIMESTAMP)
524524
@patch("time.time", return_value=TIME)
525525
@patch("sagemaker.utils.create_tar_file", MagicMock())
526-
def test_fit_mwms(time, strftime, sagemaker_session):
527-
framework_version = "2.9.1"
528-
py_version = "py39"
526+
def test_fit_mwms(
527+
time, strftime, sagemaker_session, tensorflow_training_version, tensorflow_training_py_version
528+
):
529+
framework_version = tensorflow_training_version
530+
py_version = tensorflow_training_py_version
529531
tf = TensorFlow(
530532
entry_point=SCRIPT_FILE,
531533
framework_version=framework_version,
@@ -544,7 +546,7 @@ def test_fit_mwms(time, strftime, sagemaker_session):
544546
call_names = [c[0] for c in sagemaker_session.method_calls]
545547
assert call_names == ["train", "logs_for_job"]
546548

547-
expected_train_args = _create_train_job("2.9.1", py_version="py39")
549+
expected_train_args = _create_train_job(framework_version, py_version=py_version)
548550
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
549551
expected_train_args[
550552
"image_uri"

0 commit comments

Comments
 (0)