@@ -310,9 +310,24 @@ def test_generate_tensorboard_url_domain_non_string():
310
310
@patch ("os.makedirs" )
311
311
def test_download_folder (makedirs ):
312
312
boto_mock = Mock (name = "boto_session" )
313
- boto_mock .client ("sts" ).get_caller_identity .return_value = {"Account" : "123" }
314
-
315
313
session = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
314
+ s3_mock = boto_mock .resource ("s3" )
315
+
316
+ obj_mock = Mock ()
317
+ s3_mock .Object .return_value = obj_mock
318
+
319
+ def obj_mock_download (path ):
320
+ # Mock the S3 object to raise an error when the input to download_file
321
+ # is a "folder"
322
+ if path in ("/tmp/" , os .path .join ("/tmp" , "prefix" )):
323
+ raise botocore .exceptions .ClientError (
324
+ error_response = {"Error" : {"Code" : "404" , "Message" : "Not Found" }},
325
+ operation_name = "HeadObject" ,
326
+ )
327
+ else :
328
+ return Mock ()
329
+
330
+ obj_mock .download_file .side_effect = obj_mock_download
316
331
317
332
train_data = Mock ()
318
333
validation_data = Mock ()
@@ -323,23 +338,20 @@ def test_download_folder(makedirs):
323
338
validation_data .key = "prefix/train/validation_data.csv"
324
339
325
340
s3_files = [train_data , validation_data ]
326
- boto_mock .resource ("s3" ).Bucket (BUCKET_NAME ).objects .filter .return_value = s3_files
327
-
328
- obj_mock = Mock ()
329
- boto_mock .resource ("s3" ).Object .return_value = obj_mock
341
+ s3_mock .Bucket (BUCKET_NAME ).objects .filter .return_value = s3_files
330
342
331
343
# all the S3 mocks are set, the test itself begins now.
332
344
sagemaker .utils .download_folder (BUCKET_NAME , "/prefix" , "/tmp" , session )
333
345
334
346
obj_mock .download_file .assert_called ()
335
347
calls = [
336
- call (os .path .join ("/tmp" , "train/ train_data.csv" )),
337
- call (os .path .join ("/tmp" , "train/ validation_data.csv" )),
348
+ call (os .path .join ("/tmp" , "train" , " train_data.csv" )),
349
+ call (os .path .join ("/tmp" , "train" , " validation_data.csv" )),
338
350
]
339
351
obj_mock .download_file .assert_has_calls (calls )
340
352
obj_mock .reset_mock ()
341
353
342
- # Testing with a trailing slash for the prefix.
354
+ # Test with a trailing slash for the prefix.
343
355
sagemaker .utils .download_folder (BUCKET_NAME , "/prefix/" , "/tmp" , session )
344
356
obj_mock .download_file .assert_called ()
345
357
obj_mock .download_file .assert_has_calls (calls )
@@ -369,7 +381,7 @@ def test_download_folder_points_to_single_file(makedirs):
369
381
obj_mock .download_file .assert_called ()
370
382
calls = [call (os .path .join ("/tmp" , "train_data.csv" ))]
371
383
obj_mock .download_file .assert_has_calls (calls )
372
- assert boto_mock .resource ("s3" ).Bucket (BUCKET_NAME ).objects .filter .call_count == 1
384
+ boto_mock .resource ("s3" ).Bucket (BUCKET_NAME ).objects .filter .assert_not_called ()
373
385
obj_mock .reset_mock ()
374
386
375
387
0 commit comments