Skip to content

Commit 9317e08

Browse files
feature: Support TF2.12 SageMaker DLC (#3776)
1 parent 97eaed3 commit 9317e08

File tree

6 files changed

+66
-17
lines changed

6 files changed

+66
-17
lines changed

src/sagemaker/image_uri_config/tensorflow.json

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2021,7 +2021,8 @@
20212021
"2.8": "2.8.0",
20222022
"2.9": "2.9.2",
20232023
"2.10": "2.10.1",
2024-
"2.11": "2.11.0"
2024+
"2.11": "2.11.0",
2025+
"2.12": "2.12.0"
20252026
},
20262027
"versions": {
20272028
"1.10.0": {
@@ -3755,6 +3756,37 @@
37553756
"us-west-2": "763104351884"
37563757
},
37573758
"repository": "tensorflow-training"
3759+
},
3760+
"2.12.0": {
3761+
"py_versions": [
3762+
"py310"
3763+
],
3764+
"registries": {
3765+
"af-south-1": "626614931356",
3766+
"ap-east-1": "871362719292",
3767+
"ap-northeast-1": "763104351884",
3768+
"ap-northeast-2": "763104351884",
3769+
"ap-northeast-3": "364406365360",
3770+
"ap-south-1": "763104351884",
3771+
"ap-southeast-1": "763104351884",
3772+
"ap-southeast-2": "763104351884",
3773+
"ap-southeast-3": "907027046896",
3774+
"ap-southeast-4": "457447274322",
3775+
"ca-central-1": "763104351884",
3776+
"eu-central-1": "763104351884",
3777+
"eu-north-1": "763104351884",
3778+
"eu-south-1": "692866216735",
3779+
"eu-west-1": "763104351884",
3780+
"eu-west-2": "763104351884",
3781+
"eu-west-3": "763104351884",
3782+
"me-south-1": "217643126080",
3783+
"sa-east-1": "763104351884",
3784+
"us-east-1": "763104351884",
3785+
"us-east-2": "763104351884",
3786+
"us-west-1": "763104351884",
3787+
"us-west-2": "763104351884"
3788+
},
3789+
"repository": "tensorflow-training"
37583790
}
37593791
}
37603792
}

src/sagemaker/image_uris.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -369,16 +369,15 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
369369

370370
def _validate_instance_deprecation(framework, instance_type, version):
371371
"""Check if instance type is deprecated for a certain framework with a certain version"""
372-
if (
373-
framework == "pytorch"
374-
and _get_instance_type_family(instance_type) == "p2"
375-
and Version(version) >= Version("1.13")
376-
):
377-
raise ValueError(
378-
"P2 instances have been deprecated for sagemaker jobs with PyTorch 1.13 and above. "
379-
"For information about supported instance types please refer to "
380-
"https://aws.amazon.com/sagemaker/pricing/"
381-
)
372+
if _get_instance_type_family(instance_type) == "p2":
373+
if (framework == "pytorch" and Version(version) >= Version("1.13")) or (
374+
framework == "tensorflow" and Version(version) >= Version("2.12")
375+
):
376+
raise ValueError(
377+
"P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
378+
"For information about supported instance types please refer to "
379+
"https://aws.amazon.com/sagemaker/pricing/"
380+
)
382381

383382

384383
def _validate_for_suppported_frameworks_and_instance_type(framework, instance_type):

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,9 @@ def _tf_py_version(tf_version, request):
493493
return "py37"
494494
if Version("2.6") <= version < Version("2.8"):
495495
return "py38"
496-
return "py39"
496+
if Version("2.8") <= version < Version("2.12"):
497+
return "py39"
498+
return "py310"
497499

498500

499501
@pytest.fixture(scope="module")

tests/integ/test_training_compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,10 @@ def test_tensorflow(
224224
"""
225225
Test the TensorFlow estimator
226226
"""
227-
if version.parse(tensorflow_training_latest_version) < version.parse("2.9"):
228-
pytest.skip("Training Compiler only supports TF >= 2.9")
227+
if version.parse(tensorflow_training_latest_version) >= version.parse("2.12") or version.parse(
228+
tensorflow_training_latest_version
229+
) < version.parse("2.9"):
230+
pytest.skip("Training Compiler only supports TF >= 2.9 and < 2.12")
229231
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
230232
epochs = 10
231233
batch = 256

tests/unit/sagemaker/image_uris/test_dlc_frameworks.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from sagemaker import image_uris
1818
from tests.unit.sagemaker.image_uris import expected_uris
1919

20+
import pytest
21+
2022
INSTANCE_TYPES_AND_PROCESSORS = (("ml.c4.xlarge", "cpu"), ("ml.p2.xlarge", "gpu"))
2123
RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS = (("ml.c4.xlarge", "cpu"), ("ml.g4dn.xlarge", "gpu"))
2224
REGION = "us-west-2"
@@ -72,7 +74,9 @@ def _test_image_uris(
7274
}
7375

7476
TYPES_AND_PROCESSORS = INSTANCE_TYPES_AND_PROCESSORS
75-
if framework == "pytorch" and Version(fw_version) >= Version("1.13"):
77+
if (framework == "pytorch" and Version(fw_version) >= Version("1.13")) or (
78+
framework == "tensorflow" and Version(fw_version) >= Version("2.12")
79+
):
7680
"""Handle P2 deprecation"""
7781
TYPES_AND_PROCESSORS = RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS
7882

@@ -83,6 +87,14 @@ def _test_image_uris(
8387
assert expected == uri
8488

8589
for region in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.keys():
90+
if (
91+
scope == "training"
92+
and framework == "tensorflow"
93+
and Version(fw_version) == Version("2.12")
94+
):
95+
if region in ["cn-north-1", "cn-northwest-1", "us-iso-east-1", "us-isob-east-1"]:
96+
pytest.skip(f"TF 2.12 SM DLC is not available in {region} region")
97+
8698
uri = image_uris.retrieve(region=region, instance_type="ml.c4.xlarge", **base_args)
8799

88100
expected = expected_fn(region=region, **expected_fn_args)

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@
5757

5858
@pytest.fixture(scope="module", autouse=True)
5959
def skip_if_incompatible(tensorflow_training_version, request):
60-
if version.parse(tensorflow_training_version) < version.parse("2.9"):
61-
pytest.skip("Training Compiler only supports TF >= 2.9")
60+
if version.parse(tensorflow_training_version) >= version.parse("2.12") or version.parse(
61+
tensorflow_training_version
62+
) < version.parse("2.9"):
63+
pytest.skip("Training Compiler only supports TF >= 2.9 and < 2.12")
6264

6365

6466
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)