Skip to content

Commit 33f37d8

Browse files
committed
Use distutils.dir_util.copy_tree instead and add unit tests
1 parent 346169c commit 33f37d8

File tree

4 files changed

+15
-32
lines changed

4 files changed

+15
-32
lines changed

.pylintrc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,7 @@ dummy-variables-rgx=_|unused_
8484
# Apply logging string format checks to calls on these modules.
8585
logging-modules=
8686
logging
87+
88+
[TYPECHECK]
89+
ignored-modules=
90+
distutils

src/sagemaker/local/image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import tarfile
2828
import tempfile
2929

30+
from distutils.dir_util import copy_tree
3031
from six.moves.urllib.parse import urlparse
3132
from threading import Thread
3233

@@ -217,9 +218,9 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name):
217218
for volume in volumes:
218219
host_dir, container_dir = volume.split(':')
219220
if container_dir == '/opt/ml/model':
220-
sagemaker.local.utils.recursive_copy(host_dir, model_artifacts)
221+
copy_tree(host_dir, model_artifacts)
221222
elif container_dir == '/opt/ml/output':
222-
sagemaker.local.utils.recursive_copy(host_dir, output_artifacts)
223+
copy_tree(host_dir, output_artifacts)
223224

224225
# Tar Artifacts -> model.tar.gz and output.tar.gz
225226
model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)]

src/sagemaker/local/utils.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import shutil
1717

18+
from distutils.dir_util import copy_tree
1819
from six.moves.urllib.parse import urlparse
1920

2021

@@ -45,14 +46,15 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
4546
Args:
4647
source (str): root directory to move
4748
destination (str): file:// or s3:// URI that source will be moved to.
49+
job_name (str): SageMaker job name.
4850
sagemaker_session (sagemaker.Session): a sagemaker_session to interact with S3 if needed
4951
5052
Returns:
5153
(str): destination URI
5254
"""
5355
parsed_uri = urlparse(destination)
5456
if parsed_uri.scheme == 'file':
55-
recursive_copy(source, parsed_uri.path)
57+
copy_tree(source, parsed_uri.path)
5658
final_uri = destination
5759
elif parsed_uri.scheme == 's3':
5860
bucket = parsed_uri.netloc
@@ -64,25 +66,3 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
6466

6567
shutil.rmtree(source)
6668
return final_uri
67-
68-
69-
def recursive_copy(source, destination):
70-
"""Similar to shutil.copy but the destination directory can exist. Existing files will be overriden.
71-
Args:
72-
source (str): source path
73-
destination (str): destination path
74-
"""
75-
if not os.path.exists(destination):
76-
os.makedirs(destination, exist_ok=True)
77-
78-
for root, dirs, files in os.walk(source):
79-
root = os.path.relpath(root, source)
80-
current_path = os.path.join(source, root)
81-
target_path = os.path.join(destination, root)
82-
83-
for file in files:
84-
shutil.copy(os.path.join(current_path, file), os.path.join(target_path, file))
85-
for d in dirs:
86-
new_dir = os.path.join(target_path, d)
87-
if not os.path.exists(new_dir):
88-
os.mkdir(os.path.join(target_path, d))

tests/unit/test_local_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717

1818
import sagemaker.local.utils
1919

20-
BUCKET_NAME = 'some-nice-bucket'
21-
2220

2321
@patch('shutil.rmtree', Mock())
24-
@patch('sagemaker.local.utils.recursive_copy')
25-
def test_move_to_destination(recursive_copy):
26-
# local files will just be recursive copied
27-
sagemaker.local.utils.move_to_destination('/tmp/data', 'file:///target/dir/', 'job', None)
28-
recursive_copy.assert_called()
22+
@patch('sagemaker.local.utils.copy_tree')
23+
def test_move_to_destination(copy_tree):
24+
# local files will just be recursively copied
25+
sagemaker.local.utils.move_to_destination('/tmp/data', 'file:///target/dir', 'job', None)
26+
copy_tree.assert_called_with('/tmp/data', '/target/dir')
2927

3028
# s3 destination will upload to S3
3129
sms = Mock()

0 commit comments

Comments
 (0)