Skip to content

Commit 7b0b80e

Browse files
committed
chore: update docstrings, add tests
1 parent 2b6b051 commit 7b0b80e

File tree

4 files changed

+31
-0
lines changed

4 files changed

+31
-0
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,9 @@ def attach(
412412
model_version: Optional[str] = None,
413413
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
414414
) -> "JumpStartModel":
415+
"""Attaches a JumpStartModel object to an existing SageMaker Endpoint.
416+
417+
The model id, version (and inference component name) can be inferred from the tags."""
415418

416419
inferred_model_id, inferred_model_version, inferred_inference_component_name = (
417420
get_model_id_version_from_endpoint(

src/sagemaker/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def attach(
408408
inference_component_name: Optional[str] = None,
409409
sagemaker_session=None,
410410
) -> "Model":
411+
"""Attaches a Model object to an existing SageMaker Endpoint."""
411412
raise NotImplementedError
412413

413414
@runnable_by_pipeline

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):
219219

220220
assert response is not None
221221

222+
model = JumpStartModel.attach(predictor.endpoint_name, sagemaker_session=get_sm_session())
223+
assert model.model_id == model_id
224+
assert model.endpoint_name == predictor.endpoint_name
225+
assert model.inference_component_name == predictor.inference_component_name
226+
222227

223228
@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning")
224229
def test_instatiating_model(mock_warning_logger, setup):

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,28 @@ def test_model_artifact_variant_model(
12921292
enable_network_isolation=True,
12931293
)
12941294

1295+
@mock.patch("sagemaker.jumpstart.model.get_model_id_version_from_endpoint")
1296+
@mock.patch("sagemaker.jumpstart.model.JumpStartModel.__init__")
1297+
def test_attach(
1298+
self,
1299+
mock_js_model_init,
1300+
mock_get_model_id_version_from_endpoint,
1301+
):
1302+
mock_js_model_init.return_value = None
1303+
mock_get_model_id_version_from_endpoint.return_value = "model-id", "model-version", None
1304+
val = JumpStartModel.attach("some-endpoint")
1305+
mock_get_model_id_version_from_endpoint.assert_called_once_with(
1306+
endpoint_name="some-endpoint",
1307+
inference_component_name=None,
1308+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1309+
)
1310+
mock_js_model_init.assert_called_once_with(
1311+
model_id="model-id",
1312+
model_version="model-version",
1313+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1314+
)
1315+
assert isinstance(val, JumpStartModel)
1316+
12951317
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
12961318
@mock.patch("sagemaker.jumpstart.factory.model.Session")
12971319
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")

0 commit comments

Comments
 (0)