Skip to content

Commit b4d1d70

Browse files
committed
fix: allow download_folder to download file even if bucket is more restricted
1 parent 66e517b commit b4d1d70

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

src/sagemaker/utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323
import tarfile
2424
import tempfile
2525
import time
26-
27-
2826
from datetime import datetime
2927
from functools import wraps
3028

3129
import six
3230
from six.moves.urllib import parse
31+
import botocore
3332

3433

3534
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$"
@@ -338,21 +337,25 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
338337
interact with S3.
339338
"""
340339
boto_session = sagemaker_session.boto_session
340+
s3 = boto_session.resource("s3", region_name=boto_session.region_name)
341341

342-
s3 = boto_session.resource("s3")
343342
bucket = s3.Bucket(bucket_name)
344-
345343
prefix = prefix.lstrip("/")
346344

347-
# there is a chance that the prefix points to a file and not a 'directory' if that is the case
348-
# we should just download it.
349-
objects = list(bucket.objects.filter(Prefix=prefix))
350-
351-
if len(objects) > 0 and objects[0].key == prefix and prefix[-1] != "/":
345+
# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
346+
# Do this first, in case the object has broader permissions than the bucket.
347+
try:
352348
s3.Object(bucket_name, prefix).download_file(os.path.join(target, os.path.basename(prefix)))
353349
return
350+
except botocore.exceptions.ClientError as e:
351+
if e.response["Error"]["Code"] == "404" and e.response["Error"]["Message"] == "Not Found":
352+
# S3 also throws this error if the object is a folder,
353+
# so assume that is the case here, and then raise for an actual 404 later.
354+
pass
355+
else:
356+
raise
354357

355-
# the prefix points to an s3 'directory' download the whole thing
358+
# Assume the prefix points to an S3 'directory' and download the whole thing
356359
for obj_sum in bucket.objects.filter(Prefix=prefix):
357360
# if obj_sum is a folder object skip it.
358361
if obj_sum.key != "" and obj_sum.key[-1] == "/":

tests/unit/test_utils.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,24 @@ def test_generate_tensorboard_url_domain_non_string():
310310
@patch("os.makedirs")
311311
def test_download_folder(makedirs):
312312
boto_mock = Mock(name="boto_session")
313-
boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"}
314-
315313
session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
314+
s3_mock = boto_mock.resource("s3")
315+
316+
obj_mock = Mock()
317+
s3_mock.Object.return_value = obj_mock
318+
319+
def obj_mock_download(path):
320+
# Mock the S3 object to raise an error when the input to download_file
321+
# is a "folder"
322+
if path in ("/tmp/", os.path.join("/tmp", "prefix")):
323+
raise botocore.exceptions.ClientError(
324+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
325+
operation_name="HeadObject",
326+
)
327+
else:
328+
return Mock()
329+
330+
obj_mock.download_file.side_effect = obj_mock_download
316331

317332
train_data = Mock()
318333
validation_data = Mock()
@@ -323,23 +338,20 @@ def test_download_folder(makedirs):
323338
validation_data.key = "prefix/train/validation_data.csv"
324339

325340
s3_files = [train_data, validation_data]
326-
boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.return_value = s3_files
327-
328-
obj_mock = Mock()
329-
boto_mock.resource("s3").Object.return_value = obj_mock
341+
s3_mock.Bucket(BUCKET_NAME).objects.filter.return_value = s3_files
330342

331343
# all the S3 mocks are set, the test itself begins now.
332344
sagemaker.utils.download_folder(BUCKET_NAME, "/prefix", "/tmp", session)
333345

334346
obj_mock.download_file.assert_called()
335347
calls = [
336-
call(os.path.join("/tmp", "train/train_data.csv")),
337-
call(os.path.join("/tmp", "train/validation_data.csv")),
348+
call(os.path.join("/tmp", "train", "train_data.csv")),
349+
call(os.path.join("/tmp", "train", "validation_data.csv")),
338350
]
339351
obj_mock.download_file.assert_has_calls(calls)
340352
obj_mock.reset_mock()
341353

342-
# Testing with a trailing slash for the prefix.
354+
# Test with a trailing slash for the prefix.
343355
sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/", "/tmp", session)
344356
obj_mock.download_file.assert_called()
345357
obj_mock.download_file.assert_has_calls(calls)
@@ -369,7 +381,7 @@ def test_download_folder_points_to_single_file(makedirs):
369381
obj_mock.download_file.assert_called()
370382
calls = [call(os.path.join("/tmp", "train_data.csv"))]
371383
obj_mock.download_file.assert_has_calls(calls)
372-
assert boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.call_count == 1
384+
boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.assert_not_called()
373385
obj_mock.reset_mock()
374386

375387

0 commit comments

Comments
 (0)