Skip to content

Commit c176134

Browse files
authored
fix: pass name from modelbuilder constructor to created model (#4859)
1 parent 7aa39f9 commit c176134

12 files changed

+117
-9
lines changed

src/sagemaker/serve/builder/djl_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(self):
9292
self.nb_instance_type = None
9393
self.ram_usage_model_load = None
9494
self.role_arn = None
95+
self.name = None
9596

9697
@abstractmethod
9798
def _prepare_for_mode(self):
@@ -130,6 +131,7 @@ def _create_djl_model(self) -> Type[Model]:
130131
huggingface_hub_token=self.env_vars.get("HF_TOKEN"),
131132
image_config=self.image_config,
132133
vpc_config=self.vpc_config,
134+
name=self.name,
133135
)
134136

135137
if not self.image_uri:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(self):
121121
self.is_compiled = False
122122
self.is_quantized = False
123123
self.speculative_decoding_draft_model_source = None
124+
self.name = None
124125

125126
@abstractmethod
126127
def _prepare_for_mode(self, **kwargs):
@@ -147,7 +148,10 @@ def _is_jumpstart_model_id(self) -> bool:
147148
def _create_pre_trained_js_model(self) -> Type[Model]:
148149
"""Placeholder docstring"""
149150
pysdk_model = JumpStartModel(
150-
self.model, vpc_config=self.vpc_config, sagemaker_session=self.sagemaker_session
151+
self.model,
152+
vpc_config=self.vpc_config,
153+
sagemaker_session=self.sagemaker_session,
154+
name=self.name,
151155
)
152156

153157
self._original_deploy = pysdk_model.deploy

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def _create_model(self):
492492
env=self.env_vars,
493493
sagemaker_session=self.sagemaker_session,
494494
predictor_cls=self._get_predictor,
495+
name=self.name,
495496
)
496497

497498
# store the modes in the model so that we may

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self):
6565
self.ram_usage_model_load = None
6666
self.secret_key = None
6767
self.role_arn = None
68+
self.name = None
6869

6970
@abstractmethod
7071
def _prepare_for_mode(self, *args, **kwargs):
@@ -105,6 +106,7 @@ def _create_tei_model(self, **kwargs) -> Type[Model]:
105106
env=self.env_vars,
106107
role=self.role_arn,
107108
sagemaker_session=self.sagemaker_session,
109+
name=self.name,
108110
)
109111

110112
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

src/sagemaker/serve/builder/tf_serving_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self):
5151
self.pysdk_model = None
5252
self.schema_builder = None
5353
self.env_vars = None
54+
self.name = None
5455

5556
@abstractmethod
5657
def _prepare_for_mode(self):
@@ -97,6 +98,7 @@ def _create_tensorflow_model(self):
9798
env=self.env_vars,
9899
sagemaker_session=self.sagemaker_session,
99100
predictor_cls=self._get_tensorflow_predictor,
101+
name=self.name,
100102
)
101103

102104
self.pysdk_model.mode = self.mode

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(self):
9292
self.ram_usage_model_load = None
9393
self.secret_key = None
9494
self.role_arn = None
95+
self.name = None
9596

9697
@abstractmethod
9798
def _prepare_for_mode(self, *args, **kwargs):
@@ -142,6 +143,7 @@ def _create_tgi_model(self) -> Type[Model]:
142143
env=self.env_vars,
143144
role=self.role_arn,
144145
sagemaker_session=self.sagemaker_session,
146+
name=self.name,
145147
)
146148

147149
self._original_deploy = pysdk_model.deploy

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(self):
8989
self.schema_builder = None
9090
self.inference_spec = None
9191
self.shared_libs = None
92+
self.name = None
9293

9394
@abstractmethod
9495
def _prepare_for_mode(self, *args, **kwargs):
@@ -105,6 +106,7 @@ def _create_transformers_model(self) -> Type[Model]:
105106
env=self.env_vars,
106107
role=self.role_arn,
107108
sagemaker_session=self.sagemaker_session,
109+
name=self.name,
108110
)
109111

110112
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

tests/unit/sagemaker/serve/builder/test_djl_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def test_build_deploy_for_djl_local_container(
7878
):
7979
builder = ModelBuilder(
8080
model=mock_model_id,
81+
name="mock_model_name",
8182
schema_builder=mock_schema_builder,
8283
mode=Mode.LOCAL_CONTAINER,
8384
model_server=ModelServer.DJL_SERVING,
@@ -89,6 +90,8 @@ def test_build_deploy_for_djl_local_container(
8990
builder._prepare_for_mode.side_effect = None
9091

9192
model = builder.build()
93+
assert model.name == "mock_model_name"
94+
9295
builder.serve_settings.telemetry_opt_out = True
9396

9497
assert isinstance(model, DJLModel)

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
317317
)
318318

319319
mock_model_obj = Mock()
320-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
320+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
321321
mock_model_obj
322322
if image_uri == mock_image_uri
323323
and image_config == MOCK_IMAGE_CONFIG
@@ -326,6 +326,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
326326
and role == mock_role_arn
327327
and env == ENV_VARS
328328
and sagemaker_session == mock_session
329+
and "model-name-" in name
329330
else None
330331
)
331332

