Skip to content

Commit c0aee6e

Browse files
authored
feat: Add check for if TrialComponent is already associated with a Trial in Run (#3956)
* feature: add check for if TrialComponent is already associated with a Trial * move _trial_component_is_associated_to_trial to private and move it as a method under trial_component, add doc-string and unit-tests
1 parent d71880f commit c0aee6e

File tree

7 files changed

+150
-11
lines changed

7 files changed

+150
-11
lines changed

src/sagemaker/experiments/run.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525

2626
from sagemaker.apiutils import _utils
2727
from sagemaker.experiments import _api_types
28-
from sagemaker.experiments._api_types import TrialComponentArtifact, _TrialComponentStatusType
28+
from sagemaker.experiments._api_types import (
29+
TrialComponentArtifact,
30+
_TrialComponentStatusType,
31+
)
2932
from sagemaker.experiments._helper import (
3033
_ArtifactUploader,
3134
_LineageArtifactTracker,
@@ -200,7 +203,11 @@ def __init__(
200203
self.run_name,
201204
self.experiment_name,
202205
)
203-
self._trial.add_trial_component(self._trial_component)
206+
207+
if not _TrialComponent._trial_component_is_associated_to_trial(
208+
self._trial_component.trial_component_name, self._trial.trial_name, sagemaker_session
209+
):
210+
self._trial.add_trial_component(self._trial_component)
204211

205212
self._artifact_uploader = _ArtifactUploader(
206213
trial_component_name=self._trial_component.trial_component_name,
@@ -348,7 +355,10 @@ def log_precision_recall(
348355
"noSkill": no_skill,
349356
}
350357
self._log_graph_artifact(
351-
artifact_name=title, data=data, graph_type="PrecisionRecallCurve", is_output=is_output
358+
artifact_name=title,
359+
data=data,
360+
graph_type="PrecisionRecallCurve",
361+
is_output=is_output,
352362
)
353363

354364
@validate_invoked_inside_run_context
@@ -381,7 +391,9 @@ def log_roc_curve(
381391
If set to False then represented as input association.
382392
"""
383393
verify_length_of_true_and_predicted(
384-
true_labels=y_true, predicted_attrs=y_score, predicted_attrs_name="predicted scores"
394+
true_labels=y_true,
395+
predicted_attrs=y_score,
396+
predicted_attrs_name="predicted scores",
385397
)
386398

387399
get_module("sklearn")
@@ -432,7 +444,9 @@ def log_confusion_matrix(
432444
If set to False then represented as input association.
433445
"""
434446
verify_length_of_true_and_predicted(
435-
true_labels=y_true, predicted_attrs=y_pred, predicted_attrs_name="predicted labels"
447+
true_labels=y_true,
448+
predicted_attrs=y_pred,
449+
predicted_attrs_name="predicted labels",
436450
)
437451

438452
get_module("sklearn")
@@ -447,12 +461,19 @@ def log_confusion_matrix(
447461
"confusionMatrix": matrix.tolist(),
448462
}
449463
self._log_graph_artifact(
450-
artifact_name=title, data=data, graph_type="ConfusionMatrix", is_output=is_output
464+
artifact_name=title,
465+
data=data,
466+
graph_type="ConfusionMatrix",
467+
is_output=is_output,
451468
)
452469

453470
@validate_invoked_inside_run_context
454471
def log_artifact(
455-
self, name: str, value: str, media_type: Optional[str] = None, is_output: bool = True
472+
self,
473+
name: str,
474+
value: str,
475+
media_type: Optional[str] = None,
476+
is_output: bool = True,
456477
):
457478
"""Record a single artifact for this run.
458479
@@ -575,11 +596,17 @@ def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None):
575596
# create an artifact and association for the table
576597
if is_output:
577598
self._lineage_artifact_tracker.add_output_artifact(
578-
name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type
599+
name=artifact_name,
600+
source_uri=s3_uri,
601+
etag=etag,
602+
artifact_type=graph_type,
579603
)
580604
else:
581605
self._lineage_artifact_tracker.add_input_artifact(
582-
name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type
606+
name=artifact_name,
607+
source_uri=s3_uri,
608+
etag=etag,
609+
artifact_type=graph_type,
583610
)
584611

585612
def _verify_trial_component_artifacts_length(self, is_output):
@@ -719,7 +746,8 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
719746
self._trial_component.end_time = end_time
720747
if exc_value:
721748
self._trial_component.status = _api_types.TrialComponentStatus(
722-
primary_status=_TrialComponentStatusType.Failed.value, message=str(exc_value)
749+
primary_status=_TrialComponentStatusType.Failed.value,
750+
message=str(exc_value),
723751
)
724752
else:
725753
self._trial_component.status = _api_types.TrialComponentStatus(
@@ -837,7 +865,8 @@ def load_run(
837865
run_instance = _RunContext.get_current_run()
838866
elif environment:
839867
exp_config = get_tc_and_exp_config_from_job_env(
840-
environment=environment, sagemaker_session=sagemaker_session or _utils.default_session()
868+
environment=environment,
869+
sagemaker_session=sagemaker_session or _utils.default_session(),
841870
)
842871
run_name = Run._extract_run_name_from_tc_name(
843872
trial_component_name=exp_config[RUN_NAME],

src/sagemaker/experiments/trial_component.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,42 @@ def _load_or_create(
345345
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
346346
is_existed = True
347347
return run_tc, is_existed
348+
349+
@classmethod
350+
def _trial_component_is_associated_to_trial(
351+
cls, trial_component_name, trial_name=None, sagemaker_session=None
352+
):
353+
"""Returns a bool based on if trial_component is already associated with the trial.
354+
355+
Args:
356+
trial_component_name (str): The name of the trial component.
357+
trial_name: (str): The name of the trial.
358+
sagemaker_session (sagemaker.session.Session): Session object which
359+
manages interactions with Amazon SageMaker APIs and any other
360+
AWS services needed.
361+
362+
Returns:
363+
bool: A boolean variable indicating whether the trial component is already
364+
associated with the trial.
365+
366+
"""
367+
search_results = sagemaker_session.sagemaker_client.search(
368+
Resource="ExperimentTrialComponent",
369+
SearchExpression={
370+
"Filters": [
371+
{
372+
"Name": "TrialComponentName",
373+
"Operator": "Equals",
374+
"Value": str(trial_component_name),
375+
},
376+
{
377+
"Name": "Parents.TrialName",
378+
"Operator": "Equals",
379+
"Value": str(trial_name),
380+
},
381+
]
382+
},
383+
)
384+
if search_results["Results"]:
385+
return True
386+
return False

tests/unit/sagemaker/experiments/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def run_obj(sagemaker_session):
7272
"sagemaker.experiments.run._Trial._load_or_create",
7373
MagicMock(side_effect=mock_trial_load_or_create_func),
7474
):
75+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
7576
run = Run(
7677
experiment_name=TEST_EXP_NAME,
7778
sagemaker_session=sagemaker_session,

tests/unit/sagemaker/experiments/test_run.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def test_run_init(
9393
expected_artifact_bucket,
9494
expected_artifact_prefix,
9595
):
96+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
9697
with Run(
9798
experiment_name=TEST_EXP_NAME,
9899
run_name=TEST_RUN_NAME,
@@ -121,6 +122,7 @@ def test_run_init(
121122

122123
# trail_component.save is called when entering/ exiting the with block
123124
mock_tc_save.assert_called()
125+
run_obj._trial.add_trial_component.assert_called()
124126

125127

126128
def test_run_init_name_length_exceed_limit(sagemaker_session):
@@ -206,6 +208,18 @@ def test_run_load_no_run_name_and_in_train_job(
206208
# The Run object has been created else where
207209
"ExperimentConfig": exp_config,
208210
}
211+
sagemaker_session.sagemaker_client.search.return_value = {
212+
"Results": [
213+
{
214+
"TrialComponent": {
215+
"Parents": [
216+
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
217+
],
218+
"TrialComponentName": expected_tc_name,
219+
}
220+
}
221+
]
222+
}
209223
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
210224
assert run_obj._in_load
211225
assert not run_obj._inside_init_context
@@ -221,6 +235,7 @@ def test_run_load_no_run_name_and_in_train_job(
221235
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
222236

223237
client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
238+
run_obj._trial.add_trial_component.assert_not_called()
224239

225240

226241
@patch("sagemaker.experiments.run._RunEnvironment")
@@ -296,6 +311,7 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
296311
def test_run_load_with_run_name_and_exp_name(
297312
sagemaker_session, kwargs, expected_artifact_bucket, expected_artifact_prefix
298313
):
314+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
299315
with load_run(
300316
run_name=TEST_RUN_NAME,
301317
experiment_name=TEST_EXP_NAME,
@@ -319,6 +335,8 @@ def test_run_load_with_run_name_and_exp_name(
319335
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
320336
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
321337

338+
run_obj._trial.add_trial_component.assert_called()
339+
322340

323341
def test_run_load_with_run_name_but_no_exp_name(sagemaker_session):
324342
with pytest.raises(ValueError) as err:
@@ -365,11 +383,24 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
365383
# The Run object has been created else where
366384
"ExperimentConfig": exp_config,
367385
}
386+
sagemaker_session.sagemaker_client.search.return_value = {
387+
"Results": [
388+
{
389+
"TrialComponent": {
390+
"Parents": [
391+
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
392+
],
393+
"TrialComponentName": expected_tc_name,
394+
}
395+
}
396+
]
397+
}
368398

369399
with load_run(sagemaker_session=sagemaker_session):
370400
pass
371401

372402
client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name)
403+
mock_run_env._trial.add_trial_component.assert_not_called()
373404

374405

375406
@patch(
@@ -406,11 +437,24 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
406437
# The Run object has been created else where
407438
"ExperimentConfig": exp_config,
408439
}
440+
sagemaker_session.sagemaker_client.search.return_value = {
441+
"Results": [
442+
{
443+
"TrialComponent": {
444+
"Parents": [
445+
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
446+
],
447+
"TrialComponentName": expected_tc_name,
448+
}
449+
}
450+
]
451+
}
409452

410453
with load_run(sagemaker_session=sagemaker_session):
411454
pass
412455

413456
client.describe_transform_job.assert_called_once_with(TransformJobName=job_name)
457+
mock_run_env._trial.add_trial_component.assert_not_called()
414458

415459

416460
@patch(
@@ -428,6 +472,7 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
428472
)
429473
@patch.object(_TrialComponent, "save")
430474
def test_run_object_serialize_deserialize(mock_tc_save, sagemaker_session):
475+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
431476
run_obj = Run(
432477
experiment_name=TEST_EXP_NAME,
433478
run_name=TEST_RUN_NAME,

tests/unit/sagemaker/experiments/test_trial_component.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,26 @@ def test_search(sagemaker_session):
396396
),
397397
]
398398
assert expected == list(_TrialComponent.search(sagemaker_session=sagemaker_session))
399+
400+
401+
def test_trial_component_is_associated_to_trial(sagemaker_session):
402+
obj = _TrialComponent(sagemaker_session, trial_component_name="tc-1")
403+
sagemaker_session.sagemaker_client.search.return_value = {
404+
"Results": [
405+
{
406+
"TrialComponent": {
407+
"Parents": [{"ExperimentName": "e-1", "TrialName": "t-1"}],
408+
"TrialComponentName": "tc-1",
409+
}
410+
}
411+
]
412+
}
413+
414+
assert obj._trial_component_is_associated_to_trial("tc-1", "t-1", sagemaker_session) is True
415+
416+
417+
def test_trial_component_is_not_associated_to_trial(sagemaker_session):
418+
obj = _TrialComponent(sagemaker_session, trial_component_name="tc-1")
419+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
420+
421+
assert obj._trial_component_is_associated_to_trial("tc-1", "t-1", sagemaker_session) is False

tests/unit/sagemaker/remote_function/core/test_stored_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def test_save_with_parameter_of_run_type(
109109
):
110110
session = Mock()
111111
s3_base_uri = random_s3_uri()
112+
session.sagemaker_client.search.return_value = {"Results": []}
112113

113114
run = Run(
114115
experiment_name=TEST_EXP_NAME,

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def run_obj(sagemaker_session):
114114
"sagemaker.experiments.run._Trial._load_or_create",
115115
MagicMock(side_effect=mock_trial_load_or_create_func),
116116
):
117+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
117118
run = Run(
118119
experiment_name="test-exp",
119120
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)