|
18 | 18 | import attr
|
19 | 19 |
|
20 | 20 | from sagemaker.estimator import EstimatorBase
|
| 21 | +from sagemaker.model import Model |
| 22 | +from sagemaker.predictor import Predictor |
| 23 | +from sagemaker.transformer import Transformer |
21 | 24 | from sagemaker.workflow.entities import RequestType
|
22 |
| -from sagemaker.workflow.steps import Step |
| 25 | +from sagemaker.workflow.steps import ( |
| 26 | + CreateModelStep, |
| 27 | + Step, |
| 28 | + TransformStep, |
| 29 | +) |
23 | 30 | from sagemaker.workflow._utils import (
|
24 | 31 | _RegisterModelStep,
|
25 | 32 | _RepackModelStep,
|
@@ -114,3 +121,131 @@ def __init__(
|
114 | 121 | )
|
115 | 122 | steps.append(register_model_step)
|
116 | 123 | self.steps = steps
|
| 124 | + |
| 125 | + |
| 126 | +class EstimatorTransformer(StepCollection): |
| 127 | + """Creates a Transformer step collection for workflow. |
| 128 | +
|
| 129 | + Attributes: |
| 130 | + steps (List[Step]): A list of steps. |
| 131 | + """ |
| 132 | + |
| 133 | + def __init__( |
| 134 | + self, |
| 135 | + name: str, |
| 136 | + estimator: EstimatorBase, |
| 137 | + model_data, |
| 138 | + model_inputs, |
| 139 | + instance_count, |
| 140 | + instance_type, |
| 141 | + transform_inputs, |
| 142 | + # model arguments |
| 143 | + image_uri=None, |
| 144 | + predictor_cls=None, |
| 145 | + env=None, |
| 146 | + # transformer arguments |
| 147 | + strategy=None, |
| 148 | + assemble_with=None, |
| 149 | + output_path=None, |
| 150 | + output_kms_key=None, |
| 151 | + accept=None, |
| 152 | + max_concurrent_transforms=None, |
| 153 | + max_payload=None, |
| 154 | + tags=None, |
| 155 | + volume_kms_key=None, |
| 156 | + **kwargs, |
| 157 | + ): |
| 158 | + """Constructs steps required for transformation: |
| 159 | +
|
| 160 | + An estimator-centric step collection, it models what occurs in current workflows |
| 161 | + with invoking the `transform()` method on an estimator instance: first, if custom |
| 162 | + model artifacts are required, a `_RepackModelStep` is included; second, a |
| 163 | + `CreateModelStep` with the model data passed in from a training step or other |
| 164 | + training job output; finally, a `TransformerStep`. |
| 165 | +
|
| 166 | + If repacking |
| 167 | + the model artifacts is not necessary, only the CreateModelStep and TransformerStep |
| 168 | + are in the step collection. |
| 169 | + Args: |
| 170 | + name (str): The name of the Transform Step. |
| 171 | + estimator: The estimator instance. |
| 172 | + instance_count (int): Number of EC2 instances to use. |
| 173 | + instance_type (str): Type of EC2 instance to use, for example, |
| 174 | + 'ml.c4.xlarge'. |
| 175 | + strategy (str): The strategy used to decide how to batch records in |
| 176 | + a single request (default: None). Valid values: 'MultiRecord' |
| 177 | + and 'SingleRecord'. |
| 178 | + assemble_with (str): How the output is assembled (default: None). |
| 179 | + Valid values: 'Line' or 'None'. |
| 180 | + output_path (str): S3 location for saving the transform result. If |
| 181 | + not specified, results are stored to a default bucket. |
| 182 | + output_kms_key (str): Optional. KMS key ID for encrypting the |
| 183 | + transform output (default: None). |
| 184 | + accept (str): The accept header passed by the client to |
| 185 | + the inference endpoint. If it is supported by the endpoint, |
| 186 | + it will be the format of the batch transform output. |
| 187 | + env (dict): Environment variables to be set for use during the |
| 188 | + transform job (default: None). |
| 189 | + """ |
| 190 | + steps = [] |
| 191 | + if "entry_point" in kwargs: |
| 192 | + entry_point = kwargs["entry_point"] |
| 193 | + source_dir = kwargs.get("source_dir") |
| 194 | + dependencies = kwargs.get("dependencies") |
| 195 | + repack_model_step = _RepackModelStep( |
| 196 | + name=f"{name}RepackModel", |
| 197 | + estimator=estimator, |
| 198 | + model_data=model_data, |
| 199 | + entry_point=entry_point, |
| 200 | + source_dir=source_dir, |
| 201 | + dependencies=dependencies, |
| 202 | + ) |
| 203 | + steps.append(repack_model_step) |
| 204 | + model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts |
| 205 | + |
| 206 | + def predict_wrapper(endpoint, session): |
| 207 | + return Predictor(endpoint, session) |
| 208 | + |
| 209 | + predictor_cls = predictor_cls or predict_wrapper |
| 210 | + |
| 211 | + model = Model( |
| 212 | + image_uri=image_uri or estimator.training_image_uri(), |
| 213 | + model_data=model_data, |
| 214 | + predictor_cls=predictor_cls, |
| 215 | + vpc_config=None, |
| 216 | + sagemaker_session=estimator.sagemaker_session, |
| 217 | + role=estimator.role, |
| 218 | + **kwargs, |
| 219 | + ) |
| 220 | + model_step = CreateModelStep( |
| 221 | + name=f"{name}CreateModelStep", |
| 222 | + model=model, |
| 223 | + inputs=model_inputs, |
| 224 | + ) |
| 225 | + steps.append(model_step) |
| 226 | + |
| 227 | + transformer = Transformer( |
| 228 | + model_name=model_step.properties.ModelName, |
| 229 | + instance_count=instance_count, |
| 230 | + instance_type=instance_type, |
| 231 | + strategy=strategy, |
| 232 | + assemble_with=assemble_with, |
| 233 | + output_path=output_path, |
| 234 | + output_kms_key=output_kms_key, |
| 235 | + accept=accept, |
| 236 | + max_concurrent_transforms=max_concurrent_transforms, |
| 237 | + max_payload=max_payload, |
| 238 | + env=env, |
| 239 | + tags=tags, |
| 240 | + base_transform_job_name=name, |
| 241 | + volume_kms_key=volume_kms_key, |
| 242 | + sagemaker_session=estimator.sagemaker_session, |
| 243 | + ) |
| 244 | + transform_step = TransformStep( |
| 245 | + name=f"{name}TransformStep", |
| 246 | + transformer=transformer, |
| 247 | + inputs=transform_inputs, |
| 248 | + ) |
| 249 | + steps.append(transform_step) |
| 250 | + |
| 251 | + self.steps = steps |
0 commit comments