Skip to content

Commit 566ce38

Browse files
committed
Change dir path to distributed_drivers
1 parent c5b1e12 commit 566ce38

File tree

14 files changed

+15
-13
lines changed

14 files changed

+15
-13
lines changed

src/sagemaker/modules/distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def driver_dir(self) -> str:
133133
Returns:
134134
str: Path to directory containing the driver script
135135
"""
136-
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")
136+
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
137137

138138
@property
139139
def driver_script(self) -> str:
@@ -169,7 +169,7 @@ def driver_dir(self) -> str:
169169
Returns:
170170
str: Path to directory containing the driver script
171171
"""
172-
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")
172+
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
173173

174174
@property
175175
def driver_script(self) -> str:

src/sagemaker/modules/templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121

2222
EXECUTE_BASIC_SCRIPT_DRIVER = """
2323
echo "Running Basic Script driver"
24-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/basic_script_driver.py
24+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py
2525
"""
2626

2727
EXEUCTE_DISTRIBUTED_DRIVER = """
2828
echo "Running {driver_name} Driver"
29-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/{driver_script}
29+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script}
3030
"""
3131

3232
TRAIN_SCRIPT_TEMPLATE = """

src/sagemaker/modules/train/container_drivers/scripts/environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
4949
SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
5050
SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code"
51-
SM_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/drivers"
51+
SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers"
5252

5353
SM_MASTER_ADDR = "algo-1"
5454
SM_MASTER_PORT = 7777
@@ -173,7 +173,7 @@ def set_env(
173173

174174
distributed = read_distributed_json()
175175
if distributed:
176-
env_vars["SM_DRIVER_DIR"] = SM_DRIVER_DIR_PATH
176+
env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH
177177
env_vars["SM_DISTRIBUTED_CONFIG"] = distributed
178178

179179
# Data Channels

src/sagemaker/modules/train/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def train(
569569
# If distributed is provided, overwrite code under <root>/drivers
570570
if self.distributed:
571571
distributed_driver_dir = self.distributed.driver_dir
572-
driver_dir = os.path.join(tmp_dir.name, "drivers")
572+
driver_dir = os.path.join(tmp_dir.name, "distributed_drivers")
573573
shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True)
574574

575575
# If source code is provided, create a channel for the source code

tests/data/modules/custom_drivers/driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def main():
1515

1616
source_dir = os.environ["SM_SOURCE_DIR"]
1717
assert source_dir == "/opt/ml/input/data/code"
18-
sm_drivers_dir = os.environ["SM_DRIVER_DIR"]
19-
assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/drivers"
18+
sm_drivers_dir = os.environ["SM_DISTRIBUTED_DRIVER_DIR"]
19+
assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/distributed_drivers"
2020

2121
entry_script = os.environ["SM_ENTRY_SCRIPT"]
2222
assert entry_script != None

tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
export SM_MASTER_PORT='7777'
9999
export SM_SOURCE_DIR='/opt/ml/input/data/code'
100100
export SM_ENTRY_SCRIPT='train.py'
101-
export SM_DRIVER_DIR='/opt/ml/input/data/sm_drivers/drivers'
101+
export SM_DISTRIBUTED_DRIVER_DIR='/opt/ml/input/data/sm_drivers/distributed_drivers'
102102
export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}'
103103
export SM_CHANNEL_TRAIN='/opt/ml/input/data/train'
104104
export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation'

tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
sys.modules["utils"] = MagicMock()
2323
sys.modules["mpi_utils"] = MagicMock()
2424

25-
from sagemaker.modules.train.container_drivers.drivers import mpi_driver # noqa: E402
25+
from sagemaker.modules.train.container_drivers.distributed_drivers import mpi_driver # noqa: E402
2626

2727

2828
DUMMY_MPI_COMMAND = [

tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
mock_utils.get_python_executable = Mock(return_value="/usr/bin/python")
2828

2929
with patch.dict("sys.modules", {"utils": mock_utils}):
30-
from sagemaker.modules.train.container_drivers.drivers.mpi_utils import (
30+
from sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils import (
3131
CustomHostKeyPolicy,
3232
_can_connect,
3333
write_status_file_to_workers,

tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
sys.modules["utils"] = MagicMock()
2323

24-
from sagemaker.modules.train.container_drivers.drivers import torchrun_driver # noqa: E402
24+
from sagemaker.modules.train.container_drivers.distributed_drivers import ( # noqa: E402
25+
torchrun_driver,
26+
)
2527

2628
DUMMY_DISTRIBUTED = {"process_count_per_node": 2}
2729

0 commit comments

Comments
 (0)