Skip to content

Commit 3c11d03

Browse files
committed
feature: reshape Parents into experiment analytics dataframe
aws/sagemaker-experiments#99
1 parent ca86159 commit 3c11d03

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

src/sagemaker/analytics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,23 @@ def _reshape_artifacts(self, artifacts, _artifact_names):
536536
out["{} - {}".format(name, "Value")] = value.get("Value")
537537
return out
538538

539+
def _reshape_parents(self, parents):
540+
"""Reshape trial component parents to a pandas column
541+
Args:
542+
parents: trial component parents (trials and experiments)
543+
Returns:
544+
dict: Key: artifacts name, Value: artifacts value
545+
"""
546+
out = OrderedDict()
547+
trials = []
548+
experiments = []
549+
for parent in parents:
550+
trials.append(parent["TrialName"])
551+
experiments.append(parent["ExperimentName"])
552+
out["Trials"] = trials
553+
out["Experiments"] = experiments
554+
return out
555+
539556
def _reshape(self, trial_component):
540557
"""Reshape trial component data to pandas columns
541558
Args:
@@ -563,6 +580,7 @@ def _reshape(self, trial_component):
563580
trial_component.get("OutputArtifacts", []), self._output_artifact_names
564581
)
565582
)
583+
out.update(self._reshape_parents(trial_component.get("Parents", [])))
566584
return out
567585

568586
def _fetch_dataframe(self):

tests/integ/test_experiments_analytics.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def test_experiment_analytics_artifacts(sagemaker_session):
9999
"inputArtifacts1 - Value",
100100
"outputArtifacts1 - MediaType",
101101
"outputArtifacts1 - Value",
102+
"Trials",
103+
"Experiments",
102104
]
103105

104106

@@ -109,7 +111,13 @@ def test_experiment_analytics(sagemaker_session):
109111
experiment_name=experiment_name, sagemaker_session=sagemaker_session
110112
)
111113

112-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
114+
assert list(analytics.dataframe().columns) == [
115+
"TrialComponentName",
116+
"DisplayName",
117+
"hp1",
118+
"Trials",
119+
"Experiments",
120+
]
113121

114122

115123
def test_experiment_analytics_pagination(sagemaker_session):
@@ -118,7 +126,12 @@ def test_experiment_analytics_pagination(sagemaker_session):
118126
experiment_name=experiment_name, sagemaker_session=sagemaker_session
119127
)
120128

121-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
129+
assert list(analytics.dataframe().columns) == [
130+
"TrialComponentName",
131+
"DisplayName",
132+
"hp1",
133+
"Trials",
134+
"Experiments"]
122135
assert (
123136
len(analytics.dataframe()) > 10
124137
) # TODO [owen-t] Replace with == 20 and put test in retry block
@@ -137,7 +150,12 @@ def test_experiment_analytics_search_by_nested_filter(sagemaker_session):
137150
sagemaker_session=sagemaker_session, search_expression=search_exp
138151
)
139152

140-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
153+
assert list(analytics.dataframe().columns) == [
154+
"TrialComponentName",
155+
"DisplayName",
156+
"hp1",
157+
"Trials",
158+
"Experiments"]
141159
assert (
142160
len(analytics.dataframe()) > 5
143161
) # TODO [owen-t] Replace with == 10 and put test in retry block
@@ -159,7 +177,12 @@ def test_experiment_analytics_search_by_nested_filter_sort_ascending(sagemaker_s
159177
sort_order="Ascending",
160178
)
161179

162-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
180+
assert list(analytics.dataframe().columns) == [
181+
"TrialComponentName",
182+
"DisplayName",
183+
"hp1",
184+
"Trials",
185+
"Experiments"]
163186
assert (
164187
len(analytics.dataframe()) > 5
165188
) # TODO [owen-t] Replace with == 10 and put test in retry block
@@ -183,7 +206,12 @@ def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_
183206
sort_by="Parameters.hp1",
184207
)
185208

186-
assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"]
209+
assert list(analytics.dataframe().columns) == [
210+
"TrialComponentName",
211+
"DisplayName",
212+
"hp1",
213+
"Trials",
214+
"Experiments"]
187215
assert (
188216
len(analytics.dataframe()) > 5
189217
) # TODO [owen-t] Replace with == 10 and put test in retry block

tests/unit/test_experiments_analytics.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def trial_component(trial_component_name):
4848
"outputArtifacts1": {"MediaType": "text/csv", "Value": "s3:/sky/far1"},
4949
"outputArtifacts2": {"MediaType": "text/csv", "Value": "s3:/sky/far2"},
5050
},
51+
"Parents": [{"TrialName": "trial1", "ExperimentName": "experiment1"}],
5152
}
5253

5354

54-
def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
55+
def test_trial_analytics_dataframe_all(mock_session):
5556
mock_session.sagemaker_client.search.return_value = {
5657
"Results": [
5758
{"TrialComponent": trial_component("trial-1")},
@@ -88,6 +89,8 @@ def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
8889
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
8990
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
9091
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
92+
("Trials", [["trial1"], ["trial1"]]),
93+
("Experiments", [["experiment1"], ["experiment1"]]),
9194
]
9295
)
9396
)
@@ -141,6 +144,8 @@ def test_trial_analytics_dataframe_selected_hyperparams(mock_session):
141144
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
142145
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
143146
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
147+
("Trials", [["trial1"], ["trial1"]]),
148+
("Experiments", [["experiment1"], ["experiment1"]]),
144149
]
145150
)
146151
)
@@ -189,6 +194,8 @@ def test_trial_analytics_dataframe_selected_metrics(mock_session):
189194
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
190195
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
191196
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
197+
("Trials", [["trial1"], ["trial1"]]),
198+
("Experiments", [["experiment1"], ["experiment1"]]),
192199
]
193200
)
194201
)
@@ -243,6 +250,8 @@ def test_trial_analytics_dataframe_search_pagination(mock_session):
243250
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
244251
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
245252
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
253+
("Trials", [["trial1"], ["trial1"]]),
254+
("Experiments", [["experiment1"], ["experiment1"]]),
246255
]
247256
)
248257
)

0 commit comments

Comments
 (0)