@@ -309,19 +309,13 @@ def test_create_model_with_optional_params(sagemaker_session):
309
309
310
310
@patch ('sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model' )
311
311
def test_transformer_creation_with_endpoint_type (create_tfs_model , sagemaker_session ):
312
- container_log_level = '"logging.INFO"'
313
- source_dir = 's3://mybucket/source'
314
- enable_cloudwatch_metrics = 'true'
315
- base_name = 'foo'
316
312
tf = TensorFlow (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
317
- training_steps = 1000 , evaluation_steps = 10 , train_instance_count = INSTANCE_COUNT ,
318
- train_instance_type = INSTANCE_TYPE , container_log_level = container_log_level , base_job_name = base_name ,
319
- source_dir = source_dir , enable_cloudwatch_metrics = enable_cloudwatch_metrics )
313
+ train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE )
314
+
320
315
tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
321
- assert isinstance (tf , TensorFlow )
322
316
transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE , endpoint_type = 'tensorflow-serving' )
323
- create_tfs_model .assert_called_once ()
324
317
assert isinstance (transformer , Transformer )
318
+ create_tfs_model .assert_called_once ()
325
319
assert transformer .sagemaker_session == sagemaker_session
326
320
assert transformer .instance_count == INSTANCE_COUNT
327
321
assert transformer .instance_type == INSTANCE_TYPE
@@ -332,19 +326,14 @@ def test_transformer_creation_with_endpoint_type(create_tfs_model, sagemaker_ses
332
326
333
327
@patch ('sagemaker.tensorflow.estimator.TensorFlow._create_default_model' )
334
328
def test_transformer_creation_without_endpoint_type (create_default_model , sagemaker_session ):
335
- container_log_level = '"logging.INFO"'
336
- source_dir = 's3://mybucket/source'
337
- enable_cloudwatch_metrics = 'true'
338
- base_name = 'flo'
329
+
339
330
tf = TensorFlow (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
340
- training_steps = 1000 , evaluation_steps = 10 , train_instance_count = INSTANCE_COUNT ,
341
- train_instance_type = INSTANCE_TYPE , container_log_level = container_log_level , base_job_name = base_name ,
342
- source_dir = source_dir , enable_cloudwatch_metrics = enable_cloudwatch_metrics )
331
+ train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE )
332
+
343
333
tf .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
344
- assert isinstance (tf , TensorFlow )
345
334
transformer = tf .transformer (INSTANCE_COUNT , INSTANCE_TYPE )
346
- create_default_model .assert_called_once ()
347
335
assert isinstance (transformer , Transformer )
336
+ create_default_model .assert_called_once ()
348
337
assert transformer .sagemaker_session == sagemaker_session
349
338
assert transformer .instance_count == INSTANCE_COUNT
350
339
assert transformer .instance_type == INSTANCE_TYPE
0 commit comments