Skip to content

Commit e9d6732

Browse files
author
Dewen Qi
committed
change: Add PipelineVariable annotation for all processor subclasses
1 parent 17a479a commit e9d6732

File tree

8 files changed

+259
-223
lines changed

8 files changed

+259
-223
lines changed

src/sagemaker/clarify.py

Lines changed: 117 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28-
from typing import List, Union, Dict
28+
from typing import List, Union, Dict, Optional, Any
2929

3030
from sagemaker import image_uris, s3, utils
31+
from sagemaker.session import Session
32+
from sagemaker.network import NetworkConfig
3133
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3234

3335
logger = logging.getLogger(__name__)
@@ -38,21 +40,21 @@ class DataConfig:
3840

3941
def __init__(
4042
self,
41-
s3_data_input_path,
42-
s3_output_path,
43-
s3_analysis_config_output_path=None,
44-
label=None,
45-
headers=None,
46-
features=None,
47-
dataset_type="text/csv",
48-
s3_compression_type="None",
49-
joinsource=None,
50-
facet_dataset_uri=None,
51-
facet_headers=None,
52-
predicted_label_dataset_uri=None,
53-
predicted_label_headers=None,
54-
predicted_label=None,
55-
excluded_columns=None,
43+
s3_data_input_path: str,
44+
s3_output_path: str,
45+
s3_analysis_config_output_path: Optional[str] = None,
46+
label: Optional[str] = None,
47+
headers: Optional[List[str]] = None,
48+
features: Optional[List[str]] = None,
49+
dataset_type: str = "text/csv",
50+
s3_compression_type: str = "None",
51+
joinsource: Optional[Union[str, int]] = None,
52+
facet_dataset_uri: Optional[str] = None,
53+
facet_headers: Optional[List[str]] = None,
54+
predicted_label_dataset_uri: Optional[str] = None,
55+
predicted_label_headers: Optional[List[str]] = None,
56+
predicted_label: Optional[Union[str, int]] = None,
57+
excluded_columns: Optional[Union[List[int], List[str]]] = None,
5658
):
5759
"""Initializes a configuration of both input and output datasets.
5860
@@ -65,7 +67,7 @@ def __init__(
6567
label (str): Target attribute of the model required by bias metrics.
6668
Specified as column name or index for CSV dataset or as JSONPath for JSONLines.
6769
*Required parameter* except for when the input dataset does not contain the label.
68-
features (str): JSONPath for locating the feature columns for bias metrics if the
70+
features (List[str]): JSONPath for locating the feature columns for bias metrics if the
6971
dataset format is JSONLines.
7072
dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
7173
``"application/jsonlines"`` for JSONLines, and
@@ -191,10 +193,10 @@ class BiasConfig:
191193

192194
def __init__(
193195
self,
194-
label_values_or_threshold,
195-
facet_name,
196-
facet_values_or_threshold=None,
197-
group_name=None,
196+
label_values_or_threshold: Union[int, float, str],
197+
facet_name: Union[str, int, List[str], List[int]],
198+
facet_values_or_threshold: Optional[Union[int, float, str]] = None,
199+
group_name: Optional[str] = None,
198200
):
199201
"""Initializes a configuration of the sensitive groups in the dataset.
200202
@@ -275,17 +277,17 @@ class ModelConfig:
275277

