Skip to content

Commit 1883f16

Browse files
committed
feature: reshape Parents into experiment analytics dataframe
aws/sagemaker-experiments#99
1 parent ca86159 commit 1883f16

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/sagemaker/analytics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,23 @@ def _reshape_artifacts(self, artifacts, _artifact_names):
536536
out["{} - {}".format(name, "Value")] = value.get("Value")
537537
return out
538538

539+
def _reshape_parents(self, parents):
540+
"""Reshape trial component parents to a pandas column
541+
Args:
542+
parents: trial component parents (trials and experiments)
543+
Returns:
544+
dict: Key: artifacts name, Value: artifacts value
545+
"""
546+
out = OrderedDict()
547+
trials = []
548+
experiments = []
549+
for parent in parents:
550+
trials.append(parent["TrialName"])
551+
experiments.append(parent["ExperimentName"])
552+
out["Trials"] = trials
553+
out["Experiments"] = experiments
554+
return out
555+
539556
def _reshape(self, trial_component):
540557
"""Reshape trial component data to pandas columns
541558
Args:
@@ -563,6 +580,7 @@ def _reshape(self, trial_component):
563580
trial_component.get("OutputArtifacts", []), self._output_artifact_names
564581
)
565582
)
583+
out.update(self._reshape_parents(trial_component.get("Parents", [])))
566584
return out
567585

568586
def _fetch_dataframe(self):

tests/unit/test_experiments_analytics.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def trial_component(trial_component_name):
4848
"outputArtifacts1": {"MediaType": "text/csv", "Value": "s3:/sky/far1"},
4949
"outputArtifacts2": {"MediaType": "text/csv", "Value": "s3:/sky/far2"},
5050
},
51+
"Parents": [{"TrialName": "trial1", "ExperimentName": "experiment1"}],
5152
}
5253

5354

54-
def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
55+
def test_trial_analytics_dataframe_all(mock_session):
5556
mock_session.sagemaker_client.search.return_value = {
5657
"Results": [
5758
{"TrialComponent": trial_component("trial-1")},
@@ -88,6 +89,8 @@ def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
8889
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
8990
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
9091
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
92+
("Trials", [["trial1"], ["trial1"]]),
93+
("Experiments", [["experiment1"], ["experiment1"]]),
9194
]
9295
)
9396
)
@@ -141,6 +144,8 @@ def test_trial_analytics_dataframe_selected_hyperparams(mock_session):
141144
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
142145
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
143146
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
147+
("Trials", [["trial1"], ["trial1"]]),
148+
("Experiments", [["experiment1"], ["experiment1"]]),
144149
]
145150
)
146151
)
@@ -189,6 +194,8 @@ def test_trial_analytics_dataframe_selected_metrics(mock_session):
189194
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
190195
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
191196
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
197+
("Trials", [["trial1"], ["trial1"]]),
198+
("Experiments", [["experiment1"], ["experiment1"]]),
192199
]
193200
)
194201
)
@@ -243,6 +250,8 @@ def test_trial_analytics_dataframe_search_pagination(mock_session):
243250
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
244251
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
245252
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
253+
("Trials", [["trial1"], ["trial1"]]),
254+
("Experiments", [["experiment1"], ["experiment1"]]),
246255
]
247256
)
248257
)

0 commit comments

Comments
 (0)