Skip to content

Commit b9c6b5f

Browse files
authored
Merge branch 'master' into clarify-acc
2 parents 7933173 + 6adb716 commit b9c6b5f

29 files changed

+461
-28
lines changed

CHANGELOG.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,36 @@
11
# Changelog
22

3+
## v2.33.0 (2021-04-05)
4+
5+
### Features
6+
7+
* Add environment variable support for SageMaker training job
8+
9+
### Bug Fixes and Other Changes
10+
11+
* add version length mismatch validation for HuggingFace
12+
* Disable debugger when checkpointing is enabled with distributed training
13+
* map user context is list associations response
14+
15+
### Testing and Release Infrastructure
16+
17+
* disable_profiler on mx-horovod test
18+
19+
## v2.32.1 (2021-04-01)
20+
21+
### Bug Fixes and Other Changes
22+
23+
* disable profiler in some release tests
24+
* remove outdated notebook from test
25+
* add compilation option for ml_eia2
26+
* add short version to smdataparallel supported list
27+
28+
### Documentation Changes
29+
30+
* creating a "latest" version sm distributed docs
31+
* add docs for Sagemaker Model Parallel 1.3, released with PT 1.8
32+
* update PyTorch version in doc
33+
334
## v2.32.0 (2021-03-26)
435

536
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.32.1.dev0
1+
2.33.1.dev0

