Skip to content

Commit 40b0e77

Browse files
committed
refactor according to Nadia's suggestion
1 parent d548d9f commit 40b0e77

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/sagemaker/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,6 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
339339
boto_session = sagemaker_session.boto_session
340340
s3 = boto_session.resource("s3", region_name=boto_session.region_name)
341341

342-
bucket = s3.Bucket(bucket_name)
343342
prefix = prefix.lstrip("/")
344343

345344
# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
@@ -351,11 +350,21 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
351350
if e.response["Error"]["Code"] == "404" and e.response["Error"]["Message"] == "Not Found":
352351
# S3 also throws this error if the object is a folder,
353352
# so assume that is the case here, and then raise for an actual 404 later.
354-
pass
353+
_download_files_under_prefix(bucket_name, prefix, target, s3)
355354
else:
356355
raise
357356

358-
# Assume the prefix points to an S3 'directory' and download the whole thing
357+
358+
def _download_files_under_prefix(bucket_name, prefix, target, s3):
359+
"""Download all S3 files which match the given prefix
360+
361+
Args:
362+
bucket_name (str): S3 bucket name
363+
prefix (str): S3 prefix within the bucket that will be downloaded
364+
target (str): destination path where the downloaded items will be placed
365+
s3 (boto3.resources.base.ServiceResource): S3 resource
366+
"""
367+
bucket = s3.Bucket(bucket_name)
359368
for obj_sum in bucket.objects.filter(Prefix=prefix):
360369
# if obj_sum is a folder object skip it.
361370
if obj_sum.key != "" and obj_sum.key[-1] == "/":

0 commit comments

Comments
 (0)