Skip to content

fix: local mode - support relative file structure #2768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def serve(self, model_dir, environment):
script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()]
parsed_uri = urlparse(script_dir)
if parsed_uri.scheme == "file":
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
volumes.append(_Volume(host_dir, "/opt/ml/code"))
# Update path to mount location
environment = environment.copy()
environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = "/opt/ml/code"
Expand Down Expand Up @@ -495,7 +496,8 @@ def _prepare_training_volumes(
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
parsed_uri = urlparse(training_dir)
if parsed_uri.scheme == "file":
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
volumes.append(_Volume(host_dir, "/opt/ml/code"))
# Also mount a directory that all the containers can access.
volumes.append(_Volume(shared_dir, "/opt/ml/shared"))

Expand All @@ -504,7 +506,8 @@ def _prepare_training_volumes(
parsed_uri.scheme == "file"
and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters
):
intermediate_dir = os.path.join(parsed_uri.path, "output", "intermediate")
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
intermediate_dir = os.path.join(dir_path, "output", "intermediate")
if not os.path.exists(intermediate_dir):
os.makedirs(intermediate_dir)
volumes.append(_Volume(intermediate_dir, "/opt/ml/output/intermediate"))
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
"""
parsed_uri = urlparse(destination)
if parsed_uri.scheme == "file":
recursive_copy(source, parsed_uri.path)
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
recursive_copy(source, dir_path)
final_uri = destination
elif parsed_uri.scheme == "s3":
bucket = parsed_uri.netloc
Expand Down Expand Up @@ -116,9 +117,8 @@ def get_child_process_ids(pid):
(List[int]): Child process ids
"""
cmd = f"pgrep -P {pid}".split()
output, err = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
).communicate()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
return []
pids = [int(pid) for pid in output.decode("utf-8").split()]
Expand Down
44 changes: 42 additions & 2 deletions tests/unit/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,31 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import pytest
from mock import patch, Mock

import sagemaker.local.utils


@patch("sagemaker.local.utils.os.path")
@patch("sagemaker.local.utils.os")
def test_copy_directory_structure(m_os, m_os_path):
m_os_path.exists.return_value = False
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
m_os.makedirs.assert_called_with("/tmp/", "code/")


@patch("shutil.rmtree", Mock())
@patch("sagemaker.local.utils.recursive_copy")
def test_move_to_destination_local(recursive_copy):
# local files will just be recursively copied
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None)
recursive_copy.assert_called_with("/tmp/data", "/target/dir/")
# given absolute path
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir", "job", None)
recursive_copy.assert_called_with("/tmp/data", "/target/dir")
# given relative path
sagemaker.local.utils.move_to_destination("/tmp/data", "file://root/target/dir", "job", None)
recursive_copy.assert_called_with("/tmp/data", os.path.abspath("./root/target/dir"))


@patch("shutil.rmtree", Mock())
Expand Down Expand Up @@ -52,3 +65,30 @@ def test_move_to_destination_s3(recursive_copy):
def test_move_to_destination_illegal_destination():
with pytest.raises(ValueError):
sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)


@patch("sagemaker.local.utils.os.path")
@patch("sagemaker.local.utils.copy_tree")
def test_recursive_copy(copy_tree, m_os_path):
m_os_path.isdir.return_value = True
sagemaker.local.utils.recursive_copy("source", "destination")
copy_tree.assert_called_with("source", "destination")


@patch("sagemaker.local.utils.os")
@patch("sagemaker.local.utils.get_child_process_ids")
def test_kill_child_processes(m_get_child_process_ids, m_os):
m_get_child_process_ids.return_value = ["child_pids"]
sagemaker.local.utils.kill_child_processes("pid")
m_os.kill.assert_called_with("child_pids", 15)


@patch("sagemaker.local.utils.subprocess")
def test_get_child_process_ids(m_subprocess):
cmd = "pgrep -P pid".split()
process_mock = Mock()
attrs = {"communicate.return_value": (b"\n", False), "returncode": 0}
process_mock.configure_mock(**attrs)
m_subprocess.Popen.return_value = process_mock
sagemaker.local.utils.get_child_process_ids("pid")
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)