Skip to content

Commit f67918a

Browse files
Rohan GujarathiNamrata Madan
authored andcommitted
Feature: support requirements.txt installation on conda env
1 parent 5718b88 commit f67918a

File tree

5 files changed

+161
-31
lines changed

5 files changed

+161
-31
lines changed

src/sagemaker/remote_function/core/runtime_environment.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import sys
1818
import os
19-
import shlex
2019
import subprocess
2120
from sagemaker.s3 import s3_path_join, S3Uploader, S3Downloader
2221
from sagemaker.session import Session
@@ -88,20 +87,19 @@ def bootstrap(self, dependencies_s3_uri: str, job_conda_env: str = None):
8887
"""
8988

9089
if dependencies_s3_uri.endswith(".txt"):
91-
if job_conda_env:
92-
# TODO:
93-
# 1. verify if conda exists in the image
94-
# 2. activate the given conda env
95-
# 3. update the conda env with req.txt file
96-
return
97-
9890
local_path = os.getcwd()
9991
S3Downloader.download(
10092
dependencies_s3_uri, local_path, self.s3_kms_key, self.sagemaker_session
10193
)
10294

10395
local_dependencies_file = os.path.join(local_path, dependencies_s3_uri.split("/")[-1])
104-
self._install_requirements_txt(local_dependencies_file, _python_executable())
96+
conda_env = job_conda_env or os.getenv("DEFAULT_CONDA_ENV")
97+
if conda_env:
98+
self._install_req_txt_in_conda_env(conda_env, local_dependencies_file)
99+
self._write_conda_env_to_file(conda_env)
100+
101+
else:
102+
self._install_requirements_txt(local_dependencies_file, _python_executable())
105103

106104
elif dependencies_s3_uri.endswith(".yml"):
107105
# TODO: implement
@@ -134,31 +132,40 @@ def _create_conda_env(self):
134132
# TODO: implement
135133
pass # pylint: disable=W0107
136134

137-
def _activate_conda_env(self):
138-
"""Activate conda environment"""
139-
# TODO: implement
140-
pass # pylint: disable=W0107
135+
def _install_req_txt_in_conda_env(self, env_name, local_path):
136+
"""Install requirements.txt in the given conda environment"""
137+
138+
cmd = '/bin/bash -c "mamba run -n {} pip install -r {}"'.format(env_name, local_path)
139+
logger.info("Activating conda env and installing requirements: {}".format(cmd))
140+
_run_shell_cmd(cmd)
141+
logger.info("Requirements installed successfully in conda env {}".format(env_name))
141142

142143
def _update_conda_env(self):
143144
"""Update conda env using conda yml file"""
144145
# TODO: implement
145146
pass # pylint: disable=W0107
146147

148+
def _write_conda_env_to_file(self, env_name):
149+
file_name = "remote_function_conda_env.txt"
150+
file_path = os.path.join(os.getcwd(), file_name)
151+
with open(file_path, "w") as output_file:
152+
output_file.write(env_name)
153+
147154

148155
def _run_shell_cmd(cmd: str):
149156
"""This method runs a given shell command using subprocess
150157
151158
Raises RuntimeEnvironmentError if the command fails
152159
"""
153160

154-
process = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
161+
process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
155162

156163
_log_output(process)
157164
error_logs = _log_error(process)
158165
return_code = process.wait()
159166
if return_code:
160-
error_message = "Encountered error while installing dependencies. Reason: {}".format(
161-
error_logs
167+
error_message = "Encountered error while running command '{}'. Reason: {}".format(
168+
cmd, error_logs
162169
)
163170
raise RuntimeEnvironmentError(error_message)
164171

src/sagemaker/remote_function/job_driver.sh

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44

55
set -eu
66

7-
printf "INFO: Bootstraping runtime environment\n"
7+
printf "INFO: Bootstraping runtime environment.\n"
88
python -m sagemaker.remote_function.bootstrap "$@"
99

10-
printf "INFO: Invoking remote function\n"
11-
python -m sagemaker.remote_function.invoke_function "$@"
10+
if [ -f "remote_function_conda_env.txt" ]
11+
then
12+
conda_env=$(cat remote_function_conda_env.txt)
13+
printf "INFO: Invoking remote function inside conda environment: $conda_env.\n"
14+
mamba run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
15+
else
16+
printf "INFO: No conda env provided. Invoking remote function\n"
17+
python -m sagemaker.remote_function.invoke_function "$@"
18+
fi
19+

tests/integ/sagemaker/remote_function/conftest.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,47 @@
3434
"RUN rm {source_archive}\n"
3535
)
3636

37+
DOCKERFILE_TEMPLATE_WITH_CONDA = (
38+
"FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n"
39+
"WORKDIR /opt/ml/remote_function/\n"
40+
'SHELL ["/bin/bash", "-c"]\n'
41+
"RUN apt-get update -y \
42+
&& apt-get install -y unzip curl\n\n"
43+
"RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh' \
44+
&& bash Mambaforge-Linux-x86_64.sh -b -p '/opt/conda' \
45+
&& /opt/conda/bin/conda init bash\n\n"
46+
"ENV PATH $PATH:/opt/conda/bin\n"
47+
"RUN mamba create -n integ_test_env python=3.10 -y\n"
48+
"COPY {source_archive} ./\n"
49+
"RUN pip install '{source_archive}[remote_function]' \
50+
&& mamba run -n base pip install '{source_archive}[remote_function]' \
51+
&& mamba run -n integ_test_env pip install '{source_archive}[remote_function]' \
52+
&& rm {source_archive}\n"
53+
"ENV SHELL=/bin/bash\n"
54+
"ENV DEFAULT_CONDA_ENV=base\n"
55+
)
56+
3757

3858
@pytest.fixture(scope="package")
3959
def dummy_container_without_error(sagemaker_session):
4060
# TODO: the python version should be dynamically specified instead of hardcoding
41-
ecr_uri = _build_container(sagemaker_session, "3.10")
61+
ecr_uri = _build_container(sagemaker_session, "3.10", DOCKERFILE_TEMPLATE)
4262
return ecr_uri
4363

4464

4565
@pytest.fixture(scope="package")
4666
def dummy_container_incompatible_python_runtime(sagemaker_session):
47-
ecr_uri = _build_container(sagemaker_session, "3.7")
67+
ecr_uri = _build_container(sagemaker_session, "3.7", DOCKERFILE_TEMPLATE)
68+
return ecr_uri
69+
70+
71+
@pytest.fixture(scope="package")
72+
def dummy_container_with_conda(sagemaker_session):
73+
ecr_uri = _build_container(sagemaker_session, "3.10", DOCKERFILE_TEMPLATE_WITH_CONDA)
4874
return ecr_uri
4975

5076

51-
def _build_container(sagemaker_session, py_version):
77+
def _build_container(sagemaker_session, py_version, docker_templete):
5278
"""Build a dummy test container locally and push a container to an ecr repo"""
5379

5480
region = sagemaker_session.boto_region_name
@@ -62,7 +88,7 @@ def _build_container(sagemaker_session, py_version):
6288
source_archive = _generate_and_move_sagemaker_sdk_tar(tmpdir)
6389
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
6490
file.writelines(
65-
DOCKERFILE_TEMPLATE.format(py_version=py_version, source_archive=source_archive)
91+
docker_templete.format(py_version=py_version, source_archive=source_archive)
6692
)
6793

6894
docker_client = docker.from_env()

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,69 @@ def cuberoot(x):
118118
assert cuberoot(27) == 3
119119

120120

121+
def test_additional_dependencies_with_job_conda_env(
122+
sagemaker_session, dummy_container_with_conda, cpu_instance_type
123+
):
124+
dependencies_path = os.path.join(DATA_DIR, "remote_function/requirements.txt")
125+
126+
@remote(
127+
role=ROLE,
128+
image_uri=dummy_container_with_conda,
129+
dependencies=dependencies_path,
130+
instance_type=cpu_instance_type,
131+
sagemaker_session=sagemaker_session,
132+
job_conda_env="integ_test_env",
133+
)
134+
def cuberoot(x):
135+
from scipy.special import cbrt
136+
137+
return cbrt(x)
138+
139+
assert cuberoot(27) == 3
140+
141+
142+
def test_additional_dependencies_with_default_conda_env(
143+
sagemaker_session, dummy_container_with_conda, cpu_instance_type
144+
):
145+
dependencies_path = os.path.join(DATA_DIR, "remote_function/requirements.txt")
146+
147+
@remote(
148+
role=ROLE,
149+
image_uri=dummy_container_with_conda,
150+
dependencies=dependencies_path,
151+
instance_type=cpu_instance_type,
152+
sagemaker_session=sagemaker_session,
153+
)
154+
def cuberoot(x):
155+
from scipy.special import cbrt
156+
157+
return cbrt(x)
158+
159+
assert cuberoot(27) == 3
160+
161+
162+
def test_additional_dependencies_with_non_existent_conda_env(
163+
sagemaker_session, dummy_container_with_conda, cpu_instance_type
164+
):
165+
dependencies_path = os.path.join(DATA_DIR, "remote_function/requirements.txt")
166+
167+
@remote(
168+
role=ROLE,
169+
image_uri=dummy_container_with_conda,
170+
dependencies=dependencies_path,
171+
instance_type=cpu_instance_type,
172+
sagemaker_session=sagemaker_session,
173+
job_conda_env="non_existent_env",
174+
)
175+
def cuberoot(x):
176+
from scipy.special import cbrt
177+
178+
return cbrt(x)
179+
180+
with pytest.raises(RuntimeEnvironmentError):
181+
cuberoot(27) == 3
182+
183+
121184
def test_with_non_existent_dependencies(
122185
sagemaker_session, dummy_container_without_error, cpu_instance_type
123186
):

tests/unit/sagemaker/remote_function/core/test_runtime_environment.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,8 @@ def test_bootstrap_req_txt(mock_s3_download, mock_cwd, mock_session):
135135
call_args = popen.call_args[0][0]
136136
assert call_args is not None
137137

138-
expected = [python_exe, "-m", "pip", "install", "-r", local_file_path]
139-
for i, v in enumerate(expected):
140-
assert call_args[i] == v
138+
expected_cmd = "{} -m pip install -r {}".format(python_exe, local_file_path)
139+
assert call_args == expected_cmd
141140

142141

143142
@patch("sagemaker.remote_function.core.runtime_environment._log_error", Mock())
@@ -164,13 +163,40 @@ def test_bootstrap_req_txt_error(mock_s3_download, mock_cwd, mock_session):
164163
call_args = popen.call_args[0][0]
165164
assert call_args is not None
166165

167-
expected = [python_exe, "-m", "pip", "install", "-r", local_file_path]
168-
for i, v in enumerate(expected):
169-
assert call_args[i] == v
166+
expected_cmd = "{} -m pip install -r {}".format(python_exe, local_file_path)
167+
assert call_args == expected_cmd
170168

171169

172-
def test_bootstrap_req_txt_with_conda_env():
173-
pass
170+
@patch("sagemaker.remote_function.core.runtime_environment._log_error", Mock())
171+
@patch("sagemaker.remote_function.core.runtime_environment._log_output", Mock())
172+
@patch(
173+
"sagemaker.remote_function.core.runtime_environment.RuntimeEnvironmentManager._write_conda_env_to_file",
174+
Mock(),
175+
)
176+
@patch("os.getcwd", return_value="/usr/local/path")
177+
@patch("sagemaker.s3.S3Downloader.download")
178+
def test_bootstrap_req_txt_with_conda_env(mock_s3_download, mock_cwd, mock_session):
179+
with patch("subprocess.Popen") as popen:
180+
popen.return_value.wait.return_value = 0
181+
runtime = RuntimeEnvironmentManager(TEST_S3_BASE_URI, TEST_S3_KMS_KEY, mock_session)
182+
183+
dependencies_s3_uri = TEST_S3_BASE_URI + "/additional_dependencies/requirements.txt"
184+
job_conda_env = "conda_env"
185+
runtime.bootstrap(dependencies_s3_uri, job_conda_env)
186+
187+
mock_s3_download.assert_called_once_with(
188+
dependencies_s3_uri, mock_cwd.return_value, TEST_S3_KMS_KEY, mock_session
189+
)
190+
191+
local_file_path = "/usr/local/path/requirements.txt"
192+
193+
call_args = popen.call_args[0][0]
194+
assert call_args is not None
195+
196+
expected_cmd = '/bin/bash -c "mamba run -n {} pip install -r {}"'.format(
197+
job_conda_env, local_file_path
198+
)
199+
assert call_args == expected_cmd
174200

175201

176202
def test_bootstrap_conda_yml():

0 commit comments

Comments
 (0)