Skip to content

Commit af57e4e

Browse files
committed
feature: Add support for TF 2.7 and TF 2.8
1 parent 7e2c7ab commit af57e4e

File tree

4 files changed

+150
-4
lines changed

4 files changed

+150
-4
lines changed

src/sagemaker/fw_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@
7474
"2.6",
7575
"2.6.0",
7676
"2.6.2",
77+
"2.7",
78+
"2.7.0",
79+
"2.8",
80+
"2.8.0",
7781
],
7882
"pytorch": [
7983
"1.6",

src/sagemaker/image_uri_config/tensorflow.json

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,9 @@
279279
"2.3": "2.3.2",
280280
"2.4": "2.4.3",
281281
"2.5": "2.5.1",
282-
"2.6": "2.6.0"
282+
"2.6": "2.6.0",
283+
"2.7": "2.7.0",
284+
"2.8": "2.8.0"
283285
},
284286
"versions": {
285287
"1.10.0": {
@@ -1343,6 +1345,66 @@
13431345
"us-west-2": "763104351884"
13441346
},
13451347
"repository": "tensorflow-inference"
1348+
},
1349+
"2.7.0": {
1350+
"registries": {
1351+
"af-south-1": "626614931356",
1352+
"ap-east-1": "871362719292",
1353+
"ap-northeast-1": "763104351884",
1354+
"ap-northeast-2": "763104351884",
1355+
"ap-northeast-3": "364406365360",
1356+
"ap-south-1": "763104351884",
1357+
"ap-southeast-1": "763104351884",
1358+
"ap-southeast-2": "763104351884",
1359+
"ca-central-1": "763104351884",
1360+
"cn-north-1": "727897471807",
1361+
"cn-northwest-1": "727897471807",
1362+
"eu-central-1": "763104351884",
1363+
"eu-north-1": "763104351884",
1364+
"eu-south-1": "692866216735",
1365+
"eu-west-1": "763104351884",
1366+
"eu-west-2": "763104351884",
1367+
"eu-west-3": "763104351884",
1368+
"me-south-1": "217643126080",
1369+
"sa-east-1": "763104351884",
1370+
"us-east-1": "763104351884",
1371+
"us-east-2": "763104351884",
1372+
"us-gov-west-1": "442386744353",
1373+
"us-iso-east-1": "886529160074",
1374+
"us-west-1": "763104351884",
1375+
"us-west-2": "763104351884"
1376+
},
1377+
"repository": "tensorflow-inference"
1378+
},
1379+
"2.8.0": {
1380+
"registries": {
1381+
"af-south-1": "626614931356",
1382+
"ap-east-1": "871362719292",
1383+
"ap-northeast-1": "763104351884",
1384+
"ap-northeast-2": "763104351884",
1385+
"ap-northeast-3": "364406365360",
1386+
"ap-south-1": "763104351884",
1387+
"ap-southeast-1": "763104351884",
1388+
"ap-southeast-2": "763104351884",
1389+
"ca-central-1": "763104351884",
1390+
"cn-north-1": "727897471807",
1391+
"cn-northwest-1": "727897471807",
1392+
"eu-central-1": "763104351884",
1393+
"eu-north-1": "763104351884",
1394+
"eu-south-1": "692866216735",
1395+
"eu-west-1": "763104351884",
1396+
"eu-west-2": "763104351884",
1397+
"eu-west-3": "763104351884",
1398+
"me-south-1": "217643126080",
1399+
"sa-east-1": "763104351884",
1400+
"us-east-1": "763104351884",
1401+
"us-east-2": "763104351884",
1402+
"us-gov-west-1": "442386744353",
1403+
"us-iso-east-1": "886529160074",
1404+
"us-west-1": "763104351884",
1405+
"us-west-2": "763104351884"
1406+
},
1407+
"repository": "tensorflow-inference"
13461408
}
13471409
}
13481410
},
@@ -1370,7 +1432,9 @@
13701432
"2.3": "2.3.2",
13711433
"2.4": "2.4.3",
13721434
"2.5": "2.5.1",
1373-
"2.6": "2.6.2"
1435+
"2.6": "2.6.2",
1436+
"2.7": "2.7.1",
1437+
"2.8": "2.8.0"
13741438
},
13751439
"versions": {
13761440
"1.10.0": {
@@ -2629,6 +2693,72 @@
26292693
"us-west-2": "763104351884"
26302694
},
26312695
"repository": "tensorflow-training"
2696+
},
2697+
"2.7.1": {
2698+
"py_versions": [
2699+
"py38"
2700+
],
2701+
"registries": {
2702+
"af-south-1": "626614931356",
2703+
"ap-east-1": "871362719292",
2704+
"ap-northeast-1": "763104351884",
2705+
"ap-northeast-2": "763104351884",
2706+
"ap-northeast-3": "364406365360",
2707+
"ap-south-1": "763104351884",
2708+
"ap-southeast-1": "763104351884",
2709+
"ap-southeast-2": "763104351884",
2710+
"ca-central-1": "763104351884",
2711+
"cn-north-1": "727897471807",
2712+
"cn-northwest-1": "727897471807",
2713+
"eu-central-1": "763104351884",
2714+
"eu-north-1": "763104351884",
2715+
"eu-south-1": "692866216735",
2716+
"eu-west-1": "763104351884",
2717+
"eu-west-2": "763104351884",
2718+
"eu-west-3": "763104351884",
2719+
"me-south-1": "217643126080",
2720+
"sa-east-1": "763104351884",
2721+
"us-east-1": "763104351884",
2722+
"us-east-2": "763104351884",
2723+
"us-gov-west-1": "442386744353",
2724+
"us-iso-east-1": "886529160074",
2725+
"us-west-1": "763104351884",
2726+
"us-west-2": "763104351884"
2727+
},
2728+
"repository": "tensorflow-training"
2729+
},
2730+
"2.8.0": {
2731+
"py_versions": [
2732+
"py39"
2733+
],
2734+
"registries": {
2735+
"af-south-1": "626614931356",
2736+
"ap-east-1": "871362719292",
2737+
"ap-northeast-1": "763104351884",
2738+
"ap-northeast-2": "763104351884",
2739+
"ap-northeast-3": "364406365360",
2740+
"ap-south-1": "763104351884",
2741+
"ap-southeast-1": "763104351884",
2742+
"ap-southeast-2": "763104351884",
2743+
"ca-central-1": "763104351884",
2744+
"cn-north-1": "727897471807",
2745+
"cn-northwest-1": "727897471807",
2746+
"eu-central-1": "763104351884",
2747+
"eu-north-1": "763104351884",
2748+
"eu-south-1": "692866216735",
2749+
"eu-west-1": "763104351884",
2750+
"eu-west-2": "763104351884",
2751+
"eu-west-3": "763104351884",
2752+
"me-south-1": "217643126080",
2753+
"sa-east-1": "763104351884",
2754+
"us-east-1": "763104351884",
2755+
"us-east-2": "763104351884",
2756+
"us-gov-west-1": "442386744353",
2757+
"us-iso-east-1": "886529160074",
2758+
"us-west-1": "763104351884",
2759+
"us-west-2": "763104351884"
2760+
},
2761+
"repository": "tensorflow-training"
26322762
}
26332763
}
26342764
}

