Skip to content

Commit 1da9f40

Browse files
nmadanNamrata Madan
andauthored
change: Pathways initial commit (aws#742)
Co-authored-by: Namrata Madan <[email protected]>
1 parent d0a4a56 commit 1da9f40

File tree

10 files changed

+601
-4
lines changed

10 files changed

+601
-4
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cloudpickle==2.2.0

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ requests==2.27.1
2020
sagemaker-experiments==0.1.35
2121
Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
23+
cloudpickle==2.2.0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def read_requirements(filename):
6666
extras = {
6767
"local": read_requirements("requirements/extras/local_requirements.txt"),
6868
"scipy": read_requirements("requirements/extras/scipy_requirements.txt"),
69+
"remote_function": read_requirements("requirements/extras/remote_function_requirements.txt"),
6970
}
7071
# Meta dependency groups
7172
extras["all"] = [item for group in extras.values() for item in group]

src/sagemaker/job_decorator.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""SageMaker @job decorator."""
14+
from __future__ import absolute_import
15+
16+
from typing import Dict, List, Tuple
17+
import functools
18+
19+
from sagemaker.job_executor import JobExecutor
20+
21+
22+
def job(
23+
_func=None,
24+
*,
25+
environment_variables: Dict[str, str] = None,
26+
image_uri: str = None,
27+
dependencies: str = None,
28+
instance_type: str = None,
29+
instance_count: int = 1,
30+
volume_size: int = 30,
31+
max_runtime_in_seconds: int = 24 * 60 * 60,
32+
max_retry_attempts: int = 0,
33+
keep_alive_period_in_seconds: int = 0,
34+
role: str = None,
35+
s3_root_uri: str = None,
36+
s3_kms_key: str = None,
37+
volume_kms_key: str = None,
38+
subnets: List[str] = None,
39+
security_group_ids: List[str] = None,
40+
tags: List[Tuple[str, str]] = None
41+
):
42+
"""Function that starts a new SageMaker job asynchronously with overridden runtime settings.
43+
44+
Args:
45+
_func (Optional): Python function to be executed on the SageMaker job runtime environment.
46+
environment_variables (Dict): environment variables
47+
image_uri (str): Docker image URI on ECR.
48+
dependencies (str): Path to dependencies file or a reserved keyword ``AUTO_DETECT``.
49+
instance_type (str): EC2 instance type.
50+
instance_count (int): Number of instance to use. Default is 1.
51+
volume_size (int): Size in GB of the storage volume to use for storing input and output
52+
data. Default is 30.
53+
max_runtime_in_seconds (int): Timeout in seconds for training. After this amount of time
54+
Amazon SageMaker terminates the job regardless of its current status.
55+
Default is 86400 seconds (1 day).
56+
max_retry_attempts (int): Max number of times the job is retried on InternalServerFailure.
57+
Default is 0.
58+
keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured
59+
resources in a warm pool for subsequent training jobs. Default is 0.
60+
role (str): IAM role used for SageMaker execution.
61+
s3_root_uri (str): The root S3 folder where the code archives and data are uploaded to.
62+
s3_kms_key (str): The encryption key used for storing serialized data.
63+
volume_kms_key (str): KMS key used for encrypting EBS volume attached to the training
64+
instance.
65+
subnets (List[str]): List of subnet IDs.
66+
security_group_ids (List[str]): List of security group IDs.
67+
tags (List[Tuple[str, str]]): List of tags attached to the job.
68+
"""
69+
70+
def _job(func):
71+
@functools.wraps(func)
72+
def wrapper(*args, **kwargs):
73+
executor = JobExecutor(
74+
environment_variables=environment_variables,
75+
image_uri=image_uri,
76+
dependencies=dependencies,
77+
instance_type=instance_type,
78+
instance_count=instance_count,
79+
volume_size=volume_size,
80+
max_runtime_in_seconds=max_runtime_in_seconds,
81+
max_retry_attempts=max_retry_attempts,
82+
keep_alive_period_in_seconds=keep_alive_period_in_seconds,
83+
role=role,
84+
s3_root_uri=s3_root_uri,
85+
s3_kms_key=s3_kms_key,
86+
volume_kms_key=volume_kms_key,
87+
subnets=subnets,
88+
security_group_ids=security_group_ids,
89+
tags=tags,
90+
)
91+
training_future = executor.submit(func, *args, **kwargs)
92+
return training_future.result()
93+
94+
return wrapper
95+
96+
if _func is None:
97+
return _job
98+
return _job(_func)

src/sagemaker/job_executor.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""SageMaker job Executor."""
14+
from __future__ import absolute_import
15+
16+
from typing import Any, Dict, List, Tuple
17+
import inspect
18+
import os
19+
import boto3
20+
21+
from sagemaker.job_runtime.serialization import JobSerializer
22+
from sagemaker.session import generate_default_sagemaker_bucket_name
23+
from sagemaker.utils import name_from_base, base_name_from_image
24+
25+
JOBS_CONTAINER_ENTRYPOINT = ["python3", "/code/source/job_entrypoint.py"]
26+
27+
28+
class JobExecutor(object):
29+
"""Handles asynchronous SageMaker jobs"""
30+
31+
def __init__(
32+
self,
33+
environment_variables: Dict[str, str] = None,
34+
image_uri: str = None,
35+
dependencies: str = None,
36+
instance_type: str = None,
37+
instance_count: int = 1,
38+
volume_size: int = 30,
39+
max_runtime_in_seconds: int = 24 * 60 * 60,
40+
max_retry_attempts: int = 0,
41+
keep_alive_period_in_seconds: int = 0,
42+
role: str = None,
43+
s3_root_uri: str = None,
44+
s3_kms_key: str = None,
45+
volume_kms_key: str = None,
46+
subnets: List[str] = None,
47+
security_group_ids: List[str] = None,
48+
tags: List[Tuple[str, str]] = None,
49+
boto_session: boto3.session.Session = None,
50+
):
51+
"""Initiates a ``JobExecutor`` instance.
52+
53+
Args:
54+
environment_variables (Dict): Environment variables passed to the underlying sagemaker
55+
job. Defaults to None
56+
image_uri (str): Docker image URI on ECR. Defaults to base Python image.
57+
dependencies (str): Path to dependencies file or a reserved keyword ``AUTO_DETECT``.
58+
Defaults to None.
59+
instance_type (str): EC2 instance type.
60+
instance_count (int): Number of instance to use. Defaults to 1.
61+
volume_size (int): Size in GB of the storage volume to use for storing input and output
62+
data. Defaults to 30.
63+
max_runtime_in_seconds (int): Timeout in seconds for training. After this amount of
64+
time Amazon SageMaker terminates the job regardless of its current status.
65+
Defaults to 86400 seconds (1 day).
66+
max_retry_attempts (int): Max number of times the job is retried on
67+
InternalServerFailure.Defaults to 0.
68+
keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured
69+
resources in a warm pool for subsequent training jobs. Defaults to 0.
70+
role (str): IAM role used for SageMaker execution. Defaults to SageMaker default
71+
execution role.
72+
s3_root_uri (str): The root S3 folder where the code archives and data are uploaded to.
73+
This parameter is autogenerated using information regarding the image uri if not
74+
provided.
75+
s3_kms_key (str): The encryption key used for storing serialized data. Defaults to S3
76+
managed key.
77+
volume_kms_key (str): KMS key used for encrypting EBS volume attached to the training
78+
instance.
79+
subnets (List[str]): List of subnet IDs. Defaults to None.
80+
security_group_ids (List[str]): List of security group IDs. Defaults to None.
81+
tags (List[Tuple[str, str]]): List of tags attached to the job. Defaults to None.
82+
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
83+
calls are delegated to (default: None). If not provided, one is created with
84+
default AWS configuration chain.
85+
"""
86+
self.environment_variables = environment_variables
87+
self.image_uri = image_uri
88+
self.dependencies = dependencies
89+
self.instance_type = instance_type
90+
self.instance_count = instance_count
91+
self.volume_size = volume_size
92+
self.max_runtime_in_seconds = max_runtime_in_seconds
93+
self.max_retry_attempts = max_retry_attempts
94+
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
95+
self.role = role
96+
self.boto_session = boto_session or boto3.Session()
97+
self.s3_root_uri = s3_root_uri or os.path.join(
98+
"s3://",
99+
generate_default_sagemaker_bucket_name(self.boto_session),
100+
base_name_from_image(self.image_uri),
101+
)
102+
self.s3_kms_key = s3_kms_key
103+
self.volume_kms_key = volume_kms_key
104+
self.subnets = subnets
105+
self.security_group_ids = security_group_ids
106+
self.tags = [] if tags is None else [{"Key": k, "Value": v} for k, v in tags]
107+
108+
def submit(self, func, *args, **kwargs):
109+
"""Execute the input func as a SageMaker job asynchronously.
110+
111+
Args:
112+
func: Python function to run as a SageMaker job.
113+
*args: Positional arguments to the input function.
114+
**kwargs: keyword arguments to the input function
115+
"""
116+
self._validate_submit_args(func, *args, **kwargs)
117+
118+
# TODO: retrieve and upload to S3 runtime config yaml file
119+
120+
# TODO: retrieve and aggregate sagemaker_config.yaml
121+
122+
job_serializer = JobSerializer(self.boto_session, self.s3_root_uri, self.s3_kms_key)
123+
# TODO: serialize data inputs
124+
function_s3_uri = job_serializer.serialize_function(func)
125+
126+
sagemaker_client = self.boto_session.client("sagemaker")
127+
response = sagemaker_client.create_training_job(
128+
TrainingJobName=name_from_base(func.__name__),
129+
AlgorithmSpecification={
130+
"TrainingImage": self.image_uri,
131+
"TrainingInputMode": "File",
132+
"ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT,
133+
# TODO: add additional container args
134+
"ContainerArguments": ["--function_uri", function_s3_uri],
135+
},
136+
RoleArn=self.role,
137+
ResourceConfig={
138+
"InstanceType": self.instance_type,
139+
"InstanceCount": self.instance_count,
140+
"VolumeSizeInGB": self.volume_size,
141+
"VolumeKmsKeyId": self.volume_kms_key,
142+
"KeepAlivePeriodInSeconds": self.keep_alive_period_in_seconds,
143+
},
144+
VpcConfig={"SecurityGroupIds": self.security_group_ids, "Subnets": self.subnets},
145+
StoppingCondition={
146+
"MaxRuntimeInSeconds": self.max_runtime_in_seconds,
147+
},
148+
Tags=[{"Key": k, "Value": v} for k, v in self.tags],
149+
Environment=self.environment_variables,
150+
RetryStrategy={"MaximumRetryAttempts": self.max_retry_attempts},
151+
)
152+
training_arn = response["TrainingJobArn"]
153+
return Future(training_arn)
154+
155+
@staticmethod
156+
def _validate_submit_args(func, *args, **kwargs):
157+
"""Validates input args passed to the submit() method."""
158+
if not inspect.isfunction(func):
159+
raise TypeError("Only python functions can be run as SageMaker jobs.")
160+
161+
full_arg_spec = inspect.getfullargspec(func)
162+
num_provided_args = len(args) + len(kwargs)
163+
minimum_num_expected_args = len(full_arg_spec.args)
164+
165+
is_accepting_variable_args = not (
166+
full_arg_spec.varargs is None and full_arg_spec.varkw is None
167+
)
168+
169+
if is_accepting_variable_args:
170+
if num_provided_args < minimum_num_expected_args:
171+
raise AttributeError(
172+
"Function {} expects at least {} arg(s). {} provided.".format(
173+
func.__name__, minimum_num_expected_args, num_provided_args
174+
)
175+
)
176+
else:
177+
if num_provided_args != minimum_num_expected_args:
178+
raise AttributeError(
179+
"Function {} expects {} arg(s). {} provided.".format(
180+
func.__name__, minimum_num_expected_args, num_provided_args
181+
)
182+
)
183+
184+
185+
class Future(object):
186+
"""Class representing a reference to a sagemaker job result.
187+
188+
The sagemaker job represented may or may not have finished running.
189+
"""
190+
191+
def __init__(self, job_arn):
192+
"""Initialize a JobResultReference object."""
193+
self._arn = job_arn
194+
195+
def result(self, timeout: int = None) -> Any:
196+
"""Dereferences and returns the job result object.
197+
198+
This method blocks on the sagemaker job completing for up to the timeout value (if
199+
specified). If timeout is ``None``, this method will block until the job is completed.
200+
Args:
201+
timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
202+
default.
203+
204+
Returns:
205+
The Python object returned by the function
206+
"""
207+
# pylint: disable=W0107
208+
pass
209+
210+
def wait(
211+
self,
212+
timeout: int = None,
213+
) -> None:
214+
"""Wait for the underlying sagemaker job to complete.
215+
216+
This method blocks on the sagemaker job completing for up to the timeout value (if
217+
specified). If timeout is ``None``, this method will block until the job is completed.
218+
Args:
219+
timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
220+
default.
221+
222+
Returns: None
223+
"""
224+
# pylint: disable=W0107
225+
pass
226+
227+
def cancel(self):
228+
"""Stop the underlying sagemaker job early if the it is still in progress.
229+
230+
Returns: None
231+
"""
232+
# pylint: disable=W0107
233+
pass
234+
235+
def running(self):
236+
"""Returns ``True`` if the underlying sagemaker job is still running."""
237+
# pylint: disable=W0107
238+
pass
239+
240+
def cancelled(self):
241+
"""Returns ``True`` if the underlying sagemaker job was cancelled. ``False``, otherwise."""
242+
# pylint: disable=W0107
243+
pass
244+
245+
def done(self):
246+
"""Returns ``True`` if the underlying sagemaker job was cancelled or finished running."""
247+
# pylint: disable=W0107
248+
pass

0 commit comments

Comments
 (0)