Skip to content

fix: allow S3 folder input to contain a trailing slash in Local Mode #1437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,21 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):

# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
# Do this first, in case the object has broader permissions than the bucket.
try:
s3.Object(bucket_name, prefix).download_file(os.path.join(target, os.path.basename(prefix)))
return
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404" and e.response["Error"]["Message"] == "Not Found":
# S3 also throws this error if the object is a folder,
# so assume that is the case here, and then raise for an actual 404 later.
_download_files_under_prefix(bucket_name, prefix, target, s3)
else:
raise
if not prefix.endswith("/"):
try:
file_destination = os.path.join(target, os.path.basename(prefix))
s3.Object(bucket_name, prefix).download_file(file_destination)
return
except botocore.exceptions.ClientError as e:
err_info = e.response["Error"]
if err_info["Code"] == "404" and err_info["Message"] == "Not Found":
# S3 also throws this error if the object is a folder,
# so assume that is the case here, and then raise for an actual 404 later.
pass
else:
raise

_download_files_under_prefix(bucket_name, prefix, target, s3)


def _download_files_under_prefix(bucket_name, prefix, target, s3):
Expand All @@ -370,7 +375,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3):
bucket = s3.Bucket(bucket_name)
for obj_sum in bucket.objects.filter(Prefix=prefix):
# if obj_sum is a folder object skip it.
if obj_sum.key != "" and obj_sum.key[-1] == "/":
if obj_sum.key.endswith("/"):
continue
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -349,12 +349,16 @@ def obj_mock_download(path):
call(os.path.join("/tmp", "train", "validation_data.csv")),
]
obj_mock.download_file.assert_has_calls(calls)
assert s3_mock.Object.call_count == 3

s3_mock.reset_mock()
obj_mock.reset_mock()

# Test with a trailing slash for the prefix.
sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/", "/tmp", session)
obj_mock.download_file.assert_called()
obj_mock.download_file.assert_has_calls(calls)
assert s3_mock.Object.call_count == 2


@patch("os.makedirs")
Expand Down