Skip to content

Commit 24aad93

Browse files
metrizableDan Choi
authored andcommitted
feature: include estimator-centric transformer step (aws#467)
1 parent d8b4042 commit 24aad93

File tree

6 files changed

+196
-5
lines changed

6 files changed

+196
-5
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ ignored-modules=distutils
314314
# List of class names for which member attributes should not be checked (useful
315315
# for classes with dynamically set attributes). This supports the use of
316316
# qualified names.
317-
ignored-classes=optparse.Values,thread._local,_thread._local,matplotlib.cm,tensorflow.python,tensorflow,tensorflow.train.Example,RunOptions
317+
ignored-classes=optparse.Values,thread._local,_thread._local,matplotlib.cm,tensorflow.python,tensorflow,tensorflow.train.Example,RunOptions,sagemaker.workflow.properties.Properties
318318

319319
# List of members which are set dynamically and missed by pylint inference
320320
# system, and so shouldn't trigger E1101 when accessed. Python regular

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
logger = logging.getLogger(__name__)
6464

6565

66-
class EstimatorBase(with_metaclass(ABCMeta, object)):
66+
class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-many-public-methods
6767
"""Handle end-to-end Amazon SageMaker training and deployment tasks.
6868
6969
For introduction to model training and deployment, see

src/sagemaker/workflow/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
"model_archive": self._model_archive,
9999
},
100100
)
101+
repacker.disable_profiler = True
101102
inputs = TrainingInput(self._model_prefix)
102103

103104
# super!

src/sagemaker/workflow/step_collections.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,15 @@
1818
import attr
1919

2020
from sagemaker.estimator import EstimatorBase
21+
from sagemaker.model import Model
22+
from sagemaker.predictor import Predictor
23+
from sagemaker.transformer import Transformer
2124
from sagemaker.workflow.entities import RequestType
22-
from sagemaker.workflow.steps import Step
25+
from sagemaker.workflow.steps import (
26+
CreateModelStep,
27+
Step,
28+
TransformStep,
29+
)
2330
from sagemaker.workflow._utils import (
2431
_RegisterModelStep,
2532
_RepackModelStep,
@@ -114,3 +121,131 @@ def __init__(
114121
)
115122
steps.append(register_model_step)
116123
self.steps = steps
124+
125+
126+
class EstimatorTransformer(StepCollection):
127+
"""Creates a Transformer step collection for workflow.
128+
129+
Attributes:
130+
steps (List[Step]): A list of steps.
131+
"""
132+
133+
def __init__(
134+
self,
135+
name: str,
136+
estimator: EstimatorBase,
137+
model_data,
138+
model_inputs,
139+
instance_count,
140+
instance_type,
141+
transform_inputs,
142+
# model arguments
143+
image_uri=None,
144+
predictor_cls=None,
145+
env=None,
146+
# transformer arguments
147+
strategy=None,
148+
assemble_with=None,
149+
output_path=None,
150+
output_kms_key=None,
151+
accept=None,
152+
max_concurrent_transforms=None,
153+
max_payload=None,
154+
tags=None,
155+
volume_kms_key=None,
156+
**kwargs,
157+
):
158+
"""Constructs steps required for transformation:
159+
160+
An estimator-centric step collection, it models what occurs in current workflows
161+
with invoking the `transform()` method on an estimator instance: first, if custom
162+
model artifacts are required, a `_RepackModelStep` is included; second, a
163+
`CreateModelStep` with the model data passed in from a training step or other
164+
training job output; finally, a `TransformerStep`.
165+
166+
If repacking
167+
the model artifacts is not necessary, only the CreateModelStep and TransformerStep
168+
are in the step collection.
169+
Args:
170+
name (str): The name of the Transform Step.
171+
estimator: The estimator instance.
172+
instance_count (int): Number of EC2 instances to use.
173+
instance_type (str): Type of EC2 instance to use, for example,
174+
'ml.c4.xlarge'.
175+
strategy (str): The strategy used to decide how to batch records in
176+
a single request (default: None). Valid values: 'MultiRecord'
177+
and 'SingleRecord'.
178+
assemble_with (str): How the output is assembled (default: None).
179+
Valid values: 'Line' or 'None'.
180+
output_path (str): S3 location for saving the transform result. If
181+
not specified, results are stored to a default bucket.
182+
output_kms_key (str): Optional. KMS key ID for encrypting the
183+
transform output (default: None).
184+
accept (str): The accept header passed by the client to
185+
the inference endpoint. If it is supported by the endpoint,
186+
it will be the format of the batch transform output.
187+
env (dict): Environment variables to be set for use during the
188+
transform job (default: None).
189+
"""
190+
steps = []
191+
if "entry_point" in kwargs:
192+
entry_point = kwargs["entry_point"]
193+
source_dir = kwargs.get("source_dir")
194+
dependencies = kwargs.get("dependencies")
195+
repack_model_step = _RepackModelStep(
196+
name=f"{name}RepackModel",
197+
estimator=estimator,
198+
model_data=model_data,
199+
entry_point=entry_point,
200+
source_dir=source_dir,
201+
dependencies=dependencies,
202+
)
203+
steps.append(repack_model_step)
204+
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
205+
206+
def predict_wrapper(endpoint, session):
207+
return Predictor(endpoint, session)
208+
209+
predictor_cls = predictor_cls or predict_wrapper
210+
211+
model = Model(
212+
image_uri=image_uri or estimator.training_image_uri(),
213+
model_data=model_data,
214+
predictor_cls=predictor_cls,
215+
vpc_config=None,
216+
sagemaker_session=estimator.sagemaker_session,
217+
role=estimator.role,
218+
**kwargs,
219+
)
220+
model_step = CreateModelStep(
221+
name=f"{name}CreateModelStep",
222+
model=model,
223+
inputs=model_inputs,
224+
)
225+
steps.append(model_step)
226+
227+
transformer = Transformer(
228+
model_name=model_step.properties.ModelName,
229+
instance_count=instance_count,
230+
instance_type=instance_type,
231+
strategy=strategy,
232+
assemble_with=assemble_with,
233+
output_path=output_path,
234+
output_kms_key=output_kms_key,
235+
accept=accept,
236+
max_concurrent_transforms=max_concurrent_transforms,
237+
max_payload=max_payload,
238+
env=env,
239+
tags=tags,
240+
base_transform_job_name=name,
241+
volume_kms_key=volume_kms_key,
242+
sagemaker_session=estimator.sagemaker_session,
243+
)
244+
transform_step = TransformStep(
245+
name=f"{name}TransformStep",
246+
transformer=transformer,
247+
inputs=transform_inputs,
248+
)
249+
steps.append(transform_step)
250+
251+
self.steps = steps

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
)
2424

2525
from sagemaker.estimator import Estimator
26+
from sagemaker.inputs import CreateModelInput, TransformInput
2627
from sagemaker.workflow.properties import Properties
2728
from sagemaker.workflow.steps import (
2829
Step,
2930
StepTypeEnum,
3031
)
3132
from sagemaker.workflow.step_collections import (
33+
EstimatorTransformer,
3234
StepCollection,
3335
RegisterModel,
3436
)
@@ -101,7 +103,7 @@ def estimator(sagemaker_session):
101103
image_uri=IMAGE_URI,
102104
role=ROLE,
103105
instance_count=1,
104-
instance_type="c4.4xlarge",
106+
instance_type="ml.c4.4xlarge",
105107
sagemaker_session=sagemaker_session,
106108
)
107109

@@ -145,3 +147,56 @@ def test_register_model(estimator):
145147
},
146148
]
147149
)
150+
151+
152+
def test_estimator_transformer(estimator):
153+
model_data = f"s3://{BUCKET}/model.tar.gz"
154+
model_inputs = CreateModelInput(
155+
instance_type="c4.4xlarge",
156+
accelerator_type="ml.eia1.medium",
157+
)
158+
transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
159+
estimator_transformer = EstimatorTransformer(
160+
name="EstimatorTransformerStep",
161+
estimator=estimator,
162+
model_data=model_data,
163+
model_inputs=model_inputs,
164+
instance_count=1,
165+
instance_type="ml.c4.4xlarge",
166+
transform_inputs=transform_inputs,
167+
)
168+
request_dicts = estimator_transformer.request_dicts()
169+
assert len(request_dicts) == 2
170+
for request_dict in request_dicts:
171+
if request_dict["Type"] == "CreateModel":
172+
assert request_dict == {
173+
"Name": "EstimatorTransformerStepCreateModelStep",
174+
"Type": "CreateModel",
175+
"Arguments": {
176+
"ExecutionRoleArn": "DummyRole",
177+
"PrimaryContainer": {
178+
"Environment": {},
179+
"Image": "fakeimage",
180+
"ModelDataUrl": "s3://my-bucket/model.tar.gz",
181+
},
182+
},
183+
}
184+
elif request_dict["Type"] == "Transform":
185+
assert request_dict["Name"] == "EstimatorTransformerStepTransformStep"
186+
arguments = request_dict["Arguments"]
187+
assert isinstance(arguments["ModelName"], Properties)
188+
arguments.pop("ModelName")
189+
assert arguments == {
190+
"TransformInput": {
191+
"DataSource": {
192+
"S3DataSource": {
193+
"S3DataType": "S3Prefix",
194+
"S3Uri": f"s3://{BUCKET}/transform_manifest",
195+
}
196+
}
197+
},
198+
"TransformOutput": {"S3OutputPath": None},
199+
"TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"},
200+
}
201+
else:
202+
raise Exception("A step exists in the collection of an invalid type.")

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ commands =
147147
[testenv:docstyle]
148148
deps = pydocstyle
149149
commands =
150-
pydocstyle src/sagemaker/{posargs}
150+
pydocstyle src/sagemaker
151151

152152
[testenv:collect-tests]
153153
# this needs to succeed for tests to display in some IDEs

0 commit comments

Comments
 (0)