@@ -403,6 +403,7 @@ def _run(
403
403
logs ,
404
404
job_name ,
405
405
kms_key ,
406
+ experiment_config ,
406
407
):
407
408
"""Runs a ProcessingJob with the Sagemaker Clarify container and an analysis config.
408
409
@@ -415,6 +416,9 @@ def _run(
415
416
job_name (str): Processing job name.
416
417
kms_key (str): The ARN of the KMS key that is used to encrypt the
417
418
user code file (default: None).
419
+ experiment_config (dict[str, str]): Experiment management configuration.
420
+ Dictionary contains three optional keys:
421
+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
418
422
"""
419
423
analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
420
424
with tempfile .TemporaryDirectory () as tmpdirname :
@@ -457,6 +461,7 @@ def _run(
457
461
logs = logs ,
458
462
job_name = job_name ,
459
463
kms_key = kms_key ,
464
+ experiment_config = experiment_config ,
460
465
)
461
466
462
467
def run_pre_training_bias (
@@ -468,6 +473,7 @@ def run_pre_training_bias(
468
473
logs = True ,
469
474
job_name = None ,
470
475
kms_key = None ,
476
+ experiment_config = None ,
471
477
):
472
478
"""Runs a ProcessingJob to compute the requested bias 'methods' of the input data.
473
479
@@ -487,13 +493,16 @@ def run_pre_training_bias(
487
493
"Clarify-Pretraining-Bias" and current timestamp.
488
494
kms_key (str): The ARN of the KMS key that is used to encrypt the
489
495
user code file (default: None).
496
+ experiment_config (dict[str, str]): Experiment management configuration.
497
+ Dictionary contains three optional keys:
498
+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
490
499
"""
491
500
analysis_config = data_config .get_config ()
492
501
analysis_config .update (data_bias_config .get_config ())
493
502
analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
494
503
if job_name is None :
495
504
job_name = utils .name_from_base ("Clarify-Pretraining-Bias" )
496
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
505
+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
497
506
498
507
def run_post_training_bias (
499
508
self ,
@@ -506,6 +515,7 @@ def run_post_training_bias(
506
515
logs = True ,
507
516
job_name = None ,
508
517
kms_key = None ,
518
+ experiment_config = None ,
509
519
):
510
520
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
511
521
@@ -532,6 +542,9 @@ def run_post_training_bias(
532
542
"Clarify-Posttraining-Bias" and current timestamp.
533
543
kms_key (str): The ARN of the KMS key that is used to encrypt the
534
544
user code file (default: None).
545
+ experiment_config (dict[str, str]): Experiment management configuration.
546
+ Dictionary contains three optional keys:
547
+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
535
548
"""
536
549
analysis_config = data_config .get_config ()
537
550
analysis_config .update (data_bias_config .get_config ())
@@ -545,7 +558,7 @@ def run_post_training_bias(
545
558
_set (probability_threshold , "probability_threshold" , analysis_config )
546
559
if job_name is None :
547
560
job_name = utils .name_from_base ("Clarify-Posttraining-Bias" )
548
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
561
+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
549
562
550
563
def run_bias (
551
564
self ,
@@ -559,6 +572,7 @@ def run_bias(
559
572
logs = True ,
560
573
job_name = None ,
561
574
kms_key = None ,
575
+ experiment_config = None ,
562
576
):
563
577
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
564
578
@@ -589,6 +603,9 @@ def run_bias(
589
603
"Clarify-Bias" and current timestamp.
590
604
kms_key (str): The ARN of the KMS key that is used to encrypt the
591
605
user code file (default: None).
606
+ experiment_config (dict[str, str]): Experiment management configuration.
607
+ Dictionary contains three optional keys:
608
+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
592
609
"""
593
610
analysis_config = data_config .get_config ()
594
611
analysis_config .update (bias_config .get_config ())
@@ -609,7 +626,7 @@ def run_bias(
609
626
}
610
627
if job_name is None :
611
628
job_name = utils .name_from_base ("Clarify-Bias" )
612
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
629
+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
613
630
614
631
def run_explainability (
615
632
self ,
@@ -621,6 +638,7 @@ def run_explainability(
621
638
logs = True ,
622
639
job_name = None ,
623
640
kms_key = None ,
641
+ experiment_config = None ,
624
642
):
625
643
"""Runs a ProcessingJob computing for each example in the input the feature importance.
626
644
@@ -657,7 +675,7 @@ def run_explainability(
657
675
analysis_config ["predictor" ] = predictor_config
658
676
if job_name is None :
659
677
job_name = utils .name_from_base ("Clarify-Explainability" )
660
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
678
+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
661
679
662
680
663
681
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments