Skip to content

Commit 77ed288

Browse files
milahahsan-z-khan
andauthored
change: Add configuration option with headers for Clarify Explainability (#2446)
Co-authored-by: Ahsan Khan <[email protected]>
1 parent 4c0d3cf commit 77ed288

File tree

2 files changed

+76
-12
lines changed

2 files changed

+76
-12
lines changed

src/sagemaker/clarify.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,10 @@ def run_explainability(
723723
endpoint to be created.
724724
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
725725
specific explainability method. Currently, only SHAP is supported.
726-
model_scores: Index or JSONPath location in the model output for the predicted scores
727-
to be explained. This is not required if the model output is a single score.
726+
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
727+
model output for the predicted scores to be explained. This is not required if the
728+
model output is a single score. Alternatively, an instance of
729+
ModelPredictedLabelConfig can be provided.
728730
wait (bool): Whether the call should wait until the job completes (default: True).
729731
logs (bool): Whether to show the logs produced by the job.
730732
Only meaningful when ``wait`` is True (default: True).
@@ -740,7 +742,12 @@ def run_explainability(
740742
"""
741743
analysis_config = data_config.get_config()
742744
predictor_config = model_config.get_predictor_config()
743-
_set(model_scores, "label", predictor_config)
745+
if isinstance(model_scores, ModelPredictedLabelConfig):
746+
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
747+
_set(probability_threshold, "probability_threshold", analysis_config)
748+
predictor_config.update(predicted_label_config)
749+
else:
750+
_set(model_scores, "label", predictor_config)
744751
analysis_config["methods"] = explainability_config.get_explainability_config()
745752
analysis_config["predictor"] = predictor_config
746753
if job_name is None:

tests/unit/test_clarify.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -442,21 +442,22 @@ def test_post_training_bias(
442442
)
443443

444444

445-
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
446-
def test_shap(
445+
def _run_test_shap(
447446
name_from_base,
448447
clarify_processor,
449448
clarify_processor_with_job_name_prefix,
450449
data_config,
451450
model_config,
452451
shap_config,
452+
model_scores,
453+
expected_predictor_config,
453454
):
454455
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
455456
clarify_processor.run_explainability(
456457
data_config,
457458
model_config,
458459
shap_config,
459-
model_scores=None,
460+
model_scores=model_scores,
460461
wait=True,
461462
job_name="test",
462463
experiment_config={"ExperimentName": "AnExperiment"},
@@ -485,11 +486,7 @@ def test_shap(
485486
"save_local_shap_values": True,
486487
}
487488
},
488-
"predictor": {
489-
"model_name": "xgboost-model",
490-
"instance_type": "ml.c5.xlarge",
491-
"initial_instance_count": 1,
492-
},
489+
"predictor": expected_predictor_config,
493490
}
494491
mock_method.assert_called_with(
495492
data_config,
@@ -504,7 +501,7 @@ def test_shap(
504501
data_config,
505502
model_config,
506503
shap_config,
507-
model_scores=None,
504+
model_scores=model_scores,
508505
wait=True,
509506
experiment_config={"ExperimentName": "AnExperiment"},
510507
)
@@ -518,3 +515,63 @@ def test_shap(
518515
None,
519516
{"ExperimentName": "AnExperiment"},
520517
)
518+
519+
520+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
521+
def test_shap(
522+
name_from_base,
523+
clarify_processor,
524+
clarify_processor_with_job_name_prefix,
525+
data_config,
526+
model_config,
527+
shap_config,
528+
):
529+
expected_predictor_config = {
530+
"model_name": "xgboost-model",
531+
"instance_type": "ml.c5.xlarge",
532+
"initial_instance_count": 1,
533+
}
534+
_run_test_shap(
535+
name_from_base,
536+
clarify_processor,
537+
clarify_processor_with_job_name_prefix,
538+
data_config,
539+
model_config,
540+
shap_config,
541+
None,
542+
expected_predictor_config,
543+
)
544+
545+
546+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
547+
def test_shap_with_predicted_label(
548+
name_from_base,
549+
clarify_processor,
550+
clarify_processor_with_job_name_prefix,
551+
data_config,
552+
model_config,
553+
shap_config,
554+
):
555+
probability = "pr"
556+
label_headers = ["success"]
557+
model_scores = ModelPredictedLabelConfig(
558+
probability=probability,
559+
label_headers=label_headers,
560+
)
561+
expected_predictor_config = {
562+
"model_name": "xgboost-model",
563+
"instance_type": "ml.c5.xlarge",
564+
"initial_instance_count": 1,
565+
"probability": probability,
566+
"label_headers": label_headers,
567+
}
568+
_run_test_shap(
569+
name_from_base,
570+
clarify_processor,
571+
clarify_processor_with_job_name_prefix,
572+
data_config,
573+
model_config,
574+
shap_config,
575+
model_scores,
576+
expected_predictor_config,
577+
)

0 commit comments

Comments
 (0)