Skip to content

Commit 8c056a5

Browse files
committed
feature: Add huggingface-llm 0.5.0 dlc images
1 parent c68f5b0 commit 8c056a5

File tree

6 files changed

+252
-4
lines changed

6 files changed

+252
-4
lines changed

src/sagemaker/huggingface/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
from sagemaker.huggingface.estimator import HuggingFace # noqa: F401
1717
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
1818
from sagemaker.huggingface.processing import HuggingFaceProcessor # noqa:F401
19+
from sagemaker.huggingface.llm_utils import get_huggingface_llm_image_uri # noqa: F401
1920

2021
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig # noqa: F401
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
14+
from typing import Optional
15+
16+
from sagemaker import image_uris
17+
from sagemaker.session import Session
18+
19+
20+
def _validate_pt_tf_versions(pytorch_version, tensorflow_version):
21+
"""Validate PyTorch / TensorFlow versions.
22+
23+
Args:
24+
pytorch_version (str): PyTorch version to validate.
25+
Required unless ``tensorflow_version`` is provided. List of supported versions:
26+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
27+
tensorflow_version (str): TensorFlow version to validate.
28+
Required unless ``pytorch_version`` is provided. List of supported versions:
29+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
30+
"""
31+
32+
if tensorflow_version is not None and pytorch_version is not None:
33+
raise ValueError(
34+
"tensorflow_version and pytorch_version are both not None. "
35+
"Specify only tensorflow_version or pytorch_version."
36+
)
37+
if tensorflow_version is None and pytorch_version is None:
38+
raise ValueError(
39+
"tensorflow_version and pytorch_version are both None. "
40+
"Specify either tensorflow_version or pytorch_version."
41+
)
42+
43+
44+
def get_huggingface_llm_image_uri(
45+
region: Optional[str] = None,
46+
version: Optional[str] = None,
47+
py_version: Optional[str] = None,
48+
instance_type: Optional[str] = None,
49+
container_version: Optional[str]=None,
50+
pytorch_version: Optional[str] = None,
51+
tensorflow_version: Optional[str] = None
52+
) -> str:
53+
"""Retrieves the image URI for training.
54+
55+
Args:
56+
region (str): The AWS region to use for image URI. (default: None).
57+
version (str): The framework version for which to retrieve an
58+
image URI. If no version is set, defaults to latest version. (default: None).
59+
py_version (str): The Python version. This is required if there is
60+
more than one supported Python version for the given framework version.
61+
(default: None).
62+
instance_type (str): The SageMaker instance type. For supported types, see
63+
https://aws.amazon.com/sagemaker/pricing. This is required if
64+
there are different images for different processor types. (default: None)
65+
container_version (str): the version of docker image.
66+
Ideally the value of parameter should be created inside the framework.
67+
For custom use, see the list of supported container versions:
68+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
69+
(default: None).
70+
pytorch_version (str): PyTorch version you want to use for
71+
executing your inference code. Required unless ``tensorflow_version``
72+
is provided. List of supported versions:
73+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
74+
(default: None).
75+
tensorflow_version (str): TensorFlow version you want to use for
76+
executing your inference code. Required unless ``pytorch_version``
77+
is provided. List of supported versions:
78+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
79+
(default: None).
80+
backend (str): The backend to use. (default: None)
81+
82+
Returns:
83+
str: The image URI string.
84+
"""
85+
86+
_framework_name = "huggingface-llm"
87+
region = region or Session().boto_session.region_name
88+
89+
_validate_pt_tf_versions(pytorch_version, tensorflow_version)
90+
if tensorflow_version is not None:
91+
base_framework_version = f"tensorflow{tensorflow_version}"
92+
else:
93+
base_framework_version = f"pytorch{pytorch_version}"
94+
return image_uris.retrieve(
95+
_framework_name,
96+
region=region,
97+
version=version,
98+
py_version=py_version,
99+
instance_type=instance_type,
100+
image_scope="inference",
101+
container_version=container_version,
102+
base_framework_version=base_framework_version,
103+
)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
{
2+
"inference": {
3+
"processors": ["gpu"],
4+
"version_aliases": {
5+
"0.5": "0.5.0"
6+
},
7+
"versions": {
8+
"0.5.0": {
9+
"version_aliases": {
10+
"pytorch2.0": "pytorch2.0.0"
11+
},
12+
"pytorch2.0.0": {
13+
"py_versions": ["py39"],
14+
"registries": {
15+
"af-south-1": "626614931356",
16+
"ap-east-1": "871362719292",
17+
"ap-northeast-1": "763104351884",
18+
"ap-northeast-2": "763104351884",
19+
"ap-northeast-3": "364406365360",
20+
"ap-south-1": "763104351884",
21+
"ap-south-2": "772153158452",
22+
"ap-southeast-1": "763104351884",
23+
"ap-southeast-2": "763104351884",
24+
"ap-southeast-3": "907027046896",
25+
"ap-southeast-4": "457447274322",
26+
"ca-central-1": "763104351884",
27+
"cn-north-1": "727897471807",
28+
"cn-northwest-1": "727897471807",
29+
"eu-central-1": "763104351884",
30+
"eu-central-2": "380420809688",
31+
"eu-north-1": "763104351884",
32+
"eu-west-1": "763104351884",
33+
"eu-west-2": "763104351884",
34+
"eu-west-3": "763104351884",
35+
"eu-south-1": "692866216735",
36+
"eu-south-2": "503227376785",
37+
"me-south-1": "217643126080",
38+
"me-central-1": "914824155844",
39+
"sa-east-1": "763104351884",
40+
"us-east-1": "763104351884",
41+
"us-east-2": "763104351884",
42+
"us-gov-east-1": "446045086412",
43+
"us-gov-west-1": "442386744353",
44+
"us-iso-east-1": "886529160074",
45+
"us-isob-east-1": "094389454867",
46+
"us-west-1": "763104351884",
47+
"us-west-2": "763104351884"
48+
},
49+
"repository": "huggingface-pytorch-inference",
50+
"container_version": {"gpu": "cu118-ubuntu22.04"}
51+
}
52+
}
53+
}
54+
}
55+
}

