Skip to content

Commit d98597b

Browse files
YogevKrnadiayachuyang-deng
committed
fix: Add uri as return statement for upload_string_as_file_body (#1173)
* fix: Adding uri as return statement * Tests for upload_string_as_file_body.py * black-format: commands succeeded flake8: commands succeeded pylint: commands succeeded twine: commands succeeded sphinx: commands succeeded Co-authored-by: Nadia Yakimakha <[email protected]> Co-authored-by: Chuyang <[email protected]>
1 parent b566052 commit d98597b

File tree

2 files changed

+85
-4
lines changed

2 files changed

+85
-4
lines changed

src/sagemaker/session.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,8 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
224224
kms_key (str): The KMS key to use for encrypting the file.
225225
226226
Returns:
227-
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
228-
the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
229-
If a directory is specified in the path argument, the URI format is
230-
``s3://{bucket name}/{key_prefix}``.
227+
str: The S3 URI of the uploaded file.
228+
The URI format is: ``s3://{bucket name}/{key}``.
231229
"""
232230
s3 = self.boto_session.resource("s3")
233231
s3_object = s3.Object(bucket_name=bucket, key=key)
@@ -237,6 +235,9 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
237235
else:
238236
s3_object.put(Body=body)
239237

238+
s3_uri = "s3://{}/{}".format(bucket, key)
239+
return s3_uri
240+
240241
def download_data(self, path, bucket, key_prefix="", extra_args=None):
241242
"""Download file or directory from S3.
242243
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
from mock import Mock
18+
import pytest
19+
20+
import sagemaker
21+
22+
UPLOAD_DATA_TESTS_FILE_DIR = "upload_data_tests"
23+
SINGLE_FILE_NAME = "file1.py"
24+
BODY = 'print("test")'
25+
DESTINATION_DATA_TESTS_FILE = os.path.join(UPLOAD_DATA_TESTS_FILE_DIR, SINGLE_FILE_NAME)
26+
BUCKET_NAME = "mybucket"
27+
AES_ENCRYPTION_ENABLED = {"ServerSideEncryption": "AES256"}
28+
29+
30+
@pytest.fixture()
31+
def sagemaker_session():
32+
boto_mock = Mock(name="boto_session")
33+
client_mock = Mock()
34+
client_mock.get_caller_identity.return_value = {
35+
"UserId": "mock_user_id",
36+
"Account": "012345678910",
37+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
38+
}
39+
boto_mock.client.return_value = client_mock
40+
ims = sagemaker.Session(boto_session=boto_mock)
41+
ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
42+
return ims
43+
44+
45+
def test_upload_string_file(sagemaker_session):
46+
result_s3_uri = sagemaker_session.upload_string_as_file_body(
47+
body=BODY, bucket=BUCKET_NAME, key=DESTINATION_DATA_TESTS_FILE
48+
)
49+
50+
uploaded_files_with_args = [
51+
kwargs
52+
for name, args, kwargs in sagemaker_session.boto_session.mock_calls
53+
if name == "resource().Object().put"
54+
]
55+
56+
assert result_s3_uri == "s3://{}/{}".format(BUCKET_NAME, DESTINATION_DATA_TESTS_FILE)
57+
assert len(uploaded_files_with_args) == 1
58+
kwargs = uploaded_files_with_args[0]
59+
assert kwargs["Body"] == BODY
60+
61+
62+
def test_upload_aes_encrypted_string_file(sagemaker_session):
63+
result_s3_uri = sagemaker_session.upload_string_as_file_body(
64+
body=BODY,
65+
bucket=BUCKET_NAME,
66+
key=DESTINATION_DATA_TESTS_FILE,
67+
kms_key=AES_ENCRYPTION_ENABLED,
68+
)
69+
70+
uploaded_files_with_args = [
71+
kwargs
72+
for name, args, kwargs in sagemaker_session.boto_session.mock_calls
73+
if name == "resource().Object().put"
74+
]
75+
76+
assert result_s3_uri == "s3://{}/{}".format(BUCKET_NAME, DESTINATION_DATA_TESTS_FILE)
77+
assert len(uploaded_files_with_args) == 1
78+
kwargs = uploaded_files_with_args[0]
79+
assert kwargs["Body"] == BODY
80+
assert kwargs["SSEKMSKeyId"] == AES_ENCRYPTION_ENABLED

0 commit comments

Comments
 (0)