Skip to content

Commit 96dd74e

Browse files
author
Haonian Wang
committed
feature: Add business details and hyper parameters fields and update test_model_card.py
1 parent 8d2c16b commit 96dd74e

File tree

5 files changed

+140
-0
lines changed

5 files changed

+140
-0
lines changed

doc/api/governance/model_card.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ see `Amazon SageMaker Model Cards <https://docs.aws.amazon.com/sagemaker/latest/
4141

4242
.. autoclass:: TrainingJobDetails
4343
:show-inheritance:
44+
45+
.. autoclass:: BusinessDetails
46+
:show-inheritance:
47+
48+
.. autoclass:: HyperParameter
49+
:show-inheritance:

src/sagemaker/model_card/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
Environment,
1717
ModelOverview,
1818
IntendedUses,
19+
BusinessDetails,
1920
ObjectiveFunction,
2021
TrainingMetric,
22+
HyperParameter,
2123
Metric,
2224
Function,
2325
TrainingJobDetails,

src/sagemaker/model_card/model_card.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
TRAINING_DATASETS_MAX_SIZE,
3535
TRAINING_METRICS_MAX_SIZE,
3636
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE,
37+
HYPER_PARAMETERS_MAX_SIZE,
38+
USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE,
3739
EVALUATION_DATASETS_MAX_SIZE,
3840
)
3941
from sagemaker.model_card.helpers import (
@@ -235,6 +237,27 @@ def __init__(
235237
self.explanations_for_risk_rating = explanations_for_risk_rating
236238

237239

240+
class BusinessDetails(_DefaultToRequestDict, _DefaultFromDict):
241+
"""The business details of a model."""
242+
243+
def __init__(
244+
self,
245+
business_problem: Optional[str] = None,
246+
business_stakeholders: Optional[str] = None,
247+
line_of_business: Optional[str] = None,
248+
):
249+
"""Initialize an Business Details object.
250+
251+
Args:
252+
business_problem (str, optional): The business problem of this model (default: None).
253+
business_stakeholders (str, optional): The business stakeholders for this model (default: None).
254+
line_of_business (str, optional): The line of business for this model (default: None).
255+
""" # noqa E501 # pylint: disable=line-too-long
256+
self.business_problem = business_problem
257+
self.business_stakeholders = business_stakeholders
258+
self.line_of_business = line_of_business
259+
260+
238261
class Function(_DefaultToRequestDict, _DefaultFromDict):
239262
"""Function details."""
240263

@@ -363,6 +386,24 @@ def __init__(
363386
self.notes = notes
364387

365388

389+
class HyperParameter(_DefaultToRequestDict, _DefaultFromDict):
390+
"""Hyper-Parameters data."""
391+
392+
def __init__(
393+
self,
394+
name: str,
395+
value: str,
396+
):
397+
"""Initialize a HyperParameter object.
398+
399+
Args:
400+
name (str): The hyper parameter name.
401+
value (str): The hyper parameter value.
402+
"""
403+
self.name = name
404+
self.value = value
405+
406+
366407
class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
367408
"""The overview of a training job."""
368409

@@ -371,6 +412,10 @@ class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
371412
user_provided_training_metrics = _IsList(
372413
TrainingMetric, USER_PROVIDED_TRAINING_METRICS_MAX_SIZE
373414
)
415+
hyper_parameters = _IsList(HyperParameter, HYPER_PARAMETERS_MAX_SIZE)
416+
user_provided_hyper_parameters = _IsList(
417+
HyperParameter, USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE
418+
)
374419
training_environment = _IsModelCardObject(Environment)
375420

376421
def __init__(
@@ -380,6 +425,8 @@ def __init__(
380425
training_environment: Optional[Environment] = None,
381426
training_metrics: Optional[List[TrainingMetric]] = None,
382427
user_provided_training_metrics: Optional[List[TrainingMetric]] = None,
428+
hyper_parameters: Optional[List[HyperParameter]] = None,
429+
user_provided_hyper_parameters: Optional[List[HyperParameter]] = None,
383430
):
384431
"""Initialize a Training Job Details object.
385432
@@ -389,12 +436,16 @@ def __init__(
389436
training_environment (Environment, optional): The SageMaker training image URI. (default: None).
390437
training_metrics (list[TrainingMetric], optional): SageMaker training job results. The maximum `training_metrics` list length is 50 (default: None).
391438
user_provided_training_metrics (list[TrainingMetric], optional): Custom training job results. The maximum `user_provided_training_metrics` list length is 50 (default: None).
439+
hyper_parameters (list[HyperParameter], optional): SageMaker hyper parameter results. The maximum `hyper_parameters` list length is 100 (default: None).
440+
user_provided_hyper_parameters (list[HyperParameter], optional): Custom hyper parameter results. The maximum `user_provided_hyper_parameters` list length is 100 (default: None).
392441
""" # noqa E501 # pylint: disable=line-too-long
393442
self.training_arn = training_arn
394443
self.training_datasets = training_datasets
395444
self.training_environment = training_environment
396445
self.training_metrics = training_metrics
397446
self.user_provided_training_metrics = user_provided_training_metrics
447+
self.hyper_parameters = hyper_parameters
448+
self.user_provided_hyper_parameters = user_provided_hyper_parameters
398449

399450

400451
class TrainingDetails(_DefaultToRequestDict, _DefaultFromDict):
@@ -568,6 +619,16 @@ def add_metric(self, metric: TrainingMetric):
568619
self.training_job_details = TrainingJobDetails()
569620
self.training_job_details.user_provided_training_metrics.append(metric)
570621

622+
def add_parameter(self, parameter: HyperParameter):
623+
"""Add custom hyper-parameter.
624+
625+
Args:
626+
parameter (HyperParameter): The custom parameter to add.
627+
"""
628+
if not self.training_job_details:
629+
self.training_job_details = TrainingJobDetails()
630+
self.training_job_details.user_provided_hyper_parameters.append(parameter)
631+
571632

572633
class MetricGroup(_DefaultToRequestDict, _DefaultFromDict):
573634
"""Group of metric data"""
@@ -777,6 +838,7 @@ class ModelCard(object):
777838
status = _OneOf(ModelCardStatusEnum)
778839
model_overview = _IsModelCardObject(ModelOverview)
779840
intended_uses = _IsModelCardObject(IntendedUses)
841+
business_details = _IsModelCardObject(BusinessDetails)
780842
training_details = _IsModelCardObject(TrainingDetails)
781843
evaluation_details = _IsList(EvaluationJob)
782844
additional_information = _IsModelCardObject(AdditionalInformation)
@@ -793,6 +855,7 @@ def __init__(
793855
last_modified_by: Optional[dict] = None,
794856
model_overview: Optional[ModelOverview] = None,
795857
intended_uses: Optional[IntendedUses] = None,
858+
business_details: Optional[BusinessDetails] = None,
796859
training_details: Optional[TrainingDetails] = None,
797860
evaluation_details: Optional[List[EvaluationJob]] = None,
798861
additional_information: Optional[AdditionalInformation] = None,
@@ -811,6 +874,7 @@ def __init__(
811874
last_modified_by (dict, optional): The group or individual that last modified the model card (default: None).
812875
model_overview (ModelOverview, optional): An overview of the model (default: None).
813876
intended_uses (IntendedUses, optional): The intended uses of the model (default: None).
877+
business_details (BusinessDetails, optional): The business details of the model (default: None).
814878
training_details (TrainingDetails, optional): The training details of the model (default: None).
815879
evaluation_details (List[EvaluationJob], optional): The evaluation details of the model (default: None).
816880
additional_information (AdditionalInformation, optional): Additional information about the model (default: None).
@@ -826,6 +890,7 @@ def __init__(
826890
self.last_modified_by = last_modified_by
827891
self.model_overview = model_overview
828892
self.intended_uses = intended_uses
893+
self.business_details = business_details
829894
self.training_details = training_details
830895
self.evaluation_details = evaluation_details
831896
self.additional_information = additional_information

src/sagemaker/model_card/schema_constraints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,6 @@ class MetricTypeEnum(str, Enum):
8484
TRAINING_DATASETS_MAX_SIZE = 15
8585
TRAINING_METRICS_MAX_SIZE = 50
8686
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE = 50
87+
HYPER_PARAMETERS_MAX_SIZE = 100
88+
USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE = 100
8789
EVALUATION_DATASETS_MAX_SIZE = 10

tests/unit/test_model_card.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
Environment,
2828
ModelOverview,
2929
IntendedUses,
30+
BusinessDetails,
3031
ObjectiveFunction,
3132
TrainingMetric,
33+
HyperParameter,
3234
Metric,
3335
TrainingDetails,
3436
MetricGroup,
@@ -75,6 +77,11 @@
7577
RISK_RATING = schema_constraints.RiskRatingEnum.LOW
7678
EXPLANATIONS_FOR_RISK_RATING = "ramdomly the first example"
7779

80+
# business details auguments
81+
BUSINESS_PROBLEM = "mock model for business problem testing"
82+
BUSINESS_STAKEHOLDERS = "business stakeholders testing"
83+
LINE_OF_BUSINESS = "how many business models"
84+
7885
# training details arguments
7986
OBJECITVE_FUNCTION_FUNC = schema_constraints.ObjectiveFunctionEnum.MINIMIZE
8087
OBJECTIVE_FUNCTION_FACET = schema_constraints.FacetEnum.LOSS
@@ -89,6 +96,10 @@
8996
USER_METRIC_NAME = "test_metric"
9097
USER_METRIC = TrainingMetric(name=USER_METRIC_NAME, value=1)
9198
USER_PROVIDED_TRAINING_METRICS = [USER_METRIC]
99+
HYPER_PARAMETER = [HyperParameter(name="binary_f_beta", value=0.965)]
100+
USER_PARAMETER_NAME = "test_parameter"
101+
USER_PARAMETER = HyperParameter(name=USER_PARAMETER_NAME, value=1)
102+
USER_PROVIDED_HYPER_PARAMETER = [USER_PARAMETER]
92103

93104
# evaluation job arguments
94105
EVALUATION_JOB_NAME = "evaluation job 1"
@@ -350,6 +361,29 @@
350361
"Timestamp": datetime.datetime(2022, 9, 5, 19, 18, 40),
351362
},
352363
],
364+
"HyperParameters": [
365+
{
366+
"feature_dim": "2",
367+
"mini_batch_size": "10",
368+
"predictor_type": "binary_classifier",
369+
},
370+
{
371+
"_kfold": "5",
372+
"_tuning_objective_metric": "validation:accuracy",
373+
"alpha": "0.0037170512924477993",
374+
"colsample_bytree": "0.7476726040667319",
375+
"eta": "0.011391935592233605",
376+
"eval_metric": "accuracy,f1,balanced_accuracy,precision_macro,recall_macro,mlogloss",
377+
"gamma": "1.8903517751689445",
378+
"lambda": "0.5098604662224621",
379+
"max_depth": "3",
380+
"min_child_weight": "5.081388147234708e-06",
381+
"num_class": "28",
382+
"num_round": "165",
383+
"objective": "multi:softprob",
384+
"subsample": "0.8828549481113146",
385+
},
386+
],
353387
"CreatedBy": {},
354388
}
355389
}
@@ -583,6 +617,17 @@ def fixture_fixture_intended_uses_example():
583617
return test_example
584618

585619

620+
@pytest.fixture(name="business_details_example")
621+
def fixture_fixture_business_details_example():
622+
"""Example business details instance."""
623+
test_example = BusinessDetails(
624+
business_problem=BUSINESS_PROBLEM,
625+
business_stakeholders=BUSINESS_STAKEHOLDERS,
626+
line_of_business=LINE_OF_BUSINESS,
627+
)
628+
return test_example
629+
630+
586631
@pytest.fixture(name="training_details_example")
587632
def fixture_fixture_training_details_example():
588633
"""Example training details instance."""
@@ -601,6 +646,7 @@ def fixture_fixture_training_details_example():
601646
training_datasets=TRAINING_DATASETS,
602647
training_environment=TRAINING_ENVIRONMENT,
603648
training_metrics=TRAINING_METRICS,
649+
hyper_parameters=HYPER_PARAMETER,
604650
),
605651
)
606652
return test_example
@@ -637,6 +683,7 @@ def test_create_model_card(
637683
session,
638684
model_overview_example,
639685
intended_uses_example,
686+
business_details_example,
640687
training_details_example,
641688
evaluation_details_example,
642689
additional_information_example,
@@ -649,6 +696,7 @@ def test_create_model_card(
649696
status=MODEL_CARD_STATUS,
650697
model_overview=model_overview_example,
651698
intended_uses=intended_uses_example,
699+
business_details=business_details_example,
652700
training_details=training_details_example,
653701
evaluation_details=evaluation_details_example,
654702
additional_information=additional_information_example,
@@ -1017,6 +1065,9 @@ def test_training_details_autodiscovery_from_model_overview(
10171065
assert len(training_details.training_job_details.training_metrics) == len(
10181066
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["FinalMetricDataList"]
10191067
)
1068+
assert len(training_details.training_job_details.training_metrics) == len(
1069+
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["HyperParameters"]
1070+
)
10201071
assert training_details.training_job_details.training_environment.container_image == [
10211072
TRAINING_IMAGE
10221073
]
@@ -1047,6 +1098,7 @@ def test_training_details_autodiscovery_from_model_overview_autopilot(
10471098
)
10481099

10491100
assert len(training_details.training_job_details.training_metrics) == 0
1101+
assert len(training_details.training_job_details.hyper_parameters) == 0
10501102

10511103

10521104
@patch("sagemaker.Session")
@@ -1063,6 +1115,9 @@ def test_training_details_autodiscovery_from_job_name(session):
10631115
assert len(training_details.training_job_details.training_metrics) == len(
10641116
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["FinalMetricDataList"]
10651117
)
1118+
assert len(training_details.training_job_details.training_metrics) == len(
1119+
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["HyperParameters"]
1120+
)
10661121
assert training_details.training_job_details.training_environment.container_image == [
10671122
TRAINING_IMAGE
10681123
]
@@ -1091,6 +1146,16 @@ def test_add_user_provided_training_metrics(training_details_example):
10911146
)
10921147

10931148

1149+
def test_add_user_provided_hyper_parameters(training_details_example):
1150+
assert len(training_details_example.training_job_details.user_provided_hyper_parameters) == 0
1151+
training_details_example.add_parameter(USER_PARAMETER)
1152+
assert len(training_details_example.training_job_details.user_provided_hyper_parameters) == 1
1153+
assert (
1154+
training_details_example.training_job_details.user_provided_hyper_parameters[0].name
1155+
== USER_PARAMETER_NAME
1156+
)
1157+
1158+
10941159
def test_add_evaluation_metrics_manually():
10951160
evaluation_job = EvaluationJob(name=EVALUATION_JOB_NAME)
10961161

0 commit comments

Comments
 (0)