-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Enable customizing artifact output path #3965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
from sagemaker.experiments._helper import ( | ||
_ArtifactUploader, | ||
_LineageArtifactTracker, | ||
_DEFAULT_ARTIFACT_PREFIX, | ||
) | ||
from sagemaker.experiments._environment import _RunEnvironment | ||
from sagemaker.experiments._run_context import _RunContext | ||
|
@@ -95,6 +96,8 @@ def __init__( | |
run_display_name: Optional[str] = None, | ||
tags: Optional[List[Dict[str, str]]] = None, | ||
sagemaker_session: Optional["Session"] = None, | ||
artifact_bucket: Optional[str] = None, | ||
artifact_prefix: Optional[str] = None, | ||
): | ||
"""Construct a `Run` instance. | ||
|
||
|
@@ -152,6 +155,11 @@ def __init__( | |
manages interactions with Amazon SageMaker APIs and any other | ||
AWS services needed. If not specified, one is created using the | ||
default AWS configuration chain. | ||
artifact_bucket (str): The S3 bucket to upload the artifact to. | ||
If not specified, the default bucket defined in `sagemaker_session` | ||
will be used. | ||
artifact_prefix (str): The S3 key prefix used to generate the S3 path | ||
to upload the artifact to (default: "trial-component-artifacts"). | ||
""" | ||
# TODO: we should revert the lower casting once backend fix reaches prod | ||
self.experiment_name = experiment_name.lower() | ||
|
@@ -197,6 +205,10 @@ def __init__( | |
self._artifact_uploader = _ArtifactUploader( | ||
trial_component_name=self._trial_component.trial_component_name, | ||
sagemaker_session=sagemaker_session, | ||
artifact_bucket=artifact_bucket, | ||
artifact_prefix=_DEFAULT_ARTIFACT_PREFIX | ||
if artifact_prefix is None | ||
else artifact_prefix, | ||
Comment on lines
+209
to
+211
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit this indentation is unintuitive There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nm looks like there is not many better alternatives https://stackoverflow.com/questions/28897010/how-to-make-a-line-break-on-the-python-ternary-operator |
||
) | ||
self._lineage_artifact_tracker = _LineageArtifactTracker( | ||
trial_component_arn=self._trial_component.trial_component_arn, | ||
|
@@ -729,6 +741,8 @@ def load_run( | |
run_name: Optional[str] = None, | ||
experiment_name: Optional[str] = None, | ||
sagemaker_session: Optional["Session"] = None, | ||
artifact_bucket: Optional[str] = None, | ||
artifact_prefix: Optional[str] = None, | ||
) -> Run: | ||
"""Load an existing run. | ||
|
||
|
@@ -792,6 +806,11 @@ def load_run( | |
manages interactions with Amazon SageMaker APIs and any other | ||
AWS services needed. If not specified, one is created using the | ||
default AWS configuration chain. | ||
artifact_bucket (str): The S3 bucket to upload the artifact to. | ||
If not specified, the default bucket defined in `sagemaker_session` | ||
will be used. | ||
artifact_prefix (str): The S3 key prefix used to generate the S3 path | ||
to upload the artifact to (default: "trial-component-artifacts"). | ||
|
||
Returns: | ||
Run: The loaded Run object. | ||
|
@@ -811,6 +830,8 @@ def load_run( | |
experiment_name=experiment_name, | ||
run_name=run_name, | ||
sagemaker_session=sagemaker_session or _utils.default_session(), | ||
artifact_bucket=artifact_bucket, | ||
artifact_prefix=artifact_prefix, | ||
) | ||
elif _RunContext.get_current_run(): | ||
run_instance = _RunContext.get_current_run() | ||
|
@@ -827,6 +848,8 @@ def load_run( | |
experiment_name=experiment_name, | ||
run_name=run_name, | ||
sagemaker_session=sagemaker_session or _utils.default_session(), | ||
artifact_bucket=artifact_bucket, | ||
artifact_prefix=artifact_prefix, | ||
) | ||
else: | ||
raise RuntimeError( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit the name of this sounds like it will prefix the actual name of the artifact but its really a prefix for a s3 path where the associated files will be stored
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a bit confusing to use this name here. But
_DEFAULT_ARTIFACT_PREFIX
is not introduced in this PR, it's already defined in _helper.py. There is a variable calledartifact_prefix
, so I'm assuming that is why we called the default variable to be_DEFAULT_ARTIFACT_PREFIX
.A workaround could be change the following method defined in
_ArtifactUploader
.In this way, we don't need to import
_DEFAULT_ARTIFACT_PREFIX
to run.py