|
25 | 25 |
|
26 | 26 | import tempfile
|
27 | 27 | from abc import ABC, abstractmethod
|
28 |
| -from typing import List, Union, Dict, Optional, Any |
29 |
| - |
30 |
| -from schema import Schema, And, Use, Or, Optional, Regex |
| 28 | +from typing import List, Union, Dict, Any, Optional |
| 29 | + |
| 30 | +from schema import ( |
| 31 | + Schema, |
| 32 | + And, |
| 33 | + Use, |
| 34 | + Or, |
| 35 | + Optional as SchemaOptional, |
| 36 | + Regex |
| 37 | +) |
31 | 38 |
|
32 | 39 | import sagemaker
|
33 | 40 | from sagemaker import image_uris, s3, utils
|
|
43 | 50 |
|
44 | 51 | ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
|
45 | 52 | {
|
46 |
| - Optional("version"): str, |
| 53 | + SchemaOptional("version"): str, |
47 | 54 | "dataset_type": And(
|
48 | 55 | str,
|
49 | 56 | Use(str.lower),
|
|
56 | 63 | "application/x-image",
|
57 | 64 | ),
|
58 | 65 | ),
|
59 |
| - Optional("dataset_uri"): str, |
60 |
| - Optional("headers"): [str], |
61 |
| - Optional("label"): Or(str, int), |
| 66 | + SchemaOptional("dataset_uri"): str, |
| 67 | + SchemaOptional("headers"): [str], |
| 68 | + SchemaOptional("label"): Or(str, int), |
62 | 69 | # this field indicates user provides predicted_label in dataset
|
63 |
| - Optional("predicted_label"): Or(str, int), |
64 |
| - Optional("features"): str, |
65 |
| - Optional("label_values_or_threshold"): [Or(int, float, str)], |
66 |
| - Optional("probability_threshold"): float, |
67 |
| - Optional("facet"): [ |
68 |
| - {"name_or_index": Or(str, int), Optional("value_or_threshold"): [Or(int, float, str)]} |
| 70 | + SchemaOptional("predicted_label"): Or(str, int), |
| 71 | + SchemaOptional("features"): str, |
| 72 | + SchemaOptional("label_values_or_threshold"): [Or(int, float, str)], |
| 73 | + SchemaOptional("probability_threshold"): float, |
| 74 | + SchemaOptional("facet"): [ |
| 75 | + {"name_or_index": Or(str, int), SchemaOptional("value_or_threshold"): [Or(int, float, str)]} |
69 | 76 | ],
|
70 |
| - Optional("facet_dataset_uri"): str, |
71 |
| - Optional("facet_headers"): [str], |
72 |
| - Optional("predicted_label_dataset_uri"): str, |
73 |
| - Optional("predicted_label_headers"): [str], |
74 |
| - Optional("excluded_columns"): [Or(int, str)], |
75 |
| - Optional("joinsource_name_or_index"): Or(str, int), |
76 |
| - Optional("group_variable"): Or(str, int), |
| 77 | + SchemaOptional("facet_dataset_uri"): str, |
| 78 | + SchemaOptional("facet_headers"): [str], |
| 79 | + SchemaOptional("predicted_label_dataset_uri"): str, |
| 80 | + SchemaOptional("predicted_label_headers"): [str], |
| 81 | + SchemaOptional("excluded_columns"): [Or(int, str)], |
| 82 | + SchemaOptional("joinsource_name_or_index"): Or(str, int), |
| 83 | + SchemaOptional("group_variable"): Or(str, int), |
77 | 84 | "methods": {
|
78 |
| - Optional("shap"): { |
79 |
| - Optional("baseline"): Or( |
| 85 | + SchemaOptional("shap"): { |
| 86 | + SchemaOptional("baseline"): Or( |
80 | 87 | # URI of the baseline data file
|
81 | 88 | str,
|
82 | 89 | # Inplace baseline data (a list of something)
|
|
93 | 100 | )
|
94 | 101 | ],
|
95 | 102 | ),
|
96 |
| - Optional("num_clusters"): int, |
97 |
| - Optional("use_logit"): bool, |
98 |
| - Optional("num_samples"): int, |
99 |
| - Optional("agg_method"): And( |
| 103 | + SchemaOptional("num_clusters"): int, |
| 104 | + SchemaOptional("use_logit"): bool, |
| 105 | + SchemaOptional("num_samples"): int, |
| 106 | + SchemaOptional("agg_method"): And( |
100 | 107 | str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")
|
101 | 108 | ),
|
102 |
| - Optional("save_local_shap_values"): bool, |
103 |
| - Optional("text_config"): { |
| 109 | + SchemaOptional("save_local_shap_values"): bool, |
| 110 | + SchemaOptional("text_config"): { |
104 | 111 | "granularity": And(
|
105 | 112 | str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")
|
106 | 113 | ),
|
|
237 | 244 | "yo",
|
238 | 245 | ),
|
239 | 246 | ),
|
240 |
| - Optional("max_top_tokens"): int, |
| 247 | + SchemaOptional("max_top_tokens"): int, |
241 | 248 | },
|
242 |
| - Optional("image_config"): { |
243 |
| - Optional("num_segments"): int, |
244 |
| - Optional("segment_compactness"): int, |
245 |
| - Optional("feature_extraction_method"): str, |
246 |
| - Optional("model_type"): str, |
247 |
| - Optional("max_objects"): int, |
248 |
| - Optional("iou_threshold"): float, |
249 |
| - Optional("context"): float, |
250 |
| - Optional("debug"): { |
251 |
| - Optional("image_names"): [str], |
252 |
| - Optional("class_ids"): [int], |
253 |
| - Optional("sample_from"): int, |
254 |
| - Optional("sample_to"): int, |
| 249 | + SchemaOptional("image_config"): { |
| 250 | + SchemaOptional("num_segments"): int, |
| 251 | + SchemaOptional("segment_compactness"): int, |
| 252 | + SchemaOptional("feature_extraction_method"): str, |
| 253 | + SchemaOptional("model_type"): str, |
| 254 | + SchemaOptional("max_objects"): int, |
| 255 | + SchemaOptional("iou_threshold"): float, |
| 256 | + SchemaOptional("context"): float, |
| 257 | + SchemaOptional("debug"): { |
| 258 | + SchemaOptional("image_names"): [str], |
| 259 | + SchemaOptional("class_ids"): [int], |
| 260 | + SchemaOptional("sample_from"): int, |
| 261 | + SchemaOptional("sample_to"): int, |
255 | 262 | },
|
256 | 263 | },
|
257 |
| - Optional("seed"): int, |
| 264 | + SchemaOptional("seed"): int, |
258 | 265 | },
|
259 |
| - Optional("pre_training_bias"): {"methods": Or(str, [str])}, |
260 |
| - Optional("post_training_bias"): {"methods": Or(str, [str])}, |
261 |
| - Optional("pdp"): { |
| 266 | + SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])}, |
| 267 | + SchemaOptional("post_training_bias"): {"methods": Or(str, [str])}, |
| 268 | + SchemaOptional("pdp"): { |
262 | 269 | "grid_resolution": int,
|
263 |
| - Optional("features"): [Or(str, int)], |
264 |
| - Optional("top_k_features"): int, |
| 270 | + SchemaOptional("features"): [Or(str, int)], |
| 271 | + SchemaOptional("top_k_features"): int, |
265 | 272 | },
|
266 |
| - Optional("report"): {"name": str, Optional("title"): str}, |
| 273 | + SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, |
267 | 274 | },
|
268 |
| - Optional("predictor"): { |
269 |
| - Optional("endpoint_name"): str, |
270 |
| - Optional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)), |
271 |
| - Optional("model_name"): str, |
272 |
| - Optional("target_model"): str, |
273 |
| - Optional("instance_type"): str, |
274 |
| - Optional("initial_instance_count"): int, |
275 |
| - Optional("accelerator_type"): str, |
276 |
| - Optional("content_type"): And( |
| 275 | + SchemaOptional("predictor"): { |
| 276 | + SchemaOptional("endpoint_name"): str, |
| 277 | + SchemaOptional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)), |
| 278 | + SchemaOptional("model_name"): str, |
| 279 | + SchemaOptional("target_model"): str, |
| 280 | + SchemaOptional("instance_type"): str, |
| 281 | + SchemaOptional("initial_instance_count"): int, |
| 282 | + SchemaOptional("accelerator_type"): str, |
| 283 | + SchemaOptional("content_type"): And( |
277 | 284 | str,
|
278 | 285 | Use(str.lower),
|
279 | 286 | lambda s: s
|
|
286 | 293 | "application/x-npy",
|
287 | 294 | ),
|
288 | 295 | ),
|
289 |
| - Optional("accept_type"): And( |
| 296 | + SchemaOptional("accept_type"): And( |
290 | 297 | str,
|
291 | 298 | Use(str.lower),
|
292 | 299 | lambda s: s in ("text/csv", "application/jsonlines", "application/json"),
|
293 | 300 | ),
|
294 |
| - Optional("label"): Or(str, int), |
295 |
| - Optional("probability"): Or(str, int), |
296 |
| - Optional("label_headers"): [Or(str, int)], |
297 |
| - Optional("content_template"): Or(str, {str: str}), |
298 |
| - Optional("custom_attributes"): str, |
| 301 | + SchemaOptional("label"): Or(str, int), |
| 302 | + SchemaOptional("probability"): Or(str, int), |
| 303 | + SchemaOptional("label_headers"): [Or(str, int)], |
| 304 | + SchemaOptional("content_template"): Or(str, {str: str}), |
| 305 | + SchemaOptional("custom_attributes"): str, |
299 | 306 | },
|
300 | 307 | }
|
301 | 308 | )
|
|
0 commit comments