Skip to content

Commit 813625f

Browse files
icywang86ruiRui Wang Napieralski
andauthored
change: refactor out batch transform job input generation (#1955)
* change: refactor out batch transform job input generation * Remove unused imports Co-authored-by: Rui Wang Napieralski <[email protected]>
1 parent 21ef053 commit 813625f

File tree

2 files changed

+129
-17
lines changed

2 files changed

+129
-17
lines changed

src/sagemaker/session.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,7 +2171,7 @@ def stop_tuning_job(self, name):
21712171
)
21722172
raise
21732173

2174-
def transform(
2174+
def _get_transform_request(
21752175
self,
21762176
job_name,
21772177
model_name,
@@ -2187,7 +2187,7 @@ def transform(
21872187
data_processing,
21882188
model_client_config=None,
21892189
):
2190-
"""Create an Amazon SageMaker transform job.
2190+
"""Construct an dict can be used to create an Amazon SageMaker transform job.
21912191
21922192
Args:
21932193
job_name (str): Name of the transform job being created.
@@ -2213,6 +2213,9 @@ def transform(
22132213
model_client_config (dict): A dictionary describing the model configuration for the
22142214
job. Dictionary contains two optional keys,
22152215
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
2216+
2217+
Returns:
2218+
Dict: a create transform job request dict
22162219
"""
22172220
transform_request = {
22182221
"TransformJobName": job_name,
@@ -2246,6 +2249,67 @@ def transform(
22462249
if model_client_config and len(model_client_config) > 0:
22472250
transform_request["ModelClientConfig"] = model_client_config
22482251

2252+
return transform_request
2253+
2254+
def transform(
2255+
self,
2256+
job_name,
2257+
model_name,
2258+
strategy,
2259+
max_concurrent_transforms,
2260+
max_payload,
2261+
env,
2262+
input_config,
2263+
output_config,
2264+
resource_config,
2265+
experiment_config,
2266+
tags,
2267+
data_processing,
2268+
model_client_config=None,
2269+
):
2270+
"""Create an Amazon SageMaker transform job.
2271+
2272+
Args:
2273+
job_name (str): Name of the transform job being created.
2274+
model_name (str): Name of the SageMaker model being used for the transform job.
2275+
strategy (str): The strategy used to decide how to batch records in a single request.
2276+
Possible values are 'MultiRecord' and 'SingleRecord'.
2277+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
2278+
each individual transform container at one time.
2279+
max_payload (int): Maximum size of the payload in a single HTTP request to the
2280+
container in MB.
2281+
env (dict): Environment variables to be set for use during the transform job.
2282+
input_config (dict): A dictionary describing the input data (and its location) for the
2283+
job.
2284+
output_config (dict): A dictionary describing the output location for the job.
2285+
resource_config (dict): A dictionary describing the resources to complete the job.
2286+
experiment_config (dict): A dictionary describing the experiment configuration for the
2287+
job. Dictionary contains three optional keys,
2288+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
2289+
tags (list[dict]): List of tags for labeling a transform job.
2290+
data_processing(dict): A dictionary describing config for combining the input data and
2291+
transformed data. For more, see
2292+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2293+
model_client_config (dict): A dictionary describing the model configuration for the
2294+
job. Dictionary contains two optional keys,
2295+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
2296+
"""
2297+
transform_request = self._get_transform_request(
2298+
job_name=job_name,
2299+
model_name=model_name,
2300+
strategy=strategy,
2301+
max_concurrent_transforms=max_concurrent_transforms,
2302+
max_payload=max_payload,
2303+
env=env,
2304+
input_config=input_config,
2305+
output_config=output_config,
2306+
resource_config=resource_config,
2307+
experiment_config=experiment_config,
2308+
tags=tags,
2309+
data_processing=data_processing,
2310+
model_client_config=model_client_config,
2311+
)
2312+
22492313
LOGGER.info("Creating transform job with name: %s", job_name)
22502314
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
22512315
self.sagemaker_client.create_transform_job(**transform_request)

src/sagemaker/transformer.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -363,30 +363,78 @@ def start_new(
363363
experiment_config:
364364
model_client_config:
365365
"""
366+
367+
transform_args = cls._get_transform_args(
368+
transformer,
369+
data,
370+
data_type,
371+
content_type,
372+
compression_type,
373+
split_type,
374+
input_filter,
375+
output_filter,
376+
join_source,
377+
experiment_config,
378+
model_client_config,
379+
)
380+
transformer.sagemaker_session.transform(**transform_args)
381+
382+
return cls(transformer.sagemaker_session, transformer._current_job_name)
383+
384+
@classmethod
385+
def _get_transform_args(
386+
cls,
387+
transformer,
388+
data,
389+
data_type,
390+
content_type,
391+
compression_type,
392+
split_type,
393+
input_filter,
394+
output_filter,
395+
join_source,
396+
experiment_config,
397+
model_client_config,
398+
):
399+
"""
400+
Args:
401+
transformer:
402+
data:
403+
data_type:
404+
content_type:
405+
compression_type:
406+
split_type:
407+
input_filter:
408+
output_filter:
409+
join_source:
410+
experiment_config:
411+
model_client_config:
412+
"""
413+
366414
config = _TransformJob._load_config(
367415
data, data_type, content_type, compression_type, split_type, transformer
368416
)
369417
data_processing = _TransformJob._prepare_data_processing(
370418
input_filter, output_filter, join_source
371419
)
372420

373-
transformer.sagemaker_session.transform(
374-
job_name=transformer._current_job_name,
375-
model_name=transformer.model_name,
376-
strategy=transformer.strategy,
377-
max_concurrent_transforms=transformer.max_concurrent_transforms,
378-
max_payload=transformer.max_payload,
379-
env=transformer.env,
380-
input_config=config["input_config"],
381-
output_config=config["output_config"],
382-
resource_config=config["resource_config"],
383-
experiment_config=experiment_config,
384-
model_client_config=model_client_config,
385-
tags=transformer.tags,
386-
data_processing=data_processing,
421+
transform_args = config.copy()
422+
transform_args.update(
423+
{
424+
"job_name": transformer._current_job_name,
425+
"model_name": transformer.model_name,
426+
"strategy": transformer.strategy,
427+
"max_concurrent_transforms": transformer.max_concurrent_transforms,
428+
"max_payload": transformer.max_payload,
429+
"env": transformer.env,
430+
"experiment_config": experiment_config,
431+
"model_client_config": model_client_config,
432+
"tags": transformer.tags,
433+
"data_processing": data_processing,
434+
}
387435
)
388436

389-
return cls(transformer.sagemaker_session, transformer._current_job_name)
437+
return transform_args
390438

391439
def wait(self, logs=True):
392440
if logs:

0 commit comments

Comments
 (0)