File tree Expand file tree Collapse file tree 4 files changed +31
-0
lines changed
integ/sagemaker/jumpstart/model
unit/sagemaker/jumpstart/model Expand file tree Collapse file tree 4 files changed +31
-0
lines changed Original file line number Diff line number Diff line change @@ -412,6 +412,9 @@ def attach(
412
412
model_version : Optional [str ] = None ,
413
413
sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
414
414
) -> "JumpStartModel" :
415
+ """Attaches a JumpStartModel object to an existing SageMaker Endpoint.
416
+
417
+ The model id, version (and inference component name) can be inferred from the tags."""
415
418
416
419
inferred_model_id , inferred_model_version , inferred_inference_component_name = (
417
420
get_model_id_version_from_endpoint (
Original file line number Diff line number Diff line change @@ -408,6 +408,7 @@ def attach(
408
408
inference_component_name : Optional [str ] = None ,
409
409
sagemaker_session = None ,
410
410
) -> "Model" :
411
+ """Attaches a Model object to an existing SageMaker Endpoint."""
411
412
raise NotImplementedError
412
413
413
414
@runnable_by_pipeline
Original file line number Diff line number Diff line change @@ -219,6 +219,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):
219
219
220
220
assert response is not None
221
221
222
+ model = JumpStartModel .attach (predictor .endpoint_name , sagemaker_session = get_sm_session ())
223
+ assert model .model_id == model_id
224
+ assert model .endpoint_name == predictor .endpoint_name
225
+ assert model .inference_component_name == predictor .inference_component_name
226
+
222
227
223
228
@mock .patch ("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning" )
224
229
def test_instatiating_model (mock_warning_logger , setup ):
Original file line number Diff line number Diff line change @@ -1292,6 +1292,28 @@ def test_model_artifact_variant_model(
1292
1292
enable_network_isolation = True ,
1293
1293
)
1294
1294
1295
+ @mock .patch ("sagemaker.jumpstart.model.get_model_id_version_from_endpoint" )
1296
+ @mock .patch ("sagemaker.jumpstart.model.JumpStartModel.__init__" )
1297
+ def test_attach (
1298
+ self ,
1299
+ mock_js_model_init ,
1300
+ mock_get_model_id_version_from_endpoint ,
1301
+ ):
1302
+ mock_js_model_init .return_value = None
1303
+ mock_get_model_id_version_from_endpoint .return_value = "model-id" , "model-version" , None
1304
+ val = JumpStartModel .attach ("some-endpoint" )
1305
+ mock_get_model_id_version_from_endpoint .assert_called_once_with (
1306
+ endpoint_name = "some-endpoint" ,
1307
+ inference_component_name = None ,
1308
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
1309
+ )
1310
+ mock_js_model_init .assert_called_once_with (
1311
+ model_id = "model-id" ,
1312
+ model_version = "model-version" ,
1313
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
1314
+ )
1315
+ assert isinstance (val , JumpStartModel )
1316
+
1295
1317
@mock .patch ("sagemaker.jumpstart.model.validate_model_id_and_get_type" )
1296
1318
@mock .patch ("sagemaker.jumpstart.factory.model.Session" )
1297
1319
@mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
You can’t perform that action at this time.
0 commit comments