Skip to content

Commit b4125e8

Browse files
rohangujarathiRohan Gujarathi
authored andcommitted
update: validate python version (aws#888)
Co-authored-by: Rohan Gujarathi <[email protected]>
1 parent c6d1c1a commit b4125e8

File tree

10 files changed

+235
-34
lines changed

10 files changed

+235
-34
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,12 @@ def wrapper(*args, **kwargs):
170170
if (
171171
"FailureReason" in describe_result
172172
and describe_result["FailureReason"]
173+
and "RuntimeEnvironmentError: " in describe_result["FailureReason"]
173174
):
174-
raise RuntimeEnvironmentError(describe_result["FailureReason"])
175+
failure_msg = describe_result["FailureReason"].replace(
176+
"RuntimeEnvironmentError: ", ""
177+
)
178+
raise RuntimeEnvironmentError(failure_msg)
175179
raise RemoteFunctionError(
176180
"Failed to execute remote function. "
177181
+ "Check corresponding job for details."
@@ -594,10 +598,13 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
594598
if (
595599
"FailureReason" in describe_training_job_response
596600
and describe_training_job_response["FailureReason"]
601+
and "RuntimeEnvironmentError: "
602+
in describe_training_job_response["FailureReason"]
597603
):
598-
job_exception = RuntimeEnvironmentError(
599-
describe_training_job_response["FailureReason"]
604+
failure_msg = describe_training_job_response["FailureReason"].replace(
605+
"RuntimeEnvironmentError: ", ""
600606
)
607+
job_exception = RuntimeEnvironmentError(failure_msg)
601608
else:
602609
job_exception = RemoteFunctionError(
603610
"Failed to execute remote function. "
@@ -687,10 +694,13 @@ def result(self, timeout: float = None) -> Any:
687694
if (
688695
"FailureReason" in self._job.describe()
689696
and self._job.describe()["FailureReason"]
697+
and "RuntimeEnvironmentError: "
698+
in self._job.describe()["FailureReason"]
690699
):
691-
self._exception = RuntimeEnvironmentError(
692-
self._job.describe()["FailureReason"]
700+
failure_msg = self._job.describe()["FailureReason"].replace(
701+
"RuntimeEnvironmentError: ", ""
693702
)
703+
self._exception = RuntimeEnvironmentError(failure_msg)
694704
else:
695705
self._exception = RemoteFunctionError(
696706
"Failed to execute remote function. "

src/sagemaker/remote_function/job.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
354354

355355
container_args = ["--s3_base_uri", s3_base_uri]
356356
container_args.extend(["--region", job_settings.sagemaker_session.boto_region_name])
357+
container_args.extend(
358+
["--client_python_version", RuntimeEnvironmentManager()._current_python_version()]
359+
)
357360
if job_settings.s3_kms_key:
358361
container_args.extend(["--s3_kms_key", job_settings.s3_kms_key])
359362

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,22 @@ def main():
5050

5151
try:
5252
args = _parse_agrs()
53+
client_python_version = args.client_python_version
54+
job_conda_env = args.job_conda_env
55+
56+
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
57+
58+
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
5359

5460
_execute_pre_exec_cmds()
5561

56-
_bootstrap_runtime_environment(job_conda_env=args.job_conda_env)
62+
_bootstrap_runtime_environment(client_python_version, conda_env)
5763

5864
exit_code = SUCCESS_EXIT_CODE
5965
except Exception as e: # pylint: disable=broad-except
6066
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)
61-
_write_failure_reason_file(repr(e))
67+
68+
_write_failure_reason_file(str(e))
6269
finally:
6370
sys.exit(exit_code)
6471

@@ -70,12 +77,13 @@ def _execute_pre_exec_cmds():
7077

7178

7279
def _bootstrap_runtime_environment(
73-
job_conda_env: str = None,
80+
client_python_version: str,
81+
conda_env: str = None,
7482
):
7583
"""Bootstrap runtime environment for remote function invocation
7684
7785
Args:
78-
job_conda_env (str): conda environment to be activated. Default is None.
86+
conda_env (str): conda environment to be activated. Default is None.
7987
"""
8088
workspace_archive_dir_path = os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE)
8189

@@ -109,7 +117,9 @@ def _bootstrap_runtime_environment(
109117
if dependencies_file:
110118
_create_sm_installer_symlinks()
111119
RuntimeEnvironmentManager().bootstrap(
112-
local_dependencies_file=dependencies_file, job_conda_env=job_conda_env
120+
local_dependencies_file=dependencies_file,
121+
conda_env=conda_env,
122+
client_python_version=client_python_version,
113123
)
114124
else:
115125
logger.info(
@@ -151,13 +161,14 @@ def _write_failure_reason_file(failure_msg):
151161
"""
152162
if not os.path.exists(FAILURE_REASON_PATH):
153163
with open(FAILURE_REASON_PATH, "w") as f:
154-
f.write(failure_msg)
164+
f.write("RuntimeEnvironmentError: " + failure_msg)
155165

156166

157167
def _parse_agrs():
158168
"""Parses CLI arguments."""
159169
parser = argparse.ArgumentParser()
160170
parser.add_argument("--job_conda_env", type=str)
171+
parser.add_argument("--client_python_version")
161172
args, _ = parser.parse_known_args()
162173
return args
163174

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import sys
19+
import shlex
1920
import os
2021
import subprocess
2122
import tempfile
@@ -76,18 +77,18 @@ def snapshot(self, dependencies: str = None) -> str:
7677

7778
raise ValueError(f'Invalid dependencies provided: "{dependencies}"')
7879

79-
def bootstrap(self, local_dependencies_file: str, job_conda_env: str = None):
80+
def bootstrap(
81+
self, local_dependencies_file: str, client_python_version: str, conda_env: str = None
82+
):
8083
"""Bootstraps the runtime environment by installing the additional dependencies if any.
8184
8285
Args:
8386
dependencies_s3_uri (str): S3 URI where dependencies file exists.
84-
job_conda_env (str): conda environment to be activated. Default is None.
87+
conda_env (str): conda environment to be activated. Default is None.
8588
8689
Returns: None
8790
"""
8891

89-
conda_env = job_conda_env or os.getenv("JOB_CONDA_ENV")
90-
9192
if local_dependencies_file.endswith(".txt"):
9293
if conda_env:
9394
self._install_req_txt_in_conda_env(conda_env, local_dependencies_file)
@@ -102,6 +103,7 @@ def bootstrap(self, local_dependencies_file: str, job_conda_env: str = None):
102103
else:
103104
conda_env = "sagemaker-runtime-env"
104105
self._create_conda_env(conda_env, local_dependencies_file)
106+
self._validate_python_version(client_python_version, conda_env)
105107
self._write_conda_env_to_file(conda_env)
106108

107109
def _is_file_exists(self, dependencies):
@@ -176,6 +178,44 @@ def _get_conda_exe(self):
176178
return "conda"
177179
raise ValueError("Neither conda nor mamba is installed on the image")
178180

181+
def _python_version_in_conda_env(self, env_name):
182+
"""Returns python version inside a conda environment"""
183+
cmd = f"{self._get_conda_exe()} run -n {env_name} python --version"
184+
try:
185+
output = (
186+
subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT)
187+
.decode("utf-8")
188+
.strip()
189+
)
190+
# convert 'Python 3.7.16' to [3, 7, 16]
191+
version = output.split("Python ")[1].split(".")
192+
return version[0] + "." + version[1]
193+
except subprocess.CalledProcessError as e:
194+
raise RuntimeEnvironmentError(e.output)
195+
196+
def _current_python_version(self):
197+
"""Returns the current python version where program is running"""
198+
199+
return f"{sys.version_info.major}.{sys.version_info.minor}"
200+
201+
def _validate_python_version(self, client_python_version: str, conda_env: str = None):
202+
"""Validate the python version
203+
204+
Validates if the python version where remote function runs
205+
matches the one used on client side.
206+
"""
207+
if conda_env:
208+
job_python_version = self._python_version_in_conda_env(conda_env)
209+
else:
210+
job_python_version = self._current_python_version()
211+
if client_python_version != job_python_version:
212+
raise RuntimeEnvironmentError(
213+
f"Python version found in the container is {job_python_version} which "
214+
f"does not match python version {client_python_version} on the local client . "
215+
f"Please make sure that the python version used in the training container "
216+
f"is same as the local python version."
217+
)
218+
179219

180220
def _run_shell_cmd(cmd: str):
181221
"""This method runs a given shell command using subprocess

tests/integ/sagemaker/remote_function/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
&& mamba run -n default_env pip install '{source_archive}[remote_function]' \
5252
&& mamba run -n integ_test_env pip install '{source_archive}[remote_function]'\n"
5353
"ENV SHELL=/bin/bash\n"
54-
"ENV JOB_CONDA_ENV=default_env\n"
54+
"ENV SAGEMAKER_JOB_CONDA_ENV=default_env\n"
5555
)
5656

5757
CONDA_YML_FILE_TEMPLATE = (

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def divide(x, y):
7474
divide(10, 0)
7575

7676

77-
@pytest.mark.skip
7877
def test_remote_python_runtime_is_incompatible(
7978
sagemaker_session, dummy_container_incompatible_python_runtime, cpu_instance_type
8079
):
@@ -87,8 +86,13 @@ def test_remote_python_runtime_is_incompatible(
8786
def divide(x, y):
8887
return x / y
8988

90-
# TODO: should raise serialization error
91-
with pytest.raises(RuntimeError):
89+
with pytest.raises(
90+
RuntimeEnvironmentError,
91+
match=(
92+
"Please make sure that the python version used in the training container "
93+
"is same as the local python version."
94+
),
95+
):
9296
divide(10, 2)
9397

9498

tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,26 @@
2222

2323
TEST_JOB_CONDA_ENV = "conda_env"
2424
TEST_DEPENDENCIES_PATH = "/user/set/workdir"
25+
TEST_PYTHON_VERSION = "3.10"
2526
TEST_WORKSPACE_ARCHIVE_DIR_PATH = "/opt/ml/input/data/sm_rf_user_ws"
2627
TEST_WORKSPACE_ARCHIVE_PATH = "/opt/ml/input/data/sm_rf_user_ws/workspace.zip"
2728

2829

2930
def mock_args():
3031
args = Mock()
3132
args.job_conda_env = TEST_JOB_CONDA_ENV
33+
args.client_python_version = TEST_PYTHON_VERSION
3234
return args
3335

3436

3537
@patch(
3638
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs",
3739
new=mock_args,
3840
)
41+
@patch(
42+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
43+
"RuntimeEnvironmentManager._validate_python_version"
44+
)
3945
@patch(
4046
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._create_sm_installer_symlinks"
4147
)
@@ -49,9 +55,17 @@ def mock_args():
4955
"sagemaker.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager.bootstrap"
5056
)
5157
def test_main_success(
52-
bootstrap_runtime, list_dir, file_exists, path_exists, getcwd, _exit_process, symlink
58+
bootstrap_runtime,
59+
list_dir,
60+
file_exists,
61+
path_exists,
62+
getcwd,
63+
_exit_process,
64+
symlink,
65+
validate_python,
5366
):
5467
bootstrap.main()
68+
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
5569
path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH)
5670
file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH)
5771
getcwd.assert_called()
@@ -64,6 +78,10 @@ def test_main_success(
6478
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs",
6579
new=mock_args,
6680
)
81+
@patch(
82+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
83+
"RuntimeEnvironmentManager._validate_python_version"
84+
)
6785
@patch(
6886
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._create_sm_installer_symlinks"
6987
)
@@ -74,14 +92,15 @@ def test_main_success(
7492
@patch(
7593
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_environment"
7694
)
77-
def test_main_failure(bootstrap_runtime, write_failure, _exit_process, symlink):
95+
def test_main_failure(bootstrap_runtime, write_failure, _exit_process, symlink, validate_python):
7896
runtime_err = RuntimeEnvironmentError("some failure reason")
7997
bootstrap_runtime.side_effect = runtime_err
8098

8199
bootstrap.main()
82100

101+
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
83102
bootstrap_runtime.assert_called()
84-
write_failure.assert_called_with(repr(runtime_err))
103+
write_failure.assert_called_with(str(runtime_err))
85104
_exit_process.assert_called_with(1)
86105

87106

@@ -90,13 +109,20 @@ def test_main_failure(bootstrap_runtime, write_failure, _exit_process, symlink):
90109
new=mock_args,
91110
)
92111
@patch("sys.exit")
112+
@patch(
113+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
114+
"RuntimeEnvironmentManager._validate_python_version"
115+
)
93116
@patch(
94117
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file"
95118
)
96119
@patch("os.path.exists", return_value=False)
97-
def test_main_channel_folder_does_not_exist(path_exists, write_failure, _exit_process):
120+
def test_main_channel_folder_does_not_exist(
121+
path_exists, write_failure, validate_python, _exit_process
122+
):
98123
bootstrap.main()
99124
path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH)
125+
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
100126
write_failure.assert_not_called()
101127
_exit_process.assert_called_with(0)
102128

@@ -105,14 +131,21 @@ def test_main_channel_folder_does_not_exist(path_exists, write_failure, _exit_pr
105131
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs",
106132
new=mock_args,
107133
)
134+
@patch(
135+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
136+
".RuntimeEnvironmentManager._validate_python_version"
137+
)
108138
@patch("sys.exit")
109139
@patch("os.path.exists", return_value=True)
110140
@patch("os.path.isfile", return_value=False)
111141
@patch(
112142
"sagemaker.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager.bootstrap"
113143
)
114-
def test_main_no_workspace_archive(bootstrap_runtime, file_exists, path_exists, _exit_process):
144+
def test_main_no_workspace_archive(
145+
bootstrap_runtime, file_exists, path_exists, _exit_process, validate_python
146+
):
115147
bootstrap.main()
148+
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
116149
path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH)
117150
file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH)
118151
bootstrap_runtime.assert_not_called()
@@ -123,6 +156,10 @@ def test_main_no_workspace_archive(bootstrap_runtime, file_exists, path_exists,
123156
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs",
124157
new=mock_args,
125158
)
159+
@patch(
160+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
161+
"RuntimeEnvironmentManager._validate_python_version"
162+
)
126163
@patch("sys.exit")
127164
@patch("shutil.unpack_archive", Mock())
128165
@patch("os.path.exists", return_value=True)
@@ -133,9 +170,10 @@ def test_main_no_workspace_archive(bootstrap_runtime, file_exists, path_exists,
133170
"sagemaker.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager.bootstrap"
134171
)
135172
def test_main_no_dependency_file(
136-
bootstrap_runtime, list_dir, get_cwd, file_exists, path_exists, _exit_process
173+
bootstrap_runtime, list_dir, get_cwd, file_exists, path_exists, _exit_process, validate_python
137174
):
138175
bootstrap.main()
176+
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
139177
path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH)
140178
file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH)
141179
get_cwd.assert_called_once()

0 commit comments

Comments
 (0)