Skip to content

Commit ced988f

Browse files
committed
Fix unit test
1 parent 0d33f36 commit ced988f

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

tests/integ/sagemaker/modules/train/container_drivers/test_mpi_utils.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,23 +77,35 @@ def test_can_connect_failure(mock_ssh_client):
7777

7878
def test_get_mpirun_command():
7979
"""Test MPI command generation."""
80-
os.environ["SM_NETWORK_INTERFACE_NAME"] = "eth0"
81-
os.environ["SM_CURRENT_INSTANCE_TYPE"] = "ml.p4d.24xlarge"
82-
83-
command = get_mpirun_command(
84-
host_count=2,
85-
host_list=["algo-1", "algo-2"],
86-
num_processes=2,
87-
additional_options=[],
88-
entry_script_path="train.py",
89-
)
90-
91-
assert command[0] == "mpirun"
92-
assert "--host" in command
93-
assert "algo-1,algo-2" in command
94-
assert "-np" in command
95-
assert "2" in command
96-
assert f"NCCL_SOCKET_IFNAME=eth0" in " ".join(command)
80+
test_network_interface = "eth0"
81+
test_instance_type = "ml.p4d.24xlarge"
82+
83+
with patch.dict(
84+
os.environ,
85+
{
86+
"SM_NETWORK_INTERFACE_NAME": test_network_interface,
87+
"SM_CURRENT_INSTANCE_TYPE": test_instance_type,
88+
},
89+
):
90+
command = get_mpirun_command(
91+
host_count=2,
92+
host_list=["algo-1", "algo-2"],
93+
num_processes=2,
94+
additional_options=[],
95+
entry_script_path="train.py",
96+
)
97+
98+
# Basic command structure checks
99+
assert command[0] == "mpirun"
100+
assert "--host" in command
101+
assert "algo-1,algo-2" in command
102+
assert "-np" in command
103+
assert "2" in command
104+
105+
# Network interface check
106+
expected_nccl_config = f"NCCL_SOCKET_IFNAME={test_network_interface}"
107+
command_str = " ".join(command)
108+
assert expected_nccl_config in command_str
97109

98110

99111
@patch("sagemaker.modules.train.container_drivers.mpi_utils._can_connect")

0 commit comments

Comments
 (0)