22
22
import time
23
23
24
24
from sagemaker .estimator import Framework
25
- from sagemaker .fw_utils import framework_name_from_image , framework_version_from_tag , empty_framework_version_warning
26
- from sagemaker .utils import get_config_value
27
- from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
28
-
25
+ from sagemaker .fw_utils import framework_name_from_image , framework_version_from_tag , \
26
+ empty_framework_version_warning
29
27
from sagemaker .tensorflow .defaults import TF_VERSION
30
28
from sagemaker .tensorflow .model import TensorFlowModel
29
+ from sagemaker .tensorflow .serving import Model
30
+ from sagemaker .utils import get_config_value
31
+ from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
31
32
32
33
logging .basicConfig ()
33
34
LOGGER = logging .getLogger ('sagemaker' )
@@ -103,12 +104,14 @@ def validate_requirements(self):
103
104
EnvironmentError: If at least one requirement is not installed.
104
105
"""
105
106
if not self ._cmd_exists ('tensorboard' ):
106
- raise EnvironmentError ('TensorBoard is not installed in the system. Please install TensorBoard using the'
107
- ' following command: \n pip install tensorboard' )
107
+ raise EnvironmentError (
108
+ 'TensorBoard is not installed in the system. Please install TensorBoard using the'
109
+ ' following command: \n pip install tensorboard' )
108
110
109
111
if not self ._cmd_exists ('aws' ):
110
- raise EnvironmentError ('The AWS CLI is not installed in the system. Please install the AWS CLI using the'
111
- ' following command: \n pip install awscli' )
112
+ raise EnvironmentError (
113
+ 'The AWS CLI is not installed in the system. Please install the AWS CLI using the'
114
+ ' following command: \n pip install awscli' )
112
115
113
116
def create_tensorboard_process (self ):
114
117
"""Create a TensorBoard process.
@@ -125,7 +128,8 @@ def create_tensorboard_process(self):
125
128
126
129
for i in range (100 ):
127
130
p = subprocess .Popen (
128
- ["tensorboard" , "--logdir" , self .logdir , "--host" , "localhost" , "--port" , str (port )],
131
+ ["tensorboard" , "--logdir" , self .logdir , "--host" , "localhost" , "--port" ,
132
+ str (port )],
129
133
stdout = subprocess .PIPE ,
130
134
stderr = subprocess .PIPE
131
135
)
@@ -135,7 +139,8 @@ def create_tensorboard_process(self):
135
139
else :
136
140
return port , p
137
141
138
- raise OSError ('No available ports to start TensorBoard. Attempted all ports between 6006 and 6105' )
142
+ raise OSError (
143
+ 'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105' )
139
144
140
145
def run (self ):
141
146
"""Run TensorBoard process."""
@@ -158,7 +163,8 @@ class TensorFlow(Framework):
158
163
159
164
__framework_name__ = 'tensorflow'
160
165
161
- def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None , py_version = 'py2' ,
166
+ def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None ,
167
+ py_version = 'py2' ,
162
168
framework_version = None , requirements_file = '' , image_name = None , ** kwargs ):
163
169
"""Initialize an ``TensorFlow`` estimator.
164
170
Args:
@@ -202,7 +208,8 @@ def _validate_requirements_file(self, requirements_file):
202
208
raise ValueError ('Must specify source_dir along with a requirements file.' )
203
209
204
210
if os .path .isabs (requirements_file ):
205
- raise ValueError ('Requirements file {} is not a path relative to source_dir.' .format (requirements_file ))
211
+ raise ValueError ('Requirements file {} is not a path relative to source_dir.' .format (
212
+ requirements_file ))
206
213
207
214
if not os .path .exists (os .path .join (self .source_dir , requirements_file )):
208
215
raise ValueError ('Requirements file {} does not exist.' .format (requirements_file ))
@@ -231,6 +238,7 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_
231
238
downloaded checkpoint information (default: False). This is an experimental feature, and requires
232
239
TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
233
240
"""
241
+
234
242
def fit_super ():
235
243
super (TensorFlow , self ).fit (inputs , wait , logs , job_name )
236
244
@@ -263,7 +271,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
263
271
dictionary: The transformed init_params
264
272
265
273
"""
266
- init_params = super (TensorFlow , cls )._prepare_init_params_from_job_description (job_details , model_channel_name )
274
+ init_params = super (TensorFlow , cls )._prepare_init_params_from_job_description (job_details ,
275
+ model_channel_name )
267
276
268
277
# Move some of the tensorflow specific init params from hyperparameters into the main init params.
269
278
for argument in ['checkpoint_path' , 'training_steps' , 'evaluation_steps' ]:
@@ -285,15 +294,18 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
285
294
# containing framework version, device type and python version (e.g. '1.5-gpu-py2').
286
295
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
287
296
# otherwise extract framework version from the tag itself.
288
- init_params ['framework_version' ] = '1.4' if tag == '1.0' else framework_version_from_tag (tag )
297
+ init_params ['framework_version' ] = '1.4' if tag == '1.0' else framework_version_from_tag (
298
+ tag )
289
299
290
300
training_job_name = init_params ['base_job_name' ]
291
301
if framework != cls .__framework_name__ :
292
- raise ValueError ("Training job: {} didn't use image for requested framework" .format (training_job_name ))
302
+ raise ValueError ("Training job: {} didn't use image for requested framework" .format (
303
+ training_job_name ))
293
304
294
305
return init_params
295
306
296
- def create_model (self , model_server_workers = None , role = None , vpc_config_override = VPC_CONFIG_DEFAULT ):
307
+ def create_model (self , model_server_workers = None , role = None ,
308
+ vpc_config_override = VPC_CONFIG_DEFAULT , endpoint_type = None ):
297
309
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
298
310
299
311
Args:
@@ -305,18 +317,44 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
305
317
Default: use subnets and security groups from this Estimator.
306
318
* 'Subnets' (list[str]): List of subnet ids.
307
319
* 'SecurityGroupIds' (list[str]): List of security group ids.
320
+ endpoint_type: Optional. Selects the software stack used by the inference server.
321
+ If not specified, the model will be configured to use the default
322
+ SageMaker model server. If 'tensorflow-serving', the model will be configured to
323
+ use the SageMaker Tensorflow Serving container.
308
324
309
325
Returns:
310
326
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
311
327
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
312
328
"""
313
- env = { 'SAGEMAKER_REQUIREMENTS' : self . requirements_file }
329
+
314
330
role = role or self .role
315
- return TensorFlowModel (self .model_data , role , self .entry_point , source_dir = self ._model_source_dir (),
316
- enable_cloudwatch_metrics = self .enable_cloudwatch_metrics , env = env , image = self .image_name ,
317
- name = self ._current_job_name , container_log_level = self .container_log_level ,
331
+ if endpoint_type == 'tensorflow-serving' :
332
+ return self ._create_tfs_model (role = role , vpc_config_override = vpc_config_override )
333
+
334
+ return self ._create_default_model (model_server_workers = model_server_workers , role = role ,
335
+ vpc_config_override = vpc_config_override )
336
+
337
+ def _create_tfs_model (self , role = None , vpc_config_override = VPC_CONFIG_DEFAULT ):
338
+ return Model (model_data = self .model_data ,
339
+ role = role ,
340
+ image = self .image_name ,
341
+ name = self ._current_job_name ,
342
+ container_log_level = self .container_log_level ,
343
+ framework_version = self .framework_version ,
344
+ sagemaker_session = self .sagemaker_session ,
345
+ vpc_config = self .get_vpc_config (vpc_config_override ))
346
+
347
+ def _create_default_model (self , model_server_workers , role , vpc_config_override ):
348
+ return TensorFlowModel (self .model_data , role , self .entry_point ,
349
+ source_dir = self ._model_source_dir (),
350
+ enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
351
+ env = {'SAGEMAKER_REQUIREMENTS' : self .requirements_file },
352
+ image = self .image_name ,
353
+ name = self ._current_job_name ,
354
+ container_log_level = self .container_log_level ,
318
355
code_location = self .code_location , py_version = self .py_version ,
319
- framework_version = self .framework_version , model_server_workers = model_server_workers ,
356
+ framework_version = self .framework_version ,
357
+ model_server_workers = model_server_workers ,
320
358
sagemaker_session = self .sagemaker_session ,
321
359
vpc_config = self .get_vpc_config (vpc_config_override ))
322
360
0 commit comments