Skip to content

Commit bc948e5

Browse files
authored
feature: Add segment config for Clarify (#3923)
1 parent de7e204 commit bc948e5

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

src/sagemaker/clarify.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@
6363
SchemaOptional("features"): str,
6464
SchemaOptional("label_values_or_threshold"): [Or(int, float, str)],
6565
SchemaOptional("probability_threshold"): float,
66+
SchemaOptional("segment_config"): [
67+
{
68+
SchemaOptional("config_name"): str,
69+
"name_or_index": Or(str, int),
70+
"segments": [[Or(str, int)]],
71+
SchemaOptional("display_aliases"): [str],
72+
}
73+
],
6674
SchemaOptional("facet"): [
6775
{
6876
"name_or_index": Or(str, int),
@@ -316,6 +324,74 @@ class DatasetType(Enum):
316324
IMAGE = "application/x-image"
317325

318326

327+
class SegmentationConfig:
328+
"""Config object that defines segment(s) of the dataset on which metrics are computed."""
329+
330+
def __init__(
331+
self,
332+
name_or_index: Union[str, int],
333+
segments: List[List[Union[str, int]]],
334+
config_name: Optional[str] = None,
335+
display_aliases: Optional[List[str]] = None,
336+
):
337+
"""Initializes a segmentation configuration for a dataset column.
338+
339+
Args:
340+
name_or_index (str or int): The name or index of the column in the dataset on which
341+
the segment(s) is defined.
342+
segments (List[List[str or int]]): Each List of values represents one segment. If N
343+
Lists are provided, we generate N+1 segments - the additional segment, denoted as
344+
the '__default__' segment, is for the rest of the values that are not covered by
345+
these lists. For continuous columns, a segment must be given as strings in interval
346+
notation (eg.: ["[1, 4]"] or ["(2, 5]"]). A segment can also be composed of
347+
multiple intervals (eg.: ["[1, 4]", "(5, 6]"] is one segment). For categorical
348+
columns, each segment should contain one or more of the categorical values for
349+
the categorical column, which may be strings or integers.
350+
Eg,: For a continuous column, ``segments`` could be
351+
[["[1, 4]", "(5, 6]"], ["(7, 9)"]] - this generates 3 segments including the
352+
default segment. For a categorical columns with values ("A", "B", "C", "D"),
353+
``segments``,could be [["A", "B"]]. This generate 2 segments, including the default
354+
segment.
355+
config_name (str) - Optional name for the segment config to identify the config.
356+
display_aliases (List[str]) - Optional list of display names for the ``segments`` for
357+
the analysis output and report. This list should be the same length as the number of
358+
lists provided in ``segments`` or with one additional display alias for the default
359+
segment.
360+
361+
Raises:
362+
ValueError: when the ``name_or_index`` is None, ``segments`` is invalid, or a wrong
363+
number of ``display_aliases`` are specified.
364+
"""
365+
if name_or_index is None:
366+
raise ValueError("`name_or_index` cannot be None")
367+
self.name_or_index = name_or_index
368+
if (
369+
not segments
370+
or not isinstance(segments, list)
371+
or not all([isinstance(segment, list) for segment in segments])
372+
):
373+
raise ValueError("`segments` must be a list of lists of values or intervals.")
374+
self.segments = segments
375+
self.config_name = config_name
376+
if display_aliases is not None and not (
377+
len(display_aliases) == len(segments) or len(display_aliases) == len(segments) + 1
378+
):
379+
raise ValueError(
380+
"Number of `display_aliases` must equal the number of segments"
381+
" specified or with one additional default segment display alias."
382+
)
383+
self.display_aliases = display_aliases
384+
385+
def to_dict(self) -> Dict[str, Any]: # pragma: no cover
386+
"""Returns SegmentationConfig as a dict."""
387+
segment_config_dict = {"name_or_index": self.name_or_index, "segments": self.segments}
388+
if self.config_name:
389+
segment_config_dict["config_name"] = self.config_name
390+
if self.display_aliases:
391+
segment_config_dict["display_aliases"] = self.display_aliases
392+
return segment_config_dict
393+
394+
319395
class DataConfig:
320396
"""Config object related to configurations of the input and output dataset."""
321397

@@ -336,6 +412,7 @@ def __init__(
336412
predicted_label_headers: Optional[List[str]] = None,
337413
predicted_label: Optional[Union[str, int]] = None,
338414
excluded_columns: Optional[Union[List[int], List[str]]] = None,
415+
segmentation_config: Optional[List[SegmentationConfig]] = None,
339416
):
340417
"""Initializes a configuration of both input and output datasets.
341418
@@ -402,6 +479,8 @@ def __init__(
402479
Only a single predicted label per sample is supported at this time.
403480
excluded_columns (list[int] or list[str]): A list of names or indices of the columns
404481
which are to be excluded from making model inference API calls.
482+
segmentation_config (list[SegmentationConfig]): A list of ``SegmentationConfig``
483+
objects.
405484
406485
Raises:
407486
ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters
@@ -469,6 +548,7 @@ def __init__(
469548
self.predicted_label_headers = predicted_label_headers
470549
self.predicted_label = predicted_label
471550
self.excluded_columns = excluded_columns
551+
self.segmentation_configs = segmentation_config
472552
self.analysis_config = {
473553
"dataset_type": dataset_type,
474554
}
@@ -486,6 +566,12 @@ def __init__(
486566
_set(predicted_label_headers, "predicted_label_headers", self.analysis_config)
487567
_set(predicted_label, "predicted_label", self.analysis_config)
488568
_set(excluded_columns, "excluded_columns", self.analysis_config)
569+
if segmentation_config:
570+
_set(
571+
[item.to_dict() for item in segmentation_config],
572+
"segment_config",
573+
self.analysis_config,
574+
)
489575

490576
def get_config(self):
491577
"""Returns part of an analysis config dictionary."""

tests/unit/test_clarify.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
_AnalysisConfigGenerator,
3333
DatasetType,
3434
ProcessingOutputHandler,
35+
SegmentationConfig,
3536
)
3637

3738
JOB_NAME_PREFIX = "my-prefix"
@@ -59,6 +60,15 @@ def test_data_config(dataset_type, features, excluded_columns, predicted_label):
5960
s3_output_path = "s3://path/to/output"
6061
label_name = "Label"
6162
headers = ["Label", "F1", "F2", "F3", "F4", "Predicted Label"]
63+
segment_config = [
64+
SegmentationConfig(
65+
name_or_index="F1",
66+
segments=[[0]],
67+
config_name="c1",
68+
display_aliases=["a1"],
69+
)
70+
]
71+
6272
data_config = DataConfig(
6373
s3_data_input_path=s3_data_input_path,
6474
s3_output_path=s3_output_path,
@@ -68,12 +78,21 @@ def test_data_config(dataset_type, features, excluded_columns, predicted_label):
6878
dataset_type=dataset_type,
6979
excluded_columns=excluded_columns,
7080
predicted_label=predicted_label,
81+
segmentation_config=segment_config,
7182
)
7283

7384
expected_config = {
7485
"dataset_type": dataset_type,
7586
"headers": headers,
7687
"label": "Label",
88+
"segment_config": [
89+
{
90+
"config_name": "c1",
91+
"display_aliases": ["a1"],
92+
"name_or_index": "F1",
93+
"segments": [[0]],
94+
}
95+
],
7796
}
7897
if features:
7998
expected_config["features"] = features
@@ -209,6 +228,65 @@ def test_invalid_data_config():
209228
)
210229

211230

231+
@pytest.mark.parametrize(
232+
("name_or_index", "segments", "config_name", "display_aliases"),
233+
[
234+
("feature1", [[0]], None, None),
235+
("feature1", [[0], ["[1, 3)", "(5, 10]"]], None, None),
236+
("feature1", [[0], ["[1, 3)", "(5, 10]"]], "config1", None),
237+
("feature1", [["A", "B"]], "config1", ["seg1"]),
238+
("feature1", [["A", "B"]], "config1", ["seg1", "default_seg"]),
239+
],
240+
)
241+
def test_segmentation_config(name_or_index, segments, config_name, display_aliases):
242+
segmentation_config = SegmentationConfig(
243+
name_or_index=name_or_index,
244+
segments=segments,
245+
config_name=config_name,
246+
display_aliases=display_aliases,
247+
)
248+
249+
assert segmentation_config.name_or_index == name_or_index
250+
assert segmentation_config.segments == segments
251+
if segmentation_config.config_name:
252+
assert segmentation_config.config_name == config_name
253+
if segmentation_config.display_aliases:
254+
assert segmentation_config.display_aliases == display_aliases
255+
256+
257+
@pytest.mark.parametrize(
258+
("name_or_index", "segments", "config_name", "display_aliases", "error_msg"),
259+
[
260+
(None, [[0]], "config1", None, "`name_or_index` cannot be None"),
261+
(
262+
"feature1",
263+
"0",
264+
"config1",
265+
["seg1"],
266+
"`segments` must be a list of lists of values or intervals.",
267+
),
268+
(
269+
"feature1",
270+
[[0]],
271+
"config1",
272+
["seg1", "seg2", "seg3"],
273+
"Number of `display_aliases` must equal the number of segments specified or with one "
274+
"additional default segment display alias.",
275+
),
276+
],
277+
)
278+
def test_invalid_segmentation_config(
279+
name_or_index, segments, config_name, display_aliases, error_msg
280+
):
281+
with pytest.raises(ValueError, match=error_msg):
282+
SegmentationConfig(
283+
name_or_index=name_or_index,
284+
segments=segments,
285+
config_name=config_name,
286+
display_aliases=display_aliases,
287+
)
288+
289+
212290
# features JMESPath is required for JSON dataset types
213291
def test_json_type_data_config_missing_features():
214292
# facets in input dataset

0 commit comments

Comments
 (0)