34
34
TRAINING_DATASETS_MAX_SIZE ,
35
35
TRAINING_METRICS_MAX_SIZE ,
36
36
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE ,
37
+ HYPER_PARAMETERS_MAX_SIZE ,
38
+ USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE ,
37
39
EVALUATION_DATASETS_MAX_SIZE ,
38
40
)
39
41
from sagemaker .model_card .helpers import (
@@ -235,6 +237,27 @@ def __init__(
235
237
self .explanations_for_risk_rating = explanations_for_risk_rating
236
238
237
239
240
+ class BusinessDetails (_DefaultToRequestDict , _DefaultFromDict ):
241
+ """The business details of a model."""
242
+
243
+ def __init__ (
244
+ self ,
245
+ business_problem : Optional [str ] = None ,
246
+ business_stakeholders : Optional [str ] = None ,
247
+ line_of_business : Optional [str ] = None ,
248
+ ):
249
+ """Initialize an Business Details object.
250
+
251
+ Args:
252
+ business_problem (str, optional): The business problem of this model (default: None).
253
+ business_stakeholders (str, optional): The business stakeholders for this model (default: None).
254
+ line_of_business (str, optional): The line of business for this model (default: None).
255
+ """ # noqa E501 # pylint: disable=line-too-long
256
+ self .business_problem = business_problem
257
+ self .business_stakeholders = business_stakeholders
258
+ self .line_of_business = line_of_business
259
+
260
+
238
261
class Function (_DefaultToRequestDict , _DefaultFromDict ):
239
262
"""Function details."""
240
263
@@ -363,6 +386,24 @@ def __init__(
363
386
self .notes = notes
364
387
365
388
389
+ class HyperParameter (_DefaultToRequestDict , _DefaultFromDict ):
390
+ """Hyper-Parameters data."""
391
+
392
+ def __init__ (
393
+ self ,
394
+ name : str ,
395
+ value : str ,
396
+ ):
397
+ """Initialize a HyperParameter object.
398
+
399
+ Args:
400
+ name (str): The hyper parameter name.
401
+ value (str): The hyper parameter value.
402
+ """
403
+ self .name = name
404
+ self .value = value
405
+
406
+
366
407
class TrainingJobDetails (_DefaultToRequestDict , _DefaultFromDict ):
367
408
"""The overview of a training job."""
368
409
@@ -371,6 +412,10 @@ class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
371
412
user_provided_training_metrics = _IsList (
372
413
TrainingMetric , USER_PROVIDED_TRAINING_METRICS_MAX_SIZE
373
414
)
415
+ hyper_parameters = _IsList (HyperParameter , HYPER_PARAMETERS_MAX_SIZE )
416
+ user_provided_hyper_parameters = _IsList (
417
+ HyperParameter , USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE
418
+ )
374
419
training_environment = _IsModelCardObject (Environment )
375
420
376
421
def __init__ (
@@ -380,6 +425,8 @@ def __init__(
380
425
training_environment : Optional [Environment ] = None ,
381
426
training_metrics : Optional [List [TrainingMetric ]] = None ,
382
427
user_provided_training_metrics : Optional [List [TrainingMetric ]] = None ,
428
+ hyper_parameters : Optional [List [HyperParameter ]] = None ,
429
+ user_provided_hyper_parameters : Optional [List [HyperParameter ]] = None ,
383
430
):
384
431
"""Initialize a Training Job Details object.
385
432
@@ -389,12 +436,16 @@ def __init__(
389
436
training_environment (Environment, optional): The SageMaker training image URI. (default: None).
390
437
training_metrics (list[TrainingMetric], optional): SageMaker training job results. The maximum `training_metrics` list length is 50 (default: None).
391
438
user_provided_training_metrics (list[TrainingMetric], optional): Custom training job results. The maximum `user_provided_training_metrics` list length is 50 (default: None).
439
+ hyper_parameters (list[HyperParameter], optional): SageMaker hyper parameter results. The maximum `hyper_parameters` list length is 100 (default: None).
440
+ user_provided_hyper_parameters (list[HyperParameter], optional): Custom hyper parameter results. The maximum `user_provided_hyper_parameters` list length is 100 (default: None).
392
441
""" # noqa E501 # pylint: disable=line-too-long
393
442
self .training_arn = training_arn
394
443
self .training_datasets = training_datasets
395
444
self .training_environment = training_environment
396
445
self .training_metrics = training_metrics
397
446
self .user_provided_training_metrics = user_provided_training_metrics
447
+ self .hyper_parameters = hyper_parameters
448
+ self .user_provided_hyper_parameters = user_provided_hyper_parameters
398
449
399
450
400
451
class TrainingDetails (_DefaultToRequestDict , _DefaultFromDict ):
@@ -442,6 +493,11 @@ def _create_training_details(training_job_data: dict, cls: "TrainingDetails", **
442
493
]
443
494
if "FinalMetricDataList" in training_job_data
444
495
else [],
496
+ # "hyper_parameters": print(training_job_data["HyperParameters"]),
497
+ "hyper_parameters" : [
498
+ HyperParameter (key , value )
499
+ for key , value in training_job_data ["HyperParameters" ].items ()
500
+ ],
445
501
}
446
502
kwargs .update ({"training_job_details" : TrainingJobDetails (** job )})
447
503
instance = cls (** kwargs )
@@ -568,6 +624,16 @@ def add_metric(self, metric: TrainingMetric):
568
624
self .training_job_details = TrainingJobDetails ()
569
625
self .training_job_details .user_provided_training_metrics .append (metric )
570
626
627
+ def add_parameter (self , parameter : HyperParameter ):
628
+ """Add custom hyper-parameter.
629
+
630
+ Args:
631
+ parameter (HyperParameter): The custom parameter to add.
632
+ """
633
+ if not self .training_job_details :
634
+ self .training_job_details = TrainingJobDetails ()
635
+ self .training_job_details .user_provided_hyper_parameters .append (parameter )
636
+
571
637
572
638
class MetricGroup (_DefaultToRequestDict , _DefaultFromDict ):
573
639
"""Group of metric data"""
@@ -777,6 +843,7 @@ class ModelCard(object):
777
843
status = _OneOf (ModelCardStatusEnum )
778
844
model_overview = _IsModelCardObject (ModelOverview )
779
845
intended_uses = _IsModelCardObject (IntendedUses )
846
+ business_details = _IsModelCardObject (BusinessDetails )
780
847
training_details = _IsModelCardObject (TrainingDetails )
781
848
evaluation_details = _IsList (EvaluationJob )
782
849
additional_information = _IsModelCardObject (AdditionalInformation )
@@ -793,6 +860,7 @@ def __init__(
793
860
last_modified_by : Optional [dict ] = None ,
794
861
model_overview : Optional [ModelOverview ] = None ,
795
862
intended_uses : Optional [IntendedUses ] = None ,
863
+ business_details : Optional [BusinessDetails ] = None ,
796
864
training_details : Optional [TrainingDetails ] = None ,
797
865
evaluation_details : Optional [List [EvaluationJob ]] = None ,
798
866
additional_information : Optional [AdditionalInformation ] = None ,
@@ -811,6 +879,7 @@ def __init__(
811
879
last_modified_by (dict, optional): The group or individual that last modified the model card (default: None).
812
880
model_overview (ModelOverview, optional): An overview of the model (default: None).
813
881
intended_uses (IntendedUses, optional): The intended uses of the model (default: None).
882
+ business_details (BusinessDetails, optional): The business details of the model (default: None).
814
883
training_details (TrainingDetails, optional): The training details of the model (default: None).
815
884
evaluation_details (List[EvaluationJob], optional): The evaluation details of the model (default: None).
816
885
additional_information (AdditionalInformation, optional): Additional information about the model (default: None).
@@ -826,6 +895,7 @@ def __init__(
826
895
self .last_modified_by = last_modified_by
827
896
self .model_overview = model_overview
828
897
self .intended_uses = intended_uses
898
+ self .business_details = business_details
829
899
self .training_details = training_details
830
900
self .evaluation_details = evaluation_details
831
901
self .additional_information = additional_information
0 commit comments