Skip to content

Commit 2aa256e

Browse files
committed
Merge branch 'aws:master' into master
2 parents a6a8449 + abe8399 commit 2aa256e

33 files changed

+1159
-67
lines changed

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# Changelog
22

3+
## v2.41.0 (2021-05-17)
4+
5+
### Features
6+
7+
* add pipeline experiment config
8+
* add data wrangler processor
9+
* support RetryStrategy for training jobs
10+
11+
### Bug Fixes and Other Changes
12+
13+
* fix repack pipeline step by putting inference.py in "code" sub dir
14+
* add data wrangler image uri
15+
* fix black-check errors
16+
317
## v2.40.0 (2021-05-11)
418

519
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.40.1.dev0
1+
2.41.1.dev0

doc/overview.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ Here are examples of how to use Amazon FSx for Lustre as input for training:
374374
375375
file_system_input = FileSystemInput(file_system_id='fs-2',
376376
file_system_type='FSxLustre',
377-
directory_path='/fsx/tensorflow',
377+
directory_path='/<mount-id>/tensorflow',
378378
file_system_access_mode='ro')
379379
380380
# Start an Amazon SageMaker training job with FSx using the FileSystemInput class
@@ -394,7 +394,7 @@ Here are examples of how to use Amazon FSx for Lustre as input for training:
394394
395395
records = FileSystemRecordSet(file_system_id='fs-=2,
396396
file_system_type='FSxLustre',
397-
directory_path='/fsx/kmeans',
397+
directory_path='/<mount-id>/kmeans',
398398
num_records=784,
399399
feature_dim=784)
400400

