@@ -77,23 +77,35 @@ def test_can_connect_failure(mock_ssh_client):
77
77
78
78
def test_get_mpirun_command ():
79
79
"""Test MPI command generation."""
80
- os .environ ["SM_NETWORK_INTERFACE_NAME" ] = "eth0"
81
- os .environ ["SM_CURRENT_INSTANCE_TYPE" ] = "ml.p4d.24xlarge"
82
-
83
- command = get_mpirun_command (
84
- host_count = 2 ,
85
- host_list = ["algo-1" , "algo-2" ],
86
- num_processes = 2 ,
87
- additional_options = [],
88
- entry_script_path = "train.py" ,
89
- )
90
-
91
- assert command [0 ] == "mpirun"
92
- assert "--host" in command
93
- assert "algo-1,algo-2" in command
94
- assert "-np" in command
95
- assert "2" in command
96
- assert f"NCCL_SOCKET_IFNAME=eth0" in " " .join (command )
80
+ test_network_interface = "eth0"
81
+ test_instance_type = "ml.p4d.24xlarge"
82
+
83
+ with patch .dict (
84
+ os .environ ,
85
+ {
86
+ "SM_NETWORK_INTERFACE_NAME" : test_network_interface ,
87
+ "SM_CURRENT_INSTANCE_TYPE" : test_instance_type ,
88
+ },
89
+ ):
90
+ command = get_mpirun_command (
91
+ host_count = 2 ,
92
+ host_list = ["algo-1" , "algo-2" ],
93
+ num_processes = 2 ,
94
+ additional_options = [],
95
+ entry_script_path = "train.py" ,
96
+ )
97
+
98
+ # Basic command structure checks
99
+ assert command [0 ] == "mpirun"
100
+ assert "--host" in command
101
+ assert "algo-1,algo-2" in command
102
+ assert "-np" in command
103
+ assert "2" in command
104
+
105
+ # Network interface check
106
+ expected_nccl_config = f"NCCL_SOCKET_IFNAME={ test_network_interface } "
107
+ command_str = " " .join (command )
108
+ assert expected_nccl_config in command_str
97
109
98
110
99
111
@patch ("sagemaker.modules.train.container_drivers.mpi_utils._can_connect" )
0 commit comments