Skip to content

feat: Enable customizing artifact output path #3965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/sagemaker/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sagemaker.experiments._helper import (
_ArtifactUploader,
_LineageArtifactTracker,
_DEFAULT_ARTIFACT_PREFIX,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit the name of this sounds like it will prefix the actual name of the artifact but its really a prefix for a s3 path where the associated files will be stored

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a bit confusing to use this name here. But _DEFAULT_ARTIFACT_PREFIX is not introduced in this PR, it's already defined in _helper.py. There is a variable called artifact_prefix, so I'm assuming that is why we called the default variable to be _DEFAULT_ARTIFACT_PREFIX.

A workaround could be change the following method defined in _ArtifactUploader.

class _ArtifactUploader(object):
    """Artifact uploader"""
    def __init__(
        self,
        trial_component_name,
        sagemaker_session,
        artifact_bucket=None,
        artifact_prefix=None, # HERE!!!
    ):
        self.sagemaker_session = sagemaker_session
        self.trial_component_name = trial_component_name
        self.artifact_bucket = artifact_bucket
        self.artifact_prefix = (
            _DEFAULT_ARTIFACT_PREFIX if artifact_prefix is None else artifact_prefix # HERE!!!
        )
        self._s3_client = self.sagemaker_session.boto_session.client("s3")

