Skip to content

Commit 870d8b1

Browse files
author
Haonian Wang
committed
feature: Add business details and hyper parameters fields and update test_model_card.py
1 parent 96dd74e commit 870d8b1

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

src/sagemaker/model_card/model_card.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,11 @@ def _create_training_details(training_job_data: dict, cls: "TrainingDetails", **
493493
]
494494
if "FinalMetricDataList" in training_job_data
495495
else [],
496+
# "hyper_parameters": print(training_job_data["HyperParameters"]),
497+
"hyper_parameters": [
498+
HyperParameter(key, value)
499+
for key, value in training_job_data["HyperParameters"].items()
500+
],
496501
}
497502
kwargs.update({"training_job_details": TrainingJobDetails(**job)})
498503
instance = cls(**kwargs)

tests/unit/test_model_card.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -361,29 +361,22 @@
361361
"Timestamp": datetime.datetime(2022, 9, 5, 19, 18, 40),
362362
},
363363
],
364-
"HyperParameters": [
365-
{
366-
"feature_dim": "2",
367-
"mini_batch_size": "10",
368-
"predictor_type": "binary_classifier",
369-
},
370-
{
371-
"_kfold": "5",
372-
"_tuning_objective_metric": "validation:accuracy",
373-
"alpha": "0.0037170512924477993",
374-
"colsample_bytree": "0.7476726040667319",
375-
"eta": "0.011391935592233605",
376-
"eval_metric": "accuracy,f1,balanced_accuracy,precision_macro,recall_macro,mlogloss",
377-
"gamma": "1.8903517751689445",
378-
"lambda": "0.5098604662224621",
379-
"max_depth": "3",
380-
"min_child_weight": "5.081388147234708e-06",
381-
"num_class": "28",
382-
"num_round": "165",
383-
"objective": "multi:softprob",
384-
"subsample": "0.8828549481113146",
385-
},
386-
],
364+
"HyperParameters": {
365+
"_kfold": "5",
366+
"_tuning_objective_metric": "validation:accuracy",
367+
"alpha": "0.0037170512924477993",
368+
"colsample_bytree": "0.7476726040667319",
369+
"eta": "0.011391935592233605",
370+
"eval_metric": "accuracy,f1,balanced_accuracy,precision_macro,recall_macro,mlogloss",
371+
"gamma": "1.8903517751689445",
372+
"lambda": "0.5098604662224621",
373+
"max_depth": "3",
374+
"min_child_weight": "5.081388147234708e-06",
375+
"num_class": "28",
376+
"num_round": "165",
377+
"objective": "multi:softprob",
378+
"subsample": "0.8828549481113146",
379+
},
387380
"CreatedBy": {},
388381
}
389382
}
@@ -1065,7 +1058,7 @@ def test_training_details_autodiscovery_from_model_overview(
10651058
assert len(training_details.training_job_details.training_metrics) == len(
10661059
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["FinalMetricDataList"]
10671060
)
1068-
assert len(training_details.training_job_details.training_metrics) == len(
1061+
assert len(training_details.training_job_details.hyper_parameters) == len(
10691062
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["HyperParameters"]
10701063
)
10711064
assert training_details.training_job_details.training_environment.container_image == [
@@ -1097,8 +1090,10 @@ def test_training_details_autodiscovery_from_model_overview_autopilot(
10971090
model_overview=model_overview_example, sagemaker_session=session
10981091
)
10991092

1093+
# There are 0 required keys in training_metrics field in SEARCH_TRAINING_JOB_AUTOPILOT_EXAMPLE has
11001094
assert len(training_details.training_job_details.training_metrics) == 0
1101-
assert len(training_details.training_job_details.hyper_parameters) == 0
1095+
# There are 3 required keys in hyper_parameters field in SEARCH_TRAINING_JOB_AUTOPILOT_EXAMPLE has
1096+
assert len(training_details.training_job_details.hyper_parameters) == 3
11021097

11031098

11041099
@patch("sagemaker.Session")
@@ -1115,7 +1110,7 @@ def test_training_details_autodiscovery_from_job_name(session):
11151110
assert len(training_details.training_job_details.training_metrics) == len(
11161111
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["FinalMetricDataList"]
11171112
)
1118-
assert len(training_details.training_job_details.training_metrics) == len(
1113+
assert len(training_details.training_job_details.hyper_parameters) == len(
11191114
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["HyperParameters"]
11201115
)
11211116
assert training_details.training_job_details.training_environment.container_image == [

0 commit comments

Comments
 (0)