Skip to content

Commit c7ee5d1

Browse files
Teng-xuhuilgolr
authored andcommitted
feat: Use specific images for SMP v2 jobs (aws#4333)
* Add check for smp lib * update syntax * Remove unused images * Update repo name and regions * Update account number * Update framework name and check for None distribution * Add unit tests for smp v2 uri * Check enabled * Remove logging * Add cuda version in uri * Update cu121 * Update syntax * Fix black check * Fix black --------- Co-authored-by: huilgolr <[email protected]>
1 parent f53b2e6 commit c7ee5d1

File tree

3 files changed

+107
-2
lines changed

3 files changed

+107
-2
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"training": {
3+
"processors": [
4+
"gpu"
5+
],
6+
"version_aliases": {
7+
"2.0": "2.0.1"
8+
},
9+
"versions": {
10+
"2.0.1": {
11+
"py_versions": [
12+
"py310"
13+
],
14+
"registries": {
15+
"ap-northeast-1": "658645717510",
16+
"ap-northeast-2": "658645717510",
17+
"ap-northeast-3": "658645717510",
18+
"ap-south-1": "658645717510",
19+
"ap-southeast-1": "658645717510",
20+
"ap-southeast-2": "658645717510",
21+
"ca-central-1": "658645717510",
22+
"eu-central-1": "658645717510",
23+
"eu-north-1": "658645717510",
24+
"eu-west-1": "658645717510",
25+
"eu-west-2": "658645717510",
26+
"eu-west-3": "658645717510",
27+
"sa-east-1": "658645717510",
28+
"us-east-1": "658645717510",
29+
"us-east-2": "658645717510",
30+
"us-west-1": "658645717510",
31+
"us-west-2": "658645717510"
32+
},
33+
"repository": "smdistributed-modelparallel"
34+
}
35+
}
36+
}
37+
}

src/sagemaker/image_uris.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
from sagemaker.jumpstart import artifacts
2828
from sagemaker.workflow import is_pipeline_variable
2929
from sagemaker.workflow.utilities import override_pipeline_parameter_var
30-
from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS
30+
from sagemaker.fw_utils import (
31+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
32+
GRAVITON_ALLOWED_FRAMEWORKS,
33+
)
3134

3235
logger = logging.getLogger(__name__)
3336

@@ -343,7 +346,8 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
343346

344347
if image_scope not in ("eia", "inference"):
345348
logger.warning(
346-
"Elastic inference is for inference only. Ignoring image scope: %s.", image_scope
349+
"Elastic inference is for inference only. Ignoring image scope: %s.",
350+
image_scope,
347351
)
348352
image_scope = "eia"
349353

@@ -660,6 +664,17 @@ def get_training_image_uri(
660664
container_version = None
661665
base_framework_version = None
662666

667+
# Check for smp library
668+
if distribution is not None:
669+
if "torch_distributed" in distribution and "smdistributed" in distribution:
670+
if "modelparallel" in distribution["smdistributed"]:
671+
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
672+
framework = "pytorch-smp"
673+
if "p5" in instance_type:
674+
container_version = "cu121"
675+
else:
676+
container_version = "cu118"
677+
663678
return retrieve(
664679
framework,
665680
region,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker import image_uris
17+
from tests.unit.sagemaker.image_uris import expected_uris
18+
19+
CONTAINER_VERSIONS = {"ml.p4d.24xlarge": "cu118", "ml.p5d.24xlarge": "cu121"}
20+
21+
22+
@pytest.mark.parametrize("load_config", ["pytorch-smp.json"], indirect=True)
23+
def test_smp_v2(load_config):
24+
VERSIONS = load_config["training"]["versions"]
25+
PROCESSORS = load_config["training"]["processors"]
26+
distribution = {
27+
"torch_distributed": {"enabled": True},
28+
"smdistributed": {"modelparallel": {"enabled": True}},
29+
}
30+
for processor in PROCESSORS:
31+
for version in VERSIONS:
32+
ACCOUNTS = load_config["training"]["versions"][version]["registries"]
33+
PY_VERSIONS = load_config["training"]["versions"][version]["py_versions"]
34+
for py_version in PY_VERSIONS:
35+
for region in ACCOUNTS.keys():
36+
for instance_type in CONTAINER_VERSIONS.keys():
37+
uri = image_uris.get_training_image_uri(
38+
region,
39+
framework="pytorch",
40+
framework_version=version,
41+
py_version=py_version,
42+
distribution=distribution,
43+
instance_type=instance_type,
44+
)
45+
expected = expected_uris.framework_uri(
46+
repo="smdistributed-modelparallel",
47+
fw_version=version,
48+
py_version=f"{py_version}-{CONTAINER_VERSIONS[instance_type]}",
49+
processor=processor,
50+
region=region,
51+
account=ACCOUNTS[region],
52+
)
53+
assert expected == uri

0 commit comments

Comments
 (0)