In this way, we don't need to import _DEFAULT_ARTIFACT_PREFIX to run.py

)
from sagemaker.experiments._environment import _RunEnvironment
from sagemaker.experiments._run_context import _RunContext
Expand Down Expand Up @@ -95,6 +96,8 @@ def __init__(
run_display_name: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
sagemaker_session: Optional["Session"] = None,
artifact_bucket: Optional[str] = None,
artifact_prefix: Optional[str] = None,
):
"""Construct a `Run` instance.

Expand Down Expand Up @@ -152,6 +155,11 @@ def __init__(
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
artifact_bucket (str): The S3 bucket to upload the artifact to.
If not specified, the default bucket defined in `sagemaker_session`
will be used.
artifact_prefix (str): The S3 key prefix used to generate the S3 path
to upload the artifact to (default: "trial-component-artifacts").
"""
# TODO: we should revert the lower casting once backend fix reaches prod
self.experiment_name = experiment_name.lower()
Expand Down Expand Up @@ -197,6 +205,10 @@ def __init__(
self._artifact_uploader = _ArtifactUploader(
trial_component_name=self._trial_component.trial_component_name,
sagemaker_session=sagemaker_session,
artifact_bucket=artifact_bucket,
artifact_prefix=_DEFAULT_ARTIFACT_PREFIX
if artifact_prefix is None
else artifact_prefix,
Comment on lines +209 to +211
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit this indentation is unintuitive

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
self._lineage_artifact_tracker = _LineageArtifactTracker(
trial_component_arn=self._trial_component.trial_component_arn,
Expand Down Expand Up @@ -729,6 +741,8 @@ def load_run(
run_name: Optional[str] = None,
experiment_name: Optional[str] = None,
sagemaker_session: Optional["Session"] = None,
artifact_bucket: Optional[str] = None,
artifact_prefix: Optional[str] = None,
) -> Run:
"""Load an existing run.

Expand Down Expand Up @@ -792,6 +806,11 @@ def load_run(
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
artifact_bucket (str): The S3 bucket to upload the artifact to.
If not specified, the default bucket defined in `sagemaker_session`
will be used.
artifact_prefix (str): The S3 key prefix used to generate the S3 path
to upload the artifact to (default: "trial-component-artifacts").

Returns:
Run: The loaded Run object.
Expand All @@ -811,6 +830,8 @@ def load_run(
experiment_name=experiment_name,
run_name=run_name,
sagemaker_session=sagemaker_session or _utils.default_session(),
artifact_bucket=artifact_bucket,
artifact_prefix=artifact_prefix,
)
elif _RunContext.get_current_run():
run_instance = _RunContext.get_current_run()
Expand All @@ -827,6 +848,8 @@ def load_run(
experiment_name=experiment_name,
run_name=run_name,
sagemaker_session=sagemaker_session or _utils.default_session(),
artifact_bucket=artifact_bucket,
artifact_prefix=artifact_prefix,
)
else:
raise RuntimeError(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/experiments/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
TEST_EXP_DISPLAY_NAME = "my-experiment-display-name"
TEST_RUN_DISPLAY_NAME = "my-run-display-name"
TEST_TAGS = [{"Key": "some-key", "Value": "some-value"}]
TEST_ARTIFACT_BUCKET = "my-artifact-bucket"
TEST_ARTIFACT_PREFIX = "my-artifact-prefix"


def mock_tc_load_or_create_func(
Expand Down
79 changes: 74 additions & 5 deletions tests/unit/sagemaker/experiments/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from sagemaker.experiments import Run, load_run, list_runs
from sagemaker.experiments.trial import _Trial
from sagemaker.experiments.trial_component import _TrialComponent
from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX
from tests.unit.sagemaker.experiments.helpers import (
mock_trial_load_or_create_func,
mock_tc_load_or_create_func,
Expand All @@ -52,9 +53,25 @@
TEST_RUN_NAME,
TEST_EXP_DISPLAY_NAME,
TEST_RUN_DISPLAY_NAME,
TEST_ARTIFACT_BUCKET,
TEST_ARTIFACT_PREFIX,
)


@pytest.mark.parametrize(
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
[
({}, None, _DEFAULT_ARTIFACT_PREFIX),
(
{
"artifact_bucket": TEST_ARTIFACT_BUCKET,
"artifact_prefix": TEST_ARTIFACT_PREFIX,
},
TEST_ARTIFACT_BUCKET,
TEST_ARTIFACT_PREFIX,
),
],
)
@patch(
"sagemaker.experiments.run.Experiment._load_or_create",
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
Expand All @@ -69,9 +86,18 @@
MagicMock(side_effect=mock_tc_load_or_create_func),
)
@patch.object(_TrialComponent, "save")
def test_run_init(mock_tc_save, sagemaker_session):
def test_run_init(
mock_tc_save,
sagemaker_session,
kwargs,
expected_artifact_bucket,
expected_artifact_prefix,
):
with Run(
experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session
experiment_name=TEST_EXP_NAME,
run_name=TEST_RUN_NAME,
sagemaker_session=sagemaker_session,
**kwargs,
) as run_obj:
assert not run_obj._in_load
assert not run_obj._inside_load_context
Expand All @@ -90,6 +116,8 @@ def test_run_init(mock_tc_save, sagemaker_session):
TRIAL_NAME: run_obj.run_group_name,
RUN_NAME: expected_tc_name,
}
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix

# trail_component.save is called when entering/ exiting the with block
mock_tc_save.assert_called()
Expand Down Expand Up @@ -124,6 +152,20 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
)


@pytest.mark.parametrize(
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
[
({}, None, _DEFAULT_ARTIFACT_PREFIX),
(
{
"artifact_bucket": TEST_ARTIFACT_BUCKET,
"artifact_prefix": TEST_ARTIFACT_PREFIX,
},
TEST_ARTIFACT_BUCKET,
TEST_ARTIFACT_PREFIX,
),
],
)
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
@patch(
"sagemaker.experiments.run.Experiment._load_or_create",
Expand All @@ -139,7 +181,13 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
MagicMock(side_effect=mock_tc_load_or_create_func),
)
@patch("sagemaker.experiments.run._RunEnvironment")
def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
def test_run_load_no_run_name_and_in_train_job(
mock_run_env,
sagemaker_session,
kwargs,
expected_artifact_bucket,
expected_artifact_prefix,
):
client = sagemaker_session.sagemaker_client
job_name = "my-train-job"
rv = Mock()
Expand All @@ -158,7 +206,7 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
# The Run object has been created else where
"ExperimentConfig": exp_config,
}
with load_run(sagemaker_session=sagemaker_session) as run_obj:
with load_run(sagemaker_session=sagemaker_session, **kwargs) as run_obj:
assert run_obj._in_load
assert not run_obj._inside_init_context
assert run_obj._inside_load_context
Expand All @@ -169,6 +217,8 @@ def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
assert run_obj.experiment_name == TEST_EXP_NAME
assert run_obj._experiment
assert run_obj.experiment_config == exp_config
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix

client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)

Expand Down Expand Up @@ -215,6 +265,20 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
assert "Failed to load a Run object" in str(err)


@pytest.mark.parametrize(
("kwargs", "expected_artifact_bucket", "expected_artifact_prefix"),
[
({}, None, _DEFAULT_ARTIFACT_PREFIX),
(
{
"artifact_bucket": TEST_ARTIFACT_BUCKET,
"artifact_prefix": TEST_ARTIFACT_PREFIX,
},
TEST_ARTIFACT_BUCKET,
TEST_ARTIFACT_PREFIX,
),
],
)
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
@patch(
"sagemaker.experiments.run.Experiment._load_or_create",
Expand All @@ -229,11 +293,14 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
"sagemaker.experiments.run._TrialComponent._load_or_create",
MagicMock(side_effect=mock_tc_load_or_create_func),
)
def test_run_load_with_run_name_and_exp_name(sagemaker_session):
def test_run_load_with_run_name_and_exp_name(
sagemaker_session, kwargs, expected_artifact_bucket, expected_artifact_prefix
):
with load_run(
run_name=TEST_RUN_NAME,
experiment_name=TEST_EXP_NAME,
sagemaker_session=sagemaker_session,
**kwargs,
) as run_obj:
expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}"
expected_exp_config = {
Expand All @@ -249,6 +316,8 @@ def test_run_load_with_run_name_and_exp_name(sagemaker_session):
assert run_obj._trial
assert run_obj._experiment
assert run_obj.experiment_config == expected_exp_config
assert run_obj._artifact_uploader.artifact_bucket == expected_artifact_bucket
assert run_obj._artifact_uploader.artifact_prefix == expected_artifact_prefix


def test_run_load_with_run_name_but_no_exp_name(sagemaker_session):
Expand Down