Skip to content

Commit bd32a9b

Browse files
rohangujarathiRohan Gujarathi
authored andcommitted
pathways: refactor job_driver entrypoint (aws#825)
Co-authored-by: Rohan Gujarathi <[email protected]>
1 parent 1a7dd7c commit bd32a9b

File tree

8 files changed

+192
-34
lines changed

8 files changed

+192
-34
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ def read_requirements(filename):
103103
],
104104
install_requires=required_packages,
105105
extras_require=extras,
106+
scripts=["src/sagemaker/remote_function/job_driver.sh"],
106107
entry_points={
107108
"console_scripts": [
108109
"sagemaker-upgrade-v2=sagemaker.cli.compatibility.v2.sagemaker_upgrade_v2:main",
109-
"invoke-remote-function=sagemaker.remote_function.job_driver:main",
110+
"invoke-remote-function=sagemaker.remote_function.invoke_function:main",
111+
"bootstrap-runtime-env=sagemaker.remote_function.bootstrap:main",
110112
]
111113
},
112114
)

src/sagemaker/remote_function/job_driver.py renamed to src/sagemaker/remote_function/bootstrap.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""An entry point for invoking remote function inside a job."""
13+
"""An entry point for runtime environment."""
1414
from __future__ import absolute_import
1515

1616
import argparse
@@ -20,9 +20,8 @@
2020

2121
from sagemaker.session import Session
2222
from sagemaker.remote_function.errors import handle_error
23-
from sagemaker.remote_function.core.runtime_environment import RuntimeEnvironmentManager
2423
from sagemaker.remote_function import logging_config
25-
24+
from sagemaker.remote_function.core.runtime_environment import RuntimeEnvironmentManager
2625

2726
SUCCESS_EXIT_CODE = 0
2827
DEFAULT_FAILURE_CODE = 1
@@ -40,12 +39,6 @@ def _parse_agrs():
4039
return parser.parse_args()
4140

4241

43-
def _get_sagemaker_session(region):
44-
"""Get sagemaker session for interacting with AWS or Sagemaker services"""
45-
boto_session = boto3.session.Session(region_name=region)
46-
return Session(boto_session=boto_session)
47-
48-
4942
def _execute_pre_exec_cmds():
5043
"""Execute pre-flight commands before invkoing remote function"""
5144
# TODO: complete me
@@ -58,12 +51,19 @@ def _uncompress_src_dir():
5851
pass # pylint: disable=W0107
5952

6053

61-
def _bootstrap_runtime_environment(
54+
def _get_sagemaker_session(region):
55+
"""Get sagemaker session for interacting with AWS or Sagemaker services"""
56+
boto_session = boto3.session.Session(region_name=region)
57+
return Session(boto_session=boto_session)
58+
59+
60+
def bootstrap_runtime_environment(
6261
runtime_manager: RuntimeEnvironmentManager,
6362
dependencies: str,
6463
job_conda_env: str = None,
6564
):
6665
"""Bootstrap runtime environment for remote function invocation"""
66+
6767
_execute_pre_exec_cmds()
6868

6969
if dependencies:
@@ -72,21 +72,14 @@ def _bootstrap_runtime_environment(
7272
_uncompress_src_dir()
7373

7474

75-
def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key):
76-
"""Execute stored remote function"""
77-
from sagemaker.remote_function.core.stored_function import StoredFunction
78-
79-
stored_function = StoredFunction(sagemaker_session, s3_base_uri, s3_kms_key)
80-
stored_function.load_and_invoke()
81-
82-
8375
def main():
84-
"""Entry point for job driver script"""
76+
"""Entry point for bootstrap script"""
8577

8678
logging_config.basic_config()
8779
logger = logging_config.get_logger()
8880

8981
exit_code = DEFAULT_FAILURE_CODE
82+
9083
try:
9184
args = _parse_agrs()
9285
region = args.region
@@ -103,16 +96,15 @@ def main():
10396
sagemaker_session=sagemaker_session,
10497
)
10598

106-
_bootstrap_runtime_environment(
99+
bootstrap_runtime_environment(
107100
runtime_manager=runtime_environment_manager,
108101
dependencies=dependencies,
109102
job_conda_env=job_conda_env,
110103
)
111104

112-
_execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key)
113105
exit_code = SUCCESS_EXIT_CODE
114106
except Exception as e: # pylint: disable=broad-except
115-
logger.exception("Error encountered when invoking the remote function.")
107+
logger.exception("Error encountered while bootstrapping runtime environment.")
116108
exit_code = handle_error(e, sagemaker_session, s3_base_uri, s3_kms_key)
117109
finally:
118110
sys.exit(exit_code)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""An entry point for invoking remote function inside a job."""
14+
15+
from __future__ import absolute_import
16+
17+
import argparse
18+
import sys
19+
20+
import boto3
21+
22+
from sagemaker.session import Session
23+
from sagemaker.remote_function.errors import handle_error
24+
from sagemaker.remote_function import logging_config
25+
26+
27+
SUCCESS_EXIT_CODE = 0
28+
29+
30+
def _parse_agrs():
31+
"""Parses CLI arguments."""
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument("--region", type=str, required=True)
34+
parser.add_argument("--s3_base_uri", type=str, required=True)
35+
parser.add_argument("--s3_kms_key", type=str)
36+
parser.add_argument("--dependencies", type=str)
37+
parser.add_argument("--job_conda_env", type=str)
38+
39+
return parser.parse_args()
40+
41+
42+
def _get_sagemaker_session(region):
43+
"""Get sagemaker session for interacting with AWS or Sagemaker services"""
44+
boto_session = boto3.session.Session(region_name=region)
45+
return Session(boto_session=boto_session)
46+
47+
48+
def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key):
49+
"""Execute stored remote function"""
50+
from sagemaker.remote_function.core.stored_function import StoredFunction
51+
52+
stored_function = StoredFunction(sagemaker_session, s3_base_uri, s3_kms_key)
53+
stored_function.load_and_invoke()
54+
55+
56+
def main():
57+
"""Entry point for invoke function script"""
58+
59+
logging_config.basic_config()
60+
logger = logging_config.get_logger()
61+
62+
exit_code = SUCCESS_EXIT_CODE
63+
64+
try:
65+
args = _parse_agrs()
66+
region = args.region
67+
s3_base_uri = args.s3_base_uri
68+
s3_kms_key = args.s3_kms_key
69+
70+
sagemaker_session = _get_sagemaker_session(region)
71+
_execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key)
72+
73+
except Exception as e: # pylint: disable=broad-except
74+
logger.exception("Error encountered while invoking the remote function.")
75+
exit_code = handle_error(e, sagemaker_session, s3_base_uri, s3_kms_key)
76+
finally:
77+
sys.exit(exit_code)
78+
79+
80+
if __name__ == "__main__":
81+
main()

