Skip to content

Commit dd08313

Browse files
authored
feature: reshape Parents into experiment analytics dataframe (#1884)
aws/sagemaker-experiments#99
1 parent cc98dd0 commit dd08313

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

src/sagemaker/analytics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,23 @@ def _reshape_artifacts(self, artifacts, _artifact_names):
547547
out["{} - {}".format(name, "Value")] = value.get("Value")
548548
return out
549549

550+
def _reshape_parents(self, parents):
551+
"""Reshape trial component parents to a pandas column
552+
Args:
553+
parents: trial component parents (trials and experiments)
554+
Returns:
555+
dict: Key: artifacts name, Value: artifacts value
556+
"""
557+
out = OrderedDict()
558+
trials = []
559+
experiments = []
560+
for parent in parents:
561+
trials.append(parent["TrialName"])
562+
experiments.append(parent["ExperimentName"])
563+
out["Trials"] = trials
564+
out["Experiments"] = experiments
565+
return out
566+
550567
def _reshape(self, trial_component):
551568
"""Reshape trial component data to pandas columns
552569
Args:
@@ -574,6 +591,7 @@ def _reshape(self, trial_component):
574591
trial_component.get("OutputArtifacts", []), self._output_artifact_names
575592
)
576593
)
594+
out.update(self._reshape_parents(trial_component.get("Parents", [])))
577595
return out
578596

579597
def _fetch_dataframe(self):

tests/integ/test_experiments_analytics.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def test_experiment_analytics_artifacts(sagemaker_session):
9999
"inputArtifacts1 - Value",
100100
"outputArtifacts1 - MediaType",
101101
"outputArtifacts1 - Value",
102+
"Trials",
103+
"Experiments",
102104
]
103105

104106

@@ -109,7 +111,13 @@ def test_experiment_analytics(sagemaker_session):
109111
experiment_name=experiment_name, sagemaker_session=sagemaker_session
110112
)
111113

112-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
114+
assert list(analytics.dataframe().columns) == [
115+
"TrialComponentName",
116+
"DisplayName",
117+
"hp1",
118+
"Trials",
119+
"Experiments",
120+
]
113121

114122

115123
def test_experiment_analytics_pagination(sagemaker_session):
@@ -118,7 +126,13 @@ def test_experiment_analytics_pagination(sagemaker_session):
118126
experiment_name=experiment_name, sagemaker_session=sagemaker_session
119127
)
120128

121-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
129+
assert list(analytics.dataframe().columns) == [
130+
"TrialComponentName",
131+
"DisplayName",
132+
"hp1",
133+
"Trials",
134+
"Experiments",
135+
]
122136
assert (
123137
len(analytics.dataframe()) > 10
124138
) # TODO [owen-t] Replace with == 20 and put test in retry block
@@ -137,7 +151,13 @@ def test_experiment_analytics_search_by_nested_filter(sagemaker_session):
137151
sagemaker_session=sagemaker_session, search_expression=search_exp
138152
)
139153

140-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
154+
assert list(analytics.dataframe().columns) == [
155+
"TrialComponentName",
156+
"DisplayName",
157+
"hp1",
158+
"Trials",
159+
"Experiments",
160+
]
141161
assert (
142162
len(analytics.dataframe()) > 5
143163
) # TODO [owen-t] Replace with == 10 and put test in retry block
@@ -159,7 +179,13 @@ def test_experiment_analytics_search_by_nested_filter_sort_ascending(sagemaker_s
159179
sort_order="Ascending",
160180
)
161181

162-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
182+
assert list(analytics.dataframe().columns) == [
183+
"TrialComponentName",
184+
"DisplayName",
185+
"hp1",
186+
"Trials",
187+
"Experiments",
188+
]
163189
assert (
164190
len(analytics.dataframe()) > 5
165191
) # TODO [owen-t] Replace with == 10 and put test in retry block
@@ -183,7 +209,13 @@ def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_
183209
sort_by="Parameters.hp1",
184210
)
185211

186-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
212+
assert list(analytics.dataframe().columns) == [
213+
"TrialComponentName",
214+
"DisplayName",
215+
"hp1",
216+
"Trials",
217+
"Experiments",
218+
]
187219
assert (
188220
len(analytics.dataframe()) > 5
189221
) # TODO [owen-t] Replace with == 10 and put test in retry block

tests/unit/test_experiments_analytics.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def trial_component(trial_component_name):
4848
"outputArtifacts1": {"MediaType": "text/csv", "Value": "s3:/sky/far1"},
4949
"outputArtifacts2": {"MediaType": "text/csv", "Value": "s3:/sky/far2"},
5050
},
51+
"Parents": [{"TrialName": "trial1", "ExperimentName": "experiment1"}],
5152
}
5253

5354

54-
def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
55+
def test_trial_analytics_dataframe_all(mock_session):
5556
mock_session.sagemaker_client.search.return_value = {
5657
"Results": [
5758
{"TrialComponent": trial_component("trial-1")},
@@ -88,6 +89,8 @@ def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
8889
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
8990
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
9091
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
92+
("Trials", [["trial1"], ["trial1"]]),
93+
("Experiments", [["experiment1"], ["experiment1"]]),
9194
]
9295
)
9396
)
@@ -141,6 +144,8 @@ def test_trial_analytics_dataframe_selected_hyperparams(mock_session):
141144
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
142145
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
143146
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
147+
("Trials", [["trial1"], ["trial1"]]),
148+
("Experiments", [["experiment1"], ["experiment1"]]),
144149
]
145150
)
146151
)
@@ -189,6 +194,8 @@ def test_trial_analytics_dataframe_selected_metrics(mock_session):
189194
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
190195
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
191196
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
197+
("Trials", [["trial1"], ["trial1"]]),
198+
("Experiments", [["experiment1"], ["experiment1"]]),
192199
]
193200
)
194201
)
@@ -243,6 +250,8 @@ def test_trial_analytics_dataframe_search_pagination(mock_session):
243250
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
244251
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
245252
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
253+
("Trials", [["trial1"], ["trial1"]]),
254+
("Experiments", [["experiment1"], ["experiment1"]]),
246255
]
247256
)
248257
)

0 commit comments

Comments
 (0)