Skip to content

Commit da5c441

Browse files
committed
Do not exclude none types when creating config jsons for safer reference
1 parent a6d3588 commit da5c441

File tree

6 files changed

+72
-6
lines changed

6 files changed

+72
-6
lines changed

src/sagemaker/modules/distributed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,20 @@ class Torchrun(DistributedConfig):
128128

129129
@property
130130
def driver_dir(self) -> str:
131+
"""Directory containing the driver script.
132+
133+
Returns:
134+
str: Path to directory containing the driver script
135+
"""
131136
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")
132137

133138
@property
134139
def driver_script(self) -> str:
140+
"""Name of the driver script.
141+
142+
Returns:
143+
str: Name of the driver script file
144+
"""
135145
return "torchrun_driver.py"
136146

137147

@@ -154,8 +164,18 @@ class MPI(DistributedConfig):
154164

155165
@property
156166
def driver_dir(self) -> str:
167+
"""Directory containing the driver script.
168+
169+
Returns:
170+
str: Path to directory containing the driver script
171+
"""
157172
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")
158173

159174
@property
160175
def driver_script(self) -> str:
176+
"""Name of the driver script.
177+
178+
Returns:
179+
str: Name of the driver script
180+
"""
161181
return "mpi_driver.py"

src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def main():
7777

7878
host_list = json.loads(os.environ["SM_HOSTS"])
7979
host_count = int(os.environ["SM_HOST_COUNT"])
80-
process_count = int(distributed_config.get("process_count_per_node", 0))
80+
process_count = int(distributed_config["process_count_per_node"] or 0)
8181
process_count = get_process_count(process_count)
8282

8383
if process_count > 1:
@@ -87,7 +87,7 @@ def main():
8787
host_count=host_count,
8888
host_list=host_list,
8989
num_processes=process_count,
90-
additional_options=distributed_config.get("mpi_additional_options", []),
90+
additional_options=distributed_config["mpi_additional_options"] or [],
9191
entry_script_path=entry_script,
9292
)
9393

src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_commands():
6969
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
7070
hyperparameters = json.loads(os.environ["SM_HPS"])
7171

72-
process_count = int(distributed_config.get("process_count_per_node", 0))
72+
process_count = int(distributed_config["process_count_per_node"] or 0)
7373
process_count = get_process_count(process_count)
7474
host_count = int(os.environ["SM_HOST_COUNT"])
7575

src/sagemaker/modules/train/model_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: Sour
773773
"""Write the source code configuration to a JSON file."""
774774
file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON)
775775
with open(file_path, "w") as f:
776-
dump = source_code.model_dump(exclude_none=True) if source_code else {}
776+
dump = source_code.model_dump() if source_code else {}
777777
f.write(json.dumps(dump))
778778

779779
def _write_distributed_json(
@@ -784,7 +784,7 @@ def _write_distributed_json(
784784
"""Write the distributed runner configuration to a JSON file."""
785785
file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON)
786786
with open(file_path, "w") as f:
787-
dump = distributed.model_dump(exclude_none=True) if distributed else {}
787+
dump = distributed.model_dump() if distributed else {}
788788
f.write(json.dumps(dump))
789789

790790
def _prepare_train_script(

tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@
7373
},
7474
}
7575

76+
SOURCE_CODE = {
77+
"source_dir": "code",
78+
"entry_script": "train.py",
79+
}
80+
81+
DISTRIBUTED_CONFIG = {
82+
"process_count_per_node": 2,
83+
}
84+
7685
OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env")
7786

7887
# flake8: noqa
@@ -87,6 +96,10 @@
8796
export SM_LOG_LEVEL='20'
8897
export SM_MASTER_ADDR='algo-1'
8998
export SM_MASTER_PORT='7777'
99+
export SM_SOURCE_DIR='/opt/ml/input/data/code'
100+
export SM_ENTRY_SCRIPT='train.py'
101+
export SM_DRIVER_DIR='/opt/ml/input/data/sm_drivers/drivers'
102+
export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}'
90103
export SM_CHANNEL_TRAIN='/opt/ml/input/data/train'
91104
export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation'
92105
export SM_CHANNELS='["train", "validation"]'
@@ -110,6 +123,14 @@
110123
"""
111124

112125

126+
@patch(
127+
"sagemaker.modules.train.container_drivers.scripts.environment.read_source_code_json",
128+
return_value=SOURCE_CODE,
129+
)
130+
@patch(
131+
"sagemaker.modules.train.container_drivers.scripts.environment.read_distributed_json",
132+
return_value=DISTRIBUTED_CONFIG,
133+
)
113134
@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus", return_value=8)
114135
@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus", return_value=0)
115136
@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons", return_value=0)
@@ -122,7 +143,13 @@
122143
side_effect=safe_deserialize,
123144
)
124145
def test_set_env(
125-
mock_safe_deserialize, mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons
146+
mock_safe_deserialize,
147+
mock_safe_serialize,
148+
mock_num_neurons,
149+
mock_num_gpus,
150+
mock_num_cpus,
151+
mock_read_distributed_json,
152+
mock_read_source_code_json,
126153
):
127154
with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}):
128155
set_env(
@@ -135,6 +162,8 @@ def test_set_env(
135162
mock_num_cpus.assert_called_once()
136163
mock_num_gpus.assert_called_once()
137164
mock_num_neurons.assert_called_once()
165+
mock_read_distributed_json.assert_called_once()
166+
mock_read_source_code_json.assert_called_once()
138167

139168
with open(OUTPUT_FILE, "r") as f:
140169
env_file = f.read().strip()

tests/unit/sagemaker/modules/train/container_drivers/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# language governing permissions and limitations under the License.
1313
"""Container Utils Unit Tests."""
1414
from __future__ import absolute_import
15+
import os
1516

1617
from sagemaker.modules.train.container_drivers.common.utils import (
1718
safe_deserialize,
1819
safe_serialize,
1920
hyperparameters_to_cli_args,
21+
get_process_count,
2022
)
2123

2224
SM_HPS = {
@@ -119,3 +121,18 @@ def test_safe_serialize_empty_data():
119121
assert safe_serialize("") == ""
120122
assert safe_serialize([]) == "[]"
121123
assert safe_serialize({}) == "{}"
124+
125+
126+
def test_get_process_count():
127+
assert get_process_count() == 1
128+
assert get_process_count(2) == 2
129+
os.environ["SM_NUM_GPUS"] = "4"
130+
assert get_process_count() == 4
131+
os.environ["SM_NUM_GPUS"] = "0"
132+
os.environ["SM_NUM_NEURONS"] = "8"
133+
assert get_process_count() == 8
134+
os.environ["SM_NUM_NEURONS"] = "0"
135+
assert get_process_count() == 1
136+
del os.environ["SM_NUM_GPUS"]
137+
del os.environ["SM_NUM_NEURONS"]
138+
assert get_process_count() == 1

0 commit comments

Comments
 (0)