50
50
51
51
DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1" ,)
52
52
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge" , "ml.p3.2xlarge" )
53
+ SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = ("ml.p3.16xlarge" , "ml.p3dn.24xlarge" , "local_gpu" )
54
+ SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
55
+ "tensorflow" : ["2.3.0" , "2.3.1" ],
56
+ "pytorch" : ["1.6.0" ],
57
+ }
58
+ SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel" , "modelparallel" ]
53
59
54
60
55
61
def validate_source_dir (script , directory ):
@@ -255,9 +261,8 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
255
261
.. code:: python
256
262
257
263
{
258
- 'parameter_server':
259
- {
260
- 'enabled': True
264
+ "parameter_server": {
265
+ "enabled": True
261
266
}
262
267
}
263
268
@@ -279,6 +284,154 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
279
284
logger .warning (PARAMETER_SERVER_MULTI_GPU_WARNING )
280
285
281
286
287
+ def validate_smdistributed (
288
+ instance_type , framework_name , framework_version , py_version , distribution , image_uri = None
289
+ ):
290
+ """Check if smdistributed strategy is correctly invoked by the user.
291
+
292
+ Currently, two strategies are supported: `dataparallel` or `modelparallel`.
293
+ Validate if the user requested strategy is supported.
294
+
295
+ Currently, only one strategy can be specified at a time. Validate if the user has requested
296
+ more than one strategy simultaneously.
297
+
298
+ Validate if the smdistributed dict arg is syntactically correct.
299
+
300
+ Additionally, perform strategy-specific validations.
301
+
302
+ Args:
303
+ instance_type (str): A string representing the type of training instance selected.
304
+ framework_name (str): A string representing the name of framework selected.
305
+ framework_version (str): A string representing the framework version selected.
306
+ py_version (str): A string representing the python version selected.
307
+ distribution (dict): A dictionary with information to enable distributed training.
308
+ (Defaults to None if distributed training is not enabled.) For example:
309
+
310
+ .. code:: python
311
+
312
+ {
313
+ "smdistributed": {
314
+ "dataparallel": {
315
+ "enabled": True
316
+ }
317
+ }
318
+ }
319
+ image_uri (str): A string representing a Docker image URI.
320
+
321
+ Raises:
322
+ ValueError: if distribution dictionary isn't correctly formatted or
323
+ multiple strategies are requested simultaneously or
324
+ an unsupported strategy is requested or
325
+ strategy-specific inputs are incorrect/unsupported
326
+ """
327
+ if "smdistributed" not in distribution :
328
+ # Distribution strategy other than smdistributed is selected
329
+ return
330
+
331
+ # distribution contains smdistributed
332
+ smdistributed = distribution ["smdistributed" ]
333
+ if not isinstance (smdistributed , dict ):
334
+ raise ValueError ("smdistributed strategy requires a dictionary" )
335
+
336
+ if len (smdistributed ) > 1 :
337
+ # more than 1 smdistributed strategy requested by the user
338
+ err_msg = (
339
+ "Cannot use more than 1 smdistributed strategy. \n "
340
+ "Choose one of the following supported strategies:"
341
+ f"{ SMDISTRIBUTED_SUPPORTED_STRATEGIES } "
342
+ )
343
+ raise ValueError (err_msg )
344
+
345
+ # validate if smdistributed strategy is supported
346
+ # currently this for loop essentially checks for only 1 key
347
+ for strategy in smdistributed :
348
+ if strategy not in SMDISTRIBUTED_SUPPORTED_STRATEGIES :
349
+ err_msg = (
350
+ f"Invalid smdistributed strategy provided: { strategy } \n "
351
+ f"Supported strategies: { SMDISTRIBUTED_SUPPORTED_STRATEGIES } "
352
+ )
353
+ raise ValueError (err_msg )
354
+
355
+ # smdataparallel-specific input validation
356
+ if "dataparallel" in smdistributed :
357
+ _validate_smdataparallel_args (
358
+ instance_type , framework_name , framework_version , py_version , distribution , image_uri
359
+ )
360
+
361
+
362
+ def _validate_smdataparallel_args (
363
+ instance_type , framework_name , framework_version , py_version , distribution , image_uri = None
364
+ ):
365
+ """Check if request is using unsupported arguments.
366
+
367
+ Validate if user specifies a supported instance type, framework version, and python
368
+ version.
369
+
370
+ Args:
371
+ instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
372
+ framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
373
+ framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
374
+ py_version (str): A string representing the python version selected. Ex: `py3`
375
+ distribution (dict): A dictionary with information to enable distributed training.
376
+ (Defaults to None if distributed training is not enabled.) Ex:
377
+
378
+ .. code:: python
379
+
380
+ {
381
+ "smdistributed": {
382
+ "dataparallel": {
383
+ "enabled": True
384
+ }
385
+ }
386
+ }
387
+ image_uri (str): A string representing a Docker image URI.
388
+
389
+ Raises:
390
+ ValueError: if
391
+ (`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or
392
+ `py_version` is not python3 or
393
+ `framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
394
+ """
395
+ smdataparallel_enabled = (
396
+ distribution .get ("smdistributed" ).get ("dataparallel" ).get ("enabled" , False )
397
+ )
398
+
399
+ if not smdataparallel_enabled :
400
+ return
401
+
402
+ is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES
403
+
404
+ err_msg = ""
405
+
406
+ if not is_instance_type_supported :
407
+ # instance_type is required
408
+ err_msg += (
409
+ f"Provided instance_type { instance_type } is not supported by smdataparallel.\n "
410
+ "Please specify one of the supported instance types:"
411
+ f"{ SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES } \n "
412
+ )
413
+
414
+ if not image_uri :
415
+ # ignore framework_version & py_version if image_uri is set
416
+ # in case image_uri is not set, then both are mandatory
417
+ supported = SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS [framework_name ]
418
+ if framework_version not in supported :
419
+ err_msg += (
420
+ f"Provided framework_version { framework_version } is not supported by"
421
+ " smdataparallel.\n "
422
+ f"Please specify one of the supported framework versions: { supported } \n "
423
+ )
424
+
425
+ if "py3" not in py_version :
426
+ err_msg += (
427
+ f"Provided py_version { py_version } is not supported by smdataparallel.\n "
428
+ "Please specify py_version=py3"
429
+ )
430
+
431
+ if err_msg :
432
+ raise ValueError (err_msg )
433
+
434
+
282
435
def python_deprecation_warning (framework , latest_supported_version ):
283
436
"""
284
437
Args:
0 commit comments