@@ -298,48 +298,57 @@ def test_create_model_with_optional_params(sagemaker_session):
298
298
new_role = "role"
299
299
model_server_workers = 2
300
300
301
- vpc_config = {'Subnets' : ['foo' ], 'SecurityGroupIds' : ['bar' ]}
302
- model = tf .create_model (role = new_role , model_server_workers = model_server_workers ,
303
- vpc_config_override = vpc_config )
301
+ vpc_config = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
302
+ model = tf .create_model (
303
+ role = new_role , model_server_workers = model_server_workers , vpc_config_override = vpc_config
304
+ )
304
305
305
306
assert model .role == new_role
306
307
assert model .model_server_workers == model_server_workers
307
308
assert model .vpc_config == vpc_config
308
309
309
310
310
- @patch (' sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model' )
311
+ @patch (" sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model" )
311
312
def test_transformer_creation_with_endpoint_type (create_tfs_model , sagemaker_session ):
312
- tf = TensorFlow (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
313
- train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE )
313
+ tf = TensorFlow (
314
+ entry_point = SCRIPT_PATH ,
315
+ role = ROLE ,
316
+ sagemaker_session = sagemaker_session ,
317
+ train_instance_count = INSTANCE_COUNT ,
318
+ train_instance_type = INSTANCE_TYPE ,
319
+ )
314
320
315
321
tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
316
- transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , endpoint_type = ' tensorflow-serving' )
322
+ transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , endpoint_type = " tensorflow-serving" )
317
323
assert isinstance (transformer , Transformer )
318
- create_tfs_model .assert_called_once ()
319
324
assert transformer .sagemaker_session == sagemaker_session
320
325
assert transformer .instance_count == INSTANCE_COUNT
321
326
assert transformer .instance_type == INSTANCE_TYPE
322
- assert transformer .tags is None
323
327
assert tf .script_mode is True
324
328
assert tf ._script_mode_enabled () is True
329
+ create_tfs_model .assert_called_once ()
325
330
326
331
327
- @patch (' sagemaker.tensorflow.estimator.TensorFlow._create_default_model' )
332
+ @patch (" sagemaker.tensorflow.estimator.TensorFlow._create_default_model" )
328
333
def test_transformer_creation_without_endpoint_type (create_default_model , sagemaker_session ):
329
334
330
- tf = TensorFlow (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
331
- train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE )
335
+ tf = TensorFlow (
336
+ entry_point = SCRIPT_PATH ,
337
+ role = ROLE ,
338
+ sagemaker_session = sagemaker_session ,
339
+ train_instance_count = INSTANCE_COUNT ,
340
+ train_instance_type = INSTANCE_TYPE ,
341
+ )
332
342
333
343
tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
334
344
transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE )
335
345
assert isinstance (transformer , Transformer )
336
- create_default_model .assert_called_once ()
337
346
assert transformer .sagemaker_session == sagemaker_session
338
347
assert transformer .instance_count == INSTANCE_COUNT
339
348
assert transformer .instance_type == INSTANCE_TYPE
340
- assert transformer .tags is None
341
349
assert tf .script_mode is False
342
350
assert tf ._script_mode_enabled () is False
351
+ create_default_model .assert_called_once ()
343
352
344
353
345
354
def test_create_model_with_custom_image (sagemaker_session ):
0 commit comments