@@ -593,3 +593,75 @@ def train_image(self):
593
593
)
594
594
595
595
return super (TensorFlow , self ).train_image ()
596
+
597
+ def transformer (
598
+ self ,
599
+ instance_count ,
600
+ instance_type ,
601
+ strategy = None ,
602
+ assemble_with = None ,
603
+ output_path = None ,
604
+ output_kms_key = None ,
605
+ accept = None ,
606
+ env = None ,
607
+ max_concurrent_transforms = None ,
608
+ max_payload = None ,
609
+ tags = None ,
610
+ role = None ,
611
+ model_server_workers = None ,
612
+ volume_kms_key = None ,
613
+ endpoint_type = None ,
614
+ ):
615
+ """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
616
+ SageMaker Session and base job name used by the Estimator.
617
+
618
+ Args:
619
+ instance_count (int): Number of EC2 instances to use.
620
+ instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
621
+ strategy (str): The strategy used to decide how to batch records in a single request (default: None).
622
+ Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
623
+ assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
624
+ output_path (str): S3 location for saving the transform result. If not specified, results are stored to
625
+ a default bucket.
626
+ output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
627
+ accept (str): The content type accepted by the endpoint deployed during the transform job.
628
+ env (dict): Environment variables to be set for use during the transform job (default: None).
629
+ max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
630
+ each individual transform container at one time.
631
+ max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
632
+ tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
633
+ the training job are used for the transform job.
634
+ role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
635
+ transform jobs. If not specified, the role from the Estimator will be used.
636
+ model_server_workers (int): Optional. The number of worker processes used by the inference server.
637
+ If None, server will use one worker per vCPU.
638
+ volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
639
+ compute instance (default: None).
640
+ endpoint_type (str): Optional. Selects the software stack used by the inference server.
641
+ If not specified, the model will be configured to use the default
642
+ SageMaker model server.
643
+ If 'tensorflow-serving', the model will be configured to
644
+ use the SageMaker Tensorflow Serving container.
645
+ """
646
+
647
+ role = role or self .role
648
+ model = self .create_model (
649
+ model_server_workers = model_server_workers ,
650
+ role = role ,
651
+ vpc_config_override = VPC_CONFIG_DEFAULT ,
652
+ endpoint_type = endpoint_type ,
653
+ )
654
+ return model .transformer (
655
+ instance_count ,
656
+ instance_type ,
657
+ strategy = strategy ,
658
+ assemble_with = assemble_with ,
659
+ output_path = output_path ,
660
+ output_kms_key = output_kms_key ,
661
+ accept = accept ,
662
+ env = env ,
663
+ max_concurrent_transforms = max_concurrent_transforms ,
664
+ max_payload = max_payload ,
665
+ tags = tags ,
666
+ volume_kms_key = volume_kms_key ,
667
+ )
0 commit comments