Skip to content

Commit 3f2cea6

Browse files
committed
feat: added endpoint_name to clarify.ModelConfig
1 parent 5dcaca3 commit 3f2cea6

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

src/sagemaker/clarify.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,17 @@ class ModelConfig:
272272

273273
def __init__(
274274
self,
275-
model_name,
276-
instance_count,
277-
instance_type,
275+
model_name=None,
276+
instance_count=None,
277+
instance_type=None,
278278
accept_type=None,
279279
content_type=None,
280280
content_template=None,
281281
custom_attributes=None,
282282
accelerator_type=None,
283283
endpoint_name_prefix=None,
284284
target_model=None,
285+
endpoint_name=None,
285286
):
286287
r"""Initializes a configuration of a model and the endpoint to be created for it.
287288
@@ -326,12 +327,25 @@ def __init__(
326327
ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid,
327328
``content_type`` is invalid, or ``content_template`` has no placeholder "features"
328329
"""
329-
self.predictor_config = {
330-
"model_name": model_name,
331-
"instance_type": instance_type,
332-
"initial_instance_count": instance_count,
333-
}
334-
if endpoint_name_prefix is not None:
330+
331+
_model_endpoint_config_rule = (
332+
all([model_name, instance_count, instance_type]),
333+
all([endpoint_name]),
334+
)
335+
assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule)
336+
if endpoint_name:
337+
assert not endpoint_name_prefix
338+
339+
self.predictor_config = (
340+
{
341+
"model_name": model_name,
342+
"instance_type": instance_type,
343+
"initial_instance_count": instance_count,
344+
}
345+
if not endpoint_name
346+
else {"endpoint_name": endpoint_name}
347+
)
348+
if not endpoint_name and endpoint_name_prefix:
335349
if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
336350
raise ValueError(
337351
"Invalid endpoint_name_prefix."

tests/unit/test_clarify.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,41 @@ def pdp_config():
735735
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
736736

737737

738+
def test_model_config_validations():
739+
new_model_endpoint_definition = {
740+
"model_name": "xgboost-model",
741+
"instance_type": "ml.c5.xlarge",
742+
"instance_count": 1,
743+
}
744+
existing_endpoint_definition = {"endpoint_name": "existing_endpoint"}
745+
746+
with pytest.raises(AssertionError):
747+
# should be one of them
748+
ModelConfig(
749+
**new_model_endpoint_definition,
750+
**existing_endpoint_definition,
751+
)
752+
753+
with pytest.raises(AssertionError):
754+
# should be one of them
755+
ModelConfig(
756+
endpoint_name_prefix="prefix",
757+
**existing_endpoint_definition,
758+
)
759+
760+
# success path for new model
761+
assert ModelConfig(**new_model_endpoint_definition).predictor_config == {
762+
"initial_instance_count": 1,
763+
"instance_type": "ml.c5.xlarge",
764+
"model_name": "xgboost-model",
765+
}
766+
767+
# success path for existing endpoint
768+
assert (
769+
ModelConfig(**existing_endpoint_definition).predictor_config == existing_endpoint_definition
770+
)
771+
772+
738773
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
739774
def test_pre_training_bias(
740775
name_from_base,

0 commit comments

Comments
 (0)