Skip to content

Commit ce6f566

Browse files
committed
add back transformer() because TFS doesn't accept model_server_workers
1 parent ed158e0 commit ce6f566

File tree

4 files changed

+249
-10
lines changed

4 files changed

+249
-10
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import sagemaker.fw_utils as fw
2525
from sagemaker.tensorflow import defaults
2626
from sagemaker.tensorflow.serving import Model
27+
from sagemaker.transformer import Transformer
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2829

2930
logger = logging.getLogger("sagemaker")
@@ -384,3 +385,126 @@ def train_image(self):
384385
)
385386

386387
return super(TensorFlow, self).train_image()
388+
389+
def transformer(
390+
self,
391+
instance_count,
392+
instance_type,
393+
strategy=None,
394+
assemble_with=None,
395+
output_path=None,
396+
output_kms_key=None,
397+
accept=None,
398+
env=None,
399+
max_concurrent_transforms=None,
400+
max_payload=None,
401+
tags=None,
402+
role=None,
403+
volume_kms_key=None,
404+
entry_point=None,
405+
vpc_config_override=VPC_CONFIG_DEFAULT,
406+
enable_network_isolation=None,
407+
model_name=None,
408+
):
409+
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It
410+
reuses the SageMaker Session and base job name used by the Estimator.
411+
412+
Args:
413+
instance_count (int): Number of EC2 instances to use.
414+
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
415+
strategy (str): The strategy used to decide how to batch records in a single request
416+
(default: None). Valid values: 'MultiRecord' and 'SingleRecord'.
417+
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line'
418+
or 'None'.
419+
output_path (str): S3 location for saving the transform result. If not specified,
420+
results are stored to a default bucket.
421+
output_kms_key (str): Optional. KMS key ID for encrypting the transform output
422+
(default: None).
423+
accept (str): The accept header passed by the client to
424+
the inference endpoint. If it is supported by the endpoint,
425+
it will be the format of the batch transform output.
426+
env (dict): Environment variables to be set for use during the transform job
427+
(default: None).
428+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
429+
each individual transform container at one time.
430+
max_payload (int): Maximum size of the payload in a single HTTP request to the
431+
container in MB.
432+
tags (list[dict]): List of tags for labeling a transform job. If none specified, then
433+
the tags used for the training job are used for the transform job.
434+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
435+
used during transform jobs. If not specified, the role from the Estimator will be
436+
used.
437+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
438+
compute instance (default: None).
439+
entry_point (str): Path (absolute or relative) to the local Python source file which
440+
should be executed as the entry point to training. If not specified and
441+
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
442+
``endpoint_type`` is also ``None``, then the training entry point is used.
443+
vpc_config_override (dict[str, list[str]]): Optional override for
444+
the VpcConfig set on the model.
445+
Default: use subnets and security groups from this Estimator.
446+
447+
* 'Subnets' (list[str]): List of subnet ids.
448+
* 'SecurityGroupIds' (list[str]): List of security group ids.
449+
450+
enable_network_isolation (bool): Specifies whether container will
451+
run in network isolation mode. Network isolation mode restricts
452+
the container access to outside networks (such as the internet).
453+
The container does not make any inbound or outbound network
454+
calls. If True, a channel named "code" will be created for any
455+
user entry script for inference. Also known as Internet-free mode.
456+
If not specified, this setting is taken from the estimator's
457+
current configuration.
458+
model_name (str): Name to use for creating an Amazon SageMaker
459+
model. If not specified, the name of the training job is used.
460+
"""
461+
role = role or self.role
462+
463+
if self.latest_training_job is None:
464+
logging.warning(
465+
"No finished training job found associated with this estimator. Please make sure "
466+
"this estimator is only used for building workflow config"
467+
)
468+
return Transformer(
469+
model_name or self._current_job_name,
470+
instance_count,
471+
instance_type,
472+
strategy=strategy,
473+
assemble_with=assemble_with,
474+
output_path=output_path,
475+
output_kms_key=output_kms_key,
476+
accept=accept,
477+
max_concurrent_transforms=max_concurrent_transforms,
478+
max_payload=max_payload,
479+
env=env or {},
480+
tags=tags,
481+
base_transform_job_name=self.base_job_name,
482+
volume_kms_key=volume_kms_key,
483+
sagemaker_session=self.sagemaker_session,
484+
)
485+
486+
if enable_network_isolation is None:
487+
enable_network_isolation = self.enable_network_isolation()
488+
489+
model = self.create_model(
490+
role=role,
491+
vpc_config_override=vpc_config_override,
492+
entry_point=entry_point,
493+
enable_network_isolation=enable_network_isolation,
494+
name=model_name,
495+
)
496+
497+
return model.transformer(
498+
instance_count,
499+
instance_type,
500+
strategy=strategy,
501+
assemble_with=assemble_with,
502+
output_path=output_path,
503+
output_kms_key=output_kms_key,
504+
accept=accept,
505+
env=env,
506+
max_concurrent_transforms=max_concurrent_transforms,
507+
max_payload=max_payload,
508+
tags=tags,
509+
volume_kms_key=volume_kms_key,
510+
)

src/sagemaker/workflow/airflow.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sagemaker
2020
from sagemaker import fw_utils, job, utils, session, vpc_utils
2121
from sagemaker.amazon import amazon_estimator
22+
from sagemaker.tensorflow import TensorFlow
2223

2324

2425
def prepare_framework(estimator, s3_operations):
@@ -646,15 +647,19 @@ def model_config_from_estimator(
646647
)
647648
elif isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
648649
model = estimator.create_model(vpc_config_override=vpc_config_override)
650+
elif isinstance(estimator, TensorFlow):
651+
model = estimator.create_model(
652+
role=role,
653+
vpc_config_override=vpc_config_override,
654+
entry_point=estimator.entry_point,
655+
)
649656
elif isinstance(estimator, sagemaker.estimator.Framework):
650-
model_kwargs = {
651-
"role": role,
652-
"vpc_config_override": vpc_config_override,
653-
"entry_point": estimator.entry_point,
654-
}
655-
if model_server_workers:
656-
model_kwargs["model_server_workers"] = model_server_workers
657-
model = estimator.create_model(**model_kwargs)
657+
model = estimator.create_model(
658+
model_server_workers=model_server_workers,
659+
role=role,
660+
vpc_config_override=vpc_config_override,
661+
entry_point=estimator.entry_point,
662+
)
658663
else:
659664
raise TypeError(
660665
"Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework"

tests/integ/test_tfs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from __future__ import absolute_import
1414

1515
import tarfile
16-
import os
1716

1817
import botocore.exceptions
19-
import pytest
18+
import os
2019

20+
import pytest
2121
import sagemaker
2222
import sagemaker.predictor
2323
import sagemaker.utils
@@ -104,6 +104,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
104104

105105
predictor = model.deploy(1, "local", endpoint_name=endpoint_name)
106106
try:
107+
107108
yield predictor
108109
finally:
109110
predictor.delete_endpoint()

tests/unit/test_tf_estimator.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
from mock import patch, Mock, MagicMock
2121

22+
from sagemaker.estimator import _TrainingJob
2223
from sagemaker.tensorflow import defaults, serving, TensorFlow
2324

2425
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
@@ -477,6 +478,114 @@ def test_attach_wrong_framework(sagemaker_session):
477478
assert "didn't use image for requested framework" in str(error)
478479

479480

481+
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
482+
def test_transformer_creation_with_optional_args(create_model, sagemaker_session):
483+
model = Mock()
484+
create_model.return_value = model
485+
486+
tf = TensorFlow(
487+
entry_point=SCRIPT_PATH,
488+
role=ROLE,
489+
sagemaker_session=sagemaker_session,
490+
train_instance_count=INSTANCE_COUNT,
491+
train_instance_type=INSTANCE_TYPE,
492+
)
493+
tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")
494+
495+
strategy = "SingleRecord"
496+
assemble_with = "Line"
497+
output_path = "s3://{}/batch-output".format(BUCKET_NAME)
498+
kms_key = "kms"
499+
accept_type = "text/bytes"
500+
env = {"foo": "bar"}
501+
max_concurrent_transforms = 3
502+
max_payload = 100
503+
tags = {"Key": "foo", "Value": "bar"}
504+
new_role = "role"
505+
vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]}
506+
model_name = "model-name"
507+
508+
tf.transformer(
509+
INSTANCE_COUNT,
510+
INSTANCE_TYPE,
511+
strategy=strategy,
512+
assemble_with=assemble_with,
513+
output_path=output_path,
514+
output_kms_key=kms_key,
515+
accept=accept_type,
516+
env=env,
517+
max_concurrent_transforms=max_concurrent_transforms,
518+
max_payload=max_payload,
519+
tags=tags,
520+
role=new_role,
521+
volume_kms_key=kms_key,
522+
entry_point=SERVING_SCRIPT_FILE,
523+
vpc_config_override=vpc_config,
524+
enable_network_isolation=True,
525+
model_name=model_name,
526+
)
527+
528+
create_model.assert_called_with(
529+
role=new_role,
530+
vpc_config_override=vpc_config,
531+
entry_point=SERVING_SCRIPT_FILE,
532+
enable_network_isolation=True,
533+
name=model_name,
534+
)
535+
model.transformer.assert_called_with(
536+
INSTANCE_COUNT,
537+
INSTANCE_TYPE,
538+
accept=accept_type,
539+
assemble_with=assemble_with,
540+
env=env,
541+
max_concurrent_transforms=max_concurrent_transforms,
542+
max_payload=max_payload,
543+
output_kms_key=kms_key,
544+
output_path=output_path,
545+
strategy=strategy,
546+
tags=tags,
547+
volume_kms_key=kms_key,
548+
)
549+
550+
551+
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
552+
def test_transformer_creation_without_optional_args(create_model, sagemaker_session):
553+
model = Mock()
554+
create_model.return_value = model
555+
556+
tf = TensorFlow(
557+
entry_point=SCRIPT_PATH,
558+
role=ROLE,
559+
sagemaker_session=sagemaker_session,
560+
train_instance_count=INSTANCE_COUNT,
561+
train_instance_type=INSTANCE_TYPE,
562+
)
563+
tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")
564+
tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
565+
566+
create_model.assert_called_with(
567+
role=ROLE,
568+
vpc_config_override="VPC_CONFIG_DEFAULT",
569+
entry_point=None,
570+
enable_network_isolation=False,
571+
name=None,
572+
)
573+
model.transformer.assert_called_with(
574+
INSTANCE_COUNT,
575+
INSTANCE_TYPE,
576+
accept=None,
577+
assemble_with=None,
578+
env=None,
579+
max_concurrent_transforms=None,
580+
max_payload=None,
581+
output_kms_key=None,
582+
output_path=None,
583+
strategy=None,
584+
tags=None,
585+
volume_kms_key=None,
586+
)
587+
588+
480589
def test_attach_custom_image(sagemaker_session):
481590
training_image = "1.dkr.ecr.us-west-2.amazonaws.com/tensorflow_with_custom_binary:1.0"
482591
rjd = {

0 commit comments

Comments
 (0)