|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""SageMaker job Executor.""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +from typing import Any, Dict, List, Tuple |
| 17 | +import inspect |
| 18 | +import os |
| 19 | +import boto3 |
| 20 | + |
| 21 | +from sagemaker.job_runtime.serialization import JobSerializer |
| 22 | +from sagemaker.session import generate_default_sagemaker_bucket_name |
| 23 | +from sagemaker.utils import name_from_base, base_name_from_image |
| 24 | + |
| 25 | +JOBS_CONTAINER_ENTRYPOINT = ["python3", "/code/source/job_entrypoint.py"] |
| 26 | + |
| 27 | + |
| 28 | +class JobExecutor(object): |
| 29 | + """Handles asynchronous SageMaker jobs""" |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + environment_variables: Dict[str, str] = None, |
| 34 | + image_uri: str = None, |
| 35 | + dependencies: str = None, |
| 36 | + instance_type: str = None, |
| 37 | + instance_count: int = 1, |
| 38 | + volume_size: int = 30, |
| 39 | + max_runtime_in_seconds: int = 24 * 60 * 60, |
| 40 | + max_retry_attempts: int = 0, |
| 41 | + keep_alive_period_in_seconds: int = 0, |
| 42 | + role: str = None, |
| 43 | + s3_root_uri: str = None, |
| 44 | + s3_kms_key: str = None, |
| 45 | + volume_kms_key: str = None, |
| 46 | + subnets: List[str] = None, |
| 47 | + security_group_ids: List[str] = None, |
| 48 | + tags: List[Tuple[str, str]] = None, |
| 49 | + boto_session: boto3.session.Session = None, |
| 50 | + ): |
| 51 | + """Initiates a ``JobExecutor`` instance. |
| 52 | +
|
| 53 | + Args: |
| 54 | + environment_variables (Dict): Environment variables passed to the underlying sagemaker |
| 55 | + job. Defaults to None |
| 56 | + image_uri (str): Docker image URI on ECR. Defaults to base Python image. |
| 57 | + dependencies (str): Path to dependencies file or a reserved keyword ``AUTO_DETECT``. |
| 58 | + Defaults to None. |
| 59 | + instance_type (str): EC2 instance type. |
| 60 | + instance_count (int): Number of instance to use. Defaults to 1. |
| 61 | + volume_size (int): Size in GB of the storage volume to use for storing input and output |
| 62 | + data. Defaults to 30. |
| 63 | + max_runtime_in_seconds (int): Timeout in seconds for training. After this amount of |
| 64 | + time Amazon SageMaker terminates the job regardless of its current status. |
| 65 | + Defaults to 86400 seconds (1 day). |
| 66 | + max_retry_attempts (int): Max number of times the job is retried on |
| 67 | + InternalServerFailure.Defaults to 0. |
| 68 | + keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured |
| 69 | + resources in a warm pool for subsequent training jobs. Defaults to 0. |
| 70 | + role (str): IAM role used for SageMaker execution. Defaults to SageMaker default |
| 71 | + execution role. |
| 72 | + s3_root_uri (str): The root S3 folder where the code archives and data are uploaded to. |
| 73 | + This parameter is autogenerated using information regarding the image uri if not |
| 74 | + provided. |
| 75 | + s3_kms_key (str): The encryption key used for storing serialized data. Defaults to S3 |
| 76 | + managed key. |
| 77 | + volume_kms_key (str): KMS key used for encrypting EBS volume attached to the training |
| 78 | + instance. |
| 79 | + subnets (List[str]): List of subnet IDs. Defaults to None. |
| 80 | + security_group_ids (List[str]): List of security group IDs. Defaults to None. |
| 81 | + tags (List[Tuple[str, str]]): List of tags attached to the job. Defaults to None. |
| 82 | + boto_session (boto3.session.Session): The underlying Boto3 session which AWS service |
| 83 | + calls are delegated to (default: None). If not provided, one is created with |
| 84 | + default AWS configuration chain. |
| 85 | + """ |
| 86 | + self.environment_variables = environment_variables |
| 87 | + self.image_uri = image_uri |
| 88 | + self.dependencies = dependencies |
| 89 | + self.instance_type = instance_type |
| 90 | + self.instance_count = instance_count |
| 91 | + self.volume_size = volume_size |
| 92 | + self.max_runtime_in_seconds = max_runtime_in_seconds |
| 93 | + self.max_retry_attempts = max_retry_attempts |
| 94 | + self.keep_alive_period_in_seconds = keep_alive_period_in_seconds |
| 95 | + self.role = role |
| 96 | + self.boto_session = boto_session or boto3.Session() |
| 97 | + self.s3_root_uri = s3_root_uri or os.path.join( |
| 98 | + "s3://", |
| 99 | + generate_default_sagemaker_bucket_name(self.boto_session), |
| 100 | + base_name_from_image(self.image_uri), |
| 101 | + ) |
| 102 | + self.s3_kms_key = s3_kms_key |
| 103 | + self.volume_kms_key = volume_kms_key |
| 104 | + self.subnets = subnets |
| 105 | + self.security_group_ids = security_group_ids |
| 106 | + self.tags = [] if tags is None else [{"Key": k, "Value": v} for k, v in tags] |
| 107 | + |
| 108 | + def submit(self, func, *args, **kwargs): |
| 109 | + """Execute the input func as a SageMaker job asynchronously. |
| 110 | +
|
| 111 | + Args: |
| 112 | + func: Python function to run as a SageMaker job. |
| 113 | + *args: Positional arguments to the input function. |
| 114 | + **kwargs: keyword arguments to the input function |
| 115 | + """ |
| 116 | + self._validate_submit_args(func, *args, **kwargs) |
| 117 | + |
| 118 | + # TODO: retrieve and upload to S3 runtime config yaml file |
| 119 | + |
| 120 | + # TODO: retrieve and aggregate sagemaker_config.yaml |
| 121 | + |
| 122 | + job_serializer = JobSerializer(self.boto_session, self.s3_root_uri, self.s3_kms_key) |
| 123 | + # TODO: serialize data inputs |
| 124 | + function_s3_uri = job_serializer.serialize_function(func) |
| 125 | + |
| 126 | + sagemaker_client = self.boto_session.client("sagemaker") |
| 127 | + response = sagemaker_client.create_training_job( |
| 128 | + TrainingJobName=name_from_base(func.__name__), |
| 129 | + AlgorithmSpecification={ |
| 130 | + "TrainingImage": self.image_uri, |
| 131 | + "TrainingInputMode": "File", |
| 132 | + "ContainerEntrypoint": JOBS_CONTAINER_ENTRYPOINT, |
| 133 | + # TODO: add additional container args |
| 134 | + "ContainerArguments": ["--function_uri", function_s3_uri], |
| 135 | + }, |
| 136 | + RoleArn=self.role, |
| 137 | + ResourceConfig={ |
| 138 | + "InstanceType": self.instance_type, |
| 139 | + "InstanceCount": self.instance_count, |
| 140 | + "VolumeSizeInGB": self.volume_size, |
| 141 | + "VolumeKmsKeyId": self.volume_kms_key, |
| 142 | + "KeepAlivePeriodInSeconds": self.keep_alive_period_in_seconds, |
| 143 | + }, |
| 144 | + VpcConfig={"SecurityGroupIds": self.security_group_ids, "Subnets": self.subnets}, |
| 145 | + StoppingCondition={ |
| 146 | + "MaxRuntimeInSeconds": self.max_runtime_in_seconds, |
| 147 | + }, |
| 148 | + Tags=[{"Key": k, "Value": v} for k, v in self.tags], |
| 149 | + Environment=self.environment_variables, |
| 150 | + RetryStrategy={"MaximumRetryAttempts": self.max_retry_attempts}, |
| 151 | + ) |
| 152 | + training_arn = response["TrainingJobArn"] |
| 153 | + return Future(training_arn) |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + def _validate_submit_args(func, *args, **kwargs): |
| 157 | + """Validates input args passed to the submit() method.""" |
| 158 | + if not inspect.isfunction(func): |
| 159 | + raise TypeError("Only python functions can be run as SageMaker jobs.") |
| 160 | + |
| 161 | + full_arg_spec = inspect.getfullargspec(func) |
| 162 | + num_provided_args = len(args) + len(kwargs) |
| 163 | + minimum_num_expected_args = len(full_arg_spec.args) |
| 164 | + |
| 165 | + is_accepting_variable_args = not ( |
| 166 | + full_arg_spec.varargs is None and full_arg_spec.varkw is None |
| 167 | + ) |
| 168 | + |
| 169 | + if is_accepting_variable_args: |
| 170 | + if num_provided_args < minimum_num_expected_args: |
| 171 | + raise AttributeError( |
| 172 | + "Function {} expects at least {} arg(s). {} provided.".format( |
| 173 | + func.__name__, minimum_num_expected_args, num_provided_args |
| 174 | + ) |
| 175 | + ) |
| 176 | + else: |
| 177 | + if num_provided_args != minimum_num_expected_args: |
| 178 | + raise AttributeError( |
| 179 | + "Function {} expects {} arg(s). {} provided.".format( |
| 180 | + func.__name__, minimum_num_expected_args, num_provided_args |
| 181 | + ) |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +class Future(object): |
| 186 | + """Class representing a reference to a sagemaker job result. |
| 187 | +
|
| 188 | + The sagemaker job represented may or may not have finished running. |
| 189 | + """ |
| 190 | + |
| 191 | + def __init__(self, job_arn): |
| 192 | + """Initialize a JobResultReference object.""" |
| 193 | + self._arn = job_arn |
| 194 | + |
| 195 | + def result(self, timeout: int = None) -> Any: |
| 196 | + """Dereferences and returns the job result object. |
| 197 | +
|
| 198 | + This method blocks on the sagemaker job completing for up to the timeout value (if |
| 199 | + specified). If timeout is ``None``, this method will block until the job is completed. |
| 200 | + Args: |
| 201 | + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by |
| 202 | + default. |
| 203 | +
|
| 204 | + Returns: |
| 205 | + The Python object returned by the function |
| 206 | + """ |
| 207 | + # pylint: disable=W0107 |
| 208 | + pass |
| 209 | + |
| 210 | + def wait( |
| 211 | + self, |
| 212 | + timeout: int = None, |
| 213 | + ) -> None: |
| 214 | + """Wait for the underlying sagemaker job to complete. |
| 215 | +
|
| 216 | + This method blocks on the sagemaker job completing for up to the timeout value (if |
| 217 | + specified). If timeout is ``None``, this method will block until the job is completed. |
| 218 | + Args: |
| 219 | + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by |
| 220 | + default. |
| 221 | +
|
| 222 | + Returns: None |
| 223 | + """ |
| 224 | + # pylint: disable=W0107 |
| 225 | + pass |
| 226 | + |
| 227 | + def cancel(self): |
| 228 | + """Stop the underlying sagemaker job early if the it is still in progress. |
| 229 | +
|
| 230 | + Returns: None |
| 231 | + """ |
| 232 | + # pylint: disable=W0107 |
| 233 | + pass |
| 234 | + |
| 235 | + def running(self): |
| 236 | + """Returns ``True`` if the underlying sagemaker job is still running.""" |
| 237 | + # pylint: disable=W0107 |
| 238 | + pass |
| 239 | + |
| 240 | + def cancelled(self): |
| 241 | + """Returns ``True`` if the underlying sagemaker job was cancelled. ``False``, otherwise.""" |
| 242 | + # pylint: disable=W0107 |
| 243 | + pass |
| 244 | + |
| 245 | + def done(self): |
| 246 | + """Returns ``True`` if the underlying sagemaker job was cancelled or finished running.""" |
| 247 | + # pylint: disable=W0107 |
| 248 | + pass |
0 commit comments