src/sagemaker/estimator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
profiler_config=None,
125125
disable_profiler=False,
126126
environment=None,
127+
max_retry_attempts=None,
127128
**kwargs,
128129
):
129130
"""Initialize an ``EstimatorBase`` instance.
@@ -269,6 +270,13 @@ def __init__(
269270
will be disabled (default: ``False``).
270271
environment (dict[str, str]) : Environment variables to be set for
271272
use during training job (default: ``None``)
273+
max_retry_attempts (int): The number of times to move a job to the STARTING status.
274+
You can specify between 1 and 30 attempts.
275+
If the value of attempts is greater than zero,
276+
the job is retried on InternalServerFailure
277+
the same number of attempts as the value.
278+
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
279+
(default: ``None``)
272280
273281
"""
274282
instance_count = renamed_kwargs(
@@ -357,6 +365,8 @@ def __init__(
357365

358366
self.environment = environment
359367

368+
self.max_retry_attempts = max_retry_attempts
369+
360370
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
361371
self.disable_profiler = True
362372

@@ -1114,6 +1124,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
11141124
if max_wait:
11151125
init_params["max_wait"] = max_wait
11161126

1127+
if job_details.get("RetryStrategy", False):
1128+
init_params["max_retry_attempts"] = job_details.get("RetryStrategy", {}).get(
1129+
"MaximumRetryAttempts"
1130+
)
1131+
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
1132+
if max_wait:
1133+
init_params["max_wait"] = max_wait
11171134
return init_params
11181135

11191136
def transformer(
@@ -1489,6 +1506,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14891506
if estimator.enable_network_isolation():
14901507
train_args["enable_network_isolation"] = True
14911508

1509+
if estimator.max_retry_attempts is not None:
1510+
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
1511+
else:
1512+
train_args["retry_strategy"] = None
1513+
14921514
if estimator.encrypt_inter_container_traffic:
14931515
train_args["encrypt_inter_container_traffic"] = True
14941516

@@ -1666,6 +1688,7 @@ def __init__(
16661688
profiler_config=None,
16671689
disable_profiler=False,
16681690
environment=None,
1691+
max_retry_attempts=None,
16691692
**kwargs,
16701693
):
16711694
"""Initialize an ``Estimator`` instance.
@@ -1816,6 +1839,13 @@ def __init__(
18161839
will be disabled (default: ``False``).
18171840
environment (dict[str, str]) : Environment variables to be set for
18181841
use during training job (default: ``None``)
1842+
max_retry_attempts (int): The number of times to move a job to the STARTING status.
1843+
You can specify between 1 and 30 attempts.
1844+
If the value of attempts is greater than zero,
1845+
the job is retried on InternalServerFailure
1846+
the same number of attempts as the value.
1847+
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
1848+
(default: ``None``)
18191849
"""
18201850
self.image_uri = image_uri
18211851
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
@@ -1850,6 +1880,7 @@ def __init__(
18501880
profiler_config=profiler_config,
18511881
disable_profiler=disable_profiler,
18521882
environment=environment,
1883+
max_retry_attempts=max_retry_attempts,
18531884
**kwargs,
18541885
)
18551886

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"processing": {
3+
"versions": {
4+
"1.x": {
5+
"registries": {
6+
"af-south-1": "143210264188",
7+
"ap-east-1": "707077482487",
8+
"ap-northeast-1": "649008135260",
9+
"ap-northeast-2": "131546521161",
10+
"ap-south-1": "089933028263",
11+
"ap-southeast-1": "119527597002",
12+
"ap-southeast-2": "422173101802",
13+
"ca-central-1": "557239378090",
14+
"eu-central-1": "024640144536",
15+
"eu-north-1": "054986407534",
16+
"eu-south-1": "488287956546",
17+
"eu-west-1": "245179582081",
18+
"eu-west-2": "894491911112",
19+
"eu-west-3": "807237891255",
20+
"me-south-1": "376037874950",
21+
"sa-east-1": "424196993095",
22+
"us-east-1": "663277389841",
23+
"us-east-2": "415577184552",
24+
"us-west-1": "926135532090",
25+
"us-west-2": "174368400705",
26+
"cn-north-1": "245909111842",
27+
"cn-northwest-1": "249157047649"
28+
},
29+
"repository": "sagemaker-data-wrangler-container"
30+
}
31+
}
32+
}
33+
}

src/sagemaker/session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def train( # noqa: C901
457457
profiler_rule_configs=None,
458458
profiler_config=None,
459459
environment=None,
460+
retry_strategy=None,
460461
):
461462
"""Create an Amazon SageMaker training job.
462463
@@ -529,6 +530,9 @@ def train( # noqa: C901
529530
with SageMaker Profiler. (default: ``None``).
530531
environment (dict[str, str]) : Environment variables to be set for
531532
use during training job (default: ``None``)
533+
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
534+
* max_retry_attsmpts (int): Number of times a job should be retried.
535+
The key in RetryStrategy is 'MaxRetryAttempts'.
532536
533537
Returns:
534538
str: ARN of the training job, if it is created.
@@ -561,6 +565,7 @@ def train( # noqa: C901
561565
profiler_rule_configs=profiler_rule_configs,
562566
profiler_config=profiler_config,
563567
environment=environment,
568+
retry_strategy=retry_strategy,
564569
)
565570
LOGGER.info("Creating training-job with name: %s", job_name)
566571
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
@@ -594,6 +599,7 @@ def _get_train_request( # noqa: C901
594599
profiler_rule_configs=None,
595600
profiler_config=None,
596601
environment=None,
602+
retry_strategy=None,
597603
):
598604
"""Constructs a request compatible for creating an Amazon SageMaker training job.
599605
@@ -665,6 +671,9 @@ def _get_train_request( # noqa: C901
665671
SageMaker Profiler. (default: ``None``).
666672
environment (dict[str, str]) : Environment variables to be set for
667673
use during training job (default: ``None``)
674+
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
675+
* max_retry_attsmpts (int): Number of times a job should be retried.
676+
The key in RetryStrategy is 'MaxRetryAttempts'.
668677
669678
Returns:
670679
Dict: a training request dict
@@ -749,6 +758,9 @@ def _get_train_request( # noqa: C901
749758
if profiler_config is not None:
750759
train_request["ProfilerConfig"] = profiler_config
751760

761+
if retry_strategy is not None:
762+
train_request["RetryStrategy"] = retry_strategy
763+
752764
return train_request
753765

754766
def update_training_job(

src/sagemaker/workflow/_repack_model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
import tarfile
2020
import tempfile
2121

22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
2229
# distutils.dir_util.copy_tree works way better than the half-baked
2330
# shutil.copytree which bombs on previously existing target dirs...
2431
# alas ... https://bugs.python.org/issue10948
@@ -33,17 +40,28 @@
3340
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
3441
args = parser.parse_args()
3542

43+
# the data directory contains a model archive generated by a previous training job
3644
data_directory = "/opt/ml/input/data/training"
3745
model_path = os.path.join(data_directory, args.model_archive)
3846

47+
# create a temporary directory
3948
with tempfile.TemporaryDirectory() as tmp:
4049
local_path = os.path.join(tmp, "local.tar.gz")
50+
# copy the previous training job's model archive to the temporary directory
4151
shutil.copy2(model_path, local_path)
4252
src_dir = os.path.join(tmp, "src")
53+
# create the "code" directory which will contain the inference script
54+
os.makedirs(os.path.join(src_dir, "code"))
55+
# extract the contents of the previous training job's model archive to the "src"
56+
# directory of this training job
4357
with tarfile.open(name=local_path, mode="r:gz") as tf:
4458
tf.extractall(path=src_dir)
4559

60+
# generate a path to the custom inference script
4661
entry_point = os.path.join("/opt/ml/code", args.inference_script)
47-
shutil.copy2(entry_point, os.path.join(src_dir, args.inference_script))
62+
# copy the custom inference script to the "src" dir
63+
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
4864

65+
# copy the "src" dir, which includes the previous training job's model and the
66+
# custom inference script, to the output of this training job
4967
copy_tree(src_dir, "/opt/ml/model")

src/sagemaker/workflow/conditions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ def to_request(self) -> RequestType:
186186
"""Get the request structure for workflow service calls."""
187187
return {
188188
"Type": self.condition_type.value,
189-
"Value": self.value.expr,
190-
"In": [primitive_or_expr(in_value) for in_value in self.in_values],
189+
"QueryValue": self.value.expr,
190+
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
191191
}
192192

193193

src/sagemaker/workflow/execution_variables.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,56 +13,27 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import Dict
17-
1816
from sagemaker.workflow.entities import (
19-
Entity,
17+
Expression,
2018
RequestType,
2119
)
2220

2321

24-
class ExecutionVariable(Entity, str):
22+
class ExecutionVariable(Expression):
2523
"""Pipeline execution variables for workflow."""
2624

27-
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
28-
"""Subclass str"""
29-
value = ""
30-
if len(args) == 1:
31-
value = args[0] or value
32-
elif kwargs:
33-
value = kwargs.get("name", value)
34-
return str.__new__(cls, ExecutionVariable._expr(value))
35-
3625
def __init__(self, name: str):
3726
"""Create a pipeline execution variable.
3827
3928
Args:
4029
name (str): The name of the execution variable.
4130
"""
42-
super(ExecutionVariable, self).__init__()
4331
self.name = name
4432

45-
def __hash__(self):
46-
"""Hash function for execution variable types"""
47-
return hash(tuple(self.to_request()))
48-
49-
def to_request(self) -> RequestType:
50-
"""Get the request structure for workflow service calls."""
51-
return self.expr
52-
5333
@property
54-
def expr(self) -> Dict[str, str]:
34+
def expr(self) -> RequestType:
5535
"""The 'Get' expression dict for an `ExecutionVariable`."""
56-
return ExecutionVariable._expr(self.name)
57-
58-
@classmethod
59-
def _expr(cls, name):
60-
"""An internal classmethod for the 'Get' expression dict for an `ExecutionVariable`.
61-
62-
Args:
63-
name (str): The name of the execution variable.
64-
"""
65-
return {"Get": f"Execution.{name}"}
36+
return {"Get": f"Execution.{self.name}"}
6637

6738

6839
class ExecutionVariables:

0 commit comments

Comments
 (0)