Skip to content

Commit fccff81

Browse files
author
EC2 Default User
committed
fix: enable uncompressed model artifacts upload to S3 for SAGEMAKER_ENDPOINT overwrite for TGI, TEI, MMS model servers
1 parent e09693c commit fccff81

File tree

8 files changed

+424
-54
lines changed

8 files changed

+424
-54
lines changed

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self):
6767
self.role_arn = None
6868

6969
@abstractmethod
70-
def _prepare_for_mode(self):
70+
def _prepare_for_mode(self, *args, **kwargs):
7171
"""Placeholder docstring"""
7272

7373
@abstractmethod
@@ -164,15 +164,24 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
164164
del kwargs["role"]
165165

166166
if not _is_optimized(self.pysdk_model):
167-
self._prepare_for_mode()
167+
env_vars = {}
168+
if str(Mode.LOCAL_CONTAINER) in self.modes:
169+
# upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT
170+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode(
171+
model_path=self.model_path, should_upload_artifacts=True
172+
)
173+
else:
174+
_, env_vars = self._prepare_for_mode()
175+
176+
self.env_vars.update(env_vars)
177+
self.pysdk_model.env.update(self.env_vars)
168178

169179
# if the weights have been cached via local container mode -> set to offline
170180
if str(Mode.LOCAL_CONTAINER) in self.modes:
171-
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
181+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"})
172182
else:
173183
# if has not been built for local container we must use cache
174184
# that hosting has write access to.
175-
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
176185
self.pysdk_model.env["HF_HOME"] = "/tmp"
177186
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"
178187

@@ -191,6 +200,9 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
191200

192201
predictor = self._original_deploy(*args, **kwargs)
193202

203+
if "HF_HUB_OFFLINE" in self.pysdk_model.env:
204+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"})
205+
194206
predictor.serializer = serializer
195207
predictor.deserializer = deserializer
196208
return predictor

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self):
9494
self.role_arn = None
9595

9696
@abstractmethod
97-
def _prepare_for_mode(self):
97+
def _prepare_for_mode(self, *args, **kwargs):
9898
"""Placeholder docstring"""
9999

100100
@abstractmethod
@@ -203,15 +203,24 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
203203
del kwargs["role"]
204204

205205
if not _is_optimized(self.pysdk_model):
206-
self._prepare_for_mode()
206+
env_vars = {}
207+
if str(Mode.LOCAL_CONTAINER) in self.modes:
208+
# upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT
209+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode(
210+
model_path=self.model_path, should_upload_artifacts=True
211+
)
212+
else:
213+
_, env_vars = self._prepare_for_mode()
214+
215+
self.env_vars.update(env_vars)
216+
self.pysdk_model.env.update(self.env_vars)
207217

208218
# if the weights have been cached via local container mode -> set to offline
209219
if str(Mode.LOCAL_CONTAINER) in self.modes:
210-
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
220+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"})
211221
else:
212222
# if has not been built for local container we must use cache
213223
# that hosting has write access to.
214-
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
215224
self.pysdk_model.env["HF_HOME"] = "/tmp"
216225
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"
217226

@@ -242,7 +251,8 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
242251

243252
predictor = self._original_deploy(*args, **kwargs)
244253

245-
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "0"})
254+
if "HF_HUB_OFFLINE" in self.pysdk_model.env:
255+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"})
246256

247257
predictor.serializer = serializer
248258
predictor.deserializer = deserializer

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self):
8484
self.shared_libs = None
8585

8686
@abstractmethod
87-
def _prepare_for_mode(self):
87+
def _prepare_for_mode(self, *args, **kwargs):
8888
"""Abstract method"""
8989

9090
def _create_transformers_model(self) -> Type[Model]:
@@ -234,7 +234,17 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
234234
del kwargs["role"]
235235

236236
if not _is_optimized(self.pysdk_model):
237-
self._prepare_for_mode()
237+
env_vars = {}
238+
if str(Mode.LOCAL_CONTAINER) in self.modes:
239+
# upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT
240+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode(
241+
model_path=self.model_path, should_upload_artifacts=True
242+
)
243+
else:
244+
_, env_vars = self._prepare_for_mode()
245+
246+
self.env_vars.update(env_vars)
247+
self.pysdk_model.env.update(self.env_vars)
238248