src/sagemaker/estimator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
enable_network_isolation=False,
124124
profiler_config=None,
125125
disable_profiler=False,
126+
environment=None,
126127
**kwargs,
127128
):
128129
"""Initialize an ``EstimatorBase`` instance.
@@ -266,6 +267,8 @@ def __init__(
266267
``disable_profiler`` parameter to ``True``.
267268
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
268269
will be disabled (default: ``False``).
270+
environment (dict[str, str]) : Environment variables to be set for
271+
use during training job (default: ``None``)
269272
270273
"""
271274
instance_count = renamed_kwargs(
@@ -352,6 +355,8 @@ def __init__(
352355
self.profiler_config = profiler_config
353356
self.disable_profiler = disable_profiler
354357

358+
self.environment = environment
359+
355360
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
356361
self.disable_profiler = True
357362

@@ -1471,6 +1476,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14711476
train_args["tags"] = estimator.tags
14721477
train_args["metric_definitions"] = estimator.metric_definitions
14731478
train_args["experiment_config"] = experiment_config
1479+
train_args["environment"] = estimator.environment
14741480

14751481
if isinstance(inputs, TrainingInput):
14761482
if "InputMode" in inputs.config:
@@ -1659,6 +1665,7 @@ def __init__(
16591665
enable_sagemaker_metrics=None,
16601666
profiler_config=None,
16611667
disable_profiler=False,
1668+
environment=None,
16621669
**kwargs,
16631670
):
16641671
"""Initialize an ``Estimator`` instance.
@@ -1807,6 +1814,8 @@ def __init__(
18071814
``disable_profiler`` parameter to ``True``.
18081815
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
18091816
will be disabled (default: ``False``).
1817+
environment (dict[str, str]) : Environment variables to be set for
1818+
use during training job (default: ``None``)
18101819
"""
18111820
self.image_uri = image_uri
18121821
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
@@ -1840,6 +1849,7 @@ def __init__(
18401849
enable_network_isolation=enable_network_isolation,
18411850
profiler_config=profiler_config,
18421851
disable_profiler=disable_profiler,
1852+
environment=environment,
18431853
**kwargs,
18441854
)
18451855

@@ -2209,7 +2219,21 @@ def _validate_and_set_debugger_configs(self):
22092219
):
22102220
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
22112221
elif not self.debugger_hook_config:
2212-
self.debugger_hook_config = None
2222+
# set hook config to False if _region_supports_debugger is False
2223+
self.debugger_hook_config = False
2224+
2225+
# Disable debugger if checkpointing is enabled by the customer
2226+
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
2227+
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2228+
if self.instance_count > 1 or (
2229+
hasattr(self, "distribution")
2230+
and self.distribution is not None # pylint: disable=no-member
2231+
):
2232+
logger.info(
2233+
"SMDebug Does Not Currently Support \
2234+
Distributed Training Jobs With Checkpointing Enabled"
2235+
)
2236+
self.debugger_hook_config = False
22132237

22142238
def _stage_user_code_in_s3(self):
22152239
"""Upload the user training script to s3 and return the location.

src/sagemaker/huggingface/estimator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _validate_args(self, image_uri):
189189
"""Placeholder docstring"""
190190
if image_uri is not None:
191191
return
192+
192193
if self.framework_version is None and image_uri is None:
193194
raise ValueError(
194195
"transformers_version, and image_uri are both None. "
@@ -204,6 +205,17 @@ def _validate_args(self, image_uri):
204205
"tensorflow_version and pytorch_version are both None. "
205206
"Specify either tensorflow_version or pytorch_version."
206207
)
208+
base_framework_version_len = (
209+
len(self.tensorflow_version.split("."))
210+
if self.tensorflow_version is not None
211+
else len(self.pytorch_version.split("."))
212+
)
213+
transformers_version_len = len(self.framework_version.split("."))
214+
if transformers_version_len != base_framework_version_len:
215+
raise ValueError(
216+
"Please use either full version or shortened version for both "
217+
"transformers_version, tensorflow_version and pytorch_version."
218+
)
207219

208220
def hyperparameters(self):
209221
"""Return hyperparameters used by your custom PyTorch code during model training."""

src/sagemaker/lineage/_api_types.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,36 @@ class ContextSummary(_base_types.ApiObject):
181181
last_modified_time = None
182182

183183

184+
class UserContext(_base_types.ApiObject):
185+
"""Summary model of a user context.
186+
187+
Attributes:
188+
user_profile_arn (str): User profile ARN.
189+
user_profile_name (str): User profile name.
190+
domain_id (str): DomainId.
191+
"""
192+
193+
user_profile_arn = None
194+
user_profile_name = None
195+
domain_id = None
196+
197+
def __init__(self, user_profile_arn=None, user_profile_name=None, domain_id=None, **kwargs):
198+
"""Initialize UserContext.
199+
200+
Args:
201+
user_profile_arn (str): User profile ARN.
202+
user_profile_name (str): User profile name.
203+
domain_id (str): DomainId.
204+
**kwargs: Arbitrary keyword arguments.
205+
"""
206+
super(UserContext, self).__init__(
207+
user_profile_arn=user_profile_arn,
208+
user_profile_name=user_profile_name,
209+
domain_id=domain_id,
210+
**kwargs
211+
)
212+
213+
184214
class AssociationSummary(_base_types.ApiObject):
185215
"""Summary model of an association.
186216
@@ -196,6 +226,9 @@ class AssociationSummary(_base_types.ApiObject):
196226
created_by (obj): Context on creator.
197227
"""
198228

229+
_custom_boto_types = {
230+
"created_by": (UserContext, False),
231+
}
199232
source_arn = None
200233
source_name = None
201234
destination_arn = None
@@ -204,4 +237,3 @@ class AssociationSummary(_base_types.ApiObject):
204237
destination_type = None
205238
association_type = None
206239
creation_time = None
207-
created_by = None

src/sagemaker/session.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def train( # noqa: C901
456456
enable_sagemaker_metrics=None,
457457
profiler_rule_configs=None,
458458
profiler_config=None,
459+
environment=None,
459460
):
460461
"""Create an Amazon SageMaker training job.
461462
@@ -522,9 +523,12 @@ def train( # noqa: C901
522523
Series. For more information see:
523524
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
524525
(default: ``None``).
525-
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
526+
profiler_rule_configs (list[dict]): A list of profiler rule
527+
configurations.src/sagemaker/lineage/artifact.py:285
526528
profiler_config (dict): Configuration for how profiling information is emitted
527529
with SageMaker Profiler. (default: ``None``).
530+
environment (dict[str, str]) : Environment variables to be set for
531+
use during training job (default: ``None``)
528532
529533
Returns:
530534
str: ARN of the training job, if it is created.
@@ -556,6 +560,7 @@ def train( # noqa: C901
556560
enable_sagemaker_metrics=enable_sagemaker_metrics,
557561
profiler_rule_configs=profiler_rule_configs,
558562
profiler_config=profiler_config,
563+
environment=environment,
559564
)
560565
LOGGER.info("Creating training-job with name: %s", job_name)
561566
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
@@ -588,6 +593,7 @@ def _get_train_request( # noqa: C901
588593
enable_sagemaker_metrics=None,
589594
profiler_rule_configs=None,
590595
profiler_config=None,
596+
environment=None,
591597
):
592598
"""Constructs a request compatible for creating an Amazon SageMaker training job.
593599
@@ -657,6 +663,8 @@ def _get_train_request( # noqa: C901
657663
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
658664
profiler_config(dict): Configuration for how profiling information is emitted with
659665
SageMaker Profiler. (default: ``None``).
666+
environment (dict[str, str]) : Environment variables to be set for
667+
use during training job (default: ``None``)
660668
661669
Returns:
662670
Dict: a training request dict
@@ -699,6 +707,9 @@ def _get_train_request( # noqa: C901
699707
if hyperparameters and len(hyperparameters) > 0:
700708
train_request["HyperParameters"] = hyperparameters
701709

710+
if environment is not None:
711+
train_request["Environment"] = environment
712+
702713
if tags is not None:
703714
train_request["Tags"] = tags
704715

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from packaging import version
1919

2020
from sagemaker import image_uris, s3, utils
21-
from sagemaker.debugger import DebuggerHookConfig
2221
from sagemaker.deprecations import renamed_kwargs
2322
from sagemaker.estimator import Framework
2423
import sagemaker.fw_utils as fw
@@ -347,6 +346,7 @@ def _validate_and_set_debugger_configs(self):
347346
348347
Else, set default HookConfig
349348
"""
349+
super(TensorFlow, self)._validate_and_set_debugger_configs()
350350
ps_enabled = "parameter_server" in self.distribution and self.distribution[
351351
"parameter_server"
352352
].get("enabled", False)
@@ -358,11 +358,6 @@ def _validate_and_set_debugger_configs(self):
358358
)
359359
self.debugger_hook_config = None
360360
self.debugger_rule_configs = None
361-
elif self.debugger_hook_config is None and fw._region_supports_debugger(
362-
self.sagemaker_session.boto_session.region_name
363-
):
364-
# Set defaults for debugging.
365-
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
366361

367362
def transformer(
368363
self,

tests/conftest.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,31 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190190
return "py3"
191191

192192

193+
def _huggingface_pytorch_version(huggingface_vesion):
194+
config = image_uris.config_for_framework("huggingface")
195+
training_config = config.get("training")
196+
original_version = huggingface_vesion
197+
if "version_aliases" in training_config:
198+
huggingface_vesion = training_config.get("version_aliases").get(
199+
huggingface_vesion, huggingface_vesion
200+
)
201+
version_config = training_config.get("versions").get(huggingface_vesion)
202+
for key in list(version_config.keys()):
203+
if key.startswith("pytorch"):
204+
pt_version = key[7:]
205+
if len(original_version.split(".")) == 2:
206+
pt_version = ".".join(pt_version.split(".")[:-1])
207+
return pt_version
208+
209+
193210
@pytest.fixture(scope="module")
194211
def huggingface_pytorch_version(huggingface_training_version):
195-
if Version(huggingface_training_version) <= Version("4.4.2"):
196-
return "1.6.0"
197-
else:
198-
pytest.skip("Skipping Huggingface version.")
212+
return _huggingface_pytorch_version(huggingface_training_version)
213+
214+
215+
@pytest.fixture(scope="module")
216+
def huggingface_pytorch_latest_version(huggingface_training_latest_version):
217+
return _huggingface_pytorch_version(huggingface_training_latest_version)
199218

200219

201220
@pytest.fixture(scope="module")

tests/integ/sagemaker/lineage/test_association.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def test_list(association_objs, sagemaker_session):
5656
# sanity check
5757
assert association_keys_listed
5858

59+
for listed_asscn in listed:
60+
assert listed_asscn.created_by is None
61+
5962

6063
@pytest.mark.timeout(30)
6164
def test_set_tag(association_obj, sagemaker_session):

0 commit comments

Comments
 (0)