Skip to content

Commit f7a7f98

Browse files
authored
Merge branch 'master' into feat/jumpstart-training-metrics
2 parents 9aefb21 + 7f823e1 commit f7a7f98

File tree

2 files changed

+107
-28
lines changed

2 files changed

+107
-28
lines changed

src/sagemaker/clarify.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
in (
5050
"text/csv",
5151
"application/jsonlines",
52+
"application/json",
5253
"application/sagemakercapturejson",
5354
"application/x-parquet",
5455
"application/x-image",
@@ -311,7 +312,7 @@ def __init__(
311312
s3_analysis_config_output_path: Optional[str] = None,
312313
label: Optional[str] = None,
313314
headers: Optional[List[str]] = None,
314-
features: Optional[List[str]] = None,
315+
features: Optional[str] = None,
315316
dataset_type: str = "text/csv",
316317
s3_compression_type: str = "None",
317318
joinsource: Optional[Union[str, int]] = None,
@@ -331,12 +332,18 @@ def __init__(
331332
If this field is None, then the ``s3_output_path`` will be used
332333
to store the ``analysis_config`` output.
333334
label (str): Target attribute of the model required by bias metrics. Specified as
334-
column name or index for CSV dataset or as JMESPath expression for JSONLines.
335+
column name or index for CSV dataset or a JMESPath expression for JSON/JSON Lines.
335336
*Required parameter* except for when the input dataset does not contain the label.
336-
features (List[str]): JMESPath expression to locate the feature columns for
337-
bias metrics if the dataset format is JSONLines.
337+
Note: For JSON, the JMESPath query must result in a list of labels for each
338+
sample. For JSON Lines, it must result in the label for each line.
339+
Only a single label per sample is supported at this time.
340+
features (str): JMESPath expression to locate the feature values
341+
if the dataset format is JSON/JSON Lines.
342+
Note: For JSON, the JMESPath query must result in a 2-D list (or a matrix) of
343+
feature values. For JSON Lines, it must result in a 1-D list of features for each
344+
line.
338345
dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
339-
``"application/jsonlines"`` for JSONLines, and
346+
``"application/jsonlines"`` for JSON Lines, ``"application/json"`` for JSON, and
340347
``"application/x-parquet"`` for Parquet.
341348
s3_compression_type (str): Valid options are "None" or ``"Gzip"``.
342349
joinsource (str or int): The name or index of the column in the dataset that
@@ -359,6 +366,7 @@ def __init__(
359366
360367
Clarify will not use the ``joinsource`` column and columns present in the facet
361368
dataset when calling model inference APIs.
369+
Note: this is only supported for ``"text/csv"`` dataset type.
362370
facet_headers (list[str]): List of column names in the facet dataset.
363371
predicted_label_dataset_uri (str): Dataset S3 prefix/object URI with predicted labels,
364372
which are used directly for analysis instead of making model inference API calls.
@@ -368,11 +376,16 @@ def __init__(
368376
* If the dataset and predicted label dataset are in multiple files (either one),
369377
then an index column, ``joinsource``, is required to join the two datasets.
370378
379+
Note: this is only supported for ``"text/csv"`` dataset type.
371380
predicted_label_headers (list[str]): List of column names in the predicted label dataset
372381
predicted_label (str or int): Predicted label of the target attribute of the model
373-
required for running bias analysis. Specified as column name or index for CSV data.
382+
required for running bias analysis. Specified as column name or index for CSV data,
383+
or a JMESPath expression for JSON/JSON Lines.
374384
Clarify uses the predicted labels directly instead of making model inference API
375385
calls.
386+
Note: For JSON, the JMESPath query must result in a list of predicted labels for
387+
each sample. For JSON Lines, it must result in the predicted label for each line.
388+
Only a single predicted label per sample is supported at this time.
376389
excluded_columns (list[int] or list[str]): A list of names or indices of the columns
377390
which are to be excluded from making model inference API calls.
378391
@@ -384,15 +397,21 @@ def __init__(
384397
if dataset_type not in [
385398
"text/csv",
386399
"application/jsonlines",
400+
"application/json",
387401
"application/x-parquet",
388402
"application/x-image",
389403
]:
390404
raise ValueError(
391405
f"Invalid dataset_type '{dataset_type}'."
392406
f" Please check the API documentation for the supported dataset types."
393407
)
394-
# parameters for analysis on datasets without facets are only supported for CSV datasets
395-
if dataset_type != "text/csv":
408+
# predicted_label and excluded_columns are only supported for tabular datasets
409+
if dataset_type not in [
410+
"text/csv",
411+
"application/jsonlines",
412+
"application/json",
413+
"application/x-parquet",
414+
]:
396415
if predicted_label:
397416
raise ValueError(
398417
f"The parameter 'predicted_label' is not supported"
@@ -405,6 +424,8 @@ def __init__(
405424
f" for dataset_type '{dataset_type}'."
406425
f" Please check the API documentation for the supported dataset types."
407426
)
427+
# parameters for analysis on datasets without facets are only supported for CSV datasets
428+
if dataset_type != "text/csv":
408429
if facet_dataset_uri or facet_headers:
409430
raise ValueError(
410431
f"The parameters 'facet_dataset_uri' and 'facet_headers'"
@@ -417,6 +438,9 @@ def __init__(
417438
f" are not supported for dataset_type '{dataset_type}'."
418439
f" Please check the API documentation for the supported dataset types."
419440
)
441+
# features JMESPath is required for JSON as we can't derive it ourselves
442+
if dataset_type == "application/json" and features is None:
443+
raise ValueError("features JMESPath is required for application/json dataset_type")
420444
self.s3_data_input_path = s3_data_input_path
421445
self.s3_output_path = s3_output_path
422446
self.s3_analysis_config_output_path = s3_analysis_config_output_path
@@ -571,11 +595,13 @@ def __init__(
571595
Cannot be set when ``endpoint_name`` is set.
572596
Must be set with ``instance_count``, ``model_name``
573597
accept_type (str): The model output format to be used for getting inferences with the
574-
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
575-
``"application/jsonlines"``. Default is the same as ``content_type``.
598+
shadow endpoint. Valid values are ``"text/csv"`` for CSV,
599+
``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
600+
Default is the same as ``content_type``.
576601
content_type (str): The model input format to be used for getting inferences with the
577602
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
578-
``"application/jsonlines"``. Default is the same as ``dataset_format``.
603+
``"application/jsonlines"`` for JSON Lines. Default is the same as
604+
``dataset_format``.
579605
content_template (str): A template string to be used to construct the model input from
580606
dataset instances. It is only used when ``model_content_type`` is
581607
``"application/jsonlines"``. The template should have one and only one placeholder,
@@ -641,7 +667,7 @@ def __init__(
641667
)
642668
self.predictor_config["endpoint_name_prefix"] = endpoint_name_prefix
643669
if accept_type is not None:
644-
if accept_type not in ["text/csv", "application/jsonlines"]:
670+
if accept_type not in ["text/csv", "application/jsonlines", "application/json"]:
645671
raise ValueError(
646672
f"Invalid accept_type {accept_type}."
647673
f" Please choose text/csv or application/jsonlines."

tests/unit/test_clarify.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,57 @@ def test_uri():
4242
assert "306415355426.dkr.ecr.us-west-2.amazonaws.com/sagemaker-clarify-processing:1.0" == uri
4343

4444

45-
def test_data_config():
45+
@pytest.mark.parametrize(
46+
("dataset_type", "features", "excluded_columns", "predicted_label"),
47+
[
48+
("text/csv", None, ["F4"], "Predicted Label"),
49+
("application/jsonlines", None, ["F4"], "Predicted Label"),
50+
("application/json", "[*].[F1,F2,F3]", ["F4"], "Predicted Label"),
51+
("application/x-parquet", None, ["F4"], "Predicted Label"),
52+
],
53+
)
54+
def test_data_config(dataset_type, features, excluded_columns, predicted_label):
4655
# facets in input dataset
4756
s3_data_input_path = "s3://path/to/input.csv"
4857
s3_output_path = "s3://path/to/output"
4958
label_name = "Label"
50-
headers = [
51-
"Label",
52-
"F1",
53-
"F2",
54-
"F3",
55-
"F4",
56-
]
57-
dataset_type = "text/csv"
59+
headers = ["Label", "F1", "F2", "F3", "F4", "Predicted Label"]
5860
data_config = DataConfig(
5961
s3_data_input_path=s3_data_input_path,
6062
s3_output_path=s3_output_path,
63+
features=features,
6164
label=label_name,
6265
headers=headers,
6366
dataset_type=dataset_type,
67+
excluded_columns=excluded_columns,
68+
predicted_label=predicted_label,
6469
)
6570

6671
expected_config = {
67-
"dataset_type": "text/csv",
72+
"dataset_type": dataset_type,
6873
"headers": headers,
6974
"label": "Label",
7075
}
76+
if features:
77+
expected_config["features"] = features
78+
if excluded_columns:
79+
expected_config["excluded_columns"] = excluded_columns
80+
if predicted_label:
81+
expected_config["predicted_label"] = predicted_label
7182

7283
assert expected_config == data_config.get_config()
7384
assert s3_data_input_path == data_config.s3_data_input_path
7485
assert s3_output_path == data_config.s3_output_path
7586
assert "None" == data_config.s3_compression_type
7687
assert "FullyReplicated" == data_config.s3_data_distribution_type
7788

89+
90+
def test_data_config_with_separate_facet_dataset():
91+
s3_data_input_path = "s3://path/to/input.csv"
92+
s3_output_path = "s3://path/to/output"
93+
label_name = "Label"
94+
headers = ["Label", "F1", "F2", "F3", "F4"]
95+
7896
# facets NOT in input dataset
7997
joinsource = 5
8098
facet_dataset_uri = "s3://path/to/facet.csv"
@@ -89,7 +107,7 @@ def test_data_config():
89107
s3_output_path=s3_output_path,
90108
label=label_name,
91109
headers=headers,
92-
dataset_type=dataset_type,
110+
dataset_type="text/csv",
93111
joinsource=joinsource,
94112
facet_dataset_uri=facet_dataset_uri,
95113
facet_headers=facet_headers,
@@ -126,7 +144,7 @@ def test_data_config():
126144
s3_output_path=s3_output_path,
127145
label=label_name,
128146
headers=headers,
129-
dataset_type=dataset_type,
147+
dataset_type="text/csv",
130148
joinsource=joinsource,
131149
excluded_columns=excluded_columns,
132150
)
@@ -158,7 +176,7 @@ def test_invalid_data_config():
158176
DataConfig(
159177
s3_data_input_path="s3://bucket/inputpath",
160178
s3_output_path="s3://bucket/outputpath",
161-
dataset_type="application/x-parquet",
179+
dataset_type="application/x-image",
162180
predicted_label="label",
163181
)
164182
error_msg = r"^The parameter 'excluded_columns' is not supported for dataset_type"
@@ -189,6 +207,28 @@ def test_invalid_data_config():
189207
)
190208

191209

210+
# features JMESPath is required for JSON dataset types
211+
def test_json_type_data_config_missing_features():
212+
# facets in input dataset
213+
s3_data_input_path = "s3://path/to/input.csv"
214+
s3_output_path = "s3://path/to/output"
215+
label_name = "Label"
216+
headers = ["Label", "F1", "F2", "F3", "F4", "Predicted Label"]
217+
with pytest.raises(
218+
ValueError, match="features JMESPath is required for application/json dataset_type"
219+
):
220+
DataConfig(
221+
s3_data_input_path=s3_data_input_path,
222+
s3_output_path=s3_output_path,
223+
features=None,
224+
label=label_name,
225+
headers=headers,
226+
dataset_type="application/json",
227+
excluded_columns=["F4"],
228+
predicted_label="Predicted Label",
229+
)
230+
231+
192232
def test_s3_data_distribution_type_ignorance():
193233
data_config = DataConfig(
194234
s3_data_input_path="s3://input/train.csv",
@@ -344,12 +384,25 @@ def test_facet_of_bias_config(facet_name, facet_values_or_threshold, expected_re
344384
assert bias_config.get_config() == expected_config
345385

346386

347-
def test_model_config():
387+
@pytest.mark.parametrize(
388+
("content_type", "accept_type"),
389+
[
390+
# All the combinations of content_type and accept_type should be acceptable
391+
("text/csv", "text/csv"),
392+
("application/jsonlines", "application/jsonlines"),
393+
("text/csv", "application/json"),
394+
("application/jsonlines", "application/json"),
395+
("application/jsonlines", "text/csv"),
396+
("image/jpeg", "text/csv"),
397+
("image/jpg", "text/csv"),
398+
("image/png", "text/csv"),
399+
("application/x-npy", "text/csv"),
400+
],
401+
)
402+
def test_valid_model_config(content_type, accept_type):
348403
model_name = "xgboost-model"
349404
instance_type = "ml.c5.xlarge"
350405
instance_count = 1
351-
accept_type = "text/csv"
352-
content_type = "application/jsonlines"
353406
custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
354407
target_model = "target_model_name"
355408
accelerator_type = "ml.eia1.medium"

0 commit comments

Comments
 (0)