276278
def __init__(
277279
self,
278-
model_name: str = None,
279-
instance_count: int = None,
280-
instance_type: str = None,
281-
accept_type: str = None,
282-
content_type: str = None,
283-
content_template: str = None,
284-
custom_attributes: str = None,
285-
accelerator_type: str = None,
286-
endpoint_name_prefix: str = None,
287-
target_model: str = None,
288-
endpoint_name: str = None,
280+
model_name: Optional[str] = None,
281+
instance_count: Optional[int] = None,
282+
instance_type: Optional[str] = None,
283+
accept_type: Optional[str] = None,
284+
content_type: Optional[str] = None,
285+
content_template: Optional[str] = None,
286+
custom_attributes: Optional[str] = None,
287+
accelerator_type: Optional[str] = None,
288+
endpoint_name_prefix: Optional[str] = None,
289+
target_model: Optional[str] = None,
290+
endpoint_name: Optional[str] = None,
289291
):
290292
r"""Initializes a configuration of a model and the endpoint to be created for it.
291293
@@ -414,10 +416,10 @@ class ModelPredictedLabelConfig:
414416

415417
def __init__(
416418
self,
417-
label=None,
418-
probability=None,
419-
probability_threshold=None,
420-
label_headers=None,
419+
label: Optional[Union[str, int]] = None,
420+
probability: Optional[Union[str, int]] = None,
421+
probability_threshold: Optional[float] = None,
422+
label_headers: Optional[List[str]] = None,
421423
):
422424
"""Initializes a model output config to extract the predicted label or predicted score(s).
423425
@@ -509,7 +511,9 @@ class PDPConfig(ExplainabilityConfig):
509511
and the corresponding values are included in the analysis output.
510512
""" # noqa E501
511513

512-
def __init__(self, features=None, grid_resolution=15, top_k_features=10):
514+
def __init__(
515+
self, features: Optional[List] = None, grid_resolution: int = 15, top_k_features: int = 10
516+
):
513517
"""Initializes PDP config.
514518
515519
Args:
@@ -680,8 +684,8 @@ class TextConfig:
680684

681685
def __init__(
682686
self,
683-
granularity,
684-
language,
687+
granularity: str,
688+
language: str,
685689
):
686690
"""Initializes a text configuration.
687691
@@ -736,13 +740,13 @@ class ImageConfig:
736740

737741
def __init__(
738742
self,
739-
model_type,
740-
num_segments=None,
741-
feature_extraction_method=None,
742-
segment_compactness=None,
743-
max_objects=None,
744-
iou_threshold=None,
745-
context=None,
743+
model_type: str,
744+
num_segments: Optional[int] = None,
745+
feature_extraction_method: Optional[str] = None,
746+
segment_compactness: Optional[float] = None,
747+
max_objects: Optional[int] = None,
748+
iou_threshold: Optional[float] = None,
749+
context: Optional[float] = None,
746750
):
747751
"""Initializes a config object for Computer Vision (CV) Image explainability.
748752
@@ -817,15 +821,15 @@ class SHAPConfig(ExplainabilityConfig):
817821

818822
def __init__(
819823
self,
820-
baseline=None,
821-
num_samples=None,
822-
agg_method=None,
823-
use_logit=False,
824-
save_local_shap_values=True,
825-
seed=None,
826-
num_clusters=None,
827-
text_config=None,
828-
image_config=None,
824+
baseline: Optional[Union[str, List]] = None,
825+
num_samples: Optional[int] = None,
826+
agg_method: Optional[str] = None,
827+
use_logit: Optional[bool] = None,
828+
save_local_shap_values: Optional[bool] = None,
829+
seed: Optional[int] = None,
830+
num_clusters: Optional[int] = None,
831+
text_config: Optional[TextConfig] = None,
832+
image_config: Optional[ImageConfig] = None,
829833
):
830834
"""Initializes config for SHAP analysis.
831835
@@ -909,19 +913,19 @@ class SageMakerClarifyProcessor(Processor):
909913

910914
def __init__(
911915
self,
912-
role,
913-
instance_count,
914-
instance_type,
915-
volume_size_in_gb=30,
916-
volume_kms_key=None,
917-
output_kms_key=None,
918-
max_runtime_in_seconds=None,
919-
sagemaker_session=None,
920-
env=None,
921-
tags=None,
922-
network_config=None,
923-
job_name_prefix=None,
924-
version=None,
916+
role: str,
917+
instance_count: int,
918+
instance_type: str,
919+
volume_size_in_gb: int = 30,
920+
volume_kms_key: Optional[str] = None,
921+
output_kms_key: Optional[str] = None,
922+
max_runtime_in_seconds: Optional[int] = None,
923+
sagemaker_session: Optional[Session] = None,
924+
env: Optional[Dict[str, str]] = None,
925+
tags: Optional[List[Dict[str, str]]] = None,
926+
network_config: Optional[NetworkConfig] = None,
927+
job_name_prefix: Optional[str] = None,
928+
version: Optional[str] = None,
925929
):
926930
"""Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
927931
@@ -993,13 +997,13 @@ def run(self, **_):
993997

