Skip to content

Commit 57474fa

Browse files
committed
feat: added endpoint_name to clarify.ModelConfig
1 parent 6fc02f4 commit 57474fa

File tree

2 files changed

+126
-18
lines changed

2 files changed

+126
-18
lines changed

src/sagemaker/clarify.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -276,26 +276,33 @@ class ModelConfig:
276276

277277
def __init__(
278278
self,
279-
model_name,
280-
instance_count,
281-
instance_type,
282-
accept_type=None,
283-
content_type=None,
284-
content_template=None,
285-
custom_attributes=None,
286-
accelerator_type=None,
287-
endpoint_name_prefix=None,
288-
target_model=None,
279+
model_name: str = None,
280+
instance_count: int = None,
281+
instance_type: str = None,
282+
accept_type: str = None,
283+
content_type: str = None,
284+
content_template: str = None,
285+
custom_attributes: str = None,
286+
accelerator_type: str = None,
287+
endpoint_name_prefix: str = None,
288+
target_model: str = None,
289+
endpoint_name: str = None,
289290
):
290291
r"""Initializes a configuration of a model and the endpoint to be created for it.
291292
292293
Args:
293294
model_name (str): Model name (as created by
294295
`CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_.
296+
Cannot be set when ``endpoint_name`` is set.
297+
Must be set with ``instance_count``, ``instance_type``
295298
instance_count (int): The number of instances of a new endpoint for model inference.
299+
Cannot be set when ``endpoint_name`` is set.
300+
Must be set with ``model_name``, ``instance_type``
296301
instance_type (str): The type of
297302
`EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
298303
to use for model inference; for example, ``"ml.c5.xlarge"``.
304+
Cannot be set when ``endpoint_name`` is set.
305+
Must be set with ``instance_count``, ``model_name``
299306
accept_type (str): The model output format to be used for getting inferences with the
300307
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
301308
``"application/jsonlines"``. Default is the same as ``content_type``.
@@ -325,17 +332,41 @@ def __init__(
325332
target_model (str): Sets the target model name when using a multi-model endpoint. For
326333
more information about multi-model endpoints, see
327334
https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html
335+
endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint.
336+
Cannot be set when ``model_name``, ``instance_count``,
337+
and ``instance_type`` set
328338
329339
Raises:
330-
ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid,
331-
``content_type`` is invalid, or ``content_template`` has no placeholder "features"
340+
ValueError: when the
341+
- ``endpoint_name_prefix`` is invalid,
342+
- ``accept_type`` is invalid,
343+
- ``content_type`` is invalid,
344+
- ``content_template`` has no placeholder "features"
345+
- both [``endpoint_name``]
346+
AND [``model_name``, ``instance_count``, ``instance_type``] are set
347+
- both [``endpoint_name``] AND [``endpoint_name_prefix``] are set
332348
"""
333-
self.predictor_config = {
334-
"model_name": model_name,
335-
"instance_type": instance_type,
336-
"initial_instance_count": instance_count,
337-
}
338-
if endpoint_name_prefix is not None:
349+
350+
# validation
351+
_model_endpoint_config_rule = (
352+
all([model_name, instance_count, instance_type]),
353+
all([endpoint_name]),
354+
)
355+
assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule)
356+
if endpoint_name:
357+
assert not endpoint_name_prefix
358+
359+
# main init logic
360+
self.predictor_config = (
361+
{
362+
"model_name": model_name,
363+
"instance_type": instance_type,
364+
"initial_instance_count": instance_count,
365+
}
366+
if not endpoint_name
367+
else {"endpoint_name": endpoint_name}
368+
)
369+
if endpoint_name_prefix:
339370
if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
340371
raise ValueError(
341372
"Invalid endpoint_name_prefix."

tests/unit/test_clarify.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,42 @@ def pdp_config():
739739
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
740740

741741

742+
def test_model_config_validations():
743+
new_model_endpoint_definition = {
744+
"model_name": "xgboost-model",
745+
"instance_type": "ml.c5.xlarge",
746+
"instance_count": 1,
747+
}
748+
existing_endpoint_definition = {"endpoint_name": "existing_endpoint"}
749+
750+
with pytest.raises(AssertionError):
751+
# should be one of them
752+
ModelConfig(
753+
**new_model_endpoint_definition,
754+
**existing_endpoint_definition,
755+
)
756+
757+
with pytest.raises(AssertionError):
758+
# should be one of them
759+
ModelConfig(
760+
endpoint_name_prefix="prefix",
761+
**existing_endpoint_definition,
762+
)
763+
764+
# success path for new model
765+
assert ModelConfig(**new_model_endpoint_definition).predictor_config == {
766+
"initial_instance_count": 1,
767+
"instance_type": "ml.c5.xlarge",
768+
"model_name": "xgboost-model",
769+
}
770+
771+
# success path for existing endpoint
772+
assert (
773+
ModelConfig(**existing_endpoint_definition).predictor_config
774+
== existing_endpoint_definition
775+
)
776+
777+
742778
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
743779
def test_pre_training_bias(
744780
name_from_base,
@@ -1406,6 +1442,47 @@ def test_analysis_config_generator_for_bias_explainability(
14061442
assert actual == expected
14071443

14081444

1445+
def test_analysis_config_generator_for_bias_explainability_with_existing_endpoint(
1446+
data_config, data_bias_config
1447+
):
1448+
model_config = ModelConfig(endpoint_name="existing_endpoint_name")
1449+
model_predicted_label_config = ModelPredictedLabelConfig(
1450+
probability="pr",
1451+
label_headers=["success"],
1452+
)
1453+
actual = _AnalysisConfigGenerator.bias_and_explainability(
1454+
data_config,
1455+
model_config,
1456+
model_predicted_label_config,
1457+
[SHAPConfig(), PDPConfig()],
1458+
data_bias_config,
1459+
pre_training_methods="all",
1460+
post_training_methods="all",
1461+
)
1462+
expected = {
1463+
"dataset_type": "text/csv",
1464+
"facet": [{"name_or_index": "F1"}],
1465+
"group_variable": "F2",
1466+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1467+
"joinsource_name_or_index": "F4",
1468+
"label": "Label",
1469+
"label_values_or_threshold": [1],
1470+
"methods": {
1471+
"pdp": {"grid_resolution": 15, "top_k_features": 10},
1472+
"post_training_bias": {"methods": "all"},
1473+
"pre_training_bias": {"methods": "all"},
1474+
"report": {"name": "report", "title": "Analysis Report"},
1475+
"shap": {"save_local_shap_values": True, "use_logit": False},
1476+
},
1477+
"predictor": {
1478+
"label_headers": ["success"],
1479+
"endpoint_name": "existing_endpoint_name",
1480+
"probability": "pr",
1481+
},
1482+
}
1483+
assert actual == expected
1484+
1485+
14091486
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
14101487
actual = _AnalysisConfigGenerator.bias_pre_training(
14111488
data_config, data_bias_config, methods="all"

0 commit comments

Comments
 (0)