Skip to content

Commit 5cd4c8e

Browse files
Rohan GujarathiNamrata Madan
authored andcommitted
feature: support conda env creation for pathways
1 parent cc00169 commit 5cd4c8e

File tree

4 files changed

+109
-23
lines changed

4 files changed

+109
-23
lines changed

src/sagemaker/remote_function/core/runtime_environment.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class RuntimeEnvironmentManager:
3131
def __init__(
3232
self, s3_base_uri: str = None, s3_kms_key: str = None, sagemaker_session: Session = None
3333
):
34+
"""Initializes RuntimeEnvironemntManager instance
35+
36+
Args:
37+
sagemaker_session (Session): Current sagemaker session
38+
s3_base_uri (str): Base S3 URI where dependencies file is uploaded
39+
s3_kms_key (str): KMS key to access the S3 bucket
40+
41+
"""
3442
self.s3_base_uri = s3_base_uri
3543
self.s3_kms_key = s3_kms_key
3644
self.sagemaker_session = sagemaker_session
@@ -44,9 +52,7 @@ def snapshot(self, dependencies: str = None):
4452
user's active conda env and upload the yml file to S3
4553
4654
Args:
47-
sagemaker_session (Session): Current sagemaker session
48-
s3_base_uri (str): Base S3 URI where dependencies file is uploaded
49-
s3_kms_key (str): KMS key to access the S3 bucket
55+
dependencies (str): Local path where dependencies file exists.
5056
5157
Returns:
5258
S3 URI where the dependencies file is uploaded or None
@@ -86,14 +92,15 @@ def bootstrap(self, dependencies_s3_uri: str, job_conda_env: str = None):
8692
Returns: None
8793
"""
8894

89-
if dependencies_s3_uri.endswith(".txt"):
90-
local_path = os.getcwd()
91-
S3Downloader.download(
92-
dependencies_s3_uri, local_path, self.s3_kms_key, self.sagemaker_session
93-
)
95+
local_path = os.getcwd()
96+
S3Downloader.download(
97+
dependencies_s3_uri, local_path, self.s3_kms_key, self.sagemaker_session
98+
)
99+
100+
local_dependencies_file = os.path.join(local_path, dependencies_s3_uri.split("/")[-1])
101+
conda_env = job_conda_env or os.getenv("DEFAULT_CONDA_ENV")
94102

95-
local_dependencies_file = os.path.join(local_path, dependencies_s3_uri.split("/")[-1])
96-
conda_env = job_conda_env or os.getenv("DEFAULT_CONDA_ENV")
103+
if dependencies_s3_uri.endswith(".txt"):
97104
if conda_env:
98105
self._install_req_txt_in_conda_env(conda_env, local_dependencies_file)
99106
self._write_conda_env_to_file(conda_env)
@@ -102,12 +109,10 @@ def bootstrap(self, dependencies_s3_uri: str, job_conda_env: str = None):
102109
self._install_requirements_txt(local_dependencies_file, _python_executable())
103110

104111
elif dependencies_s3_uri.endswith(".yml"):
105-
# TODO: implement
106-
# 1. verify is conda exists in the image
107-
# 2. if job_conda_env: activate and update the conda env with yml
108-
# 3. if not, create and activate conda env from conda yml file
109-
return
110-
return
112+
# TODO: implement updating conda env if either job_conda_env or default_conda_env is provided.
113+
conda_env_name = "sagemaker-runtime-env"
114+
self._create_conda_env(conda_env_name, local_dependencies_file)
115+
self._write_conda_env_to_file(conda_env_name)
111116

112117
def _is_file_exists(self, dependencies):
113118
"""Check whether the dependencies file exists at the given location.
@@ -127,10 +132,15 @@ def _install_requirements_txt(self, local_path, python_executable):
127132
_run_shell_cmd(cmd)
128133
logger.info("Command {} ran successfully".format(cmd))
129134

130-
def _create_conda_env(self):
135+
def _create_conda_env(self, env_name, local_path):
131136
"""Create conda env using conda yml file"""
132-
# TODO: implement
133-
pass # pylint: disable=W0107
137+
138+
cmd = '/bin/bash -c "{} env create -n {} --file {}"'.format(
139+
self._get_conda_exe(), env_name, local_path
140+
)
141+
logger.info("Creating conda environment {} using: {}.".format(env_name, cmd))
142+
_run_shell_cmd(cmd)
143+
logger.info("Conda environment {} created successfully.".format(env_name))
134144

135145
def _install_req_txt_in_conda_env(self, env_name, local_path):
136146
"""Install requirements.txt in the given conda environment"""

tests/integ/sagemaker/remote_function/conftest.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,22 @@
4848
"COPY {source_archive} ./\n"
4949
"RUN pip install '{source_archive}[remote_function]' \
5050
&& mamba run -n base pip install '{source_archive}[remote_function]' \
51-
&& mamba run -n integ_test_env pip install '{source_archive}[remote_function]' \
52-
&& rm {source_archive}\n"
51+
&& mamba run -n integ_test_env pip install '{source_archive}[remote_function]'\n"
5352
"ENV SHELL=/bin/bash\n"
5453
"ENV DEFAULT_CONDA_ENV=base\n"
5554
)
5655

