Skip to content

Commit 640dbd1

Browse files
feat: add ITEM_RECORDS as a supported dataset format
change: remove ``headers`` as a requirement for time series doc: add example dataset formats to ``TimeSeriesJSONDatasetFormat``
1 parent 9e80a48 commit 640dbd1

File tree

1 file changed

+93
-3
lines changed

1 file changed

+93
-3
lines changed

src/sagemaker/clarify.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,95 @@ class DatasetType(Enum):
389389

390390

391391
class TimeSeriesJSONDatasetFormat(Enum):
392-
"""Possible dataset formats for JSON time series data files."""
392+
"""Possible dataset formats for JSON time series data files.
393+
394+
Below is an example ``COLUMNS`` dataset for time series explainability:
395+
396+
```
397+
{
398+
"ids": [1, 2],
399+
"timestamps": [3, 4],
400+
"target_ts": [5, 6],
401+
"rts1": [0.25, 0.5],
402+
"rts2": [1.25, 1.5],
403+
"scv1": [10, 20],
404+
"scv2": [30, 40]
405+
}
406+
407+
```
408+
409+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:
410+
411+
```
412+
item_id="ids"
413+
timestamp="timestamps"
414+
target_time_series="target_ts"
415+
related_time_series=["rts1", "rts2"]
416+
static_covariates=["scv1", "scv2"]
417+
```
418+
419+
Below is an example ``ITEM_RECORDS`` dataset for time series explainability:
420+
421+
```
422+
[
423+
{
424+
"id": 1,
425+
"scv1": 10,
426+
"scv2": "red",
427+
"timeseries": [
428+
{"timestamp": 1, "target_ts": 5, "rts1": 0.25, "rts2": 10},
429+
{"timestamp": 2, "target_ts": 6, "rts1": 0.35, "rts2": 20},
430+
{"timestamp": 3, "target_ts": 4, "rts1": 0.45, "rts2": 30}
431+
]
432+
},
433+
{
434+
"id": 2,
435+
"scv1": 20,
436+
"scv2": "blue",
437+
"timeseries": [
438+
{"timestamp": 1, "target_ts": 4, "rts1": 0.25, "rts2": 40},
439+
{"timestamp": 2, "target_ts": 2, "rts1": 0.35, "rts2": 50}
440+
]
441+
}
442+
]
443+
```
444+
445+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:
446+
447+
```
448+
item_id="[*].id"
449+
timestamp="[*].timeseries[].timestamp"
450+
target_time_series="[*].timeseries[].target_ts"
451+
related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"]
452+
static_covariates=["[*].scv1", "[*].scv2"]
453+
```
454+
455+
Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability:
456+
457+
```
458+
[
459+
{"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25},
460+
{"id": 1, "timestamp": 2, "target_ts": 6, "scv1": 10, "rts1": 0.5},
461+
{"id": 1, "timestamp": 3, "target_ts": 3, "scv1": 10, "rts1": 0.75},
462+
{"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1}
463+
]
464+
465+
```
466+
467+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:
468+
469+
```
470+
item_id="[*].id"
471+
timestamp="[*].timestamp"
472+
target_time_series="[*].target_ts"
473+
related_time_series=["[*].rts1"]
474+
static_covariates=["[*].scv1"]
475+
```
476+
477+
"""
393478

394479
COLUMNS = "columns"
480+
ITEM_RECORDS = "item_records"
395481
TIMESTAMP_RECORDS = "timestamp_records"
396482

397483

@@ -607,6 +693,11 @@ def __init__(
607693
Note: For JSON, the JMESPath query must result in a list of labels for each
608694
sample. For JSON Lines, it must result in the label for each line.
609695
Only a single label per sample is supported at this time.
696+
headers (str): List of column names in the dataset. If not provided, Clarify will
697+
generate headers to use internally. For time series explainability cases,
698+
please provide headers in the following order:
699+
item_id, timestamp, target_time_series, all related_time_series columns,
700+
all static_covariate columns
610701
features (str): JMESPath expression to locate the feature values
611702
if the dataset format is JSON/JSON Lines.
612703
Note: For JSON, the JMESPath query must result in a 2-D list (or a matrix) of
@@ -716,9 +807,8 @@ def __init__(
716807
if time_series_data_config:
717808
if dataset_type != "application/json":
718809
raise ValueError(
719-
"Currently time series explainability only supports JSON format data"
810+
"Currently time series explainability only supports JSON format data."
720811
)
721-
assert headers, "Headers are required for time series explainability"
722812
# features JMESPath is required for JSON as we can't derive it ourselves
723813
if dataset_type == "application/json" and features is None and not time_series_data_config:
724814
raise ValueError("features JMESPath is required for application/json dataset_type")

0 commit comments

Comments
 (0)