994998
def _run(
995999
self,
996-
data_config,
997-
analysis_config,
998-
wait,
999-
logs,
1000-
job_name,
1001-
kms_key,
1002-
experiment_config,
1000+
data_config: DataConfig,
1001+
analysis_config: Dict[str, Any],
1002+
wait: bool,
1003+
logs: bool,
1004+
job_name: str,
1005+
kms_key: str,
1006+
experiment_config: Dict[str, str],
10031007
):
10041008
"""Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container
10051009
@@ -1077,14 +1081,14 @@ def _run(
10771081

10781082
def run_pre_training_bias(
10791083
self,
1080-
data_config,
1081-
data_bias_config,
1082-
methods="all",
1083-
wait=True,
1084-
logs=True,
1085-
job_name=None,
1086-
kms_key=None,
1087-
experiment_config=None,
1084+
data_config: DataConfig,
1085+
data_bias_config: BiasConfig,
1086+
methods: Union[str, List[str]] = "all",
1087+
wait: bool = True,
1088+
logs: bool = True,
1089+
job_name: Optional[str] = None,
1090+
kms_key: Optional[str] = None,
1091+
experiment_config: Optional[Dict[str, str]] = None,
10881092
):
10891093
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods
10901094
@@ -1146,16 +1150,16 @@ def run_pre_training_bias(
11461150

11471151
def run_post_training_bias(
11481152
self,
1149-
data_config,
1150-
data_bias_config,
1151-
model_config,
1152-
model_predicted_label_config,
1153-
methods="all",
1154-
wait=True,
1155-
logs=True,
1156-
job_name=None,
1157-
kms_key=None,
1158-
experiment_config=None,
1153+
data_config: DataConfig,
1154+
data_bias_config: BiasConfig,
1155+
model_config: ModelConfig,
1156+
model_predicted_label_config: ModelPredictedLabelConfig,
1157+
methods: Union[str, List[str]] = "all",
1158+
wait: bool = True,
1159+
logs: bool = True,
1160+
job_name: Optional[str] = None,
1161+
kms_key: Optional[str] = None,
1162+
experiment_config: Optional[Dict[str, str]] = None,
11591163
):
11601164
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias
11611165
@@ -1231,17 +1235,17 @@ def run_post_training_bias(
12311235

12321236
def run_bias(
12331237
self,
1234-
data_config,
1235-
bias_config,
1236-
model_config,
1237-
model_predicted_label_config=None,
1238-
pre_training_methods="all",
1239-
post_training_methods="all",
1240-
wait=True,
1241-
logs=True,
1242-
job_name=None,
1243-
kms_key=None,
1244-
experiment_config=None,
1238+
data_config: DataConfig,
1239+
bias_config: BiasConfig,
1240+
model_config: ModelConfig,
1241+
model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
1242+
pre_training_methods: Union[str, List[str]] = "all",
1243+
post_training_methods: Union[str, List[str]] = "all",
1244+
wait: bool = True,
1245+
logs: bool = True,
1246+
job_name: Optional[str] = None,
1247+
kms_key: Optional[str] = None,
1248+
experiment_config: Optional[Dict[str, str]] = None,
12451249
):
12461250
"""Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods
12471251
@@ -1325,15 +1329,15 @@ def run_bias(
13251329

13261330
def run_explainability(
13271331
self,
1328-
data_config,
1329-
model_config,
1330-
explainability_config,
1331-
model_scores=None,
1332-
wait=True,
1333-
logs=True,
1334-
job_name=None,
1335-
kms_key=None,
1336-
experiment_config=None,
1332+
data_config: DataConfig,
1333+
model_config: ModelConfig,
1334+
explainability_config: Union[ExplainabilityConfig, List],
1335+
model_scores: Optional[Union[int, str, ModelPredictedLabelConfig]] = None,
1336+
wait: bool = True,
1337+
logs: bool = True,
1338+
job_name: Optional[str] = None,
1339+
kms_key: Optional[str] = None,
1340+
experiment_config: Optional[Dict[str, str]] = None,
13371341
):
13381342
"""Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
13391343

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def validate_smdistributed(
575575
if "smdistributed" not in distribution:
576576
# Distribution strategy other than smdistributed is selected
577577
return
578-
if is_pipeline_variable(instance_type):
578+
if is_pipeline_variable(instance_type) or is_pipeline_variable(image_uri):
579579
# The instance_type is not available in compile time.
580580
# Rather, it's given in Pipeline execution time
581581
return

0 commit comments

Comments
 (0)