56+
CONDA_YML_FILE_TEMPLATE = (
57+
"name: test_conda_env\n"
58+
"channels:\n"
59+
" - defaults\n"
60+
"dependencies:\n"
61+
" - scipy=1.10.0\n"
62+
" - pip:\n"
63+
" - sagemaker-{sagemaker_version}.tar.gz[remote_function]\n"
64+
"prefix: /opt/conda/bin/conda\n"
65+
)
66+
5767

5868
@pytest.fixture(scope="package")
5969
def dummy_container_without_error(sagemaker_session):
@@ -74,6 +84,23 @@ def dummy_container_with_conda(sagemaker_session):
7484
return ecr_uri
7585

7686

87+
@pytest.fixture(scope="package")
88+
def conda_env_yml():
89+
"""Write conda yml file needed for tests"""
90+
91+
conda_yml_file_name = "conda_env.yml"
92+
with open(os.path.join(os.getcwd(), "VERSION"), "r") as version_file:
93+
sagemaker_version = version_file.readline().strip()
94+
conda_file_path = os.path.join(os.getcwd(), conda_yml_file_name)
95+
with open(conda_file_path, "w") as yml_file:
96+
yml_file.writelines(CONDA_YML_FILE_TEMPLATE.format(sagemaker_version=sagemaker_version))
97+
yield conda_file_path
98+
99+
# cleanup
100+
if os.path.isfile(conda_yml_file_name):
101+
os.remove(conda_yml_file_name)
102+
103+
77104
def _build_container(sagemaker_session, py_version, docker_templete):
78105
"""Build a dummy test container locally and push a container to an ecr repo"""
79106

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,24 @@ def cuberoot(x):
203203
cuberoot(27) == 3
204204

205205

206+
def test_additional_dependencies_with_conda_yml_file(
207+
sagemaker_session, dummy_container_with_conda, cpu_instance_type, conda_env_yml
208+
):
209+
@remote(
210+
role=ROLE,
211+
image_uri=dummy_container_with_conda,
212+
dependencies=conda_env_yml,
213+
instance_type=cpu_instance_type,
214+
sagemaker_session=sagemaker_session,
215+
)
216+
def cuberoot(x):
217+
from scipy.special import cbrt
218+
219+
return cbrt(x)
220+
221+
assert cuberoot(27) == 3
222+
223+
206224
def test_with_non_existent_dependencies(
207225
sagemaker_session, dummy_container_without_error, cpu_instance_type
208226
):

tests/unit/sagemaker/remote_function/core/test_runtime_environment.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,36 @@ def test_bootstrap_req_txt_with_conda_env(mock_s3_download, mock_cwd, mock_conda
203203
assert call_args == expected_cmd
204204

205205

206-
def test_bootstrap_conda_yml():
207-
pass
206+
@patch("sagemaker.remote_function.core.runtime_environment._log_error", Mock())
207+
@patch("sagemaker.remote_function.core.runtime_environment._log_output", Mock())
208+
@patch(
209+
"sagemaker.remote_function.core.runtime_environment.RuntimeEnvironmentManager._write_conda_env_to_file",
210+
Mock(),
211+
)
212+
@patch(
213+
"sagemaker.remote_function.core.runtime_environment.RuntimeEnvironmentManager._get_conda_exe",
214+
return_value="conda",
215+
)
216+
@patch("os.getcwd", return_value="/usr/local/path")
217+
@patch("sagemaker.s3.S3Downloader.download")
218+
def test_bootstrap_conda_yml_create_env(mock_s3_download, mock_cwd, mock_conda_exe, mock_session):
219+
with patch("subprocess.Popen") as popen:
220+
popen.return_value.wait.return_value = 0
221+
runtime = RuntimeEnvironmentManager(TEST_S3_BASE_URI, TEST_S3_KMS_KEY, mock_session)
222+
223+
dependencies_s3_uri = TEST_S3_BASE_URI + "/additional_dependencies/conda_env.yml"
224+
runtime.bootstrap(dependencies_s3_uri)
225+
226+
mock_s3_download.assert_called_once_with(
227+
dependencies_s3_uri, mock_cwd.return_value, TEST_S3_KMS_KEY, mock_session
228+
)
229+
230+
local_file_path = "/usr/local/path/conda_env.yml"
231+
232+
call_args = popen.call_args[0][0]
233+
assert call_args is not None
234+
235+
expected_cmd = '/bin/bash -c "{} env create -n {} --file {}"'.format(
236+
mock_conda_exe.return_value, "sagemaker-runtime-env", local_file_path
237+
)
238+
assert call_args == expected_cmd

0 commit comments

Comments
 (0)