src/sagemaker/image_uris.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3434
HUGGING_FACE_FRAMEWORK = "huggingface"
35+
HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm"
3536
XGBOOST_FRAMEWORK = "xgboost"
3637
SKLEARN_FRAMEWORK = "sklearn"
3738
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
@@ -167,7 +168,7 @@ def retrieve(
167168
version = _validate_version_and_set_if_needed(version, config, framework)
168169
version_config = config["versions"][_version_for_config(version, config)]
169170

170-
if framework == HUGGING_FACE_FRAMEWORK:
171+
if framework in [HUGGING_FACE_FRAMEWORK, HUGGING_FACE_LLM_FRAMEWORK]:
171172
if version_config.get("version_aliases"):
172173
full_base_framework_version = version_config["version_aliases"].get(
173174
base_framework_version, base_framework_version
@@ -198,7 +199,7 @@ def retrieve(
198199
sdk_version = _get_latest_versions(version_config["sdk_versions"])
199200
container_version = sdk_version + "-" + container_version
200201

201-
if framework == HUGGING_FACE_FRAMEWORK:
202+
if framework in [HUGGING_FACE_FRAMEWORK, HUGGING_FACE_LLM_FRAMEWORK]:
202203
pt_or_tf_version = (
203204
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
204205
)
@@ -228,7 +229,11 @@ def retrieve(
228229
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
229230
)
230231

231-
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
232+
if framework == HUGGING_FACE_LLM_FRAMEWORK:
233+
_version = version
234+
tag_prefix = f"{pt_or_tf_version}-tgi{_version}"
235+
else:
236+
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
232237
else:
233238
tag_prefix = version_config.get("tag_prefix", version)
234239

@@ -462,7 +467,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
462467

463468
return available_versions[0]
464469

465-
if version is None and framework in [DATA_WRANGLER_FRAMEWORK]:
470+
if version is None and framework in [DATA_WRANGLER_FRAMEWORK, HUGGING_FACE_LLM_FRAMEWORK]:
466471
version = _get_latest_versions(available_versions)
467472

468473
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,21 @@ def djl_framework_uri(repo, account, djl_version, primary_framework, region=REGI
8080
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
8181

8282

83+
def huggingface_llm_framework_uri(
84+
repo,
85+
account,
86+
version,
87+
py_version="py39",
88+
processor="gpu",
89+
pytorch_version="2.0.0",
90+
region=REGION,
91+
container_version="cu118-ubuntu22.04",
92+
):
93+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
94+
tag = f"{pytorch_version}-tgi{version}-{processor}-{py_version}-{container_version}"
95+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
96+
97+
8398
def base_python_uri(repo, account, region=REGION):
8499
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
85100
tag = "1.0"
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker.huggingface import get_huggingface_llm_image_uri
18+
from tests.unit.sagemaker.image_uris import expected_uris
19+
20+
ACCOUNTS = {
21+
"af-south-1": "626614931356",
22+
"ap-east-1": "871362719292",
23+
"ap-northeast-1": "763104351884",
24+
"ap-northeast-2": "763104351884",
25+
"ap-northeast-3": "364406365360",
26+
"ap-south-1": "763104351884",
27+
"ap-south-2": "772153158452",
28+
"ap-southeast-1": "763104351884",
29+
"ap-southeast-2": "763104351884",
30+
"ap-southeast-3": "907027046896",
31+
"ap-southeast-4": "457447274322",
32+
"ca-central-1": "763104351884",
33+
"cn-north-1": "727897471807",
34+
"cn-northwest-1": "727897471807",
35+
"eu-central-1": "763104351884",
36+
"eu-central-2": "380420809688",
37+
"eu-north-1": "763104351884",
38+
"eu-west-1": "763104351884",
39+
"eu-west-2": "763104351884",
40+
"eu-west-3": "763104351884",
41+
"eu-south-1": "692866216735",
42+
"eu-south-2": "503227376785",
43+
"me-south-1": "217643126080",
44+
"me-central-1": "914824155844",
45+
"sa-east-1": "763104351884",
46+
"us-east-1": "763104351884",
47+
"us-east-2": "763104351884",
48+
"us-gov-east-1": "446045086412",
49+
"us-gov-west-1": "442386744353",
50+
"us-iso-east-1": "886529160074",
51+
"us-isob-east-1": "094389454867",
52+
"us-west-1": "763104351884",
53+
"us-west-2": "763104351884"
54+
}
55+
VERSIONS = ["0.5.0"]
56+
PYTORCH_VERSIONS = ["2.0.0"]
57+
58+
59+
@pytest.mark.parametrize("version", VERSIONS)
60+
@pytest.mark.parametrize("pytorch_version", PYTORCH_VERSIONS)
61+
def test_huggingface_llm(version, pytorch_version):
62+
for region in ACCOUNTS.keys():
63+
uri = get_huggingface_llm_image_uri(region=region, version=version, pytorch_version=pytorch_version)
64+
65+
expected = expected_uris.huggingface_llm_framework_uri(
66+
"huggingface-pytorch-inference", ACCOUNTS[region], version=version,
67+
pytorch_version=pytorch_version, region=region
68+
)
69+
assert expected == uri

0 commit comments

Comments
 (0)