Skip to content

change: Add label_headers option for Clarify ModelExplainabilityMonitor #2707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,15 @@ def __init__(
probability_threshold (float): An optional value for binary prediction tasks in which
the model returns a probability, to indicate the threshold to convert the
prediction to a boolean value. Default is 0.5.
label_headers (list): List of label values - one for each score of the ``probability``.
label_headers (list[str]): List of headers, each for a predicted score in model output.
For bias analysis, it is used to extract the label value with the highest score as
predicted label. For explainability job, It is used to beautify the analysis report
by replacing placeholders like "label0".
"""
self.label = label
self.probability = probability
self.probability_threshold = probability_threshold
self.label_headers = label_headers
if probability_threshold is not None:
try:
float(probability_threshold)
Expand Down Expand Up @@ -1060,10 +1064,10 @@ def run_explainability(
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
Config of the specific explainability method or a list of ExplainabilityConfig
objects. Currently, SHAP and PDP are the two methods supported.
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
model output for the predicted scores to be explained. This is not required if the
model output is a single score. Alternatively, an instance of
ModelPredictedLabelConfig can be provided.
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
Index or JSONPath to locate the predicted scores in the model output. This is not
required if the model output is a single score. Alternatively, it can be an instance
of ModelPredictedLabelConfig to provide more parameters like label_headers.
wait (bool): Whether the call should wait until the job completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when ``wait`` is True (default: True).
Expand Down
31 changes: 24 additions & 7 deletions src/sagemaker/model_monitor/clarify_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sagemaker import image_uris, s3
from sagemaker.session import Session
from sagemaker.utils import name_from_base
from sagemaker.clarify import SageMakerClarifyProcessor
from sagemaker.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -833,9 +833,10 @@ def suggest_baseline(
specific explainability method. Currently, only SHAP is supported.
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
endpoint to be created.
model_scores (int or str): Index or JSONPath location in the model output for the
predicted scores to be explained. This is not required if the model output is
a single score.
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
Index or JSONPath to locate the predicted scores in the model output. This is not
required if the model output is a single score. Alternatively, it can be an instance
of ModelPredictedLabelConfig to provide more parameters like label_headers.
wait (bool): Whether the call should wait until the job completes (default: False).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when wait is True (default: False).
Expand Down Expand Up @@ -865,14 +866,24 @@ def suggest_baseline(
headers = copy.deepcopy(data_config.headers)
if headers and data_config.label in headers:
headers.remove(data_config.label)
if model_scores is None:
inference_attribute = None
label_headers = None
elif isinstance(model_scores, ModelPredictedLabelConfig):
inference_attribute = str(model_scores.label)
label_headers = model_scores.label_headers
else:
inference_attribute = str(model_scores)
label_headers = None
self.latest_baselining_job_config = ClarifyBaseliningConfig(
analysis_config=ExplainabilityAnalysisConfig(
explainability_config=explainability_config,
model_config=model_config,
headers=headers,
label_headers=label_headers,
),
features_attribute=data_config.features,
inference_attribute=model_scores if model_scores is None else str(model_scores),
inference_attribute=inference_attribute,
)
self.latest_baselining_job_name = baselining_job_name
self.latest_baselining_job = ClarifyBaseliningJob(
Expand Down Expand Up @@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
class ExplainabilityAnalysisConfig:
"""Analysis configuration for ModelExplainabilityMonitor."""

def __init__(self, explainability_config, model_config, headers=None):
def __init__(self, explainability_config, model_config, headers=None, label_headers=None):
"""Creates an analysis config dictionary.

Args:
Expand All @@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None):
model_config (sagemaker.clarify.ModelConfig): Config object related to bias
configurations.
headers (list[str]): A list of feature names (without label) of model/endpint input.
label_headers (list[str]): List of headers, each for a predicted score in model output.
It is used to beautify the analysis report by replacing placeholders like "label0".

