Skip to content

Commit 0bc4f41

Browse files
shrestha-bikashBikash Shrestha
andauthored
feature: model registry integration to model cards to support model packages (#3933)
Co-authored-by: Bikash Shrestha <[email protected]>
1 parent c1b2465 commit 0bc4f41

File tree

9 files changed

+1811
-70
lines changed

9 files changed

+1811
-70
lines changed

src/sagemaker/model_card/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
EvaluationJob,
2929
AdditionalInformation,
3030
ModelCard,
31+
ModelPackage,
3132
)
3233

3334
from sagemaker.model_card.schema_constraints import ( # noqa: F401 # pylint: disable=unused-import

src/sagemaker/model_card/evaluation_metric_parsers.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class EvaluationMetricTypeEnum(str, Enum):
3131
MODEL_CARD_METRIC_SCHEMA = "Model Card Metric Schema"
3232
CLARIFY_BIAS = "Clarify Bias"
3333
CLARIFY_EXPLAINABILITY = "Clarify Explainability"
34+
MODEL_MONITOR_MODEL_QUALITY = "Model Monitor Model Quality"
3435
REGRESSION = "Model Monitor Model Quality Regression"
3536
BINARY_CLASSIFICATION = "Model Monitor Model Quality Binary Classification"
3637
MULTICLASS_CLASSIFICATION = "Model Monitor Model Quality Multiclass Classification"
@@ -138,6 +139,7 @@ def _parse(self, json_data: dict):
138139
[
139140
{"name": i["name"], "value": i["value"], "type": "number"}
140141
for i in item["metrics"]
142+
if i["value"] is not None
141143
]
142144
)
143145
for group_name, metric_data in group_data.items():
@@ -368,9 +370,10 @@ def _parse(self, json_data: dict):
368370
result = {"metric_groups": []}
369371
for group_name, group_data in json_data.items():
370372
metric_data = []
371-
for metric_name, raw_data in group_data.item():
372-
metric_data.extend(self._parse_basic_metric(metric_name, raw_data))
373-
result["metric_groups"].append({"name": group_name, "metric_data": metric_data})
373+
if group_name == "regression_metrics":
374+
for metric_name, raw_data in group_data.items():
375+
metric_data.extend(self._parse_basic_metric(metric_name, raw_data))
376+
result["metric_groups"].append({"name": group_name, "metric_data": metric_data})
374377
return result
375378

376379

