Skip to content

Commit 2dfe7d3

Browse files
authored
change: allow serving script to be defined for deploy() and transformer() with frameworks (#944)
1 parent 356283e commit 2dfe7d3

File tree

13 files changed

+204
-65
lines changed

13 files changed

+204
-65
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,13 @@ def hyperparameters(self):
155155
return hyperparameters
156156

157157
def create_model(
158-
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT
158+
self,
159+
model_server_workers=None,
160+
role=None,
161+
vpc_config_override=VPC_CONFIG_DEFAULT,
162+
entry_point=None,
163+
source_dir=None,
164+
dependencies=None,
159165
):
160166
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an
161167
``Endpoint``.
@@ -171,17 +177,24 @@ def create_model(
171177
the model. Default: use subnets and security groups from this Estimator.
172178
* 'Subnets' (list[str]): List of subnet ids.
173179
* 'SecurityGroupIds' (list[str]): List of security group ids.
180+
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
181+
as the entry point to training. If not specified, the training entry point is used.
182+
source_dir (str): Path (absolute or relative) to a directory with any other serving
183+
source code dependencies aside from the entry point file.
184+
If not specified, the model source directory from training is used.
185+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
186+
any additional libraries that will be exported to the container.
187+
If not specified, the dependencies from training are used.
174188
175189
Returns:
176190
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
177191
object. See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
178192
"""
179-
role = role or self.role
180193
return ChainerModel(
181194
self.model_data,
182-
role,
183-
self.entry_point,
184-
source_dir=self._model_source_dir(),
195+
role or self.role,
196+
entry_point or self.entry_point,
197+
source_dir=(source_dir or self._model_source_dir()),
185198
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
186199
name=self._current_job_name,
187200
container_log_level=self.container_log_level,
@@ -192,7 +205,7 @@ def create_model(
192205
image=self.image_name,
193206
sagemaker_session=self.sagemaker_session,
194207
vpc_config=self.get_vpc_config(vpc_config_override),
195-
dependencies=self.dependencies,
208+
dependencies=(dependencies or self.dependencies),
196209
)
197210

198211
@classmethod

src/sagemaker/estimator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -1524,6 +1524,7 @@ def transformer(
15241524
role=None,
15251525
model_server_workers=None,
15261526
volume_kms_key=None,
1527+
entry_point=None,
15271528
):
15281529
"""Return a ``Transformer`` that uses a SageMaker Model based on the
15291530
training job. It reuses the SageMaker Session and base job name used by
@@ -1561,11 +1562,19 @@ def transformer(
15611562
worker per vCPU.
15621563
volume_kms_key (str): Optional. KMS key ID for encrypting the volume
15631564
attached to the ML compute instance (default: None).
1565+
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
1566+
as the entry point to training. If not specified, the training entry point is used.
1567+
1568+
Returns:
1569+
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
1570+
SageMaker Batch Transform job.
15641571
"""
15651572
role = role or self.role
15661573

15671574
if self.latest_training_job is not None:
1568-
model = self.create_model(role=role, model_server_workers=model_server_workers)
1575+
model = self.create_model(
1576+
role=role, model_server_workers=model_server_workers, entry_point=entry_point
1577+
)
15691578

15701579
container_def = model.prepare_container_def(instance_type)
15711580
model_name = model.name or name_from_image(container_def["Image"])

src/sagemaker/mxnet/estimator.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,13 @@ def _configure_distribution(self, distributions):
135135
self._hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
136136

137137
def create_model(
138-
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT
138+
self,
139+
model_server_workers=None,
140+
role=None,
141+
vpc_config_override=VPC_CONFIG_DEFAULT,
142+
entry_point=None,
143+
source_dir=None,
144+
dependencies=None,
139145
):
140146
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
141147
``Endpoint``.
@@ -151,17 +157,24 @@ def create_model(
151157
the model. Default: use subnets and security groups from this Estimator.
152158
* 'Subnets' (list[str]): List of subnet ids.
153159
* 'SecurityGroupIds' (list[str]): List of security group ids.
160+
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
161+
as the entry point to training. If not specified, the training entry point is used.
162+
source_dir (str): Path (absolute or relative) to a directory with any other serving
163+
source code dependencies aside from the entry point file.
164+
If not specified, the model source directory from training is used.
165+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
166+
any additional libraries that will be exported to the container.
167+
If not specified, the dependencies from training are used.
154168
155169
Returns:
156170
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
157171
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
158172
"""
159-
role = role or self.role
160173
return MXNetModel(
161174
self.model_data,
162-
role,
163-
self.entry_point,
164-
source_dir=self._model_source_dir(),
175+
role or self.role,
176+
entry_point or self.entry_point,
177+
source_dir=(source_dir or self._model_source_dir()),
165178
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
166179
name=self._current_job_name,
167180
container_log_level=self.container_log_level,
@@ -172,7 +185,7 @@ def create_model(
172185
model_server_workers=model_server_workers,
173186
sagemaker_session=self.sagemaker_session,
174187
vpc_config=self.get_vpc_config(vpc_config_override),
175-
dependencies=self.dependencies,
188+
dependencies=(dependencies or self.dependencies),
176189
)
177190

178191
@classmethod

src/sagemaker/pytorch/estimator.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,13 @@ def __init__(
108108
self.py_version = py_version
109109

110110
def create_model(
111-
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT
111+
self,
112+
model_server_workers=None,
113+
role=None,
114+
vpc_config_override=VPC_CONFIG_DEFAULT,
115+
entry_point=None,
116+
source_dir=None,
117+
dependencies=None,
112118
):
113119
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an
114120
``Endpoint``.
@@ -124,17 +130,24 @@ def create_model(
124130
the model. Default: use subnets and security groups from this Estimator.
125131
* 'Subnets' (list[str]): List of subnet ids.
126132
* 'SecurityGroupIds' (list[str]): List of security group ids.
133+
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
134+
as the entry point to training. If not specified, the training entry point is used.
135+
source_dir (str): Path (absolute or relative) to a directory with any other serving
136+
source code dependencies aside from the entry point file.
137+
If not specified, the model source directory from training is used.
138+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
139+
any additional libraries that will be exported to the container.
140+
If not specified, the dependencies from training are used.
127141
128142
Returns:
129143
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``
130144
object. See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
131145
"""
132-
role = role or self.role
133146
return PyTorchModel(
134147
self.model_data,
135-
role,
136-
self.entry_point,
137-
source_dir=self._model_source_dir(),
148+
role or self.role,
149+
entry_point or self.entry_point,
150+
source_dir=(source_dir or self._model_source_dir()),
138151
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
139152
name=self._current_job_name,
140153
container_log_level=self.container_log_level,
@@ -145,7 +158,7 @@ def create_model(
145158
model_server_workers=model_server_workers,
146159
sagemaker_session=self.sagemaker_session,
147160
vpc_config=self.get_vpc_config(vpc_config_override),
148-
dependencies=self.dependencies,
161+
dependencies=(dependencies or self.dependencies),
149162
)
150163

151164
@classmethod

src/sagemaker/tensorflow/estimator.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,12 @@ def create_model(
470470
role=None,
471471
vpc_config_override=VPC_CONFIG_DEFAULT,
472472
endpoint_type=None,
473+
entry_point=None,
474+
source_dir=None,
475+
dependencies=None,
473476
):
474-
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
477+
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
478+
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
475479
476480
Args:
477481
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
@@ -482,27 +486,55 @@ def create_model(
482486
Default: use subnets and security groups from this Estimator.
483487
* 'Subnets' (list[str]): List of subnet ids.
484488
* 'SecurityGroupIds' (list[str]): List of security group ids.
485-
endpoint_type: Optional. Selects the software stack used by the inference server.
489+
endpoint_type (str): Optional. Selects the software stack used by the inference server.
486490
If not specified, the model will be configured to use the default
487491
SageMaker model server. If 'tensorflow-serving', the model will be configured to
488492
use the SageMaker Tensorflow Serving container.
493+
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
494+
as the entry point to training. If not specified and ``endpoint_type`` is 'tensorflow-serving',
495+
no entry point is used. If ``endpoint_type`` is also ``None``, then the training entry point is used.
496+
source_dir (str): Path (absolute or relative) to a directory with any other serving
497+
source code dependencies aside from the entry point file. If not specified and
498+
``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If ``endpoint_type`` is also ``None``,
499+
then the model source directory from training is used.
500+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
501+
any additional libraries that will be exported to the container.
502+
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is set to ``None``.
503+
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
489504
490505
Returns:
491-
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
492-
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
506+
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A ``Model`` object.
507+
See :class:`~sagemaker.tensorflow.serving.Model` or :class:`~sagemaker.tensorflow.model.TensorFlowModel`
508+
for full details.
493509
"""
494-
495510
role = role or self.role
511+
496512
if endpoint_type == "tensorflow-serving" or self._script_mode_enabled():
497-
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)
513+
return self._create_tfs_model(
514+
role=role,
515+
vpc_config_override=vpc_config_override,
516+
entry_point=entry_point,
517+
source_dir=source_dir,
518+
dependencies=dependencies,
519+
)
498520

499521
return self._create_default_model(
500522
model_server_workers=model_server_workers,
501523
role=role,
502524
vpc_config_override=vpc_config_override,
525+
entry_point=entry_point,
526+
source_dir=source_dir,
527+
dependencies=dependencies,
503528
)
504529

505-
def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
530+
def _create_tfs_model(
531+
self,
532+
role=None,
533+
vpc_config_override=VPC_CONFIG_DEFAULT,
534+
entry_point=None,
535+
source_dir=None,
536+
dependencies=None,
537+
):
506538
"""Placeholder docstring"""
507539
return Model(
508540
model_data=self.model_data,
@@ -513,15 +545,26 @@ def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
513545
framework_version=utils.get_short_version(self.framework_version),
514546
sagemaker_session=self.sagemaker_session,
515547
vpc_config=self.get_vpc_config(vpc_config_override),
548+
entry_point=entry_point,
549+
source_dir=source_dir,
550+
dependencies=dependencies,
516551
)
517552

518-
def _create_default_model(self, model_server_workers, role, vpc_config_override):
553+
def _create_default_model(
554+
self,
555+
model_server_workers,
556+
role,
557+
vpc_config_override,
558+
entry_point=None,
559+
source_dir=None,
560+
dependencies=None,
561+
):
519562
"""Placeholder docstring"""
520563
return TensorFlowModel(
521564
self.model_data,
522565
role,
523-
self.entry_point,
524-
source_dir=self._model_source_dir(),
566+
entry_point or self.entry_point,
567+
source_dir=source_dir or self._model_source_dir(),
525568
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
526569
env={"SAGEMAKER_REQUIREMENTS": self.requirements_file},
527570
image=self.image_name,
@@ -533,7 +576,7 @@ def _create_default_model(self, model_server_workers, role, vpc_config_override)
533576
model_server_workers=model_server_workers,
534577
sagemaker_session=self.sagemaker_session,
535578
vpc_config=self.get_vpc_config(vpc_config_override),
536-
dependencies=self.dependencies,
579+
dependencies=dependencies or self.dependencies,
537580
)
538581

539582
def hyperparameters(self):
@@ -625,6 +668,7 @@ def transformer(
625668
model_server_workers=None,
626669
volume_kms_key=None,
627670
endpoint_type=None,
671+
entry_point=None,
628672
):
629673
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
630674
SageMaker Session and base job name used by the Estimator.
@@ -656,6 +700,9 @@ def transformer(
656700
SageMaker model server.
657701
If 'tensorflow-serving', the model will be configured to
658702
use the SageMaker Tensorflow Serving container.
703+
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
704+
as the entry point to training. If not specified and ``endpoint_type`` is 'tensorflow-serving',
705+
no entry point is used. If ``endpoint_type`` is also ``None``, then the training entry point is used.
659706
"""
660707

661708
role = role or self.role
@@ -664,6 +711,7 @@ def transformer(
664711
role=role,
665712
vpc_config_override=VPC_CONFIG_DEFAULT,
666713
endpoint_type=endpoint_type,
714+
entry_point=entry_point,
667715
)
668716
return model.transformer(
669717
instance_count,
Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -10,19 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
14-
"""Exports a toy TensorFlow model.
15-
Exports a TensorFlow model to /opt/ml/model/
16-
This graph calculates,
17-
y = a*x + b
18-
where a and b are variables with a=0.5 and b=2.
19-
"""
2013
import json
21-
import shutil
22-
23-
24-
def save_model():
25-
shutil.copytree("/opt/ml/code/123", "/opt/ml/model/123")
2614

2715

2816
def input_handler(data, context):
@@ -36,7 +24,3 @@ def output_handler(data, context):
3624
response_content_type = context.accept_header
3725
prediction = data.content
3826
return prediction, response_content_type
39-
40-
41-
if __name__ == "__main__":
42-
save_model()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
"""Exports a toy TensorFlow model.
15+
Exports a TensorFlow model to /opt/ml/model/
16+
This graph calculates,
17+
y = a*x + b
18+
where a and b are variables with a=0.5 and b=2.
19+
"""
20+
import shutil
21+
22+
23+
def save_model():
24+
shutil.copytree("/opt/ml/code/123", "/opt/ml/model/123")
25+
26+
27+
if __name__ == "__main__":
28+
save_model()

0 commit comments

Comments
 (0)