44
44
from sagemaker .experiments import Run , load_run , list_runs
45
45
from sagemaker .experiments .trial import _Trial
46
46
from sagemaker .experiments .trial_component import _TrialComponent
47
+ from sagemaker .experiments ._helper import _DEFAULT_ARTIFACT_PREFIX
47
48
from tests .unit .sagemaker .experiments .helpers import (
48
49
mock_trial_load_or_create_func ,
49
50
mock_tc_load_or_create_func ,
52
53
TEST_RUN_NAME ,
53
54
TEST_EXP_DISPLAY_NAME ,
54
55
TEST_RUN_DISPLAY_NAME ,
56
+ TEST_ARTIFACT_BUCKET ,
57
+ TEST_ARTIFACT_PREFIX ,
55
58
)
56
59
57
60
61
+ @pytest .mark .parametrize (
62
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
63
+ [
64
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
65
+ ({"artifact_bucket" : TEST_ARTIFACT_BUCKET , "artifact_prefix" : TEST_ARTIFACT_PREFIX }, TEST_ARTIFACT_BUCKET , TEST_ARTIFACT_PREFIX ),
66
+ ],
67
+ )
58
68
@patch (
59
69
"sagemaker.experiments.run.Experiment._load_or_create" ,
60
70
MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME )),
69
79
MagicMock (side_effect = mock_tc_load_or_create_func ),
70
80
)
71
81
@patch .object (_TrialComponent , "save" )
72
- def test_run_init (mock_tc_save , sagemaker_session ):
82
+ def test_run_init (mock_tc_save , sagemaker_session , kwargs , expected_artifact_bucket , expected_artifact_prefix ):
73
83
with Run (
74
- experiment_name = TEST_EXP_NAME , run_name = TEST_RUN_NAME , sagemaker_session = sagemaker_session
84
+ experiment_name = TEST_EXP_NAME , run_name = TEST_RUN_NAME , sagemaker_session = sagemaker_session , ** kwargs
75
85
) as run_obj :
76
86
assert not run_obj ._in_load
77
87
assert not run_obj ._inside_load_context
@@ -90,6 +100,8 @@ def test_run_init(mock_tc_save, sagemaker_session):
90
100
TRIAL_NAME : run_obj .run_group_name ,
91
101
RUN_NAME : expected_tc_name ,
92
102
}
103
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
104
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
93
105
94
106
# trail_component.save is called when entering/ exiting the with block
95
107
mock_tc_save .assert_called ()
@@ -123,7 +135,13 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
123
135
err
124
136
)
125
137
126
-
138
+ @pytest .mark .parametrize (
139
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
140
+ [
141
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
142
+ ({"artifact_bucket" : TEST_ARTIFACT_BUCKET , "artifact_prefix" : TEST_ARTIFACT_PREFIX }, TEST_ARTIFACT_BUCKET , TEST_ARTIFACT_PREFIX ),
143
+ ],
144
+ )
127
145
@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
128
146
@patch (
129
147
"sagemaker.experiments.run.Experiment._load_or_create" ,
@@ -139,7 +157,7 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
139
157
MagicMock (side_effect = mock_tc_load_or_create_func ),
140
158
)
141
159
@patch ("sagemaker.experiments.run._RunEnvironment" )
142
- def test_run_load_no_run_name_and_in_train_job (mock_run_env , sagemaker_session ):
160
+ def test_run_load_no_run_name_and_in_train_job (mock_run_env , sagemaker_session , kwargs , expected_artifact_bucket , expected_artifact_prefix ):
143
161
client = sagemaker_session .sagemaker_client
144
162
job_name = "my-train-job"
145
163
rv = Mock ()
@@ -158,7 +176,7 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
158
176
# The Run object has been created else where
159
177
"ExperimentConfig" : exp_config ,
160
178
}
161
- with load_run (sagemaker_session = sagemaker_session ) as run_obj :
179
+ with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
162
180
assert run_obj ._in_load
163
181
assert not run_obj ._inside_init_context
164
182
assert run_obj ._inside_load_context
@@ -169,6 +187,8 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
169
187
assert run_obj .experiment_name == TEST_EXP_NAME
170
188
assert run_obj ._experiment
171
189
assert run_obj .experiment_config == exp_config
190
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
191
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
172
192
173
193
client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
174
194
@@ -214,7 +234,13 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
214
234
215
235
assert "Failed to load a Run object" in str (err )
216
236
217
-
237
+ @pytest .mark .parametrize (
238
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
239
+ [
240
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
241
+ ({"artifact_bucket" : TEST_ARTIFACT_BUCKET , "artifact_prefix" : TEST_ARTIFACT_PREFIX }, TEST_ARTIFACT_BUCKET , TEST_ARTIFACT_PREFIX ),
242
+ ],
243
+ )
218
244
@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
219
245
@patch (
220
246
"sagemaker.experiments.run.Experiment._load_or_create" ,
@@ -229,11 +255,12 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
229
255
"sagemaker.experiments.run._TrialComponent._load_or_create" ,
230
256
MagicMock (side_effect = mock_tc_load_or_create_func ),
231
257
)
232
- def test_run_load_with_run_name_and_exp_name (sagemaker_session ):
258
+ def test_run_load_with_run_name_and_exp_name (sagemaker_session , kwargs , expected_artifact_bucket , expected_artifact_prefix ):
233
259
with load_run (
234
260
run_name = TEST_RUN_NAME ,
235
261
experiment_name = TEST_EXP_NAME ,
236
262
sagemaker_session = sagemaker_session ,
263
+ ** kwargs ,
237
264
) as run_obj :
238
265
expected_tc_name = f"{ TEST_EXP_NAME } { DELIMITER } { TEST_RUN_NAME } "
239
266
expected_exp_config = {
@@ -249,6 +276,8 @@ def test_run_load_with_run_name_and_exp_name(sagemaker_session):
249
276
assert run_obj ._trial
250
277
assert run_obj ._experiment
251
278
assert run_obj .experiment_config == expected_exp_config
279
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
280
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
252
281
253
282
254
283
def test_run_load_with_run_name_but_no_exp_name (sagemaker_session ):
0 commit comments