Skip to content

Commit d81a2cd

Browse files
TF-2.16 test modification and handling (#4830)
* [DO NOT MERGE] Experimenting TF image_uri configs * Add logic to Tf estimator * Comment previous data * Test with making changes to model.py * Add override FW version * Make changes to prevent non 2.16 from making the change * Add Print * Use hasattr * Fix print and use getattr * Print statement * Only define override when needed * Change to net.export() for TF2.16 * Skip tests failing due to TF-IO * Reformatting * Reformatting pylint * Change tf_full_vesion fixture * Add Version * Revert unit test changes * Preventing re-initialization of Version * Revert config JSON changes * Handle in case inf and training have different major.minor * Introduce return version concept in tests * Add TF2.16.1 inf to config * Revert temp changes --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 1f21668 commit d81a2cd

File tree

7 files changed

+52
-7
lines changed

7 files changed

+52
-7
lines changed

src/sagemaker/tensorflow/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __init__(
199199
# patch versions, but end up hosting the model of same TF version. For eg., the upstream
200200
# TFS-2.12.0 release was a bad release and hence a new TFS-2.12.1 release was made to host
201201
# models from TF-2.12.0.
202-
training_inference_version_mismatch_dict = {"2.12.0": "2.12.1"}
202+
training_inference_version_mismatch_dict = {"2.12.0": "2.12.1", "2.16.2": "2.16.1"}
203203
self.inference_framework_version = training_inference_version_mismatch_dict.get(
204204
framework_version, framework_version
205205
)

tests/conftest.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from botocore.config import Config
2424
from packaging.version import Version
25+
from packaging.specifiers import SpecifierSet
2526

2627
from sagemaker import Session, image_uris, utils, get_execution_role
2728
from sagemaker.local import LocalSession
@@ -555,11 +556,18 @@ def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_lat
555556
Fixture exists as such, since TF training and TFS have different latest versions.
556557
Otherwise, this would simply be a single latest version.
557558
"""
558-
return str(
559-
min(
560-
Version(tensorflow_training_latest_version),
561-
Version(tensorflow_inference_latest_version),
562-
)
559+
tensorflow_training_latest_version = Version(tensorflow_training_latest_version)
560+
tensorflow_inference_latest_version = Version(tensorflow_inference_latest_version)
561+
562+
return_version = min(
563+
tensorflow_training_latest_version,
564+
tensorflow_inference_latest_version,
565+
)
566+
567+
return (
568+
f"{return_version.major}.{return_version.minor}"
569+
if return_version in SpecifierSet(">=2.16")
570+
else str(return_version)
563571
)
564572

565573

tests/data/tensorflow_mnist/mnist_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,10 @@ def main(args):
198198

199199
if args.current_host == args.hosts[0]:
200200
ckpt_manager.save()
201-
net.save("/opt/ml/model/1")
201+
if int(tf_major) > 2 or (int(tf_major) == 2 and int(tf_minor) >= 16):
202+
net.export("/opt/ml/model/1")
203+
else:
204+
net.save("/opt/ml/model/1")
202205

203206

204207
if __name__ == "__main__":

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626
import pytest
2727

28+
from packaging.version import Version
29+
from packaging.specifiers import SpecifierSet
30+
2831
from sagemaker.model_card.model_card import ModelCard, ModelOverview, ModelPackageModelCard
2932
from sagemaker.model_card.schema_constraints import ModelCardStatusEnum
3033
import tests
@@ -1250,6 +1253,11 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
12501253
pipeline_name,
12511254
region_name,
12521255
):
1256+
if Version(tf_full_version) in SpecifierSet("==2.16.*"):
1257+
pytest.skip(
1258+
"This test is failing in TensorFlow 2.16 beacuse of an upstream bug: "
1259+
"https://github.com/tensorflow/io/issues/2039"
1260+
)
12531261
base_dir = os.path.join(DATA_DIR, "tensorflow_mnist")
12541262
entry_point = os.path.join(base_dir, "mnist_v2.py")
12551263
input_path = sagemaker_session_for_pipeline.upload_data(

tests/integ/sagemaker/workflow/test_model_steps.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
import pytest
1919

20+
from packaging.version import Version
21+
from packaging.specifiers import SpecifierSet
22+
2023
from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution
2124
from sagemaker.workflow.fail_step import FailStep
2225
from sagemaker.workflow.functions import Join
@@ -589,6 +592,11 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
589592
def test_model_registration_with_tensorflow_model_with_pipeline_model(
590593
pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name
591594
):
595+
if Version(tf_full_version) in SpecifierSet("==2.16.*"):
596+
pytest.skip(
597+
"This test is failing in TensorFlow 2.16 beacuse of an upstream bug: "
598+
"https://github.com/tensorflow/io/issues/2039"
599+
)
592600
base_dir = os.path.join(DATA_DIR, "tensorflow_mnist")
593601
entry_point = os.path.join(base_dir, "mnist_v2.py")
594602
input_path = pipeline_session.upload_data(

tests/integ/sagemaker/workflow/test_training_steps.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
import pytest
2020

21+
from packaging.version import Version
22+
from packaging.specifiers import SpecifierSet
23+
2124
from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution
2225
from sagemaker import TrainingInput, get_execution_role, utils, image_uris
2326
from sagemaker.debugger import (
@@ -235,6 +238,12 @@ def test_training_step_with_output_path_as_join(
235238
def test_tensorflow_training_step_with_parameterized_code_input(
236239
pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name
237240
):
241+
if Version(tf_full_version) in SpecifierSet("==2.16.*"):
242+
pytest.skip(
243+
"This test is failing in TensorFlow 2.16 beacuse of an upstream bug: "
244+
"https://github.com/tensorflow/io/issues/2039"
245+
)
246+
238247
base_dir = os.path.join(DATA_DIR, "tensorflow_mnist")
239248
entry_point1 = "mnist_v2.py"
240249
entry_point2 = "mnist_dummy.py"

tests/integ/test_transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
import pytest
2020

21+
from packaging.version import Version
22+
from packaging.specifiers import SpecifierSet
23+
2124
from sagemaker import KMeans, s3, get_execution_role
2225
from sagemaker.mxnet import MXNet
2326
from sagemaker.pytorch import PyTorchModel
@@ -553,6 +556,12 @@ def test_transform_mxnet_logs(
553556
def test_transform_tf_kms_network_isolation(
554557
sagemaker_session, cpu_instance_type, tmpdir, tf_full_version, tf_full_py_version
555558
):
559+
if Version(tf_full_version) in SpecifierSet("==2.16.*"):
560+
pytest.skip(
561+
"This test is failing in TensorFlow 2.16 beacuse of an upstream bug: "
562+
"https://github.com/tensorflow/io/issues/2039"
563+
)
564+
556565
data_path = os.path.join(DATA_DIR, "tensorflow_mnist")
557566

558567
tf = TensorFlow(

0 commit comments

Comments
 (0)