Skip to content

Commit bc70321

Browse files
committed
Change to unit test
1 parent ced988f commit bc70321

File tree

2 files changed

+172
-140
lines changed

2 files changed

+172
-140
lines changed

tests/integ/sagemaker/modules/train/container_drivers/test_mpi_utils.py

Lines changed: 0 additions & 140 deletions
This file was deleted.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""MPI Utils Unit Tests."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
from unittest.mock import Mock, patch
18+
19+
import paramiko
20+
import pytest
21+
22+
from sagemaker.modules.train.container_drivers.mpi_utils import (
23+
CustomHostKeyPolicy,
24+
_can_connect,
25+
bootstrap_master_node,
26+
bootstrap_worker_node,
27+
get_mpirun_command,
28+
write_status_file_to_workers,
29+
)
30+
31+
TEST_HOST = "algo-1"
32+
TEST_WORKER = "algo-2"
33+
TEST_STATUS_FILE = "/tmp/test-status"
34+
35+
36+
def test_custom_host_key_policy_valid_hostname():
37+
"""Test CustomHostKeyPolicy with valid algo- hostname."""
38+
policy = CustomHostKeyPolicy()
39+
mock_client = Mock()
40+
mock_key = Mock()
41+
mock_key.get_name.return_value = "ssh-rsa"
42+
43+
policy.missing_host_key(mock_client, "algo-1", mock_key)
44+
45+
mock_client.get_host_keys.assert_called_once()
46+
mock_client.get_host_keys().add.assert_called_once_with("algo-1", "ssh-rsa", mock_key)
47+
48+
49+
def test_custom_host_key_policy_invalid_hostname():
50+
"""Test CustomHostKeyPolicy with invalid hostname."""
51+
policy = CustomHostKeyPolicy()
52+
mock_client = Mock()
53+
mock_key = Mock()
54+
55+
with pytest.raises(paramiko.SSHException) as exc_info:
56+
policy.missing_host_key(mock_client, "invalid-1", mock_key)
57+
58+
assert "Unknown host key for invalid-1" in str(exc_info.value)
59+
mock_client.get_host_keys.assert_not_called()
60+
61+
62+
@patch("paramiko.SSHClient")
63+
def test_can_connect_success(mock_ssh_client):
64+
"""Test successful SSH connection."""
65+
mock_client = Mock()
66+
mock_ssh_client.return_value = mock_client
67+
68+
assert _can_connect(TEST_HOST) is True
69+
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
70+
71+
72+
@patch("paramiko.SSHClient")
73+
def test_can_connect_failure(mock_ssh_client):
74+
"""Test SSH connection failure."""
75+
mock_client = Mock()
76+
mock_ssh_client.return_value = mock_client
77+
mock_client.connect.side_effect = Exception("Connection failed")
78+
79+
assert _can_connect(TEST_HOST) is False
80+
81+
82+
@patch("subprocess.run")
83+
def test_write_status_file_to_workers_success(mock_run):
84+
"""Test successful status file writing to workers."""
85+
mock_run.return_value = Mock(returncode=0)
86+
87+
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)
88+
89+
mock_run.assert_called_once()
90+
args = mock_run.call_args[0][0]
91+
assert args == ["ssh", TEST_WORKER, "touch", TEST_STATUS_FILE]
92+
93+
94+
@patch("subprocess.run")
95+
def test_write_status_file_to_workers_failure(mock_run):
96+
"""Test failed status file writing to workers with retry timeout."""
97+
mock_run.side_effect = Exception("SSH failed")
98+
99+
with pytest.raises(TimeoutError) as exc_info:
100+
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)
101+
102+
assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value)
103+
104+
105+
def test_get_mpirun_command_basic():
106+
"""Test basic MPI command generation."""
107+
with patch.dict(
108+
os.environ,
109+
{"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p3.16xlarge"},
110+
):
111+
command = get_mpirun_command(
112+
host_count=2,
113+
host_list=[TEST_HOST, TEST_WORKER],
114+
num_processes=2,
115+
additional_options=[],
116+
entry_script_path="train.py",
117+
)
118+
119+
assert command[0] == "mpirun"
120+
assert "--host" in command
121+
assert f"{TEST_HOST},{TEST_WORKER}" in command
122+
assert "-np" in command
123+
assert "2" in command
124+
125+
126+
def test_get_mpirun_command_efa():
127+
"""Test MPI command generation with EFA instance."""
128+
with patch.dict(
129+
os.environ,
130+
{"SM_NETWORK_INTERFACE_NAME": "eth0", "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge"},
131+
):
132+
command = get_mpirun_command(
133+
host_count=2,
134+
host_list=[TEST_HOST, TEST_WORKER],
135+
num_processes=2,
136+
additional_options=[],
137+
entry_script_path="train.py",
138+
)
139+
140+
command_str = " ".join(command)
141+
assert "FI_PROVIDER=efa" in command_str
142+
assert "NCCL_PROTO=simple" in command_str
143+
144+
145+
@patch("sagemaker.modules.train.container_drivers.mpi_utils._can_connect")
146+
@patch("sagemaker.modules.train.container_drivers.mpi_utils._write_file_to_host")
147+
def test_bootstrap_worker_node(mock_write, mock_connect):
148+
"""Test worker node bootstrap process."""
149+
mock_connect.return_value = True
150+
mock_write.return_value = True
151+
152+
with patch.dict(os.environ, {"SM_CURRENT_HOST": TEST_WORKER}):
153+
with pytest.raises(TimeoutError):
154+
bootstrap_worker_node(TEST_HOST, timeout=1)
155+
156+
mock_connect.assert_called_with(TEST_HOST)
157+
mock_write.assert_called_with(TEST_HOST, f"/tmp/ready.{TEST_WORKER}")
158+
159+
160+
@patch("sagemaker.modules.train.container_drivers.mpi_utils._can_connect")
161+
def test_bootstrap_master_node(mock_connect):
162+
"""Test master node bootstrap process."""
163+
mock_connect.return_value = True
164+
165+
with pytest.raises(TimeoutError):
166+
bootstrap_master_node([TEST_WORKER], timeout=1)
167+
168+
mock_connect.assert_called_with(TEST_WORKER)
169+
170+
171+
if __name__ == "__main__":
172+
pytest.main([__file__])

0 commit comments

Comments
 (0)