@@ -425,13 +426,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
425426
)
426427

427428
mock_model_obj = Mock()
428-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
429+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
429430
mock_model_obj
430431
if image_uri == mock_1p_dlc_image_uri
431432
and model_data == model_data
432433
and role == mock_role_arn
433434
and env == ENV_VARS
434435
and sagemaker_session == mock_session
436+
and "model-name-" in name
435437
else None
436438
)
437439

@@ -532,13 +534,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
532534
)
533535

534536
mock_model_obj = Mock()
535-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
537+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
536538
mock_model_obj
537539
if image_uri == mock_image_uri
538540
and model_data == model_data
539541
and role == mock_role_arn
540542
and env == ENV_VARS_INF_SPEC
541543
and sagemaker_session == mock_session
544+
and "model-name-" in name
542545
else None
543546
)
544547

@@ -633,13 +636,14 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
633636
)
634637

635638
mock_model_obj = Mock()
636-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
639+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
637640
mock_model_obj
638641
if image_uri == mock_image_uri
639642
and model_data == model_data
640643
and role == mock_role_arn
641644
and env == ENV_VARS
642645
and sagemaker_session == mock_session
646+
and "model-name-" in name
643647
else None
644648
)
645649

@@ -742,13 +746,14 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
742746
)
743747

744748
mock_model_obj = Mock()
745-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
749+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
746750
mock_model_obj
747751
if image_uri == mock_image_uri
748752
and model_data == model_data
749753
and role == mock_role_arn
750754
and env == ENV_VARS
751755
and sagemaker_session == mock_session
756+
and "model-name-" in name
752757
else None
753758
)
754759

@@ -847,13 +852,14 @@ def test_build_happy_path_with_local_container_mode(
847852
mock_mode.prepare.side_effect = lambda: None
848853

849854
mock_model_obj = Mock()
850-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
855+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
851856
mock_model_obj
852857
if image_uri == mock_image_uri
853858
and model_data is None
854859
and role == mock_role_arn
855860
and env == {}
856861
and sagemaker_session == mock_session
862+
and "model-name-" in name
857863
else None
858864
)
859865

@@ -968,13 +974,14 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
968974
)
969975

970976
mock_model_obj = Mock()
971-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
977+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
972978
mock_model_obj
973979
if image_uri == mock_image_uri
974980
and model_data is None
975981
and role == mock_role_arn
976982
and env == {}
977983
and sagemaker_session == mock_session
984+
and "model-name-" in name
978985
else None
979986
)
980987

@@ -1119,13 +1126,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
11191126
)
11201127

11211128
mock_model_obj = Mock()
1122-
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501
1129+
mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501
11231130
mock_model_obj
11241131
if image_uri == mock_image_uri
11251132
and model_data == model_data
11261133
and role == mock_role_arn
11271134
and env == ENV_VARS
11281135
and sagemaker_session == mock_session
1136+
and "model-name-" in name
11291137
else None
11301138
)
11311139

