Skip to content

Commit 19b349a

Browse files
committed
feature: reshape Parents into experiment analytics dataframe
aws/sagemaker-experiments#99
1 parent ca86159 commit 19b349a

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

src/sagemaker/analytics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,23 @@ def _reshape_artifacts(self, artifacts, _artifact_names):
536536
out["{} - {}".format(name, "Value")] = value.get("Value")
537537
return out
538538

539+
def _reshape_parents(self, parents):
540+
"""Reshape trial component parents to a pandas column
541+
Args:
542+
parents: trial component parents (trials and experiments)
543+
Returns:
544+
dict: Key: artifacts name, Value: artifacts value
545+
"""
546+
out = OrderedDict()
547+
trials = []
548+
experiments = []
549+
for parent in parents:
550+
trials.append(parent["TrialName"])
551+
experiments.append(parent["ExperimentName"])
552+
out["Trials"] = trials
553+
out["Experiments"] = experiments
554+
return out
555+
539556
def _reshape(self, trial_component):
540557
"""Reshape trial component data to pandas columns
541558
Args:
@@ -563,6 +580,7 @@ def _reshape(self, trial_component):
563580
trial_component.get("OutputArtifacts", []), self._output_artifact_names
564581
)
565582
)
583+
out.update(self._reshape_parents(trial_component.get("Parents", [])))
566584
return out
567585

568586
def _fetch_dataframe(self):

tests/unit/test_experiments_analytics.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,16 @@ def trial_component(trial_component_name):
4848
"outputArtifacts1": {"MediaType": "text/csv", "Value": "s3:/sky/far1"},
4949
"outputArtifacts2": {"MediaType": "text/csv", "Value": "s3:/sky/far2"},
5050
},
51+
"Parents": [
52+
{
53+
"TrialName": "trial1",
54+
"ExperimentName": "experiment1"
55+
}
56+
]
5157
}
5258

5359

54-
def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
60+
def test_trial_analytics_dataframe_all(mock_session):
5561
mock_session.sagemaker_client.search.return_value = {
5662
"Results": [
5763
{"TrialComponent": trial_component("trial-1")},
@@ -88,6 +94,8 @@ def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
8894
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
8995
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
9096
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
97+
("Trials", [["trial1"], ["trial1"]]),
98+
("Experiments", [["experiment1"], ["experiment1"]]),
9199
]
92100
)
93101
)
@@ -141,6 +149,8 @@ def test_trial_analytics_dataframe_selected_hyperparams(mock_session):
141149
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
142150
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
143151
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
152+
("Trials", [["trial1"], ["trial1"]]),
153+
("Experiments", [["experiment1"], ["experiment1"]]),
144154
]
145155
)
146156
)
@@ -189,6 +199,8 @@ def test_trial_analytics_dataframe_selected_metrics(mock_session):
189199
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
190200
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
191201
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
202+
("Trials", [["trial1"], ["trial1"]]),
203+
("Experiments", [["experiment1"], ["experiment1"]]),
192204
]
193205
)
194206
)
@@ -243,6 +255,8 @@ def test_trial_analytics_dataframe_search_pagination(mock_session):
243255
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
244256
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
245257
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
258+
("Trials", [["trial1"], ["trial1"]]),
259+
("Experiments", [["experiment1"], ["experiment1"]]),
246260
]
247261
)
248262
)

0 commit comments

Comments
 (0)