Skip to content

Commit 90f8b0f

Browse files
authored
fix: allow S3 folder input to contain a trailing slash in Local Mode (#1437)
1 parent aafae70 commit 90f8b0f

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

src/sagemaker/utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,21 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
346346

347347
# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
348348
# Do this first, in case the object has broader permissions than the bucket.
349-
try:
350-
s3.Object(bucket_name, prefix).download_file(os.path.join(target, os.path.basename(prefix)))
351-
return
352-
except botocore.exceptions.ClientError as e:
353-
if e.response["Error"]["Code"] == "404" and e.response["Error"]["Message"] == "Not Found":
354-
# S3 also throws this error if the object is a folder,
355-
# so assume that is the case here, and then raise for an actual 404 later.
356-
_download_files_under_prefix(bucket_name, prefix, target, s3)
357-
else:
358-
raise
349+
if not prefix.endswith("/"):
350+
try:
351+
file_destination = os.path.join(target, os.path.basename(prefix))
352+
s3.Object(bucket_name, prefix).download_file(file_destination)
353+
return
354+
except botocore.exceptions.ClientError as e:
355+
err_info = e.response["Error"]
356+
if err_info["Code"] == "404" and err_info["Message"] == "Not Found":
357+
# S3 also throws this error if the object is a folder,
358+
# so assume that is the case here, and then raise for an actual 404 later.
359+
pass
360+
else:
361+
raise
362+
363+
_download_files_under_prefix(bucket_name, prefix, target, s3)
359364

360365

361366
def _download_files_under_prefix(bucket_name, prefix, target, s3):
@@ -370,7 +375,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3):
370375
bucket = s3.Bucket(bucket_name)
371376
for obj_sum in bucket.objects.filter(Prefix=prefix):
372377
# if obj_sum is a folder object skip it.
373-
if obj_sum.key != "" and obj_sum.key[-1] == "/":
378+
if obj_sum.key.endswith("/"):
374379
continue
375380
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
376381
s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")

tests/unit/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License"). You
66
# may not use this file except in compliance with the License. A copy of
@@ -349,12 +349,16 @@ def obj_mock_download(path):
349349
call(os.path.join("/tmp", "train", "validation_data.csv")),
350350
]
351351
obj_mock.download_file.assert_has_calls(calls)
352+
assert s3_mock.Object.call_count == 3
353+
354+
s3_mock.reset_mock()
352355
obj_mock.reset_mock()
353356

354357
# Test with a trailing slash for the prefix.
355358
sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/", "/tmp", session)
356359
obj_mock.download_file.assert_called()
357360
obj_mock.download_file.assert_has_calls(calls)
361+
assert s3_mock.Object.call_count == 2
358362

359363

360364
@patch("os.makedirs")

0 commit comments

Comments
 (0)