@@ -308,29 +308,26 @@ def test_create_model_with_optional_params(sagemaker_session):
308
308
assert model .vpc_config == vpc_config
309
309
310
310
311
- @patch ("sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model " )
312
- def test_transformer_creation_with_endpoint_type (create_tfs_model , sagemaker_session ):
311
+ @patch ("sagemaker.tensorflow.estimator.TensorFlow.create_model " )
312
+ def test_transformer_creation_with_endpoint_type (create_model , sagemaker_session ):
313
313
tf = TensorFlow (
314
314
entry_point = SCRIPT_PATH ,
315
315
role = ROLE ,
316
316
sagemaker_session = sagemaker_session ,
317
317
train_instance_count = INSTANCE_COUNT ,
318
318
train_instance_type = INSTANCE_TYPE ,
319
319
)
320
-
321
320
tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
322
- transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , endpoint_type = "tensorflow-serving" )
323
- assert isinstance (transformer , Transformer )
324
- assert transformer .sagemaker_session == sagemaker_session
325
- assert transformer .instance_count == INSTANCE_COUNT
326
- assert transformer .instance_type == INSTANCE_TYPE
327
- assert tf .script_mode is True
328
- assert tf ._script_mode_enabled () is True
329
- create_tfs_model .assert_called_once ()
321
+
322
+ tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , model_server_workers = 2 , endpoint_type = "tensorflow-serving" )
323
+ create_model .assert_called_with (endpoint_type = 'tensorflow-serving' ,
324
+ model_server_workers = 2 ,
325
+ role = 'Dummy' ,
326
+ vpc_config_override = 'VPC_CONFIG_DEFAULT' )
330
327
331
328
332
- @patch ("sagemaker.tensorflow.estimator.TensorFlow._create_default_model " )
333
- def test_transformer_creation_without_endpoint_type (create_default_model , sagemaker_session ):
329
+ @patch ("sagemaker.tensorflow.estimator.TensorFlow.create_model " )
330
+ def test_transformer_creation_without_endpoint_type (create_model , sagemaker_session ):
334
331
335
332
tf = TensorFlow (
336
333
entry_point = SCRIPT_PATH ,
@@ -341,14 +338,11 @@ def test_transformer_creation_without_endpoint_type(create_default_model, sagema
341
338
)
342
339
343
340
tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
344
- transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE )
345
- assert isinstance (transformer , Transformer )
346
- assert transformer .sagemaker_session == sagemaker_session
347
- assert transformer .instance_count == INSTANCE_COUNT
348
- assert transformer .instance_type == INSTANCE_TYPE
349
- assert tf .script_mode is False
350
- assert tf ._script_mode_enabled () is False
351
- create_default_model .assert_called_once ()
341
+ transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , model_server_workers = 4 )
342
+ create_model .assert_called_with (endpoint_type = None ,
343
+ model_server_workers = 4 ,
344
+ role = 'Dummy' ,
345
+ vpc_config_override = 'VPC_CONFIG_DEFAULT' )
352
346
353
347
354
348
def test_create_model_with_custom_image (sagemaker_session ):
0 commit comments