Skip to content

fix: Add uri as return statement for upload_string_as_file_body #1173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
9 changes: 5 additions & 4 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,8 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
kms_key (str): The KMS key to use for encrypting the file.

Returns:
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
If a directory is specified in the path argument, the URI format is
``s3://{bucket name}/{key_prefix}``.
str: The S3 URI of the uploaded file.
The URI format is: ``s3://{bucket name}/{key}``.
"""
s3 = self.boto_session.resource("s3")
s3_object = s3.Object(bucket_name=bucket, key=key)
Expand All @@ -237,6 +235,9 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
else:
s3_object.put(Body=body)

s3_uri = "s3://{}/{}".format(bucket, key)
return s3_uri

def download_data(self, path, bucket, key_prefix="", extra_args=None):
"""Download file or directory from S3.

Expand Down
80 changes: 80 additions & 0 deletions tests/unit/test_upload_string_as_file_body.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

from mock import Mock
import pytest

import sagemaker

UPLOAD_DATA_TESTS_FILE_DIR = "upload_data_tests"
SINGLE_FILE_NAME = "file1.py"
BODY = 'print("test")'
DESTINATION_DATA_TESTS_FILE = os.path.join(UPLOAD_DATA_TESTS_FILE_DIR, SINGLE_FILE_NAME)
BUCKET_NAME = "mybucket"
AES_ENCRYPTION_ENABLED = {"ServerSideEncryption": "AES256"}


@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session")
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_mock.client.return_value = client_mock
ims = sagemaker.Session(boto_session=boto_mock)
ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
return ims


def test_upload_string_file(sagemaker_session):
result_s3_uri = sagemaker_session.upload_string_as_file_body(
body=BODY, bucket=BUCKET_NAME, key=DESTINATION_DATA_TESTS_FILE
)

uploaded_files_with_args = [
kwargs
for name, args, kwargs in sagemaker_session.boto_session.mock_calls
if name == "resource().Object().put"
]

assert result_s3_uri == "s3://{}/{}".format(BUCKET_NAME, DESTINATION_DATA_TESTS_FILE)
assert len(uploaded_files_with_args) == 1
kwargs = uploaded_files_with_args[0]
assert kwargs["Body"] == BODY


def test_upload_aes_encrypted_string_file(sagemaker_session):
result_s3_uri = sagemaker_session.upload_string_as_file_body(
body=BODY,
bucket=BUCKET_NAME,
key=DESTINATION_DATA_TESTS_FILE,
kms_key=AES_ENCRYPTION_ENABLED,
)

uploaded_files_with_args = [
kwargs
for name, args, kwargs in sagemaker_session.boto_session.mock_calls
if name == "resource().Object().put"
]

assert result_s3_uri == "s3://{}/{}".format(BUCKET_NAME, DESTINATION_DATA_TESTS_FILE)
assert len(uploaded_files_with_args) == 1
kwargs = uploaded_files_with_args[0]
assert kwargs["Body"] == BODY
assert kwargs["SSEKMSKeyId"] == AES_ENCRYPTION_ENABLED