Skip to content

Commit d0002bc

Browse files
change: rename item_metadata to static_covariates
change: add ``dataset_format`` as a parameter for time series cases change: allow features jmespaths to be none for time series cases change: add validation to prevent non-json dataset formats for time series cases test: update unit tests to reflect above changes
1 parent 8003172 commit d0002bc

File tree

2 files changed

+184
-92
lines changed

2 files changed

+184
-92
lines changed

src/sagemaker/clarify.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,16 @@
102102
"item_id": Or(str, int),
103103
"timestamp": Or(str, int),
104104
SchemaOptional("related_time_series"): Or([str], [int]),
105-
SchemaOptional("item_metadata"): Or([str], [int]),
105+
SchemaOptional("static_covariates"): Or([str], [int]),
106+
SchemaOptional("dataset_format"): And(
107+
str,
108+
Use(str.lower),
109+
lambda s: s
110+
in (
111+
"columns",
112+
"timestamp_records",
113+
),
114+
),
106115
},
107116
"methods": {
108117
SchemaOptional("shap"): {
@@ -370,6 +379,13 @@ class DatasetType(Enum):
370379
IMAGE = "application/x-image"
371380

372381

382+
class TimeSeriesJSONDatasetFormat(Enum):
383+
"""Possible dataset formats for JSON time series data files."""
384+
385+
COLUMNS = "columns"
386+
TIMESTAMP_RECORDS = "timestamp_records"
387+
388+
373389
class SegmentationConfig:
374390
"""Config object that defines segment(s) of the dataset on which metrics are computed."""
375391

@@ -447,26 +463,31 @@ def __init__(
447463
item_id: Union[str, int],
448464
timestamp: Union[str, int],
449465
related_time_series: Optional[List[Union[str, int]]] = None,
450-
item_metadata: Optional[List[Union[str, int]]] = None,
466+
static_covariates: Optional[List[Union[str, int]]] = None,
467+
dataset_format: Optional[TimeSeriesJSONDatasetFormat] = None,
451468
):
452469
"""Initialises TimeSeries explainability data configuration fields.
453470
454471
Args:
455472
target_time_series (str or int): A string or a zero-based integer index.
456473
Used to locate the target time series in the shared input dataset.
457-
If this parameter is a string, then all other parameters must also
458-
be strings or lists of strings. If this parameter is an int, then
459-
all others must be ints or lists of ints.
474+
If this parameter is a string, then all other parameters except
475+
`dataset_format` must be strings or lists of strings. If
476+
this parameter is an int, then all other parameters except
477+
`dataset_format` must be ints or lists of ints.
460478
item_id (str or int): A string or a zero-based integer index. Used to
461479
locate item id in the shared input dataset.
462480
timestamp (str or int): A string or a zero-based integer index. Used to
463481
locate timestamp in the shared input dataset.
464482
related_time_series (list[str] or list[int]): Optional. An array of strings
465483
or array of zero-based integer indices. Used to locate all related time
466484
series in the shared input dataset (if present).
467-
item_metadata (list[str] or list[int]): Optional. An array of strings or
468-
array of zero-based integer indices. Used to locate all item metadata
485+
static_covariates (list[str] or list[int]): Optional. An array of strings or
486+
array of zero-based integer indices. Used to locate all static covariate
469487
fields in the shared input dataset (if present).
488+
dataset_format (TimeSeriesJSONDatasetFormat): Describes the format
489+
of the data files provided for analysis. Should only be provided
490+
when dataset is in JSON format.
470491
471492
Raises:
472493
AssertionError: If any required arguments are not provided.
@@ -484,7 +505,7 @@ def __init__(
484505
raise ValueError(f"Please provide {params_type} for ``item_id``")
485506
if not isinstance(timestamp, params_type):
486507
raise ValueError(f"Please provide {params_type} for ``timestamp``")
487-
# add remaining fields to an internal dictionary
508+
# add mandatory fields to an internal dictionary
488509
self.time_series_data_config = dict()
489510
_set(target_time_series, "target_time_series", self.time_series_data_config)
490511
_set(item_id, "item_id", self.time_series_data_config)
@@ -502,22 +523,32 @@ def __init__(
502523
raise ValueError(
503524
related_time_series_error_message
504525
) # related_time_series is not a list of strings or list of ints
526+
if params_type == str and not all(related_time_series):
527+
raise ValueError("Please do not provide empty strings in ``related_time_series``.")
505528
_set(
506529
related_time_series, "related_time_series", self.time_series_data_config
507530
) # related_time_series is valid, add it
508-
item_metadata_series_error_message = (
509-
f"Please provide a list of {params_type} for ``item_metadata``"
531+
static_covariates_series_error_message = (
532+
f"Please provide a list of {params_type} for ``static_covariates``"
510533
)
511-
if item_metadata:
512-
if not isinstance(item_metadata, list):
513-
raise ValueError(item_metadata_series_error_message) # item_metadata is not a list
514-
if not all([isinstance(value, params_type) for value in item_metadata]):
534+
if static_covariates:
535+
if not isinstance(static_covariates, list):
536+
raise ValueError(static_covariates_series_error_message) # static_covariates is not a list
537+
if not all([isinstance(value, params_type) for value in static_covariates]):
515538
raise ValueError(
516-
item_metadata_series_error_message
517-
) # item_metadata is not a list of strings or list of ints
539+
static_covariates_series_error_message
540+
) # static_covariates is not a list of strings or list of ints
541+
if params_type == str and not all(static_covariates):
542+
raise ValueError("Please do not provide empty strings in ``static_covariates``.")
518543
_set(
519-
item_metadata, "item_metadata", self.time_series_data_config
520-
) # item_metadata is valid, add it
544+
static_covariates, "static_covariates", self.time_series_data_config
545+
) # static_covariates is valid, add it
546+
if params_type == str:
547+
# check dataset_format is provided and valid
548+
assert isinstance(dataset_format, TimeSeriesJSONDatasetFormat), "Please provide a valid dataset format."
549+
_set(dataset_format.value, "dataset_format", self.time_series_data_config)
550+
else:
551+
assert not dataset_format, "Dataset format should only be provided when data files are JSONs."
521552

522553
def get_time_series_data_config(self):
523554
"""Returns part of an analysis config dictionary."""
@@ -666,8 +697,11 @@ def __init__(
666697
f" are not supported for dataset_type '{dataset_type}'."
667698
f" Please check the API documentation for the supported dataset types."
668699
)
700+
# check if any other format other than JSON is provided for time series case
701+
if time_series_data_config and dataset_type != "application/json":
702+
raise ValueError("Currently time series explainability only supports JSON format data")
669703
# features JMESPath is required for JSON as we can't derive it ourselves
670-
if dataset_type == "application/json" and features is None:
704+
if dataset_type == "application/json" and features is None and not time_series_data_config:
671705
raise ValueError("features JMESPath is required for application/json dataset_type")
672706
self.s3_data_input_path = s3_data_input_path
673707
self.s3_output_path = s3_output_path

0 commit comments

Comments
 (0)