Skip to content

Commit e477d1e

Browse files
samrudsroot
authored andcommitted
feat: Introduce HF Transformers to ModelBuilder (aws#4368)
* feat: Introduce HF Transformers to ModelBuilder * Add integ test * Revert the change in comment for tgi prepare * Capitalize enum * Address PR feedbacks * Format files * Format files * Address PR feedbacks * Address PR feedbacks * Fix test builds
1 parent 234376f commit e477d1e

File tree

14 files changed

+937
-15
lines changed

14 files changed

+937
-15
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.serve.builder.djl_builder import DJL
3535
from sagemaker.serve.builder.tgi_builder import TGI
3636
from sagemaker.serve.builder.jumpstart_builder import JumpStart
37+
from sagemaker.serve.builder.transformers_builder import Transformers
3738
from sagemaker.predictor import Predictor
3839
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
3940
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -53,6 +54,7 @@
5354
from sagemaker.serve.validations.check_image_and_hardware_type import (
5455
validate_image_uri_and_hardware,
5556
)
57+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
5658

5759
logger = logging.getLogger(__name__)
5860

@@ -65,7 +67,7 @@
6567

6668
# pylint: disable=attribute-defined-outside-init
6769
@dataclass
68-
class ModelBuilder(Triton, DJL, JumpStart, TGI):
70+
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
6971
"""Class that builds a deployable model.
7072
7173
Args:
@@ -125,8 +127,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI):
125127
in order for model builder to build the artifacts correctly (according
126128
to the model server). Possible values for this argument are
127129
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
128-
``TRITON``, and ``TGI``.
129-
130+
``TRITON``, and``TGI``.
130131
"""
131132

132133
model_path: Optional[str] = field(
@@ -535,7 +536,7 @@ def wrapper(*args, **kwargs):
535536
return wrapper
536537

537538
# Model Builder is a class to build the model for deployment.
538-
# It supports three modes of deployment
539+
# It supports two modes of deployment
539540
# 1/ SageMaker Endpoint
540541
# 2/ Local launch with container
541542
def build(
@@ -577,12 +578,20 @@ def build(
577578
)
578579

579580
self.serve_settings = self._get_serve_setting()
581+
582+
hf_model_md = get_huggingface_model_metadata(
583+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
584+
)
585+
580586
if isinstance(self.model, str):
581587
if self._is_jumpstart_model_id():
582588
return self._build_for_jumpstart()
583589
if self._is_djl():
584590
return self._build_for_djl()
585-
return self._build_for_tgi()
591+
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
592+
return self._build_for_tgi()
593+
else:
594+
return self._build_for_transformers()
586595

587596
self._build_validations()
588597

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Transformers build logic with model builder"""
14+
from __future__ import absolute_import
15+
import logging
16+
from abc import ABC, abstractmethod
17+
from typing import Type
18+
from packaging.version import Version
19+
20+
from sagemaker.model import Model
21+
from sagemaker import image_uris
22+
from sagemaker.serve.utils.local_hardware import (
23+
_get_nb_instance,
24+
)
25+
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
26+
from sagemaker.huggingface import HuggingFaceModel
27+
from sagemaker.serve.model_server.multi_model_server.prepare import (
28+
_create_dir_structure,
29+
)
30+
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
31+
from sagemaker.serve.utils.types import ModelServer
32+
from sagemaker.serve.mode.function_pointers import Mode
33+
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
34+
from sagemaker.base_predictor import PredictorBase
35+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
36+
37+
logger = logging.getLogger(__name__)
38+
DEFAULT_TIMEOUT = 1800
39+
40+
41+
"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub
42+
"""
43+
44+
45+
# pylint: disable=W0108
46+
class Transformers(ABC):
47+
"""Transformers build logic with ModelBuilder()"""
48+
49+
def __init__(self):
50+
self.model = None
51+
self.serve_settings = None
52+
self.sagemaker_session = None
53+
self.model_path = None
54+
self.dependencies = None
55+
self.modes = None
56+
self.mode = None
57+
self.model_server = None
58+
self.image_uri = None
59+
self._original_deploy = None
60+
self.hf_model_config = None
61+
self._default_data_type = None
62+
self.pysdk_model = None
63+
self.env_vars = None
64+
self.nb_instance_type = None
65+
self.ram_usage_model_load = None
66+
self.secret_key = None
67+
self.role_arn = None
68+
self.py_version = None
69+
self.tensorflow_version = None
70+
self.pytorch_version = None
71+
self.instance_type = None
72+
self.schema_builder = None
73+
74+
@abstractmethod
75+
def _prepare_for_mode(self):
76+
"""Abstract method"""
77+
78+
def _create_transformers_model(self) -> Type[Model]:
79+
"""Initializes the model after fetching image
80+
81+
1. Get the metadata for deciding framework
82+
2. Get the supported hugging face versions
83+
3. Create model
84+
4. Fetch image
85+
86+
Returns:
87+
pysdk_model: Corresponding model instance
88+
"""
89+
90+
hf_model_md = get_huggingface_model_metadata(
91+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
92+
)
93+
hf_config = image_uris.config_for_framework("huggingface").get("inference")
94+
config = hf_config["versions"]
95+
base_hf_version = sorted(config.keys(), key=lambda v: Version(v))[0]
96+
97+
if hf_model_md is None:
98+
raise ValueError("Could not fetch HF metadata")
99+
100+
if "pytorch" in hf_model_md.get("tags"):
101+
self.pytorch_version = self._get_supported_version(
102+
hf_config, base_hf_version, "pytorch"
103+
)
104+
self.py_version = config[base_hf_version]["pytorch" + self.pytorch_version].get(
105+
"py_versions"
106+
)[-1]
107+
pysdk_model = HuggingFaceModel(
108+
env=self.env_vars,
109+
role=self.role_arn,
110+
sagemaker_session=self.sagemaker_session,
111+
py_version=self.py_version,
112+
transformers_version=base_hf_version,
113+
pytorch_version=self.pytorch_version,
114+
)
115+
elif "keras" in hf_model_md.get("tags") or "tensorflow" in hf_model_md.get("tags"):
116+
self.tensorflow_version = self._get_supported_version(
117+
hf_config, base_hf_version, "tensorflow"
118+
)
119+
self.py_version = config[base_hf_version]["tensorflow" + self.tensorflow_version].get(
120+
"py_versions"
121+
)[-1]
122+
pysdk_model = HuggingFaceModel(
123+
env=self.env_vars,
124+
role=self.role_arn,
125+
sagemaker_session=self.sagemaker_session,
126+
py_version=self.py_version,
127+
transformers_version=base_hf_version,
128+
tensorflow_version=self.tensorflow_version,
129+
)
130+
131+
if self.mode == Mode.LOCAL_CONTAINER:
132+
self.image_uri = pysdk_model.serving_image_uri(
133+
self.sagemaker_session.boto_region_name, "local"
134+
)
135+
else:
136+
self.image_uri = pysdk_model.serving_image_uri(
137+
self.sagemaker_session.boto_region_name, self.instance_type
138+
)
139+
140+
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
141+
142+
self._original_deploy = pysdk_model.deploy
143+
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
144+
return pysdk_model
145+
146+
@_capture_telemetry("transformers.deploy")
147+
def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
148+
"""Returns predictor depending on local or sagemaker endpoint mode
149+
150+
Returns:
151+
TransformersLocalModePredictor: During local mode deployment
152+
"""
153+
timeout = kwargs.get("model_data_download_timeout")
154+
if timeout:
155+
self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
156+
157+
if "mode" in kwargs and kwargs.get("mode") != self.mode:
158+
overwrite_mode = kwargs.get("mode")
159+
# mode overwritten by customer during model.deploy()
160+
logger.warning(
161+
"Deploying in %s Mode, overriding existing configurations set for %s mode",
162+
overwrite_mode,
163+
self.mode,
164+
)
165+
166+
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
167+
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
168+
elif overwrite_mode == Mode.LOCAL_CONTAINER:
169+
self._prepare_for_mode()
170+
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
171+
else:
172+
raise ValueError("Mode %s is not supported!" % overwrite_mode)
173+
174+
self._set_instance()
175+
176+
serializer = self.schema_builder.input_serializer
177+
deserializer = self.schema_builder._output_deserializer
178+
if self.mode == Mode.LOCAL_CONTAINER:
179+
timeout = kwargs.get("model_data_download_timeout")
180+
181+
predictor = TransformersLocalModePredictor(
182+
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
183+
)
184+
185+
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
186+
self.image_uri,
187+
timeout if timeout else DEFAULT_TIMEOUT,
188+
None,
189+
predictor,
190+
self.pysdk_model.env,
191+
jumpstart=False,
192+
)
193+
return predictor
194+
195+
if "mode" in kwargs:
196+
del kwargs["mode"]
197+
if "role" in kwargs:
198+
self.pysdk_model.role = kwargs.get("role")
199+
del kwargs["role"]
200+
201+
# set model_data to uncompressed s3 dict
202+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
203+
self.env_vars.update(env_vars)
204+
self.pysdk_model.env.update(self.env_vars)
205+
206+
if "endpoint_logging" not in kwargs:
207+
kwargs["endpoint_logging"] = True
208+
209+
if "initial_instance_count" not in kwargs:
210+
kwargs.update({"initial_instance_count": 1})
211+
212+
predictor = self._original_deploy(*args, **kwargs)
213+
214+
predictor.serializer = serializer
215+
predictor.deserializer = deserializer
216+
return predictor
217+
218+
def _build_transformers_env(self):
219+
"""Build model for hugging face deployment using"""
220+
self.nb_instance_type = _get_nb_instance()
221+
222+
_create_dir_structure(self.model_path)
223+
if not hasattr(self, "pysdk_model"):
224+
self.env_vars.update({"HF_MODEL_ID": self.model})
225+
226+
logger.info(self.env_vars)
227+
228+
# TODO: Move to a helper function
229+
if hasattr(self.env_vars, "HF_API_TOKEN"):
230+
self.hf_model_config = _get_model_config_properties_from_hf(
231+
self.model, self.env_vars.get("HF_API_TOKEN")
232+
)
233+
else:
234+
self.hf_model_config = _get_model_config_properties_from_hf(
235+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
236+
)
237+
238+
self.pysdk_model = self._create_transformers_model()
239+
240+
if self.mode == Mode.LOCAL_CONTAINER:
241+
self._prepare_for_mode()
242+
243+
return self.pysdk_model
244+
245+
def _set_instance(self, **kwargs):
246+
"""Set the instance : Given the detected notebook type or provided instance type"""
247+
if self.mode == Mode.SAGEMAKER_ENDPOINT:
248+
if self.nb_instance_type and "instance_type" not in kwargs:
249+
kwargs.update({"instance_type": self.nb_instance_type})
250+
elif self.instance_type and "instance_type" not in kwargs:
251+
kwargs.update({"instance_type": self.instance_type})
252+
else:
253+
raise ValueError(
254+
"Instance type must be provided when deploying to SageMaker Endpoint mode."
255+
)
256+
logger.info("Setting instance type to %s", self.instance_type)
257+
258+
def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
259+
"""Uses the hugging face json config to pick supported versions"""
260+
version_config = hf_config.get("versions").get(hugging_face_version)
261+
versions_to_return = list()
262+
for key in list(version_config.keys()):
263+
if key.startswith(base_fw):
264+
base_fw_version = key[len(base_fw) :]
265+
if len(hugging_face_version.split(".")) == 2:
266+
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
267+
versions_to_return.append(base_fw_version)
268+
return sorted(versions_to_return)[0]
269+
270+
def _build_for_transformers(self):
271+
"""Method that triggers model build
272+
273+
Returns:PySDK model
274+
"""
275+
self.secret_key = None
276+
self.model_server = ModelServer.MMS
277+
278+
self._build_transformers_env()
279+
280+
return self.pysdk_model

src/sagemaker/serve/mode/local_container_mode.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
2020
from sagemaker.serve.model_server.triton.server import LocalTritonServer
2121
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
22+
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
2223
from sagemaker.session import Session
2324

2425
logger = logging.getLogger(__name__)
@@ -31,7 +32,9 @@
3132
)
3233

3334

34-
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing):
35+
class LocalContainerMode(
36+
LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer
37+
):
3538
"""A class that holds methods to deploy model to a container in local environment"""
3639

3740
def __init__(
@@ -128,6 +131,15 @@ def create_server(
128131
jumpstart=jumpstart,
129132
)
130133
self._ping_container = self._tgi_deep_ping
134+
elif self.model_server == ModelServer.MMS:
135+
self._start_serving(
136+
client=self.client,
137+
image=image,
138+
model_path=model_path if model_path else self.model_path,
139+
secret_key=secret_key,
140+
env_vars=env_vars if env_vars else self.env_vars,
141+
)
142+
self._ping_container = self._multi_model_server_deep_ping
131143

132144
# allow some time for container to be ready
133145
time.sleep(10)

0 commit comments

Comments
 (0)