@@ -48,10 +48,16 @@ def trial_component(trial_component_name):
48
48
"outputArtifacts1" : {"MediaType" : "text/csv" , "Value" : "s3:/sky/far1" },
49
49
"outputArtifacts2" : {"MediaType" : "text/csv" , "Value" : "s3:/sky/far2" },
50
50
},
51
+ "Parents" : [
52
+ {
53
+ "TrialName" : "trial1" ,
54
+ "ExperimentName" : "experiment1"
55
+ }
56
+ ]
51
57
}
52
58
53
59
54
- def test_trial_analytics_dataframe_all_metrics_hyperparams (mock_session ):
60
+ def test_trial_analytics_dataframe_all (mock_session ):
55
61
mock_session .sagemaker_client .search .return_value = {
56
62
"Results" : [
57
63
{"TrialComponent" : trial_component ("trial-1" )},
@@ -88,6 +94,8 @@ def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
88
94
("outputArtifacts1 - Value" , ["s3:/sky/far1" , "s3:/sky/far1" ]),
89
95
("outputArtifacts2 - MediaType" , ["text/csv" , "text/csv" ]),
90
96
("outputArtifacts2 - Value" , ["s3:/sky/far2" , "s3:/sky/far2" ]),
97
+ ("Trials" , [["trial1" ], ["trial1" ]]),
98
+ ("Experiments" , [["experiment1" ], ["experiment1" ]]),
91
99
]
92
100
)
93
101
)
@@ -141,6 +149,8 @@ def test_trial_analytics_dataframe_selected_hyperparams(mock_session):
141
149
("outputArtifacts1 - Value" , ["s3:/sky/far1" , "s3:/sky/far1" ]),
142
150
("outputArtifacts2 - MediaType" , ["text/csv" , "text/csv" ]),
143
151
("outputArtifacts2 - Value" , ["s3:/sky/far2" , "s3:/sky/far2" ]),
152
+ ("Trials" , [["trial1" ], ["trial1" ]]),
153
+ ("Experiments" , [["experiment1" ], ["experiment1" ]]),
144
154
]
145
155
)
146
156
)
@@ -189,6 +199,8 @@ def test_trial_analytics_dataframe_selected_metrics(mock_session):
189
199
("outputArtifacts1 - Value" , ["s3:/sky/far1" , "s3:/sky/far1" ]),
190
200
("outputArtifacts2 - MediaType" , ["text/csv" , "text/csv" ]),
191
201
("outputArtifacts2 - Value" , ["s3:/sky/far2" , "s3:/sky/far2" ]),
202
+ ("Trials" , [["trial1" ], ["trial1" ]]),
203
+ ("Experiments" , [["experiment1" ], ["experiment1" ]]),
192
204
]
193
205
)
194
206
)
@@ -243,6 +255,8 @@ def test_trial_analytics_dataframe_search_pagination(mock_session):
243
255
("outputArtifacts1 - Value" , ["s3:/sky/far1" , "s3:/sky/far1" ]),
244
256
("outputArtifacts2 - MediaType" , ["text/csv" , "text/csv" ]),
245
257
("outputArtifacts2 - Value" , ["s3:/sky/far2" , "s3:/sky/far2" ]),
258
+ ("Trials" , [["trial1" ], ["trial1" ]]),
259
+ ("Experiments" , [["experiment1" ], ["experiment1" ]]),
246
260
]
247
261
)
248
262
)
0 commit comments