Skip to content

Commit fb706ee

Browse files
committed
Fix unit tests
1 parent bc70321 commit fb706ee

File tree

1 file changed

+41
-97
lines changed

1 file changed

+41
-97
lines changed

tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py

Lines changed: 41 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,33 @@
1313
"""MPI Utils Unit Tests."""
1414
from __future__ import absolute_import
1515

16-
import os
16+
import subprocess
1717
from unittest.mock import Mock, patch
1818

1919
import paramiko
2020
import pytest
2121

22-
from sagemaker.modules.train.container_drivers.mpi_utils import (
23-
CustomHostKeyPolicy,
24-
_can_connect,
25-
bootstrap_master_node,
26-
bootstrap_worker_node,
27-
get_mpirun_command,
28-
write_status_file_to_workers,
29-
)
22+
# Mock the utils module before importing mpi_utils
23+
mock_utils = Mock()
24+
mock_utils.logger = Mock()
25+
mock_utils.SM_EFA_NCCL_INSTANCES = []
26+
mock_utils.SM_EFA_RDMA_INSTANCES = []
27+
mock_utils.get_python_executable = Mock(return_value="/usr/bin/python")
28+
29+
with patch.dict("sys.modules", {"utils": mock_utils}):
30+
from sagemaker.modules.train.container_drivers.mpi_utils import (
31+
CustomHostKeyPolicy,
32+
_can_connect,
33+
write_status_file_to_workers,
34+
)
3035

3136
TEST_HOST = "algo-1"
3237
TEST_WORKER = "algo-2"
3338
TEST_STATUS_FILE = "/tmp/test-status"
3439

3540

3641
def test_custom_host_key_policy_valid_hostname():
37-
"""Test CustomHostKeyPolicy with valid algo- hostname."""
42+
"""Test CustomHostKeyPolicy accepts algo- prefixed hostnames."""
3843
policy = CustomHostKeyPolicy()
3944
mock_client = Mock()
4045
mock_key = Mock()
@@ -47,7 +52,7 @@ def test_custom_host_key_policy_valid_hostname():
4752

4853

4954
def test_custom_host_key_policy_invalid_hostname():
50-
"""Test CustomHostKeyPolicy with invalid hostname."""
55+
"""Test CustomHostKeyPolicy rejects non-algo prefixed hostnames."""
5156
policy = CustomHostKeyPolicy()
5257
mock_client = Mock()
5358
mock_key = Mock()
@@ -60,112 +65,51 @@ def test_custom_host_key_policy_invalid_hostname():
6065

6166

6267
@patch("paramiko.SSHClient")
63-
def test_can_connect_success(mock_ssh_client):
68+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
69+
def test_can_connect_success(mock_logger, mock_ssh_client):
6470
"""Test successful SSH connection."""
6571
mock_client = Mock()
66-
mock_ssh_client.return_value = mock_client
72+
mock_ssh_client.return_value.__enter__.return_value = mock_client
73+
mock_client.connect.return_value = None # Successful connection
74+
75+
result = _can_connect(TEST_HOST)
6776

68-
assert _can_connect(TEST_HOST) is True
77+
assert result is True
78+
mock_client.load_system_host_keys.assert_called_once()
79+
mock_client.set_missing_host_key_policy.assert_called_once()
6980
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
81+
mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST)
7082

7183

7284
@patch("paramiko.SSHClient")
73-
def test_can_connect_failure(mock_ssh_client):
85+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
86+
def test_can_connect_failure(mock_logger, mock_ssh_client):
7487
"""Test SSH connection failure."""
7588
mock_client = Mock()
76-
mock_ssh_client.return_value = mock_client
77-
mock_client.connect.side_effect = Exception("Connection failed")
78-
79-
assert _can_connect(TEST_HOST) is False
89+
mock_ssh_client.return_value.__enter__.return_value = mock_client
90+
mock_client.connect.side_effect = paramiko.SSHException("Connection failed")
8091

