28
28
SageMakerClarifyProcessor ,
29
29
SHAPConfig ,
30
30
TextConfig ,
31
+ ImageConfig ,
31
32
)
32
33
33
34
JOB_NAME_PREFIX = "my-prefix"
@@ -254,17 +255,34 @@ def test_shap_config():
254
255
seed = 123
255
256
granularity = "sentence"
256
257
language = "german"
258
+ model_type = "IMAGE_CLASSIFICATION"
259
+ num_segments = 2
260
+ feature_extraction_method = "segmentation"
261
+ segment_compactness = 10
262
+ max_objects = 4
263
+ iou_threshold = 0.5
264
+ context = 1.0
257
265
text_config = TextConfig (
258
266
granularity = granularity ,
259
267
language = language ,
260
268
)
269
+ image_config = ImageConfig (
270
+ model_type = model_type ,
271
+ num_segments = num_segments ,
272
+ feature_extraction_method = feature_extraction_method ,
273
+ segment_compactness = segment_compactness ,
274
+ max_objects = max_objects ,
275
+ iou_threshold = iou_threshold ,
276
+ context = context ,
277
+ )
261
278
shap_config = SHAPConfig (
262
279
baseline = baseline ,
263
280
num_samples = num_samples ,
264
281
agg_method = agg_method ,
265
282
use_logit = use_logit ,
266
283
seed = seed ,
267
284
text_config = text_config ,
285
+ image_config = image_config ,
268
286
)
269
287
expected_config = {
270
288
"shap" : {
@@ -278,6 +296,15 @@ def test_shap_config():
278
296
"granularity" : granularity ,
279
297
"language" : language ,
280
298
},
299
+ "image_config" : {
300
+ "model_type" : model_type ,
301
+ "num_segments" : num_segments ,
302
+ "feature_extraction_method" : feature_extraction_method ,
303
+ "segment_compactness" : segment_compactness ,
304
+ "max_objects" : max_objects ,
305
+ "iou_threshold" : iou_threshold ,
306
+ "context" : context ,
307
+ },
281
308
}
282
309
}
283
310
assert expected_config == shap_config .get_explainability_config ()
@@ -359,6 +386,50 @@ def test_invalid_text_config():
359
386
assert "Invalid language invalid. Please choose among ['chinese'," in str (error .value )
360
387
361
388
389
+ def test_image_config ():
390
+ model_type = "IMAGE_CLASSIFICATION"
391
+ num_segments = 2
392
+ feature_extraction_method = "segmentation"
393
+ segment_compactness = 10
394
+ max_objects = 4
395
+ iou_threshold = 0.5
396
+ context = 1.0
397
+ image_config = ImageConfig (
398
+ model_type = model_type ,
399
+ num_segments = num_segments ,
400
+ feature_extraction_method = feature_extraction_method ,
401
+ segment_compactness = segment_compactness ,
402
+ max_objects = max_objects ,
403
+ iou_threshold = iou_threshold ,
404
+ context = context ,
405
+ )
406
+ expected_config = {
407
+ "model_type" : model_type ,
408
+ "num_segments" : num_segments ,
409
+ "feature_extraction_method" : feature_extraction_method ,
410
+ "segment_compactness" : segment_compactness ,
411
+ "max_objects" : max_objects ,
412
+ "iou_threshold" : iou_threshold ,
413
+ "context" : context ,
414
+ }
415
+
416
+ assert expected_config == image_config .get_image_config ()
417
+
418
+
419
+ def test_invalid_image_config ():
420
+ model_type = "OBJECT_SEGMENTATION"
421
+ num_segments = 2
422
+ with pytest .raises (ValueError ) as error :
423
+ ImageConfig (
424
+ model_type = model_type ,
425
+ num_segments = num_segments ,
426
+ )
427
+ assert (
428
+ "Clarify SHAP only supports object detection and image classification methods. "
429
+ "Please set model_type to OBJECT_DETECTION or IMAGE_CLASSIFICATION." in str (error .value )
430
+ )
431
+
432
+
362
433
def test_invalid_shap_config ():
363
434
with pytest .raises (ValueError ) as error :
364
435
SHAPConfig (
@@ -665,6 +736,7 @@ def _run_test_explain(
665
736
model_scores ,
666
737
expected_predictor_config ,
667
738
expected_text_config = None ,
739
+ expected_image_config = None ,
668
740
):
669
741
with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
670
742
explanation_configs = None
@@ -684,21 +756,6 @@ def _run_test_explain(
684
756
job_name = "test" ,
685
757
experiment_config = {"ExperimentName" : "AnExperiment" },
686
758
)
687
- expected_shap_config = {
688
- "baseline" : [
689
- [
690
- 0.26124998927116394 ,
691
- 0.2824999988079071 ,
692
- 0.06875000149011612 ,
693
- ]
694
- ],
695
- "num_samples" : 100 ,
696
- "agg_method" : "mean_sq" ,
697
- "use_logit" : False ,
698
- "save_local_shap_values" : True ,
699
- }
700
- if expected_text_config :
701
- expected_shap_config ["text_config" ] = expected_text_config
702
759
expected_analysis_config = {
703
760
"dataset_type" : "text/csv" ,
704
761
"headers" : [
@@ -710,9 +767,6 @@ def _run_test_explain(
710
767
],
711
768
"label" : "Label" ,
712
769
"joinsource_name_or_index" : "F4" ,
713
- "methods" : {
714
- "shap" : expected_shap_config ,
715
- },
716
770
"predictor" : expected_predictor_config ,
717
771
}
718
772
expected_explanation_configs = {}
@@ -732,6 +786,8 @@ def _run_test_explain(
732
786
}
733
787
if expected_text_config :
734
788
expected_explanation_configs ["shap" ]["text_config" ] = expected_text_config
789
+ if expected_image_config :
790
+ expected_explanation_configs ["shap" ]["image_config" ] = expected_image_config
735
791
if pdp_config :
736
792
expected_explanation_configs ["pdp" ] = {
737
793
"features" : ["F1" , "F2" ],
@@ -963,3 +1019,70 @@ def test_shap_with_text_config(
963
1019
expected_predictor_config ,
964
1020
expected_text_config = expected_text_config ,
965
1021
)
1022
+
1023
+
1024
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
1025
+ def test_shap_with_image_config (
1026
+ name_from_base ,
1027
+ clarify_processor ,
1028
+ clarify_processor_with_job_name_prefix ,
1029
+ data_config ,
1030
+ model_config ,
1031
+ ):
1032
+ model_type = "IMAGE_CLASSIFICATION"
1033
+ num_segments = 2
1034
+ feature_extraction_method = "segmentation"
1035
+ segment_compactness = 10
1036
+ max_objects = 4
1037
+ iou_threshold = 0.5
1038
+ context = 1.0
1039
+ image_config = ImageConfig (
1040
+ model_type = model_type ,
1041
+ num_segments = num_segments ,
1042
+ feature_extraction_method = feature_extraction_method ,
1043
+ segment_compactness = segment_compactness ,
1044
+ max_objects = max_objects ,
1045
+ iou_threshold = iou_threshold ,
1046
+ context = context ,
1047
+ )
1048
+
1049
+ shap_config = SHAPConfig (
1050
+ baseline = [
1051
+ [
1052
+ 0.26124998927116394 ,
1053
+ 0.2824999988079071 ,
1054
+ 0.06875000149011612 ,
1055
+ ]
1056
+ ],
1057
+ num_samples = 100 ,
1058
+ agg_method = "mean_sq" ,
1059
+ image_config = image_config ,
1060
+ )
1061
+
1062
+ expected_image_config = {
1063
+ "model_type" : model_type ,
1064
+ "num_segments" : num_segments ,
1065
+ "feature_extraction_method" : feature_extraction_method ,
1066
+ "segment_compactness" : segment_compactness ,
1067
+ "max_objects" : max_objects ,
1068
+ "iou_threshold" : iou_threshold ,
1069
+ "context" : context ,
1070
+ }
1071
+ expected_predictor_config = {
1072
+ "model_name" : "xgboost-model" ,
1073
+ "instance_type" : "ml.c5.xlarge" ,
1074
+ "initial_instance_count" : 1 ,
1075
+ }
1076
+
1077
+ _run_test_explain (
1078
+ name_from_base ,
1079
+ clarify_processor ,
1080
+ clarify_processor_with_job_name_prefix ,
1081
+ data_config ,
1082
+ model_config ,
1083
+ shap_config ,
1084
+ None ,
1085
+ None ,
1086
+ expected_predictor_config ,
1087
+ expected_image_config = expected_image_config ,
1088
+ )
0 commit comments