Skip to content

Commit c8104af

Browse files
ChaiBapchyaChoiByungWook
authored andcommitted
feature: add data parallelism support (#454) (#511) (#495)
1 parent f8c5287 commit c8104af

File tree

16 files changed

+830
-66
lines changed

16 files changed

+830
-66
lines changed

src/sagemaker/debugger/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sagemaker.debugger.metrics_config import ( # noqa: F401
3030
DataloaderProfilingConfig,
3131
DetailedProfilingConfig,
32-
HerringProfilingConfig,
32+
SMDataParallelProfilingConfig,
3333
HorovodProfilingConfig,
3434
PythonProfilingConfig,
3535
)

src/sagemaker/debugger/framework_profile.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker.debugger.metrics_config import (
1717
DetailedProfilingConfig,
1818
DataloaderProfilingConfig,
19-
HerringProfilingConfig,
19+
SMDataParallelProfilingConfig,
2020
HorovodProfilingConfig,
2121
PythonProfilingConfig,
2222
)
@@ -33,7 +33,7 @@
3333
DataloaderProfilingConfig,
3434
PythonProfilingConfig,
3535
HorovodProfilingConfig,
36-
HerringProfilingConfig,
36+
SMDataParallelProfilingConfig,
3737
]
3838

3939

@@ -53,7 +53,7 @@ def __init__(
5353
dataloader_profiling_config=None,
5454
python_profiling_config=None,
5555
horovod_profiling_config=None,
56-
herring_profiling_config=None,
56+
smdataparallel_profiling_config=None,
5757
start_step=None,
5858
num_steps=None,
5959
start_unix_time=None,
@@ -88,8 +88,8 @@ def __init__(
8888
collected by the Python profiler (cProfile or Pyinstrument).
8989
horovod_profiling_config (HorovodProfilingConfig): The configuration for metrics
9090
collected by horovod when using horovod for distributed training.
91-
herring_profiling_config (HerringProfilingConfig): The configuration for metrics
92-
collected by herring when using herring for distributed training.
91+
smdataparallel_profiling_config (SMDataParallelProfilingConfig): The configuration for
92+
metrics collected by SageMaker Distributed training.
9393
start_step (int): The step at which to start profiling.
9494
num_steps (int): The number of steps to profile.
9595
start_unix_time (int): The UNIX time at which to start profiling.
@@ -108,7 +108,7 @@ def __init__(
108108
dataloader_profiling_config,
109109
python_profiling_config,
110110
horovod_profiling_config,
111-
herring_profiling_config,
111+
smdataparallel_profiling_config,
112112
)
113113

114114
use_one_config_for_all_metrics = (

src/sagemaker/debugger/metrics_config.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
DATALOADER_PROFILING_START_STEP_DEFAULT,
1919
DETAILED_PROFILING_CONFIG_NAME,
2020
DETAILED_PROFILING_START_STEP_DEFAULT,
21-
HERRING_PROFILING_CONFIG_NAME,
22-
HERRING_PROFILING_START_STEP_DEFAULT,
21+
SMDATAPARALLEL_PROFILING_CONFIG_NAME,
22+
SMDATAPARALLEL_PROFILING_START_STEP_DEFAULT,
2323
HOROVOD_PROFILING_CONFIG_NAME,
2424
HOROVOD_PROFILING_START_STEP_DEFAULT,
2525
PROFILING_NUM_STEPS_DEFAULT,
@@ -298,6 +298,9 @@ def __init__(
298298
start_step = PYTHON_PROFILING_START_STEP_DEFAULT
299299
num_steps = PYTHON_PROFILING_NUM_STEPS_DEFAULT
300300

301+
if profile_default_steps:
302+
cprofile_timer = cProfileTimer.DEFAULT
303+
301304
super().__init__(
302305
PYTHON_PROFILING_CONFIG_NAME, start_step, num_steps, start_unix_time, duration
303306
)
@@ -367,8 +370,8 @@ def __init__(
367370
)
368371

369372

370-
class HerringProfilingConfig(MetricsConfigBase):
371-
"""Configuration for metrics collected by herring when using herring for distributed training.
373+
class SMDataParallelProfilingConfig(MetricsConfigBase):
374+
"""Configuration for metrics collected by SageMaker Distributed training.
372375
373376
By default, profile step 15 of training.
374377
"""
@@ -381,9 +384,9 @@ def __init__(
381384
duration=None,
382385
profile_default_steps=False,
383386
):
384-
"""If profile_default_steps is set to True or none of the range fields are specified,
385-
use the default config for herring profiling. Otherwise, profile according to the
386-
specified range fields.
387+
"""If profile_default_steps is set to True or none of the range fields are specified, use
388+
the default profiling config for SageMaker Distributed training. Otherwise, profile
389+
according to the specified range fields.
387390
388391
Args:
389392
start_step (int): The step at which to start profiling.
@@ -396,9 +399,9 @@ def __init__(
396399
profile_default_steps, bool
397400
), ErrorMessages.INVALID_PROFILE_DEFAULT_STEPS.value
398401
if profile_default_steps or start_step is num_steps is start_unix_time is duration is None:
399-
start_step = HERRING_PROFILING_START_STEP_DEFAULT
402+
start_step = SMDATAPARALLEL_PROFILING_START_STEP_DEFAULT
400403
num_steps = PROFILING_NUM_STEPS_DEFAULT
401404

402405
super().__init__(
403-
HERRING_PROFILING_CONFIG_NAME, start_step, num_steps, start_unix_time, duration
406+
SMDATAPARALLEL_PROFILING_CONFIG_NAME, start_step, num_steps, start_unix_time, duration
404407
)

src/sagemaker/debugger/profiler_constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
DATALOADER_PROFILING_CONFIG_NAME = "DataloaderProfilingConfig"
2323
PYTHON_PROFILING_CONFIG_NAME = "PythonProfilingConfig"
2424
HOROVOD_PROFILING_CONFIG_NAME = "HorovodProfilingConfig"
25-
HERRING_PROFILING_CONFIG_NAME = "HerringProfilingConfig"
25+
SMDATAPARALLEL_PROFILING_CONFIG_NAME = "SMDataParallelProfilingConfig"
2626

2727
DETAILED_PROFILING_START_STEP_DEFAULT = 5
2828
DATALOADER_PROFILING_START_STEP_DEFAULT = 7
2929
PYTHON_PROFILING_START_STEP_DEFAULT = 9
3030
HOROVOD_PROFILING_START_STEP_DEFAULT = 13
31-
HERRING_PROFILING_START_STEP_DEFAULT = 15
31+
SMDATAPARALLEL_PROFILING_START_STEP_DEFAULT = 15
3232
PROFILING_NUM_STEPS_DEFAULT = 1
3333
START_STEP_DEFAULT = 0
3434
PYTHON_PROFILING_NUM_STEPS_DEFAULT = 3

src/sagemaker/debugger/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,4 @@ class cProfileTimer(Enum):
118118
TOTAL_TIME = "total_time"
119119
CPU_TIME = "cpu_time"
120120
OFF_CPU_TIME = "off_cpu_time"
121+
DEFAULT = "default"

src/sagemaker/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,8 @@ class Framework(EstimatorBase):
19031903

19041904
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
19051905
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
1906+
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
1907+
INSTANCE_TYPE = "sagemaker_instance_type"
19061908
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
19071909
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
19081910
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"

src/sagemaker/fw_utils.py

Lines changed: 156 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050

5151
DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
5252
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
53+
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = ("ml.p3.16xlarge", "ml.p3dn.24xlarge", "local_gpu")
54+
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
55+
"tensorflow": ["2.3.0", "2.3.1"],
56+
"pytorch": ["1.6.0"],
57+
}
58+
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
5359

5460

5561
def validate_source_dir(script, directory):
@@ -255,9 +261,8 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
255261
.. code:: python
256262
257263
{
258-
'parameter_server':
259-
{
260-
'enabled': True
264+
"parameter_server": {
265+
"enabled": True
261266
}
262267
}
263268
@@ -279,6 +284,154 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
279284
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)
280285

281286

287+
def validate_smdistributed(
288+
instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
289+
):
290+
"""Check if smdistributed strategy is correctly invoked by the user.
291+
292+
Currently, two strategies are supported: `dataparallel` or `modelparallel`.
293+
Validate if the user requested strategy is supported.
294+
295+
Currently, only one strategy can be specified at a time. Validate if the user has requested
296+
more than one strategy simultaneously.
297+
298+
Validate if the smdistributed dict arg is syntactically correct.
299+
300+
Additionally, perform strategy-specific validations.
301+
302+
Args:
303+
instance_type (str): A string representing the type of training instance selected.
304+
framework_name (str): A string representing the name of framework selected.
305+
framework_version (str): A string representing the framework version selected.
306+
py_version (str): A string representing the python version selected.
307+
distribution (dict): A dictionary with information to enable distributed training.
308+
(Defaults to None if distributed training is not enabled.) For example:
309+
310+
.. code:: python
311+
312+
{
313+
"smdistributed": {
314+
"dataparallel": {
315+
"enabled": True
316+
}
317+
}
318+
}
319+
image_uri (str): A string representing a Docker image URI.
320+
321+
Raises:
322+
ValueError: if distribution dictionary isn't correctly formatted or
323+
multiple strategies are requested simultaneously or
324+
an unsupported strategy is requested or
325+
strategy-specific inputs are incorrect/unsupported
326+
"""
327+
if "smdistributed" not in distribution:
328+
# Distribution strategy other than smdistributed is selected
329+
return
330+
331+
# distribution contains smdistributed
332+
smdistributed = distribution["smdistributed"]
333+
if not isinstance(smdistributed, dict):
334+
raise ValueError("smdistributed strategy requires a dictionary")
335+
336+
if len(smdistributed) > 1:
337+
# more than 1 smdistributed strategy requested by the user
338+
err_msg = (
339+
"Cannot use more than 1 smdistributed strategy. \n"
340+
"Choose one of the following supported strategies:"
341+
f"{SMDISTRIBUTED_SUPPORTED_STRATEGIES}"
342+
)
343+
raise ValueError(err_msg)
344+
345+
# validate if smdistributed strategy is supported
346+
# currently this for loop essentially checks for only 1 key
347+
for strategy in smdistributed:
348+
if strategy not in SMDISTRIBUTED_SUPPORTED_STRATEGIES:
349+
err_msg = (
350+
f"Invalid smdistributed strategy provided: {strategy} \n"
351+
f"Supported strategies: {SMDISTRIBUTED_SUPPORTED_STRATEGIES}"
352+
)
353+
raise ValueError(err_msg)
354+
355+
# smdataparallel-specific input validation
356+
if "dataparallel" in smdistributed:
357+
_validate_smdataparallel_args(
358+
instance_type, framework_name, framework_version, py_version, distribution, image_uri
359+
)
360+
361+
362+
def _validate_smdataparallel_args(
363+
instance_type, framework_name, framework_version, py_version, distribution, image_uri=None
364+
):
365+
"""Check if request is using unsupported arguments.
366+
367+
Validate if user specifies a supported instance type, framework version, and python
368+
version.
369+
370+
Args:
371+
instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
372+
framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
373+
framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
374+
py_version (str): A string representing the python version selected. Ex: `py3`
375+
distribution (dict): A dictionary with information to enable distributed training.
376+
(Defaults to None if distributed training is not enabled.) Ex:
377+
378+
.. code:: python
379+
380+
{
381+
"smdistributed": {
382+
"dataparallel": {
383+
"enabled": True
384+
}
385+
}
386+
}
387+
image_uri (str): A string representing a Docker image URI.
388+
389+
Raises:
390+
ValueError: if
391+
(`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or
392+
`py_version` is not python3 or
393+
`framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
394+
"""
395+
smdataparallel_enabled = (
396+
distribution.get("smdistributed").get("dataparallel").get("enabled", False)
397+
)
398+
399+
if not smdataparallel_enabled:
400+
return
401+
402+
is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES
403+
404+
err_msg = ""
405+
406+
if not is_instance_type_supported:
407+
# instance_type is required
408+
err_msg += (
409+
f"Provided instance_type {instance_type} is not supported by smdataparallel.\n"
410+
"Please specify one of the supported instance types:"
411+
f"{SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES}\n"
412+
)
413+
414+
if not image_uri:
415+
# ignore framework_version & py_version if image_uri is set
416+
# in case image_uri is not set, then both are mandatory
417+
supported = SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS[framework_name]
418+
if framework_version not in supported:
419+
err_msg += (
420+
f"Provided framework_version {framework_version} is not supported by"
421+
" smdataparallel.\n"
422+
f"Please specify one of the supported framework versions: {supported} \n"
423+
)
424+
425+
if "py3" not in py_version:
426+
err_msg += (
427+
f"Provided py_version {py_version} is not supported by smdataparallel.\n"
428+
"Please specify py_version=py3"
429+
)
430+
431+
if err_msg:
432+
raise ValueError(err_msg)
433+
434+
282435
def python_deprecation_warning(framework, latest_supported_version):
283436
"""
284437
Args:

0 commit comments

Comments
 (0)