"""
predictor_config = model_config.get_predictor_config()
self.analysis_config = {
"methods": explainability_config.get_explainability_config(),
"predictor": model_config.get_predictor_config(),
"predictor": predictor_config,
}
if headers is not None:
self.analysis_config["headers"] = headers
if label_headers is not None:
predictor_config["label_headers"] = label_headers

def _to_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/test_clarify_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
HEADER_OF_LABEL = "Label"
HEADERS_OF_FEATURES = ["F1", "F2", "F3", "F4", "F5", "F6", "F7"]
ALL_HEADERS = [*HEADERS_OF_FEATURES, HEADER_OF_LABEL]
HEADER_OF_PREDICTION = "Decision"
DATASET_TYPE = "text/csv"
CONTENT_TYPE = DATASET_TYPE
ACCEPT_TYPE = DATASET_TYPE
Expand Down Expand Up @@ -325,7 +326,7 @@ def scheduled_explainability_monitor(
):
monitor_schedule_name = utils.unique_name_from_base("explainability-monitor")
analysis_config = ExplainabilityAnalysisConfig(
shap_config, model_config, headers=HEADERS_OF_FEATURES
shap_config, model_config, headers=HEADERS_OF_FEATURES, label_headers=[HEADER_OF_PREDICTION]
)
s3_uri_monitoring_output = os.path.join(
"s3://",
Expand Down
38 changes: 33 additions & 5 deletions tests/unit/sagemaker/monitor/test_clarify_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
# for bias
ANALYSIS_CONFIG_LABEL = "Label"
ANALYSIS_CONFIG_HEADERS_OF_FEATURES = ["F1", "F2", "F3"]
ANALYSIS_CONFIG_LABEL_HEADERS = ["Decision"]
ANALYSIS_CONFIG_ALL_HEADERS = [*ANALYSIS_CONFIG_HEADERS_OF_FEATURES, ANALYSIS_CONFIG_LABEL]
ANALYSIS_CONFIG_LABEL_VALUES = [1]
ANALYSIS_CONFIG_FACET_NAME = "F1"
Expand Down Expand Up @@ -330,6 +331,11 @@
"content_type": CONTENT_TYPE,
},
}
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS = copy.deepcopy(EXPLAINABILITY_ANALYSIS_CONFIG)
# noinspection PyTypeChecker
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS["predictor"][
"label_headers"
] = ANALYSIS_CONFIG_LABEL_HEADERS


@pytest.fixture()
Expand Down Expand Up @@ -1048,25 +1054,44 @@ def test_explainability_analysis_config(shap_config, model_config):
explainability_config=shap_config,
model_config=model_config,
headers=ANALYSIS_CONFIG_HEADERS_OF_FEATURES,
label_headers=ANALYSIS_CONFIG_LABEL_HEADERS,
)
assert EXPLAINABILITY_ANALYSIS_CONFIG == config._to_dict()
assert EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS == config._to_dict()


@pytest.mark.parametrize(
"model_scores,explainability_analysis_config",
[
(INFERENCE_ATTRIBUTE, EXPLAINABILITY_ANALYSIS_CONFIG),
(
ModelPredictedLabelConfig(
label=INFERENCE_ATTRIBUTE, label_headers=ANALYSIS_CONFIG_LABEL_HEADERS
),
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS,
),
],
)
def test_model_explainability_monitor_suggest_baseline(
model_explainability_monitor, sagemaker_session, data_config, shap_config, model_config
model_explainability_monitor,
sagemaker_session,
data_config,
shap_config,
model_config,
model_scores,
explainability_analysis_config,
):
clarify_model_monitor = model_explainability_monitor
# suggest baseline
clarify_model_monitor.suggest_baseline(
data_config=data_config,
explainability_config=shap_config,
model_config=model_config,
model_scores=INFERENCE_ATTRIBUTE,
model_scores=model_scores,
job_name=BASELINING_JOB_NAME,
)
assert isinstance(clarify_model_monitor.latest_baselining_job, ClarifyBaseliningJob)
assert (
EXPLAINABILITY_ANALYSIS_CONFIG
explainability_analysis_config
== clarify_model_monitor.latest_baselining_job_config.analysis_config._to_dict()
)
clarify_baselining_job = clarify_model_monitor.latest_baselining_job
Expand All @@ -1081,6 +1106,7 @@ def test_model_explainability_monitor_suggest_baseline(
analysis_config=None, # will pick up config from baselining job
baseline_job_name=BASELINING_JOB_NAME,
endpoint_input=ENDPOINT_NAME,
explainability_analysis_config=explainability_analysis_config,
# will pick up attributes from baselining job
)

Expand Down Expand Up @@ -1133,6 +1159,7 @@ def test_model_explainability_monitor_created_with_config(
sagemaker_session=sagemaker_session,
analysis_config=analysis_config,
constraints=CONSTRAINTS,
explainability_analysis_config=EXPLAINABILITY_ANALYSIS_CONFIG,
)

# update schedule
Expand Down Expand Up @@ -1263,6 +1290,7 @@ def _test_model_explainability_monitor_create_schedule(
features_attribute=FEATURES_ATTRIBUTE,
inference_attribute=str(INFERENCE_ATTRIBUTE),
),
explainability_analysis_config=None,
):
# create schedule
with patch(
Expand All @@ -1278,7 +1306,7 @@ def _test_model_explainability_monitor_create_schedule(
)
if not isinstance(analysis_config, str):
upload.assert_called_once()
assert json.loads(upload.call_args[0][0]) == EXPLAINABILITY_ANALYSIS_CONFIG
assert json.loads(upload.call_args[0][0]) == explainability_analysis_config

# validation
expected_arguments = {
Expand Down