Skip to content

Commit 3764325

Browse files
authored
Merge branch 'master' into fix-remove-kwargs
2 parents 63dab51 + ee6ef13 commit 3764325

File tree

11 files changed

+737
-58
lines changed

11 files changed

+737
-58
lines changed

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self):
6767
self.role_arn = None
6868

6969
@abstractmethod
70-
def _prepare_for_mode(self):
70+
def _prepare_for_mode(self, *args, **kwargs):
7171
"""Placeholder docstring"""
7272

7373
@abstractmethod
@@ -164,15 +164,24 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
164164
del kwargs["role"]
165165

166166
if not _is_optimized(self.pysdk_model):
167-
self._prepare_for_mode()
167+
env_vars = {}
168+
if str(Mode.LOCAL_CONTAINER) in self.modes:
169+
# upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT
170+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode(
171+
model_path=self.model_path, should_upload_artifacts=True
172+
)
173+
else:
174+
_, env_vars = self._prepare_for_mode()
175+
176+
self.env_vars.update(env_vars)
177+
self.pysdk_model.env.update(self.env_vars)
168178

169179
# if the weights have been cached via local container mode -> set to offline
170180
if str(Mode.LOCAL_CONTAINER) in self.modes:
171-
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
181+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"})
172182
else:
173183
# if has not been built for local container we must use cache
174184
# that hosting has write access to.
175-
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
176185
self.pysdk_model.env["HF_HOME"] = "/tmp"
177186
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"
178187

@@ -191,6 +200,9 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
191200

192201
predictor = self._original_deploy(*args, **kwargs)
193202

203+
if "HF_HUB_OFFLINE" in self.pysdk_model.env:
204+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"})
205+
194206
predictor.serializer = serializer
195207
predictor.deserializer = deserializer
196208
return predictor

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self):
9494
self.role_arn = None
9595

9696
@abstractmethod
97-
def _prepare_for_mode(self):
97+
def _prepare_for_mode(self, *args, **kwargs):
9898
"""Placeholder docstring"""
9999

100100
@abstractmethod
@@ -203,15 +203,24 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
203203
del kwargs["role"]
204204

205205
if not _is_optimized(self.pysdk_model):
206-
self._prepare_for_mode()
206+
env_vars = {}
207+
if str(Mode.LOCAL_CONTAINER) in self.modes:
208+
# upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT
209+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode(
210+
model_path=self.model_path, should_upload_artifacts=True
211+
)
212+
else:
213+
_, env_vars = self._prepare_for_mode()
214+
215+
self.env_vars.update(env_vars)
216+
self.pysdk_model.env.update(self.env_vars)
207217

208218
# if the weights have been cached via local container mode -> set to offline
209219
if str(Mode.LOCAL_CONTAINER) in self.modes:
210-
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
220+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"})
211221
else:
212222
# if has not been built for local container we must use cache
213223
# that hosting has write access to.
214-
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
215224
self.pysdk_model.env["HF_HOME"] = "/tmp"
216225
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"
217226

@@ -242,7 +251,8 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
242251

243252
predictor = self._original_deploy(*args, **kwargs)
244253

245-
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "0"})
254+
if "HF_HUB_OFFLINE" in self.pysdk_model.env:
255+
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"})
246256

247257
predictor.serializer = serializer
248258
predictor.deserializer = deserializer

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self):
8484
self.shared_libs = None
8585

8686
@abstractmethod
87-
def _prepare_for_mode(self):
87+
def _prepare_for_mode(self, *args, **kwargs):
8888
"""Abstract method"""
8989

9090
def _create_transformers_model(self) -> Type[Model]:
@@ -206,8 +206,6 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
206206
else:
207207
raise ValueError("Mode %s is not supported!" % overwrite_mode)
208208

209-
self._set_instance()
210-
211209
serializer = self.schema_builder.input_serializer
212210
deserializer = self.schema_builder._output_deserializer
213211
if self.mode == Mode.LOCAL_CONTAINER:
@@ -227,14 +225,32 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
227225
)
228226
return predictor
229227

228+
self._set_instance(kwargs)
229+
230230
if "mode" in kwargs:
231231
del kwargs["mode"]
232232
if "role" in kwargs:
233233
self.pysdk_model.role = kwargs.get("role")
234234
del kwargs["role"]
235235

236236
if not _is_optimized(self.pysdk_model):
237-
self._prepare_for_mode()
237+
env_vars = {}
238+
if str(Mode.LOCAL_CONTAINER) in self.modes:
239+
# upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT
240+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode(
241+
model_path=self.model_path, should_upload_artifacts=True
242+
)
243+
else:
244+
_, env_vars = self._prepare_for_mode()
245+
246+
self.env_vars.update(env_vars)
247+
self.pysdk_model.env.update(self.env_vars)
248+
249+
if (
250+
"SAGEMAKER_SERVE_SECRET_KEY" in self.pysdk_model.env
251+
and not self.pysdk_model.env["SAGEMAKER_SERVE_SECRET_KEY"]
252+
):
253+
del self.pysdk_model.env["SAGEMAKER_SERVE_SECRET_KEY"]
238254

