Skip to content

Commit 9d386de

Browse files
ananth102akrishna1995
authored andcommitted
feat: Supporting tbac in load_run
1 parent ef7c5a0 commit 9d386de

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

src/sagemaker/experiments/run.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def __init__(
205205
)
206206

207207
if not _TrialComponent._trial_component_is_associated_to_trial(
208-
self._trial_component.trial_component_name, self._trial.trial_name, sagemaker_session
208+
self._trial_component.trial_component_name,
209+
self._trial.trial_name,
210+
sagemaker_session,
209211
):
210212
self._trial.add_trial_component(self._trial_component)
211213

@@ -771,6 +773,7 @@ def load_run(
771773
sagemaker_session: Optional["Session"] = None,
772774
artifact_bucket: Optional[str] = None,
773775
artifact_prefix: Optional[str] = None,
776+
tags: Optional[List[Dict[str, str]]] = None,
774777
) -> Run:
775778
"""Load an existing run.
776779
@@ -839,6 +842,8 @@ def load_run(
839842
will be used.
840843
artifact_prefix (str): The S3 key prefix used to generate the S3 path
841844
to upload the artifact to (default: "trial-component-artifacts").
845+
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
846+
e.g. to create an experiment, a run group, etc. (default: None).
842847
843848
Returns:
844849
Run: The loaded Run object.
@@ -860,6 +865,7 @@ def load_run(
860865
sagemaker_session=sagemaker_session or _utils.default_session(),
861866
artifact_bucket=artifact_bucket,
862867
artifact_prefix=artifact_prefix,
868+
tags=tags,
863869
)
864870
elif _RunContext.get_current_run():
865871
run_instance = _RunContext.get_current_run()
@@ -879,6 +885,7 @@ def load_run(
879885
sagemaker_session=sagemaker_session or _utils.default_session(),
880886
artifact_bucket=artifact_bucket,
881887
artifact_prefix=artifact_prefix,
888+
tags=tags,
882889
)
883890
else:
884891
raise RuntimeError(

tests/unit/sagemaker/experiments/test_run.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TEST_RUN_DISPLAY_NAME,
5656
TEST_ARTIFACT_BUCKET,
5757
TEST_ARTIFACT_PREFIX,
58+
TEST_TAGS,
5859
)
5960

6061

@@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
155156

156157

157158
@pytest.mark.parametrize(
158-
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
159+
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix", "expected_tags"),
159160
[
160-
({}, None, _DEFAULT_ARTIFACT_PREFIX),
161+
({}, None, _DEFAULT_ARTIFACT_PREFIX, None),
161162
(
162163
{
163164
"artifact_bucket": TEST_ARTIFACT_BUCKET,
164165
"artifact_prefix": TEST_ARTIFACT_PREFIX,
166+
"tags": TEST_TAGS,
165167
},
166168
TEST_ARTIFACT_BUCKET,
167169
TEST_ARTIFACT_PREFIX,
170+
TEST_TAGS,
168171
),
169172
],
170173
)
171174
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
172-
@patch(
173-
"sagemaker.experiments.run.Experiment._load_or_create",
174-
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
175-
)
176175
@patch(
177176
"sagemaker.experiments.run._Trial._load_or_create",
178177
MagicMock(side_effect=mock_trial_load_or_create_func),
@@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
189188
kwargs,
190189
expected_artifact_bucket,
191190
expected_artifact_prefix,
191+
expected_tags,
192192
):
193193
client = sagemaker_session.sagemaker_client
194194
job_name = "my-train-job"
@@ -213,26 +213,32 @@ def test_run_load_no_run_name_and_in_train_job(
213213
{
214214
"TrialComponent": {
215215
"Parents": [
216-
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
216+
{
217+
"ExperimentName": TEST_EXP_NAME,
218+
"TrialName": exp_config[TRIAL_NAME],
219+
}
217220
],
218221
"TrialComponentName": expected_tc_name,
219222
}
220223
}
221224
]
222225
}
223-
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
224-
assert run_obj._in_load
225-
assert not run_obj._inside_init_context
226-
assert run_obj._inside_load_context
227-
assert run_obj.run_name == TEST_RUN_NAME
228-
assert run_obj._trial_component.trial_component_name == expected_tc_name
229-
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
230-
assert run_obj._trial
231-
assert run_obj.experiment_name == TEST_EXP_NAME
232-
assert run_obj._experiment
233-
assert run_obj.experiment_config == exp_config
234-
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
235-
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
226+
expmock = MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME, tags=expected_tags))
227+
with patch("sagemaker.experiments.run.Experiment._load_or_create", expmock):
228+
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
229+
assert run_obj._in_load
230+
assert not run_obj._inside_init_context
231+
assert run_obj._inside_load_context
232+
assert run_obj.run_name == TEST_RUN_NAME
233+
assert run_obj._trial_component.trial_component_name == expected_tc_name
234+
assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
235+
assert run_obj._trial
236+
assert run_obj.experiment_name == TEST_EXP_NAME
237+
assert run_obj._experiment
238+
assert run_obj.experiment_config == exp_config
239+
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
240+
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
241+
assert run_obj._experiment.tags == expected_tags
236242

237243
client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
238244
run_obj._trial.add_trial_component.assert_not_called()
@@ -265,7 +271,9 @@ def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session):
265271
assert run_obj == run
266272

267273

268-
def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session):
274+
def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(
275+
sagemaker_session,
276+
):
269277
with pytest.raises(RuntimeError) as err:
270278
with load_run(sagemaker_session=sagemaker_session):
271279
pass
@@ -388,7 +396,10 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
388396
{
389397
"TrialComponent": {
390398
"Parents": [
391-
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
399+
{
400+
"ExperimentName": TEST_EXP_NAME,
401+
"TrialName": exp_config[TRIAL_NAME],
402+
}
392403
],
393404
"TrialComponentName": expected_tc_name,
394405
}
@@ -442,7 +453,10 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
442453
{
443454
"TrialComponent": {
444455
"Parents": [
445-
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
456+
{
457+
"ExperimentName": TEST_EXP_NAME,
458+
"TrialName": exp_config[TRIAL_NAME],
459+
}
446460
],
447461
"TrialComponentName": expected_tc_name,
448462
}
@@ -589,7 +603,10 @@ def test_log_output_artifact_outside_run_context(run_obj):
589603

590604

591605
def test_log_output_artifact(run_obj):
592-
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
606+
run_obj._artifact_uploader.upload_artifact.return_value = (
607+
"s3uri_value",
608+
"etag_value",
609+
)
593610
with run_obj:
594611
run_obj.log_file("foo.txt", "name", "whizz/bang")
595612
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
@@ -608,7 +625,10 @@ def test_log_input_artifact_outside_run_context(run_obj):
608625

609626

610627
def test_log_input_artifact(run_obj):
611-
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
628+
run_obj._artifact_uploader.upload_artifact.return_value = (
629+
"s3uri_value",
630+
"etag_value",
631+
)
612632
with run_obj:
613633
run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
614634
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
@@ -653,7 +673,10 @@ def test_log_multiple_input_artifacts(run_obj):
653673
"etag_value" + str(index),
654674
)
655675
run_obj.log_file(
656-
file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False
676+
file_path,
677+
"name" + str(index),
678+
"whizz/bang" + str(index),
679+
is_output=False,
657680
)
658681
run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path)
659682

@@ -753,7 +776,12 @@ def test_log_precision_recall_invalid_input(run_obj):
753776
with run_obj:
754777
with pytest.raises(ValueError) as error:
755778
run_obj.log_precision_recall(
756-
y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False
779+
y_true,
780+
y_scores,
781+
0,
782+
title="TestPrecisionRecall",
783+
no_skill=no_skill,
784+
is_output=False,
757785
)
758786
assert "Lengths mismatch between true labels and predicted probabilities" in str(error)
759787

@@ -901,7 +929,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
901929
display_name="C" + str(i),
902930
source_arn="D" + str(i),
903931
status=TrialComponentStatus(
904-
primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
932+
primary_status=_TrialComponentStatusType.InProgress.value,
933+
message="E" + str(i),
905934
),
906935
start_time=start_time + datetime.timedelta(hours=i),
907936
end_time=end_time + datetime.timedelta(hours=i),
@@ -921,7 +950,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
921950
display_name="C" + str(i),
922951
source_arn="D" + str(i),
923952
status=TrialComponentStatus(
924-
primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
953+
primary_status=_TrialComponentStatusType.InProgress.value,
954+
message="E" + str(i),
925955
),
926956
start_time=start_time + datetime.timedelta(hours=i),
927957
end_time=end_time + datetime.timedelta(hours=i),

0 commit comments

Comments
 (0)