|
24 | 24 | import sagemaker.fw_utils as fw
|
25 | 25 | from sagemaker.tensorflow import defaults
|
26 | 26 | from sagemaker.tensorflow.serving import Model
|
| 27 | +from sagemaker.transformer import Transformer |
27 | 28 | from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
|
28 | 29 |
|
29 | 30 | logger = logging.getLogger("sagemaker")
|
@@ -384,3 +385,126 @@ def train_image(self):
|
384 | 385 | )
|
385 | 386 |
|
386 | 387 | return super(TensorFlow, self).train_image()
|
| 388 | + |
| 389 | + def transformer( |
| 390 | + self, |
| 391 | + instance_count, |
| 392 | + instance_type, |
| 393 | + strategy=None, |
| 394 | + assemble_with=None, |
| 395 | + output_path=None, |
| 396 | + output_kms_key=None, |
| 397 | + accept=None, |
| 398 | + env=None, |
| 399 | + max_concurrent_transforms=None, |
| 400 | + max_payload=None, |
| 401 | + tags=None, |
| 402 | + role=None, |
| 403 | + volume_kms_key=None, |
| 404 | + entry_point=None, |
| 405 | + vpc_config_override=VPC_CONFIG_DEFAULT, |
| 406 | + enable_network_isolation=None, |
| 407 | + model_name=None, |
| 408 | + ): |
| 409 | + """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It |
| 410 | + reuses the SageMaker Session and base job name used by the Estimator. |
| 411 | +
|
| 412 | + Args: |
| 413 | + instance_count (int): Number of EC2 instances to use. |
| 414 | + instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'. |
| 415 | + strategy (str): The strategy used to decide how to batch records in a single request |
| 416 | + (default: None). Valid values: 'MultiRecord' and 'SingleRecord'. |
| 417 | + assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' |
| 418 | + or 'None'. |
| 419 | + output_path (str): S3 location for saving the transform result. If not specified, |
| 420 | + results are stored to a default bucket. |
| 421 | + output_kms_key (str): Optional. KMS key ID for encrypting the transform output |
| 422 | + (default: None). |
| 423 | + accept (str): The accept header passed by the client to |
| 424 | + the inference endpoint. If it is supported by the endpoint, |
| 425 | + it will be the format of the batch transform output. |
| 426 | + env (dict): Environment variables to be set for use during the transform job |
| 427 | + (default: None). |
| 428 | + max_concurrent_transforms (int): The maximum number of HTTP requests to be made to |
| 429 | + each individual transform container at one time. |
| 430 | + max_payload (int): Maximum size of the payload in a single HTTP request to the |
| 431 | + container in MB. |
| 432 | + tags (list[dict]): List of tags for labeling a transform job. If none specified, then |
| 433 | + the tags used for the training job are used for the transform job. |
| 434 | + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also |
| 435 | + used during transform jobs. If not specified, the role from the Estimator will be |
| 436 | + used. |
| 437 | + volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML |
| 438 | + compute instance (default: None). |
| 439 | + entry_point (str): Path (absolute or relative) to the local Python source file which |
| 440 | + should be executed as the entry point to training. If not specified and |
| 441 | + ``endpoint_type`` is 'tensorflow-serving', no entry point is used. If |
| 442 | + ``endpoint_type`` is also ``None``, then the training entry point is used. |
| 443 | + vpc_config_override (dict[str, list[str]]): Optional override for |
| 444 | + the VpcConfig set on the model. |
| 445 | + Default: use subnets and security groups from this Estimator. |
| 446 | +
|
| 447 | + * 'Subnets' (list[str]): List of subnet ids. |
| 448 | + * 'SecurityGroupIds' (list[str]): List of security group ids. |
| 449 | +
|
| 450 | + enable_network_isolation (bool): Specifies whether container will |
| 451 | + run in network isolation mode. Network isolation mode restricts |
| 452 | + the container access to outside networks (such as the internet). |
| 453 | + The container does not make any inbound or outbound network |
| 454 | + calls. If True, a channel named "code" will be created for any |
| 455 | + user entry script for inference. Also known as Internet-free mode. |
| 456 | + If not specified, this setting is taken from the estimator's |
| 457 | + current configuration. |
| 458 | + model_name (str): Name to use for creating an Amazon SageMaker |
| 459 | + model. If not specified, the name of the training job is used. |
| 460 | + """ |
| 461 | + role = role or self.role |
| 462 | + |
| 463 | + if self.latest_training_job is None: |
| 464 | + logging.warning( |
| 465 | + "No finished training job found associated with this estimator. Please make sure " |
| 466 | + "this estimator is only used for building workflow config" |
| 467 | + ) |
| 468 | + return Transformer( |
| 469 | + model_name or self._current_job_name, |
| 470 | + instance_count, |
| 471 | + instance_type, |
| 472 | + strategy=strategy, |
| 473 | + assemble_with=assemble_with, |
| 474 | + output_path=output_path, |
| 475 | + output_kms_key=output_kms_key, |
| 476 | + accept=accept, |
| 477 | + max_concurrent_transforms=max_concurrent_transforms, |
| 478 | + max_payload=max_payload, |
| 479 | + env=env or {}, |
| 480 | + tags=tags, |
| 481 | + base_transform_job_name=self.base_job_name, |
| 482 | + volume_kms_key=volume_kms_key, |
| 483 | + sagemaker_session=self.sagemaker_session, |
| 484 | + ) |
| 485 | + |
| 486 | + if enable_network_isolation is None: |
| 487 | + enable_network_isolation = self.enable_network_isolation() |
| 488 | + |
| 489 | + model = self.create_model( |
| 490 | + role=role, |
| 491 | + vpc_config_override=vpc_config_override, |
| 492 | + entry_point=entry_point, |
| 493 | + enable_network_isolation=enable_network_isolation, |
| 494 | + name=model_name, |
| 495 | + ) |
| 496 | + |
| 497 | + return model.transformer( |
| 498 | + instance_count, |
| 499 | + instance_type, |
| 500 | + strategy=strategy, |
| 501 | + assemble_with=assemble_with, |
| 502 | + output_path=output_path, |
| 503 | + output_kms_key=output_kms_key, |
| 504 | + accept=accept, |
| 505 | + env=env, |
| 506 | + max_concurrent_transforms=max_concurrent_transforms, |
| 507 | + max_payload=max_payload, |
| 508 | + tags=tags, |
| 509 | + volume_kms_key=volume_kms_key, |
| 510 | + ) |
0 commit comments