Skip to content

Commit d3d07b4

Browse files
authored
feat: add experiment_config for clarify processing job (#2287)
* feat: add experiment_config for clarify processing job * chore: add comment
1 parent 564a061 commit d3d07b4

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

src/sagemaker/clarify.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def _run(
403403
logs,
404404
job_name,
405405
kms_key,
406+
experiment_config,
406407
):
407408
"""Runs a ProcessingJob with the Sagemaker Clarify container and an analysis config.
408409
@@ -415,6 +416,9 @@ def _run(
415416
job_name (str): Processing job name.
416417
kms_key (str): The ARN of the KMS key that is used to encrypt the
417418
user code file (default: None).
419+
experiment_config (dict[str, str]): Experiment management configuration.
420+
Dictionary contains three optional keys:
421+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
418422
"""
419423
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
420424
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -457,6 +461,7 @@ def _run(
457461
logs=logs,
458462
job_name=job_name,
459463
kms_key=kms_key,
464+
experiment_config=experiment_config,
460465
)
461466

462467
def run_pre_training_bias(
@@ -468,6 +473,7 @@ def run_pre_training_bias(
468473
logs=True,
469474
job_name=None,
470475
kms_key=None,
476+
experiment_config=None,
471477
):
472478
"""Runs a ProcessingJob to compute the requested bias 'methods' of the input data.
473479
@@ -487,13 +493,16 @@ def run_pre_training_bias(
487493
"Clarify-Pretraining-Bias" and current timestamp.
488494
kms_key (str): The ARN of the KMS key that is used to encrypt the
489495
user code file (default: None).
496+
experiment_config (dict[str, str]): Experiment management configuration.
497+
Dictionary contains three optional keys:
498+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
490499
"""
491500
analysis_config = data_config.get_config()
492501
analysis_config.update(data_bias_config.get_config())
493502
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
494503
if job_name is None:
495504
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
496-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key)
505+
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
497506

498507
def run_post_training_bias(
499508
self,
@@ -506,6 +515,7 @@ def run_post_training_bias(
506515
logs=True,
507516
job_name=None,
508517
kms_key=None,
518+
experiment_config=None,
509519
):
510520
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
511521
@@ -532,6 +542,9 @@ def run_post_training_bias(
532542
"Clarify-Posttraining-Bias" and current timestamp.
533543
kms_key (str): The ARN of the KMS key that is used to encrypt the
534544
user code file (default: None).
545+
experiment_config (dict[str, str]): Experiment management configuration.
546+
Dictionary contains three optional keys:
547+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
535548
"""
536549
analysis_config = data_config.get_config()
537550
analysis_config.update(data_bias_config.get_config())
@@ -545,7 +558,7 @@ def run_post_training_bias(
545558
_set(probability_threshold, "probability_threshold", analysis_config)
546559
if job_name is None:
547560
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
548-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key)
561+
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
549562

550563
def run_bias(
551564
self,
@@ -559,6 +572,7 @@ def run_bias(
559572
logs=True,
560573
job_name=None,
561574
kms_key=None,
575+
experiment_config=None,
562576
):
563577
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
564578
@@ -589,6 +603,9 @@ def run_bias(
589603
"Clarify-Bias" and current timestamp.
590604
kms_key (str): The ARN of the KMS key that is used to encrypt the
591605
user code file (default: None).
606+
experiment_config (dict[str, str]): Experiment management configuration.
607+
Dictionary contains three optional keys:
608+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
592609
"""
593610
analysis_config = data_config.get_config()
594611
analysis_config.update(bias_config.get_config())
@@ -609,7 +626,7 @@ def run_bias(
609626
}
610627
if job_name is None:
611628
job_name = utils.name_from_base("Clarify-Bias")
612-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key)
629+
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
613630

614631
def run_explainability(
615632
self,
@@ -621,6 +638,7 @@ def run_explainability(
621638
logs=True,
622639
job_name=None,
623640
kms_key=None,
641+
experiment_config=None,
624642
):
625643
"""Runs a ProcessingJob computing for each example in the input the feature importance.
626644
@@ -649,6 +667,9 @@ def run_explainability(
649667
"Clarify-Explainability" and current timestamp.
650668
kms_key (str): The ARN of the KMS key that is used to encrypt the
651669
user code file (default: None).
670+
experiment_config (dict[str, str]): Experiment management configuration.
671+
Dictionary contains three optional keys:
672+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
652673
"""
653674
analysis_config = data_config.get_config()
654675
predictor_config = model_config.get_predictor_config()
@@ -657,7 +678,7 @@ def run_explainability(
657678
analysis_config["predictor"] = predictor_config
658679
if job_name is None:
659680
job_name = utils.name_from_base("Clarify-Explainability")
660-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key)
681+
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
661682

662683

663684
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

tests/unit/test_clarify.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,11 @@ def shap_config():
287287
def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
288288
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
289289
clarify_processor.run_pre_training_bias(
290-
data_config, data_bias_config, wait=True, job_name="test"
290+
data_config,
291+
data_bias_config,
292+
wait=True,
293+
job_name="test",
294+
experiment_config={"ExperimentName": "AnExperiment"},
291295
)
292296
expected_analysis_config = {
293297
"dataset_type": "text/csv",
@@ -304,7 +308,13 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
304308
"methods": {"pre_training_bias": {"methods": "all"}},
305309
}
306310
mock_method.assert_called_once_with(
307-
data_config, expected_analysis_config, True, True, "test", None
311+
data_config,
312+
expected_analysis_config,
313+
True,
314+
True,
315+
"test",
316+
None,
317+
{"ExperimentName": "AnExperiment"},
308318
)
309319

310320

@@ -319,6 +329,7 @@ def test_post_training_bias(
319329
model_predicted_label_config,
320330
wait=True,
321331
job_name="test",
332+
experiment_config={"ExperimentName": "AnExperiment"},
322333
)
323334
expected_analysis_config = {
324335
"dataset_type": "text/csv",
@@ -340,14 +351,26 @@ def test_post_training_bias(
340351
},
341352
}
342353
mock_method.assert_called_once_with(
343-
data_config, expected_analysis_config, True, True, "test", None
354+
data_config,
355+
expected_analysis_config,
356+
True,
357+
True,
358+
"test",
359+
None,
360+
{"ExperimentName": "AnExperiment"},
344361
)
345362

346363

347364
def test_shap(clarify_processor, data_config, model_config, shap_config):
348365
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
349366
clarify_processor.run_explainability(
350-
data_config, model_config, shap_config, model_scores=None, wait=True, job_name="test"
367+
data_config,
368+
model_config,
369+
shap_config,
370+
model_scores=None,
371+
wait=True,
372+
job_name="test",
373+
experiment_config={"ExperimentName": "AnExperiment"},
351374
)
352375
expected_analysis_config = {
353376
"dataset_type": "text/csv",
@@ -380,5 +403,11 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
380403
},
381404
}
382405
mock_method.assert_called_once_with(
383-
data_config, expected_analysis_config, True, True, "test", None
406+
data_config,
407+
expected_analysis_config,
408+
True,
409+
True,
410+
"test",
411+
None,
412+
{"ExperimentName": "AnExperiment"},
384413
)

0 commit comments

Comments
 (0)