239249
if "endpoint_logging" not in kwargs:
240250
kwargs["endpoint_logging"] = True

src/sagemaker/serve/model_server/tei/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
MODE_DIR_BINDING = "/opt/ml/model/"
1818
_SHM_SIZE = "2G"
1919
_DEFAULT_ENV_VARS = {
20-
"TRANSFORMERS_CACHE": "/opt/ml/model/",
2120
"HF_HOME": "/opt/ml/model/",
2221
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
2322
}

src/sagemaker/serve/model_server/tgi/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
MODE_DIR_BINDING = "/opt/ml/model/"
1818
_SHM_SIZE = "2G"
1919
_DEFAULT_ENV_VARS = {
20-
"TRANSFORMERS_CACHE": "/opt/ml/model/",
2120
"HF_HOME": "/opt/ml/model/",
2221
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
2322
}

tests/unit/sagemaker/serve/builder/test_tei_builder.py

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020

2121
from sagemaker.serve.utils.predictors import TeiLocalModePredictor
2222

23-
mock_model_id = "bert-base-uncased"
24-
mock_prompt = "The man worked as a [MASK]."
25-
mock_sample_input = {"inputs": mock_prompt}
26-
mock_sample_output = [
23+
MOCK_MODEL_ID = "bert-base-uncased"
24+
MOCK_PROMPT = "The man worked as a [MASK]."
25+
MOCK_SAMPLE_INPUT = {"inputs": MOCK_PROMPT}
26+
MOCK_SAMPLE_OUTPUT = [
2727
{
2828
"score": 0.0974755585193634,
2929
"token": 10533,
@@ -55,13 +55,14 @@
5555
"sequence": "the man worked as a salesman.",
5656
},
5757
]
58-
mock_schema_builder = MagicMock()
59-
mock_schema_builder.sample_input = mock_sample_input
60-
mock_schema_builder.sample_output = mock_sample_output
58+
MOCK_SCHEMA_BUILDER = MagicMock()
59+
MOCK_SCHEMA_BUILDER.sample_input = MOCK_SAMPLE_INPUT
60+
MOCK_SCHEMA_BUILDER.sample_output = MOCK_SAMPLE_OUTPUT
6161
MOCK_IMAGE_CONFIG = (
6262
"763104351884.dkr.ecr.us-west-2.amazonaws.com/"
6363
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0"
6464
)
65+
MOCK_MODEL_PATH = "mock model path"
6566

6667

6768
class TestTEIBuilder(unittest.TestCase):
@@ -70,57 +71,136 @@ class TestTEIBuilder(unittest.TestCase):
7071
return_value="ml.g5.24xlarge",
7172
)
7273
@patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None)
73-
def test_build_deploy_for_tei_local_container_and_remote_container(
74+
def test_tei_builder_sagemaker_endpoint_mode_no_s3_upload_success(
7475
self,
7576
mock_get_nb_instance,
7677
mock_telemetry,
7778
):
79+
# verify SAGEMAKER_ENDPOINT deploy
7880
builder = ModelBuilder(
79-
model=mock_model_id,
80-
schema_builder=mock_schema_builder,
81+
model=MOCK_MODEL_ID,
82+
schema_builder=MOCK_SCHEMA_BUILDER,
83+
mode=Mode.SAGEMAKER_ENDPOINT,
84+
model_metadata={
85+
"HF_TASK": "sentence-similarity",
86+
},
87+
)
88+
89+
builder._prepare_for_mode = MagicMock()
90+
builder._prepare_for_mode.return_value = (None, {})
91+
model = builder.build()
92+
builder.serve_settings.telemetry_opt_out = True
93+
builder._original_deploy = MagicMock()
94+
95+
model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
96+
97+
assert "HF_MODEL_ID" in model.env
98+
with self.assertRaises(ValueError) as _:
99+
model.deploy(mode=Mode.IN_PROCESS)
100+
builder._prepare_for_mode.assert_called_with()
101+
102+
@patch(
103+
"sagemaker.serve.builder.tei_builder._get_nb_instance",
104+
return_value="ml.g5.24xlarge",
105+
)
106+
@patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None)
107+
def test_tei_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success(
108+
self,
109+
mock_get_nb_instance,
110+
mock_telemetry,
111+
):
112+
# verify LOCAL_CONTAINER deploy
113+
builder = ModelBuilder(
114+
model=MOCK_MODEL_ID,
115+
schema_builder=MOCK_SCHEMA_BUILDER,
81116
mode=Mode.LOCAL_CONTAINER,
82117
vpc_config=MOCK_VPC_CONFIG,
83118
model_metadata={
84119
"HF_TASK": "sentence-similarity",
85120
},
121+
model_path=MOCK_MODEL_PATH,
86122
)
87123

88124
builder._prepare_for_mode = MagicMock()
89125
builder._prepare_for_mode.side_effect = None
90-
91126
model = builder.build()
92127
builder.serve_settings.telemetry_opt_out = True
93-
94128
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
129+
95130
predictor = model.deploy(model_data_download_timeout=1800)
96131

97132
assert model.vpc_config == MOCK_VPC_CONFIG
98133
assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
99134
assert isinstance(predictor, TeiLocalModePredictor)
100-
101135
assert builder.nb_instance_type == "ml.g5.24xlarge"
102136

137+
# verify SAGEMAKER_ENDPOINT overwritten deploy
103138
builder._original_deploy = MagicMock()
104139
builder._prepare_for_mode.return_value = (None, {})
105-
predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
106-
assert "HF_MODEL_ID" in model.env
107140

141+
model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
142+
143+
assert "HF_MODEL_ID" in model.env
108144
with self.assertRaises(ValueError) as _:
109145
model.deploy(mode=Mode.IN_PROCESS)
146+
builder._prepare_for_mode.call_args_list[1].assert_called_once_with(
147+
model_path=MOCK_MODEL_PATH, should_upload_artifacts=True
148+
)
149+
150+
@patch(
151+
"sagemaker.serve.builder.tei_builder._get_nb_instance",
152+
return_value="ml.g5.24xlarge",
153+
)
154+
@patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None)
155+
@patch("sagemaker.serve.builder.tei_builder._is_optimized", return_value=True)
156+
def test_tei_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success(
157+
self,
158+
mock_is_optimized,
159+
mock_get_nb_instance,
160+
mock_telemetry,
161+
):
162+
# verify LOCAL_CONTAINER deploy
163+
builder = ModelBuilder(
164+
model=MOCK_MODEL_ID,
165+
schema_builder=MOCK_SCHEMA_BUILDER,
166+
mode=Mode.LOCAL_CONTAINER,
167+
vpc_config=MOCK_VPC_CONFIG,
168+
model_metadata={
169+
"HF_TASK": "sentence-similarity",
170+
},
171+
model_path=MOCK_MODEL_PATH,
172+
)
173+
174+
builder._prepare_for_mode = MagicMock()
175+
builder._prepare_for_mode.side_effect = None
176+
model = builder.build()
177+
builder.serve_settings.telemetry_opt_out = True
178+
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
179+
180+
model.deploy(model_data_download_timeout=1800)
181+
182+
# verify SAGEMAKER_ENDPOINT overwritten deploy
183+
builder._original_deploy = MagicMock()
184+
builder._prepare_for_mode.return_value = (None, {})
185+
186+
model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
187+
188+
# verify that if optimized, no s3 upload occurs
189+
builder._prepare_for_mode.assert_called_with()
110190

111191
@patch(
112192
"sagemaker.serve.builder.tei_builder._get_nb_instance",
113193
return_value="ml.g5.24xlarge",
114194
)
115195
@patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None)
116-
def test_image_uri_override(
196+
def test_tei_builder_image_uri_override_success(
117197
self,
118198
mock_get_nb_instance,
119199
mock_telemetry,
120200
):
121201
builder = ModelBuilder(
122-
model=mock_model_id,
123-
schema_builder=mock_schema_builder,
202+
model=MOCK_MODEL_ID,
203+
schema_builder=MOCK_SCHEMA_BUILDER,
124204
mode=Mode.LOCAL_CONTAINER,
125205
image_uri=MOCK_IMAGE_CONFIG,
126206
model_metadata={

0 commit comments

Comments
 (0)