src/sagemaker/remote_function/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.remote_function import logging_config
2727

2828

29-
JOBS_CONTAINER_ENTRYPOINT = ["invoke-remote-function"]
29+
JOBS_CONTAINER_ENTRYPOINT = ["/bin/bash", "job_driver.sh"]
3030

3131

3232
logger = logging_config.get_logger()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
# Entry point for bootstrapping runtime environment and invoking remote function
4+
5+
set -eu
6+
7+
printf "INFO: Bootstraping runtime environment\n"
8+
bootstrap-runtime-env "$@"
9+
10+
printf "INFO: Invoking remote function\n"
11+
invoke-remote-function "$@"
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock import patch, Mock
16+
from sagemaker.remote_function import bootstrap
17+
from sagemaker.remote_function.errors import RuntimeEnvironmentError
18+
19+
TEST_REGION = "us-west-2"
20+
TEST_S3_BASE_URI = "s3://my-bucket/"
21+
TEST_S3_KMS_KEY = "my-kms-key"
22+
TEST_DEPENDENCIES = "s3://my-bucket/requirements.txt"
23+
24+
25+
def mock_args():
26+
args = Mock()
27+
args.region = TEST_REGION
28+
args.s3_base_uri = TEST_S3_BASE_URI
29+
args.s3_kms_key = TEST_S3_KMS_KEY
30+
return args
31+
32+
33+
def mock_session():
34+
return Mock()
35+
36+
37+
@patch("sagemaker.remote_function.bootstrap._parse_agrs", new=mock_args)
38+
@patch("sys.exit")
39+
@patch("sagemaker.remote_function.core.runtime_environment.RuntimeEnvironmentManager.bootstrap")
40+
@patch(
41+
"sagemaker.remote_function.bootstrap._get_sagemaker_session",
42+
return_value=mock_session(),
43+
)
44+
def test_main_success(_get_sagemaker_session, bootstrap_runtime, _exit_process):
45+
bootstrap.main()
46+
47+
_get_sagemaker_session.assert_called_with(TEST_REGION)
48+
bootstrap_runtime.assert_called()
49+
_exit_process.assert_called_with(0)
50+
51+
52+
@patch("sagemaker.remote_function.bootstrap._parse_agrs", new=mock_args)
53+
@patch("sagemaker.remote_function.bootstrap.handle_error")
54+
@patch("sys.exit")
55+
@patch("sagemaker.remote_function.core.runtime_environment.RuntimeEnvironmentManager.bootstrap")
56+
@patch(
57+
"sagemaker.remote_function.bootstrap._get_sagemaker_session",
58+
return_value=mock_session(),
59+
)
60+
def test_main_failure(_get_sagemaker_session, bootstrap_runtime, _exit_process, handle_error):
61+
runtime_err = RuntimeEnvironmentError("some failure reason")
62+
bootstrap_runtime.side_effect = runtime_err
63+
handle_error.return_value = 1
64+
65+
bootstrap.main()
66+
67+
_get_sagemaker_session.assert_called_with(TEST_REGION)
68+
bootstrap_runtime.assert_called()
69+
handle_error.assert_called_with(
70+
runtime_err, _get_sagemaker_session(), TEST_S3_BASE_URI, TEST_S3_KMS_KEY
71+
)
72+
_exit_process.assert_called_with(1)

