Skip to content

Commit 5954ecb

Browse files
committed
feat: Make DistributedConfig Extensible
1 parent 13ad978 commit 5954ecb

File tree

18 files changed

+219
-192
lines changed

18 files changed

+219
-192
lines changed

src/sagemaker/modules/distributed.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
"""Distributed module."""
1414
from __future__ import absolute_import
1515

16+
import os
17+
18+
from abc import ABC, abstractmethod
1619
from typing import Optional, Dict, Any, List
17-
from pydantic import BaseModel, PrivateAttr
20+
from pydantic import BaseModel
1821
from sagemaker.modules.utils import safe_serialize
22+
from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH
1923

2024

2125
class SMP(BaseModel):
@@ -72,16 +76,39 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7276
return hyperparameters
7377

7478

75-
class DistributedConfig(BaseModel):
76-
"""Base class for distributed training configurations."""
79+
class DistributedConfig(BaseModel, ABC):
80+
"""Abstract base class for distributed training configurations.
81+
82+
This class defines the interface that all distributed training configurations
83+
must implement. It provides a standardized way to specify driver scripts and
84+
their locations for distributed training jobs.
85+
"""
86+
87+
@property
88+
@abstractmethod
89+
def driver_dir(self) -> str:
90+
"""Directory containing the driver script.
91+
92+
This property should return the path to the directory containing
93+
the driver script, relative to the container's working directory.
94+
95+
Returns:
96+
str: Path to directory containing the driver script
97+
"""
98+
pass
99+
100+
@property
101+
@abstractmethod
102+
def driver_script(self) -> str:
103+
"""Name of the driver script.
77104
78-
_type: str = PrivateAttr()
105+
This property should return the name of the Python script that implements
106+
the distributed training driver logic.
79107
80-
def model_dump(self, *args, **kwargs):
81-
"""Dump the model to a dictionary."""
82-
result = super().model_dump(*args, **kwargs)
83-
result["_type"] = self._type
84-
return result
108+
Returns:
109+
str: Name of the driver script file
110+
"""
111+
pass
85112

86113

87114
class Torchrun(DistributedConfig):
@@ -98,11 +125,17 @@ class Torchrun(DistributedConfig):
98125
The SageMaker Model Parallelism v2 parameters.
99126
"""
100127

101-
_type: str = PrivateAttr(default="torchrun")
102-
103128
process_count_per_node: Optional[int] = None
104129
smp: Optional["SMP"] = None
105130

131+
@property
132+
def driver_dir(self) -> str:
133+
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")
134+
135+
@property
136+
def driver_script(self) -> str:
137+
return "torchrun_driver.py"
138+
106139

107140
class MPI(DistributedConfig):
108141
"""MPI.
@@ -118,7 +151,13 @@ class MPI(DistributedConfig):
118151
The custom MPI options to use for the training job.
119152
"""
120153

121-
_type: str = PrivateAttr(default="mpi")
122-
123154
process_count_per_node: Optional[int] = None
124155
mpi_additional_options: Optional[List[str]] = None
156+
157+
@property
158+
def driver_dir(self) -> str:
159+
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")
160+
161+
@property
162+
def driver_script(self) -> str:
163+
return "mpi_driver.py"

src/sagemaker/modules/templates.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,12 @@
2121

2222
EXECUTE_BASIC_SCRIPT_DRIVER = """
2323
echo "Running Basic Script driver"
24-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py
24+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/basic_script_driver.py
2525
"""
2626

27-
EXEUCTE_TORCHRUN_DRIVER = """
28-
echo "Running Torchrun driver"
29-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py
30-
"""
31-
32-
EXECUTE_MPI_DRIVER = """
33-
echo "Running MPI driver"
34-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py
27+
EXEUCTE_DISTRIBUTED_DRIVER = """
28+
echo "Running {driver_name} Driver"
29+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/{driver_script}
3530
"""
3631

