Skip to content

Commit f824ccc

Browse files
committed
feature: add check for if TrialComponent is already associated with a Trial
1 parent 12b19ee commit f824ccc

File tree

5 files changed

+108
-11
lines changed

5 files changed

+108
-11
lines changed

src/sagemaker/experiments/run.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525

2626
from sagemaker.apiutils import _utils
2727
from sagemaker.experiments import _api_types
28-
from sagemaker.experiments._api_types import TrialComponentArtifact, _TrialComponentStatusType
28+
from sagemaker.experiments._api_types import (
29+
TrialComponentArtifact,
30+
_TrialComponentStatusType,
31+
)
2932
from sagemaker.experiments._helper import (
3033
_ArtifactUploader,
3134
_LineageArtifactTracker,
@@ -192,7 +195,31 @@ def __init__(
192195
self.run_name,
193196
self.experiment_name,
194197
)
195-
self._trial.add_trial_component(self._trial_component)
198+
199+
def search_trial_component_associated_trial():
200+
search_results = sagemaker_session.sagemaker_client.search(
201+
Resource="ExperimentTrialComponent",
202+
SearchExpression={
203+
"Filters": [
204+
{
205+
"Name": "TrialComponentName",
206+
"Operator": "Equals",
207+
"Value": str(self._trial_component.trial_component_name),
208+
},
209+
{
210+
"Name": "Parents.TrialName",
211+
"Operator": "Equals",
212+
"Value": str(self._trial.trial_name),
213+
},
214+
]
215+
},
216+
)
217+
if search_results["Results"]:
218+
return True
219+
return False
220+
221+
if not search_trial_component_associated_trial():
222+
self._trial.add_trial_component(self._trial_component)
196223

197224
self._artifact_uploader = _ArtifactUploader(
198225
trial_component_name=self._trial_component.trial_component_name,
@@ -336,7 +363,10 @@ def log_precision_recall(
336363
"noSkill": no_skill,
337364
}
338365
self._log_graph_artifact(
339-
artifact_name=title, data=data, graph_type="PrecisionRecallCurve", is_output=is_output
366+
artifact_name=title,
367+
data=data,
368+
graph_type="PrecisionRecallCurve",
369+
is_output=is_output,
340370
)
341371

342372
@validate_invoked_inside_run_context
@@ -369,7 +399,9 @@ def log_roc_curve(
369399
If set to False then represented as input association.
370400
"""
371401
verify_length_of_true_and_predicted(
372-
true_labels=y_true, predicted_attrs=y_score, predicted_attrs_name="predicted scores"
402+
true_labels=y_true,
403+
predicted_attrs=y_score,
404+
predicted_attrs_name="predicted scores",
373405
)
374406

375407
get_module("sklearn")
@@ -420,7 +452,9 @@ def log_confusion_matrix(
420452
If set to False then represented as input association.
421453
"""
422454
verify_length_of_true_and_predicted(
423-
true_labels=y_true, predicted_attrs=y_pred, predicted_attrs_name="predicted labels"
455+
true_labels=y_true,
456+
predicted_attrs=y_pred,
457+
predicted_attrs_name="predicted labels",
424458
)
425459

426460
get_module("sklearn")
@@ -435,12 +469,19 @@ def log_confusion_matrix(
435469
"confusionMatrix": matrix.tolist(),
436470
}
437471
self._log_graph_artifact(
438-
artifact_name=title, data=data, graph_type="ConfusionMatrix", is_output=is_output
472+
artifact_name=title,
473+
data=data,
474+
graph_type="ConfusionMatrix",
475+
is_output=is_output,
439476
)
440477

441478
@validate_invoked_inside_run_context
442479
def log_artifact(
443-
self, name: str, value: str, media_type: Optional[str] = None, is_output: bool = True
480+
self,
481+
name: str,
482+
value: str,
483+
media_type: Optional[str] = None,
484+
is_output: bool = True,
444485
):
445486
"""Record a single artifact for this run.
446487
@@ -563,11 +604,17 @@ def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None):
563604
# create an artifact and association for the table
564605
if is_output:
565606
self._lineage_artifact_tracker.add_output_artifact(
566-
name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type
607+
name=artifact_name,
608+
source_uri=s3_uri,
609+
etag=etag,
610+
artifact_type=graph_type,
567611
)
568612
else:
569613
self._lineage_artifact_tracker.add_input_artifact(
570-
name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type
614+
name=artifact_name,
615+
source_uri=s3_uri,
616+
etag=etag,
617+
artifact_type=graph_type,
571618
)
572619

573620
def _verify_trial_component_artifacts_length(self, is_output):
@@ -707,7 +754,8 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
707754
self._trial_component.end_time = end_time
708755
if exc_value:
709756
self._trial_component.status = _api_types.TrialComponentStatus(
710-
primary_status=_TrialComponentStatusType.Failed.value, message=str(exc_value)
757+
primary_status=_TrialComponentStatusType.Failed.value,
758+
message=str(exc_value),
711759
)
712760
else:
713761
self._trial_component.status = _api_types.TrialComponentStatus(
@@ -816,7 +864,8 @@ def load_run(
816864
run_instance = _RunContext.get_current_run()
817865
elif environment:
818866
exp_config = get_tc_and_exp_config_from_job_env(
819-
environment=environment, sagemaker_session=sagemaker_session or _utils.default_session()
867+
environment=environment,
868+
sagemaker_session=sagemaker_session or _utils.default_session(),
820869
)
821870
run_name = Run._extract_run_name_from_tc_name(
822871
trial_component_name=exp_config[RUN_NAME],

tests/unit/sagemaker/experiments/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def run_obj(sagemaker_session):
7272
"sagemaker.experiments.run._Trial._load_or_create",
7373
MagicMock(side_effect=mock_trial_load_or_create_func),
7474
):
75+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
7576
run = Run(
7677
experiment_name=TEST_EXP_NAME,
7778
sagemaker_session=sagemaker_session,

tests/unit/sagemaker/experiments/test_run.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
)
7171
@patch.object(_TrialComponent, "save")
7272
def test_run_init(mock_tc_save, sagemaker_session):
73+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
7374
with Run(
7475
experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session
7576
) as run_obj:
@@ -93,6 +94,7 @@ def test_run_init(mock_tc_save, sagemaker_session):
9394

9495
# trail_component.save is called when entering/ exiting the with block
9596
mock_tc_save.assert_called()
97+
run_obj._trial.add_trial_component.assert_called()
9698

9799

98100
def test_run_init_name_length_exceed_limit(sagemaker_session):
@@ -158,6 +160,18 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
158160
# The Run object has been created else where
159161
"ExperimentConfig": exp_config,
160162
}
163+
sagemaker_session.sagemaker_client.search.return_value = {
164+
"Results": [
165+
{
166+
"TrialComponent": {
167+
"Parents": [
168+
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
169+
],
170+
"TrialComponentName": expected_tc_name,
171+
}
172+
}
173+
]
174+
}
161175
with load_run(sagemaker_session=sagemaker_session) as run_obj:
162176
assert run_obj._in_load
163177
assert not run_obj._inside_init_context
@@ -171,6 +185,7 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
171185
assert run_obj.experiment_config == exp_config
172186

173187
client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
188+
run_obj._trial.add_trial_component.assert_not_called()
174189

175190

176191
@patch("sagemaker.experiments.run._RunEnvironment")
@@ -230,6 +245,7 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
230245
MagicMock(side_effect=mock_tc_load_or_create_func),
231246
)
232247
def test_run_load_with_run_name_and_exp_name(sagemaker_session):
248+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
233249
with load_run(
234250
run_name=TEST_RUN_NAME,
235251
experiment_name=TEST_EXP_NAME,
@@ -250,6 +266,8 @@ def test_run_load_with_run_name_and_exp_name(sagemaker_session):
250266
assert run_obj._experiment
251267
assert run_obj.experiment_config == expected_exp_config
252268

269+
run_obj._trial.add_trial_component.assert_called()
270+
253271

254272
def test_run_load_with_run_name_but_no_exp_name(sagemaker_session):
255273
with pytest.raises(ValueError) as err:
@@ -296,11 +314,24 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
296314
# The Run object has been created else where
297315
"ExperimentConfig": exp_config,
298316
}
317+
sagemaker_session.sagemaker_client.search.return_value = {
318+
"Results": [
319+
{
320+
"TrialComponent": {
321+
"Parents": [
322+
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
323+
],
324+
"TrialComponentName": expected_tc_name,
325+
}
326+
}
327+
]
328+
}
299329

300330
with load_run(sagemaker_session=sagemaker_session):
301331
pass
302332

303333
client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name)
334+
mock_run_env._trial.add_trial_component.assert_not_called()
304335

305336

306337
@patch(
@@ -337,11 +368,24 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
337368
# The Run object has been created else where
338369
"ExperimentConfig": exp_config,
339370
}
371+
sagemaker_session.sagemaker_client.search.return_value = {
372+
"Results": [
373+
{
374+
"TrialComponent": {
375+
"Parents": [
376+
{"ExperimentName": TEST_EXP_NAME, "TrialName": exp_config[TRIAL_NAME]}
377+
],
378+
"TrialComponentName": expected_tc_name,
379+
}
380+
}
381+
]
382+
}
340383

341384
with load_run(sagemaker_session=sagemaker_session):
342385
pass
343386

344387
client.describe_transform_job.assert_called_once_with(TransformJobName=job_name)
388+
mock_run_env._trial.add_trial_component.assert_not_called()
345389

346390

347391
@patch(
@@ -359,6 +403,7 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
359403
)
360404
@patch.object(_TrialComponent, "save")
361405
def test_run_object_serialize_deserialize(mock_tc_save, sagemaker_session):
406+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
362407
run_obj = Run(
363408
experiment_name=TEST_EXP_NAME,
364409
run_name=TEST_RUN_NAME,

tests/unit/sagemaker/remote_function/core/test_stored_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def test_save_with_parameter_of_run_type(
109109
):
110110
session = Mock()
111111
s3_base_uri = random_s3_uri()
112+
session.sagemaker_client.search.return_value = {"Results": []}
112113

113114
run = Run(
114115
experiment_name=TEST_EXP_NAME,

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def run_obj(sagemaker_session):
113113
"sagemaker.experiments.run._Trial._load_or_create",
114114
MagicMock(side_effect=mock_trial_load_or_create_func),
115115
):
116+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
116117
run = Run(
117118
experiment_name="test-exp",
118119
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)