Skip to content

Commit 8573440

Browse files
prkrishnan1keerthanvasistPranav Krishnan
authored andcommitted
feature: add CV shap explainability for SageMaker Clarify
Enabled CV explainability for SHAP in SageMaker Clarify. - Since CV explainability parameters are part of SHAP parameters, SHAPConfig includes a new parameter 'image_config' which can be set to a non-None value. Default is None. - To handle image config parameters, this change created a new class ImageConfig which accepts the following parameters: model_type, num_segments, feature_extraction_method, segment_compactness, max_objects, iou_threshold, context - To enable image data to be used with Clarify, added a new accepted 'dataset_type' 'application/x-image' to the list of valid dataset_types in DataConfig Co-authored-by: keerthanvasist <[email protected]> Co-authored-by: Pranav Krishnan <[email protected]>
1 parent 7620749 commit 8573440

File tree

2 files changed

+226
-32
lines changed

2 files changed

+226
-32
lines changed

src/sagemaker/clarify.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import logging
1919
import os
2020
import re
21+
2122
import tempfile
2223
from abc import ABC, abstractmethod
23-
2424
from sagemaker import image_uris, s3, utils
2525
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
2626

@@ -67,7 +67,12 @@ def __init__(
6767
optional field in all cases except when the dataset contains more than one file,
6868
and `save_local_shap_values` is set to true in SHAPConfig.
6969
"""
70-
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
70+
if dataset_type not in [
71+
"text/csv",
72+
"application/jsonlines",
73+
"application/x-parquet",
74+
"application/x-image",
75+
]:
7176
raise ValueError(
7277
f"Invalid dataset_type '{dataset_type}'."
7378
f" Please check the API documentation for the supported dataset types."
@@ -212,7 +217,14 @@ def __init__(
212217
)
213218
self.predictor_config["accept_type"] = accept_type
214219
if content_type is not None:
215-
if content_type not in ["text/csv", "application/jsonlines"]:
220+
if content_type not in [
221+
"text/csv",
222+
"application/jsonlines",
223+
"image/jpeg",
224+
"image/jpg",
225+
"image/png",
226+
"application/x-npy",
227+
]:
216228
raise ValueError(
217229
f"Invalid content_type {content_type}."
218230
f" Please choose text/csv or application/jsonlines."
@@ -456,6 +468,65 @@ def get_text_config(self):
456468
return copy.deepcopy(self.text_config)
457469

458470

471+
class ImageConfig:
472+
"""Config object for handling images"""
473+
474+
def __init__(
475+
self,
476+
model_type,
477+
num_segments=None,
478+
feature_extraction_method=None,
479+
segment_compactness=None,
480+
max_objects=None,
481+
iou_threshold=None,
482+
context=None,
483+
):
484+
"""Initializes all configuration parameters needed for SHAP CV explainability
485+
486+
Args:
487+
model_type (str): Specifies the type of CV model. Options:
488+
(IMAGE_CLASSIFICATION | OBJECT_DETECTION).
489+
num_segments (None or int): Clarify uses SKLearn's SLIC method for image segmentation
490+
to generate features/superpixels. num_segments specifies approximate
491+
number of segments to be generated. Default is None. SLIC will default to
492+
100 segments.
493+
feature_extraction_method (None or str): method used for extracting features from the
494+
image.ex. "segmentation". Default is segmentation.
495+
segment_compactness (None or float): Balances color proximity and space proximity.
496+
Higher values give more weight to space proximity, making superpixel
497+
shapes more square/cubic. We recommend exploring possible values on a log
498+
scale, e.g., 0.01, 0.1, 1, 10, 100, before refining around a chosen value.
499+
max_objects (None or int): maximum number of objects displayed. Object detection
500+
algorithm may detect more than max_objects number of objects in a single
501+
image. The top max_objects number of objects according to confidence score
502+
will be displayed.
503+
iou_threshold (None or float): minimum intersection over union for the object
504+
bounding box to consider its confidence score for computing SHAP values [0.0, 1.0].
505+
This parameter is used for the object detection case.
506+
context (None or float): refers to the portion of the image outside of the bounding box.
507+
Scale is [0.0, 1.0]. If set to 1.0, whole image is considered, if set to
508+
0.0 only the image inside bounding box is considered.
509+
"""
510+
self.image_config = {}
511+
512+
if model_type not in ["OBJECT_DETECTION", "IMAGE_CLASSIFICATION"]:
513+
raise ValueError(
514+
"Clarify SHAP only supports object detection and image classification methods. "
515+
"Please set model_type to OBJECT_DETECTION or IMAGE_CLASSIFICATION."
516+
)
517+
self.image_config["model_type"] = model_type
518+
_set(num_segments, "num_segments", self.image_config)
519+
_set(feature_extraction_method, "feature_extraction_method", self.image_config)
520+
_set(segment_compactness, "segment_compactness", self.image_config)
521+
_set(max_objects, "max_objects", self.image_config)
522+
_set(iou_threshold, "iou_threshold", self.image_config)
523+
_set(context, "context", self.image_config)
524+
525+
def get_image_config(self):
526+
"""Returns the image config part of an analysis config dictionary."""
527+
return copy.deepcopy(self.image_config)
528+
529+
459530
class SHAPConfig(ExplainabilityConfig):
460531
"""Config class of SHAP."""
461532

@@ -469,6 +540,7 @@ def __init__(
469540
seed=None,
470541
num_clusters=None,
471542
text_config=None,
543+
image_config=None,
472544
):
473545
"""Initializes config for SHAP.
474546
@@ -497,7 +569,10 @@ def __init__(
497569
computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
498570
num_clusters is a parameter for this algorithm. num_clusters will be the resulting
499571
size of the baseline dataset. If not provided, Clarify job will use a default value.
500-
text_config (:class:`~sagemaker.clarify.TextConfig`): Config to handle text features
572+
text_config (:class:`~sagemaker.clarify.TextConfig`): Config to handle text features.
573+
Default is None
574+
image_config (:class:`~sagemaker.clarify.ImageConfig`): Config to handle image features.
575+
Default is None
501576
"""
502577
if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]:
503578
raise ValueError(
@@ -512,17 +587,11 @@ def __init__(
512587
"use_logit": use_logit,
513588
"save_local_shap_values": save_local_shap_values,
514589
}
515-
if baseline is not None:
516-
self.shap_config["baseline"] = baseline
517-
if num_samples is not None:
518-
self.shap_config["num_samples"] = num_samples
519-
if agg_method is not None:
520-
self.shap_config["agg_method"] = agg_method
521-
if seed is not None:
522-
self.shap_config["seed"] = seed
523-
if num_clusters is not None:
524-
self.shap_config["num_clusters"] = num_clusters
590+
_set(baseline, "baseline", self.shap_config)
591+
_set(num_samples, "num_samples", self.shap_config)
592+
_set(agg_method, "agg_method", self.shap_config)
525593
_set(seed, "seed", self.shap_config)
594+
_set(num_clusters, "num_clusters", self.shap_config)
526595
if text_config:
527596
_set(text_config.get_text_config(), "text_config", self.shap_config)
528597
if not save_local_shap_values:
@@ -531,6 +600,8 @@ def __init__(
531600
"Consider setting save_local_shap_values=True to inspect local text "
532601
"explanations."
533602
)
603+
if image_config:
604+
_set(image_config.get_image_config(), "image_config", self.shap_config)
534605

535606
def get_explainability_config(self):
536607
"""Returns config."""

tests/unit/test_clarify.py

Lines changed: 141 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SageMakerClarifyProcessor,
2929
SHAPConfig,
3030
TextConfig,
31+
ImageConfig,
3132
)
3233

3334
JOB_NAME_PREFIX = "my-prefix"
@@ -254,17 +255,34 @@ def test_shap_config():
254255
seed = 123
255256
granularity = "sentence"
256257
language = "german"
258+
model_type = "IMAGE_CLASSIFICATION"
259+
num_segments = 2
260+
feature_extraction_method = "segmentation"
261+
segment_compactness = 10
262+
max_objects = 4
263+
iou_threshold = 0.5
264+
context = 1.0
257265
text_config = TextConfig(
258266
granularity=granularity,
259267
language=language,
260268
)
269+
image_config = ImageConfig(
270+
model_type=model_type,
271+
num_segments=num_segments,
272+
feature_extraction_method=feature_extraction_method,
273+
segment_compactness=segment_compactness,
274+
max_objects=max_objects,
275+
iou_threshold=iou_threshold,
276+
context=context,
277+
)
261278
shap_config = SHAPConfig(
262279
baseline=baseline,
263280
num_samples=num_samples,
264281
agg_method=agg_method,
265282
use_logit=use_logit,
266283
seed=seed,
267284
text_config=text_config,
285+
image_config=image_config,
268286
)
269287
expected_config = {
270288
"shap": {
@@ -278,6 +296,15 @@ def test_shap_config():
278296
"granularity": granularity,
279297
"language": language,
280298
},
299+
"image_config": {
300+
"model_type": model_type,
301+
"num_segments": num_segments,
302+
"feature_extraction_method": feature_extraction_method,
303+
"segment_compactness": segment_compactness,
304+
"max_objects": max_objects,
305+
"iou_threshold": iou_threshold,
306+
"context": context,
307+
},
281308
}
282309
}
283310
assert expected_config == shap_config.get_explainability_config()
@@ -359,6 +386,50 @@ def test_invalid_text_config():
359386
assert "Invalid language invalid. Please choose among ['chinese'," in str(error.value)
360387

361388

389+
def test_image_config():
390+
model_type = "IMAGE_CLASSIFICATION"
391+
num_segments = 2
392+
feature_extraction_method = "segmentation"
393+
segment_compactness = 10
394+
max_objects = 4
395+
iou_threshold = 0.5
396+
context = 1.0
397+
image_config = ImageConfig(
398+
model_type=model_type,
399+
num_segments=num_segments,
400+
feature_extraction_method=feature_extraction_method,
401+
segment_compactness=segment_compactness,
402+
max_objects=max_objects,
403+
iou_threshold=iou_threshold,
404+
context=context,
405+
)
406+
expected_config = {
407+
"model_type": model_type,
408+
"num_segments": num_segments,
409+
"feature_extraction_method": feature_extraction_method,
410+
"segment_compactness": segment_compactness,
411+
"max_objects": max_objects,
412+
"iou_threshold": iou_threshold,
413+
"context": context,
414+
}
415+
416+
assert expected_config == image_config.get_image_config()
417+
418+
419+
def test_invalid_image_config():
420+
model_type = "OBJECT_SEGMENTATION"
421+
num_segments = 2
422+
with pytest.raises(ValueError) as error:
423+
ImageConfig(
424+
model_type=model_type,
425+
num_segments=num_segments,
426+
)
427+
assert (
428+
"Clarify SHAP only supports object detection and image classification methods. "
429+
"Please set model_type to OBJECT_DETECTION or IMAGE_CLASSIFICATION." in str(error.value)
430+
)
431+
432+
362433
def test_invalid_shap_config():
363434
with pytest.raises(ValueError) as error:
364435
SHAPConfig(
@@ -665,6 +736,7 @@ def _run_test_explain(
665736
model_scores,
666737
expected_predictor_config,
667738
expected_text_config=None,
739+
expected_image_config=None,
668740
):
669741
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
670742
explanation_configs = None
@@ -684,21 +756,6 @@ def _run_test_explain(
684756
job_name="test",
685757
experiment_config={"ExperimentName": "AnExperiment"},
686758
)
687-
expected_shap_config = {
688-
"baseline": [
689-
[
690-
0.26124998927116394,
691-
0.2824999988079071,
692-
0.06875000149011612,
693-
]
694-
],
695-
"num_samples": 100,
696-
"agg_method": "mean_sq",
697-
"use_logit": False,
698-
"save_local_shap_values": True,
699-
}
700-
if expected_text_config:
701-
expected_shap_config["text_config"] = expected_text_config
702759
expected_analysis_config = {
703760
"dataset_type": "text/csv",
704761
"headers": [
@@ -710,9 +767,6 @@ def _run_test_explain(
710767
],
711768
"label": "Label",
712769
"joinsource_name_or_index": "F4",
713-
"methods": {
714-
"shap": expected_shap_config,
715-
},
716770
"predictor": expected_predictor_config,
717771
}
718772
expected_explanation_configs = {}
@@ -732,6 +786,8 @@ def _run_test_explain(
732786
}
733787
if expected_text_config:
734788
expected_explanation_configs["shap"]["text_config"] = expected_text_config
789+
if expected_image_config:
790+
expected_explanation_configs["shap"]["image_config"] = expected_image_config
735791
if pdp_config:
736792
expected_explanation_configs["pdp"] = {
737793
"features": ["F1", "F2"],
@@ -963,3 +1019,70 @@ def test_shap_with_text_config(
9631019
expected_predictor_config,
9641020
expected_text_config=expected_text_config,
9651021
)
1022+
1023+
1024+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
1025+
def test_shap_with_image_config(
1026+
name_from_base,
1027+
clarify_processor,
1028+
clarify_processor_with_job_name_prefix,
1029+
data_config,
1030+
model_config,
1031+
):
1032+
model_type = "IMAGE_CLASSIFICATION"
1033+
num_segments = 2
1034+
feature_extraction_method = "segmentation"
1035+
segment_compactness = 10
1036+
max_objects = 4
1037+
iou_threshold = 0.5
1038+
context = 1.0
1039+
image_config = ImageConfig(
1040+
model_type=model_type,
1041+
num_segments=num_segments,
1042+
feature_extraction_method=feature_extraction_method,
1043+
segment_compactness=segment_compactness,
1044+
max_objects=max_objects,
1045+
iou_threshold=iou_threshold,
1046+
context=context,
1047+
)
1048+
1049+
shap_config = SHAPConfig(
1050+
baseline=[
1051+
[
1052+
0.26124998927116394,
1053+
0.2824999988079071,
1054+
0.06875000149011612,
1055+
]
1056+
],
1057+
num_samples=100,
1058+
agg_method="mean_sq",
1059+
image_config=image_config,
1060+
)
1061+
1062+
expected_image_config = {
1063+
"model_type": model_type,
1064+
"num_segments": num_segments,
1065+
"feature_extraction_method": feature_extraction_method,
1066+
"segment_compactness": segment_compactness,
1067+
"max_objects": max_objects,
1068+
"iou_threshold": iou_threshold,
1069+
"context": context,
1070+
}
1071+
expected_predictor_config = {
1072+
"model_name": "xgboost-model",
1073+
"instance_type": "ml.c5.xlarge",
1074+
"initial_instance_count": 1,
1075+
}
1076+
1077+
_run_test_explain(
1078+
name_from_base,
1079+
clarify_processor,
1080+
clarify_processor_with_job_name_prefix,
1081+
data_config,
1082+
model_config,
1083+
shap_config,
1084+
None,
1085+
None,
1086+
expected_predictor_config,
1087+
expected_image_config=expected_image_config,
1088+
)

0 commit comments

Comments
 (0)