34
34
TRAINING_DATASETS_MAX_SIZE ,
35
35
TRAINING_METRICS_MAX_SIZE ,
36
36
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE ,
37
+ HYPER_PARAMETERS_MAX_SIZE ,
38
+ USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE ,
37
39
EVALUATION_DATASETS_MAX_SIZE ,
38
40
)
39
41
from sagemaker .model_card .helpers import (
@@ -235,6 +237,27 @@ def __init__(
235
237
self .explanations_for_risk_rating = explanations_for_risk_rating
236
238
237
239
240
+ class BusinessDetails (_DefaultToRequestDict , _DefaultFromDict ):
241
+ """The business details of a model."""
242
+
243
+ def __init__ (
244
+ self ,
245
+ business_problem : Optional [str ] = None ,
246
+ business_stakeholders : Optional [str ] = None ,
247
+ line_of_business : Optional [str ] = None ,
248
+ ):
249
+ """Initialize an Business Details object.
250
+
251
+ Args:
252
+ business_problem (str, optional): The business problem of this model (default: None).
253
+ business_stakeholders (str, optional): The business stakeholders for this model (default: None).
254
+ line_of_business (str, optional): The line of business for this model (default: None).
255
+ """ # noqa E501 # pylint: disable=line-too-long
256
+ self .business_problem = business_problem
257
+ self .business_stakeholders = business_stakeholders
258
+ self .line_of_business = line_of_business
259
+
260
+
238
261
class Function (_DefaultToRequestDict , _DefaultFromDict ):
239
262
"""Function details."""
240
263
@@ -363,6 +386,27 @@ def __init__(
363
386
self .notes = notes
364
387
365
388
389
+ class HyperParameter (_DefaultToRequestDict , _DefaultFromDict ):
390
+ """Hyper-Parameters data.
391
+
392
+ Should only be used during auto-population of parameters details.
393
+ """
394
+
395
+ def __init__ (
396
+ self ,
397
+ name : str ,
398
+ value : str ,
399
+ ):
400
+ """Initialize a HyperParameter object.
401
+
402
+ Args:
403
+ name (str): The hyper parameter name.
404
+ value (str): The hyper parameter value.
405
+ """
406
+ self .name = name
407
+ self .value = value
408
+
409
+
366
410
class TrainingJobDetails (_DefaultToRequestDict , _DefaultFromDict ):
367
411
"""The overview of a training job."""
368
412
@@ -371,6 +415,10 @@ class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
371
415
user_provided_training_metrics = _IsList (
372
416
TrainingMetric , USER_PROVIDED_TRAINING_METRICS_MAX_SIZE
373
417
)
418
+ hyper_parameters = _IsList (HyperParameter , HYPER_PARAMETERS_MAX_SIZE )
419
+ user_provided_hyper_parameters = _IsList (
420
+ HyperParameter , USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE
421
+ )
374
422
training_environment = _IsModelCardObject (Environment )
375
423
376
424
def __init__ (
@@ -380,6 +428,8 @@ def __init__(
380
428
training_environment : Optional [Environment ] = None ,
381
429
training_metrics : Optional [List [TrainingMetric ]] = None ,
382
430
user_provided_training_metrics : Optional [List [TrainingMetric ]] = None ,
431
+ hyper_parameters : Optional [List [HyperParameter ]] = None ,
432
+ user_provided_hyper_parameters : Optional [List [HyperParameter ]] = None ,
383
433
):
384
434
"""Initialize a Training Job Details object.
385
435
@@ -389,12 +439,16 @@ def __init__(
389
439
training_environment (Environment, optional): The SageMaker training image URI. (default: None).
390
440
training_metrics (list[TrainingMetric], optional): SageMaker training job results. The maximum `training_metrics` list length is 50 (default: None).
391
441
user_provided_training_metrics (list[TrainingMetric], optional): Custom training job results. The maximum `user_provided_training_metrics` list length is 50 (default: None).
442
+ hyper_parameters (list[HyperParameter], optional): SageMaker hyper parameter results. The maximum `hyper_parameters` list length is 100 (default: None).
443
+ user_provided_hyper_parameters (list[HyperParameter], optional): Custom hyper parameter results. The maximum `user_provided_hyper_parameters` list length is 100 (default: None).
392
444
""" # noqa E501 # pylint: disable=line-too-long
393
445
self .training_arn = training_arn
394
446
self .training_datasets = training_datasets
395
447
self .training_environment = training_environment
396
448
self .training_metrics = training_metrics
397
449
self .user_provided_training_metrics = user_provided_training_metrics
450
+ self .hyper_parameters = hyper_parameters
451
+ self .user_provided_hyper_parameters = user_provided_hyper_parameters
398
452
399
453
400
454
class TrainingDetails (_DefaultToRequestDict , _DefaultFromDict ):
@@ -568,6 +622,16 @@ def add_metric(self, metric: TrainingMetric):
568
622
self .training_job_details = TrainingJobDetails ()
569
623
self .training_job_details .user_provided_training_metrics .append (metric )
570
624
625
+ def add_parameter (self , parameter : HyperParameter ):
626
+ """Add custom hyper-parameter.
627
+
628
+ Args:
629
+ parameter (HyperParameter): The custom parameter to add.
630
+ """
631
+ if not self .training_job_details :
632
+ self .training_job_details = TrainingJobDetails ()
633
+ self .training_job_details .user_provided_hyper_parameters .append (parameter )
634
+
571
635
572
636
class MetricGroup (_DefaultToRequestDict , _DefaultFromDict ):
573
637
"""Group of metric data"""
@@ -777,6 +841,7 @@ class ModelCard(object):
777
841
status = _OneOf (ModelCardStatusEnum )
778
842
model_overview = _IsModelCardObject (ModelOverview )
779
843
intended_uses = _IsModelCardObject (IntendedUses )
844
+ business_details = _IsModelCardObject (BusinessDetails )
780
845
training_details = _IsModelCardObject (TrainingDetails )
781
846
evaluation_details = _IsList (EvaluationJob )
782
847
additional_information = _IsModelCardObject (AdditionalInformation )
@@ -793,6 +858,7 @@ def __init__(
793
858
last_modified_by : Optional [dict ] = None ,
794
859
model_overview : Optional [ModelOverview ] = None ,
795
860
intended_uses : Optional [IntendedUses ] = None ,
861
+ business_details : Optional [BusinessDetails ] = None ,
796
862
training_details : Optional [TrainingDetails ] = None ,
797
863
evaluation_details : Optional [List [EvaluationJob ]] = None ,
798
864
additional_information : Optional [AdditionalInformation ] = None ,
@@ -811,6 +877,7 @@ def __init__(
811
877
last_modified_by (dict, optional): The group or individual that last modified the model card (default: None).
812
878
model_overview (ModelOverview, optional): An overview of the model (default: None).
813
879
intended_uses (IntendedUses, optional): The intended uses of the model (default: None).
880
+ business_details (BusinessDetails, optional): The business details of the model (default: None).
814
881
training_details (TrainingDetails, optional): The training details of the model (default: None).
815
882
evaluation_details (List[EvaluationJob], optional): The evaluation details of the model (default: None).
816
883
additional_information (AdditionalInformation, optional): Additional information about the model (default: None).
@@ -826,6 +893,7 @@ def __init__(
826
893
self .last_modified_by = last_modified_by
827
894
self .model_overview = model_overview
828
895
self .intended_uses = intended_uses
896
+ self .business_details = business_details
829
897
self .training_details = training_details
830
898
self .evaluation_details = evaluation_details
831
899
self .additional_information = additional_information
0 commit comments