3732
TRAIN_SCRIPT_TEMPLATE = """
Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +0,0 @@
1-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License"). You
4-
# may not use this file except in compliance with the License. A copy of
5-
# the License is located at
6-
#
7-
# http://aws.amazon.com/apache2.0/
8-
#
9-
# or in the "license" file accompanying this file. This file is
10-
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11-
# ANY KIND, either express or implied. See the License for the specific
12-
# language governing permissions and limitations under the License.
13-
"""Sagemaker modules container_drivers directory."""
14-
from __future__ import absolute_import

src/sagemaker/modules/train/container_drivers/common/__init__.py

Whitespace-only changes.

src/sagemaker/modules/train/container_drivers/utils.py renamed to src/sagemaker/modules/train/container_drivers/common/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME
9999
return hyperparameters_dict
100100

101101

102-
def get_process_count(distributed_dict: Dict[str, Any]) -> int:
102+
def get_process_count(process_count: Optional[int] = None) -> int:
103103
"""Get the number of processes to run on each node in the training job."""
104104
return (
105-
int(distributed_dict.get("process_count_per_node", 0))
105+
process_count
106106
or int(os.environ.get("SM_NUM_GPUS", 0))
107107
or int(os.environ.get("SM_NUM_NEURONS", 0))
108108
or 1

src/sagemaker/modules/train/container_drivers/drivers/__init__.py

Whitespace-only changes.

src/sagemaker/modules/train/container_drivers/basic_script_driver.py renamed to src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@
1313
"""This module is the entry point for the Basic Script Driver."""
1414
from __future__ import absolute_import
1515

16+
import os
1617
import sys
18+
import json
1719
import shlex
1820

21+
from pathlib import Path
1922
from typing import List
2023

21-
from utils import (
24+
sys.path.insert(0, str(Path(__file__).parent.parent))
25+
26+
from common.utils import ( # noqa: E402
2227
logger,
2328
get_python_executable,
24-
read_source_code_json,
25-
read_hyperparameters_json,
2629
execute_commands,
2730
write_failure_file,
2831
hyperparameters_to_cli_args,
@@ -31,11 +34,10 @@
3134

3235
def create_commands() -> List[str]:
3336
"""Create the commands to execute."""
34-
source_code = read_source_code_json()
35-
hyperparameters = read_hyperparameters_json()
37+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
38+
hyperparameters = json.loads(os.environ["SM_HPS"])
3639
python_executable = get_python_executable()
3740

38-
entry_script = source_code["entry_script"]
3941
args = hyperparameters_to_cli_args(hyperparameters)
4042
if entry_script.endswith(".py"):
4143
commands = [python_executable, entry_script]

src/sagemaker/modules/train/container_drivers/mpi_driver.py renamed to src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,8 @@
1616
import os
1717
import sys
1818
import json
19+
from pathlib import Path
1920

20-
from utils import (
21-
logger,
22-
read_source_code_json,
23-
read_distributed_json,
24-
read_hyperparameters_json,
25-
hyperparameters_to_cli_args,
26-
get_process_count,
27-
execute_commands,
28-
write_failure_file,
29-
USER_CODE_PATH,
30-
)
3121
from mpi_utils import (
3222
start_sshd_daemon,
3323
bootstrap_master_node,
@@ -38,6 +28,16 @@
3828
)
3929

4030

31+
sys.path.insert(0, str(Path(__file__).parent.parent))
32+
from common.utils import ( # noqa: E402
33+
logger,
34+
hyperparameters_to_cli_args,
35+
get_process_count,
36+
execute_commands,
37+
write_failure_file,
38+
)
39+
40+
4141
def main():
4242
"""Main function for the MPI driver script.
4343
@@ -58,9 +58,9 @@ def main():
5858
5. Exit
5959
6060
"""
61-
source_code = read_source_code_json()
62-
distribution = read_distributed_json()
63-
hyperparameters = read_hyperparameters_json()
61+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
62+
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
63+
hyperparameters = json.loads(os.environ["SM_HPS"])
6464

6565
sm_current_host = os.environ["SM_CURRENT_HOST"]
6666
sm_hosts = json.loads(os.environ["SM_HOSTS"])
@@ -77,7 +77,8 @@ def main():
7777

7878
host_list = json.loads(os.environ["SM_HOSTS"])
7979
host_count = int(os.environ["SM_HOST_COUNT"])
80-
process_count = get_process_count(distribution)
80+
process_count = int(distributed_config.get("process_count_per_node", 0))
81+
process_count = get_process_count(process_count)
8182

8283
if process_count > 1:
8384
host_list = ["{}:{}".format(host, process_count) for host in host_list]
@@ -86,8 +87,8 @@ def main():
8687
host_count=host_count,
8788
host_list=host_list,
8889
num_processes=process_count,
89-
additional_options=distribution.get("mpi_additional_options", []),
90-
entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]),
90+
additional_options=distributed_config.get("mpi_additional_options", []),
91+
entry_script_path=entry_script,
9192
)
9293

9394
args = hyperparameters_to_cli_args(hyperparameters)

src/sagemaker/modules/train/container_drivers/mpi_utils.py renamed to src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,22 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
import sys
1718
import subprocess
1819
import time
20+
import paramiko
21+
22+
from pathlib import Path
1923
from typing import List
2024

21-
import paramiko
22-
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger
25+
sys.path.insert(0, str(Path(__file__).parent.parent))
26+
27+
from common.utils import ( # noqa: E402
28+
SM_EFA_NCCL_INSTANCES,
29+
SM_EFA_RDMA_INSTANCES,
30+
get_python_executable,
31+
logger,
32+
)
2333

2434
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
2535
READY_FILE = "/tmp/ready.%s"

src/sagemaker/modules/train/container_drivers/torchrun_driver.py renamed to src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515

1616
import os
1717
import sys
18+
import json
1819

20+
from pathlib import Path
1921
from typing import List, Tuple
2022

21-
from utils import (
23+
sys.path.insert(0, str(Path(__file__).parent.parent))
24+
25+
from common.utils import ( # noqa: E402
2226
logger,
23-
read_source_code_json,
24-
read_distributed_json,
25-
read_hyperparameters_json,
2627
hyperparameters_to_cli_args,
2728
get_process_count,
2829
get_python_executable,
2930
execute_commands,
3031
write_failure_file,
31-
USER_CODE_PATH,
3232
SM_EFA_NCCL_INSTANCES,
3333
SM_EFA_RDMA_INSTANCES,
3434
)
@@ -65,11 +65,12 @@ def setup_env():
6565

6666
def create_commands():
6767
"""Create the Torch Distributed command to execute"""
68-
source_code = read_source_code_json()
69-
distribution = read_distributed_json()
70-
hyperparameters = read_hyperparameters_json()
68+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
69+
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
70+
hyperparameters = json.loads(os.environ["SM_HPS"])
7171

72-
process_count = get_process_count(distribution)
72+
process_count = int(distributed_config.get("process_count_per_node", 0))
73+
process_count = get_process_count(process_count)
7374
host_count = int(os.environ["SM_HOST_COUNT"])
7475

7576
torch_cmd = []
@@ -94,7 +95,7 @@ def create_commands():
9495
]
9596
)
9697

97-
torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])])
98+
torch_cmd.extend([entry_script])
9899

99100
args = hyperparameters_to_cli_args(hyperparameters)
100101
torch_cmd += args

src/sagemaker/modules/train/container_drivers/scripts/environment.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
import json
2020
import os
2121
import sys
22+
from pathlib import Path
2223
import logging
2324

24-
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
25-
sys.path.insert(0, parent_dir)
25+
sys.path.insert(0, str(Path(__file__).parent.parent))
2626

27-
from utils import safe_serialize, safe_deserialize # noqa: E402 # pylint: disable=C0413
27+
from common.utils import ( # noqa: E402
28+
safe_serialize,
29+
safe_deserialize,
30+
read_distributed_json,
31+
read_source_code_json,
32+
)
2833

2934
# Initialize logger
3035
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
@@ -42,6 +47,8 @@
4247
SM_OUTPUT_DIR = "/opt/ml/output"
4348
SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
4449
SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
50+
SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code"
51+
SM_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/drivers"
4552

4653
SM_MASTER_ADDR = "algo-1"
4754
SM_MASTER_PORT = 7777
@@ -158,6 +165,17 @@ def set_env(
158165
"SM_MASTER_PORT": SM_MASTER_PORT,
159166
}
160167

168+
# SourceCode and DistributedConfig Environment Variables
169+
source_code = read_source_code_json()
170+
if source_code:
171+
env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH
172+
env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "")
173+
174+
distributed = read_distributed_json()
175+
if distributed:
176+
env_vars["SM_DRIVER_DIR"] = SM_DRIVER_DIR_PATH
177+
env_vars["SM_DISTRIBUTED_CONFIG"] = distributed
178+
161179
# Data Channels
162180
channels = list(input_data_config.keys())
163181
for channel in channels:

0 commit comments

Comments
 (0)