tests/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ def _tf_py_version(tf_version, request):
352352
return request.param
353353
if Version("2.2") <= version < Version("2.6"):
354354
return "py37"
355-
return "py38"
355+
if Version("2.6") <= version < Version("2.8"):
356+
return "py38"
357+
return "py39"
356358

357359

358360
@pytest.fixture(scope="module")
@@ -384,7 +386,9 @@ def tf_full_py_version(tf_full_version):
384386
return "py3"
385387
if version < Version("2.6"):
386388
return "py37"
387-
return "py38"
389+
if version < Version("2.8"):
390+
return "py38"
391+
return "py39"
388392

389393

390394
@pytest.fixture(scope="session")

tests/unit/test_fw_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,12 +678,18 @@ def test_validate_smdataparallel_args_not_raises():
678678
("ml.p3.16xlarge", "tensorflow", "2.3.2", "py37", smdataparallel_enabled),
679679
("ml.p3.16xlarge", "tensorflow", "2.3", "py37", smdataparallel_enabled),
680680
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled),
681+
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py37", smdataparallel_enabled),
681682
("ml.p3.16xlarge", "tensorflow", "2.4", "py37", smdataparallel_enabled),
682683
("ml.p3.16xlarge", "tensorflow", "2.5.0", "py37", smdataparallel_enabled),
683684
("ml.p3.16xlarge", "tensorflow", "2.5.1", "py37", smdataparallel_enabled),
684685
("ml.p3.16xlarge", "tensorflow", "2.5", "py37", smdataparallel_enabled),
685686
("ml.p3.16xlarge", "tensorflow", "2.6.0", "py38", smdataparallel_enabled),
687+
("ml.p3.16xlarge", "tensorflow", "2.6.2", "py38", smdataparallel_enabled),
686688
("ml.p3.16xlarge", "tensorflow", "2.6", "py38", smdataparallel_enabled),
689+
("ml.p3.16xlarge", "tensorflow", "2.7.0", "py38", smdataparallel_enabled),
690+
("ml.p3.16xlarge", "tensorflow", "2.7", "py38", smdataparallel_enabled),
691+
("ml.p3.16xlarge", "tensorflow", "2.8.0", "py38", smdataparallel_enabled),
692+
("ml.p3.16xlarge", "tensorflow", "2.8", "py39", smdataparallel_enabled),
687693
("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
688694
("ml.p3.16xlarge", "pytorch", "1.6", "py3", smdataparallel_enabled),
689695
("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled),
@@ -698,6 +704,8 @@ def test_validate_smdataparallel_args_not_raises():
698704
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
699705
("ml.p3.16xlarge", "tensorflow", "2.5.1", "py37", smdataparallel_enabled_custom_mpi),
700706
("ml.p3.16xlarge", "tensorflow", "2.6.0", "py38", smdataparallel_enabled_custom_mpi),
707+
("ml.p3.16xlarge", "tensorflow", "2.7.0", "py38", smdataparallel_enabled_custom_mpi),
708+
("ml.p3.16xlarge", "tensorflow", "2.8.0", "py38", smdataparallel_enabled_custom_mpi),
701709
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
702710
("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi),
703711
]

0 commit comments

Comments
 (0)