102
102
"item_id" : Or (str , int ),
103
103
"timestamp" : Or (str , int ),
104
104
SchemaOptional ("related_time_series" ): Or ([str ], [int ]),
105
- SchemaOptional ("item_metadata" ): Or ([str ], [int ]),
105
+ SchemaOptional ("static_covariates" ): Or ([str ], [int ]),
106
+ SchemaOptional ("dataset_format" ): And (
107
+ str ,
108
+ Use (str .lower ),
109
+ lambda s : s
110
+ in (
111
+ "columns" ,
112
+ "timestamp_records" ,
113
+ ),
114
+ ),
106
115
},
107
116
"methods" : {
108
117
SchemaOptional ("shap" ): {
@@ -370,6 +379,13 @@ class DatasetType(Enum):
370
379
IMAGE = "application/x-image"
371
380
372
381
382
+ class TimeSeriesJSONDatasetFormat (Enum ):
383
+ """Possible dataset formats for JSON time series data files."""
384
+
385
+ COLUMNS = "columns"
386
+ TIMESTAMP_RECORDS = "timestamp_records"
387
+
388
+
373
389
class SegmentationConfig :
374
390
"""Config object that defines segment(s) of the dataset on which metrics are computed."""
375
391
@@ -447,26 +463,31 @@ def __init__(
447
463
item_id : Union [str , int ],
448
464
timestamp : Union [str , int ],
449
465
related_time_series : Optional [List [Union [str , int ]]] = None ,
450
- item_metadata : Optional [List [Union [str , int ]]] = None ,
466
+ static_covariates : Optional [List [Union [str , int ]]] = None ,
467
+ dataset_format : Optional [TimeSeriesJSONDatasetFormat ] = None ,
451
468
):
452
469
"""Initialises TimeSeries explainability data configuration fields.
453
470
454
471
Args:
455
472
target_time_series (str or int): A string or a zero-based integer index.
456
473
Used to locate the target time series in the shared input dataset.
457
- If this parameter is a string, then all other parameters must also
458
- be strings or lists of strings. If this parameter is an int, then
459
- all others must be ints or lists of ints.
474
+ If this parameter is a string, then all other parameters except
475
+ `dataset_format` must be strings or lists of strings. If
476
+ this parameter is an int, then all other parameters except
477
+ `dataset_format` must be ints or lists of ints.
460
478
item_id (str or int): A string or a zero-based integer index. Used to
461
479
locate item id in the shared input dataset.
462
480
timestamp (str or int): A string or a zero-based integer index. Used to
463
481
locate timestamp in the shared input dataset.
464
482
related_time_series (list[str] or list[int]): Optional. An array of strings
465
483
or array of zero-based integer indices. Used to locate all related time
466
484
series in the shared input dataset (if present).
467
- item_metadata (list[str] or list[int]): Optional. An array of strings or
468
- array of zero-based integer indices. Used to locate all item metadata
485
+ static_covariates (list[str] or list[int]): Optional. An array of strings or
486
+ array of zero-based integer indices. Used to locate all static covariate
469
487
fields in the shared input dataset (if present).
488
+ dataset_format (TimeSeriesJSONDatasetFormat): Describes the format
489
+ of the data files provided for analysis. Should only be provided
490
+ when dataset is in JSON format.
470
491
471
492
Raises:
472
493
AssertionError: If any required arguments are not provided.
@@ -484,7 +505,7 @@ def __init__(
484
505
raise ValueError (f"Please provide { params_type } for ``item_id``" )
485
506
if not isinstance (timestamp , params_type ):
486
507
raise ValueError (f"Please provide { params_type } for ``timestamp``" )
487
- # add remaining fields to an internal dictionary
508
+ # add mandatory fields to an internal dictionary
488
509
self .time_series_data_config = dict ()
489
510
_set (target_time_series , "target_time_series" , self .time_series_data_config )
490
511
_set (item_id , "item_id" , self .time_series_data_config )
@@ -502,22 +523,32 @@ def __init__(
502
523
raise ValueError (
503
524
related_time_series_error_message
504
525
) # related_time_series is not a list of strings or list of ints
526
+ if params_type == str and not all (related_time_series ):
527
+ raise ValueError ("Please do not provide empty strings in ``related_time_series``." )
505
528
_set (
506
529
related_time_series , "related_time_series" , self .time_series_data_config
507
530
) # related_time_series is valid, add it
508
- item_metadata_series_error_message = (
509
- f"Please provide a list of { params_type } for ``item_metadata ``"
531
+ static_covariates_series_error_message = (
532
+ f"Please provide a list of { params_type } for ``static_covariates ``"
510
533
)
511
- if item_metadata :
512
- if not isinstance (item_metadata , list ):
513
- raise ValueError (item_metadata_series_error_message ) # item_metadata is not a list
514
- if not all ([isinstance (value , params_type ) for value in item_metadata ]):
534
+ if static_covariates :
535
+ if not isinstance (static_covariates , list ):
536
+ raise ValueError (static_covariates_series_error_message ) # static_covariates is not a list
537
+ if not all ([isinstance (value , params_type ) for value in static_covariates ]):
515
538
raise ValueError (
516
- item_metadata_series_error_message
517
- ) # item_metadata is not a list of strings or list of ints
539
+ static_covariates_series_error_message
540
+ ) # static_covariates is not a list of strings or list of ints
541
+ if params_type == str and not all (static_covariates ):
542
+ raise ValueError ("Please do not provide empty strings in ``static_covariates``." )
518
543
_set (
519
- item_metadata , "item_metadata" , self .time_series_data_config
520
- ) # item_metadata is valid, add it
544
+ static_covariates , "static_covariates" , self .time_series_data_config
545
+ ) # static_covariates is valid, add it
546
+ if params_type == str :
547
+ # check dataset_format is provided and valid
548
+ assert isinstance (dataset_format , TimeSeriesJSONDatasetFormat ), "Please provide a valid dataset format."
549
+ _set (dataset_format .value , "dataset_format" , self .time_series_data_config )
550
+ else :
551
+ assert not dataset_format , "Dataset format should only be provided when data files are JSONs."
521
552
522
553
def get_time_series_data_config (self ):
523
554
"""Returns part of an analysis config dictionary."""
@@ -666,8 +697,11 @@ def __init__(
666
697
f" are not supported for dataset_type '{ dataset_type } '."
667
698
f" Please check the API documentation for the supported dataset types."
668
699
)
700
+ # check if any other format other than JSON is provided for time series case
701
+ if time_series_data_config and dataset_type != "application/json" :
702
+ raise ValueError ("Currently time series explainability only supports JSON format data" )
669
703
# features JMESPath is required for JSON as we can't derive it ourselves
670
- if dataset_type == "application/json" and features is None :
704
+ if dataset_type == "application/json" and features is None and not time_series_data_config :
671
705
raise ValueError ("features JMESPath is required for application/json dataset_type" )
672
706
self .s3_data_input_path = s3_data_input_path
673
707
self .s3_output_path = s3_output_path
0 commit comments