Skip to content

Commit 1c514bf

Browse files
authored
feature: Add huggingface-llm 0.6.0 dlc images (#3837)
1 parent 77bda8e commit 1c514bf

File tree

6 files changed

+196
-1
lines changed

6 files changed

+196
-1
lines changed

src/sagemaker/huggingface/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
from sagemaker.huggingface.estimator import HuggingFace # noqa: F401
1717
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
1818
from sagemaker.huggingface.processing import HuggingFaceProcessor # noqa:F401
19+
from sagemaker.huggingface.llm_utils import get_huggingface_llm_image_uri # noqa: F401
1920

2021
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig # noqa: F401
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional
17+
18+
from sagemaker import image_uris
19+
from sagemaker.session import Session
20+
21+
22+
def get_huggingface_llm_image_uri(
23+
backend: str,
24+
session: Optional[Session] = None,
25+
region: Optional[str] = None,
26+
version: Optional[str] = None,
27+
) -> str:
28+
"""Retrieves the image URI for inference.
29+
30+
Args:
31+
backend (str): The backend to use. Valid values include "huggingface" and "lmi".
32+
session (Session): The SageMaker Session to use. (Default: None).
33+
region (str): The AWS region to use for image URI. (default: None).
34+
version (str): The framework version for which to retrieve an
35+
image URI. If no version is set, defaults to latest version. (default: None).
36+
37+
Returns:
38+
str: The image URI string.
39+
"""
40+
41+
if region is None:
42+
if session is None:
43+
region = Session().boto_session.region_name
44+
else:
45+
region = session.boto_session.region_name
46+
if backend == "huggingface":
47+
return image_uris.retrieve(
48+
"huggingface-llm",
49+
region=region,
50+
version=version,
51+
image_scope="inference",
52+
)
53+
if backend == "lmi":
54+
version = version or "0.22.1"
55+
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
56+
raise ValueError("Unsupported backend: %s" % backend)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{
2+
"inference": {
3+
"processors": ["gpu"],
4+
"version_aliases": {
5+
"0.6": "0.6.0"
6+
},
7+
"versions": {
8+
"0.6.0": {
9+
"py_versions": ["py39"],
10+
"registries": {
11+
"af-south-1": "626614931356",
12+
"ap-east-1": "871362719292",
13+
"ap-northeast-1": "763104351884",
14+
"ap-northeast-2": "763104351884",
15+
"ap-northeast-3": "364406365360",
16+
"ap-south-1": "763104351884",
17+
"ap-south-2": "772153158452",
18+
"ap-southeast-1": "763104351884",
19+
"ap-southeast-2": "763104351884",
20+
"ap-southeast-3": "907027046896",
21+
"ap-southeast-4": "457447274322",
22+
"ca-central-1": "763104351884",
23+
"cn-north-1": "727897471807",
24+
"cn-northwest-1": "727897471807",
25+
"eu-central-1": "763104351884",
26+
"eu-central-2": "380420809688",
27+
"eu-north-1": "763104351884",
28+
"eu-west-1": "763104351884",
29+
"eu-west-2": "763104351884",
30+
"eu-west-3": "763104351884",
31+
"eu-south-1": "692866216735",
32+
"eu-south-2": "503227376785",
33+
"me-south-1": "217643126080",
34+
"me-central-1": "914824155844",
35+
"sa-east-1": "763104351884",
36+
"us-east-1": "763104351884",
37+
"us-east-2": "763104351884",
38+
"us-gov-east-1": "446045086412",
39+
"us-gov-west-1": "442386744353",
40+
"us-iso-east-1": "886529160074",
41+
"us-isob-east-1": "094389454867",
42+
"us-west-1": "763104351884",
43+
"us-west-2": "763104351884"
44+
},
45+
"tag_prefix": "2.0.0-tgi0.6.0",
46+
"repository": "huggingface-pytorch-tgi-inference",
47+
"container_version": {"gpu": "cu118-ubuntu20.04"}
48+
}
49+
}
50+
}
51+
}

src/sagemaker/image_uris.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3434
HUGGING_FACE_FRAMEWORK = "huggingface"
35+
HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm"
3536
XGBOOST_FRAMEWORK = "xgboost"
3637
SKLEARN_FRAMEWORK = "sklearn"
3738
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
@@ -462,7 +463,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
462463

463464
return available_versions[0]
464465

465-
if version is None and framework in [DATA_WRANGLER_FRAMEWORK]:
466+
if version is None and framework in [DATA_WRANGLER_FRAMEWORK, HUGGING_FACE_LLM_FRAMEWORK]:
466467
version = _get_latest_versions(available_versions)
467468

468469
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ def djl_framework_uri(repo, account, djl_version, primary_framework, region=REGI
8080
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
8181

8282

83+
def huggingface_llm_framework_uri(
84+
repo,
85+
account,
86+
version,
87+
tag,
88+
region=REGION,
89+
):
90+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
91+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
92+
93+
8394
def base_python_uri(repo, account, region=REGION):
8495
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
8596
tag = "1.0"
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker.huggingface import get_huggingface_llm_image_uri
18+
from tests.unit.sagemaker.image_uris import expected_uris
19+
20+
ACCOUNTS = {
21+
"af-south-1": "626614931356",
22+
"ap-east-1": "871362719292",
23+
"ap-northeast-1": "763104351884",
24+
"ap-northeast-2": "763104351884",
25+
"ap-northeast-3": "364406365360",
26+
"ap-south-1": "763104351884",
27+
"ap-southeast-1": "763104351884",
28+
"ap-southeast-2": "763104351884",
29+
"ap-southeast-3": "907027046896",
30+
"ca-central-1": "763104351884",
31+
"cn-north-1": "727897471807",
32+
"cn-northwest-1": "727897471807",
33+
"eu-central-1": "763104351884",
34+
"eu-north-1": "763104351884",
35+
"eu-west-1": "763104351884",
36+
"eu-west-2": "763104351884",
37+
"eu-west-3": "763104351884",
38+
"eu-south-1": "692866216735",
39+
"me-south-1": "217643126080",
40+
"sa-east-1": "763104351884",
41+
"us-east-1": "763104351884",
42+
"us-east-2": "763104351884",
43+
"us-west-1": "763104351884",
44+
"us-west-2": "763104351884",
45+
}
46+
HF_VERSIONS = ["0.6.0"]
47+
LMI_VERSIONS = ["0.22.1"]
48+
HF_VERSIONS_MAPPING = {"0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04"}
49+
LMI_VERSIONS_MAPPING = {"0.22.1": "deepspeed0.8.3-cu118"}
50+
51+
52+
@pytest.mark.parametrize("version", HF_VERSIONS)
53+
def test_huggingface(version):
54+
for region in ACCOUNTS.keys():
55+
uri = get_huggingface_llm_image_uri("huggingface", region=region, version=version)
56+
57+
expected = expected_uris.huggingface_llm_framework_uri(
58+
"huggingface-pytorch-tgi-inference",
59+
ACCOUNTS[region],
60+
version,
61+
HF_VERSIONS_MAPPING[version],
62+
region=region,
63+
)
64+
assert expected == uri
65+
66+
67+
@pytest.mark.parametrize("version", LMI_VERSIONS)
68+
def test_lmi(version):
69+
for region in ACCOUNTS.keys():
70+
uri = get_huggingface_llm_image_uri("lmi", region=region, version=version)
71+
72+
expected = expected_uris.djl_framework_uri(
73+
"djl-inference", ACCOUNTS[region], version, LMI_VERSIONS_MAPPING[version], region=region
74+
)
75+
assert expected == uri

0 commit comments

Comments
 (0)