Skip to content

Commit 80858e7

Browse files
martinRenouroot
authored andcommitted
Change: Allow extra_args to be passed to uploader (aws#4338)
* Change: Allow extra_args to be passed to uploader * Fix tests * Black * Fix test
1 parent fee50e5 commit 80858e7

File tree

4 files changed

+29
-11
lines changed

4 files changed

+29
-11
lines changed

src/sagemaker/experiments/_helper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,15 @@ def __init__(
5959
self.artifact_prefix = artifact_prefix
6060
self._s3_client = self.sagemaker_session.boto_session.client("s3")
6161

62-
def upload_artifact(self, file_path):
62+
def upload_artifact(self, file_path, extra_args=None):
6363
"""Upload an artifact file to S3.
6464
6565
Args:
6666
file_path (str): the file path of the artifact
67+
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
68+
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
69+
ExtraArgs parameter documentation here:
70+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
6771
6872
Returns:
6973
(str, str): The s3 URI of the uploaded file and the etag of the file.
@@ -91,7 +95,12 @@ def upload_artifact(self, file_path):
9195
artifact_s3_key = "{}/{}/{}".format(
9296
self.artifact_prefix, self.trial_component_name, artifact_name
9397
)
94-
self._s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key)
98+
self._s3_client.upload_file(
99+
file_path,
100+
self.artifact_bucket,
101+
artifact_s3_key,
102+
ExtraArgs=extra_args,
103+
)
95104
etag = self._try_get_etag(artifact_s3_key)
96105
return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
97106

src/sagemaker/experiments/run.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,8 @@ def log_file(
508508
file_path: str,
509509
name: Optional[str] = None,
510510
media_type: Optional[str] = None,
511-
is_output: bool = True,
511+
is_output: Optional[bool] = True,
512+
extra_args: Optional[dict] = None,
512513
):
513514
"""Upload a file to s3 and store it as an input/output artifact in this run.
514515
@@ -521,11 +522,15 @@ def log_file(
521522
is_output (bool): Determines direction of association to the
522523
run. Defaults to True (output artifact).
523524
If set to False then represented as input association.
525+
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
526+
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
527+
ExtraArgs parameter documentation here:
528+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
524529
"""
525530
self._verify_trial_component_artifacts_length(is_output)
526531
media_type = media_type or guess_media_type(file_path)
527532
name = name or resolve_artifact_name(file_path)
528-
s3_uri, _ = self._artifact_uploader.upload_artifact(file_path)
533+
s3_uri, _ = self._artifact_uploader.upload_artifact(file_path, extra_args=extra_args)
529534
if is_output:
530535
self._trial_component.output_artifacts[name] = TrialComponentArtifact(
531536
value=s3_uri, media_type=media_type

tests/unit/sagemaker/experiments/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_artifact_uploader_upload_artifact(tempdir, artifact_uploader):
171171
)
172172

173173
artifact_uploader._s3_client.upload_file.assert_called_with(
174-
path, artifact_uploader.artifact_bucket, expected_key
174+
path, artifact_uploader.artifact_bucket, expected_key, ExtraArgs=None
175175
)
176176

177177
expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key)

tests/unit/sagemaker/experiments/test_run.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -592,11 +592,11 @@ def test_log_output_artifact(run_obj):
592592
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
593593
with run_obj:
594594
run_obj.log_file("foo.txt", "name", "whizz/bang")
595-
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
595+
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
596596
assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type
597597

598598
run_obj.log_file("foo.txt")
599-
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
599+
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
600600
assert "foo.txt" in run_obj._trial_component.output_artifacts
601601
assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type
602602

@@ -611,11 +611,11 @@ def test_log_input_artifact(run_obj):
611611
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
612612
with run_obj:
613613
run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
614-
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
614+
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
615615
assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type
616616

617617
run_obj.log_file("foo.txt", is_output=False)
618-
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
618+
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
619619
assert "foo.txt" in run_obj._trial_component.input_artifacts
620620
assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type
621621

@@ -655,7 +655,9 @@ def test_log_multiple_input_artifacts(run_obj):
655655
run_obj.log_file(
656656
file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False
657657
)
658-
run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path)
658+
run_obj._artifact_uploader.upload_artifact.assert_called_with(
659+
file_path, extra_args=None
660+
)
659661

660662
run_obj._artifact_uploader.upload_artifact.return_value = (
661663
"s3uri_value",
@@ -680,7 +682,9 @@ def test_log_multiple_output_artifacts(run_obj):
680682
"etag_value" + str(index),
681683
)
682684
run_obj.log_file(file_path, "name" + str(index), "whizz/bang" + str(index))
683-
run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path)
685+
run_obj._artifact_uploader.upload_artifact.assert_called_with(
686+
file_path, extra_args=None
687+
)
684688

685689
run_obj._artifact_uploader.upload_artifact.return_value = (
686690
"s3uri_value",

0 commit comments

Comments
 (0)