Skip to content

Commit 6e2b673

Browse files
committed
fix: Add write permission to job output dirs for remote and step decorator running on non-root job user
1 parent 889d7b7 commit 6e2b673

File tree

5 files changed

+300
-6
lines changed

5 files changed

+300
-6
lines changed

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from __future__ import absolute_import
1515

1616
import argparse
17+
import getpass
18+
import subprocess
1719
import sys
1820
import os
1921
import shutil
@@ -24,12 +26,14 @@
2426
RuntimeEnvironmentManager,
2527
_DependencySettings,
2628
get_logger,
29+
RuntimeEnvironmentError,
2730
)
2831
else:
2932
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
3033
RuntimeEnvironmentManager,
3134
_DependencySettings,
3235
get_logger,
36+
RuntimeEnvironmentError,
3337
)
3438

3539
SUCCESS_EXIT_CODE = 0
@@ -38,6 +42,7 @@
3842
REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws"
3943
BASE_CHANNEL_PATH = "/opt/ml/input/data"
4044
FAILURE_REASON_PATH = "/opt/ml/output/failure"
45+
JOB_OUTPUT_DIRS_NEEDED_PERMISSION_CHANGE = ["/opt/ml/output", "/opt/ml/model", "/tmp"]
4146
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
4247
JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace"
4348
SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies"
@@ -63,6 +68,17 @@ def main(sys_args=None):
6368

6469
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
6570

71+
user = getpass.getuser()
72+
if user != "root":
73+
log_message = (
74+
"The job is running on non-root user: %s. Adding write permissions to the "
75+
"following job output directories: %s."
76+
)
77+
logger.info(log_message, user, JOB_OUTPUT_DIRS_NEEDED_PERMISSION_CHANGE)
78+
_change_dir_permission(
79+
dirs=JOB_OUTPUT_DIRS_NEEDED_PERMISSION_CHANGE, new_permission="777"
80+
)
81+
6682
if pipeline_execution_id:
6783
_bootstrap_runtime_env_for_pipeline_step(
6884
client_python_version, func_step_workspace, conda_env, dependency_settings
@@ -81,6 +97,33 @@ def main(sys_args=None):
8197
sys.exit(exit_code)
8298

8399

100+
def _change_dir_permission(dirs: list, new_permission: str):
101+
"""Change the permission of given directories
102+
103+
Args:
104+
dirs (list[str]): A list of directories for permission update.
105+
new_permission (str): The new permission for the given directories.
106+
"""
107+
108+
_ERROR_MSG_PREFIX = "Failed to change directory permissions due to: "
109+
command = ["sudo", "chmod", "-R", new_permission] + dirs
110+
logger.info("Executing '%s'.", {" ".join(command)})
111+
112+
try:
113+
subprocess.run(command, check=True, stderr=subprocess.PIPE)
114+
except subprocess.CalledProcessError as called_process_err:
115+
err_msg = called_process_err.stderr.decode("utf-8")
116+
raise RuntimeEnvironmentError(f"{_ERROR_MSG_PREFIX} {err_msg}")
117+
except FileNotFoundError as file_not_found_err:
118+
if "[Errno 2] No such file or directory: 'sudo'" in str(file_not_found_err):
119+
raise RuntimeEnvironmentError(
120+
f"{_ERROR_MSG_PREFIX} {file_not_found_err}. "
121+
"Please contact the image owner to install 'sudo' in the job container "
122+
"and provide sudo privilege to the container user."
123+
)
124+
raise RuntimeEnvironmentError(file_not_found_err)
125+
126+
84127
def _bootstrap_runtime_env_for_remote_function(
85128
client_python_version: str,
86129
conda_env: str = None,

tests/integ/sagemaker/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@
6868
"RUN curl 'https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip' -o 'awscliv2.zip' \
6969
&& unzip awscliv2.zip \
7070
&& ./aws/install\n\n"
71+
"RUN apt install sudo\n"
7172
"RUN useradd -ms /bin/bash integ-test-user\n"
73+
# Add the user to sudo group
74+
"RUN usermod -aG sudo integ-test-user\n"
75+
# Ensure passwords are not required for sudo group users
76+
"RUN echo '%sudo ALL= (ALL) NOPASSWD:ALL' >> /etc/sudoers\n"
7277
"USER integ-test-user\n"
7378
"WORKDIR /home/integ-test-user\n"
7479
"COPY {source_archive} ./\n"

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,25 @@ def cuberoot(x):
747747
assert cuberoot(27) == 3
748748

749749

750+
def test_with_user_and_workdir_set_in_the_image_client_error_case(
751+
sagemaker_session, dummy_container_with_user_and_workdir, cpu_instance_type
752+
):
753+
client_error_message = "Testing client error in job."
754+
755+
@remote(
756+
role=ROLE,
757+
image_uri=dummy_container_with_user_and_workdir,
758+
instance_type=cpu_instance_type,
759+
sagemaker_session=sagemaker_session,
760+
)
761+
def my_func():
762+
raise RuntimeError(client_error_message)
763+
764+
with pytest.raises(RuntimeError) as error:
765+
my_func()
766+
assert client_error_message in str(error)
767+
768+
750769
@pytest.mark.skip
751770
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
752771
@remote(

tests/integ/sagemaker/workflow/test_step_decorator.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,49 @@ def cuberoot(x):
858858
pipeline.delete()
859859
except Exception:
860860
pass
861+
862+
863+
def test_with_user_and_workdir_set_in_the_image_client_error_case(
864+
sagemaker_session, role, pipeline_name, region_name, dummy_container_with_user_and_workdir
865+
):
866+
# This test aims to ensure client error in step decorated function
867+
# can be successfully surfaced and the job can be failed.
868+
os.environ["AWS_DEFAULT_REGION"] = region_name
869+
client_error_message = "Testing client error in job."
870+
871+
@step(
872+
role=role,
873+
image_uri=dummy_container_with_user_and_workdir,
874+
instance_type=INSTANCE_TYPE,
875+
)
876+
def my_func():
877+
raise RuntimeError(client_error_message)
878+
879+
step_a = my_func()
880+
881+
pipeline = Pipeline(
882+
name=pipeline_name,
883+
steps=[step_a],
884+
sagemaker_session=sagemaker_session,
885+
)
886+
887+
try:
888+
_, execution_steps = create_and_execute_pipeline(
889+
pipeline=pipeline,
890+
pipeline_name=pipeline_name,
891+
region_name=region_name,
892+
role=role,
893+
no_of_steps=1,
894+
last_step_name=get_step(step_a).name,
895+
execution_parameters=dict(),
896+
step_status="Failed",
897+
)
898+
assert (
899+
f"ClientError: AlgorithmError: RuntimeError('{client_error_message}')"
900+
in execution_steps[0]["FailureReason"]
901+
)
902+
finally:
903+
try:
904+
pipeline.delete()
905+
except Exception:
906+
pass

0 commit comments

Comments
 (0)