Skip to content

Commit f6180dc

Browse files
author
Haonian Wang
committed
add business details and hyperparameters
1 parent 70ce8fa commit f6180dc

File tree

5 files changed

+122
-0
lines changed

5 files changed

+122
-0
lines changed

doc/api/governance/model_card.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ see `Amazon SageMaker Model Cards <https://docs.aws.amazon.com/sagemaker/latest/
4141

4242
.. autoclass:: TrainingJobDetails
4343
:show-inheritance:
44+
45+
.. autoclass:: BusinessDetails
46+
:show-inheritance:
47+
48+
.. autoclass:: HyperParameter
49+
:show-inheritance:

src/sagemaker/model_card/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
Environment,
1717
ModelOverview,
1818
IntendedUses,
19+
BusinessDetails,
1920
ObjectiveFunction,
2021
TrainingMetric,
22+
HyperParameter,
2123
Metric,
2224
Function,
2325
TrainingJobDetails,

src/sagemaker/model_card/model_card.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
TRAINING_DATASETS_MAX_SIZE,
3535
TRAINING_METRICS_MAX_SIZE,
3636
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE,
37+
HYPER_PARAMETERS_MAX_SIZE,
38+
USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE,
3739
EVALUATION_DATASETS_MAX_SIZE,
3840
)
3941
from sagemaker.model_card.helpers import (
@@ -235,6 +237,27 @@ def __init__(
235237
self.explanations_for_risk_rating = explanations_for_risk_rating
236238

237239

240+
class BusinessDetails(_DefaultToRequestDict, _DefaultFromDict):
241+
"""The business details of a model."""
242+
243+
def __init__(
244+
self,
245+
business_problem: Optional[str] = None,
246+
business_stakeholders: Optional[str] = None,
247+
line_of_business: Optional[str] = None,
248+
):
249+
"""Initialize an Business Details object.
250+
251+
Args:
252+
business_problem (str, optional): The business problem of this model (default: None).
253+
business_stakeholders (str, optional): The business stakeholders for this model (default: None).
254+
line_of_business (str, optional): The line of business for this model (default: None).
255+
""" # noqa E501 # pylint: disable=line-too-long
256+
self.business_problem = business_problem
257+
self.business_stakeholders = business_stakeholders
258+
self.line_of_business = line_of_business
259+
260+
238261
class Function(_DefaultToRequestDict, _DefaultFromDict):
239262
"""Function details."""
240263

@@ -363,6 +386,27 @@ def __init__(
363386
self.notes = notes
364387

365388

389+
class HyperParameter(_DefaultToRequestDict, _DefaultFromDict):
390+
"""Hyper-Parameters data.
391+
392+
Should only be used during auto-population of parameters details.
393+
"""
394+
395+
def __init__(
396+
self,
397+
name: str,
398+
value: str,
399+
):
400+
"""Initialize a HyperParameter object.
401+
402+
Args:
403+
name (str): The hyper parameter name.
404+
value (str): The hyper parameter value.
405+
"""
406+
self.name = name
407+
self.value = value
408+
409+
366410
class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
367411
"""The overview of a training job."""
368412

@@ -371,6 +415,10 @@ class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
371415
user_provided_training_metrics = _IsList(
372416
TrainingMetric, USER_PROVIDED_TRAINING_METRICS_MAX_SIZE
373417
)
418+
hyper_parameters = _IsList(HyperParameter, HYPER_PARAMETERS_MAX_SIZE)
419+
user_provided_hyper_parameters = _IsList(
420+
HyperParameter, USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE
421+
)
374422
training_environment = _IsModelCardObject(Environment)
375423

376424
def __init__(
@@ -380,6 +428,8 @@ def __init__(
380428
training_environment: Optional[Environment] = None,
381429
training_metrics: Optional[List[TrainingMetric]] = None,
382430
user_provided_training_metrics: Optional[List[TrainingMetric]] = None,
431+
hyper_parameters: Optional[List[HyperParameter]] = None,
432+
user_provided_hyper_parameters: Optional[List[HyperParameter]] = None,
383433
):
384434
"""Initialize a Training Job Details object.
385435
@@ -389,12 +439,16 @@ def __init__(
389439
training_environment (Environment, optional): The SageMaker training image URI. (default: None).
390440
training_metrics (list[TrainingMetric], optional): SageMaker training job results. The maximum `training_metrics` list length is 50 (default: None).
391441
user_provided_training_metrics (list[TrainingMetric], optional): Custom training job results. The maximum `user_provided_training_metrics` list length is 50 (default: None).
442+
hyper_parameters (list[HyperParameter], optional): SageMaker hyper parameter results. The maximum `hyper_parameters` list length is 100 (default: None).
443+
user_provided_hyper_parameters (list[HyperParameter], optional): Custom hyper parameter results. The maximum `user_provided_hyper_parameters` list length is 100 (default: None).
392444
""" # noqa E501 # pylint: disable=line-too-long
393445
self.training_arn = training_arn
394446
self.training_datasets = training_datasets
395447
self.training_environment = training_environment
396448
self.training_metrics = training_metrics
397449
self.user_provided_training_metrics = user_provided_training_metrics
450+
self.hyper_parameters = hyper_parameters
451+
self.user_provided_hyper_parameters = user_provided_hyper_parameters
398452

399453

400454
class TrainingDetails(_DefaultToRequestDict, _DefaultFromDict):
@@ -568,6 +622,16 @@ def add_metric(self, metric: TrainingMetric):
568622
self.training_job_details = TrainingJobDetails()
569623
self.training_job_details.user_provided_training_metrics.append(metric)
570624

625+
def add_parameter(self, parameter: HyperParameter):
626+
"""Add custom hyper-parameter.
627+
628+
Args:
629+
parameter (HyperParameter): The custom parameter to add.
630+
"""
631+
if not self.training_job_details:
632+
self.training_job_details = TrainingJobDetails()
633+
self.training_job_details.user_provided_hyper_parameters.append(parameter)
634+
571635

572636
class MetricGroup(_DefaultToRequestDict, _DefaultFromDict):
573637
"""Group of metric data"""
@@ -777,6 +841,7 @@ class ModelCard(object):
777841
status = _OneOf(ModelCardStatusEnum)
778842
model_overview = _IsModelCardObject(ModelOverview)
779843
intended_uses = _IsModelCardObject(IntendedUses)
844+
business_details = _IsModelCardObject(BusinessDetails)
780845
training_details = _IsModelCardObject(TrainingDetails)
781846
evaluation_details = _IsList(EvaluationJob)
782847
additional_information = _IsModelCardObject(AdditionalInformation)
@@ -793,6 +858,7 @@ def __init__(
793858
last_modified_by: Optional[dict] = None,
794859
model_overview: Optional[ModelOverview] = None,
795860
intended_uses: Optional[IntendedUses] = None,
861+
business_details: Optional[BusinessDetails] = None,
796862
training_details: Optional[TrainingDetails] = None,
797863
evaluation_details: Optional[List[EvaluationJob]] = None,
798864
additional_information: Optional[AdditionalInformation] = None,
@@ -811,6 +877,7 @@ def __init__(
811877
last_modified_by (dict, optional): The group or individual that last modified the model card (default: None).
812878
model_overview (ModelOverview, optional): An overview of the model (default: None).
813879
intended_uses (IntendedUses, optional): The intended uses of the model (default: None).
880+
business_details (BusinessDetails, optional): The business details of the model (default: None).
814881
training_details (TrainingDetails, optional): The training details of the model (default: None).
815882
evaluation_details (List[EvaluationJob], optional): The evaluation details of the model (default: None).
816883
additional_information (AdditionalInformation, optional): Additional information about the model (default: None).
@@ -826,6 +893,7 @@ def __init__(
826893
self.last_modified_by = last_modified_by
827894
self.model_overview = model_overview
828895
self.intended_uses = intended_uses
896+
self.business_details = business_details
829897
self.training_details = training_details
830898
self.evaluation_details = evaluation_details
831899
self.additional_information = additional_information

src/sagemaker/model_card/schema_constraints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,6 @@ class MetricTypeEnum(str, Enum):
8484
TRAINING_DATASETS_MAX_SIZE = 15
8585
TRAINING_METRICS_MAX_SIZE = 50
8686
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE = 50
87+
HYPER_PARAMETERS_MAX_SIZE = 100
88+
USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE = 100
8789
EVALUATION_DATASETS_MAX_SIZE = 10

tests/unit/test_model_card.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
Environment,
2828
ModelOverview,
2929
IntendedUses,
30+
BusinessDetails,
3031
ObjectiveFunction,
3132
TrainingMetric,
33+
HyperParameter,
3234
Metric,
3335
TrainingDetails,
3436
MetricGroup,
@@ -75,6 +77,11 @@
7577
RISK_RATING = schema_constraints.RiskRatingEnum.LOW
7678
EXPLANATIONS_FOR_RISK_RATING = "ramdomly the first example"
7779

80+
# business details auguments
81+
BUSINESS_PROBLEM = "mock model for business problem testing"
82+
BUSINESS_STAKEHOLDERS = "business stakeholders testing"
83+
LINE_OF_BUSINESS = "how many business models"
84+
7885
# training details arguments
7986
OBJECITVE_FUNCTION_FUNC = schema_constraints.ObjectiveFunctionEnum.MINIMIZE
8087
OBJECTIVE_FUNCTION_FACET = schema_constraints.FacetEnum.LOSS
@@ -89,6 +96,10 @@
8996
USER_METRIC_NAME = "test_metric"
9097
USER_METRIC = TrainingMetric(name=USER_METRIC_NAME, value=1)
9198
USER_PROVIDED_TRAINING_METRICS = [USER_METRIC]
99+
HYPER_PARAMETER = [HyperParameter(name="binary_f_beta", value=0.965)]
100+
USER_PARAMETER_NAME = "test_parameter"
101+
USER_PARAMETER = HyperParameter(name=USER_PARAMETER_NAME, value=1)
102+
USER_PROVIDED_HYPER_PARAMETER = [USER_PARAMETER]
92103

93104
# evaluation job arguments
94105
EVALUATION_JOB_NAME = "evaluation job 1"
@@ -583,6 +594,17 @@ def fixture_fixture_intended_uses_example():
583594
return test_example
584595

585596

597+
@pytest.fixture(name="business_details_example")
598+
def fixture_fixture_business_details_example():
599+
"""Example business details instance."""
600+
test_example = BusinessDetails(
601+
business_problem=BUSINESS_PROBLEM,
602+
business_stakeholders=BUSINESS_STAKEHOLDERS,
603+
line_of_business=LINE_OF_BUSINESS,
604+
)
605+
return test_example
606+
607+
586608
@pytest.fixture(name="training_details_example")
587609
def fixture_fixture_training_details_example():
588610
"""Example training details instance."""
@@ -601,6 +623,7 @@ def fixture_fixture_training_details_example():
601623
training_datasets=TRAINING_DATASETS,
602624
training_environment=TRAINING_ENVIRONMENT,
603625
training_metrics=TRAINING_METRICS,
626+
hyper_parameters=HYPER_PARAMETER,
604627
),
605628
)
606629
return test_example
@@ -637,6 +660,7 @@ def test_create_model_card(
637660
session,
638661
model_overview_example,
639662
intended_uses_example,
663+
business_details_example,
640664
training_details_example,
641665
evaluation_details_example,
642666
additional_information_example,
@@ -649,6 +673,7 @@ def test_create_model_card(
649673
status=MODEL_CARD_STATUS,
650674
model_overview=model_overview_example,
651675
intended_uses=intended_uses_example,
676+
business_details=business_details_example,
652677
training_details=training_details_example,
653678
evaluation_details=evaluation_details_example,
654679
additional_information=additional_information_example,
@@ -866,6 +891,14 @@ def __init__(self, attr1): # pylint: disable=C0116
866891
):
867892
ExampleClass(attr1=IntendedUses())
868893

894+
with pytest.raises(
895+
ValueError,
896+
match=re.escape(
897+
"Expected <class 'sagemaker.model_card.model_card.BusinessDetails'> instance to be of class ModelOverview" # noqa E501 # pylint: disable=c0301
898+
),
899+
):
900+
ExampleClass(attr1=BusinessDetails())
901+
869902
# decode object from json data
870903
assert ExampleClass({"model_name": "test"})
871904

@@ -1047,6 +1080,7 @@ def test_training_details_autodiscovery_from_model_overview_autopilot(
10471080
)
10481081

10491082
assert len(training_details.training_job_details.training_metrics) == 0
1083+
assert len(training_details.training_job_details.hyper_parameters) == 0
10501084

10511085

10521086
@patch("sagemaker.Session")
@@ -1091,6 +1125,16 @@ def test_add_user_provided_training_metrics(training_details_example):
10911125
)
10921126

10931127

1128+
def test_add_user_provided_hyper_parameters(training_details_example):
1129+
assert len(training_details_example.training_job_details.user_provided_hyper_parameters) == 0
1130+
training_details_example.add_parameter(USER_PARAMETER)
1131+
assert len(training_details_example.training_job_details.user_provided_hyper_parameters) == 1
1132+
assert (
1133+
training_details_example.training_job_details.user_provided_hyper_parameters[0].name
1134+
== USER_PARAMETER_NAME
1135+
)
1136+
1137+
10941138
def test_add_evaluation_metrics_manually():
10951139
evaluation_job = EvaluationJob(name=EVALUATION_JOB_NAME)
10961140

0 commit comments

Comments
 (0)