Skip to content

Commit beb5a9a

Browse files
authored
fix: allow download_folder to download file even if bucket is more restricted (#1295)
1 parent 157a208 commit beb5a9a

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

src/sagemaker/utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323
import tarfile
2424
import tempfile
2525
import time
26-
27-
2826
from datetime import datetime
2927
from functools import wraps
3028

3129
import six
3230
from six.moves.urllib import parse
31+
import botocore
3332

3433

3534
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$"
@@ -338,21 +337,34 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
338337
interact with S3.
339338
"""
340339
boto_session = sagemaker_session.boto_session
341-
342-
s3 = boto_session.resource("s3")
343-
bucket = s3.Bucket(bucket_name)
340+
s3 = boto_session.resource("s3", region_name=boto_session.region_name)
344341

345342
prefix = prefix.lstrip("/")
346343

347-
# there is a chance that the prefix points to a file and not a 'directory' if that is the case
348-
# we should just download it.
349-
objects = list(bucket.objects.filter(Prefix=prefix))
350-
351-
if len(objects) > 0 and objects[0].key == prefix and prefix[-1] != "/":
344+
# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
345+
# Do this first, in case the object has broader permissions than the bucket.
346+
try:
352347
s3.Object(bucket_name, prefix).download_file(os.path.join(target, os.path.basename(prefix)))
353348
return
349+
except botocore.exceptions.ClientError as e:
350+
if e.response["Error"]["Code"] == "404" and e.response["Error"]["Message"] == "Not Found":
351+
# S3 also throws this error if the object is a folder,
352+
# so assume that is the case here, and then raise for an actual 404 later.
353+
_download_files_under_prefix(bucket_name, prefix, target, s3)
354+
else:
355+
raise
354356

355-
# the prefix points to an s3 'directory' download the whole thing
357+
358+
def _download_files_under_prefix(bucket_name, prefix, target, s3):
359+
"""Download all S3 files which match the given prefix
360+
361+
Args:
362+
bucket_name (str): S3 bucket name
363+
prefix (str): S3 prefix within the bucket that will be downloaded
364+
target (str): destination path where the downloaded items will be placed
365+
s3 (boto3.resources.base.ServiceResource): S3 resource
366+
"""
367+
bucket = s3.Bucket(bucket_name)
356368
for obj_sum in bucket.objects.filter(Prefix=prefix):
357369
# if obj_sum is a folder object skip it.
358370
if obj_sum.key != "" and obj_sum.key[-1] == "/":

tests/unit/test_utils.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,24 @@ def test_generate_tensorboard_url_domain_non_string():
310310
@patch("os.makedirs")
311311
def test_download_folder(makedirs):
312312
boto_mock = Mock(name="boto_session")
313-
boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"}
314-
315313
session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
314+
s3_mock = boto_mock.resource("s3")
315+
316+
obj_mock = Mock()
317+
s3_mock.Object.return_value = obj_mock
318+
319+
def obj_mock_download(path):
320+
# Mock the S3 object to raise an error when the input to download_file
321+
# is a "folder"
322+
if path in ("/tmp/", os.path.join("/tmp", "prefix")):
323+
raise botocore.exceptions.ClientError(
324+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
325+
operation_name="HeadObject",
326+
)
327+
else:
328+
return Mock()
329+
330+
obj_mock.download_file.side_effect = obj_mock_download
316331

317332
train_data = Mock()
318333
validation_data = Mock()
@@ -323,23 +338,20 @@ def test_download_folder(makedirs):
323338
validation_data.key = "prefix/train/validation_data.csv"
324339

325340
s3_files = [train_data, validation_data]
326-
boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.return_value = s3_files
327-
328-
obj_mock = Mock()
329-
boto_mock.resource("s3").Object.return_value = obj_mock
341+
s3_mock.Bucket(BUCKET_NAME).objects.filter.return_value = s3_files
330342

331343
# all the S3 mocks are set, the test itself begins now.
332344
sagemaker.utils.download_folder(BUCKET_NAME, "/prefix", "/tmp", session)
333345

334346
obj_mock.download_file.assert_called()
335347
calls = [
336-
call(os.path.join("/tmp", "train/train_data.csv")),
337-
call(os.path.join("/tmp", "train/validation_data.csv")),
348+
call(os.path.join("/tmp", "train", "train_data.csv")),
349+
call(os.path.join("/tmp", "train", "validation_data.csv")),
338350
]
339351
obj_mock.download_file.assert_has_calls(calls)
340352
obj_mock.reset_mock()
341353

342-
# Testing with a trailing slash for the prefix.
354+
# Test with a trailing slash for the prefix.
343355
sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/", "/tmp", session)
344356
obj_mock.download_file.assert_called()
345357
obj_mock.download_file.assert_has_calls(calls)
@@ -369,7 +381,7 @@ def test_download_folder_points_to_single_file(makedirs):
369381
obj_mock.download_file.assert_called()
370382
calls = [call(os.path.join("/tmp", "train_data.csv"))]
371383
obj_mock.download_file.assert_has_calls(calls)
372-
assert boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.call_count == 1
384+
boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.assert_not_called()
373385
obj_mock.reset_mock()
374386

375387

0 commit comments

Comments
 (0)