Skip to content

Commit f8ff838

Browse files
evakravibenieric
andauthored
feat: JumpStartModel attach (aws#4680)
* feat: JumpStartModel attach * fix: unit tests * chore: change order of kwargs to pass unit tests * chore: update docstrings, add tests * fix: docstring * fix: integ tests * chore: address PR comments --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 908daa1 commit f8ff838

File tree

7 files changed

+122
-4
lines changed

7 files changed

+122
-4
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def get_deploy_kwargs(
258258
deserializer: Optional[BaseDeserializer] = None,
259259
accelerator_type: Optional[str] = None,
260260
endpoint_name: Optional[str] = None,
261+
inference_component_name: Optional[str] = None,
261262
tags: Optional[Tags] = None,
262263
kms_key: Optional[str] = None,
263264
wait: Optional[bool] = None,
@@ -302,6 +303,7 @@ def get_deploy_kwargs(
302303
deserializer=deserializer,
303304
accelerator_type=accelerator_type,
304305
endpoint_name=endpoint_name,
306+
inference_component_name=inference_component_name,
305307
tags=format_tags(tags),
306308
kms_key=kms_key,
307309
wait=wait,

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ def get_deploy_kwargs(
542542
deserializer: Optional[BaseDeserializer] = None,
543543
accelerator_type: Optional[str] = None,
544544
endpoint_name: Optional[str] = None,
545+
inference_component_name: Optional[str] = None,
545546
tags: Optional[Tags] = None,
546547
kms_key: Optional[str] = None,
547548
wait: Optional[bool] = None,
@@ -576,6 +577,7 @@ def get_deploy_kwargs(
576577
deserializer=deserializer,
577578
accelerator_type=accelerator_type,
578579
endpoint_name=endpoint_name,
580+
inference_component_name=inference_component_name,
579581
tags=format_tags(tags),
580582
kms_key=kms_key,
581583
wait=wait,

src/sagemaker/jumpstart/model.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,13 @@
3636
get_init_kwargs,
3737
get_register_kwargs,
3838
)
39+
from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint
3940
from sagemaker.jumpstart.types import JumpStartSerializablePayload
4041
from sagemaker.jumpstart.utils import (
4142
validate_model_id_and_get_type,
4243
verify_model_region_and_return_specs,
4344
)
44-
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
45+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
4546
from sagemaker.jumpstart.enums import JumpStartModelType
4647
from sagemaker.model_card import (
4748
ModelCard,
@@ -406,6 +407,45 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
406407
sagemaker_session=self.sagemaker_session,
407408
)
408409

410+
@classmethod
411+
def attach(
412+
cls,
413+
endpoint_name: str,
414+
inference_component_name: Optional[str] = None,
415+
model_id: Optional[str] = None,
416+
model_version: Optional[str] = None,
417+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
418+
) -> "JumpStartModel":
419+
"""Attaches a JumpStartModel object to an existing SageMaker Endpoint.
420+
421+
The model id, version (and inference component name) can be inferred from the tags.
422+
"""
423+
424+
inferred_model_id = inferred_model_version = inferred_inference_component_name = None
425+
426+
if inference_component_name is None or model_id is None or model_version is None:
427+
inferred_model_id, inferred_model_version, inferred_inference_component_name = (
428+
get_model_id_version_from_endpoint(
429+
endpoint_name=endpoint_name,
430+
inference_component_name=inference_component_name,
431+
sagemaker_session=sagemaker_session,
432+
)
433+
)
434+
435+
model_id = model_id or inferred_model_id
436+
model_version = model_version or inferred_model_version or "*"
437+
inference_component_name = inference_component_name or inferred_inference_component_name
438+
439+
model = JumpStartModel(
440+
model_id=model_id,
441+
model_version=model_version,
442+
sagemaker_session=sagemaker_session,
443+
)
444+
model.endpoint_name = endpoint_name
445+
model.inference_component_name = inference_component_name
446+
447+
return model
448+
409449
def _create_sagemaker_model(
410450
self,
411451
instance_type=None,
@@ -484,6 +524,7 @@ def deploy(
484524
deserializer: Optional[BaseDeserializer] = None,
485525
accelerator_type: Optional[str] = None,
486526
endpoint_name: Optional[str] = None,
527+
inference_component_name: Optional[str] = None,
487528
tags: Optional[Tags] = None,
488529
kms_key: Optional[str] = None,
489530
wait: Optional[bool] = True,
@@ -614,6 +655,7 @@ def deploy(
614655
deserializer=deserializer,
615656
accelerator_type=accelerator_type,
616657
endpoint_name=endpoint_name,
658+
inference_component_name=inference_component_name,
617659
tags=format_tags(tags),
618660
kms_key=kms_key,
619661
wait=wait,

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
15961596
"deserializer",
15971597
"accelerator_type",
15981598
"endpoint_name",
1599+
"inference_component_name",
15991600
"tags",
16001601
"kms_key",
16011602
"wait",
@@ -1641,6 +1642,7 @@ def __init__(
16411642
deserializer: Optional[Any] = None,
16421643
accelerator_type: Optional[str] = None,
16431644
endpoint_name: Optional[str] = None,
1645+
inference_component_name: Optional[str] = None,
16441646
tags: Optional[Tags] = None,
16451647
kms_key: Optional[str] = None,
16461648
wait: Optional[bool] = None,
@@ -1674,6 +1676,7 @@ def __init__(
16741676
self.deserializer = deserializer
16751677
self.accelerator_type = accelerator_type
16761678
self.endpoint_name = endpoint_name
1679+
self.inference_component_name = inference_component_name
16771680
self.tags = format_tags(tags)
16781681
self.kms_key = kms_key
16791682
self.wait = wait

src/sagemaker/model.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def __init__(
358358
sagemaker_config=self._sagemaker_config,
359359
)
360360
self.endpoint_name = None
361+
self.inference_component_name = None
361362
self._is_compiled_model = False
362363
self._compilation_job_name = None
363364
self._is_edge_packaged_model = False
@@ -405,6 +406,16 @@ def __init__(
405406
self.response_types = None
406407
self.accept_eula = None
407408

409+
@classmethod
410+
def attach(
411+
cls,
412+
endpoint_name: str,
413+
inference_component_name: Optional[str] = None,
414+
sagemaker_session=None,
415+
) -> "Model":
416+
"""Attaches a Model object to an existing SageMaker Endpoint."""
417+
raise NotImplementedError
418+
408419
@runnable_by_pipeline
409420
def register(
410421
self,
@@ -1318,6 +1329,7 @@ def deploy(
13181329
resources: Optional[ResourceRequirements] = None,
13191330
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
13201331
managed_instance_scaling: Optional[str] = None,
1332+
inference_component_name=None,
13211333
routing_config: Optional[Dict[str, Any]] = None,
13221334
**kwargs,
13231335
):
@@ -1602,11 +1614,15 @@ def deploy(
16021614
"ComputeResourceRequirements": resources.get_compute_resource_requirements(),
16031615
}
16041616
runtime_config = {"CopyCount": resources.copy_count}
1605-
inference_component_name = unique_name_from_base(self.name)
1617+
self.inference_component_name = (
1618+
inference_component_name
1619+
or self.inference_component_name
1620+
or unique_name_from_base(self.name)
1621+
)
16061622

16071623
# [TODO]: Add endpoint_logging support
16081624
self.sagemaker_session.create_inference_component(
1609-
inference_component_name=inference_component_name,
1625+
inference_component_name=self.inference_component_name,
16101626
endpoint_name=self.endpoint_name,
16111627
variant_name="AllTraffic", # default variant name
16121628
specification=inference_component_spec,
@@ -1619,7 +1635,7 @@ def deploy(
16191635
predictor = self.predictor_cls(
16201636
self.endpoint_name,
16211637
self.sagemaker_session,
1622-
component_name=inference_component_name,
1638+
component_name=self.inference_component_name,
16231639
)
16241640
if serializer:
16251641
predictor.serializer = serializer

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):
219219

220220
assert response is not None
221221

222+
model = JumpStartModel.attach(predictor.endpoint_name, sagemaker_session=get_sm_session())
223+
assert model.model_id == model_id
224+
assert model.endpoint_name == predictor.endpoint_name
225+
assert model.inference_component_name == predictor.component_name
226+
222227

223228
@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning")
224229
def test_instatiating_model(mock_warning_logger, setup):

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,54 @@ def test_model_artifact_variant_model(
13241324
enable_network_isolation=True,
13251325
)
13261326

1327+
@mock.patch("sagemaker.jumpstart.model.get_model_id_version_from_endpoint")
1328+
@mock.patch("sagemaker.jumpstart.model.JumpStartModel.__init__")
1329+
def test_attach(
1330+
self,
1331+
mock_js_model_init,
1332+
mock_get_model_id_version_from_endpoint,
1333+
):
1334+
mock_js_model_init.return_value = None
1335+
mock_get_model_id_version_from_endpoint.return_value = "model-id", "model-version", None
1336+
val = JumpStartModel.attach("some-endpoint")
1337+
mock_get_model_id_version_from_endpoint.assert_called_once_with(
1338+
endpoint_name="some-endpoint",
1339+
inference_component_name=None,
1340+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1341+
)
1342+
mock_js_model_init.assert_called_once_with(
1343+
model_id="model-id",
1344+
model_version="model-version",
1345+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1346+
)
1347+
assert isinstance(val, JumpStartModel)
1348+
1349+
mock_get_model_id_version_from_endpoint.reset_mock()
1350+
JumpStartModel.attach("some-endpoint", model_id="some-id")
1351+
mock_get_model_id_version_from_endpoint.assert_called_once_with(
1352+
endpoint_name="some-endpoint",
1353+
inference_component_name=None,
1354+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1355+
)
1356+
1357+
mock_get_model_id_version_from_endpoint.reset_mock()
1358+
JumpStartModel.attach("some-endpoint", model_id="some-id", model_version="some-version")
1359+
mock_get_model_id_version_from_endpoint.assert_called_once_with(
1360+
endpoint_name="some-endpoint",
1361+
inference_component_name=None,
1362+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1363+
)
1364+
1365+
# providing model id, version, and ic name should bypass check with endpoint tags
1366+
mock_get_model_id_version_from_endpoint.reset_mock()
1367+
JumpStartModel.attach(
1368+
"some-endpoint",
1369+
model_id="some-id",
1370+
model_version="some-version",
1371+
inference_component_name="some-ic-name",
1372+
)
1373+
mock_get_model_id_version_from_endpoint.assert_not_called()
1374+
13271375
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
13281376
@mock.patch(
13291377
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"

0 commit comments

Comments
 (0)