Skip to content

Commit ecd1e05

Browse files
committed
feat: added endpoint_name to clarify.ModelConfig
1 parent 5dcaca3 commit ecd1e05

File tree

2 files changed

+99
-9
lines changed

2 files changed

+99
-9
lines changed

src/sagemaker/clarify.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,17 @@ class ModelConfig:
272272

273273
def __init__(
274274
self,
275-
model_name,
276-
instance_count,
277-
instance_type,
275+
model_name=None,
276+
instance_count=None,
277+
instance_type=None,
278278
accept_type=None,
279279
content_type=None,
280280
content_template=None,
281281
custom_attributes=None,
282282
accelerator_type=None,
283283
endpoint_name_prefix=None,
284284
target_model=None,
285+
endpoint_name=None,
285286
):
286287
r"""Initializes a configuration of a model and the endpoint to be created for it.
287288
@@ -326,12 +327,25 @@ def __init__(
326327
ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid,
327328
``content_type`` is invalid, or ``content_template`` has no placeholder "features"
328329
"""
329-
self.predictor_config = {
330-
"model_name": model_name,
331-
"instance_type": instance_type,
332-
"initial_instance_count": instance_count,
333-
}
334-
if endpoint_name_prefix is not None:
330+
331+
_model_endpoint_config_rule = (
332+
all([model_name, instance_count, instance_type]),
333+
all([endpoint_name]),
334+
)
335+
assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule)
336+
if endpoint_name:
337+
assert not endpoint_name_prefix
338+
339+
self.predictor_config = (
340+
{
341+
"model_name": model_name,
342+
"instance_type": instance_type,
343+
"initial_instance_count": instance_count,
344+
}
345+
if not endpoint_name
346+
else {"endpoint_name": endpoint_name}
347+
)
348+
if endpoint_name_prefix:
335349
if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
336350
raise ValueError(
337351
"Invalid endpoint_name_prefix."

tests/unit/test_clarify.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,41 @@ def pdp_config():
735735
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
736736

737737

738+
def test_model_config_validations():
739+
new_model_endpoint_definition = {
740+
"model_name": "xgboost-model",
741+
"instance_type": "ml.c5.xlarge",
742+
"instance_count": 1,
743+
}
744+
existing_endpoint_definition = {"endpoint_name": "existing_endpoint"}
745+
746+
with pytest.raises(AssertionError):
747+
# should be one of them
748+
ModelConfig(
749+
**new_model_endpoint_definition,
750+
**existing_endpoint_definition,
751+
)
752+
753+
with pytest.raises(AssertionError):
754+
# should be one of them
755+
ModelConfig(
756+
endpoint_name_prefix="prefix",
757+
**existing_endpoint_definition,
758+
)
759+
760+
# success path for new model
761+
assert ModelConfig(**new_model_endpoint_definition).predictor_config == {
762+
"initial_instance_count": 1,
763+
"instance_type": "ml.c5.xlarge",
764+
"model_name": "xgboost-model",
765+
}
766+
767+
# success path for existing endpoint
768+
assert (
769+
ModelConfig(**existing_endpoint_definition).predictor_config == existing_endpoint_definition
770+
)
771+
772+
738773
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
739774
def test_pre_training_bias(
740775
name_from_base,
@@ -1396,6 +1431,47 @@ def test_analysis_config_generator_for_bias_explainability(
13961431
assert actual == expected
13971432

13981433

1434+
def test_analysis_config_generator_for_bias_explainability_with_existing_endpoint(
1435+
data_config, data_bias_config
1436+
):
1437+
model_config = ModelConfig(endpoint_name="existing_endpoint_name")
1438+
model_predicted_label_config = ModelPredictedLabelConfig(
1439+
probability="pr",
1440+
label_headers=["success"],
1441+
)
1442+
actual = _AnalysisConfigGenerator.bias_and_explainability(
1443+
data_config,
1444+
model_config,
1445+
model_predicted_label_config,
1446+
[SHAPConfig(), PDPConfig()],
1447+
data_bias_config,
1448+
pre_training_methods="all",
1449+
post_training_methods="all",
1450+
)
1451+
expected = {
1452+
"dataset_type": "text/csv",
1453+
"facet": [{"name_or_index": "F1"}],
1454+
"group_variable": "F2",
1455+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1456+
"joinsource_name_or_index": "F4",
1457+
"label": "Label",
1458+
"label_values_or_threshold": [1],
1459+
"methods": {
1460+
"pdp": {"grid_resolution": 15, "top_k_features": 10},
1461+
"post_training_bias": {"methods": "all"},
1462+
"pre_training_bias": {"methods": "all"},
1463+
"report": {"name": "report", "title": "Analysis Report"},
1464+
"shap": {"save_local_shap_values": True, "use_logit": False},
1465+
},
1466+
"predictor": {
1467+
"label_headers": ["success"],
1468+
"endpoint_name": "existing_endpoint_name",
1469+
"probability": "pr",
1470+
},
1471+
}
1472+
assert actual == expected
1473+
1474+
13991475
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
14001476
actual = _AnalysisConfigGenerator.bias_pre_training(
14011477
data_config, data_bias_config, methods="all"

0 commit comments

Comments
 (0)