239255
if "endpoint_logging" not in kwargs:
240256
kwargs["endpoint_logging"] = True
@@ -279,9 +295,11 @@ def _build_transformers_env(self):
279295

280296
return self.pysdk_model
281297

282-
def _set_instance(self, **kwargs):
298+
def _set_instance(self, kwargs):
283299
"""Set the instance : Given the detected notebook type or provided instance type"""
284300
if self.mode == Mode.SAGEMAKER_ENDPOINT:
301+
if "instance_type" in kwargs:
302+
return
285303
if self.nb_instance_type and "instance_type" not in kwargs:
286304
kwargs.update({"instance_type": self.nb_instance_type})
287305
logger.info("Setting instance type to %s", self.nb_instance_type)

src/sagemaker/serve/model_server/tei/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
MODE_DIR_BINDING = "/opt/ml/model/"
1818
_SHM_SIZE = "2G"
1919
_DEFAULT_ENV_VARS = {
20-
"TRANSFORMERS_CACHE": "/opt/ml/model/",
2120
"HF_HOME": "/opt/ml/model/",
2221
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
2322
}

src/sagemaker/serve/model_server/tgi/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
MODE_DIR_BINDING = "/opt/ml/model/"
1818
_SHM_SIZE = "2G"
1919
_DEFAULT_ENV_VARS = {
20-
"TRANSFORMERS_CACHE": "/opt/ml/model/",
2120
"HF_HOME": "/opt/ml/model/",
2221
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
2322
}

tests/unit/sagemaker/serve/builder/test_djl_builder.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"HF_MODEL_ID": "TheBloke/Llama-2-7b-chat-fp16",
4545
"TENSOR_PARALLEL_DEGREE": "1",
4646
"OPTION_DTYPE": "bf16",
47+
"MODEL_LOADING_TIMEOUT": "1800",
4748
}
4849

4950
mock_schema_builder = MagicMock()
@@ -63,8 +64,13 @@ class TestDjlBuilder(unittest.TestCase):
6364
)
6465
@patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024)
6566
@patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge")
67+
@patch(
68+
"sagemaker.serve.builder.djl_builder._get_default_djl_configurations",
69+
return_value=(mock_default_configs, 128),
70+
)
6671
def test_build_deploy_for_djl_local_container(
6772
self,
73+
mock_default_djl_config,
6874
mock_get_nb_instance,
6975
mock_get_ram_usage_mb,
7076
mock_is_jumpstart_model,
@@ -125,8 +131,13 @@ def test_build_deploy_for_djl_local_container(
125131
"sagemaker.serve.builder.djl_builder._concurrent_benchmark",
126132
side_effect=[(0.03, 16), (0.10, 4), (0.15, 2)],
127133
)
134+
@patch(
135+
"sagemaker.serve.builder.djl_builder._get_default_djl_configurations",
136+
return_value=(mock_default_configs, 128),
137+
)
128138
def test_tune_for_djl_local_container(
129139
self,
140+
mock_default_djl_config,
130141
mock_concurrent_benchmarks,
131142
mock_serial_benchmarks,
132143
mock_admissible_tensor_parallel_degrees,
@@ -165,8 +176,10 @@ def test_tune_for_djl_local_container(
165176
"sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees",
166177
return_value=[4],
167178
)
179+
@patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None)
168180
def test_tune_for_djl_local_container_deep_ping_ex(
169181
self,
182+
mock_get_available_gpus,
170183
mock_get_admissible_tensor_parallel_degrees,
171184
mock_serial_benchmarks,
172185
mock_get_nb_instance,
@@ -204,8 +217,10 @@ def test_tune_for_djl_local_container_deep_ping_ex(
204217
"sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees",
205218
return_value=[4],
206219
)
220+
@patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None)
207221
def test_tune_for_djl_local_container_load_ex(
208222
self,
223+
mock_get_available_gpus,
209224
mock_get_admissible_tensor_parallel_degrees,
210225
mock_serial_benchmarks,
211226
mock_get_nb_instance,
@@ -245,8 +260,10 @@ def test_tune_for_djl_local_container_load_ex(
245260
"sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees",
246261
return_value=[4],
247262
)
263+
@patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None)
248264
def test_tune_for_djl_local_container_oom_ex(
249265
self,
266+
mock_get_available_gpus,
250267
mock_get_admissible_tensor_parallel_degrees,
251268
mock_serial_benchmarks,
252269
mock_get_nb_instance,
@@ -283,8 +300,10 @@ def test_tune_for_djl_local_container_oom_ex(
283300
"sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees",
284301
return_value=[4],
285302
)
303+
@patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None)
286304
def test_tune_for_djl_local_container_invoke_ex(
287305
self,
306+
mock_get_available_gpus,
288307
mock_get_admissible_tensor_parallel_degrees,
289308
mock_serial_benchmarks,
290309
mock_get_nb_instance,

0 commit comments

Comments
 (0)