Skip to content

Commit 311f486

Browse files
fix: replace all added asserts with ValueError
1 parent 222fb75 commit 311f486

File tree

2 files changed

+100
-104
lines changed

2 files changed

+100
-104
lines changed

src/sagemaker/clarify.py

Lines changed: 67 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -587,13 +587,15 @@ def __init__(
587587
when dataset is in JSON format.
588588
589589
Raises:
590-
AssertionError: If any required arguments are not provided.
591-
ValueError: If any provided arguments are the wrong type.
590+
ValueError: If any required arguments are not provided or are the wrong type.
592591
"""
593592
# check target_time_series, item_id, and timestamp are provided
594-
assert target_time_series, "Please provide a target time series."
595-
assert item_id, "Please provide an item id."
596-
assert timestamp, "Please provide a timestamp."
593+
if not target_time_series:
594+
raise ValueError("Please provide a target time series.")
595+
if not item_id:
596+
raise ValueError("Please provide an item id.")
597+
if not timestamp:
598+
raise ValueError("Please provide a timestamp.")
597599
# check all arguments are the right types
598600
if not isinstance(target_time_series, (str, int)):
599601
raise ValueError("Please provide a string or an int for ``target_time_series``")
@@ -644,14 +646,14 @@ def __init__(
644646
) # static_covariates is valid, add it
645647
if params_type == str:
646648
# check dataset_format is provided and valid
647-
assert isinstance(
648-
dataset_format, TimeSeriesJSONDatasetFormat
649-
), "Please provide a valid dataset format."
649+
if not isinstance(dataset_format, TimeSeriesJSONDatasetFormat):
650+
raise ValueError("Please provide a valid dataset format.")
650651
_set(dataset_format.value, "dataset_format", self.time_series_data_config)
651652
else:
652-
assert (
653-
not dataset_format
654-
), "Dataset format should only be provided when data files are JSONs."
653+
if dataset_format:
654+
raise ValueError(
655+
"Dataset format should only be provided when data files are JSONs."
656+
)
655657

656658
def get_time_series_data_config(self):
657659
"""Returns part of an analysis config dictionary."""
@@ -960,16 +962,14 @@ def __init__(
960962
forecast (str): JMESPath expression to extract the forecast result.
961963
962964
Raises:
963-
AssertionError: when ``forecast`` is not provided
964-
ValueError: when any provided argument are not of specified type
965+
ValueError: when ``forecast`` is not a string or not provided
965966
"""
966-
# assert forecast is provided
967-
assert (
968-
forecast
969-
), "Please provide ``forecast``, a JMESPath expression to extract the forecast result."
970-
# check provided arguments are of the right type
967+
# check string forecast is provided
971968
if not isinstance(forecast, str):
972-
raise ValueError("Please provide a string JMESPath expression for ``forecast``.")
969+
raise ValueError(
970+
"Please provide a string JMESPath expression for ``forecast`` "
971+
"to extract the forecast result."
972+
)
973973
# add fields to an internal config dictionary
974974
self.time_series_model_config = dict()
975975
_set(forecast, "forecast", self.time_series_model_config)
@@ -1796,45 +1796,49 @@ def __init__(
17961796
}
17971797
17981798
Raises:
1799-
AssertionError: when ``direction`` or ``granularity`` are not valid,
1800-
or ``num_samples`` is not provided for fine-grained explanations
1801-
ValueError: when ``num_samples`` is provided for non fine-grained explanations, or
1802-
when direction is not ``"chronological"`` when granularity is
1803-
``"fine_grained"``.
1799+
ValueError: when ``direction`` or ``granularity`` are not valid, ``num_samples`` is not
1800+
provided for fine-grained explanations, ``num_samples`` is provided for non
1801+
fine-grained explanations, or when ``direction`` is not ``"chronological"`` while
1802+
``granularity`` is ``"fine_grained"``.
18041803
"""
18051804
self.asymmetric_shapley_value_config = dict()
18061805
# validate explanation direction
1807-
assert (
1808-
direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS
1809-
), "Please provide a valid explanation direction from: " + ", ".join(
1810-
ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS
1811-
)
1806+
if direction not in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS:
1807+
raise ValueError(
1808+
"Please provide a valid explanation direction from: "
1809+
+ ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS)
1810+
)
18121811
# validate granularity
1813-
assert (
1814-
granularity in ASYM_SHAP_VAL_GRANULARITIES
1815-
), "Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES)
1812+
if granularity not in ASYM_SHAP_VAL_GRANULARITIES:
1813+
raise ValueError(
1814+
"Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES)
1815+
)
18161816
if granularity == "fine_grained":
1817-
assert isinstance(num_samples, int), "Please provide an integer for ``num_samples``."
1818-
assert (
1819-
direction == "chronological"
1820-
), f"{direction} and {granularity} granularity are not supported together."
1817+
if not isinstance(num_samples, int):
1818+
raise ValueError("Please provide an integer for ``num_samples``.")
1819+
if direction != "chronological":
1820+
raise ValueError(
1821+
f"{direction} and {granularity} granularity are not supported together."
1822+
)
18211823
elif num_samples: # validate num_samples is not provided when unnecessary
18221824
raise ValueError("``num_samples`` is only used for fine-grained explanations.")
18231825
# validate baseline if provided as a dictionary
18241826
if isinstance(baseline, dict):
18251827
temporal_baselines = ["zero", "mean"] # possible baseline options for temporal fields
18261828
if "target_time_series" in baseline:
18271829
target_baseline = baseline.get("target_time_series")
1828-
assert target_baseline in temporal_baselines, (
1829-
f"Provided value {target_baseline} for ``target_time_series`` is "
1830-
f"invalid. Please select one of {temporal_baselines}."
1831-
)
1830+
if target_baseline not in temporal_baselines:
1831+
raise ValueError(
1832+
f"Provided value {target_baseline} for ``target_time_series`` is "
1833+
f"invalid. Please select one of {temporal_baselines}."
1834+
)
18321835
if "related_time_series" in baseline:
18331836
related_baseline = baseline.get("related_time_series")
1834-
assert related_baseline in temporal_baselines, (
1835-
f"Provided value {related_baseline} for ``related_time_series`` is "
1836-
f"invalid. Please select one of {temporal_baselines}."
1837-
)
1837+
if related_baseline not in temporal_baselines:
1838+
raise ValueError(
1839+
f"Provided value {related_baseline} for ``related_time_series`` is "
1840+
f"invalid. Please select one of {temporal_baselines}."
1841+
)
18381842
# set explanation type and (if provided) num_samples in internal config dictionary
18391843
_set(direction, "direction", self.asymmetric_shapley_value_config)
18401844
_set(granularity, "granularity", self.asymmetric_shapley_value_config)
@@ -2550,25 +2554,27 @@ def explainability(
25502554
"""Generates a config for Explainability"""
25512555
# determine if this is a time series explainability case by checking
25522556
# if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given
2553-
ts_data_config_present = "time_series_data_config" in data_config.analysis_config
2554-
ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config
2557+
ts_data_conf_absent = "time_series_data_config" not in data_config.analysis_config
2558+
ts_model_conf_absent = "time_series_predictor_config" not in model_config.predictor_config
25552559

25562560
if isinstance(explainability_config, AsymmetricShapleyValueConfig):
2557-
assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig."
2558-
assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig."
2561+
if ts_data_conf_absent:
2562+
raise ValueError("Please provide a TimeSeriesDataConfig to DataConfig.")
2563+
if ts_model_conf_absent:
2564+
raise ValueError("Please provide a TimeSeriesModelConfig to ModelConfig.")
25592565
# Check static covariates baseline matches number of provided static covariate columns
25602566
_AnalysisConfigGenerator._validate_time_series_static_covariates_baseline(
25612567
explainability_config=explainability_config,
25622568
data_config=data_config,
25632569
)
25642570
else:
2565-
if ts_data_config_present:
2571+
if not ts_data_conf_absent:
25662572
raise ValueError(
25672573
"Please provide an AsymmetricShapleyValueConfig for time series "
25682574
"explainability cases. For non time series cases, please do not provide a "
25692575
"TimeSeriesDataConfig."
25702576
)
2571-
if ts_model_config_present:
2577+
if not ts_model_conf_absent:
25722578
raise ValueError(
25732579
"Please provide an AsymmetricShapleyValueConfig for time series "
25742580
"explainability cases. For non time series cases, please do not provide a "
@@ -2786,15 +2792,17 @@ def _validate_time_series_static_covariates_baseline(
27862792
if covariate_count > 0:
27872793
for item_id in baseline.get("static_covariates", []):
27882794
baseline_entry = baseline["static_covariates"][item_id]
2789-
assert isinstance(baseline_entry, list), (
2790-
f"Baseline entry for {item_id} must be a list, is "
2791-
f"{type(baseline_entry)}."
2792-
)
2793-
assert len(baseline_entry) == covariate_count, (
2794-
f"Length of baseline entry for {item_id} does not match number "
2795-
f"of static covariate columns. Please ensure every covariate "
2796-
f"has a baseline value for every item id."
2797-
)
2795+
if not isinstance(baseline_entry, list):
2796+
raise ValueError(
2797+
f"Baseline entry for {item_id} must be a list, is "
2798+
f"{type(baseline_entry)}."
2799+
)
2800+
if len(baseline_entry) != covariate_count:
2801+
raise ValueError(
2802+
f"Length of baseline entry for {item_id} does not match number "
2803+
f"of static covariate columns. Please ensure every covariate "
2804+
f"has a baseline value for every item id."
2805+
)
27982806
else:
27992807
raise ValueError(
28002808
"Static covariate baselines are provided in AsymmetricShapleyValueConfig "

0 commit comments

Comments
 (0)