13
13
"""MPI Utils Unit Tests."""
14
14
from __future__ import absolute_import
15
15
16
- import os
16
+ import subprocess
17
17
from unittest .mock import Mock , patch
18
18
19
19
import paramiko
20
20
import pytest
21
21
22
- from sagemaker .modules .train .container_drivers .mpi_utils import (
23
- CustomHostKeyPolicy ,
24
- _can_connect ,
25
- bootstrap_master_node ,
26
- bootstrap_worker_node ,
27
- get_mpirun_command ,
28
- write_status_file_to_workers ,
29
- )
22
+ # Mock the utils module before importing mpi_utils
23
+ mock_utils = Mock ()
24
+ mock_utils .logger = Mock ()
25
+ mock_utils .SM_EFA_NCCL_INSTANCES = []
26
+ mock_utils .SM_EFA_RDMA_INSTANCES = []
27
+ mock_utils .get_python_executable = Mock (return_value = "/usr/bin/python" )
28
+
29
+ with patch .dict ("sys.modules" , {"utils" : mock_utils }):
30
+ from sagemaker .modules .train .container_drivers .mpi_utils import (
31
+ CustomHostKeyPolicy ,
32
+ _can_connect ,
33
+ write_status_file_to_workers ,
34
+ )
30
35
31
36
TEST_HOST = "algo-1"
32
37
TEST_WORKER = "algo-2"
33
38
TEST_STATUS_FILE = "/tmp/test-status"
34
39
35
40
36
41
def test_custom_host_key_policy_valid_hostname ():
37
- """Test CustomHostKeyPolicy with valid algo- hostname ."""
42
+ """Test CustomHostKeyPolicy accepts algo- prefixed hostnames ."""
38
43
policy = CustomHostKeyPolicy ()
39
44
mock_client = Mock ()
40
45
mock_key = Mock ()
@@ -47,7 +52,7 @@ def test_custom_host_key_policy_valid_hostname():
47
52
48
53
49
54
def test_custom_host_key_policy_invalid_hostname ():
50
- """Test CustomHostKeyPolicy with invalid hostname ."""
55
+ """Test CustomHostKeyPolicy rejects non-algo prefixed hostnames ."""
51
56
policy = CustomHostKeyPolicy ()
52
57
mock_client = Mock ()
53
58
mock_key = Mock ()
@@ -60,112 +65,51 @@ def test_custom_host_key_policy_invalid_hostname():
60
65
61
66
62
67
@patch ("paramiko.SSHClient" )
63
- def test_can_connect_success (mock_ssh_client ):
68
+ @patch ("sagemaker.modules.train.container_drivers.mpi_utils.logger" )
69
+ def test_can_connect_success (mock_logger , mock_ssh_client ):
64
70
"""Test successful SSH connection."""
65
71
mock_client = Mock ()
66
- mock_ssh_client .return_value = mock_client
72
+ mock_ssh_client .return_value .__enter__ .return_value = mock_client
73
+ mock_client .connect .return_value = None # Successful connection
74
+
75
+ result = _can_connect (TEST_HOST )
67
76
68
- assert _can_connect (TEST_HOST ) is True
77
+ assert result is True
78
+ mock_client .load_system_host_keys .assert_called_once ()
79
+ mock_client .set_missing_host_key_policy .assert_called_once ()
69
80
mock_client .connect .assert_called_once_with (TEST_HOST , port = 22 )
81
+ mock_logger .info .assert_called_with ("Can connect to host %s" , TEST_HOST )
70
82
71
83
72
84
@patch ("paramiko.SSHClient" )
73
- def test_can_connect_failure (mock_ssh_client ):
85
+ @patch ("sagemaker.modules.train.container_drivers.mpi_utils.logger" )
86
+ def test_can_connect_failure (mock_logger , mock_ssh_client ):
74
87
"""Test SSH connection failure."""
75
88
mock_client = Mock ()
76
- mock_ssh_client .return_value = mock_client
77
- mock_client .connect .side_effect = Exception ("Connection failed" )
78
-
79
- assert _can_connect (TEST_HOST ) is False
89
+ mock_ssh_client .return_value .__enter__ .return_value = mock_client
90
+ mock_client .connect .side_effect = paramiko .SSHException ("Connection failed" )
80
91
92
+ result = _can_connect (TEST_HOST )
81
93
82
- @patch ("subprocess.run" )
83
- def test_write_status_file_to_workers_success (mock_run ):
84
- """Test successful status file writing to workers."""
85
- mock_run .return_value = Mock (returncode = 0 )
86
-
87
- write_status_file_to_workers ([TEST_WORKER ], TEST_STATUS_FILE )
88
-
89
- mock_run .assert_called_once ()
90
- args = mock_run .call_args [0 ][0 ]
91
- assert args == ["ssh" , TEST_WORKER , "touch" , TEST_STATUS_FILE ]
94
+ assert result is False
95
+ mock_client .load_system_host_keys .assert_called_once ()
96
+ mock_client .set_missing_host_key_policy .assert_called_once ()
97
+ mock_client .connect .assert_called_once_with (TEST_HOST , port = 22 )
98
+ mock_logger .info .assert_called_with ("Cannot connect to host %s" , TEST_HOST )
92
99
93
100
94
101
@patch ("subprocess.run" )
95
- def test_write_status_file_to_workers_failure (mock_run ):
102
+ @patch ("sagemaker.modules.train.container_drivers.mpi_utils.logger" )
103
+ def test_write_status_file_to_workers_failure (mock_logger , mock_run ):
96
104
"""Test failed status file writing to workers with retry timeout."""
97
- mock_run .side_effect = Exception ( "SSH failed " )
105
+ mock_run .side_effect = subprocess . CalledProcessError ( 1 , "ssh " )
98
106
99
107
with pytest .raises (TimeoutError ) as exc_info :
100
108
write_status_file_to_workers ([TEST_WORKER ], TEST_STATUS_FILE )
101
109
102
110
assert f"Timed out waiting for { TEST_WORKER } " in str (exc_info .value )
103
-
104
-
105
- def test_get_mpirun_command_basic ():
106
- """Test basic MPI command generation."""
107
- with patch .dict (
108
- os .environ ,
109
- {"SM_NETWORK_INTERFACE_NAME" : "eth0" , "SM_CURRENT_INSTANCE_TYPE" : "ml.p3.16xlarge" },
110
- ):
111
- command = get_mpirun_command (
112
- host_count = 2 ,
113
- host_list = [TEST_HOST , TEST_WORKER ],
114
- num_processes = 2 ,
115
- additional_options = [],
116
- entry_script_path = "train.py" ,
117
- )
118
-
119
- assert command [0 ] == "mpirun"
120
- assert "--host" in command
121
- assert f"{ TEST_HOST } ,{ TEST_WORKER } " in command
122
- assert "-np" in command
123
- assert "2" in command
124
-
125
-
126
- def test_get_mpirun_command_efa ():
127
- """Test MPI command generation with EFA instance."""
128
- with patch .dict (
129
- os .environ ,
130
- {"SM_NETWORK_INTERFACE_NAME" : "eth0" , "SM_CURRENT_INSTANCE_TYPE" : "ml.p4d.24xlarge" },
131
- ):
132
- command = get_mpirun_command (
133
- host_count = 2 ,
134
- host_list = [TEST_HOST , TEST_WORKER ],
135
- num_processes = 2 ,
136
- additional_options = [],
137
- entry_script_path = "train.py" ,
138
- )
139
-
140
- command_str = " " .join (command )
141
- assert "FI_PROVIDER=efa" in command_str
142
- assert "NCCL_PROTO=simple" in command_str
143
-
144
-
145
- @patch ("sagemaker.modules.train.container_drivers.mpi_utils._can_connect" )
146
- @patch ("sagemaker.modules.train.container_drivers.mpi_utils._write_file_to_host" )
147
- def test_bootstrap_worker_node (mock_write , mock_connect ):
148
- """Test worker node bootstrap process."""
149
- mock_connect .return_value = True
150
- mock_write .return_value = True
151
-
152
- with patch .dict (os .environ , {"SM_CURRENT_HOST" : TEST_WORKER }):
153
- with pytest .raises (TimeoutError ):
154
- bootstrap_worker_node (TEST_HOST , timeout = 1 )
155
-
156
- mock_connect .assert_called_with (TEST_HOST )
157
- mock_write .assert_called_with (TEST_HOST , f"/tmp/ready.{ TEST_WORKER } " )
158
-
159
-
160
- @patch ("sagemaker.modules.train.container_drivers.mpi_utils._can_connect" )
161
- def test_bootstrap_master_node (mock_connect ):
162
- """Test master node bootstrap process."""
163
- mock_connect .return_value = True
164
-
165
- with pytest .raises (TimeoutError ):
166
- bootstrap_master_node ([TEST_WORKER ], timeout = 1 )
167
-
168
- mock_connect .assert_called_with (TEST_WORKER )
111
+ assert mock_run .call_count > 1 # Verifies that retries occurred
112
+ mock_logger .info .assert_any_call (f"Cannot connect to { TEST_WORKER } " )
169
113
170
114
171
115
if __name__ == "__main__" :
0 commit comments