44
44
from sagemaker .experiments import Run , load_run , list_runs
45
45
from sagemaker .experiments .trial import _Trial
46
46
from sagemaker .experiments .trial_component import _TrialComponent
47
+ from sagemaker .experiments ._helper import _DEFAULT_ARTIFACT_PREFIX
47
48
from tests .unit .sagemaker .experiments .helpers import (
48
49
mock_trial_load_or_create_func ,
49
50
mock_tc_load_or_create_func ,
52
53
TEST_RUN_NAME ,
53
54
TEST_EXP_DISPLAY_NAME ,
54
55
TEST_RUN_DISPLAY_NAME ,
56
+ TEST_ARTIFACT_BUCKET ,
57
+ TEST_ARTIFACT_PREFIX ,
55
58
)
56
59
57
60
61
+ @pytest .mark .parametrize (
62
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
63
+ [
64
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
65
+ (
66
+ {
67
+ "artifact_bucket" : TEST_ARTIFACT_BUCKET ,
68
+ "artifact_prefix" : TEST_ARTIFACT_PREFIX ,
69
+ },
70
+ TEST_ARTIFACT_BUCKET ,
71
+ TEST_ARTIFACT_PREFIX ,
72
+ ),
73
+ ],
74
+ )
58
75
@patch (
59
76
"sagemaker.experiments.run.Experiment._load_or_create" ,
60
77
MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME )),
69
86
MagicMock (side_effect = mock_tc_load_or_create_func ),
70
87
)
71
88
@patch .object (_TrialComponent , "save" )
72
- def test_run_init (mock_tc_save , sagemaker_session ):
89
+ def test_run_init (
90
+ mock_tc_save ,
91
+ sagemaker_session ,
92
+ kwargs ,
93
+ expected_artifact_bucket ,
94
+ expected_artifact_prefix ,
95
+ ):
73
96
with Run (
74
- experiment_name = TEST_EXP_NAME , run_name = TEST_RUN_NAME , sagemaker_session = sagemaker_session
97
+ experiment_name = TEST_EXP_NAME ,
98
+ run_name = TEST_RUN_NAME ,
99
+ sagemaker_session = sagemaker_session ,
100
+ ** kwargs ,
75
101
) as run_obj :
76
102
assert not run_obj ._in_load
77
103
assert not run_obj ._inside_load_context
@@ -90,6 +116,8 @@ def test_run_init(mock_tc_save, sagemaker_session):
90
116
TRIAL_NAME : run_obj .run_group_name ,
91
117
RUN_NAME : expected_tc_name ,
92
118
}
119
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
120
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
93
121
94
122
# trail_component.save is called when entering/ exiting the with block
95
123
mock_tc_save .assert_called ()
@@ -124,6 +152,20 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
124
152
)
125
153
126
154
155
+ @pytest .mark .parametrize (
156
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
157
+ [
158
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
159
+ (
160
+ {
161
+ "artifact_bucket" : TEST_ARTIFACT_BUCKET ,
162
+ "artifact_prefix" : TEST_ARTIFACT_PREFIX ,
163
+ },
164
+ TEST_ARTIFACT_BUCKET ,
165
+ TEST_ARTIFACT_PREFIX ,
166
+ ),
167
+ ],
168
+ )
127
169
@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
128
170
@patch (
129
171
"sagemaker.experiments.run.Experiment._load_or_create" ,
@@ -139,7 +181,13 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
139
181
MagicMock (side_effect = mock_tc_load_or_create_func ),
140
182
)
141
183
@patch ("sagemaker.experiments.run._RunEnvironment" )
142
- def test_run_load_no_run_name_and_in_train_job (mock_run_env , sagemaker_session ):
184
+ def test_run_load_no_run_name_and_in_train_job (
185
+ mock_run_env ,
186
+ sagemaker_session ,
187
+ kwargs ,
188
+ expected_artifact_bucket ,
189
+ expected_artifact_prefix ,
190
+ ):
143
191
client = sagemaker_session .sagemaker_client
144
192
job_name = "my-train-job"
145
193
rv = Mock ()
@@ -158,7 +206,7 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
158
206
# The Run object has been created else where
159
207
"ExperimentConfig" : exp_config ,
160
208
}
161
- with load_run (sagemaker_session = sagemaker_session ) as run_obj :
209
+ with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
162
210
assert run_obj ._in_load
163
211
assert not run_obj ._inside_init_context
164
212
assert run_obj ._inside_load_context
@@ -169,6 +217,8 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
169
217
assert run_obj .experiment_name == TEST_EXP_NAME
170
218
assert run_obj ._experiment
171
219
assert run_obj .experiment_config == exp_config
220
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
221
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
172
222
173
223
client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
174
224
@@ -215,6 +265,20 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
215
265
assert "Failed to load a Run object" in str (err )
216
266
217
267
268
+ @pytest .mark .parametrize (
269
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
270
+ [
271
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
272
+ (
273
+ {
274
+ "artifact_bucket" : TEST_ARTIFACT_BUCKET ,
275
+ "artifact_prefix" : TEST_ARTIFACT_PREFIX ,
276
+ },
277
+ TEST_ARTIFACT_BUCKET ,
278
+ TEST_ARTIFACT_PREFIX ,
279
+ ),
280
+ ],
281
+ )
218
282
@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
219
283
@patch (
220
284
"sagemaker.experiments.run.Experiment._load_or_create" ,
@@ -229,11 +293,14 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
229
293
"sagemaker.experiments.run._TrialComponent._load_or_create" ,
230
294
MagicMock (side_effect = mock_tc_load_or_create_func ),
231
295
)
232
- def test_run_load_with_run_name_and_exp_name (sagemaker_session ):
296
+ def test_run_load_with_run_name_and_exp_name (
297
+ sagemaker_session , kwargs , expected_artifact_bucket , expected_artifact_prefix
298
+ ):
233
299
with load_run (
234
300
run_name = TEST_RUN_NAME ,
235
301
experiment_name = TEST_EXP_NAME ,
236
302
sagemaker_session = sagemaker_session ,
303
+ ** kwargs ,
237
304
) as run_obj :
238
305
expected_tc_name = f"{ TEST_EXP_NAME } { DELIMITER } { TEST_RUN_NAME } "
239
306
expected_exp_config = {
@@ -249,6 +316,8 @@ def test_run_load_with_run_name_and_exp_name(sagemaker_session):
249
316
assert run_obj ._trial
250
317
assert run_obj ._experiment
251
318
assert run_obj .experiment_config == expected_exp_config
319
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
320
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
252
321
253
322
254
323
def test_run_load_with_run_name_but_no_exp_name (sagemaker_session ):
0 commit comments