|
23 | 23 | from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
|
24 | 24 | from sagemaker.session import s3_input
|
25 | 25 | from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
|
| 26 | +from sagemaker.estimator import _TrainingJob |
26 | 27 | import sagemaker.tensorflow.estimator as tfe
|
| 28 | +from sagemaker.transformer import Transformer |
27 | 29 |
|
28 | 30 | DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
|
29 | 31 | SCRIPT_FILE = 'dummy_script.py'
|
@@ -264,12 +266,56 @@ def test_create_model_with_optional_params(sagemaker_session):
|
264 | 266 | vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
|
265 | 267 | model = tf.create_model(role=new_role, model_server_workers=model_server_workers,
|
266 | 268 | vpc_config_override=vpc_config)
|
267 |
| - |
268 | 269 | assert model.role == new_role
|
269 | 270 | assert model.model_server_workers == model_server_workers
|
270 | 271 | assert model.vpc_config == vpc_config
|
271 | 272 |
|
272 | 273 |
|
| 274 | +@patch('sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model') |
| 275 | +def test_transformer_creation_with_endpoint_type(create_tfs_model, sagemaker_session): |
| 276 | + container_log_level = '"logging.INFO"' |
| 277 | + source_dir = 's3://mybucket/source' |
| 278 | + enable_cloudwatch_metrics = 'true' |
| 279 | + base_name = 'foo' |
| 280 | + tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, |
| 281 | + training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT, |
| 282 | + train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name=base_name, |
| 283 | + source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics) |
| 284 | + tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) |
| 285 | + assert isinstance(tf, TensorFlow) |
| 286 | + transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type='tensorflow-serving') |
| 287 | + create_tfs_model.assert_called_once() |
| 288 | + assert isinstance(transformer, Transformer) |
| 289 | + assert transformer.sagemaker_session == sagemaker_session |
| 290 | + assert transformer.instance_count == INSTANCE_COUNT |
| 291 | + assert transformer.instance_type == INSTANCE_TYPE |
| 292 | + assert transformer.tags is None |
| 293 | + assert tf.script_mode is True |
| 294 | + assert tf._script_mode_enabled() is True |
| 295 | + |
| 296 | +@patch('sagemaker.tensorflow.estimator.TensorFlow._create_default_model') |
| 297 | +def test_transformer_creation_without_endpoint_type(create_default_model, sagemaker_session): |
| 298 | + container_log_level = '"logging.INFO"' |
| 299 | + source_dir = 's3://mybucket/source' |
| 300 | + enable_cloudwatch_metrics = 'true' |
| 301 | + base_name = 'flo' |
| 302 | + tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, |
| 303 | + training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT, |
| 304 | + train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name=base_name, |
| 305 | + source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics) |
| 306 | + tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) |
| 307 | + assert isinstance(tf, TensorFlow) |
| 308 | + transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE) |
| 309 | + create_default_model.assert_called_once() |
| 310 | + assert isinstance(transformer, Transformer) |
| 311 | + assert transformer.sagemaker_session == sagemaker_session |
| 312 | + assert transformer.instance_count == INSTANCE_COUNT |
| 313 | + assert transformer.instance_type == INSTANCE_TYPE |
| 314 | + assert transformer.tags is None |
| 315 | + assert tf.script_mode is False |
| 316 | + assert tf._script_mode_enabled() is False |
| 317 | + |
| 318 | + |
273 | 319 | def test_create_model_with_custom_image(sagemaker_session):
|
274 | 320 | container_log_level = '"logging.INFO"'
|
275 | 321 | source_dir = 's3://mybucket/source'
|
|
0 commit comments