Skip to content

feature: model registry integration to model cards to support model packages #3933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/model_card/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
EvaluationJob,
AdditionalInformation,
ModelCard,
ModelPackage,
)

from sagemaker.model_card.schema_constraints import ( # noqa: F401 # pylint: disable=unused-import
Expand Down
50 changes: 46 additions & 4 deletions src/sagemaker/model_card/evaluation_metric_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class EvaluationMetricTypeEnum(str, Enum):
MODEL_CARD_METRIC_SCHEMA = "Model Card Metric Schema"
CLARIFY_BIAS = "Clarify Bias"
CLARIFY_EXPLAINABILITY = "Clarify Explainability"
MODEL_MONITOR_MODEL_QUALITY = "Model Monitor Model Quality"
REGRESSION = "Model Monitor Model Quality Regression"
BINARY_CLASSIFICATION = "Model Monitor Model Quality Binary Classification"
MULTICLASS_CLASSIFICATION = "Model Monitor Model Quality Multiclass Classification"
Expand Down Expand Up @@ -138,6 +139,7 @@ def _parse(self, json_data: dict):
[
{"name": i["name"], "value": i["value"], "type": "number"}
for i in item["metrics"]
if i["value"] is not None
]
)
for group_name, metric_data in group_data.items():
Expand Down Expand Up @@ -368,9 +370,10 @@ def _parse(self, json_data: dict):
result = {"metric_groups": []}
for group_name, group_data in json_data.items():
metric_data = []
for metric_name, raw_data in group_data.item():
metric_data.extend(self._parse_basic_metric(metric_name, raw_data))
result["metric_groups"].append({"name": group_name, "metric_data": metric_data})
if group_name == "regression_metrics":
for metric_name, raw_data in group_data.items():
metric_data.extend(self._parse_basic_metric(metric_name, raw_data))
result["metric_groups"].append({"name": group_name, "metric_data": metric_data})
return result


Expand All @@ -388,7 +391,7 @@ def _validate(self, json_data: dict):
"""
if (
"binary_classification_metrics" not in json_data
and "multiclass_classification_metrics" in json_data
and "multiclass_classification_metrics" not in json_data
):
raise ValueError("Missing *_classification_metrics from the metric data.")

Expand All @@ -401,6 +404,11 @@ def _parse(self, json_data: dict):
result = {"metric_groups": []}
for group_name, group_data in json_data.items():
metric_data = []
if group_name not in (
"binary_classification_metrics",
"multiclass_classification_metrics",
):
continue
for metric_name, raw_data in group_data.items():
metric_data.extend(self._parse_confusion_matrix(metric_name, raw_data))
metric_data.extend(
Expand Down Expand Up @@ -506,11 +514,45 @@ def _parse_precision_recall_curve(self, metric_name, raw_data):
return metric_data


class ModelMonitorModelQualityParser(ParserBase):
"""Top level parser for model monitor model quality metric type"""

def _validate(self, json_data: dict):
"""Implement ParserBase._validate.

Args:
json_data (dict): Metric data to be validated.

Raises:
ValueError: missing model monitor model quality metrics.
"""
if len(json_data) == 0:
raise ValueError("Missing model monitor model quality metrics from the metric data.")

def _parse(self, json_data: dict):
"""Implement ParserBase._parse.

Args:
json_data (dict): Raw metric data.
"""
result = {"metric_groups": []}
if "regression_metrics" in json_data:
result = RegressionParser().run(json_data)
elif (
"binary_classification_metrics" in json_data
or "multiclass_classification_metrics" in json_data
):
result = ClassificationParser().run(json_data)

return result


EVALUATION_METRIC_PARSERS = {
EvaluationMetricTypeEnum.MODEL_CARD_METRIC_SCHEMA: DefaultParser(),
EvaluationMetricTypeEnum.CLARIFY_BIAS: ClarifyBiasParser(),
EvaluationMetricTypeEnum.CLARIFY_EXPLAINABILITY: ClarifyExplainabilityParser(),
EvaluationMetricTypeEnum.REGRESSION: RegressionParser(),
EvaluationMetricTypeEnum.BINARY_CLASSIFICATION: ClassificationParser(),
EvaluationMetricTypeEnum.MULTICLASS_CLASSIFICATION: ClassificationParser(),
EvaluationMetricTypeEnum.MODEL_MONITOR_MODEL_QUALITY: ModelMonitorModelQualityParser(),
}
49 changes: 46 additions & 3 deletions src/sagemaker/model_card/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,21 @@ def _clean_descriptor_name(self, name: str):

return name

def _skip_encoding(self, attr: str):
"""Skip encoding if the attribute is an instance of _SkipEncodingDecoding descriptor"""
if attr in self.__class__.__dict__:
return isinstance(self.__class__.__dict__[attr], _SkipEncodingDecoding)

return False

def _to_request_dict(self):
"""Implement this method in a subclass to return a custom request_dict."""
request_data = {}
for attr, value in self.__dict__.items():
if value is not None:
name = self._clean_descriptor_name(attr)
request_data[name] = value
if not self._skip_encoding(name):
request_data[name] = value

return request_data

Expand Down Expand Up @@ -149,6 +157,38 @@ def decode(self, value: dict):
pass # pylint: disable=W0107


class _SkipEncodingDecoding(_DescriptorBase):
"""Object that skip the encoding/decoding in model card attributes."""

def __init__(self, value_type: Any):
"""Initialize an SkipEncodingDecoding descriptor.

Args:
value_type (Any): Value type of the attribute.
"""
self.value_type = value_type

def validate(self, value: Any):
"""Check if value type is valid.

Args:
value (Any): value type depends on self.value_type

Raises:
ValueError: value is not a self.value_type.
"""
if value is not None and not isinstance(value, self.value_type):
raise ValueError(f"Please assign a {self.value_type} to {self.private_name[1:]}")

def require_decode(self, value: Any):
"""No decoding is required."""
return False

def decode(self, value: Any):
"""No decoding is required. Required placeholder for abstractmethod"""
pass # pylint: disable=W0107


class _OneOf(_DescriptorBase):
"""Verifies that a value is one of a restricted set of options"""

Expand Down Expand Up @@ -463,9 +503,12 @@ def _read_s3_json(session: Session, bucket: str, key: str):
raise

result = {}
if data["ContentType"] == "application/json":
if data["ContentType"] == "application/json" or data["ContentType"] == "binary/octet-stream":
result = json.loads(data["Body"].read().decode("utf-8"))
else:
logger.warning("Invalid file type %s. application/json is expected.", data["ContentType"])
logger.warning(
"Invalid file type %s. application/json or binary/octet-stream is expected.",
data["ContentType"],
)

return result
Loading