17
17
import re
18
18
from typing import Optional , Union , Dict
19
19
20
- from sagemaker .deprecations import renamed_kwargs
21
20
from sagemaker .estimator import Framework , EstimatorBase
22
21
from sagemaker .fw_utils import (
23
22
framework_name_from_image ,
24
- warn_if_parameter_server_with_multi_gpu ,
25
- validate_smdistributed ,
23
+ validate_distribution ,
26
24
)
27
25
from sagemaker .huggingface .model import HuggingFaceModel
28
26
from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
@@ -37,6 +35,9 @@ class HuggingFace(Framework):
37
35
"""Handle training of custom HuggingFace code."""
38
36
39
37
_framework_name = "huggingface"
38
+ LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
39
+ LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
40
+ INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
40
41
41
42
def __init__ (
42
43
self ,
@@ -142,6 +143,36 @@ def __init__(
142
143
}
143
144
}
144
145
146
+ **To enable PyTorch DDP:**
147
+
148
+ .. code:: python
149
+
150
+ {
151
+ "pytorchddp": {
152
+ "enabled": True
153
+ }
154
+ }
155
+
156
+ To learn more, see `Distributed PyTorch Training
157
+ <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
158
+
159
+ **To enable Torch Distributed:**
160
+
161
+ This is available for general distributed training on
162
+ GPU instances from PyTorch v1.13.1 and later.
163
+
164
+ .. code:: python
165
+
166
+ {
167
+ "torch_distributed": {
168
+ "enabled": True
169
+ }
170
+ }
171
+
172
+ This option also supports distributed training on Trn1.
173
+ To learn more, see `Distributed PyTorch Training on Trainium
174
+ <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
175
+
145
176
To enable distributed training with
146
177
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
147
178
for Hugging Face Transformers with PyTorch:
@@ -182,28 +213,23 @@ def __init__(
182
213
183
214
self ._validate_args (image_uri = image_uri )
184
215
185
- instance_type = renamed_kwargs (
186
- "train_instance_type" , "instance_type" , kwargs .get ("instance_type" ), kwargs
187
- )
188
-
189
- base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
190
- base_framework_version = (
216
+ self .base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
217
+ self .base_framework_version = (
191
218
tensorflow_version if tensorflow_version is not None else pytorch_version
192
219
)
193
220
194
221
if distribution is not None :
195
- validate_smdistributed (
196
- instance_type = instance_type ,
197
- framework_name = base_framework_name ,
198
- framework_version = base_framework_version ,
199
- py_version = self .py_version ,
200
- distribution = distribution ,
201
- image_uri = image_uri ,
222
+ distribution = validate_distribution (
223
+ distribution ,
224
+ self .instance_groups ,
225
+ self .base_framework_name ,
226
+ self .base_framework_version ,
227
+ py_version ,
228
+ image_uri ,
229
+ kwargs ,
202
230
)
203
231
204
- warn_if_parameter_server_with_multi_gpu (
205
- training_instance_type = instance_type , distribution = distribution
206
- )
232
+ self .distribution = distribution or {}
207
233
208
234
if "enable_sagemaker_metrics" not in kwargs :
209
235
kwargs ["enable_sagemaker_metrics" ] = True
@@ -214,8 +240,6 @@ def __init__(
214
240
entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs
215
241
)
216
242
217
- self .distribution = distribution or {}
218
-
219
243
if compiler_config is not None :
220
244
if not isinstance (compiler_config , TrainingCompilerConfig ):
221
245
error_string = (
@@ -267,14 +291,44 @@ def _validate_args(self, image_uri):
267
291
"transformers_version, tensorflow_version and pytorch_version."
268
292
)
269
293
294
+ def _huggingface_distribution_configuration (self , distribution ):
295
+ """Returns a dict of distribution config for Hugging Face training
296
+
297
+ Args:
298
+ distribution (dict): A dictionary with information on how to run distributed training.
299
+ Returns:
300
+ dict containing Pytorch DDP config
301
+ """
302
+ distribution_config = {}
303
+ pytorch_ddp_enabled = False
304
+ torch_distributed_enabled = False
305
+
306
+ if "pytorchddp" in distribution :
307
+ pytorch_ddp_enabled = distribution .get ("pytorchddp" ).get ("enabled" , False )
308
+ elif "torch_distributed" in distribution :
309
+ torch_distributed_enabled = distribution .get ("torch_distributed" ).get ("enabled" , False )
310
+
311
+ if pytorch_ddp_enabled :
312
+ distribution_config [self .LAUNCH_PYTORCH_DDP_ENV_NAME ] = pytorch_ddp_enabled
313
+ if self .instance_type is not None :
314
+ distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
315
+ elif torch_distributed_enabled :
316
+ distribution_config [self .LAUNCH_TORCH_DISTRIBUTED_ENV_NAME ] = torch_distributed_enabled
317
+ if self .instance_type is not None :
318
+ distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
319
+ else :
320
+ distribution_config = self ._distribution_configuration (distribution = distribution )
321
+
322
+ return distribution_config
323
+
270
324
def hyperparameters (self ):
271
325
"""Return hyperparameters used by your custom PyTorch code during model training."""
272
326
hyperparameters = super (HuggingFace , self ).hyperparameters ()
273
- distributed_training_hyperparameters = self ._distribution_configuration (
327
+ additional_hyperparameters = self ._huggingface_distribution_configuration (
274
328
distribution = self .distribution
275
329
)
276
330
hyperparameters .update (
277
- EstimatorBase ._json_encode_hyperparameters (distributed_training_hyperparameters )
331
+ EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters )
278
332
)
279
333
280
334
if self .compiler_config :
0 commit comments