Skip to content

Commit b1e63d9

Browse files
committed
feat: added endpoint_name to clarify.ModelConfig
1 parent 67708bf commit b1e63d9

File tree

2 files changed

+120
-18
lines changed

2 files changed

+120
-18
lines changed

src/sagemaker/clarify.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -272,26 +272,30 @@ class ModelConfig:
272272

273273
def __init__(
274274
self,
275-
model_name,
276-
instance_count,
277-
instance_type,
278-
accept_type=None,
279-
content_type=None,
280-
content_template=None,
281-
custom_attributes=None,
282-
accelerator_type=None,
283-
endpoint_name_prefix=None,
284-
target_model=None,
275+
model_name: str = None,
276+
instance_count: int = None,
277+
instance_type: str = None,
278+
accept_type: str = None,
279+
content_type: str = None,
280+
content_template: str = None,
281+
custom_attributes: str = None,
282+
accelerator_type: str = None,
283+
endpoint_name_prefix: str = None,
284+
target_model: str = None,
285+
endpoint_name: str = None,
285286
):
286287
r"""Initializes a configuration of a model and the endpoint to be created for it.
287288
288289
Args:
289290
model_name (str): Model name (as created by
290291
`CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_.
292+
Cannot be set when endpoint_name is set. Must be set with `instance_count`, `instance_type`
291293
instance_count (int): The number of instances of a new endpoint for model inference.
294+
Cannot be set when endpoint_name is set. Must be set with `model_name`, `instance_type`
292295
instance_type (str): The type of
293296
`EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
294297
to use for model inference; for example, ``"ml.c5.xlarge"``.
298+
Cannot be set when endpoint_name is set. Must be set with `instance_count`, `model_name`
295299
accept_type (str): The model output format to be used for getting inferences with the
296300
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
297301
``"application/jsonlines"``. Default is the same as ``content_type``.
@@ -321,17 +325,39 @@ def __init__(
321325
target_model (str): Sets the target model name when using a multi-model endpoint. For
322326
more information about multi-model endpoints, see
323327
https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html
328+
endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint. Cannot be set
329+
when `model_name`, `instance_count`, and `instance_type` set
324330
325331
Raises:
326-
ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid,
327-
``content_type`` is invalid, or ``content_template`` has no placeholder "features"
332+
ValueError: when the
333+
- ``endpoint_name_prefix`` is invalid,
334+
- ``accept_type`` is invalid,
335+
- ``content_type`` is invalid,
336+
- ``content_template`` has no placeholder "features"
337+
- both [``endpoint_name``] AND [``model_name``, ``instance_count``, ``instance_type``] are set
338+
- both [``endpoint_name``] AND [``endpoint_name_prefix``] are set
328339
"""
329-
self.predictor_config = {
330-
"model_name": model_name,
331-
"instance_type": instance_type,
332-
"initial_instance_count": instance_count,
333-
}
334-
if endpoint_name_prefix is not None:
340+
341+
# validation
342+
_model_endpoint_config_rule = (
343+
all([model_name, instance_count, instance_type]),
344+
all([endpoint_name]),
345+
)
346+
assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule)
347+
if endpoint_name:
348+
assert not endpoint_name_prefix
349+
350+
# main init logic
351+
self.predictor_config = (
352+
{
353+
"model_name": model_name,
354+
"instance_type": instance_type,
355+
"initial_instance_count": instance_count,
356+
}
357+
if not endpoint_name
358+
else {"endpoint_name": endpoint_name}
359+
)
360+
if endpoint_name_prefix:
335361
if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
336362
raise ValueError(
337363
"Invalid endpoint_name_prefix."

tests/unit/test_clarify.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,41 @@ def pdp_config():
735735
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
736736

737737

738+
def test_model_config_validations():
739+
new_model_endpoint_definition = {
740+
"model_name": "xgboost-model",
741+
"instance_type": "ml.c5.xlarge",
742+
"instance_count": 1,
743+
}
744+
existing_endpoint_definition = {"endpoint_name": "existing_endpoint"}
745+
746+
with pytest.raises(AssertionError):
747+
# should be one of them
748+
ModelConfig(
749+
**new_model_endpoint_definition,
750+
**existing_endpoint_definition,
751+
)
752+
753+
with pytest.raises(AssertionError):
754+
# should be one of them
755+
ModelConfig(
756+
endpoint_name_prefix="prefix",
757+
**existing_endpoint_definition,
758+
)
759+
760+
# success path for new model
761+
assert ModelConfig(**new_model_endpoint_definition).predictor_config == {
762+
"initial_instance_count": 1,
763+
"instance_type": "ml.c5.xlarge",
764+
"model_name": "xgboost-model",
765+
}
766+
767+
# success path for existing endpoint
768+
assert (
769+
ModelConfig(**existing_endpoint_definition).predictor_config == existing_endpoint_definition
770+
)
771+
772+
738773
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
739774
def test_pre_training_bias(
740775
name_from_base,
@@ -1396,6 +1431,47 @@ def test_analysis_config_generator_for_bias_explainability(
13961431
assert actual == expected
13971432

13981433

1434+
def test_analysis_config_generator_for_bias_explainability_with_existing_endpoint(
1435+
data_config, data_bias_config
1436+
):
1437+
model_config = ModelConfig(endpoint_name="existing_endpoint_name")
1438+
model_predicted_label_config = ModelPredictedLabelConfig(
1439+
probability="pr",
1440+
label_headers=["success"],
1441+
)
1442+
actual = _AnalysisConfigGenerator.bias_and_explainability(
1443+
data_config,
1444+
model_config,
1445+
model_predicted_label_config,
1446+
[SHAPConfig(), PDPConfig()],
1447+
data_bias_config,
1448+
pre_training_methods="all",
1449+
post_training_methods="all",
1450+
)
1451+
expected = {
1452+
"dataset_type": "text/csv",
1453+
"facet": [{"name_or_index": "F1"}],
1454+
"group_variable": "F2",
1455+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1456+
"joinsource_name_or_index": "F4",
1457+
"label": "Label",
1458+
"label_values_or_threshold": [1],
1459+
"methods": {
1460+
"pdp": {"grid_resolution": 15, "top_k_features": 10},
1461+
"post_training_bias": {"methods": "all"},
1462+
"pre_training_bias": {"methods": "all"},
1463+
"report": {"name": "report", "title": "Analysis Report"},
1464+
"shap": {"save_local_shap_values": True, "use_logit": False},
1465+
},
1466+
"predictor": {
1467+
"label_headers": ["success"],
1468+
"endpoint_name": "existing_endpoint_name",
1469+
"probability": "pr",
1470+
},
1471+
}
1472+
assert actual == expected
1473+
1474+
13991475
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
14001476
actual = _AnalysisConfigGenerator.bias_pre_training(
14011477
data_config, data_bias_config, methods="all"

0 commit comments

Comments
 (0)