tests/unit/sagemaker/serve/builder/test_tei_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_tei_builder_sagemaker_endpoint_mode_no_s3_upload_success(
7979
# verify SAGEMAKER_ENDPOINT deploy
8080
builder = ModelBuilder(
8181
model=MOCK_MODEL_ID,
82+
name="mock_model_name",
8283
schema_builder=MOCK_SCHEMA_BUILDER,
8384
mode=Mode.SAGEMAKER_ENDPOINT,
8485
model_metadata={
@@ -88,7 +89,10 @@ def test_tei_builder_sagemaker_endpoint_mode_no_s3_upload_success(
8889

8990
builder._prepare_for_mode = MagicMock()
9091
builder._prepare_for_mode.return_value = (None, {})
92+
9193
model = builder.build()
94+
assert model.name == "mock_model_name"
95+
9296
builder.serve_settings.telemetry_opt_out = True
9397
builder._original_deploy = MagicMock()
9498

tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def setUp(self):
3333
self.instance.image_config = {}
3434
self.instance.vpc_config = {}
3535
self.instance.modes = {}
36+
self.instance.name = "model-name-mock-uuid-hex"
3637

3738
@patch("os.makedirs")
3839
@patch("os.path.exists")
@@ -71,5 +72,6 @@ def test_create_tensorflow_model(self, mock_model):
7172
env=self.instance.env_vars,
7273
sagemaker_session=self.instance.sagemaker_session,
7374
predictor_cls=self.instance._get_tensorflow_predictor,
75+
name="model-name-mock-uuid-hex",
7476
)
7577
self.assertEqual(model, mock_model.return_value)

tests/unit/sagemaker/serve/builder/test_tgi_builder.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,17 @@ def test_tgi_builder_sagemaker_endpoint_mode_no_s3_upload_success(
6161
# verify SAGEMAKER_ENDPOINT deploy
6262
builder = ModelBuilder(
6363
model=MOCK_MODEL_ID,
64+
name="mock_model_name",
6465
schema_builder=MOCK_SCHEMA_BUILDER,
6566
mode=Mode.SAGEMAKER_ENDPOINT,
6667
)
6768

6869
builder._prepare_for_mode = MagicMock()
6970
builder._prepare_for_mode.return_value = (None, {})
71+
7072
model = builder.build()
73+
assert model.name == "mock_model_name"
74+
7175
builder.serve_settings.telemetry_opt_out = True
7276
builder._original_deploy = MagicMock()
7377

@@ -187,3 +191,75 @@ def test_tgi_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success(
187191

188192
# verify that if optimized, no s3 upload occurs
189193
builder._prepare_for_mode.assert_called_with()
194+
195+
@patch(
196+
"sagemaker.serve.builder.tgi_builder._get_nb_instance",
197+
return_value="ml.g5.24xlarge",
198+
)
199+
@patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None)
200+
@patch(
201+
"sagemaker.serve.builder.model_builder.get_huggingface_model_metadata",
202+
return_value={"pipeline_tag": "text-generation"},
203+
)
204+
@patch(
205+
"sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf",
206+
return_value=({}, None),
207+
)
208+
@patch(
209+
"sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations",
210+
return_value=({}, None),
211+
)
212+
@patch(
213+
"sagemaker.serve.builder.tgi_builder._get_admissible_tensor_parallel_degrees",
214+
return_value=[4, 8],
215+
)
216+
@patch("sagemaker.serve.builder.tgi_builder._get_admissible_dtypes", return_value=["fp16"])
217+
@patch("sagemaker.serve.builder.tgi_builder.datetime")
218+
@patch("sagemaker.serve.builder.tgi_builder.timedelta", return_value=1800)
219+
@patch("sagemaker.serve.builder.tgi_builder._serial_benchmark")
220+
@patch("sagemaker.serve.builder.tgi_builder._concurrent_benchmark")
221+
def test_tgi_builder_tune_success(
222+
self,
223+
mock_concurrent_benchmark,
224+
mock_serial_benchmark,
225+
mock_timedelta,
226+
mock_datetime,
227+
mock_get_admissible_dtypes,
228+
mock_get_admissible_tensor_parallel_degrees,
229+
mock_default_tgi_configurations,
230+
mock_hf_model_config,
231+
mock_hf_model_md,
232+
mock_get_nb_instance,
233+
mock_telemetry,
234+
):
235+
# WHERE
236+
mock_datetime.now.side_effect = [0, 100, 200]
237+
mock_serial_benchmark.side_effect = [(1000, 10000, 10), (500, 5000, 50)]
238+
mock_concurrent_benchmark.side_effect = [(10, 10), (50, 5)]
239+
240+
builder = ModelBuilder(
241+
model=MOCK_MODEL_ID,
242+
schema_builder=MOCK_SCHEMA_BUILDER,
243+
mode=Mode.LOCAL_CONTAINER,
244+
model_path=MOCK_MODEL_PATH,
245+
)
246+
builder._prepare_for_mode = MagicMock()
247+
builder._prepare_for_mode.side_effect = None
248+
249+
model = builder.build()
250+
251+
builder.serve_settings.telemetry_opt_out = True
252+
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
253+
builder.pysdk_model = MagicMock()
254+
255+
# WHEN
256+
ret_new_model = model.tune(max_tuning_duration=1800)
257+
258+
# THEN
259+
assert ret_new_model != model
260+
assert len(mock_datetime.now.call_args_list) == 3
261+
assert len(mock_serial_benchmark.call_args_list) == 2
262+
assert len(mock_concurrent_benchmark.call_args_list) == 2
263+
assert ret_new_model.env["NUM_SHARD"] == "8"
264+
assert ret_new_model.env["DTYPE"] == "fp16"
265+
assert ret_new_model.env["SHARDED"] == "true"

0 commit comments

Comments
 (0)