@@ -388,7 +391,7 @@ def _validate(self, json_data: dict):
388391
"""
389392
if (
390393
"binary_classification_metrics" not in json_data
391-
and "multiclass_classification_metrics" in json_data
394+
and "multiclass_classification_metrics" not in json_data
392395
):
393396
raise ValueError("Missing *_classification_metrics from the metric data.")
394397

@@ -401,6 +404,11 @@ def _parse(self, json_data: dict):
401404
result = {"metric_groups": []}
402405
for group_name, group_data in json_data.items():
403406
metric_data = []
407+
if group_name not in (
408+
"binary_classification_metrics",
409+
"multiclass_classification_metrics",
410+
):
411+
continue
404412
for metric_name, raw_data in group_data.items():
405413
metric_data.extend(self._parse_confusion_matrix(metric_name, raw_data))
406414
metric_data.extend(
@@ -506,11 +514,45 @@ def _parse_precision_recall_curve(self, metric_name, raw_data):
506514
return metric_data
507515

508516

517+
class ModelMonitorModelQualityParser(ParserBase):
518+
"""Top level parser for model monitor model quality metric type"""
519+
520+
def _validate(self, json_data: dict):
521+
"""Implement ParserBase._validate.
522+
523+
Args:
524+
json_data (dict): Metric data to be validated.
525+
526+
Raises:
527+
ValueError: missing model monitor model quality metrics.
528+
"""
529+
if len(json_data) == 0:
530+
raise ValueError("Missing model monitor model quality metrics from the metric data.")
531+
532+
def _parse(self, json_data: dict):
533+
"""Implement ParserBase._parse.
534+
535+
Args:
536+
json_data (dict): Raw metric data.
537+
"""
538+
result = {"metric_groups": []}
539+
if "regression_metrics" in json_data:
540+
result = RegressionParser().run(json_data)
541+
elif (
542+
"binary_classification_metrics" in json_data
543+
or "multiclass_classification_metrics" in json_data
544+
):
545+
result = ClassificationParser().run(json_data)
546+
547+
return result
548+
549+
509550
EVALUATION_METRIC_PARSERS = {
510551
EvaluationMetricTypeEnum.MODEL_CARD_METRIC_SCHEMA: DefaultParser(),
511552
EvaluationMetricTypeEnum.CLARIFY_BIAS: ClarifyBiasParser(),
512553
EvaluationMetricTypeEnum.CLARIFY_EXPLAINABILITY: ClarifyExplainabilityParser(),
513554
EvaluationMetricTypeEnum.REGRESSION: RegressionParser(),
514555
EvaluationMetricTypeEnum.BINARY_CLASSIFICATION: ClassificationParser(),
515556
EvaluationMetricTypeEnum.MULTICLASS_CLASSIFICATION: ClassificationParser(),
557+
EvaluationMetricTypeEnum.MODEL_MONITOR_MODEL_QUALITY: ModelMonitorModelQualityParser(),
516558
}

src/sagemaker/model_card/helpers.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,21 @@ def _clean_descriptor_name(self, name: str):
6262

6363
return name
6464

65+
def _skip_encoding(self, attr: str):
66+
"""Skip encoding if the attribute is an instance of _SkipEncodingDecoding descriptor"""
67+
if attr in self.__class__.__dict__:
68+
return isinstance(self.__class__.__dict__[attr], _SkipEncodingDecoding)
69+
70+
return False
71+
6572
def _to_request_dict(self):
6673
"""Implement this method in a subclass to return a custom request_dict."""
6774
request_data = {}
6875
for attr, value in self.__dict__.items():
6976
if value is not None:
7077
name = self._clean_descriptor_name(attr)
71-
request_data[name] = value
78+
if not self._skip_encoding(name):
79+
request_data[name] = value
7280

7381
return request_data
7482

@@ -149,6 +157,38 @@ def decode(self, value: dict):
149157
pass # pylint: disable=W0107
150158

151159

160+
class _SkipEncodingDecoding(_DescriptorBase):
161+
"""Object that skip the encoding/decoding in model card attributes."""
162+
163+
def __init__(self, value_type: Any):
164+
"""Initialize an SkipEncodingDecoding descriptor.
165+
166+
Args:
167+
value_type (Any): Value type of the attribute.
168+
"""
169+
self.value_type = value_type
170+
171+
def validate(self, value: Any):
172+
"""Check if value type is valid.
173+
174+
Args:
175+
value (Any): value type depends on self.value_type
176+
177+
Raises:
178+
ValueError: value is not a self.value_type.
179+
"""
180+
if value is not None and not isinstance(value, self.value_type):
181+
raise ValueError(f"Please assign a {self.value_type} to {self.private_name[1:]}")
182+
183+
def require_decode(self, value: Any):
184+
"""No decoding is required."""
185+
return False
186+
187+
def decode(self, value: Any):
188+
"""No decoding is required. Required placeholder for abstractmethod"""
189+
pass # pylint: disable=W0107
190+
191+
152192
class _OneOf(_DescriptorBase):
153193
"""Verifies that a value is one of a restricted set of options"""
154194

@@ -463,9 +503,12 @@ def _read_s3_json(session: Session, bucket: str, key: str):
463503
raise
464504

465505
result = {}
466-
if data["ContentType"] == "application/json":
506+
if data["ContentType"] == "application/json" or data["ContentType"] == "binary/octet-stream":
467507
result = json.loads(data["Body"].read().decode("utf-8"))
468508
else:
469-
logger.warning("Invalid file type %s. application/json is expected.", data["ContentType"])
509+
logger.warning(
510+
"Invalid file type %s. application/json or binary/octet-stream is expected.",
511+
data["ContentType"],
512+
)
470513

471514
return result

0 commit comments

Comments
 (0)