92+
result = _can_connect(TEST_HOST)
8193

82-
@patch("subprocess.run")
83-
def test_write_status_file_to_workers_success(mock_run):
84-
"""Test successful status file writing to workers."""
85-
mock_run.return_value = Mock(returncode=0)
86-
87-
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)
88-
89-
mock_run.assert_called_once()
90-
args = mock_run.call_args[0][0]
91-
assert args == ["ssh", TEST_WORKER, "touch", TEST_STATUS_FILE]
94+
assert result is False
95+
mock_client.load_system_host_keys.assert_called_once()
96+
mock_client.set_missing_host_key_policy.assert_called_once()
97+
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
98+
mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST)
9299

93100

94101
@patch("subprocess.run")
95-
def test_write_status_file_to_workers_failure(mock_run):
102+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
103+
def test_write_status_file_to_workers_failure(mock_logger, mock_run):
96104
"""Test failed status file writing to workers with retry timeout."""
97-
mock_run.side_effect = Exception("SSH failed")
105+
mock_run.side_effect = subprocess.CalledProcessError(1, "ssh")
98106

99107
with pytest.raises(TimeoutError) as exc_info:
100108
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)
101109

102110
assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value)
103-
104-
105-
def test_get_mpirun_command_basic():
106-
"""Test basic MPI command generation."""
107-
with patch.dict(
108-
os.environ,
109-
{"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p3.16xlarge"},
110-
):
111-
command = get_mpirun_command(
112-
host_count=2,
113-
host_list=[TEST_HOST, TEST_WORKER],
114-
num_processes=2,
115-
additional_options=[],
116-
entry_script_path="train.py",
117-
)
118-
119-
assert command[0] == "mpirun"
120-
assert "--host" in command
121-
assert f"{TEST_HOST},{TEST_WORKER}" in command
122-
assert "-np" in command
123-
assert "2" in command
124-
125-
126-
def test_get_mpirun_command_efa():
127-
"""Test MPI command generation with EFA instance."""
128-
with patch.dict(
129-
os.environ,
130-
{"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge"},
131-
):
132-
command = get_mpirun_command(
133-
host_count=2,
134-
host_list=[TEST_HOST, TEST_WORKER],
135-
num_processes=2,
136-
additional_options=[],
137-
entry_script_path="train.py",
138-
)
139-
140-
command_str = " ".join(command)
141-
assert "FI_PROVIDER=efa" in command_str
142-
assert "NCCL_PROTO=simple" in command_str
143-
144-
145-
@patch("sagemaker.modules.train.container_drivers.mpi_utils._can_connect")
146-
@patch("sagemaker.modules.train.container_drivers.mpi_utils._write_file_to_host")
147-
def test_bootstrap_worker_node(mock_write, mock_connect):
148-
"""Test worker node bootstrap process."""
149-
mock_connect.return_value = True
150-
mock_write.return_value = True
151-
152-
with patch.dict(os.environ, {"SM_CURRENT_HOST": TEST_WORKER}):
153-
with pytest.raises(TimeoutError):
154-
bootstrap_worker_node(TEST_HOST, timeout=1)
155-
156-
mock_connect.assert_called_with(TEST_HOST)
157-
mock_write.assert_called_with(TEST_HOST, f"/tmp/ready.{TEST_WORKER}")
158-
159-
160-
@patch("sagemaker.modules.train.container_drivers.mpi_utils._can_connect")
161-
def test_bootstrap_master_node(mock_connect):
162-
"""Test master node bootstrap process."""
163-
mock_connect.return_value = True
164-
165-
with pytest.raises(TimeoutError):
166-
bootstrap_master_node([TEST_WORKER], timeout=1)
167-
168-
mock_connect.assert_called_with(TEST_WORKER)
111+
assert mock_run.call_count > 1 # Verifies that retries occurred
112+
mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}")
169113

170114

171115
if __name__ == "__main__":

0 commit comments

Comments
 (0)