Skip to content

Commit 521a255

Browse files
authored
feat: Enable customizing artifact output path (#3965)
1 parent 45cdd70 commit 521a255

File tree

3 files changed

+99
-5
lines changed

3 files changed

+99
-5
lines changed

src/sagemaker/experiments/run.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.experiments._helper import (
3030
_ArtifactUploader,
3131
_LineageArtifactTracker,
32+
_DEFAULT_ARTIFACT_PREFIX,
3233
)
3334
from sagemaker.experiments._environment import _RunEnvironment
3435
from sagemaker.experiments._run_context import _RunContext
@@ -95,6 +96,8 @@ def __init__(
9596
run_display_name: Optional[str] = None,
9697
tags: Optional[List[Dict[str, str]]] = None,
9798
sagemaker_session: Optional["Session"] = None,
99+
artifact_bucket: Optional[str] = None,
100+
artifact_prefix: Optional[str] = None,
98101
):
99102
"""Construct a `Run` instance.
100103
@@ -152,6 +155,11 @@ def __init__(
152155
manages interactions with Amazon SageMaker APIs and any other
153156
AWS services needed. If not specified, one is created using the
154157
default AWS configuration chain.
158+
artifact_bucket (str): The S3 bucket to upload the artifact to.
159+
If not specified, the default bucket defined in `sagemaker_session`
160+
will be used.
161+
artifact_prefix (str): The S3 key prefix used to generate the S3 path
162+
to upload the artifact to (default: "trial-component-artifacts").
155163
"""
156164
# TODO: we should revert the lower casting once backend fix reaches prod
157165
self.experiment_name = experiment_name.lower()
@@ -197,6 +205,10 @@ def __init__(
197205
self._artifact_uploader = _ArtifactUploader(
198206
trial_component_name=self._trial_component.trial_component_name,
199207
sagemaker_session=sagemaker_session,
208+
artifact_bucket=artifact_bucket,
209+
artifact_prefix=_DEFAULT_ARTIFACT_PREFIX
210+
if artifact_prefix is None
211+
else artifact_prefix,
200212
)
201213
self._lineage_artifact_tracker = _LineageArtifactTracker(
202214
trial_component_arn=self._trial_component.trial_component_arn,
@@ -729,6 +741,8 @@ def load_run(
729741
run_name: Optional[str] = None,
730742
experiment_name: Optional[str] = None,
731743
sagemaker_session: Optional["Session"] = None,
744+
artifact_bucket: Optional[str] = None,
745+
artifact_prefix: Optional[str] = None,
732746
) -> Run:
733747
"""Load an existing run.
734748
@@ -792,6 +806,11 @@ def load_run(
792806
manages interactions with Amazon SageMaker APIs and any other
793807
AWS services needed. If not specified, one is created using the
794808
default AWS configuration chain.
809+
artifact_bucket (str): The S3 bucket to upload the artifact to.
810+
If not specified, the default bucket defined in `sagemaker_session`
811+
will be used.
812+
artifact_prefix (str): The S3 key prefix used to generate the S3 path
813+
to upload the artifact to (default: "trial-component-artifacts").
795814
796815
Returns:
797816
Run: The loaded Run object.
@@ -811,6 +830,8 @@ def load_run(
811830
experiment_name=experiment_name,
812831
run_name=run_name,
813832
sagemaker_session=sagemaker_session or _utils.default_session(),
833+
artifact_bucket=artifact_bucket,
834+
artifact_prefix=artifact_prefix,
814835
)
815836
elif _RunContext.get_current_run():
816837
run_instance = _RunContext.get_current_run()
@@ -827,6 +848,8 @@ def load_run(
827848
experiment_name=experiment_name,
828849
run_name=run_name,
829850
sagemaker_session=sagemaker_session or _utils.default_session(),
851+
artifact_bucket=artifact_bucket,
852+
artifact_prefix=artifact_prefix,
830853
)
831854
else:
832855
raise RuntimeError(

tests/unit/sagemaker/experiments/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
TEST_EXP_DISPLAY_NAME = "my-experiment-display-name"
2323
TEST_RUN_DISPLAY_NAME = "my-run-display-name"
2424
TEST_TAGS = [{"Key": "some-key", "Value": "some-value"}]
25+
TEST_ARTIFACT_BUCKET = "my-artifact-bucket"
26+
TEST_ARTIFACT_PREFIX = "my-artifact-prefix"
2527

2628

2729
def mock_tc_load_or_create_func(

tests/unit/sagemaker/experiments/test_run.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from sagemaker.experiments import Run, load_run, list_runs
4545
from sagemaker.experiments.trial import _Trial
4646
from sagemaker.experiments.trial_component import _TrialComponent
47+
from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX
4748
from tests.unit.sagemaker.experiments.helpers import (
4849
mock_trial_load_or_create_func,
4950
mock_tc_load_or_create_func,
@@ -52,9 +53,25 @@
5253
TEST_RUN_NAME,
5354
TEST_EXP_DISPLAY_NAME,
5455
TEST_RUN_DISPLAY_NAME,
56+
TEST_ARTIFACT_BUCKET,
57+
TEST_ARTIFACT_PREFIX,
5558
)
5659

5760

61+
@pytest.mark.parametrize(
62+
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
63+
[
64+
({}, None, _DEFAULT_ARTIFACT_PREFIX),
65+
(
66+
{
67+
"artifact_bucket": TEST_ARTIFACT_BUCKET,
68+
"artifact_prefix": TEST_ARTIFACT_PREFIX,
69+
},
70+
TEST_ARTIFACT_BUCKET,
71+
TEST_ARTIFACT_PREFIX,
72+
),
73+
],
74+
)
5875
@patch(
5976
"sagemaker.experiments.run.Experiment._load_or_create",
6077
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
@@ -69,9 +86,18 @@
6986
MagicMock(side_effect=mock_tc_load_or_create_func),
7087
)
7188
@patch.object(_TrialComponent, "save")
72-
def test_run_init(mock_tc_save, sagemaker_session):
89+
def test_run_init(
90+
mock_tc_save,
91+
sagemaker_session,
92+
kwargs,
93+
expected_artifact_bucket,
94+
expected_artifact_prefix,
95+
):
7396
with Run(
74-
experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session
97+
experiment_name=TEST_EXP_NAME,
98+
run_name=TEST_RUN_NAME,
99+
sagemaker_session=sagemaker_session,
100+
**kwargs,
75101
) as run_obj:
76102
assert not run_obj._in_load
77103
assert not run_obj._inside_load_context
@@ -90,6 +116,8 @@ def test_run_init(mock_tc_save, sagemaker_session):
90116
TRIAL_NAME: run_obj.run_group_name,
91117
RUN_NAME: expected_tc_name,
92118
}
119+
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
120+
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
93121

94122
# trail_component.save is called when entering/ exiting the with block
95123
mock_tc_save.assert_called()
@@ -124,6 +152,20 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
124152
)
125153

126154

155+
@pytest.mark.parametrize(
156+
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
157+
[
158+
({}, None, _DEFAULT_ARTIFACT_PREFIX),
159+
(
160+
{
161+
"artifact_bucket": TEST_ARTIFACT_BUCKET,
162+
"artifact_prefix": TEST_ARTIFACT_PREFIX,
163+
},
164+
TEST_ARTIFACT_BUCKET,
165+
TEST_ARTIFACT_PREFIX,
166+
),
167+
],
168+
)
127169
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
128170
@patch(
129171
"sagemaker.experiments.run.Experiment._load_or_create",
@@ -139,7 +181,13 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
139181
MagicMock(side_effect=mock_tc_load_or_create_func),
140182
)
141183
@patch("sagemaker.experiments.run._RunEnvironment")
142-
def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
184+
def test_run_load_no_run_name_and_in_train_job(
185+
mock_run_env,
186+
sagemaker_session,
187+
kwargs,
188+
expected_artifact_bucket,
189+
expected_artifact_prefix,
190+
):
143191
client = sagemaker_session.sagemaker_client
144192
job_name = "my-train-job"
145193
rv = Mock()
@@ -158,7 +206,7 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
158206
# The Run object has been created else where
159207
"ExperimentConfig": exp_config,
160208
}
161-
with load_run(sagemaker_session=sagemaker_session) as run_obj:
209+
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
162210
assert run_obj._in_load
163211
assert not run_obj._inside_init_context
164212
assert run_obj._inside_load_context
@@ -169,6 +217,8 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
169217
assert run_obj.experiment_name == TEST_EXP_NAME
170218
assert run_obj._experiment
171219
assert run_obj.experiment_config == exp_config
220+
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
221+
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
172222

173223
client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
174224

@@ -215,6 +265,20 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
215265
assert "Failed to load a Run object" in str(err)
216266

217267

268+
@pytest.mark.parametrize(
269+
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
270+
[
271+
({}, None, _DEFAULT_ARTIFACT_PREFIX),
272+
(
273+
{
274+
"artifact_bucket": TEST_ARTIFACT_BUCKET,
275+
"artifact_prefix": TEST_ARTIFACT_PREFIX,
276+
},
277+
TEST_ARTIFACT_BUCKET,
278+
TEST_ARTIFACT_PREFIX,
279+
),
280+
],
281+
)
218282
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
219283
@patch(
220284
"sagemaker.experiments.run.Experiment._load_or_create",
@@ -229,11 +293,14 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
229293
"sagemaker.experiments.run._TrialComponent._load_or_create",
230294
MagicMock(side_effect=mock_tc_load_or_create_func),
231295
)
232-
def test_run_load_with_run_name_and_exp_name(sagemaker_session):
296+
def test_run_load_with_run_name_and_exp_name(
297+
sagemaker_session, kwargs, expected_artifact_bucket, expected_artifact_prefix
298+
):
233299
with load_run(
234300
run_name=TEST_RUN_NAME,
235301
experiment_name=TEST_EXP_NAME,
236302
sagemaker_session=sagemaker_session,
303+
**kwargs,
237304
) as run_obj:
238305
expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}"
239306
expected_exp_config = {
@@ -249,6 +316,8 @@ def test_run_load_with_run_name_and_exp_name(sagemaker_session):
249316
assert run_obj._trial
250317
assert run_obj._experiment
251318
assert run_obj.experiment_config == expected_exp_config
319+
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
320+
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix
252321

253322

254323
def test_run_load_with_run_name_but_no_exp_name(sagemaker_session):

0 commit comments

Comments
 (0)