Skip to content

Commit dbec613

Browse files
authored
fix: create the correct session for MultiDataModel (#1255)
* fix: create the correct session for MultiDataModel * remove unused param in test * fix black-format
1 parent f7cd477 commit dbec613

File tree

3 files changed

+69
-1
lines changed

3 files changed

+69
-1
lines changed

src/sagemaker/local/local_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def invoke_endpoint(
341341
ContentType=None,
342342
Accept=None,
343343
CustomAttributes=None,
344+
TargetModel=None,
344345
):
345346
"""
346347
@@ -365,6 +366,9 @@ def invoke_endpoint(
365366
if CustomAttributes is not None:
366367
headers["X-Amzn-SageMaker-Custom-Attributes"] = CustomAttributes
367368

369+
if TargetModel is not None:
370+
headers["X-Amzn-SageMaker-Target-Model"] = TargetModel
371+
368372
r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)
369373

370374
return {"Body": r, "ContentType": Accept}

src/sagemaker/multidatamodel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from six.moves.urllib.parse import urlparse
1818

1919
import sagemaker
20-
from sagemaker import s3
20+
from sagemaker import local, s3
2121
from sagemaker.model import Model
2222
from sagemaker.session import Session
2323

@@ -210,6 +210,9 @@ def deploy(
210210
if role is None:
211211
raise ValueError("Role can not be null for deploying a model")
212212

213+
if instance_type == "local" and not isinstance(self.sagemaker_session, local.LocalSession):
214+
self.sagemaker_session = local.LocalSession()
215+
213216
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
214217
self.sagemaker_session.create_model(
215218
self.name,

tests/integ/test_multidatamodel.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,67 @@ def test_multi_data_model_deploy_pretrained_models(
177177
assert "Could not find endpoint" in str(exception.value)
178178

179179

180+
@pytest.mark.local_mode
181+
def test_multi_data_model_deploy_pretrained_models_local_mode(container_image, sagemaker_session):
182+
timestamp = sagemaker_timestamp()
183+
endpoint_name = "test-multimodel-endpoint-{}".format(timestamp)
184+
model_name = "test-multimodel-{}".format(timestamp)
185+
186+
# Define pretrained model local path
187+
pretrained_model_data_local_path = os.path.join(DATA_DIR, "sparkml_model", "mleap_model.tar.gz")
188+
189+
with timeout(minutes=30):
190+
model_data_prefix = os.path.join(
191+
"s3://", sagemaker_session.default_bucket(), "multimodel-{}/".format(timestamp)
192+
)
193+
multi_data_model = MultiDataModel(
194+
name=model_name,
195+
model_data_prefix=model_data_prefix,
196+
image=container_image,
197+
role=ROLE,
198+
sagemaker_session=sagemaker_session,
199+
)
200+
201+
# Add model before deploy
202+
multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_1)
203+
# Deploy model to an endpoint
204+
multi_data_model.deploy(1, "local", endpoint_name=endpoint_name)
205+
# Add models after deploy
206+
multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_2)
207+
208+
endpoint_models = []
209+
for model_path in multi_data_model.list_models():
210+
endpoint_models.append(model_path)
211+
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
212+
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
213+
214+
predictor = RealTimePredictor(
215+
endpoint=endpoint_name,
216+
sagemaker_session=multi_data_model.sagemaker_session,
217+
serializer=npy_serializer,
218+
deserializer=string_deserializer,
219+
)
220+
221+
data = numpy.zeros(shape=(1, 1, 28, 28))
222+
result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_1)
223+
assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_1)
224+
225+
result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_2)
226+
assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_2)
227+
228+
# Cleanup
229+
multi_data_model.sagemaker_session.sagemaker_client.delete_endpoint_config(
230+
EndpointConfigName=endpoint_name
231+
)
232+
multi_data_model.sagemaker_session.delete_endpoint(endpoint_name)
233+
multi_data_model.delete_model()
234+
with pytest.raises(Exception) as exception:
235+
sagemaker_session.sagemaker_client.describe_model(ModelName=multi_data_model.name)
236+
assert "Could not find model" in str(exception.value)
237+
sagemaker_session.sagemaker_client.describe_endpoint_config(name=endpoint_name)
238+
assert "Could not find endpoint" in str(exception.value)
239+
240+
180241
def test_multi_data_model_deploy_trained_model_from_framework_estimator(
181242
container_image, sagemaker_session, cpu_instance_type
182243
):

0 commit comments

Comments
 (0)