tests/unit/sagemaker/remote_function/test_job_driver.py renamed to tests/unit/sagemaker/remote_function/test_invoke_function.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
from mock import patch, Mock
16-
from sagemaker.remote_function import job_driver
16+
from sagemaker.remote_function import invoke_function
1717
from sagemaker.remote_function.errors import SerializationError
1818

1919
TEST_REGION = "us-west-2"
@@ -33,35 +33,35 @@ def mock_session():
3333
return Mock()
3434

3535

36-
@patch("sagemaker.remote_function.job_driver._parse_agrs", new=mock_args)
36+
@patch("sagemaker.remote_function.invoke_function._parse_agrs", new=mock_args)
3737
@patch("sys.exit")
3838
@patch("sagemaker.remote_function.core.stored_function.StoredFunction.load_and_invoke")
3939
@patch(
40-
"sagemaker.remote_function.job_driver._get_sagemaker_session",
40+
"sagemaker.remote_function.invoke_function._get_sagemaker_session",
4141
return_value=mock_session(),
4242
)
4343
def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process):
44-
job_driver.main()
44+
invoke_function.main()
4545

4646
_get_sagemaker_session.assert_called_with(TEST_REGION)
4747
load_and_invoke.assert_called()
4848
_exit_process.assert_called_with(0)
4949

5050

51-
@patch("sagemaker.remote_function.job_driver._parse_agrs", new=mock_args)
52-
@patch("sagemaker.remote_function.job_driver.handle_error")
51+
@patch("sagemaker.remote_function.invoke_function._parse_agrs", new=mock_args)
52+
@patch("sagemaker.remote_function.invoke_function.handle_error")
5353
@patch("sys.exit")
5454
@patch("sagemaker.remote_function.core.stored_function.StoredFunction.load_and_invoke")
5555
@patch(
56-
"sagemaker.remote_function.job_driver._get_sagemaker_session",
56+
"sagemaker.remote_function.invoke_function._get_sagemaker_session",
5757
return_value=mock_session(),
5858
)
5959
def test_main_failure(_get_sagemaker_session, load_and_invoke, _exit_process, handle_error):
6060
ser_err = SerializationError("some failure reason")
6161
load_and_invoke.side_effect = ser_err
6262
handle_error.return_value = 1
6363

64-
job_driver.main()
64+
invoke_function.main()
6565

6666
_get_sagemaker_session.assert_called_with(TEST_REGION)
6767
load_and_invoke.assert_called()

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_start(session, mock_stored_function, mock_runtime_manager):
128128
AlgorithmSpecification=dict(
129129
TrainingImage=IMAGE,
130130
TrainingInputMode="File",
131-
ContainerEntrypoint=["invoke-remote-function"],
131+
ContainerEntrypoint=["/bin/bash", "job_driver.sh"],
132132
ContainerArguments=[
133133
"--s3_base_uri",
134134
f"{S3_URI}/{job.job_name}",
@@ -190,7 +190,7 @@ def test_start_with_complete_job_settings(session, mock_stored_function, mock_ru
190190
AlgorithmSpecification=dict(
191191
TrainingImage=IMAGE,
192192
TrainingInputMode="File",
193-
ContainerEntrypoint=["invoke-remote-function"],
193+
ContainerEntrypoint=["/bin/bash", "job_driver.sh"],
194194
ContainerArguments=[
195195
"--s3_base_uri",
196196
f"{S3_URI}/{job.job_name}",

0 commit comments

Comments
 (0)