25
25
26
26
import tempfile
27
27
from abc import ABC , abstractmethod
28
- from typing import List , Union , Dict
28
+ from typing import List , Union , Dict , Optional , Any
29
29
30
30
from sagemaker import image_uris , s3 , utils
31
+ from sagemaker .session import Session
32
+ from sagemaker .network import NetworkConfig
31
33
from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
32
34
33
35
logger = logging .getLogger (__name__ )
@@ -38,21 +40,21 @@ class DataConfig:
38
40
39
41
def __init__ (
40
42
self ,
41
- s3_data_input_path ,
42
- s3_output_path ,
43
- s3_analysis_config_output_path = None ,
44
- label = None ,
45
- headers = None ,
46
- features = None ,
47
- dataset_type = "text/csv" ,
48
- s3_compression_type = "None" ,
49
- joinsource = None ,
50
- facet_dataset_uri = None ,
51
- facet_headers = None ,
52
- predicted_label_dataset_uri = None ,
53
- predicted_label_headers = None ,
54
- predicted_label = None ,
55
- excluded_columns = None ,
43
+ s3_data_input_path : str ,
44
+ s3_output_path : str ,
45
+ s3_analysis_config_output_path : Optional [ str ] = None ,
46
+ label : Optional [ str ] = None ,
47
+ headers : Optional [ List [ str ]] = None ,
48
+ features : Optional [ List [ str ]] = None ,
49
+ dataset_type : str = "text/csv" ,
50
+ s3_compression_type : str = "None" ,
51
+ joinsource : Optional [ Union [ str , int ]] = None ,
52
+ facet_dataset_uri : Optional [ str ] = None ,
53
+ facet_headers : Optional [ List [ str ]] = None ,
54
+ predicted_label_dataset_uri : Optional [ str ] = None ,
55
+ predicted_label_headers : Optional [ List [ str ]] = None ,
56
+ predicted_label : Optional [ Union [ str , int ]] = None ,
57
+ excluded_columns : Optional [ Union [ List [ int ], List [ str ]]] = None ,
56
58
):
57
59
"""Initializes a configuration of both input and output datasets.
58
60
@@ -65,7 +67,7 @@ def __init__(
65
67
label (str): Target attribute of the model required by bias metrics.
66
68
Specified as column name or index for CSV dataset or as JSONPath for JSONLines.
67
69
*Required parameter* except for when the input dataset does not contain the label.
68
- features (str): JSONPath for locating the feature columns for bias metrics if the
70
+ features (List[ str] ): JSONPath for locating the feature columns for bias metrics if the
69
71
dataset format is JSONLines.
70
72
dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
71
73
``"application/jsonlines"`` for JSONLines, and
@@ -191,10 +193,10 @@ class BiasConfig:
191
193
192
194
def __init__ (
193
195
self ,
194
- label_values_or_threshold ,
195
- facet_name ,
196
- facet_values_or_threshold = None ,
197
- group_name = None ,
196
+ label_values_or_threshold : Union [ int , float , str ] ,
197
+ facet_name : Union [ str , int , List [ str ], List [ int ]] ,
198
+ facet_values_or_threshold : Optional [ Union [ int , float , str ]] = None ,
199
+ group_name : Optional [ str ] = None ,
198
200
):
199
201
"""Initializes a configuration of the sensitive groups in the dataset.
200
202
@@ -275,17 +277,17 @@ class ModelConfig:
275
277
276
278
def __init__ (
277
279
self ,
278
- model_name : str = None ,
279
- instance_count : int = None ,
280
- instance_type : str = None ,
281
- accept_type : str = None ,
282
- content_type : str = None ,
283
- content_template : str = None ,
284
- custom_attributes : str = None ,
285
- accelerator_type : str = None ,
286
- endpoint_name_prefix : str = None ,
287
- target_model : str = None ,
288
- endpoint_name : str = None ,
280
+ model_name : str ,
281
+ instance_count : int ,
282
+ instance_type : str ,
283
+ accept_type : Optional [ str ] = None ,
284
+ content_type : Optional [ str ] = None ,
285
+ content_template : Optional [ str ] = None ,
286
+ custom_attributes : Optional [ str ] = None ,
287
+ accelerator_type : Optional [ str ] = None ,
288
+ endpoint_name_prefix : Optional [ str ] = None ,
289
+ target_model : Optional [ str ] = None ,
290
+ endpoint_name : Optional [ str ] = None ,
289
291
):
290
292
r"""Initializes a configuration of a model and the endpoint to be created for it.
291
293
@@ -414,10 +416,10 @@ class ModelPredictedLabelConfig:
414
416
415
417
def __init__ (
416
418
self ,
417
- label = None ,
418
- probability = None ,
419
- probability_threshold = None ,
420
- label_headers = None ,
419
+ label : Optional [ Union [ str , int ]] = None ,
420
+ probability : Optional [ Union [ str , int ]] = None ,
421
+ probability_threshold : Optional [ float ] = None ,
422
+ label_headers : Optional [ List [ str ]] = None ,
421
423
):
422
424
"""Initializes a model output config to extract the predicted label or predicted score(s).
423
425
@@ -509,7 +511,9 @@ class PDPConfig(ExplainabilityConfig):
509
511
and the corresponding values are included in the analysis output.
510
512
""" # noqa E501
511
513
512
- def __init__ (self , features = None , grid_resolution = 15 , top_k_features = 10 ):
514
+ def __init__ (
515
+ self , features : Optional [List ] = None , grid_resolution : int = 15 , top_k_features : int = 10
516
+ ):
513
517
"""Initializes PDP config.
514
518
515
519
Args:
@@ -680,8 +684,8 @@ class TextConfig:
680
684
681
685
def __init__ (
682
686
self ,
683
- granularity ,
684
- language ,
687
+ granularity : str ,
688
+ language : str ,
685
689
):
686
690
"""Initializes a text configuration.
687
691
@@ -736,13 +740,13 @@ class ImageConfig:
736
740
737
741
def __init__ (
738
742
self ,
739
- model_type ,
740
- num_segments = None ,
741
- feature_extraction_method = None ,
742
- segment_compactness = None ,
743
- max_objects = None ,
744
- iou_threshold = None ,
745
- context = None ,
743
+ model_type : str ,
744
+ num_segments : Optional [ int ] = None ,
745
+ feature_extraction_method : Optional [ str ] = None ,
746
+ segment_compactness : Optional [ float ] = None ,
747
+ max_objects : Optional [ int ] = None ,
748
+ iou_threshold : Optional [ float ] = None ,
749
+ context : Optional [ float ] = None ,
746
750
):
747
751
"""Initializes a config object for Computer Vision (CV) Image explainability.
748
752
@@ -817,15 +821,15 @@ class SHAPConfig(ExplainabilityConfig):
817
821
818
822
def __init__ (
819
823
self ,
820
- baseline = None ,
821
- num_samples = None ,
822
- agg_method = None ,
823
- use_logit = False ,
824
- save_local_shap_values = True ,
825
- seed = None ,
826
- num_clusters = None ,
827
- text_config = None ,
828
- image_config = None ,
824
+ baseline : Optional [ Union [ str , List ]] = None ,
825
+ num_samples : Optional [ int ] = None ,
826
+ agg_method : Optional [ str ] = None ,
827
+ use_logit : Optional [ bool ] = None ,
828
+ save_local_shap_values : Optional [ bool ] = None ,
829
+ seed : Optional [ int ] = None ,
830
+ num_clusters : Optional [ int ] = None ,
831
+ text_config : Optional [ TextConfig ] = None ,
832
+ image_config : Optional [ ImageConfig ] = None ,
829
833
):
830
834
"""Initializes config for SHAP analysis.
831
835
@@ -909,19 +913,19 @@ class SageMakerClarifyProcessor(Processor):
909
913
910
914
def __init__ (
911
915
self ,
912
- role ,
913
- instance_count ,
914
- instance_type ,
915
- volume_size_in_gb = 30 ,
916
- volume_kms_key = None ,
917
- output_kms_key = None ,
918
- max_runtime_in_seconds = None ,
919
- sagemaker_session = None ,
920
- env = None ,
921
- tags = None ,
922
- network_config = None ,
923
- job_name_prefix = None ,
924
- version = None ,
916
+ role : str ,
917
+ instance_count : int ,
918
+ instance_type : str ,
919
+ volume_size_in_gb : int = 30 ,
920
+ volume_kms_key : Optional [ str ] = None ,
921
+ output_kms_key : Optional [ str ] = None ,
922
+ max_runtime_in_seconds : Optional [ int ] = None ,
923
+ sagemaker_session : Optional [ Session ] = None ,
924
+ env : Optional [ Dict [ str , str ]] = None ,
925
+ tags : Optional [ List [ Dict [ str , str ]]] = None ,
926
+ network_config : Optional [ NetworkConfig ] = None ,
927
+ job_name_prefix : Optional [ str ] = None ,
928
+ version : Optional [ str ] = None ,
925
929
):
926
930
"""Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
927
931
@@ -993,13 +997,13 @@ def run(self, **_):
993
997
994
998
def _run (
995
999
self ,
996
- data_config ,
997
- analysis_config ,
998
- wait ,
999
- logs ,
1000
- job_name ,
1001
- kms_key ,
1002
- experiment_config ,
1000
+ data_config : DataConfig ,
1001
+ analysis_config : Dict [ str , Any ] ,
1002
+ wait : bool ,
1003
+ logs : bool ,
1004
+ job_name : str ,
1005
+ kms_key : str ,
1006
+ experiment_config : Dict [ str , str ] ,
1003
1007
):
1004
1008
"""Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container
1005
1009
@@ -1077,14 +1081,14 @@ def _run(
1077
1081
1078
1082
def run_pre_training_bias (
1079
1083
self ,
1080
- data_config ,
1081
- data_bias_config ,
1082
- methods = "all" ,
1083
- wait = True ,
1084
- logs = True ,
1085
- job_name = None ,
1086
- kms_key = None ,
1087
- experiment_config = None ,
1084
+ data_config : DataConfig ,
1085
+ data_bias_config : BiasConfig ,
1086
+ methods : str = "all" ,
1087
+ wait : bool = True ,
1088
+ logs : bool = True ,
1089
+ job_name : Optional [ str ] = None ,
1090
+ kms_key : Optional [ str ] = None ,
1091
+ experiment_config : Optional [ Dict [ str , str ]] = None ,
1088
1092
):
1089
1093
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods
1090
1094
@@ -1146,16 +1150,16 @@ def run_pre_training_bias(
1146
1150
1147
1151
def run_post_training_bias (
1148
1152
self ,
1149
- data_config ,
1150
- data_bias_config ,
1151
- model_config ,
1152
- model_predicted_label_config ,
1153
- methods = "all" ,
1154
- wait = True ,
1155
- logs = True ,
1156
- job_name = None ,
1157
- kms_key = None ,
1158
- experiment_config = None ,
1153
+ data_config : DataConfig ,
1154
+ data_bias_config : BiasConfig ,
1155
+ model_config : ModelConfig ,
1156
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1157
+ methods : str = "all" ,
1158
+ wait : bool = True ,
1159
+ logs : bool = True ,
1160
+ job_name : Optional [ str ] = None ,
1161
+ kms_key : Optional [ str ] = None ,
1162
+ experiment_config : Optional [ Dict [ str , str ]] = None ,
1159
1163
):
1160
1164
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias
1161
1165
@@ -1231,17 +1235,17 @@ def run_post_training_bias(
1231
1235
1232
1236
def run_bias (
1233
1237
self ,
1234
- data_config ,
1235
- bias_config ,
1236
- model_config ,
1237
- model_predicted_label_config = None ,
1238
- pre_training_methods = "all" ,
1239
- post_training_methods = "all" ,
1240
- wait = True ,
1241
- logs = True ,
1242
- job_name = None ,
1243
- kms_key = None ,
1244
- experiment_config = None ,
1238
+ data_config : DataConfig ,
1239
+ bias_config : BiasConfig ,
1240
+ model_config : ModelConfig ,
1241
+ model_predicted_label_config : Optional [ ModelPredictedLabelConfig ] = None ,
1242
+ pre_training_methods : str = "all" ,
1243
+ post_training_methods : str = "all" ,
1244
+ wait : bool = True ,
1245
+ logs : bool = True ,
1246
+ job_name : Optional [ str ] = None ,
1247
+ kms_key : Optional [ str ] = None ,
1248
+ experiment_config : Optional [ Dict [ str , str ]] = None ,
1245
1249
):
1246
1250
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods
1247
1251
@@ -1325,15 +1329,15 @@ def run_bias(
1325
1329
1326
1330
def run_explainability (
1327
1331
self ,
1328
- data_config ,
1329
- model_config ,
1330
- explainability_config ,
1331
- model_scores = None ,
1332
- wait = True ,
1333
- logs = True ,
1334
- job_name = None ,
1335
- kms_key = None ,
1336
- experiment_config = None ,
1332
+ data_config : DataConfig ,
1333
+ model_config : ModelConfig ,
1334
+ explainability_config : Union [ ExplainabilityConfig , List ] ,
1335
+ model_scores : Optional [ Union [ int , ModelPredictedLabelConfig ]] = None ,
1336
+ wait : bool = True ,
1337
+ logs : bool = True ,
1338
+ job_name : Optional [ str ] = None ,
1339
+ kms_key : Optional [ str ] = None ,
1340
+ experiment_config : Optional [ Dict [ str , str ]] = None ,
1337
1341
):
1338
1342
"""Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
1339
1343
0 commit comments